mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 02:02:46 +00:00
Deprecate static kernel (#42)
* deprecate static kernels
[ROCm/composable_kernel commit: 81c942cd7e]
This commit is contained in:
@@ -38,8 +38,6 @@ link_libraries(${OpenMP_pthread_LIBRARY})
|
||||
#GPU backend
|
||||
if(DEVICE_BACKEND STREQUAL "AMD")
|
||||
find_package(HIP REQUIRED)
|
||||
elseif(DEVICE_BACKEND STREQUAL "NVIDIA")
|
||||
enable_language(CUDA)
|
||||
endif()
|
||||
|
||||
#
|
||||
@@ -64,13 +62,7 @@ endif()
|
||||
if(DEVICE_BACKEND STREQUAL "AMD")
|
||||
configure_file("${PROJECT_SOURCE_DIR}/composable_kernel/include/utility/config.amd.hpp.in" "${PROJECT_BINARY_DIR}/composable_kernel/include/utility/config.hpp")
|
||||
configure_file("${PROJECT_SOURCE_DIR}/composable_kernel/include/utility/float_type.amd.hpp.in" "${PROJECT_BINARY_DIR}/composable_kernel/include/utility/float_type.hpp")
|
||||
configure_file("${PROJECT_SOURCE_DIR}/composable_kernel/include/utility/in_memory_operation.amd.hpp.in" "${PROJECT_BINARY_DIR}/composable_kernel/include/utility/in_memory_operation.hpp")
|
||||
configure_file("${PROJECT_SOURCE_DIR}/composable_kernel/include/utility/synchronization.amd.hpp.in" "${PROJECT_BINARY_DIR}/composable_kernel/include/utility/synchronization.hpp")
|
||||
elseif(DEVICE_BACKEND STREQUAL "NVIDIA")
|
||||
configure_file("${PROJECT_SOURCE_DIR}/composable_kernel/include/utility/config.nvidia.hpp.in" "${PROJECT_BINARY_DIR}/composable_kernel/include/utility/config.hpp")
|
||||
configure_file("${PROJECT_SOURCE_DIR}/composable_kernel/include/utility/float_type.nvidia.hpp.in" "${PROJECT_BINARY_DIR}/composable_kernel/include/utility/float_type.hpp")
|
||||
configure_file("${PROJECT_SOURCE_DIR}/composable_kernel/include/utility/in_memory_operation.nvidia.hpp.in" "${PROJECT_BINARY_DIR}/composable_kernel/include/utility/in_memory_operation.hpp")
|
||||
configure_file("${PROJECT_SOURCE_DIR}/composable_kernel/include/utility/synchronization.nvidia.hpp.in" "${PROJECT_BINARY_DIR}/composable_kernel/include/utility/synchronization.hpp")
|
||||
endif()
|
||||
|
||||
add_subdirectory(driver)
|
||||
@@ -80,26 +72,17 @@ set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS}")
|
||||
message("Compiling options for drivers: ${CMAKE_CXX_FLAGS}")
|
||||
|
||||
if(DEVICE_BACKEND STREQUAL "AMD")
|
||||
set(CONV_SOURCE driver/conv_driver.cpp)
|
||||
set(CONV_BWD_DATA_SOURCE driver/conv_bwd_data_driver.cpp)
|
||||
set(CONV_V2_SOURCE driver/conv_driver_v2.cpp)
|
||||
set(CONV_BWD_DATA_V2_SOURCE driver/conv_bwd_data_driver_v2.cpp)
|
||||
set(CONV_V2_OLC_SOURCE driver/conv_driver_v2_olc.cpp)
|
||||
elseif(DEVICE_BACKEND STREQUAL "NVIDIA")
|
||||
set(CONV_SOURCE driver/conv_driver.cu)
|
||||
set(CONV_BWD_DATA_SOURCE driver/conv_bwd_data_driver.cu)
|
||||
endif()
|
||||
|
||||
add_executable(conv_driver ${CONV_SOURCE})
|
||||
add_executable(conv_bwd_data_driver ${CONV_BWD_DATA_SOURCE})
|
||||
add_executable(conv_driver_v2 ${CONV_V2_SOURCE})
|
||||
add_executable(conv_bwd_data_driver_v2 ${CONV_BWD_DATA_V2_SOURCE})
|
||||
add_executable(conv_driver_v2_olc ${CONV_V2_OLC_SOURCE})
|
||||
|
||||
target_include_directories(conv_driver_v2_olc PRIVATE driver/olCompiling/include/)
|
||||
|
||||
target_link_libraries(conv_driver PRIVATE modConv)
|
||||
target_link_libraries(conv_bwd_data_driver PRIVATE modConv)
|
||||
target_link_libraries(conv_driver_v2 PRIVATE modConv)
|
||||
target_link_libraries(conv_bwd_data_driver_v2 PRIVATE modConv)
|
||||
target_link_libraries(conv_driver_v2_olc PRIVATE modConv)
|
||||
|
||||
@@ -1,172 +0,0 @@
|
||||
#ifndef CK_GRIDWISE_CONVOLUTION_BACKWARD_DATA_IMPLICIT_GEMM_V1R1_NCHW_KCYX_NKHW_HPP
|
||||
#define CK_GRIDWISE_CONVOLUTION_BACKWARD_DATA_IMPLICIT_GEMM_V1R1_NCHW_KCYX_NKHW_HPP
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "tensor_descriptor.hpp"
|
||||
#include "tensor_descriptor_helper.hpp"
|
||||
#include "gridwise_gemm.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
// GemmM = C * Y * X
|
||||
// GemmN = N * Ho * Wo
|
||||
// GemmK = K
|
||||
template <index_t GridSize,
|
||||
index_t BlockSize,
|
||||
typename Float,
|
||||
typename AccFloat,
|
||||
typename InGlobalDesc,
|
||||
typename WeiGlobalDesc,
|
||||
typename OutGlobalDesc,
|
||||
typename ConvStrides,
|
||||
typename ConvDilations,
|
||||
typename InLeftPads,
|
||||
typename InRightPads,
|
||||
index_t GemmMPerBlock,
|
||||
index_t GemmNPerBlock,
|
||||
index_t GemmKPerBlock,
|
||||
index_t GemmMPerThread,
|
||||
index_t GemmNPerThread,
|
||||
index_t GemmKPerThread,
|
||||
index_t GemmMLevel0Cluster,
|
||||
index_t GemmNLevel0Cluster,
|
||||
index_t GemmMLevel1Cluster,
|
||||
index_t GemmNLevel1Cluster,
|
||||
index_t ThreadGemmDataPerRead_GemmM,
|
||||
index_t ThreadGemmDataPerRead_GemmN,
|
||||
typename GemmABlockCopyThreadSliceLengths_GemmK_GemmM,
|
||||
typename GemmABlockCopyThreadClusterLengths_GemmK_GemmM,
|
||||
index_t GemmABlockCopySrcDataPerRead_GemmM,
|
||||
index_t GemmABlockCopyDstDataPerWrite_GemmM,
|
||||
typename GemmBBlockCopyThreadSliceLengths_GemmK_GemmN,
|
||||
typename GemmBBlockCopyThreadClusterLengths_GemmK_GemmN,
|
||||
index_t GemmBBlockCopySrcDataPerRead_GemmN,
|
||||
index_t GemmBBlockCopyDstDataPerWrite_GemmN,
|
||||
index_t GemmCThreadCopyDstDataPerWrite_GemmN1>
|
||||
struct GridwiseConvolutionBackwardDataImplicitGemm_v1r1_nchw_kcyx_nkhw
|
||||
{
|
||||
__device__ void Run(Float* __restrict__ p_in_global,
|
||||
const Float* __restrict__ p_wei_global,
|
||||
const Float* __restrict__ p_out_global) const
|
||||
{
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
constexpr auto in_n_c_hi_wi_global_desc = InGlobalDesc{};
|
||||
constexpr auto wei_k_c_y_x_global_desc = WeiGlobalDesc{};
|
||||
constexpr auto out_n_k_ho_wo_global_desc = OutGlobalDesc{};
|
||||
|
||||
constexpr index_t N = in_n_c_hi_wi_global_desc.GetLengths()[0];
|
||||
constexpr index_t C = in_n_c_hi_wi_global_desc.GetLengths()[1];
|
||||
constexpr index_t Hi = in_n_c_hi_wi_global_desc.GetLengths()[2];
|
||||
constexpr index_t Wi = in_n_c_hi_wi_global_desc.GetLengths()[3];
|
||||
|
||||
constexpr index_t K = out_n_k_ho_wo_global_desc.GetLengths()[1];
|
||||
constexpr index_t Ho = out_n_k_ho_wo_global_desc.GetLengths()[2];
|
||||
constexpr index_t Wo = out_n_k_ho_wo_global_desc.GetLengths()[3];
|
||||
|
||||
constexpr index_t Y = wei_k_c_y_x_global_desc.GetLengths()[2];
|
||||
constexpr index_t X = wei_k_c_y_x_global_desc.GetLengths()[3];
|
||||
|
||||
constexpr index_t ConvStrideH = ConvStrides{}[0];
|
||||
constexpr index_t ConvStrideW = ConvStrides{}[1];
|
||||
|
||||
constexpr index_t ConvDilationH = ConvDilations{}[0];
|
||||
constexpr index_t ConvDilationW = ConvDilations{}[1];
|
||||
|
||||
//\todo static_assert for global vector load/store
|
||||
// statc_assert();
|
||||
|
||||
// weight tensor
|
||||
constexpr auto wei_gemmk_gemmm_global_desc =
|
||||
unfold_tensor_descriptor(wei_k_c_y_x_global_desc, I1, I3);
|
||||
|
||||
// input tensor
|
||||
constexpr auto in_n_c_hip_wip_global_desc = transform_tensor_descriptor(
|
||||
in_n_c_hi_wi_global_desc,
|
||||
make_tuple(PassThrough<N>{},
|
||||
PassThrough<C>{},
|
||||
Pad<Sequence<Hi, Wi>, InLeftPads, InRightPads>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}));
|
||||
|
||||
constexpr index_t Hip = in_n_c_hip_wip_global_desc.GetLengths()[2];
|
||||
constexpr index_t Wip = in_n_c_hip_wip_global_desc.GetLengths()[3];
|
||||
|
||||
constexpr auto in_n_c_y_ho_x_wo_global_desc = transform_tensor_descriptor(
|
||||
in_n_c_hip_wip_global_desc,
|
||||
make_tuple(PassThrough<N>{},
|
||||
PassThrough<C>{},
|
||||
Embed<Hip, Sequence<Y, Ho>, Sequence<ConvDilationH, ConvStrideH, 0>>{},
|
||||
Embed<Wip, Sequence<X, Wo>, Sequence<ConvDilationW, ConvStrideW, 0>>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}));
|
||||
|
||||
constexpr auto in_gemmm_gemmn_global_desc = transform_tensor_descriptor(
|
||||
in_n_c_y_ho_x_wo_global_desc,
|
||||
make_tuple(Merge<Sequence<C, Y, X>>{}, Merge<Sequence<N, Ho, Wo>>{}),
|
||||
make_tuple(Sequence<1, 2, 4>{}, Sequence<0, 3, 5>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
// output tensor
|
||||
constexpr auto out_gemmk_gemmn_global_desc =
|
||||
transform_tensor_descriptor(unfold_tensor_descriptor(out_n_k_ho_wo_global_desc, I2, I3),
|
||||
make_tuple(PassThrough<K>{}, Merge<Sequence<N, Ho * Wo>>{}),
|
||||
make_tuple(Sequence<1>{}, Sequence<0, 2>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
// GEMM
|
||||
// \todo there are more combinations of Y, ConvDilationH and ConvStrideH that don't need
|
||||
// atomic, find out all of them
|
||||
constexpr bool not_need_atomic = (ConvStrideH >= ConvDilationH * (Y - 1) + 1) and
|
||||
(ConvStrideW >= ConvDilationW * (X - 1) + 1);
|
||||
|
||||
constexpr auto in_memory_op =
|
||||
not_need_atomic ? InMemoryDataOperation::Set : InMemoryDataOperation::AtomicAdd;
|
||||
|
||||
constexpr auto gridwise_gemm =
|
||||
GridwiseGemmTransposedANormalBNormalC_v1<GridSize,
|
||||
BlockSize,
|
||||
Float,
|
||||
AccFloat,
|
||||
decltype(wei_gemmk_gemmm_global_desc),
|
||||
decltype(out_gemmk_gemmn_global_desc),
|
||||
decltype(in_gemmm_gemmn_global_desc),
|
||||
in_memory_op,
|
||||
GemmMPerBlock,
|
||||
GemmNPerBlock,
|
||||
GemmKPerBlock,
|
||||
GemmMPerThread,
|
||||
GemmNPerThread,
|
||||
GemmKPerThread,
|
||||
GemmMLevel0Cluster,
|
||||
GemmNLevel0Cluster,
|
||||
GemmMLevel1Cluster,
|
||||
GemmNLevel1Cluster,
|
||||
ThreadGemmDataPerRead_GemmM,
|
||||
ThreadGemmDataPerRead_GemmN,
|
||||
GemmABlockCopyThreadSliceLengths_GemmK_GemmM,
|
||||
GemmABlockCopyThreadClusterLengths_GemmK_GemmM,
|
||||
Sequence<0, 1>,
|
||||
Sequence<0, 1>,
|
||||
1,
|
||||
GemmABlockCopySrcDataPerRead_GemmM,
|
||||
GemmABlockCopyDstDataPerWrite_GemmM,
|
||||
GemmBBlockCopyThreadSliceLengths_GemmK_GemmN,
|
||||
GemmBBlockCopyThreadClusterLengths_GemmK_GemmN,
|
||||
Sequence<0, 1>,
|
||||
Sequence<0, 1>,
|
||||
1,
|
||||
GemmBBlockCopySrcDataPerRead_GemmN,
|
||||
GemmBBlockCopyDstDataPerWrite_GemmN,
|
||||
Sequence<0, 1, 2, 3>,
|
||||
3,
|
||||
GemmCThreadCopyDstDataPerWrite_GemmN1>{};
|
||||
|
||||
gridwise_gemm.Run(p_wei_global, p_out_global, p_in_global);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
@@ -1,450 +0,0 @@
|
||||
#ifndef CK_GRIDWISE_CONVOLUTION_BACKWARD_DATA_IMPLICIT_GEMM_V1R2_NCHW_KCYX_NKHW_LDS_DOUBLE_BUFFER_HPP
|
||||
#define CK_GRIDWISE_CONVOLUTION_BACKWARD_DATA_IMPLICIT_GEMM_V1R2_NCHW_KCYX_NKHW_LDS_DOUBLE_BUFFER_HPP
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "tensor_descriptor.hpp"
|
||||
#include "tensor_descriptor_helper.hpp"
|
||||
#include "ConstantMatrixDescriptor.hpp"
|
||||
#include "blockwise_generic_tensor_slice_copy.hpp"
|
||||
#include "threadwise_generic_tensor_slice_copy.hpp"
|
||||
#include "blockwise_gemm.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <index_t GridSize,
|
||||
index_t BlockSize,
|
||||
typename Float,
|
||||
typename AccFloat,
|
||||
typename InGlobalDesc,
|
||||
typename WeiGlobalDesc,
|
||||
typename OutGlobalDesc,
|
||||
typename ConvStrides,
|
||||
typename ConvDilations,
|
||||
typename LeftPads,
|
||||
typename RightPads,
|
||||
index_t EPerBlock,
|
||||
index_t BPerBlock,
|
||||
index_t KPerBlock,
|
||||
index_t GemmMPerThread,
|
||||
index_t GemmNPerThread,
|
||||
index_t GemmKPerThread,
|
||||
index_t GemmMLevel0Cluster,
|
||||
index_t GemmNLevel0Cluster,
|
||||
index_t GemmMLevel1Cluster,
|
||||
index_t GemmNLevel1Cluster,
|
||||
index_t GemmDataPerReadA,
|
||||
index_t GemmDataPerReadB,
|
||||
typename OutBlockCopySubLengths_K_B_N0,
|
||||
typename OutBlockCopyClusterLengths_K_B_N0,
|
||||
index_t OutBlockCopySrcDataPerRead_B,
|
||||
index_t OutBlockCopyDstDataPerWrite_N0,
|
||||
typename WeiBlockCopySubLengths_K_E_C0,
|
||||
typename WeiBlockCopyClusterLengths_K_E_C0,
|
||||
index_t WeiBlockCopySrcDataPerRead_E,
|
||||
index_t WeiBlockCopyDstDataPerWrite_C0,
|
||||
index_t InThreadCopyDstDataPerWrite_B>
|
||||
struct GridwiseConvolutionBackwardDataImplicitGemm_v1r2_nchw_kcyx_nkhw_lds_double_buffer
|
||||
{
|
||||
__device__ void Run(Float* const __restrict__ p_in_global,
|
||||
const Float* const __restrict__ p_wei_global,
|
||||
const Float* const __restrict__ p_out_global) const
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
constexpr auto True = integral_constant<bool, true>{};
|
||||
|
||||
constexpr auto in_n_c_hi_wi_global_desc = InGlobalDesc{};
|
||||
constexpr auto wei_k_c_y_x_global_desc = WeiGlobalDesc{};
|
||||
constexpr auto out_n_k_ho_wo_global_desc = OutGlobalDesc{};
|
||||
|
||||
constexpr index_t N = in_n_c_hi_wi_global_desc.GetLengths()[0];
|
||||
constexpr index_t C = in_n_c_hi_wi_global_desc.GetLengths()[1];
|
||||
constexpr index_t Hi = in_n_c_hi_wi_global_desc.GetLengths()[2];
|
||||
constexpr index_t Wi = in_n_c_hi_wi_global_desc.GetLengths()[3];
|
||||
|
||||
constexpr index_t K = out_n_k_ho_wo_global_desc.GetLengths()[1];
|
||||
constexpr index_t Ho = out_n_k_ho_wo_global_desc.GetLengths()[2];
|
||||
constexpr index_t Wo = out_n_k_ho_wo_global_desc.GetLengths()[3];
|
||||
|
||||
constexpr index_t Y = wei_k_c_y_x_global_desc.GetLengths()[2];
|
||||
constexpr index_t X = wei_k_c_y_x_global_desc.GetLengths()[3];
|
||||
|
||||
constexpr index_t ConvStrideH = ConvStrides{}[0];
|
||||
constexpr index_t ConvStrideW = ConvStrides{}[1];
|
||||
|
||||
constexpr index_t ConvDilationH = ConvDilations{}[0];
|
||||
constexpr index_t ConvDilationW = ConvDilations{}[1];
|
||||
|
||||
constexpr index_t C0 = GemmMPerThread;
|
||||
constexpr index_t N0 = GemmNPerThread;
|
||||
|
||||
static_assert(C % C0 == 0 && N % N0 == 0, "wrong!");
|
||||
|
||||
constexpr index_t C1 = C / C0;
|
||||
constexpr index_t N1 = N / N0;
|
||||
|
||||
constexpr index_t E = C1 * Y * X;
|
||||
constexpr index_t B = N1 * Ho * Wo;
|
||||
|
||||
// sanity-check for vectorized memory load
|
||||
static_assert((Wo == 1 || (ConvStrideW == 1 || InThreadCopyDstDataPerWrite_B == 1)) &&
|
||||
(X == 1 || ConvDilationW % InThreadCopyDstDataPerWrite_B == 0),
|
||||
"wrong! aligment requirement for vectorized global load of input tensor will "
|
||||
"be violated");
|
||||
|
||||
// divide block work by [K, B]
|
||||
static_assert(E % EPerBlock == 0 && B % BPerBlock == 0 && K % KPerBlock == 0,
|
||||
"wrong! cannot divide work evenly among block");
|
||||
|
||||
constexpr index_t EBlockWork = E / EPerBlock;
|
||||
constexpr index_t BBlockWork = B / BPerBlock;
|
||||
|
||||
constexpr auto block_work_desc =
|
||||
make_cluster_descriptor(Sequence<EBlockWork, BBlockWork>{});
|
||||
|
||||
const auto block_work_id = block_work_desc.CalculateClusterIndex(get_block_1d_id());
|
||||
|
||||
const index_t e_block_data_on_global = block_work_id[Number<0>{}] * EPerBlock;
|
||||
const index_t b_block_data_on_global = block_work_id[Number<1>{}] * BPerBlock;
|
||||
|
||||
// output tensor
|
||||
// global tensor in global memory, src of blockwise copy
|
||||
constexpr auto out_n_k_howo_global_desc =
|
||||
unfold_tensor_descriptor(out_n_k_ho_wo_global_desc, I2, I3);
|
||||
|
||||
constexpr auto out_n0_n1_k_howo_global_desc = transform_tensor_descriptor(
|
||||
out_n_k_howo_global_desc,
|
||||
make_tuple(UnMerge<Sequence<N0, N1>>{}, PassThrough<K>{}, PassThrough<Ho * Wo>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
|
||||
make_tuple(Sequence<0, 1>{}, Sequence<2>{}, Sequence<3>{}));
|
||||
|
||||
constexpr auto out_k_b_n0_global_desc = transform_tensor_descriptor(
|
||||
out_n0_n1_k_howo_global_desc,
|
||||
make_tuple(PassThrough<K>{}, Merge<Sequence<N1, Ho * Wo>>{}, PassThrough<N0>{}),
|
||||
make_tuple(Sequence<2>{}, Sequence<1, 3>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
|
||||
|
||||
// block tensor in LDS memory, dst of blockwise copy
|
||||
// be careful of LDS alignment
|
||||
constexpr auto out_k_b_n0_block_desc = make_native_tensor_descriptor_aligned(
|
||||
Sequence<KPerBlock, BPerBlock, N0>{}, Number<OutBlockCopyDstDataPerWrite_N0>{});
|
||||
|
||||
// output tensor blockwise copy
|
||||
auto blockwise_out_copy =
|
||||
BlockwiseGenericTensorSliceCopy_v4<BlockSize,
|
||||
decltype(out_k_b_n0_global_desc),
|
||||
decltype(out_k_b_n0_block_desc),
|
||||
decltype(out_k_b_n0_block_desc.GetLengths()),
|
||||
OutBlockCopySubLengths_K_B_N0,
|
||||
OutBlockCopyClusterLengths_K_B_N0,
|
||||
Sequence<0, 1, 2>,
|
||||
Sequence<0, 1, 2>,
|
||||
Sequence<0, 1, 2>,
|
||||
1,
|
||||
2,
|
||||
OutBlockCopySrcDataPerRead_B,
|
||||
OutBlockCopyDstDataPerWrite_N0,
|
||||
AddressSpace::Global,
|
||||
AddressSpace::Vgpr,
|
||||
AddressSpace::Lds,
|
||||
InMemoryDataOperation::Set>(
|
||||
make_multi_index(0, b_block_data_on_global, 0), make_multi_index(0, 0, 0));
|
||||
|
||||
// weight tensor
|
||||
// global tensor in global memory, src of blockwise copy
|
||||
constexpr auto wei_k_cyx_global_desc =
|
||||
unfold_tensor_descriptor(wei_k_c_y_x_global_desc, I1, I3);
|
||||
|
||||
constexpr auto wei_k_c0_e_global_desc =
|
||||
transform_tensor_descriptor(wei_k_cyx_global_desc,
|
||||
make_tuple(PassThrough<K>{}, UnMerge<Sequence<C0, E>>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1, 2>{}));
|
||||
|
||||
constexpr auto wei_k_e_c0_global_desc = reorder_tensor_descriptor_given_lower2upper(
|
||||
wei_k_c0_e_global_desc, Sequence<0, 2, 1>{});
|
||||
|
||||
// block tensor in LDS memory, dst of blockwise copy
|
||||
// be careful of LDS alignment
|
||||
constexpr auto wei_k_e_c0_block_desc = make_native_tensor_descriptor_aligned(
|
||||
Sequence<KPerBlock, EPerBlock, C0>{}, Number<WeiBlockCopyDstDataPerWrite_C0>{});
|
||||
|
||||
// weight tensor blockwise copy
|
||||
auto blockwise_wei_copy =
|
||||
BlockwiseGenericTensorSliceCopy_v4<BlockSize,
|
||||
decltype(wei_k_e_c0_global_desc),
|
||||
decltype(wei_k_e_c0_block_desc),
|
||||
decltype(wei_k_e_c0_block_desc.GetLengths()),
|
||||
WeiBlockCopySubLengths_K_E_C0,
|
||||
WeiBlockCopyClusterLengths_K_E_C0,
|
||||
Sequence<0, 1, 2>,
|
||||
Sequence<0, 1, 2>,
|
||||
Sequence<0, 1, 2>,
|
||||
1,
|
||||
2,
|
||||
WeiBlockCopySrcDataPerRead_E,
|
||||
WeiBlockCopyDstDataPerWrite_C0,
|
||||
AddressSpace::Global,
|
||||
AddressSpace::Vgpr,
|
||||
AddressSpace::Lds,
|
||||
InMemoryDataOperation::Set>(
|
||||
make_multi_index(0, e_block_data_on_global, 0), make_multi_index(0, 0, 0));
|
||||
|
||||
// GEMM definition
|
||||
// c_mtx += transpose(a_mtx) * b_mtx
|
||||
// a_mtx[KPerBlock, EPerBlock*C0] is in LDS
|
||||
// b_mtx[KPerBlocl, BPerBlock*N0] is in LDS
|
||||
// c_mtx[EPerBlock*C0, BPerBlock*N0] is distributed among threads, and saved in
|
||||
// register
|
||||
constexpr auto a_k_ec0_block_mtx_desc = make_ConstantMatrixDescriptor(
|
||||
wei_k_e_c0_block_desc.GetLength(I0),
|
||||
wei_k_e_c0_block_desc.GetLength(I1) * wei_k_e_c0_block_desc.GetLength(I2),
|
||||
wei_k_e_c0_block_desc.GetStride(I0));
|
||||
constexpr auto b_k_bn0_block_mtx_desc = make_ConstantMatrixDescriptor(
|
||||
out_k_b_n0_block_desc.GetLength(I0),
|
||||
out_k_b_n0_block_desc.GetLength(I1) * out_k_b_n0_block_desc.GetLength(I2),
|
||||
out_k_b_n0_block_desc.GetStride(I0));
|
||||
|
||||
// sanity check alignment
|
||||
// TODO: this check is ad-hoc, should enforce it by enforcing alignment of
|
||||
// wei_k_e_c0_block_desc and out_k_b_n0_block_desc
|
||||
static_assert(a_k_ec0_block_mtx_desc.RowStride() % GemmDataPerReadB == 0, "wrong!");
|
||||
static_assert(b_k_bn0_block_mtx_desc.RowStride() % GemmDataPerReadA == 0, "wrong!");
|
||||
|
||||
// sanity check
|
||||
static_assert(EPerBlock % (GemmMLevel0Cluster * GemmMLevel1Cluster) == 0 &&
|
||||
BPerBlock % (GemmNLevel0Cluster * GemmNLevel1Cluster) == 0,
|
||||
"wrong!");
|
||||
|
||||
constexpr index_t GemmMRepeat = EPerBlock / (GemmMLevel0Cluster * GemmMLevel1Cluster);
|
||||
constexpr index_t GemmNRepeat = BPerBlock / (GemmNLevel0Cluster * GemmNLevel1Cluster);
|
||||
|
||||
// c_thread_mtx definition: this is a mess
|
||||
// TODO:: more elegent way of defining c_thread_mtx
|
||||
constexpr auto c_e0e1c0_b0b1n0_thread_mtx_desc = make_ConstantMatrixDescriptor_packed(
|
||||
Number<GemmMRepeat * GemmMPerThread>{}, Number<GemmNRepeat * GemmNPerThread>{});
|
||||
|
||||
const auto blockwise_gemm = BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2<
|
||||
BlockSize,
|
||||
decltype(a_k_ec0_block_mtx_desc),
|
||||
decltype(b_k_bn0_block_mtx_desc),
|
||||
decltype(c_e0e1c0_b0b1n0_thread_mtx_desc),
|
||||
GemmMPerThread,
|
||||
GemmNPerThread,
|
||||
GemmKPerThread,
|
||||
GemmMLevel0Cluster,
|
||||
GemmNLevel0Cluster,
|
||||
GemmMLevel1Cluster,
|
||||
GemmNLevel1Cluster,
|
||||
GemmDataPerReadA,
|
||||
GemmDataPerReadB>{};
|
||||
|
||||
// LDS allocation for input and weight: be careful of alignment
|
||||
constexpr index_t max_lds_align = math::lcm(WeiBlockCopyDstDataPerWrite_C0,
|
||||
OutBlockCopyDstDataPerWrite_N0,
|
||||
GemmDataPerReadA,
|
||||
GemmDataPerReadB);
|
||||
|
||||
constexpr index_t out_block_space =
|
||||
math::integer_least_multiple(out_k_b_n0_block_desc.GetElementSpace(), max_lds_align);
|
||||
|
||||
constexpr index_t wei_block_space =
|
||||
math::integer_least_multiple(wei_k_e_c0_block_desc.GetElementSpace(), max_lds_align);
|
||||
|
||||
__shared__ Float p_out_block_double[2 * out_block_space];
|
||||
__shared__ Float p_wei_block_double[2 * wei_block_space];
|
||||
|
||||
// register allocation for output
|
||||
AccFloat p_in_thread[c_e0e1c0_b0b1n0_thread_mtx_desc.GetElementSpace()];
|
||||
|
||||
// zero out threadwise output
|
||||
threadwise_matrix_set_zero(c_e0e1c0_b0b1n0_thread_mtx_desc, p_in_thread);
|
||||
|
||||
// LDS double buffer: preload data into LDS
|
||||
{
|
||||
blockwise_out_copy.Run(p_out_global, p_out_block_double);
|
||||
blockwise_wei_copy.Run(p_wei_global, p_wei_block_double);
|
||||
}
|
||||
|
||||
// LDS double buffer: main body
|
||||
for(index_t k_block_data_begin = 0; k_block_data_begin + 2 * KPerBlock < K;
|
||||
k_block_data_begin += 2 * KPerBlock)
|
||||
{
|
||||
#pragma unroll
|
||||
for(index_t iloop = 0; iloop < 2; ++iloop)
|
||||
{
|
||||
const bool even_loop = (iloop % 2 == 0);
|
||||
|
||||
Float* p_out_block_now =
|
||||
even_loop ? p_out_block_double : p_out_block_double + out_block_space;
|
||||
Float* p_wei_block_now =
|
||||
even_loop ? p_wei_block_double : p_wei_block_double + wei_block_space;
|
||||
|
||||
Float* p_out_block_next =
|
||||
even_loop ? p_out_block_double + out_block_space : p_out_block_double;
|
||||
Float* p_wei_block_next =
|
||||
even_loop ? p_wei_block_double + wei_block_space : p_wei_block_double;
|
||||
|
||||
Float p_out_thread_buffer[blockwise_out_copy.GetThreadBufferSize()];
|
||||
Float p_wei_thread_buffer[blockwise_wei_copy.GetThreadBufferSize()];
|
||||
|
||||
blockwise_out_copy.MoveSrcSliceWindow(Sequence<KPerBlock, 0, 0>{}, True);
|
||||
blockwise_wei_copy.MoveSrcSliceWindow(Sequence<KPerBlock, 0, 0>{}, True);
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// LDS doubel buffer: load next data from device mem
|
||||
blockwise_out_copy.RunLoadThreadBuffer(p_out_global, p_out_thread_buffer);
|
||||
blockwise_wei_copy.RunLoadThreadBuffer(p_wei_global, p_wei_thread_buffer);
|
||||
|
||||
// LDS double buffer: GEMM on current data
|
||||
blockwise_gemm.Run(p_wei_block_now, p_out_block_now, p_in_thread);
|
||||
|
||||
// LDS double buffer: store next data to LDS
|
||||
blockwise_out_copy.RunStoreThreadBuffer(p_out_thread_buffer, p_out_block_next);
|
||||
blockwise_wei_copy.RunStoreThreadBuffer(p_wei_thread_buffer, p_wei_block_next);
|
||||
}
|
||||
}
|
||||
|
||||
// LDS double buffer: tail
|
||||
{
|
||||
constexpr bool has_two_iteration_left = (K % (2 * KPerBlock) == 0);
|
||||
|
||||
if(has_two_iteration_left) // if has 2 iteration left
|
||||
{
|
||||
Float p_out_thread_buffer[blockwise_out_copy.GetThreadBufferSize()];
|
||||
Float p_wei_thread_buffer[blockwise_wei_copy.GetThreadBufferSize()];
|
||||
|
||||
blockwise_out_copy.MoveSrcSliceWindow(Sequence<KPerBlock, 0, 0>{}, True);
|
||||
blockwise_wei_copy.MoveSrcSliceWindow(Sequence<KPerBlock, 0, 0>{}, True);
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// LDS double buffer: load last data from device mem
|
||||
blockwise_out_copy.RunLoadThreadBuffer(p_out_global, p_out_thread_buffer);
|
||||
blockwise_wei_copy.RunLoadThreadBuffer(p_wei_global, p_wei_thread_buffer);
|
||||
|
||||
// LDS double buffer: GEMM on 2nd-last data
|
||||
blockwise_gemm.Run(p_wei_block_double, p_out_block_double, p_in_thread);
|
||||
|
||||
// LDS double buffer: store last data to LDS
|
||||
blockwise_out_copy.RunStoreThreadBuffer(p_out_thread_buffer,
|
||||
p_out_block_double + out_block_space);
|
||||
blockwise_wei_copy.RunStoreThreadBuffer(p_wei_thread_buffer,
|
||||
p_wei_block_double + wei_block_space);
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// LDS double buffer: GEMM on last data
|
||||
blockwise_gemm.Run(p_wei_block_double + wei_block_space,
|
||||
p_out_block_double + out_block_space,
|
||||
p_in_thread);
|
||||
}
|
||||
else // if has 1 iteration left
|
||||
{
|
||||
__syncthreads();
|
||||
|
||||
// LDS double buffer: GEMM on last data
|
||||
blockwise_gemm.Run(p_wei_block_double, p_out_block_double, p_in_thread);
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
#if 1 // debug
|
||||
// input: register to global memory, atomic add
|
||||
constexpr auto in_memory_op = (Y <= ConvStrideH && X <= ConvStrideW)
|
||||
? InMemoryDataOperation::Set
|
||||
: InMemoryDataOperation::AtomicAdd;
|
||||
#else
|
||||
constexpr auto in_memory_op = InMemoryDataOperation::AtomicAdd;
|
||||
#endif
|
||||
|
||||
constexpr index_t E1 = GemmMLevel0Cluster * GemmMLevel1Cluster;
|
||||
constexpr index_t E0 = E / E1;
|
||||
|
||||
constexpr index_t B1 = GemmNLevel0Cluster * GemmNLevel1Cluster;
|
||||
constexpr index_t B0 = B / B1;
|
||||
|
||||
// define input tensor descriptor for threadwise copy
|
||||
// thread input tensor, src of threadwise copy
|
||||
constexpr auto in_e0_e1_c0_b0_b1_n0_thread_desc = make_native_tensor_descriptor_packed(
|
||||
Sequence<GemmMRepeat, 1, GemmMPerThread, GemmNRepeat, 1, GemmNPerThread>{});
|
||||
|
||||
// global input tensor, dst of threadwise copy
|
||||
constexpr auto in_n_c_hip_wip_global_desc = transform_tensor_descriptor(
|
||||
in_n_c_hi_wi_global_desc,
|
||||
make_tuple(PassThrough<N>{},
|
||||
PassThrough<C>{},
|
||||
Pad<Sequence<Hi, Wi>, LeftPads, RightPads>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}));
|
||||
|
||||
constexpr auto in_n0_n1_c0_c1_y_ho_x_wo_global_desc = transform_tensor_descriptor(
|
||||
in_n_c_hip_wip_global_desc,
|
||||
make_tuple(UnMerge<Sequence<N0, N1>>{},
|
||||
UnMerge<Sequence<C0, C1>>{},
|
||||
Embed<Hi + LeftPads::At(0) + RightPads::At(0),
|
||||
Sequence<Y, Ho>,
|
||||
Sequence<ConvDilationH, ConvStrideH, 0>>{},
|
||||
Embed<Wi + LeftPads::At(1) + RightPads::At(1),
|
||||
Sequence<X, Wo>,
|
||||
Sequence<ConvDilationW, ConvStrideW, 0>>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}, Sequence<6, 7>{}));
|
||||
|
||||
constexpr auto in_e_c0_b_n0_global_desc = transform_tensor_descriptor(
|
||||
in_n0_n1_c0_c1_y_ho_x_wo_global_desc,
|
||||
make_tuple(Merge<Sequence<C1, Y, X>>{},
|
||||
PassThrough<C0>{},
|
||||
Merge<Sequence<N1, Ho, Wo>>{},
|
||||
PassThrough<N0>{}),
|
||||
make_tuple(Sequence<3, 4, 6>{}, Sequence<2>{}, Sequence<1, 5, 7>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
|
||||
|
||||
constexpr auto in_e0_e1_c0_b0_b1_n0_global_desc = transform_tensor_descriptor(
|
||||
in_e_c0_b_n0_global_desc,
|
||||
make_tuple(UnMerge<Sequence<E0, E1>>{},
|
||||
PassThrough<C0>{},
|
||||
UnMerge<Sequence<B0, B1>>{},
|
||||
PassThrough<N0>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0, 1>{}, Sequence<2>{}, Sequence<3, 4>{}, Sequence<5>{}));
|
||||
|
||||
// calculate origin of thread input tensor on global memory
|
||||
// blockwise GEMM c matrix starting index
|
||||
const auto c_thread_mtx_on_block =
|
||||
blockwise_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id());
|
||||
|
||||
const index_t e_thread_data_on_global =
|
||||
e_block_data_on_global + c_thread_mtx_on_block.row / GemmMPerThread;
|
||||
|
||||
const index_t b_thread_data_on_global =
|
||||
b_block_data_on_global + c_thread_mtx_on_block.col / GemmNPerThread;
|
||||
|
||||
ThreadwiseGenericTensorSliceCopy_v4r2<
|
||||
decltype(in_e0_e1_c0_b0_b1_n0_thread_desc),
|
||||
decltype(in_e0_e1_c0_b0_b1_n0_global_desc),
|
||||
decltype(in_e0_e1_c0_b0_b1_n0_thread_desc.GetLengths()),
|
||||
Sequence<0, 1, 2, 3, 4, 5>,
|
||||
4,
|
||||
1,
|
||||
InThreadCopyDstDataPerWrite_B,
|
||||
AddressSpace::Vgpr,
|
||||
AddressSpace::Global,
|
||||
in_memory_op>(make_multi_index(0, 0, 0, 0, 0, 0),
|
||||
make_multi_index(e_thread_data_on_global / E1,
|
||||
e_thread_data_on_global % E1,
|
||||
0,
|
||||
b_thread_data_on_global / B1,
|
||||
b_thread_data_on_global % B1,
|
||||
0))
|
||||
.Run(p_in_thread, p_in_global);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
@@ -1,418 +0,0 @@
|
||||
#ifndef CK_GRIDWISE_CONVOLUTION_BACKWARD_DATA_IMPLICIT_GEMM_V4R1_NCHW_KCYX_NKHW_HPP
|
||||
#define CK_GRIDWISE_CONVOLUTION_BACKWARD_DATA_IMPLICIT_GEMM_V4R1_NCHW_KCYX_NKHW_HPP
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "tensor_descriptor.hpp"
|
||||
#include "tensor_descriptor_helper.hpp"
|
||||
#include "gridwise_gemm.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
// Number of GEMMs: YTilda * XTilda
|
||||
// GemmM = C
|
||||
// GemmN = N * HTildaSlice * WTildaSlice
|
||||
// GemmK = K * YDotSlice * XDotSlice
|
||||
template <index_t GridSize,
|
||||
index_t BlockSize,
|
||||
typename Float,
|
||||
typename AccFloat,
|
||||
typename InGlobalDesc,
|
||||
typename WeiGlobalDesc,
|
||||
typename OutGlobalDesc,
|
||||
typename ConvStrides,
|
||||
typename ConvDilations,
|
||||
typename InLeftPads,
|
||||
typename InRightPads,
|
||||
index_t GemmMPerBlock,
|
||||
index_t GemmNPerBlock,
|
||||
index_t GemmKPerBlock,
|
||||
index_t GemmMPerThread,
|
||||
index_t GemmNPerThread,
|
||||
index_t GemmKPerThread,
|
||||
index_t GemmMLevel0Cluster,
|
||||
index_t GemmNLevel0Cluster,
|
||||
index_t GemmMLevel1Cluster,
|
||||
index_t GemmNLevel1Cluster,
|
||||
index_t ThreadGemmDataPerRead_GemmM,
|
||||
index_t ThreadGemmDataPerRead_GemmN,
|
||||
typename GemmABlockCopyThreadSliceLengths_GemmK_GemmM,
|
||||
typename GemmABlockCopyThreadClusterLengths_GemmK_GemmM,
|
||||
index_t GemmABlockCopySrcDataPerRead_GemmM,
|
||||
index_t GemmABlockCopyDstDataPerWrite_GemmM,
|
||||
typename GemmBBlockCopyThreadSliceLengths_GemmK_GemmN,
|
||||
typename GemmBBlockCopyThreadClusterLengths_GemmK_GemmN,
|
||||
index_t GemmBBlockCopySrcDataPerRead_GemmN,
|
||||
index_t GemmBBlockCopyDstDataPerWrite_GemmN,
|
||||
index_t GemmCThreadCopyDstDataPerWrite_GemmN1>
|
||||
struct GridwiseConvolutionBackwardDataImplicitGemm_v4r1_nchw_kcyx_nkhw
|
||||
{
|
||||
__host__ __device__ static constexpr index_t GetNumberOfGemm()
|
||||
{
|
||||
constexpr index_t ConvStrideH = ConvStrides{}[0];
|
||||
constexpr index_t ConvStrideW = ConvStrides{}[1];
|
||||
|
||||
constexpr index_t ConvDilationH = ConvDilations{}[0];
|
||||
constexpr index_t ConvDilationW = ConvDilations{}[1];
|
||||
|
||||
constexpr index_t GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH);
|
||||
constexpr index_t GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW);
|
||||
|
||||
constexpr index_t YTilda = ConvStrideH / GcdStrideDilationH;
|
||||
constexpr index_t XTilda = ConvStrideW / GcdStrideDilationW;
|
||||
|
||||
return YTilda * XTilda;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto GetGemmSizeImpl(index_t iYTilda, index_t iXTilda)
|
||||
{
|
||||
constexpr index_t N = InGlobalDesc::GetLengths()[0];
|
||||
constexpr index_t C = InGlobalDesc::GetLengths()[1];
|
||||
constexpr index_t Hi = InGlobalDesc::GetLengths()[2];
|
||||
constexpr index_t Wi = InGlobalDesc::GetLengths()[3];
|
||||
|
||||
constexpr index_t K = OutGlobalDesc::GetLengths()[1];
|
||||
constexpr index_t Ho = OutGlobalDesc::GetLengths()[2];
|
||||
constexpr index_t Wo = OutGlobalDesc::GetLengths()[3];
|
||||
|
||||
constexpr index_t Y = WeiGlobalDesc::GetLengths()[2];
|
||||
constexpr index_t X = WeiGlobalDesc::GetLengths()[3];
|
||||
|
||||
constexpr index_t ConvStrideH = ConvStrides{}[0];
|
||||
constexpr index_t ConvStrideW = ConvStrides{}[1];
|
||||
|
||||
constexpr index_t ConvDilationH = ConvDilations{}[0];
|
||||
constexpr index_t ConvDilationW = ConvDilations{}[1];
|
||||
|
||||
constexpr index_t GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH);
|
||||
constexpr index_t GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW);
|
||||
|
||||
constexpr index_t YTilda = ConvStrideH / GcdStrideDilationH;
|
||||
constexpr index_t XTilda = ConvStrideW / GcdStrideDilationW;
|
||||
|
||||
constexpr index_t YDot = math::integer_divide_ceil(Y, YTilda);
|
||||
constexpr index_t XDot = math::integer_divide_ceil(X, XTilda);
|
||||
|
||||
constexpr index_t HTilda =
|
||||
Ho + math::integer_divide_ceil(ConvDilationH * (Y - 1), ConvStrideH);
|
||||
constexpr index_t WTilda =
|
||||
Wo + math::integer_divide_ceil(ConvDilationW * (X - 1), ConvStrideW);
|
||||
|
||||
// only work on HTilda and WTilda that contribute to non-padding area of input tensor
|
||||
constexpr index_t iHTildaLeft = math::integer_divide_floor(
|
||||
math::max(0, InLeftPads{}[0] - ConvDilationH * (YTilda - 1)), ConvStrides{}[0]);
|
||||
constexpr index_t iWTildaLeft = math::integer_divide_floor(
|
||||
math::max(0, InLeftPads{}[1] - ConvDilationW * (XTilda - 1)), ConvStrides{}[1]);
|
||||
|
||||
constexpr index_t iHTildaRight = math::min(
|
||||
HTilda, math::integer_divide_ceil(InLeftPads{}[0] + Hi - 1, ConvStrides{}[0]) + 1);
|
||||
constexpr index_t iWTildaRight = math::min(
|
||||
WTilda, math::integer_divide_ceil(InLeftPads{}[1] + Wi - 1, ConvStrides{}[1]) + 1);
|
||||
|
||||
constexpr index_t HTildaSlice = iHTildaRight - iHTildaLeft;
|
||||
constexpr index_t WTildaSlice = iWTildaRight - iWTildaLeft;
|
||||
|
||||
// GemmM and GemmN
|
||||
constexpr index_t GemmM = C;
|
||||
constexpr index_t GemmN = N * HTildaSlice * WTildaSlice;
|
||||
|
||||
// GemmK is different for each GEMM
|
||||
index_t YDotSlice = math::integer_divide_ceil(Y - iYTilda, YTilda);
|
||||
index_t XDotSlice = math::integer_divide_ceil(X - iXTilda, XTilda);
|
||||
|
||||
index_t GemmK = K * YDotSlice * XDotSlice;
|
||||
|
||||
return Array<index_t, 3>{GemmM, GemmN, GemmK};
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto GetGemmSize(index_t gemm_id)
|
||||
{
|
||||
constexpr index_t ConvStrideW = ConvStrides{}[1];
|
||||
|
||||
constexpr index_t ConvDilationW = ConvDilations{}[1];
|
||||
|
||||
constexpr index_t GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW);
|
||||
|
||||
constexpr index_t XTilda = ConvStrideW / GcdStrideDilationW;
|
||||
|
||||
index_t iYTilda = gemm_id / XTilda;
|
||||
index_t iXTilda = gemm_id % XTilda;
|
||||
|
||||
return GetGemmSizeImpl(iYTilda, iXTilda);
|
||||
}
|
||||
|
||||
template <index_t iYTilda, index_t iXTilda>
|
||||
__device__ static void RunImpl(Float* __restrict__ p_in_global,
|
||||
const Float* __restrict__ p_wei_global,
|
||||
const Float* __restrict__ p_out_global)
|
||||
{
|
||||
constexpr auto in_n_c_hi_wi_global_desc = InGlobalDesc{};
|
||||
constexpr auto wei_k_c_y_x_global_desc = WeiGlobalDesc{};
|
||||
constexpr auto out_n_k_ho_wo_global_desc = OutGlobalDesc{};
|
||||
|
||||
constexpr index_t N = in_n_c_hi_wi_global_desc.GetLengths()[0];
|
||||
constexpr index_t C = in_n_c_hi_wi_global_desc.GetLengths()[1];
|
||||
constexpr index_t Hi = in_n_c_hi_wi_global_desc.GetLengths()[2];
|
||||
constexpr index_t Wi = in_n_c_hi_wi_global_desc.GetLengths()[3];
|
||||
|
||||
constexpr index_t K = out_n_k_ho_wo_global_desc.GetLengths()[1];
|
||||
constexpr index_t Ho = out_n_k_ho_wo_global_desc.GetLengths()[2];
|
||||
constexpr index_t Wo = out_n_k_ho_wo_global_desc.GetLengths()[3];
|
||||
|
||||
constexpr index_t Y = wei_k_c_y_x_global_desc.GetLengths()[2];
|
||||
constexpr index_t X = wei_k_c_y_x_global_desc.GetLengths()[3];
|
||||
|
||||
constexpr index_t ConvStrideH = ConvStrides{}[0];
|
||||
constexpr index_t ConvStrideW = ConvStrides{}[1];
|
||||
|
||||
constexpr index_t ConvDilationH = ConvDilations{}[0];
|
||||
constexpr index_t ConvDilationW = ConvDilations{}[1];
|
||||
|
||||
constexpr index_t GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH);
|
||||
constexpr index_t GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW);
|
||||
|
||||
constexpr index_t YTilda = ConvStrideH / GcdStrideDilationH;
|
||||
constexpr index_t XTilda = ConvStrideW / GcdStrideDilationW;
|
||||
|
||||
constexpr index_t YDot = math::integer_divide_ceil(Y, YTilda);
|
||||
constexpr index_t XDot = math::integer_divide_ceil(X, XTilda);
|
||||
|
||||
constexpr index_t YDotSlice = math::integer_divide_ceil(Y - iYTilda, YTilda);
|
||||
constexpr index_t XDotSlice = math::integer_divide_ceil(X - iXTilda, XTilda);
|
||||
|
||||
constexpr index_t HTilda =
|
||||
Ho + math::integer_divide_ceil(ConvDilationH * (Y - 1), ConvStrideH);
|
||||
constexpr index_t WTilda =
|
||||
Wo + math::integer_divide_ceil(ConvDilationW * (X - 1), ConvStrideW);
|
||||
|
||||
// only work on HTilda and WTilda that contribute to non-padding area of input tensor
|
||||
constexpr index_t iHTildaLeft = math::integer_divide_floor(
|
||||
math::max(0, InLeftPads{}[0] - ConvDilationH * (YTilda - 1)), ConvStrides{}[0]);
|
||||
constexpr index_t iWTildaLeft = math::integer_divide_floor(
|
||||
math::max(0, InLeftPads{}[1] - ConvDilationW * (XTilda - 1)), ConvStrides{}[1]);
|
||||
|
||||
constexpr index_t iHTildaRight = math::min(
|
||||
HTilda, math::integer_divide_ceil(InLeftPads{}[0] + Hi - 1, ConvStrides{}[0]) + 1);
|
||||
constexpr index_t iWTildaRight = math::min(
|
||||
WTilda, math::integer_divide_ceil(InLeftPads{}[1] + Wi - 1, ConvStrides{}[1]) + 1);
|
||||
|
||||
constexpr index_t HTildaSlice = iHTildaRight - iHTildaLeft;
|
||||
constexpr index_t WTildaSlice = iWTildaRight - iWTildaLeft;
|
||||
|
||||
// A matrix: weight
|
||||
// weight out-of-bound check can be skipped
|
||||
constexpr bool wei_skip_out_of_bound_check = true;
|
||||
|
||||
constexpr auto wei_k_c_ydot_ytilda_xdot_xtilda_global_desc = transform_tensor_descriptor(
|
||||
wei_k_c_y_x_global_desc,
|
||||
make_tuple(PassThrough<K>{},
|
||||
PassThrough<C>{},
|
||||
Embed<Y,
|
||||
Sequence<YDot, YTilda>,
|
||||
Sequence<ConvStrideH / GcdStrideDilationH, 1, 0>,
|
||||
wei_skip_out_of_bound_check>{},
|
||||
Embed<X,
|
||||
Sequence<XDot, XTilda>,
|
||||
Sequence<ConvStrideW / GcdStrideDilationW, 1, 0>,
|
||||
wei_skip_out_of_bound_check>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}));
|
||||
|
||||
constexpr auto wei_k_c_ydotslice_xdotslice_global_desc = transform_tensor_descriptor(
|
||||
wei_k_c_ydot_ytilda_xdot_xtilda_global_desc,
|
||||
make_tuple(
|
||||
PassThrough<K>{},
|
||||
PassThrough<C>{},
|
||||
Slice<Sequence<YDot, XDot>, Sequence<0, 0>, Sequence<YDotSlice, XDotSlice>>{},
|
||||
Freeze<Sequence<YTilda, XTilda>, Sequence<iYTilda, iXTilda>>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 4>{}, Sequence<3, 5>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<>{}));
|
||||
|
||||
constexpr auto wei_gemmk_gemmm_global_desc = transform_tensor_descriptor(
|
||||
wei_k_c_ydotslice_xdotslice_global_desc,
|
||||
make_tuple(Merge<Sequence<K, YDotSlice, XDotSlice>>{}, PassThrough<C>{}),
|
||||
make_tuple(Sequence<0, 2, 3>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
// B matrix: output tensor
|
||||
// TODO sometimes output tensor out-of-bound check can be skipped, find out all such
|
||||
// situations
|
||||
#if !CK_EXPERIMENTAL_IMPLICIT_GEMM_BACKWARD_DATA_V4R1_OUTPUT_SKIP_OUT_OF_BOUND_CHECK
|
||||
constexpr bool out_skip_out_of_bound_check = false;
|
||||
#else
|
||||
constexpr bool out_skip_out_of_bound_check = true;
|
||||
#endif
|
||||
|
||||
constexpr auto out_n_k_ydot_htilda_xdot_wtilda_global_desc = transform_tensor_descriptor(
|
||||
out_n_k_ho_wo_global_desc,
|
||||
make_tuple(PassThrough<N>{},
|
||||
PassThrough<K>{},
|
||||
Embed<Ho,
|
||||
Sequence<YDot, HTilda>,
|
||||
Sequence<-ConvDilationH / GcdStrideDilationH, 1, 0>,
|
||||
out_skip_out_of_bound_check>{},
|
||||
Embed<Wo,
|
||||
Sequence<XDot, WTilda>,
|
||||
Sequence<-ConvDilationW / GcdStrideDilationW, 1, 0>,
|
||||
out_skip_out_of_bound_check>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}));
|
||||
|
||||
constexpr auto out_n_k_ydot_htildaslice_xdot_wtildaslice_global_desc =
|
||||
transform_tensor_descriptor(
|
||||
out_n_k_ydot_htilda_xdot_wtilda_global_desc,
|
||||
make_tuple(PassThrough<N>{},
|
||||
PassThrough<K>{},
|
||||
PassThrough<YDot>{},
|
||||
PassThrough<XDot>{},
|
||||
Slice<Sequence<HTilda, WTilda>,
|
||||
Sequence<iHTildaLeft, iWTildaLeft>,
|
||||
Sequence<iHTildaRight, iWTildaRight>>{}),
|
||||
make_tuple(
|
||||
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<4>{}, Sequence<3, 5>{}),
|
||||
make_tuple(
|
||||
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<4>{}, Sequence<3, 5>{}));
|
||||
|
||||
constexpr auto out_n_k_ydotslice_htildaslice_xdotslice_wtildaslice_global_desc =
|
||||
transform_tensor_descriptor(
|
||||
out_n_k_ydot_htildaslice_xdot_wtildaslice_global_desc,
|
||||
make_tuple(
|
||||
PassThrough<N>{},
|
||||
PassThrough<K>{},
|
||||
PassThrough<HTildaSlice>{},
|
||||
PassThrough<WTildaSlice>{},
|
||||
Slice<Sequence<YDot, XDot>, Sequence<0, 0>, Sequence<YDotSlice, XDotSlice>>{}),
|
||||
make_tuple(
|
||||
Sequence<0>{}, Sequence<1>{}, Sequence<3>{}, Sequence<5>{}, Sequence<2, 4>{}),
|
||||
make_tuple(
|
||||
Sequence<0>{}, Sequence<1>{}, Sequence<3>{}, Sequence<5>{}, Sequence<2, 4>{}));
|
||||
|
||||
constexpr auto out_gemmk_gemmn_global_desc = transform_tensor_descriptor(
|
||||
out_n_k_ydotslice_htildaslice_xdotslice_wtildaslice_global_desc,
|
||||
make_tuple(Merge<Sequence<K, YDotSlice, XDotSlice>>{},
|
||||
Merge<Sequence<N, HTildaSlice, WTildaSlice>>{}),
|
||||
make_tuple(Sequence<1, 2, 4>{}, Sequence<0, 3, 5>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
// C matrix: input tensor
|
||||
// TODO sometimes input out-of-bound check can be skipped, find out all such situations
|
||||
#if !CK_EXPERIMENTAL_IMPLICIT_GEMM_BACKWARD_DATA_V4R1_INPUT_SKIP_OUT_OF_BOUND_CHECK
|
||||
constexpr bool in_skip_out_of_bound_check = false;
|
||||
#else
|
||||
constexpr bool in_skip_out_of_bound_check = true;
|
||||
#endif
|
||||
|
||||
constexpr auto in_n_c_hip_wip_global_desc = transform_tensor_descriptor(
|
||||
in_n_c_hi_wi_global_desc,
|
||||
make_tuple(
|
||||
PassThrough<N>{},
|
||||
PassThrough<C>{},
|
||||
Pad<Sequence<Hi, Wi>, InLeftPads, InRightPads, in_skip_out_of_bound_check>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}));
|
||||
|
||||
constexpr index_t Hip = in_n_c_hip_wip_global_desc.GetLengths()[2];
|
||||
constexpr index_t Wip = in_n_c_hip_wip_global_desc.GetLengths()[3];
|
||||
|
||||
constexpr auto in_n_c_ytilda_htilda_xtilda_wtilda_global_desc = transform_tensor_descriptor(
|
||||
in_n_c_hip_wip_global_desc,
|
||||
make_tuple(PassThrough<N>{},
|
||||
PassThrough<C>{},
|
||||
Embed<Hip,
|
||||
Sequence<YTilda, HTilda>,
|
||||
Sequence<ConvDilationH, ConvStrideH, 0>,
|
||||
in_skip_out_of_bound_check>{},
|
||||
Embed<Wip,
|
||||
Sequence<XTilda, WTilda>,
|
||||
Sequence<ConvDilationW, ConvStrideW, 0>,
|
||||
in_skip_out_of_bound_check>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}));
|
||||
|
||||
constexpr auto in_n_c_htildaslice_wtildaslice_global_desc = transform_tensor_descriptor(
|
||||
in_n_c_ytilda_htilda_xtilda_wtilda_global_desc,
|
||||
make_tuple(PassThrough<N>{},
|
||||
PassThrough<C>{},
|
||||
Freeze<Sequence<YTilda, XTilda>, Sequence<iYTilda, iXTilda>>{},
|
||||
Slice<Sequence<HTilda, WTilda>,
|
||||
Sequence<iHTildaLeft, iWTildaLeft>,
|
||||
Sequence<iHTildaRight, iWTildaRight>>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 4>{}, Sequence<3, 5>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<>{}, Sequence<2, 3>{}));
|
||||
|
||||
constexpr auto in_gemmm_gemmn_global_desc = transform_tensor_descriptor(
|
||||
in_n_c_htildaslice_wtildaslice_global_desc,
|
||||
make_tuple(PassThrough<C>{}, Merge<Sequence<N, HTildaSlice, WTildaSlice>>{}),
|
||||
make_tuple(Sequence<1>{}, Sequence<0, 2, 3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
constexpr auto gridwise_gemm =
|
||||
GridwiseGemmTransposedANormalBNormalC_v1<GridSize,
|
||||
BlockSize,
|
||||
Float,
|
||||
AccFloat,
|
||||
decltype(wei_gemmk_gemmm_global_desc),
|
||||
decltype(out_gemmk_gemmn_global_desc),
|
||||
decltype(in_gemmm_gemmn_global_desc),
|
||||
InMemoryDataOperation::Set,
|
||||
GemmMPerBlock,
|
||||
GemmNPerBlock,
|
||||
GemmKPerBlock,
|
||||
GemmMPerThread,
|
||||
GemmNPerThread,
|
||||
GemmKPerThread,
|
||||
GemmMLevel0Cluster,
|
||||
GemmNLevel0Cluster,
|
||||
GemmMLevel1Cluster,
|
||||
GemmNLevel1Cluster,
|
||||
ThreadGemmDataPerRead_GemmM,
|
||||
ThreadGemmDataPerRead_GemmN,
|
||||
GemmABlockCopyThreadSliceLengths_GemmK_GemmM,
|
||||
GemmABlockCopyThreadClusterLengths_GemmK_GemmM,
|
||||
Sequence<0, 1>,
|
||||
Sequence<0, 1>,
|
||||
1,
|
||||
GemmABlockCopySrcDataPerRead_GemmM,
|
||||
GemmABlockCopyDstDataPerWrite_GemmM,
|
||||
GemmBBlockCopyThreadSliceLengths_GemmK_GemmN,
|
||||
GemmBBlockCopyThreadClusterLengths_GemmK_GemmN,
|
||||
Sequence<0, 1>,
|
||||
Sequence<0, 1>,
|
||||
1,
|
||||
GemmBBlockCopySrcDataPerRead_GemmN,
|
||||
GemmBBlockCopyDstDataPerWrite_GemmN,
|
||||
Sequence<0, 1, 2, 3>,
|
||||
3,
|
||||
GemmCThreadCopyDstDataPerWrite_GemmN1>{};
|
||||
|
||||
gridwise_gemm.Run(p_wei_global, p_out_global, p_in_global);
|
||||
}
|
||||
|
||||
template <index_t GemmId>
|
||||
__device__ static void Run(Float* __restrict__ p_in_global,
|
||||
const Float* __restrict__ p_wei_global,
|
||||
const Float* __restrict__ p_out_global,
|
||||
Number<GemmId>)
|
||||
{
|
||||
constexpr index_t ConvStrideH = ConvStrides{}[0];
|
||||
constexpr index_t ConvStrideW = ConvStrides{}[1];
|
||||
|
||||
constexpr index_t ConvDilationH = ConvDilations{}[0];
|
||||
constexpr index_t ConvDilationW = ConvDilations{}[1];
|
||||
|
||||
constexpr index_t GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH);
|
||||
constexpr index_t GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW);
|
||||
|
||||
constexpr index_t YTilda = ConvStrideH / GcdStrideDilationH;
|
||||
constexpr index_t XTilda = ConvStrideW / GcdStrideDilationW;
|
||||
|
||||
constexpr index_t iYTilda = GemmId / XTilda;
|
||||
constexpr index_t iXTilda = GemmId % XTilda;
|
||||
|
||||
static_assert(iYTilda < YTilda && iXTilda < XTilda, "wrong! iYtilda, iXtilda");
|
||||
|
||||
RunImpl<iYTilda, iXTilda>(p_in_global, p_wei_global, p_out_global);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
@@ -1,406 +0,0 @@
|
||||
#ifndef CK_GRIDWISE_CONVOLUTION_BACKWARD_DATA_IMPLICIT_GEMM_V5R1_NHWC_KYXC_NHWK_HPP
|
||||
#define CK_GRIDWISE_CONVOLUTION_BACKWARD_DATA_IMPLICIT_GEMM_V5R1_NHWC_KYXC_NHWK_HPP
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "tensor_descriptor.hpp"
|
||||
#include "tensor_descriptor_helper.hpp"
|
||||
#include "gridwise_gemm.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
// Number of GEMMs = YTilda * XTilda
|
||||
// GemmM = C
|
||||
// GemmN = N * HTildaSlice * WTildaSlice
|
||||
// GemmK0 = YDotSlice
|
||||
// GemmK1 = XDotSlice
|
||||
// GemmK2 = K
|
||||
template <index_t GridSize,
|
||||
index_t BlockSize,
|
||||
typename Float,
|
||||
typename AccFloat,
|
||||
typename InGlobalDesc,
|
||||
typename WeiGlobalDesc,
|
||||
typename OutGlobalDesc,
|
||||
typename ConvStrides,
|
||||
typename ConvDilations,
|
||||
typename InLeftPads,
|
||||
typename InRightPads,
|
||||
index_t GemmMPerBlock,
|
||||
index_t GemmNPerBlock,
|
||||
index_t GemmKPerBlock,
|
||||
index_t GemmMPerThread,
|
||||
index_t GemmNPerThread,
|
||||
index_t GemmKPerThread,
|
||||
index_t GemmMLevel0Cluster,
|
||||
index_t GemmNLevel0Cluster,
|
||||
index_t GemmMLevel1Cluster,
|
||||
index_t GemmNLevel1Cluster,
|
||||
index_t ThreadGemmDataPerRead_GemmM,
|
||||
index_t ThreadGemmDataPerRead_GemmN,
|
||||
typename GemmABlockCopyThreadSliceLengths_GemmK0_GemmK1_GemmK2_GemmM,
|
||||
typename GemmABlockCopyThreadClusterLengths_GemmK0_GemmK1_GemmK2_GemmM,
|
||||
index_t GemmABlockCopySrcDataPerRead_GemmM,
|
||||
index_t GemmABlockCopyDstDataPerWrite_GemmM,
|
||||
typename GemmBBlockCopyThreadSliceLengths_GemmK0_GemmK1_GemmK2_GemmN,
|
||||
typename GemmBBlockCopyThreadClusterLengths_GemmK0_GemmK1_GemmK2_GemmN,
|
||||
index_t GemmBBlockCopySrcDataPerRead_GemmK2,
|
||||
index_t GemmBBlockCopyDstDataPerWrite_GemmN,
|
||||
index_t GemmCThreadCopyDstDataPerWrite_GemmN1>
|
||||
struct GridwiseConvolutionBackwardDataImplicitGemm_v5r1_nhwc_kyxc_nhwk
|
||||
{
|
||||
__host__ __device__ static constexpr index_t GetNumberOfGemm()
|
||||
{
|
||||
constexpr index_t ConvStrideH = ConvStrides{}[0];
|
||||
constexpr index_t ConvStrideW = ConvStrides{}[1];
|
||||
|
||||
constexpr index_t ConvDilationH = ConvDilations{}[0];
|
||||
constexpr index_t ConvDilationW = ConvDilations{}[1];
|
||||
|
||||
constexpr index_t GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH);
|
||||
constexpr index_t GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW);
|
||||
|
||||
constexpr index_t YTilda = ConvStrideH / GcdStrideDilationH;
|
||||
constexpr index_t XTilda = ConvStrideW / GcdStrideDilationW;
|
||||
|
||||
return YTilda * XTilda;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto GetGemmSizeImpl(index_t iYTilda, index_t iXTilda)
|
||||
{
|
||||
constexpr index_t N = InGlobalDesc::GetLengths()[0];
|
||||
constexpr index_t Hi = InGlobalDesc::GetLengths()[1];
|
||||
constexpr index_t Wi = InGlobalDesc::GetLengths()[2];
|
||||
constexpr index_t C = InGlobalDesc::GetLengths()[3];
|
||||
|
||||
constexpr index_t Ho = OutGlobalDesc::GetLengths()[1];
|
||||
constexpr index_t Wo = OutGlobalDesc::GetLengths()[2];
|
||||
constexpr index_t K = OutGlobalDesc::GetLengths()[3];
|
||||
|
||||
constexpr index_t Y = WeiGlobalDesc::GetLengths()[1];
|
||||
constexpr index_t X = WeiGlobalDesc::GetLengths()[2];
|
||||
|
||||
constexpr index_t ConvStrideH = ConvStrides{}[0];
|
||||
constexpr index_t ConvStrideW = ConvStrides{}[1];
|
||||
|
||||
constexpr index_t ConvDilationH = ConvDilations{}[0];
|
||||
constexpr index_t ConvDilationW = ConvDilations{}[1];
|
||||
|
||||
constexpr index_t GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH);
|
||||
constexpr index_t GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW);
|
||||
|
||||
constexpr index_t YTilda = ConvStrideH / GcdStrideDilationH;
|
||||
constexpr index_t XTilda = ConvStrideW / GcdStrideDilationW;
|
||||
|
||||
constexpr index_t YDot = math::integer_divide_ceil(Y, YTilda);
|
||||
constexpr index_t XDot = math::integer_divide_ceil(X, XTilda);
|
||||
|
||||
constexpr index_t HTilda =
|
||||
Ho + math::integer_divide_ceil(ConvDilationH * (Y - 1), ConvStrideH);
|
||||
constexpr index_t WTilda =
|
||||
Wo + math::integer_divide_ceil(ConvDilationW * (X - 1), ConvStrideW);
|
||||
|
||||
// only work on HTilda and WTilda that contribute to non-padding area of input tensor
|
||||
constexpr index_t iHTildaLeft = math::integer_divide_floor(
|
||||
math::max(0, InLeftPads{}[0] - ConvDilationH * (YTilda - 1)), ConvStrides{}[0]);
|
||||
constexpr index_t iWTildaLeft = math::integer_divide_floor(
|
||||
math::max(0, InLeftPads{}[1] - ConvDilationW * (XTilda - 1)), ConvStrides{}[1]);
|
||||
|
||||
constexpr index_t iHTildaRight = math::min(
|
||||
HTilda, math::integer_divide_ceil(InLeftPads{}[0] + Hi - 1, ConvStrides{}[0]) + 1);
|
||||
constexpr index_t iWTildaRight = math::min(
|
||||
WTilda, math::integer_divide_ceil(InLeftPads{}[1] + Wi - 1, ConvStrides{}[1]) + 1);
|
||||
|
||||
constexpr index_t HTildaSlice = iHTildaRight - iHTildaLeft;
|
||||
constexpr index_t WTildaSlice = iWTildaRight - iWTildaLeft;
|
||||
|
||||
// GemmM and GemmN
|
||||
constexpr index_t GemmM = C;
|
||||
constexpr index_t GemmN = N * HTildaSlice * WTildaSlice;
|
||||
|
||||
// GemmK is different for each GEMM
|
||||
index_t YDotSlice = math::integer_divide_ceil(Y - iYTilda, YTilda);
|
||||
index_t XDotSlice = math::integer_divide_ceil(X - iXTilda, XTilda);
|
||||
|
||||
index_t GemmK0 = YDotSlice;
|
||||
index_t GemmK1 = XDotSlice;
|
||||
index_t GemmK2 = K;
|
||||
|
||||
return make_multi_index(GemmM, GemmN, GemmK0, GemmK1, GemmK2);
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto GetGemmSize(index_t gemm_id)
|
||||
{
|
||||
constexpr index_t ConvStrideW = ConvStrides{}[1];
|
||||
|
||||
constexpr index_t ConvDilationW = ConvDilations{}[1];
|
||||
|
||||
constexpr index_t GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW);
|
||||
|
||||
constexpr index_t XTilda = ConvStrideW / GcdStrideDilationW;
|
||||
|
||||
index_t iYTilda = gemm_id / XTilda;
|
||||
index_t iXTilda = gemm_id % XTilda;
|
||||
|
||||
return GetGemmSizeImpl(iYTilda, iXTilda);
|
||||
}
|
||||
|
||||
template <index_t iYTilda, index_t iXTilda>
|
||||
__device__ static void RunImpl(Float* __restrict__ p_in_global,
|
||||
const Float* __restrict__ p_wei_global,
|
||||
const Float* __restrict__ p_out_global)
|
||||
{
|
||||
constexpr auto in_n_hi_wi_c_global_desc = InGlobalDesc{};
|
||||
constexpr auto wei_k_y_x_c_global_desc = WeiGlobalDesc{};
|
||||
constexpr auto out_n_ho_wo_k_global_desc = OutGlobalDesc{};
|
||||
|
||||
constexpr index_t N = in_n_hi_wi_c_global_desc.GetLengths()[0];
|
||||
constexpr index_t Hi = in_n_hi_wi_c_global_desc.GetLengths()[1];
|
||||
constexpr index_t Wi = in_n_hi_wi_c_global_desc.GetLengths()[2];
|
||||
constexpr index_t C = in_n_hi_wi_c_global_desc.GetLengths()[3];
|
||||
|
||||
constexpr index_t Ho = out_n_ho_wo_k_global_desc.GetLengths()[1];
|
||||
constexpr index_t Wo = out_n_ho_wo_k_global_desc.GetLengths()[2];
|
||||
constexpr index_t K = out_n_ho_wo_k_global_desc.GetLengths()[3];
|
||||
|
||||
constexpr index_t Y = wei_k_y_x_c_global_desc.GetLengths()[1];
|
||||
constexpr index_t X = wei_k_y_x_c_global_desc.GetLengths()[2];
|
||||
|
||||
constexpr index_t ConvStrideH = ConvStrides{}[0];
|
||||
constexpr index_t ConvStrideW = ConvStrides{}[1];
|
||||
|
||||
constexpr index_t ConvDilationH = ConvDilations{}[0];
|
||||
constexpr index_t ConvDilationW = ConvDilations{}[1];
|
||||
|
||||
constexpr index_t GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH);
|
||||
constexpr index_t GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW);
|
||||
|
||||
constexpr index_t YTilda = ConvStrideH / GcdStrideDilationH;
|
||||
constexpr index_t XTilda = ConvStrideW / GcdStrideDilationW;
|
||||
|
||||
constexpr index_t YDot = math::integer_divide_ceil(Y, YTilda);
|
||||
constexpr index_t XDot = math::integer_divide_ceil(X, XTilda);
|
||||
|
||||
constexpr index_t YDotSlice = math::integer_divide_ceil(Y - iYTilda, YTilda);
|
||||
constexpr index_t XDotSlice = math::integer_divide_ceil(X - iXTilda, XTilda);
|
||||
|
||||
constexpr index_t HTilda =
|
||||
Ho + math::integer_divide_ceil(ConvDilationH * (Y - 1), ConvStrideH);
|
||||
constexpr index_t WTilda =
|
||||
Wo + math::integer_divide_ceil(ConvDilationW * (X - 1), ConvStrideW);
|
||||
|
||||
// only work on HTilda and WTilda that contribute to non-padding area of input tensor
|
||||
constexpr index_t iHTildaLeft = math::integer_divide_floor(
|
||||
math::max(0, InLeftPads{}[0] - ConvDilationH * (YTilda - 1)), ConvStrides{}[0]);
|
||||
constexpr index_t iWTildaLeft = math::integer_divide_floor(
|
||||
math::max(0, InLeftPads{}[1] - ConvDilationW * (XTilda - 1)), ConvStrides{}[1]);
|
||||
|
||||
constexpr index_t iHTildaRight = math::min(
|
||||
HTilda, math::integer_divide_ceil(InLeftPads{}[0] + Hi - 1, ConvStrides{}[0]) + 1);
|
||||
constexpr index_t iWTildaRight = math::min(
|
||||
WTilda, math::integer_divide_ceil(InLeftPads{}[1] + Wi - 1, ConvStrides{}[1]) + 1);
|
||||
|
||||
constexpr index_t HTildaSlice = iHTildaRight - iHTildaLeft;
|
||||
constexpr index_t WTildaSlice = iWTildaRight - iWTildaLeft;
|
||||
|
||||
// A matrix: weight
|
||||
// weight out-of-bound check can be skipped
|
||||
constexpr bool wei_skip_out_of_bound_check = true;
|
||||
|
||||
constexpr auto wei_k_ydot_ytilda_xdot_xtilda_c_global_desc = transform_tensor_descriptor(
|
||||
wei_k_y_x_c_global_desc,
|
||||
make_tuple(PassThrough<K>{},
|
||||
Embed<Y,
|
||||
Sequence<YDot, YTilda>,
|
||||
Sequence<ConvStrideH / GcdStrideDilationH, 1, 0>,
|
||||
wei_skip_out_of_bound_check>{},
|
||||
Embed<X,
|
||||
Sequence<XDot, XTilda>,
|
||||
Sequence<ConvStrideW / GcdStrideDilationW, 1, 0>,
|
||||
wei_skip_out_of_bound_check>{},
|
||||
PassThrough<C>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
|
||||
|
||||
constexpr auto wei_k_ydotslice_xdotslice_c_global_desc = transform_tensor_descriptor(
|
||||
wei_k_ydot_ytilda_xdot_xtilda_c_global_desc,
|
||||
make_tuple(
|
||||
PassThrough<K>{},
|
||||
Slice<Sequence<YDot, XDot>, Sequence<0, 0>, Sequence<YDotSlice, XDotSlice>>{},
|
||||
Freeze<Sequence<YTilda, XTilda>, Sequence<iYTilda, iXTilda>>{},
|
||||
PassThrough<C>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1, 3>{}, Sequence<2, 4>{}, Sequence<5>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<>{}, Sequence<3>{}));
|
||||
|
||||
constexpr auto wei_gemmk0_gemmk1_gemmk2_gemmm_global_desc =
|
||||
reorder_tensor_descriptor_given_lower2upper(wei_k_ydotslice_xdotslice_c_global_desc,
|
||||
Sequence<2, 0, 1, 3>{});
|
||||
|
||||
// B matrix: output tensor
|
||||
// TODO sometimes output tensor out-of-bound check can be skipped, find out all such
|
||||
// situations
|
||||
#if !CK_EXPERIMENTAL_IMPLICIT_GEMM_BACKWARD_DATA_V4R1_OUTPUT_SKIP_OUT_OF_BOUND_CHECK
|
||||
constexpr bool out_skip_out_of_bound_check = false;
|
||||
#else
|
||||
constexpr bool out_skip_out_of_bound_check = true;
|
||||
#endif
|
||||
|
||||
constexpr auto out_n_ydot_htilda_xdot_wtilda_k_global_desc = transform_tensor_descriptor(
|
||||
out_n_ho_wo_k_global_desc,
|
||||
make_tuple(PassThrough<N>{},
|
||||
Embed<Ho,
|
||||
Sequence<YDot, HTilda>,
|
||||
Sequence<-ConvDilationH / GcdStrideDilationH, 1, 0>,
|
||||
out_skip_out_of_bound_check>{},
|
||||
Embed<Wo,
|
||||
Sequence<XDot, WTilda>,
|
||||
Sequence<-ConvDilationW / GcdStrideDilationW, 1, 0>,
|
||||
out_skip_out_of_bound_check>{},
|
||||
PassThrough<K>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
|
||||
|
||||
constexpr auto out_n_ydotslice_htildaslice_xdotslice_wtildaslice_k_global_desc =
|
||||
transform_tensor_descriptor(
|
||||
out_n_ydot_htilda_xdot_wtilda_k_global_desc,
|
||||
make_tuple(
|
||||
PassThrough<N>{},
|
||||
Slice<Sequence<YDot, XDot>, Sequence<0, 0>, Sequence<YDotSlice, XDotSlice>>{},
|
||||
Slice<Sequence<HTilda, WTilda>,
|
||||
Sequence<iHTildaLeft, iWTildaLeft>,
|
||||
Sequence<iHTildaRight, iWTildaRight>>{},
|
||||
PassThrough<K>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1, 3>{}, Sequence<2, 4>{}, Sequence<5>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1, 3>{}, Sequence<2, 4>{}, Sequence<5>{}));
|
||||
|
||||
constexpr auto out_gemmk0_gemmk1_gemmk2_gemmn_global_desc = transform_tensor_descriptor(
|
||||
out_n_ydotslice_htildaslice_xdotslice_wtildaslice_k_global_desc,
|
||||
make_tuple(PassThrough<YDotSlice>{},
|
||||
PassThrough<XDotSlice>{},
|
||||
PassThrough<K>{},
|
||||
Merge<Sequence<N, HTildaSlice, WTildaSlice>>{}),
|
||||
make_tuple(Sequence<1>{}, Sequence<3>{}, Sequence<5>{}, Sequence<0, 2, 4>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
|
||||
|
||||
// C matrix: input tensor
|
||||
// TODO sometimes input out-of-bound check can be skipped, find out all such situations
|
||||
#if !CK_EXPERIMENTAL_IMPLICIT_GEMM_BACKWARD_DATA_V4R1_INPUT_SKIP_OUT_OF_BOUND_CHECK
|
||||
constexpr bool in_skip_out_of_bound_check = false;
|
||||
#else
|
||||
constexpr bool in_skip_out_of_bound_check = true;
|
||||
#endif
|
||||
|
||||
constexpr auto in_n_hip_wip_c_global_desc = transform_tensor_descriptor(
|
||||
in_n_hi_wi_c_global_desc,
|
||||
make_tuple(PassThrough<N>{},
|
||||
Pad<Sequence<Hi, Wi>, InLeftPads, InRightPads, in_skip_out_of_bound_check>{},
|
||||
PassThrough<C>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}));
|
||||
|
||||
constexpr index_t Hip = in_n_hip_wip_c_global_desc.GetLengths()[1];
|
||||
constexpr index_t Wip = in_n_hip_wip_c_global_desc.GetLengths()[2];
|
||||
|
||||
constexpr auto in_n_ytilda_htilda_xtilda_wtilda_c_global_desc = transform_tensor_descriptor(
|
||||
in_n_hip_wip_c_global_desc,
|
||||
make_tuple(PassThrough<N>{},
|
||||
Embed<Hip,
|
||||
Sequence<YTilda, HTilda>,
|
||||
Sequence<ConvDilationH, ConvStrideH, 0>,
|
||||
in_skip_out_of_bound_check>{},
|
||||
Embed<Wip,
|
||||
Sequence<XTilda, WTilda>,
|
||||
Sequence<ConvDilationW, ConvStrideW, 0>,
|
||||
in_skip_out_of_bound_check>{},
|
||||
PassThrough<C>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
|
||||
|
||||
constexpr auto in_n_htildaslice_wtildaslice_c_global_desc = transform_tensor_descriptor(
|
||||
in_n_ytilda_htilda_xtilda_wtilda_c_global_desc,
|
||||
make_tuple(PassThrough<N>{},
|
||||
Freeze<Sequence<YTilda, XTilda>, Sequence<iYTilda, iXTilda>>{},
|
||||
Slice<Sequence<HTilda, WTilda>,
|
||||
Sequence<iHTildaLeft, iWTildaLeft>,
|
||||
Sequence<iHTildaRight, iWTildaRight>>{},
|
||||
PassThrough<C>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1, 3>{}, Sequence<2, 4>{}, Sequence<5>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<>{}, Sequence<1, 2>{}, Sequence<3>{}));
|
||||
|
||||
constexpr auto in_gemmm_gemmn_global_desc = transform_tensor_descriptor(
|
||||
in_n_htildaslice_wtildaslice_c_global_desc,
|
||||
make_tuple(PassThrough<C>{}, Merge<Sequence<N, HTildaSlice, WTildaSlice>>{}),
|
||||
make_tuple(Sequence<3>{}, Sequence<0, 1, 2>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
// call GEMM
|
||||
constexpr auto gridwise_gemm = GridwiseGemmTransposedANormalBNormalC_v2<
|
||||
GridSize,
|
||||
BlockSize,
|
||||
Float,
|
||||
AccFloat,
|
||||
decltype(wei_gemmk0_gemmk1_gemmk2_gemmm_global_desc),
|
||||
decltype(out_gemmk0_gemmk1_gemmk2_gemmn_global_desc),
|
||||
decltype(in_gemmm_gemmn_global_desc),
|
||||
InMemoryDataOperation::Set,
|
||||
GemmMPerBlock,
|
||||
GemmNPerBlock,
|
||||
GemmKPerBlock,
|
||||
GemmMPerThread,
|
||||
GemmNPerThread,
|
||||
GemmKPerThread,
|
||||
GemmMLevel0Cluster,
|
||||
GemmNLevel0Cluster,
|
||||
GemmMLevel1Cluster,
|
||||
GemmNLevel1Cluster,
|
||||
ThreadGemmDataPerRead_GemmM,
|
||||
ThreadGemmDataPerRead_GemmN,
|
||||
GemmABlockCopyThreadSliceLengths_GemmK0_GemmK1_GemmK2_GemmM,
|
||||
GemmABlockCopyThreadClusterLengths_GemmK0_GemmK1_GemmK2_GemmM,
|
||||
Sequence<0, 1, 2, 3>,
|
||||
Sequence<0, 1, 2, 3>,
|
||||
3,
|
||||
GemmABlockCopySrcDataPerRead_GemmM,
|
||||
GemmABlockCopyDstDataPerWrite_GemmM,
|
||||
GemmBBlockCopyThreadSliceLengths_GemmK0_GemmK1_GemmK2_GemmN,
|
||||
GemmBBlockCopyThreadClusterLengths_GemmK0_GemmK1_GemmK2_GemmN,
|
||||
Sequence<0, 1, 3, 2>,
|
||||
Sequence<0, 1, 3, 2>,
|
||||
2,
|
||||
GemmBBlockCopySrcDataPerRead_GemmK2,
|
||||
GemmBBlockCopyDstDataPerWrite_GemmN,
|
||||
Sequence<2, 3, 0, 1>,
|
||||
3,
|
||||
GemmCThreadCopyDstDataPerWrite_GemmN1>{};
|
||||
|
||||
gridwise_gemm.Run(p_wei_global, p_out_global, p_in_global);
|
||||
}
|
||||
|
||||
template <index_t GemmId>
|
||||
__device__ static void Run(Float* __restrict__ p_in_global,
|
||||
const Float* __restrict__ p_wei_global,
|
||||
const Float* __restrict__ p_out_global,
|
||||
Number<GemmId>)
|
||||
{
|
||||
constexpr index_t ConvStrideH = ConvStrides{}[0];
|
||||
constexpr index_t ConvStrideW = ConvStrides{}[1];
|
||||
|
||||
constexpr index_t ConvDilationH = ConvDilations{}[0];
|
||||
constexpr index_t ConvDilationW = ConvDilations{}[1];
|
||||
|
||||
constexpr index_t GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH);
|
||||
constexpr index_t GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW);
|
||||
|
||||
constexpr index_t YTilda = ConvStrideH / GcdStrideDilationH;
|
||||
constexpr index_t XTilda = ConvStrideW / GcdStrideDilationW;
|
||||
|
||||
constexpr index_t iYTilda = GemmId / XTilda;
|
||||
constexpr index_t iXTilda = GemmId % XTilda;
|
||||
|
||||
static_assert(iYTilda < YTilda && iXTilda < XTilda, "wrong! iYtilda, iXtilda");
|
||||
|
||||
RunImpl<iYTilda, iXTilda>(p_in_global, p_wei_global, p_out_global);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
@@ -1,454 +0,0 @@
|
||||
#ifndef CK_GRIDWISE_CONVOLUTION_FORWARD_IMPLICIT_GEMM_V4R1_NCHW_KCYX_NKHW_LDS_DOUBLE_BUFFER_HPP
|
||||
#define CK_GRIDWISE_CONVOLUTION_FORWARD_IMPLICIT_GEMM_V4R1_NCHW_KCYX_NKHW_LDS_DOUBLE_BUFFER_HPP
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "tensor_descriptor.hpp"
|
||||
#include "tensor_descriptor_helper.hpp"
|
||||
#include "ConstantMatrixDescriptor.hpp"
|
||||
#include "blockwise_generic_tensor_slice_copy.hpp"
|
||||
#include "threadwise_generic_tensor_slice_copy.hpp"
|
||||
#include "blockwise_gemm.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <index_t GridSize,
|
||||
index_t BlockSize,
|
||||
typename Float,
|
||||
typename AccDataType,
|
||||
typename InGlobalDesc,
|
||||
typename WeiGlobalDesc,
|
||||
typename OutGlobalDesc,
|
||||
typename ConvStrides,
|
||||
typename ConvDilations,
|
||||
typename LeftPads,
|
||||
typename RightPads,
|
||||
index_t BPerBlock,
|
||||
index_t KPerBlock,
|
||||
index_t EPerBlock,
|
||||
index_t GemmNRepeat,
|
||||
index_t GemmMPerThread,
|
||||
index_t GemmNPerThread,
|
||||
index_t GemmKPerThread,
|
||||
index_t GemmMLevel0Cluster,
|
||||
index_t GemmNLevel0Cluster,
|
||||
index_t GemmMLevel1Cluster,
|
||||
index_t GemmNLevel1Cluster,
|
||||
index_t GemmDataPerReadA,
|
||||
index_t GemmDataPerReadB,
|
||||
typename InBlockCopySubLengths_E_N1_B_N2,
|
||||
typename InBlockCopyClusterLengths_E_N1_B_N2,
|
||||
typename InBlockCopyThreadClusterArrangeOrder,
|
||||
typename InBlockCopySrcAccessOrder,
|
||||
typename InBlockCopyDstAccessOrder,
|
||||
index_t InBlockCopySrcDataPerRead_B,
|
||||
index_t InBlockCopyDstDataPerWrite_N2,
|
||||
typename WeiBlockCopySubLengths_E_K,
|
||||
typename WeiBlockCopyClusterLengths_E_K,
|
||||
typename WeiBlockCopyThreadClusterArrangeOrder,
|
||||
typename WeiBlockCopySrcAccessOrder,
|
||||
typename WeiBlockCopyDstAccessOrder,
|
||||
index_t WeiBlockCopySrcDataPerRead_E,
|
||||
index_t WeiBlockCopyDstDataPerWrite_K>
|
||||
struct GridwiseConvolutionForwardImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer
|
||||
{
|
||||
__device__ void Run(const Float* const __restrict__ p_in_global,
|
||||
const Float* const __restrict__ p_wei_global,
|
||||
Float* const __restrict__ p_out_global) const
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
constexpr auto True = integral_constant<bool, true>{};
|
||||
|
||||
// this is a mess
|
||||
// TODO: find more elegent way of specifying (or calculating) performance parameters
|
||||
constexpr index_t N1 = GemmNRepeat;
|
||||
constexpr index_t N2 = GemmNPerThread;
|
||||
|
||||
static_assert(
|
||||
(N1 * N2 * BPerBlock) % (GemmNPerThread * GemmNLevel0Cluster * GemmNLevel1Cluster) == 0,
|
||||
"wrong!");
|
||||
|
||||
constexpr auto in_n_c_hi_wi_global_desc = InGlobalDesc{};
|
||||
constexpr auto wei_k_c_y_x_global_desc = WeiGlobalDesc{};
|
||||
constexpr auto out_n_k_ho_wo_global_desc = OutGlobalDesc{};
|
||||
|
||||
constexpr index_t N = in_n_c_hi_wi_global_desc.GetLengths()[0];
|
||||
constexpr index_t C = in_n_c_hi_wi_global_desc.GetLengths()[1];
|
||||
constexpr index_t Hi = in_n_c_hi_wi_global_desc.GetLengths()[2];
|
||||
constexpr index_t Wi = in_n_c_hi_wi_global_desc.GetLengths()[3];
|
||||
|
||||
constexpr index_t K = out_n_k_ho_wo_global_desc.GetLengths()[1];
|
||||
constexpr index_t Ho = out_n_k_ho_wo_global_desc.GetLengths()[2];
|
||||
constexpr index_t Wo = out_n_k_ho_wo_global_desc.GetLengths()[3];
|
||||
|
||||
constexpr index_t Y = wei_k_c_y_x_global_desc.GetLengths()[2];
|
||||
constexpr index_t X = wei_k_c_y_x_global_desc.GetLengths()[3];
|
||||
|
||||
constexpr index_t ConvStrideH = ConvStrides{}[0];
|
||||
constexpr index_t ConvStrideW = ConvStrides{}[1];
|
||||
|
||||
constexpr index_t ConvDilationH = ConvDilations{}[0];
|
||||
constexpr index_t ConvDilationW = ConvDilations{}[1];
|
||||
|
||||
static_assert(N % (N1 * N2) == 0, "wrong! cannot divice N evenly among thread");
|
||||
|
||||
constexpr index_t N0 = N / (N1 * N2);
|
||||
|
||||
constexpr index_t B = N0 * Ho * Wo;
|
||||
|
||||
constexpr index_t E = C * Y * X;
|
||||
|
||||
// sanity-check for vectorized memory load
|
||||
static_assert((Wo == 1 || (ConvStrideW == 1 || InBlockCopySrcDataPerRead_B == 1)) &&
|
||||
(X == 1 || ConvDilationW % InBlockCopySrcDataPerRead_B == 0),
|
||||
"wrong! aligment requirement for vectorized global load of input tensor will "
|
||||
"be violated");
|
||||
|
||||
// divide block work by [K, B]
|
||||
static_assert(K % KPerBlock == 0 && B % BPerBlock == 0 && E % EPerBlock == 0,
|
||||
"wrong! cannot divide work evenly among block");
|
||||
|
||||
constexpr index_t KBlockWork = K / KPerBlock;
|
||||
constexpr index_t BBlockWork = B / BPerBlock;
|
||||
|
||||
constexpr auto block_work_desc =
|
||||
make_cluster_descriptor(Sequence<KBlockWork, BBlockWork>{});
|
||||
|
||||
const auto block_work_id = block_work_desc.CalculateClusterIndex(get_block_1d_id());
|
||||
|
||||
const index_t k_block_data_on_global = block_work_id[I0] * KPerBlock;
|
||||
const index_t b_block_data_on_global = block_work_id[I1] * BPerBlock;
|
||||
|
||||
// input tensor
|
||||
// global tensor in global memory
|
||||
constexpr auto in_n_c_hip_wip_global_desc = transform_tensor_descriptor(
|
||||
in_n_c_hi_wi_global_desc,
|
||||
make_tuple(
|
||||
PassThrough<N>{}, PassThrough<C>{}, Pad<Sequence<Hi, Wi>, LeftPads, RightPads>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}));
|
||||
|
||||
constexpr index_t Hip = in_n_c_hip_wip_global_desc.GetLengths()[2];
|
||||
constexpr index_t Wip = in_n_c_hip_wip_global_desc.GetLengths()[3];
|
||||
|
||||
constexpr auto in_n0_n1_n2_c_y_ho_x_wo_global_desc = transform_tensor_descriptor(
|
||||
in_n_c_hip_wip_global_desc,
|
||||
make_tuple(UnMerge<Sequence<N0, N1, N2>>{},
|
||||
PassThrough<C>{},
|
||||
Embed<Hip, Sequence<Y, Ho>, Sequence<ConvDilationH, ConvStrideH, 0>>{},
|
||||
Embed<Wip, Sequence<X, Wo>, Sequence<ConvDilationW, ConvStrideW, 0>>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0, 1, 2>{}, Sequence<3>{}, Sequence<4, 5>{}, Sequence<6, 7>{}));
|
||||
|
||||
// global tensor in global memory, src of blockwise copy
|
||||
constexpr auto in_e_n1_b_n2_global_desc = transform_tensor_descriptor(
|
||||
in_n0_n1_n2_c_y_ho_x_wo_global_desc,
|
||||
make_tuple(Merge<Sequence<C, Y, X>>{},
|
||||
PassThrough<N1>{},
|
||||
Merge<Sequence<N0, Ho, Wo>>{},
|
||||
PassThrough<N2>{}),
|
||||
make_tuple(Sequence<3, 4, 6>{}, Sequence<1>{}, Sequence<0, 5, 7>{}, Sequence<2>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
|
||||
|
||||
// block tensor in LDS memory, dst of blockwise copy
|
||||
// be careful of LDS alignment
|
||||
constexpr auto in_e_n1_b_n2_block_desc = make_native_tensor_descriptor_aligned(
|
||||
Sequence<EPerBlock, N1, BPerBlock, N2>{}, Number<InBlockCopyDstDataPerWrite_N2>{});
|
||||
|
||||
// this check is ad-hoc
|
||||
// TODO: need to properly implement tensor descriptor with multiple alignment
|
||||
// requirements
|
||||
static_assert(in_e_n1_b_n2_block_desc.GetStride(I1) % GemmDataPerReadB == 0,
|
||||
"GemmDataPerReadB alignment requirement is not satisfied");
|
||||
|
||||
// input tensor blockwise copy
|
||||
auto blockwise_in_copy =
|
||||
BlockwiseGenericTensorSliceCopy_v4<BlockSize,
|
||||
decltype(in_e_n1_b_n2_global_desc),
|
||||
decltype(in_e_n1_b_n2_block_desc),
|
||||
decltype(in_e_n1_b_n2_block_desc.GetLengths()),
|
||||
InBlockCopySubLengths_E_N1_B_N2,
|
||||
InBlockCopyClusterLengths_E_N1_B_N2,
|
||||
InBlockCopyThreadClusterArrangeOrder,
|
||||
InBlockCopySrcAccessOrder,
|
||||
InBlockCopyDstAccessOrder,
|
||||
2,
|
||||
3,
|
||||
InBlockCopySrcDataPerRead_B,
|
||||
InBlockCopyDstDataPerWrite_N2,
|
||||
AddressSpace::Global,
|
||||
AddressSpace::Vgpr,
|
||||
AddressSpace::Lds,
|
||||
InMemoryDataOperation::Set>(
|
||||
make_multi_index(0, 0, b_block_data_on_global, 0), make_multi_index(0, 0, 0, 0));
|
||||
|
||||
// weight tensor
|
||||
// global tensor in global memory, src of blockwise copy
|
||||
// It is constructed differently, depending on whether forward or backward weight
|
||||
// convolution
|
||||
constexpr auto wei_e_k_global_desc =
|
||||
transform_tensor_descriptor(unfold_tensor_descriptor(wei_k_c_y_x_global_desc, I2, I3),
|
||||
make_tuple(Merge<Sequence<C, Y * X>>{}, PassThrough<K>{}),
|
||||
make_tuple(Sequence<1, 2>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
// block tensor in LDS memory, dst of blockwise copy
|
||||
// be careful of LDS alignment
|
||||
constexpr auto wei_e_k_block_desc = make_native_tensor_descriptor_aligned(
|
||||
Sequence<EPerBlock, KPerBlock>{},
|
||||
Number<math::lcm(WeiBlockCopyDstDataPerWrite_K, GemmDataPerReadA)>{});
|
||||
|
||||
// this check is ad-hoc
|
||||
// TODO: need to properly implement tensor descriptor with multiple alignment
|
||||
// requirements
|
||||
static_assert(wei_e_k_block_desc.GetStride(I0) % GemmDataPerReadA == 0,
|
||||
"GemmDataPerReadA alignment requirement is not satisfied");
|
||||
|
||||
// weight tensor blockwise copy
|
||||
auto blockwise_wei_copy =
|
||||
BlockwiseGenericTensorSliceCopy_v4<BlockSize,
|
||||
decltype(wei_e_k_global_desc),
|
||||
decltype(wei_e_k_block_desc),
|
||||
decltype(wei_e_k_block_desc.GetLengths()),
|
||||
WeiBlockCopySubLengths_E_K,
|
||||
WeiBlockCopyClusterLengths_E_K,
|
||||
WeiBlockCopyThreadClusterArrangeOrder,
|
||||
WeiBlockCopySrcAccessOrder,
|
||||
WeiBlockCopyDstAccessOrder,
|
||||
0,
|
||||
1,
|
||||
WeiBlockCopySrcDataPerRead_E,
|
||||
WeiBlockCopyDstDataPerWrite_K,
|
||||
AddressSpace::Global,
|
||||
AddressSpace::Vgpr,
|
||||
AddressSpace::Lds,
|
||||
InMemoryDataOperation::Set>(
|
||||
make_multi_index(0, k_block_data_on_global), make_multi_index(0, 0));
|
||||
|
||||
// GEMM definition
|
||||
// c_mtx += transpose(a_mtx) * b_mtx
|
||||
// a_mtx[EPerBlock, KPerBlock] is in LDS
|
||||
// b_mtx[EPerBlocl, N1 * BPerBlock * N2] is in LDS
|
||||
// c_mtx[KPerBlock, N1 * BPerBlock * N2] is distributed among threads, and saved in
|
||||
// register
|
||||
constexpr auto a_e_k_block_mtx_desc = make_ConstantMatrixDescriptor(wei_e_k_block_desc);
|
||||
|
||||
constexpr auto b_e_n1bn2_block_mtx_desc = make_ConstantMatrixDescriptor(
|
||||
in_e_n1_b_n2_block_desc.GetLength(I0),
|
||||
in_e_n1_b_n2_block_desc.GetLength(I1) * in_e_n1_b_n2_block_desc.GetLength(I2) *
|
||||
in_e_n1_b_n2_block_desc.GetLength(I3),
|
||||
in_e_n1_b_n2_block_desc.GetStride(I0));
|
||||
|
||||
// sanity check
|
||||
static_assert(KPerBlock % (GemmMPerThread * GemmMLevel0Cluster * GemmMLevel1Cluster) == 0,
|
||||
"wrong!");
|
||||
|
||||
constexpr index_t GemmMRepeat =
|
||||
KPerBlock / (GemmMPerThread * GemmMLevel0Cluster * GemmMLevel1Cluster);
|
||||
|
||||
// c_thread_mtx definition: this is a mess
|
||||
// TODO:: more elegent way of defining c_thread_mtx
|
||||
constexpr auto c_k0k1_n1n2_thread_mtx_desc = make_ConstantMatrixDescriptor_packed(
|
||||
Number<GemmMRepeat * GemmMPerThread>{}, Number<GemmNRepeat * GemmNPerThread>{});
|
||||
|
||||
const auto blockwise_gemm = BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2<
|
||||
BlockSize,
|
||||
decltype(a_e_k_block_mtx_desc),
|
||||
decltype(b_e_n1bn2_block_mtx_desc),
|
||||
decltype(c_k0k1_n1n2_thread_mtx_desc),
|
||||
GemmMPerThread,
|
||||
GemmNPerThread,
|
||||
GemmKPerThread,
|
||||
GemmMLevel0Cluster,
|
||||
GemmNLevel0Cluster,
|
||||
GemmMLevel1Cluster,
|
||||
GemmNLevel1Cluster,
|
||||
GemmDataPerReadA,
|
||||
GemmDataPerReadB>{};
|
||||
|
||||
// LDS allocation for input and weight: be careful of alignment
|
||||
constexpr index_t max_align = math::lcm(InBlockCopyDstDataPerWrite_N2,
|
||||
WeiBlockCopyDstDataPerWrite_K,
|
||||
GemmDataPerReadA,
|
||||
GemmDataPerReadB);
|
||||
|
||||
constexpr index_t in_block_space =
|
||||
math::integer_least_multiple(in_e_n1_b_n2_block_desc.GetElementSpace(), max_align);
|
||||
|
||||
constexpr index_t wei_block_space =
|
||||
math::integer_least_multiple(wei_e_k_block_desc.GetElementSpace(), max_align);
|
||||
|
||||
__shared__ Float p_in_block_double[2 * in_block_space];
|
||||
__shared__ Float p_wei_block_double[2 * wei_block_space];
|
||||
|
||||
// register allocation for output
|
||||
AccDataType p_out_thread[c_k0k1_n1n2_thread_mtx_desc.GetElementSpace()];
|
||||
|
||||
// zero out threadwise output
|
||||
threadwise_matrix_set_zero(c_k0k1_n1n2_thread_mtx_desc, p_out_thread);
|
||||
|
||||
// LDS double buffer: preload data into LDS
|
||||
{
|
||||
blockwise_in_copy.Run(p_in_global, p_in_block_double);
|
||||
blockwise_wei_copy.Run(p_wei_global, p_wei_block_double);
|
||||
}
|
||||
|
||||
constexpr auto in_block_slice_copy_steps = Sequence<EPerBlock, 0, 0, 0>{};
|
||||
constexpr auto wei_block_slice_copy_steps = Sequence<EPerBlock, 0>{};
|
||||
|
||||
// LDS double buffer: main body
|
||||
for(index_t e_block_data_begin = 0; e_block_data_begin + 2 * EPerBlock < E;
|
||||
e_block_data_begin += 2 * EPerBlock)
|
||||
{
|
||||
#pragma unroll
|
||||
for(index_t iloop = 0; iloop < 2; ++iloop)
|
||||
{
|
||||
const bool even_loop = (iloop % 2 == 0);
|
||||
|
||||
Float* p_in_block_now =
|
||||
even_loop ? p_in_block_double : p_in_block_double + in_block_space;
|
||||
Float* p_wei_block_now =
|
||||
even_loop ? p_wei_block_double : p_wei_block_double + wei_block_space;
|
||||
|
||||
Float* p_in_block_next =
|
||||
even_loop ? p_in_block_double + in_block_space : p_in_block_double;
|
||||
Float* p_wei_block_next =
|
||||
even_loop ? p_wei_block_double + wei_block_space : p_wei_block_double;
|
||||
|
||||
Float p_in_thread_buffer[blockwise_in_copy.GetThreadBufferSize()];
|
||||
Float p_wei_thread_buffer[blockwise_wei_copy.GetThreadBufferSize()];
|
||||
|
||||
blockwise_in_copy.MoveSrcSliceWindow(in_block_slice_copy_steps, True);
|
||||
blockwise_wei_copy.MoveSrcSliceWindow(wei_block_slice_copy_steps, True);
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// LDS doubel buffer: load next data from device mem
|
||||
blockwise_in_copy.RunLoadThreadBuffer(p_in_global, p_in_thread_buffer);
|
||||
blockwise_wei_copy.RunLoadThreadBuffer(p_wei_global, p_wei_thread_buffer);
|
||||
|
||||
// LDS double buffer: GEMM on current data
|
||||
blockwise_gemm.Run(p_wei_block_now, p_in_block_now, p_out_thread);
|
||||
|
||||
// LDS double buffer: store next data to LDS
|
||||
blockwise_in_copy.RunStoreThreadBuffer(p_in_thread_buffer, p_in_block_next);
|
||||
blockwise_wei_copy.RunStoreThreadBuffer(p_wei_thread_buffer, p_wei_block_next);
|
||||
}
|
||||
}
|
||||
|
||||
// LDS double buffer: tail
|
||||
{
|
||||
constexpr bool has_two_iteration_left = (E % (2 * EPerBlock) == 0);
|
||||
|
||||
if(has_two_iteration_left) // if has 2 iteration left
|
||||
{
|
||||
Float p_in_thread_buffer[blockwise_in_copy.GetThreadBufferSize()];
|
||||
Float p_wei_thread_buffer[blockwise_wei_copy.GetThreadBufferSize()];
|
||||
|
||||
blockwise_in_copy.MoveSrcSliceWindow(in_block_slice_copy_steps, True);
|
||||
blockwise_wei_copy.MoveSrcSliceWindow(wei_block_slice_copy_steps, True);
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// LDS double buffer: load last data from device mem
|
||||
blockwise_in_copy.RunLoadThreadBuffer(p_in_global, p_in_thread_buffer);
|
||||
blockwise_wei_copy.RunLoadThreadBuffer(p_wei_global, p_wei_thread_buffer);
|
||||
|
||||
// LDS double buffer: GEMM on 2nd-last data
|
||||
blockwise_gemm.Run(p_wei_block_double, p_in_block_double, p_out_thread);
|
||||
|
||||
// LDS double buffer: store last data to LDS
|
||||
blockwise_in_copy.RunStoreThreadBuffer(p_in_thread_buffer,
|
||||
p_in_block_double + in_block_space);
|
||||
blockwise_wei_copy.RunStoreThreadBuffer(p_wei_thread_buffer,
|
||||
p_wei_block_double + wei_block_space);
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// LDS double buffer: GEMM on last data
|
||||
blockwise_gemm.Run(p_wei_block_double + wei_block_space,
|
||||
p_in_block_double + in_block_space,
|
||||
p_out_thread);
|
||||
}
|
||||
else // if has 1 iteration left
|
||||
{
|
||||
__syncthreads();
|
||||
|
||||
// LDS double buffer: GEMM on last data
|
||||
blockwise_gemm.Run(p_wei_block_double, p_in_block_double, p_out_thread);
|
||||
}
|
||||
}
|
||||
|
||||
// copy output: register to global memory
|
||||
{
|
||||
constexpr index_t K1 = GemmMPerThread * GemmMLevel0Cluster * GemmMLevel1Cluster;
|
||||
constexpr index_t K0 = K / K1;
|
||||
|
||||
// define output tensor descriptor for threadwise copy
|
||||
// thread output tensor, src of threadwise copy
|
||||
constexpr auto out_k0_k1_n1_b_n2_thread_desc = make_native_tensor_descriptor_packed(
|
||||
Sequence<GemmMRepeat, GemmMPerThread, N1, 1, N2>{});
|
||||
|
||||
// global output tensor
|
||||
constexpr auto out_n0_n1_n2_k0_k1_ho_wo_global_desc = transform_tensor_descriptor(
|
||||
out_n_k_ho_wo_global_desc,
|
||||
make_tuple(UnMerge<Sequence<N0, N1, N2>>{},
|
||||
UnMerge<Sequence<K0, K1>>{},
|
||||
PassThrough<Ho>{},
|
||||
PassThrough<Wo>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0, 1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}, Sequence<6>{}));
|
||||
|
||||
// global output tensor, dst of threadwise copy
|
||||
constexpr auto out_k0_k1_n1_b_n2_global_desc = transform_tensor_descriptor(
|
||||
out_n0_n1_n2_k0_k1_ho_wo_global_desc,
|
||||
make_tuple(PassThrough<K0>{},
|
||||
PassThrough<K1>{},
|
||||
PassThrough<N1>{},
|
||||
Merge<Sequence<N0, Ho, Wo>>{},
|
||||
PassThrough<N2>{}),
|
||||
make_tuple(Sequence<3>{},
|
||||
Sequence<4>{},
|
||||
Sequence<1>{},
|
||||
Sequence<0, 5, 6>{},
|
||||
Sequence<2>{}),
|
||||
make_tuple(
|
||||
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}));
|
||||
|
||||
// calculate origin of thread output tensor on global memory
|
||||
// blockwise GEMM c matrix starting index
|
||||
const auto c_thread_mtx_on_block =
|
||||
blockwise_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id());
|
||||
|
||||
const index_t k_thread_data_on_global =
|
||||
k_block_data_on_global + c_thread_mtx_on_block.row;
|
||||
|
||||
const index_t b_thread_data_on_global =
|
||||
b_block_data_on_global + c_thread_mtx_on_block.col / N2;
|
||||
|
||||
ThreadwiseGenericTensorSliceCopy_v4r2<
|
||||
decltype(out_k0_k1_n1_b_n2_thread_desc),
|
||||
decltype(out_k0_k1_n1_b_n2_global_desc),
|
||||
decltype(out_k0_k1_n1_b_n2_thread_desc.GetLengths()),
|
||||
arithmetic_sequence_gen<0, 5, 1>::type,
|
||||
3,
|
||||
1,
|
||||
1,
|
||||
AddressSpace::Vgpr,
|
||||
AddressSpace::Global,
|
||||
InMemoryDataOperation::Set>(make_multi_index(0, 0, 0, 0, 0),
|
||||
make_multi_index(k_thread_data_on_global / K1,
|
||||
k_thread_data_on_global % K1,
|
||||
0,
|
||||
b_thread_data_on_global,
|
||||
0))
|
||||
.Run(p_out_thread, p_out_global);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
@@ -1,171 +0,0 @@
|
||||
#ifndef CK_GRIDWISE_CONVOLUTION_FORWARD_IMPLICIT_GEMM_V4R4_NCHW_KCYX_NKHW_HPP
|
||||
#define CK_GRIDWISE_CONVOLUTION_FORWARD_IMPLICIT_GEMM_V4R4_NCHW_KCYX_NKHW_HPP
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "tensor_descriptor.hpp"
|
||||
#include "tensor_descriptor_helper.hpp"
|
||||
#include "gridwise_gemm.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
// GemmM = K
|
||||
// GemmN = N * Ho * Wo
|
||||
// GemmK = C * Y * X
|
||||
template <index_t GridSize,
|
||||
index_t BlockSize,
|
||||
typename Float,
|
||||
typename AccFloat,
|
||||
typename InGlobalDesc,
|
||||
typename WeiGlobalDesc,
|
||||
typename OutGlobalDesc,
|
||||
typename ConvStrides,
|
||||
typename ConvDilations,
|
||||
typename InLeftPads,
|
||||
typename InRightPads,
|
||||
index_t GemmMPerBlock,
|
||||
index_t GemmNPerBlock,
|
||||
index_t GemmKPerBlock,
|
||||
index_t GemmMPerThread,
|
||||
index_t GemmNPerThread,
|
||||
index_t GemmKPerThread,
|
||||
index_t GemmMLevel0Cluster,
|
||||
index_t GemmNLevel0Cluster,
|
||||
index_t GemmMLevel1Cluster,
|
||||
index_t GemmNLevel1Cluster,
|
||||
index_t ThreadGemmDataPerRead_GemmM,
|
||||
index_t ThreadGemmDataPerRead_GemmN,
|
||||
typename GemmABlockCopyThreadSliceLengths_GemmK_GemmM,
|
||||
typename GemmABlockCopyThreadClusterLengths_GemmK_GemmM,
|
||||
index_t GemmABlockCopySrcDataPerRead_GemmK,
|
||||
index_t GemmABlockCopyDstDataPerWrite_GemmM,
|
||||
typename GemmBBlockCopyThreadSliceLengths_GemmK_GemmN,
|
||||
typename GemmBBlockCopyThreadClusterLengths_GemmK_GemmN,
|
||||
index_t GemmBBlockCopySrcDataPerRead_GemmN,
|
||||
index_t GemmBBlockCopyDstDataPerWrite_GemmN,
|
||||
index_t GemmCThreadCopyDstDataPerWrite_GemmN1>
|
||||
struct GridwiseConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw
|
||||
{
|
||||
__device__ void Run(const Float* const __restrict__ p_in_global,
|
||||
const Float* const __restrict__ p_wei_global,
|
||||
Float* const __restrict__ p_out_global) const
|
||||
{
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
constexpr auto in_n_c_hi_wi_global_desc = InGlobalDesc{};
|
||||
constexpr auto wei_k_c_y_x_global_desc = WeiGlobalDesc{};
|
||||
constexpr auto out_n_k_ho_wo_global_desc = OutGlobalDesc{};
|
||||
|
||||
constexpr index_t N = in_n_c_hi_wi_global_desc.GetLengths()[0];
|
||||
constexpr index_t C = in_n_c_hi_wi_global_desc.GetLengths()[1];
|
||||
constexpr index_t Hi = in_n_c_hi_wi_global_desc.GetLengths()[2];
|
||||
constexpr index_t Wi = in_n_c_hi_wi_global_desc.GetLengths()[3];
|
||||
|
||||
constexpr index_t K = out_n_k_ho_wo_global_desc.GetLengths()[1];
|
||||
constexpr index_t Ho = out_n_k_ho_wo_global_desc.GetLengths()[2];
|
||||
constexpr index_t Wo = out_n_k_ho_wo_global_desc.GetLengths()[3];
|
||||
|
||||
constexpr index_t Y = wei_k_c_y_x_global_desc.GetLengths()[2];
|
||||
constexpr index_t X = wei_k_c_y_x_global_desc.GetLengths()[3];
|
||||
|
||||
constexpr index_t ConvStrideH = ConvStrides{}[0];
|
||||
constexpr index_t ConvStrideW = ConvStrides{}[1];
|
||||
|
||||
constexpr index_t ConvDilationH = ConvDilations{}[0];
|
||||
constexpr index_t ConvDilationW = ConvDilations{}[1];
|
||||
|
||||
#if 0
|
||||
// sanity-check for vectorized memory load
|
||||
static_assert((Wo == 1 || (ConvStrideW == 1 || GemmBBlockCopySrcDataPerRead_GemmN == 1)) &&
|
||||
(X == 1 || ConvDilationW % GemmBBlockCopySrcDataPerRead_GemmN == 0) &&
|
||||
InLeftPads{}[1] % GemmBBlockCopySrcDataPerRead_GemmN == 0 &&
|
||||
InRightPads{}[1] % GemmBBlockCopySrcDataPerRead_GemmN == 0,
|
||||
"wrong! aligment requirement for vectorized global load of input tensor will "
|
||||
"be violated");
|
||||
#endif
|
||||
|
||||
// weight tensor
|
||||
constexpr auto wei_gemmk_gemmm_global_desc = reorder_tensor_descriptor_given_upper2lower(
|
||||
unfold_tensor_descriptor(wei_k_c_y_x_global_desc, I1, I3), Sequence<1, 0>{});
|
||||
|
||||
// input tensor
|
||||
constexpr auto in_n_c_hip_wip_global_desc = transform_tensor_descriptor(
|
||||
in_n_c_hi_wi_global_desc,
|
||||
make_tuple(PassThrough<N>{},
|
||||
PassThrough<C>{},
|
||||
Pad<Sequence<Hi, Wi>, InLeftPads, InRightPads>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}));
|
||||
|
||||
constexpr index_t Hip = in_n_c_hip_wip_global_desc.GetLengths()[2];
|
||||
constexpr index_t Wip = in_n_c_hip_wip_global_desc.GetLengths()[3];
|
||||
|
||||
constexpr auto in_n_c_y_ho_x_wo_global_desc = transform_tensor_descriptor(
|
||||
in_n_c_hip_wip_global_desc,
|
||||
make_tuple(PassThrough<N>{},
|
||||
PassThrough<C>{},
|
||||
Embed<Hip, Sequence<Y, Ho>, Sequence<ConvDilationH, ConvStrideH, 0>>{},
|
||||
Embed<Wip, Sequence<X, Wo>, Sequence<ConvDilationW, ConvStrideW, 0>>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}));
|
||||
|
||||
constexpr auto in_gemmk_gemmn_global_desc = transform_tensor_descriptor(
|
||||
in_n_c_y_ho_x_wo_global_desc,
|
||||
make_tuple(Merge<Sequence<C, Y, X>>{}, Merge<Sequence<N, Ho, Wo>>{}),
|
||||
make_tuple(Sequence<1, 2, 4>{}, Sequence<0, 3, 5>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
// output tensor
|
||||
constexpr auto out_gemmm_gemmn_global_desc =
|
||||
transform_tensor_descriptor(unfold_tensor_descriptor(out_n_k_ho_wo_global_desc, I2, I3),
|
||||
make_tuple(PassThrough<K>{}, Merge<Sequence<N, Ho * Wo>>{}),
|
||||
make_tuple(Sequence<1>{}, Sequence<0, 2>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
// GEMM
|
||||
constexpr auto gridwise_gemm =
|
||||
GridwiseGemmTransposedANormalBNormalC_v1<GridSize,
|
||||
BlockSize,
|
||||
Float,
|
||||
AccFloat,
|
||||
decltype(wei_gemmk_gemmm_global_desc),
|
||||
decltype(in_gemmk_gemmn_global_desc),
|
||||
decltype(out_gemmm_gemmn_global_desc),
|
||||
InMemoryDataOperation::Set,
|
||||
GemmMPerBlock,
|
||||
GemmNPerBlock,
|
||||
GemmKPerBlock,
|
||||
GemmMPerThread,
|
||||
GemmNPerThread,
|
||||
GemmKPerThread,
|
||||
GemmMLevel0Cluster,
|
||||
GemmNLevel0Cluster,
|
||||
GemmMLevel1Cluster,
|
||||
GemmNLevel1Cluster,
|
||||
ThreadGemmDataPerRead_GemmM,
|
||||
ThreadGemmDataPerRead_GemmN,
|
||||
GemmABlockCopyThreadSliceLengths_GemmK_GemmM,
|
||||
GemmABlockCopyThreadClusterLengths_GemmK_GemmM,
|
||||
Sequence<1, 0>,
|
||||
Sequence<1, 0>,
|
||||
0,
|
||||
GemmABlockCopySrcDataPerRead_GemmK,
|
||||
GemmABlockCopyDstDataPerWrite_GemmM,
|
||||
GemmBBlockCopyThreadSliceLengths_GemmK_GemmN,
|
||||
GemmBBlockCopyThreadClusterLengths_GemmK_GemmN,
|
||||
Sequence<0, 1>,
|
||||
Sequence<0, 1>,
|
||||
1,
|
||||
GemmBBlockCopySrcDataPerRead_GemmN,
|
||||
GemmBBlockCopyDstDataPerWrite_GemmN,
|
||||
Sequence<2, 3, 0, 1>,
|
||||
3,
|
||||
GemmCThreadCopyDstDataPerWrite_GemmN1>{};
|
||||
|
||||
gridwise_gemm.Run(p_wei_global, p_in_global, p_out_global);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
@@ -1,162 +0,0 @@
|
||||
#ifndef CK_GRIDWISE_CONVOLUTION_FORWARD_IMPLICIT_GEMM_V4R4_NHWC_KYXC_NHWK_HPP
|
||||
#define CK_GRIDWISE_CONVOLUTION_FORWARD_IMPLICIT_GEMM_V4R4_NHWC_KYXC_NHWK_HPP
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "tensor_descriptor.hpp"
|
||||
#include "tensor_descriptor_helper.hpp"
|
||||
#include "gridwise_gemm.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
// GemmM = K
|
||||
// GemmN = N * Ho * Wo
|
||||
// GemmK = C * Y * X
|
||||
template <index_t GridSize,
|
||||
index_t BlockSize,
|
||||
typename Float,
|
||||
typename AccFloat,
|
||||
typename InGlobalDesc,
|
||||
typename WeiGlobalDesc,
|
||||
typename OutGlobalDesc,
|
||||
typename ConvStrides,
|
||||
typename ConvDilations,
|
||||
typename InLeftPads,
|
||||
typename InRightPads,
|
||||
index_t GemmMPerBlock,
|
||||
index_t GemmNPerBlock,
|
||||
index_t GemmKPerBlock,
|
||||
index_t GemmMPerThread,
|
||||
index_t GemmNPerThread,
|
||||
index_t GemmKPerThread,
|
||||
index_t GemmMLevel0Cluster,
|
||||
index_t GemmNLevel0Cluster,
|
||||
index_t GemmMLevel1Cluster,
|
||||
index_t GemmNLevel1Cluster,
|
||||
index_t ThreadGemmDataPerRead_GemmM,
|
||||
index_t ThreadGemmDataPerRead_GemmN,
|
||||
typename GemmABlockCopyThreadSliceLengths_GemmK_GemmM,
|
||||
typename GemmABlockCopyThreadClusterLengths_GemmK_GemmM,
|
||||
index_t GemmABlockCopySrcDataPerRead_GemmK,
|
||||
index_t GemmABlockCopyDstDataPerWrite_GemmM,
|
||||
typename GemmBBlockCopyThreadSliceLengths_GemmK_GemmN,
|
||||
typename GemmBBlockCopyThreadClusterLengths_GemmK_GemmN,
|
||||
index_t GemmBBlockCopySrcDataPerRead_GemmK,
|
||||
index_t GemmBBlockCopyDstDataPerWrite_GemmN,
|
||||
index_t GemmCThreadCopyDstDataPerWrite_GemmM1>
|
||||
struct GridwiseConvolutionForwardImplicitGemm_v4r4_nhwc_kyxc_nhwk
|
||||
{
|
||||
__device__ void Run(const Float* const __restrict__ p_in_global,
|
||||
const Float* const __restrict__ p_wei_global,
|
||||
Float* const __restrict__ p_out_global) const
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
constexpr auto in_n_hi_wi_c_global_desc = InGlobalDesc{};
|
||||
constexpr auto wei_k_y_x_c_global_desc = WeiGlobalDesc{};
|
||||
constexpr auto out_n_ho_wo_k_global_desc = OutGlobalDesc{};
|
||||
|
||||
constexpr index_t N = in_n_hi_wi_c_global_desc.GetLengths()[I0];
|
||||
constexpr index_t Hi = in_n_hi_wi_c_global_desc.GetLengths()[I1];
|
||||
constexpr index_t Wi = in_n_hi_wi_c_global_desc.GetLengths()[I2];
|
||||
constexpr index_t C = in_n_hi_wi_c_global_desc.GetLengths()[I3];
|
||||
|
||||
constexpr index_t K = out_n_ho_wo_k_global_desc.GetLengths()[I3];
|
||||
constexpr index_t Ho = out_n_ho_wo_k_global_desc.GetLengths()[I1];
|
||||
constexpr index_t Wo = out_n_ho_wo_k_global_desc.GetLengths()[I2];
|
||||
|
||||
constexpr index_t Y = wei_k_y_x_c_global_desc.GetLengths()[I1];
|
||||
constexpr index_t X = wei_k_y_x_c_global_desc.GetLengths()[I2];
|
||||
|
||||
constexpr index_t ConvStrideH = ConvStrides{}[I0];
|
||||
constexpr index_t ConvStrideW = ConvStrides{}[I1];
|
||||
|
||||
constexpr index_t ConvDilationH = ConvDilations{}[I0];
|
||||
constexpr index_t ConvDilationW = ConvDilations{}[I1];
|
||||
|
||||
// weight tensor
|
||||
constexpr auto wei_gemmk_gemmm_global_desc = reorder_tensor_descriptor_given_upper2lower(
|
||||
unfold_tensor_descriptor(wei_k_y_x_c_global_desc, I1, I3), Sequence<1, 0>{});
|
||||
|
||||
// input tensor
|
||||
constexpr auto in_n_hip_wip_c_global_desc =
|
||||
transform_tensor_descriptor(in_n_hi_wi_c_global_desc,
|
||||
make_tuple(PassThrough<N>{},
|
||||
Pad<Sequence<Hi, Wi>, InLeftPads, InRightPads>{},
|
||||
PassThrough<C>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}));
|
||||
|
||||
constexpr index_t Hip = in_n_hip_wip_c_global_desc.GetLengths()[I1];
|
||||
constexpr index_t Wip = in_n_hip_wip_c_global_desc.GetLengths()[I2];
|
||||
|
||||
constexpr auto in_n_y_ho_x_wo_c_global_desc = transform_tensor_descriptor(
|
||||
in_n_hip_wip_c_global_desc,
|
||||
make_tuple(PassThrough<N>{},
|
||||
Embed<Hip, Sequence<Y, Ho>, Sequence<ConvDilationH, ConvStrideH, 0>>{},
|
||||
Embed<Wip, Sequence<X, Wo>, Sequence<ConvDilationW, ConvStrideW, 0>>{},
|
||||
PassThrough<C>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
|
||||
|
||||
constexpr auto in_gemmk_gemmn_global_desc = transform_tensor_descriptor(
|
||||
in_n_y_ho_x_wo_c_global_desc,
|
||||
make_tuple(Merge<Sequence<Y, X, C>>{}, Merge<Sequence<N, Ho, Wo>>{}),
|
||||
make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
// output tensor
|
||||
constexpr auto out_gemmm_gemmn_global_desc = transform_tensor_descriptor(
|
||||
unfold_tensor_descriptor(out_n_ho_wo_k_global_desc, I0, I2),
|
||||
make_tuple(PassThrough<K>{}, Merge<Sequence<N * Ho * Wo>>{}),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
// GEMM
|
||||
constexpr auto gridwise_gemm =
|
||||
GridwiseGemmTransposedANormalBNormalC_v1<GridSize,
|
||||
BlockSize,
|
||||
Float,
|
||||
AccFloat,
|
||||
decltype(wei_gemmk_gemmm_global_desc),
|
||||
decltype(in_gemmk_gemmn_global_desc),
|
||||
decltype(out_gemmm_gemmn_global_desc),
|
||||
InMemoryDataOperation::Set,
|
||||
GemmMPerBlock,
|
||||
GemmNPerBlock,
|
||||
GemmKPerBlock,
|
||||
GemmMPerThread,
|
||||
GemmNPerThread,
|
||||
GemmKPerThread,
|
||||
GemmMLevel0Cluster,
|
||||
GemmNLevel0Cluster,
|
||||
GemmMLevel1Cluster,
|
||||
GemmNLevel1Cluster,
|
||||
ThreadGemmDataPerRead_GemmM,
|
||||
ThreadGemmDataPerRead_GemmN,
|
||||
GemmABlockCopyThreadSliceLengths_GemmK_GemmM,
|
||||
GemmABlockCopyThreadClusterLengths_GemmK_GemmM,
|
||||
Sequence<1, 0>,
|
||||
Sequence<1, 0>,
|
||||
0,
|
||||
GemmABlockCopySrcDataPerRead_GemmK,
|
||||
GemmABlockCopyDstDataPerWrite_GemmM,
|
||||
GemmBBlockCopyThreadSliceLengths_GemmK_GemmN,
|
||||
GemmBBlockCopyThreadClusterLengths_GemmK_GemmN,
|
||||
Sequence<1, 0>,
|
||||
Sequence<1, 0>,
|
||||
0,
|
||||
GemmBBlockCopySrcDataPerRead_GemmK,
|
||||
GemmBBlockCopyDstDataPerWrite_GemmN,
|
||||
Sequence<2, 3, 0, 1>,
|
||||
1,
|
||||
GemmCThreadCopyDstDataPerWrite_GemmM1>{};
|
||||
|
||||
gridwise_gemm.Run(p_wei_global, p_in_global, p_out_global);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
@@ -1,80 +0,0 @@
|
||||
#ifndef CK_CONSTANT_MATRIX_DESCRIPTOR_HPP
|
||||
#define CK_CONSTANT_MATRIX_DESCRIPTOR_HPP
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "tensor_descriptor.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <index_t NRow_, index_t NCol_, index_t RowStride_>
|
||||
struct ConstantMatrixDescriptor
|
||||
{
|
||||
__host__ __device__ constexpr ConstantMatrixDescriptor()
|
||||
{
|
||||
static_assert(NCol_ <= RowStride_, "wrong! NCol > RowStride!");
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr index_t NRow() { return NRow_; }
|
||||
|
||||
__host__ __device__ static constexpr index_t NCol() { return NCol_; }
|
||||
|
||||
__host__ __device__ static constexpr index_t RowStride() { return RowStride_; }
|
||||
|
||||
__host__ __device__ static constexpr auto GetLengths() { return Sequence<NRow_, NCol_>{}; }
|
||||
|
||||
__host__ __device__ static constexpr index_t GetElementSize() { return NRow_ * NCol_; }
|
||||
|
||||
__host__ __device__ static constexpr index_t GetElementSpace() { return NRow_ * RowStride_; }
|
||||
|
||||
__host__ __device__ static index_t GetOffsetFromMultiIndex(index_t irow, index_t icol)
|
||||
{
|
||||
return irow * RowStride_ + icol;
|
||||
}
|
||||
|
||||
__host__ __device__ static index_t CalculateOffset(index_t irow, index_t icol)
|
||||
{
|
||||
return GetOffsetFromMultiIndex(irow, icol);
|
||||
}
|
||||
|
||||
template <index_t SubNRow, index_t SubNCol>
|
||||
__host__ __device__ static constexpr auto MakeSubMatrixDescriptor(Number<SubNRow>,
|
||||
Number<SubNCol>)
|
||||
{
|
||||
return ConstantMatrixDescriptor<SubNRow, SubNCol, RowStride_>{};
|
||||
}
|
||||
};
|
||||
|
||||
template <index_t NRow, index_t NCol>
|
||||
__host__ __device__ constexpr auto make_ConstantMatrixDescriptor_packed(Number<NRow>, Number<NCol>)
|
||||
{
|
||||
return ConstantMatrixDescriptor<NRow, NCol, NCol>{};
|
||||
}
|
||||
|
||||
template <index_t NRow, index_t NCol, index_t RowStride>
|
||||
__host__ __device__ constexpr auto
|
||||
make_ConstantMatrixDescriptor(Number<NRow>, Number<NCol>, Number<RowStride>)
|
||||
{
|
||||
return ConstantMatrixDescriptor<NRow, NCol, RowStride>{};
|
||||
}
|
||||
|
||||
template <typename... Ts>
|
||||
__host__ __device__ constexpr auto make_ConstantMatrixDescriptor(NativeTensorDescriptor<Ts...>)
|
||||
{
|
||||
using TDesc = NativeTensorDescriptor<Ts...>;
|
||||
static_assert(TDesc::GetNumOfDimension() == 2, "wrong");
|
||||
static_assert(TDesc::GetStrides()[1] == 1, "wrong");
|
||||
return ConstantMatrixDescriptor<TDesc::GetLengths()[0],
|
||||
TDesc::GetLengths()[1],
|
||||
TDesc::GetStrides()[0]>{};
|
||||
}
|
||||
|
||||
template <typename TDesc>
|
||||
__host__ __device__ void print_ConstantMatrixDescriptor(TDesc, const char* s)
|
||||
{
|
||||
printf(
|
||||
"%s NRow %u NCol %u RowStride %u\n", s, TDesc::NRow(), TDesc::NCol(), TDesc::RowStride());
|
||||
}
|
||||
|
||||
} // namespace ck
|
||||
|
||||
#endif
|
||||
@@ -2,50 +2,10 @@
|
||||
#define CK_CLUSTER_DESCRIPTOR_HPP
|
||||
|
||||
#include "common_header.hpp"
|
||||
|
||||
// TODO remove dependency on deprecated tensor descriptor
|
||||
#include "tensor_descriptor.hpp"
|
||||
#include "tensor_adaptor.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
// a cluster map 1d index to N-d index
|
||||
template <typename Lengths, typename ArrangeOrder>
|
||||
struct ClusterDescriptor
|
||||
{
|
||||
static constexpr index_t nDim = Lengths::Size();
|
||||
|
||||
static constexpr auto mDesc = transform_tensor_descriptor(
|
||||
make_native_tensor_descriptor_packed(Lengths{}),
|
||||
make_tuple(Merge<decltype(Lengths::ReorderGivenNew2Old(ArrangeOrder{}))>{}),
|
||||
make_tuple(ArrangeOrder{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
|
||||
__host__ __device__ constexpr ClusterDescriptor()
|
||||
{
|
||||
static_assert(Lengths::Size() == nDim && ArrangeOrder::Size() == nDim,
|
||||
"wrong! size not the same");
|
||||
|
||||
static_assert(is_valid_sequence_map<ArrangeOrder>{}, "wrong! ArrangeOrder is wrong");
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr index_t GetElementSize() { return mDesc.GetElementSize(); }
|
||||
|
||||
__host__ __device__ static constexpr auto CalculateClusterIndex(index_t idx_1d)
|
||||
{
|
||||
return mDesc.CalculateLowerIndex(MultiIndex<1>{idx_1d});
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Lengths,
|
||||
typename ArrangeOrder = typename arithmetic_sequence_gen<0, Lengths::Size(), 1>::type>
|
||||
__host__ __device__ constexpr auto make_cluster_descriptor(
|
||||
Lengths, ArrangeOrder order = typename arithmetic_sequence_gen<0, Lengths::Size(), 1>::type{})
|
||||
{
|
||||
return ClusterDescriptor<Lengths, decltype(order)>{};
|
||||
}
|
||||
|
||||
#if 1
|
||||
template <typename Lengths,
|
||||
typename ArrangeOrder = typename arithmetic_sequence_gen<0, Lengths::Size(), 1>::type>
|
||||
__host__ __device__ constexpr auto make_cluster_descriptor_v2(
|
||||
@@ -68,7 +28,6 @@ __host__ __device__ constexpr auto make_cluster_descriptor_v2(
|
||||
return make_single_stage_tensor_adaptor(
|
||||
make_tuple(transform), make_tuple(low_dim_old_top_ids), make_tuple(up_dim_new_top_ids));
|
||||
}
|
||||
#endif
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
|
||||
@@ -1,17 +0,0 @@
|
||||
#ifndef CK_DIMENSION_HPP
|
||||
#define CK_DIMENSION_HPP
|
||||
|
||||
#include "common_header.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <index_t Length, index_t Stride>
|
||||
struct NativeDimension
|
||||
{
|
||||
__host__ __device__ static constexpr auto GetLength() { return Number<Length>{}; }
|
||||
|
||||
__host__ __device__ static constexpr auto GetStride() { return Number<Stride>{}; }
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
@@ -1,523 +0,0 @@
|
||||
#ifndef CK_MULTI_INDEX_TRANSFORM_HPP
|
||||
#define CK_MULTI_INDEX_TRANSFORM_HPP
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "multi_index.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <index_t Length>
|
||||
struct PassThrough
|
||||
{
|
||||
using LowerIndex = MultiIndex<1>;
|
||||
using UpperIndex = MultiIndex<1>;
|
||||
|
||||
__host__ __device__ static constexpr auto GetNumOfLowerDimension() { return Number<1>{}; }
|
||||
|
||||
__host__ __device__ static constexpr auto GetNumOfUpperDimension() { return Number<1>{}; }
|
||||
|
||||
__host__ __device__ static constexpr auto GetUpperLengths() { return Sequence<Length>{}; }
|
||||
|
||||
__host__ __device__ static constexpr auto CalculateLowerIndex(const UpperIndex& idx_up)
|
||||
{
|
||||
return idx_up;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto
|
||||
CalculateLowerIndexDiff(const UpperIndex& idx_up_diff,
|
||||
const UpperIndex& /* idx_up_old */,
|
||||
const LowerIndex& /* idx_low_old */)
|
||||
{
|
||||
return idx_up_diff;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr bool IsLinearTransform() { return true; }
|
||||
|
||||
__host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex()
|
||||
{
|
||||
return true;
|
||||
}
|
||||
};
|
||||
|
||||
// By default, will automatically judge if is-valid check for upper-to-lower-index-mapping is
|
||||
// necessary
|
||||
// However, the check will be skipped if SkipIsValidCheck is set to true by user
|
||||
// LowerLengths: Sequence<...>
|
||||
template <typename LowerLengths,
|
||||
typename LeftPads,
|
||||
typename RightPads,
|
||||
bool SkipIsValidCheck = false>
|
||||
struct Pad
|
||||
{
|
||||
static constexpr index_t nDim = LowerLengths::Size();
|
||||
|
||||
using LowerIndex = MultiIndex<nDim>;
|
||||
using UpperIndex = MultiIndex<nDim>;
|
||||
|
||||
__host__ __device__ constexpr Pad()
|
||||
{
|
||||
static_assert(LowerLengths::GetSize() == nDim && LeftPads::GetSize() == nDim &&
|
||||
RightPads::GetSize() == nDim,
|
||||
"wrong! # of dimensions not consistent");
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto GetNumOfLowerDimension() { return Number<nDim>{}; }
|
||||
|
||||
__host__ __device__ static constexpr auto GetNumOfUpperDimension() { return Number<nDim>{}; }
|
||||
|
||||
__host__ __device__ static constexpr auto GetUpperLengths()
|
||||
{
|
||||
return LowerLengths{} + LeftPads{} + RightPads{};
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto CalculateLowerIndex(const UpperIndex& idx_up)
|
||||
{
|
||||
return idx_up - LeftPads{};
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto
|
||||
CalculateLowerIndexDiff(const UpperIndex& idx_up_diff,
|
||||
const UpperIndex& /* idx_up_old */,
|
||||
const LowerIndex& /* idx_low_old */)
|
||||
{
|
||||
return idx_up_diff;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr bool IsLinearTransform() { return true; }
|
||||
|
||||
__host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex()
|
||||
{
|
||||
// skip valid check if user request it
|
||||
if(SkipIsValidCheck)
|
||||
{
|
||||
return true;
|
||||
}
|
||||
|
||||
bool flag = true;
|
||||
|
||||
for(index_t i = 0; i < nDim; ++i)
|
||||
{
|
||||
flag = flag && LeftPads::At(i) == 0 && RightPads::At(i) == 0;
|
||||
}
|
||||
|
||||
return flag;
|
||||
}
|
||||
};
|
||||
|
||||
// LowerLengths: Sequence<...>
|
||||
// SliceBegins: Sequence<...>
|
||||
// SliceEnds: Sequence<...>
|
||||
template <typename LowerLengths, typename SliceBegins, typename SliceEnds>
|
||||
struct Slice
|
||||
{
|
||||
static constexpr index_t nDim = LowerLengths::Size();
|
||||
|
||||
using LowerIndex = MultiIndex<nDim>;
|
||||
using UpperIndex = MultiIndex<nDim>;
|
||||
|
||||
__host__ __device__ constexpr Slice()
|
||||
{
|
||||
static_assert(LowerLengths::GetSize() == nDim && SliceBegins::GetSize() == nDim &&
|
||||
SliceEnds::GetSize() == nDim,
|
||||
"wrong! # of dimensions not consistent");
|
||||
|
||||
#if 0
|
||||
// TODO: would not compile, error on constexpr
|
||||
static_for<0, nDim, 1>{}([&](auto idim) {
|
||||
static_assert(SliceBegins::At(idim) <= SliceEnds::At(idim) &&
|
||||
SliceBegins::At(idim) >= 0 &&
|
||||
SliceEnds::At(idim) <= LowerLengths::At(idim),
|
||||
"wrong! Slice config is wrong");
|
||||
});
|
||||
#endif
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto GetNumOfLowerDimension() { return Number<nDim>{}; }
|
||||
|
||||
__host__ __device__ static constexpr auto GetNumOfUpperDimension() { return Number<nDim>{}; }
|
||||
|
||||
__host__ __device__ static constexpr auto GetUpperLengths()
|
||||
{
|
||||
return SliceEnds{} - SliceBegins{};
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto CalculateLowerIndex(const UpperIndex& idx_up)
|
||||
{
|
||||
return idx_up + SliceBegins{};
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto
|
||||
CalculateLowerIndexDiff(const UpperIndex& idx_up_diff,
|
||||
const UpperIndex& /* idx_up_old */,
|
||||
const LowerIndex& /* idx_low_old */)
|
||||
{
|
||||
return idx_up_diff;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr bool IsLinearTransform() { return true; }
|
||||
|
||||
__host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex()
|
||||
{
|
||||
return true;
|
||||
}
|
||||
};
|
||||
|
||||
// LowerLengths: Sequence<...>
|
||||
template <typename LowerLengths>
|
||||
struct Merge
|
||||
{
|
||||
static constexpr index_t nDimLow = LowerLengths::Size();
|
||||
static constexpr index_t nDimUp = 1;
|
||||
|
||||
using LowerIndex = MultiIndex<nDimLow>;
|
||||
using UpperIndex = MultiIndex<nDimUp>;
|
||||
|
||||
__host__ __device__ static constexpr auto GetNumOfLowerDimension() { return Number<nDimLow>{}; }
|
||||
|
||||
__host__ __device__ static constexpr auto GetNumOfUpperDimension() { return Number<nDimUp>{}; }
|
||||
|
||||
__host__ __device__ static constexpr auto GetUpperLengths()
|
||||
{
|
||||
return Sequence<reduce_on_sequence(
|
||||
LowerLengths{}, math::multiplies<index_t>{}, Number<1>{})>{};
|
||||
}
|
||||
|
||||
// emulate constexpr lambda
|
||||
template <typename PseudoLowStrides>
|
||||
struct lambda_CalculateLowerIndex
|
||||
{
|
||||
index_t& itmp;
|
||||
LowerIndex& idx_low;
|
||||
|
||||
__host__ __device__ constexpr lambda_CalculateLowerIndex(index_t& itmp_,
|
||||
LowerIndex& idx_low_)
|
||||
: itmp(itmp_), idx_low(idx_low_)
|
||||
{
|
||||
}
|
||||
|
||||
template <typename IDim>
|
||||
__host__ __device__ constexpr void operator()(IDim idim) const
|
||||
{
|
||||
constexpr index_t stride = PseudoLowStrides::At(idim);
|
||||
idx_low(idim) = itmp / stride;
|
||||
itmp -= idx_low[idim] * stride;
|
||||
}
|
||||
};
|
||||
|
||||
__host__ __device__ static constexpr auto CalculateLowerIndex(const UpperIndex& idx_up)
|
||||
{
|
||||
LowerIndex idx_low;
|
||||
|
||||
index_t itmp = idx_up[Number<0>{}];
|
||||
|
||||
constexpr auto pseudo_low_strides =
|
||||
reverse_inclusive_scan_sequence(
|
||||
LowerLengths::PopFront(), math::multiplies<index_t>{}, Number<1>{})
|
||||
.PushBack(Number<1>{});
|
||||
|
||||
static_for<0, nDimLow - 1, 1>{}(
|
||||
lambda_CalculateLowerIndex<decltype(pseudo_low_strides)>(itmp, idx_low));
|
||||
|
||||
idx_low(Number<nDimLow - 1>{}) = itmp / pseudo_low_strides[Number<nDimLow - 1>{}];
|
||||
|
||||
return idx_low;
|
||||
}
|
||||
|
||||
// idx_low_diff depends on idx_low_old, so idx_low need to be up-to-date
|
||||
// If idx_up_diff is known at compile-time, many calculations can be optimized
|
||||
// away by compiler
|
||||
// This function assume idx_low_old is not out-of-bound
|
||||
__host__ __device__ static constexpr auto
|
||||
CalculateLowerIndexDiff(const UpperIndex& idx_up_diff,
|
||||
const UpperIndex& /* idx_up_old */,
|
||||
const LowerIndex& idx_low_old)
|
||||
{
|
||||
if(idx_up_diff[Number<0>{}] == 0)
|
||||
{
|
||||
return make_zero_multi_index<nDimLow>();
|
||||
}
|
||||
else
|
||||
{
|
||||
// CalculateLowerIndex(idx_up_diff) has multiple integer divisions.
|
||||
// If idx_up_diff is known at compile-time, the calculation can
|
||||
// be done at compile-time. However, if idx_up_diff is only known
|
||||
// at run-time, then the calculation will also be computed at
|
||||
// run-time, and can be very expensive.
|
||||
LowerIndex idx_low_diff_tmp = CalculateLowerIndex(idx_up_diff);
|
||||
|
||||
// find out the last low dimension that changed
|
||||
index_t last_changed_low_dim = 0;
|
||||
|
||||
static_for<0, nDimLow, 1>{}([&](auto i) {
|
||||
if(idx_low_diff_tmp[i] != 0)
|
||||
{
|
||||
last_changed_low_dim = i;
|
||||
}
|
||||
});
|
||||
|
||||
LowerIndex idx_low_new = idx_low_old + idx_low_diff_tmp;
|
||||
|
||||
if(idx_up_diff[Number<0>{}] > 0)
|
||||
{
|
||||
// do carry check on each low dimension in reversed order
|
||||
// starting from the first digit that changed
|
||||
// don't check the highest dimension
|
||||
bool carry = false;
|
||||
|
||||
static_for<nDimLow - 1, 0, -1>{}([&](auto i) {
|
||||
if(i <= last_changed_low_dim)
|
||||
{
|
||||
if(carry)
|
||||
{
|
||||
++idx_low_new(i);
|
||||
}
|
||||
|
||||
carry = false;
|
||||
|
||||
if(idx_low_new[i] >= LowerLengths::At(i))
|
||||
{
|
||||
idx_low_new(i) -= LowerLengths::At(i);
|
||||
carry = true;
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
// highest dimension, no out-of-bound check
|
||||
if(carry)
|
||||
{
|
||||
++idx_low_new(Number<0>{});
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
// do borrow check on each low dimension in reversed order
|
||||
// starting from the first digit that changed
|
||||
// don't check the highest dimension
|
||||
bool borrow = false;
|
||||
|
||||
static_for<nDimLow - 1, 0, -1>{}([&](auto i) {
|
||||
if(i <= last_changed_low_dim)
|
||||
{
|
||||
if(borrow)
|
||||
{
|
||||
--idx_low_new(i);
|
||||
}
|
||||
|
||||
borrow = false;
|
||||
|
||||
if(idx_low_new[i] < 0)
|
||||
{
|
||||
idx_low_new(i) += LowerLengths::At(i);
|
||||
borrow = true;
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
// highest dimension, no out-of-bound check
|
||||
if(borrow)
|
||||
{
|
||||
--idx_low_new(Number<0>{});
|
||||
}
|
||||
}
|
||||
|
||||
return idx_low_new - idx_low_old;
|
||||
}
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr bool IsLinearTransform() { return false; }
|
||||
|
||||
__host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex()
|
||||
{
|
||||
return true;
|
||||
}
|
||||
};
|
||||
|
||||
// UpperLengths: Sequence<...>
|
||||
template <typename UpperLengths>
|
||||
struct UnMerge
|
||||
{
|
||||
static constexpr index_t nDimLow = 1;
|
||||
static constexpr index_t nDimUp = UpperLengths::Size();
|
||||
|
||||
using LowerIndex = MultiIndex<nDimLow>;
|
||||
using UpperIndex = MultiIndex<nDimUp>;
|
||||
|
||||
__host__ __device__ static constexpr auto GetNumOfLowerDimension() { return Number<nDimLow>{}; }
|
||||
|
||||
__host__ __device__ static constexpr auto GetNumOfUpperDimension() { return Number<nDimUp>{}; }
|
||||
|
||||
__host__ __device__ static constexpr auto GetUpperLengths() { return UpperLengths{}; }
|
||||
|
||||
__host__ __device__ static constexpr auto CalculateLowerIndex(const UpperIndex& idx_up)
|
||||
{
|
||||
LowerIndex idx_low = make_multi_index(0);
|
||||
|
||||
constexpr auto pseudo_up_strides =
|
||||
reverse_inclusive_scan_sequence(
|
||||
UpperLengths::PopFront(), math::multiplies<index_t>{}, Number<1>{})
|
||||
.PushBack(Number<1>{});
|
||||
|
||||
static_for<0, nDimUp, 1>{}(
|
||||
[&](auto idim) { idx_low(Number<0>{}) += idx_up[idim] * pseudo_up_strides[idim]; });
|
||||
|
||||
return idx_low;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto
|
||||
CalculateLowerIndexDiff(const UpperIndex& idx_up_diff,
|
||||
const UpperIndex& /* idx_up_old */,
|
||||
const LowerIndex& /* idx_low_old */)
|
||||
{
|
||||
return CalculateLowerIndex(idx_up_diff);
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr bool IsLinearTransform() { return true; }
|
||||
|
||||
__host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex()
|
||||
{
|
||||
return true;
|
||||
}
|
||||
};
|
||||
|
||||
// By default, will automatically judge if is-valid check for upper-to-lower-index-mapping is
|
||||
// necessary
|
||||
// However, the check will be skipped if SkipIsValidCheck is set to true by user
|
||||
// UpperLengths: Sequence<...>
|
||||
// Coefficients: Sequence<...>
|
||||
// idx_low = coefficients[0, ...nDimUp-1] * idx_up[0, ...nDimUp-1] + coefficients[nDimUp]
|
||||
template <index_t LowerLength,
|
||||
typename UpperLengths,
|
||||
typename Coefficients,
|
||||
bool SkipIsValidCheck = false>
|
||||
struct Embed
|
||||
{
|
||||
static constexpr index_t nDimLow = 1;
|
||||
static constexpr index_t nDimUp = UpperLengths::Size();
|
||||
|
||||
using LowerIndex = MultiIndex<nDimLow>;
|
||||
using UpperIndex = MultiIndex<nDimUp>;
|
||||
|
||||
__host__ __device__ constexpr Embed()
|
||||
{
|
||||
static_assert(UpperLengths::GetSize() == nDimUp && Coefficients::GetSize() == nDimUp + 1,
|
||||
"wrong! # of dimensions not consistent");
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto GetNumOfUpperDimension() { return Number<nDimUp>{}; }
|
||||
|
||||
__host__ __device__ static constexpr auto GetNumOfLowerDimension() { return Number<nDimLow>{}; }
|
||||
|
||||
__host__ __device__ static constexpr auto GetUpperLengths() { return UpperLengths{}; }
|
||||
|
||||
__host__ __device__ static constexpr auto CalculateLowerIndex(const UpperIndex& idx_up)
|
||||
{
|
||||
LowerIndex idx_low = make_multi_index(Coefficients{}[Number<nDimUp>{}]);
|
||||
|
||||
static_for<0, nDimUp, 1>{}(
|
||||
[&](auto i) { idx_low(Number<0>{}) += idx_up[i] * Coefficients{}[i]; });
|
||||
|
||||
return idx_low;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto
|
||||
CalculateLowerIndexDiff(const UpperIndex& idx_up_diff,
|
||||
const UpperIndex& /* idx_up_old */,
|
||||
const LowerIndex& /* idx_low_old */)
|
||||
{
|
||||
LowerIndex idx_low_diff = make_multi_index(0);
|
||||
|
||||
static_for<0, nDimUp, 1>{}(
|
||||
[&](auto i) { idx_low_diff(Number<0>{}) += idx_up_diff[i] * Coefficients{}[i]; });
|
||||
|
||||
return idx_low_diff;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr bool IsLinearTransform() { return true; }
|
||||
|
||||
__host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex()
|
||||
{
|
||||
// skip valid check if user request it
|
||||
if(SkipIsValidCheck)
|
||||
{
|
||||
return true;
|
||||
}
|
||||
|
||||
bool flag = true;
|
||||
|
||||
index_t ncorner = 1;
|
||||
|
||||
for(index_t idim = 0; idim < nDimUp; ++idim)
|
||||
{
|
||||
ncorner *= 2;
|
||||
}
|
||||
|
||||
// loop over each corner of the upper tensor
|
||||
for(index_t icorner = 0; icorner < ncorner; ++icorner)
|
||||
{
|
||||
// generate upper index for each corner
|
||||
auto idx_up = make_zero_multi_index<nDimUp>();
|
||||
|
||||
index_t itmp = icorner;
|
||||
|
||||
static_for<nDimUp, 0, -1>{}([&](auto idim) {
|
||||
auto idim_m1 = idim - Number<1>{};
|
||||
idx_up(idim_m1) = itmp % 2 == 0 ? 0 : UpperLengths::At(idim_m1) - 1;
|
||||
itmp /= 2;
|
||||
});
|
||||
|
||||
// calculate lower index
|
||||
auto idx_low = CalculateLowerIndex(idx_up);
|
||||
|
||||
// judge if lower index is valid
|
||||
flag = flag && idx_low[Number<0>{}] >= 0 && idx_low[Number<0>{}] < LowerLength;
|
||||
}
|
||||
|
||||
return flag;
|
||||
}
|
||||
};
|
||||
|
||||
// LowerLengths: Sequence<...>
|
||||
// LowerFreezePoint: Sequence<...>
|
||||
template <typename LowerLengths, typename LowerFreezePoint>
|
||||
struct Freeze
|
||||
{
|
||||
static constexpr index_t nDimLow = LowerLengths::Size();
|
||||
static constexpr index_t nDimUp = 0;
|
||||
|
||||
using LowerIndex = MultiIndex<nDimLow>;
|
||||
using UpperIndex = MultiIndex<nDimUp>;
|
||||
|
||||
__host__ __device__ constexpr Freeze()
|
||||
{
|
||||
// TODO: sanity check: LowerFreezePoint should be within range of LowerLengths
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto GetNumOfLowerDimension() { return Number<nDimLow>{}; }
|
||||
|
||||
__host__ __device__ static constexpr auto GetNumOfUpperDimension() { return Number<0>{}; }
|
||||
|
||||
__host__ __device__ static constexpr auto GetUpperLengths() { return Sequence<>{}; }
|
||||
|
||||
__host__ __device__ static constexpr auto CalculateLowerIndex(const UpperIndex& /*idx_up*/)
|
||||
{
|
||||
return to_multi_index(LowerFreezePoint{});
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto
|
||||
CalculateLowerIndexDiff(const UpperIndex& /* idx_up_diff */,
|
||||
const UpperIndex& /* idx_up_old */,
|
||||
const LowerIndex& /* idx_low_old */)
|
||||
{
|
||||
return make_zero_multi_index<nDimLow>();
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr bool IsLinearTransform() { return true; }
|
||||
|
||||
__host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex()
|
||||
{
|
||||
return true;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
@@ -1,173 +0,0 @@
|
||||
#ifndef CK_PRINT_TENSOR_DESCRIPTOR_HPP
|
||||
#define CK_PRINT_TENSOR_DESCRIPTOR_HPP
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "tensor_descriptor.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <typename... NativeDimensions>
|
||||
__host__ __device__ void
|
||||
print_tensor_descriptor(const char* s, const NativeTensorDescriptor<NativeDimensions...>& desc)
|
||||
{
|
||||
print_tensor_descriptor_impl(s, desc.GetLengths(), desc.GetStrides());
|
||||
}
|
||||
|
||||
template <typename... Ts>
|
||||
__host__ __device__ void print_tensor_descriptor(const char* s,
|
||||
const TransformedTensorDescriptor<Ts...>& desc)
|
||||
{
|
||||
print_tensor_descriptor_impl(s, desc.GetLengths());
|
||||
}
|
||||
|
||||
template <index_t... Lengths, index_t... Strides>
|
||||
__host__ __device__ void
|
||||
print_tensor_descriptor_impl(const char* s, Sequence<Lengths...>, Sequence<Strides...>)
|
||||
{
|
||||
constexpr index_t nDim = sizeof...(Lengths);
|
||||
|
||||
static_assert(nDim > 0 && nDim <= 12, "wrong!");
|
||||
|
||||
static_if<nDim == 1>{}([&](auto) {
|
||||
printf("%s dim %u, lengths {%u}, strides {%u}\n", s, nDim, Lengths..., Strides...);
|
||||
});
|
||||
|
||||
static_if<nDim == 2>{}([&](auto) {
|
||||
printf("%s dim %u, lengths {%u %u}, strides {%u %u}\n", s, nDim, Lengths..., Strides...);
|
||||
});
|
||||
|
||||
static_if<nDim == 3>{}([&](auto) {
|
||||
printf(
|
||||
"%s dim %u, lengths {%u %u %u}, strides {%u %u %u}\n", s, nDim, Lengths..., Strides...);
|
||||
});
|
||||
|
||||
static_if<nDim == 4>{}([&](auto) {
|
||||
printf("%s dim %u, lengths {%u %u %u %u}, strides {%u %u %u %u}\n",
|
||||
s,
|
||||
nDim,
|
||||
Lengths...,
|
||||
Strides...);
|
||||
});
|
||||
|
||||
static_if<nDim == 5>{}([&](auto) {
|
||||
printf("%s dim %u, lengths {%u %u %u %u %u}, strides {%u %u %u %u %u}\n",
|
||||
s,
|
||||
nDim,
|
||||
Lengths...,
|
||||
Strides...);
|
||||
});
|
||||
|
||||
static_if<nDim == 6>{}([&](auto) {
|
||||
printf("%s dim %u, lengths {%u %u %u %u %u %u}, strides {%u %u %u %u %u %u}\n",
|
||||
s,
|
||||
nDim,
|
||||
Lengths...,
|
||||
Strides...);
|
||||
});
|
||||
|
||||
static_if<nDim == 7>{}([&](auto) {
|
||||
printf("%s dim %u, lengths {%u %u %u %u %u %u %u}, strides {%u %u %u %u %u %u %u}\n",
|
||||
s,
|
||||
nDim,
|
||||
Lengths...,
|
||||
Strides...);
|
||||
});
|
||||
|
||||
static_if<nDim == 8>{}([&](auto) {
|
||||
printf("%s dim %u, lengths {%u %u %u %u %u %u %u %u}, strides {%u %u %u %u %u %u %u %u}\n",
|
||||
s,
|
||||
nDim,
|
||||
Lengths...,
|
||||
Strides...);
|
||||
});
|
||||
|
||||
static_if<nDim == 9>{}([&](auto) {
|
||||
printf("%s dim %u, lengths {%u %u %u %u %u %u %u %u %u}, strides {%u %u %u %u %u %u %u %u "
|
||||
"%u}\n",
|
||||
s,
|
||||
nDim,
|
||||
Lengths...,
|
||||
Strides...);
|
||||
});
|
||||
|
||||
static_if<nDim == 10>{}([&](auto) {
|
||||
printf("%s dim %u, lengths {%u %u %u %u %u %u %u %u %u %u}, strides {%u %u %u %u %u %u %u "
|
||||
"%u %u %u}\n",
|
||||
s,
|
||||
nDim,
|
||||
Lengths...,
|
||||
Strides...);
|
||||
});
|
||||
|
||||
static_if<nDim == 11>{}([&](auto) {
|
||||
printf("%s dim %u, lengths {%u %u %u %u %u %u %u %u %u %u %u}, strides {%u %u %u %u %u %u "
|
||||
"%u %u "
|
||||
"%u %u %u}\n",
|
||||
s,
|
||||
nDim,
|
||||
Lengths...,
|
||||
Strides...);
|
||||
});
|
||||
|
||||
static_if<nDim == 12>{}([&](auto) {
|
||||
printf("%s dim %u, lengths {%u %u %u %u %u %u %u %u %u %u %u %u}, strides {%u %u %u %u %u "
|
||||
"%u %u %u %u "
|
||||
"%u %u %u}\n",
|
||||
s,
|
||||
nDim,
|
||||
Lengths...,
|
||||
Strides...);
|
||||
});
|
||||
}
|
||||
|
||||
template <index_t... Lengths>
|
||||
__host__ __device__ void print_tensor_descriptor_impl(const char* s, Sequence<Lengths...>)
|
||||
{
|
||||
constexpr index_t nDim = sizeof...(Lengths);
|
||||
|
||||
static_assert(nDim > 0 && nDim <= 12, "wrong!");
|
||||
|
||||
static_if<nDim == 1>{}([&](auto) { printf("%s dim %u, lengths {%u}\n", s, nDim, Lengths...); });
|
||||
|
||||
static_if<nDim == 2>{}(
|
||||
[&](auto) { printf("%s dim %u, lengths {%u %u}\n", s, nDim, Lengths...); });
|
||||
|
||||
static_if<nDim == 3>{}(
|
||||
[&](auto) { printf("%s dim %u, lengths {%u %u %u}\n", s, nDim, Lengths...); });
|
||||
|
||||
static_if<nDim == 4>{}(
|
||||
[&](auto) { printf("%s dim %u, lengths {%u %u %u %u}\n", s, nDim, Lengths...); });
|
||||
|
||||
static_if<nDim == 5>{}(
|
||||
[&](auto) { printf("%s dim %u, lengths {%u %u %u %u %u}\n", s, nDim, Lengths...); });
|
||||
|
||||
static_if<nDim == 6>{}(
|
||||
[&](auto) { printf("%s dim %u, lengths {%u %u %u %u %u %u}, \n", s, nDim, Lengths...); });
|
||||
|
||||
static_if<nDim == 7>{}(
|
||||
[&](auto) { printf("%s dim %u, lengths {%u %u %u %u %u %u %u}\n", s, nDim, Lengths...); });
|
||||
|
||||
static_if<nDim == 8>{}([&](auto) {
|
||||
printf("%s dim %u, lengths {%u %u %u %u %u %u %u %u}\n", s, nDim, Lengths...);
|
||||
});
|
||||
|
||||
static_if<nDim == 9>{}([&](auto) {
|
||||
printf("%s dim %u, lengths {%u %u %u %u %u %u %u %u %u}\n", s, nDim, Lengths...);
|
||||
});
|
||||
|
||||
static_if<nDim == 10>{}([&](auto) {
|
||||
printf("%s dim %u, lengths {%u %u %u %u %u %u %u %u %u %u}\n", s, nDim, Lengths...);
|
||||
});
|
||||
|
||||
static_if<nDim == 11>{}([&](auto) {
|
||||
printf("%s dim %u, lengths {%u %u %u %u %u %u %u %u %u %u %u}\n", s, nDim, Lengths...);
|
||||
});
|
||||
|
||||
static_if<nDim == 12>{}([&](auto) {
|
||||
printf("%s dim %u, lengths {%u %u %u %u %u %u %u %u %u %u %u %u}\n", s, nDim, Lengths...);
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace ck
|
||||
|
||||
#endif
|
||||
@@ -1,289 +0,0 @@
|
||||
#ifndef CK_TENSOR_COORDINATE_HPP
|
||||
#define CK_TENSOR_COORDINATE_HPP
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "dimension.hpp"
|
||||
#include "multi_index_transform.hpp"
|
||||
#include "tensor_descriptor.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
// A "tensor cooridnate" is an opaque object that represents a "point of location" inside a tensor
|
||||
// At the bare minimun, user should be able to query the following information from a tensor
|
||||
// coordinate:
|
||||
// 1. Tensor descriptor
|
||||
// 2. Location, represented in the form of multi-index
|
||||
// 3. Location, represented in the form of the offset to the origin of the tensor
|
||||
// 4. If the location is inside invalid area or not, i.e. the padding area of an implicitly padded
|
||||
// tensor is considered invalid, because the padding area doesn't have any physical memory
|
||||
// allocation
|
||||
// A tensor cooridnate also provides following functionality:
|
||||
// 1. Given step size in each dimension, update itself, or return a new tensor cooridnate, so user
|
||||
// can freely move the "point of location" inside the tensor
|
||||
|
||||
// wrapper class for NativeTensorCoordinate and TransformedTensorCoordinate
|
||||
template <typename TensorDesc>
|
||||
struct TensorCoordinate;
|
||||
|
||||
// tensor coordinate for native tensor
|
||||
template <typename NativeTensorDesc>
|
||||
struct NativeTensorCoordinate
|
||||
{
|
||||
using type = NativeTensorCoordinate;
|
||||
using tensor_desc_type = NativeTensorDesc;
|
||||
static constexpr index_t nDim = tensor_desc_type::GetNumOfDimension();
|
||||
using Index = MultiIndex<nDim>;
|
||||
|
||||
__host__ __device__ constexpr NativeTensorCoordinate(Index idx)
|
||||
: mIndex(idx), mOffset(tensor_desc_type::CalculateOffset(idx))
|
||||
{
|
||||
}
|
||||
|
||||
template <typename... Xs>
|
||||
__host__ __device__ constexpr NativeTensorCoordinate(Xs... xs)
|
||||
: NativeTensorCoordinate(make_multi_index(xs...))
|
||||
{
|
||||
}
|
||||
|
||||
template <index_t... Xs>
|
||||
__host__ __device__ constexpr NativeTensorCoordinate(Sequence<Xs...>)
|
||||
: NativeTensorCoordinate(make_mutli_index(Xs...))
|
||||
{
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto GetTensorDescriptor() { return tensor_desc_type{}; }
|
||||
|
||||
__host__ __device__ constexpr const Index& GetUpperIndex() const { return mIndex; }
|
||||
|
||||
__host__ __device__ constexpr const Index& GetIndex() const { return mIndex; }
|
||||
|
||||
__host__ __device__ constexpr const index_t& GetOffset() const { return mOffset; }
|
||||
|
||||
__host__ __device__ constexpr type operator+=(const Index& idx_diff)
|
||||
{
|
||||
// mIndex is updated here, but some (or all) of its entries may never be used
|
||||
// compiler should remove those entries as dead code
|
||||
mIndex += idx_diff;
|
||||
|
||||
mOffset += tensor_desc_type::CalculateOffsetDiff(idx_diff);
|
||||
|
||||
return *this;
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr type operator-=(const Index& idx_diff)
|
||||
{
|
||||
// mIndex is updated here, but some (or all) of its entries may never be used
|
||||
// compiler should remove those entries as dead code
|
||||
mIndex -= idx_diff;
|
||||
|
||||
mOffset -= tensor_desc_type::CalculateOffsetDiff(idx_diff);
|
||||
|
||||
return *this;
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr type operator+(const Index& idx_diff) const
|
||||
{
|
||||
type coord = *this;
|
||||
coord += idx_diff;
|
||||
return coord;
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr type operator-(const Index& idx_diff) const
|
||||
{
|
||||
type coord = *this;
|
||||
coord -= idx_diff;
|
||||
return coord;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr index_t CalculateOffsetDiff(const Index& idx_diff)
|
||||
{
|
||||
return tensor_desc_type::CalculateOffsetDiff(idx_diff);
|
||||
}
|
||||
|
||||
// evaluated at run-time
|
||||
__host__ __device__ constexpr bool IsUpperIndexValid() const
|
||||
{
|
||||
return tensor_desc_type::IsUpperIndexValid(GetUpperIndex());
|
||||
}
|
||||
|
||||
// evaluated at run-time
|
||||
__host__ __device__ constexpr bool IsOffsetValid() const
|
||||
{
|
||||
// For native tensor, offset is valid if upper-index is valid
|
||||
return IsUpperIndexValid();
|
||||
}
|
||||
|
||||
// evaluated at compile-time
|
||||
__host__ __device__ static constexpr bool IsOffsetValidAssumingUpperIndexIsValid()
|
||||
{
|
||||
return true;
|
||||
}
|
||||
|
||||
private:
|
||||
// mIndex may be saved and updated, however, the value of some (or all) of its entries may
|
||||
// never be used. Compiler should be able to remove these entries as well as its calculation
|
||||
// as dead code.
|
||||
// TODO: make sure compiler indeed remove these dead code
|
||||
Index mIndex;
|
||||
index_t mOffset;
|
||||
};
|
||||
|
||||
// tensor coordinate for transformed tensor
|
||||
template <typename TransformedTensorDesc>
|
||||
struct TransformedTensorCoordinate
|
||||
{
|
||||
using tensor_desc_type = TransformedTensorDesc;
|
||||
using LowerCoord =
|
||||
typename TensorCoordinate<decltype(tensor_desc_type::GetLowerTensorDescriptor())>::type;
|
||||
using UpperCoord = TransformedTensorCoordinate;
|
||||
static constexpr index_t nDim = tensor_desc_type::GetNumOfDimension();
|
||||
using UpperIndex = MultiIndex<nDim>;
|
||||
|
||||
__host__ __device__ constexpr TransformedTensorCoordinate(UpperIndex idx)
|
||||
: mIndexUp{idx}, mCoordLow{tensor_desc_type::CalculateLowerIndex(idx)}
|
||||
{
|
||||
}
|
||||
|
||||
template <typename... Xs>
|
||||
__host__ __device__ constexpr TransformedTensorCoordinate(Xs... xs)
|
||||
: TransformedTensorCoordinate(UpperIndex{xs...})
|
||||
{
|
||||
}
|
||||
|
||||
template <index_t... Xs>
|
||||
__host__ __device__ constexpr TransformedTensorCoordinate(Sequence<Xs...>)
|
||||
: TransformedTensorCoordinate(UpperIndex{Xs...})
|
||||
{
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto GetTensorDescriptor() { return tensor_desc_type{}; }
|
||||
|
||||
__host__ __device__ constexpr const LowerCoord& GetLowerCoordinate() const { return mCoordLow; }
|
||||
|
||||
__host__ __device__ constexpr const UpperIndex& GetUpperIndex() const { return mIndexUp; }
|
||||
|
||||
__host__ __device__ constexpr const UpperIndex& GetIndex() const { return GetUpperIndex(); }
|
||||
|
||||
__host__ __device__ constexpr const index_t& GetOffset() const
|
||||
{
|
||||
return GetLowerCoordinate().GetOffset();
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr UpperCoord operator+=(const UpperIndex& idx_up_diff)
|
||||
{
|
||||
// For transformation of multi-index difference, not all transformation functions need to
|
||||
// know the old lower-index or the old upper-index. We pass both of them to the
|
||||
// transformation function. The transformation function itself decides to use them or not.
|
||||
mCoordLow += tensor_desc_type::CalculateLowerIndexDiff(
|
||||
idx_up_diff, GetIndex(), GetLowerCoordinate().GetIndex());
|
||||
|
||||
// mIndexUp is updated here, but some (or all) of its entries may never be used
|
||||
// compiler should remove those entries as dead code
|
||||
mIndexUp += idx_up_diff;
|
||||
|
||||
return *this;
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr UpperCoord operator-=(const UpperIndex& idx_up_diff)
|
||||
{
|
||||
mCoordLow -= tensor_desc_type::CalculateLowerIndexDiff(
|
||||
idx_up_diff, GetIndex(), GetLowerCoordinate().GetIndex());
|
||||
|
||||
// mIndex is updated here, but some (or all) of its entries may never be used
|
||||
// compiler should remove those entries as dead code
|
||||
mIndexUp -= idx_up_diff;
|
||||
|
||||
return *this;
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr UpperCoord operator+(const UpperIndex& idx_up_diff) const
|
||||
{
|
||||
UpperCoord coord_up = *this;
|
||||
coord_up += idx_up_diff;
|
||||
return coord_up;
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr UpperCoord operator-(const UpperIndex& idx_up_diff) const
|
||||
{
|
||||
UpperCoord coord_up = *this;
|
||||
coord_up -= idx_up_diff;
|
||||
return coord_up;
|
||||
}
|
||||
|
||||
// Calculate offset diff without updating tensor-coordinate
|
||||
// If idx_up_diff is know at compile time, and has only non-zero entries on linear dimensions,
|
||||
// then all calculation can be done at compile-time.
|
||||
// TODO: this function is not compiled to expected ISA
|
||||
__host__ __device__ constexpr index_t CalculateOffsetDiff(const UpperIndex& idx_up_diff) const
|
||||
{
|
||||
// For transformation of multi-index difference, not all transformation functions need to
|
||||
// know the old lower-index or the old upper-index. We pass both of them to the
|
||||
// transformation function. The transformation function itself decides to use them or not.
|
||||
const auto idx_low_diff = tensor_desc_type::CalculateLowerIndexDiff(
|
||||
idx_up_diff, GetIndex(), GetLowerCoordinate().GetIndex());
|
||||
|
||||
return GetLowerCoordinate().CalculateOffsetDiff(idx_low_diff);
|
||||
}
|
||||
|
||||
// evaluated at run-time
|
||||
__host__ __device__ constexpr bool IsUpperIndexValid() const
|
||||
{
|
||||
return tensor_desc_type::IsUpperIndexValid(GetUpperIndex());
|
||||
}
|
||||
|
||||
// evaluted at run-time
|
||||
__host__ __device__ constexpr bool IsOffsetValid() const
|
||||
{
|
||||
return IsUpperIndexValid() && GetLowerCoordinate().IsOffsetValid();
|
||||
}
|
||||
|
||||
// most evaluatation is done at comile-time
|
||||
__host__ __device__ constexpr bool IsLowerIndexValidAssumingUpperIndexIsValid() const
|
||||
{
|
||||
return tensor_desc_type::IsLowerIndexValidAssumingUpperIndexIsValid(
|
||||
GetLowerCoordinate().GetIndex());
|
||||
}
|
||||
|
||||
// most evaluatation is done at comile-time
|
||||
__host__ __device__ constexpr bool IsOffsetValidAssumingUpperIndexIsValid() const
|
||||
{
|
||||
return IsLowerIndexValidAssumingUpperIndexIsValid() &&
|
||||
GetLowerCoordinate().IsOffsetValidAssumingUpperIndexIsValid();
|
||||
}
|
||||
|
||||
private:
|
||||
// mIndexUp may be calculated and updated, however, the value of some (or all) of its entries
|
||||
// may
|
||||
// never be used. Compiler should be able to remove these entries as well as its calculation
|
||||
// as dead code.
|
||||
// TODO: make sure compiler indeed remove these dead code
|
||||
UpperIndex mIndexUp;
|
||||
LowerCoord mCoordLow;
|
||||
};
|
||||
|
||||
template <typename TensorDesc>
|
||||
struct TensorCoordinate
|
||||
{
|
||||
private:
|
||||
template <typename... Ts>
|
||||
__host__ __device__ static constexpr auto
|
||||
MakeDummyTensorCoordinate(NativeTensorDescriptor<Ts...>)
|
||||
{
|
||||
return NativeTensorCoordinate<NativeTensorDescriptor<Ts...>>(
|
||||
make_zero_multi_index<TensorDesc::GetNumOfDimension()>());
|
||||
}
|
||||
|
||||
template <typename... Ts>
|
||||
__host__ __device__ static constexpr auto
|
||||
MakeDummyTensorCoordinate(TransformedTensorDescriptor<Ts...>)
|
||||
{
|
||||
return TransformedTensorCoordinate<TransformedTensorDescriptor<Ts...>>(
|
||||
make_zero_multi_index<TensorDesc::GetNumOfDimension()>());
|
||||
}
|
||||
|
||||
public:
|
||||
using type = decltype(MakeDummyTensorCoordinate(TensorDesc{}));
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
@@ -1,526 +0,0 @@
|
||||
#ifndef CK_TENSOR_DESCRIPTOR_HPP
|
||||
#define CK_TENSOR_DESCRIPTOR_HPP
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "dimension.hpp"
|
||||
#include "multi_index_transform.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
// tensor descriptor for "native tensor"
|
||||
// A "native tensor" is a "true" tensor that can be represented by Lengths and Strides
|
||||
template <typename... NativeDimensions>
|
||||
struct NativeTensorDescriptor
|
||||
{
|
||||
using type = NativeTensorDescriptor;
|
||||
static constexpr index_t nDim = sizeof...(NativeDimensions);
|
||||
static constexpr auto mDimensions = make_tuple(NativeDimensions{}...);
|
||||
|
||||
using Index = MultiIndex<nDim>;
|
||||
|
||||
__host__ __device__ static constexpr auto GetNumOfDimension() { return Number<nDim>{}; }
|
||||
|
||||
template <index_t IDim>
|
||||
__host__ __device__ static constexpr auto GetLength(Number<IDim>)
|
||||
{
|
||||
return mDimensions.At(Number<IDim>{}).GetLength();
|
||||
}
|
||||
|
||||
template <index_t IDim>
|
||||
__host__ __device__ static constexpr auto GetStride(Number<IDim>)
|
||||
{
|
||||
return mDimensions.At(Number<IDim>{}).GetStride();
|
||||
}
|
||||
|
||||
template <index_t... IDims>
|
||||
__host__ __device__ static constexpr auto GetLengths(Sequence<IDims...>)
|
||||
{
|
||||
return Sequence<GetLength(Number<IDims>{})...>{};
|
||||
}
|
||||
|
||||
template <index_t... IDims>
|
||||
__host__ __device__ static constexpr auto GetStrides(Sequence<IDims...>)
|
||||
{
|
||||
return Sequence<GetStride(Number<IDims>{})...>{};
|
||||
}
|
||||
|
||||
template <index_t IDim, index_t... IDims>
|
||||
__host__ __device__ static constexpr auto GetLengths(Number<IDim>, Number<IDims>...)
|
||||
{
|
||||
return GetLengths(Sequence<IDim, IDims...>{});
|
||||
}
|
||||
|
||||
template <index_t IDim, index_t... IDims>
|
||||
__host__ __device__ static constexpr auto GetStrides(Number<IDim>, Number<IDims>...)
|
||||
{
|
||||
return GetStrides(Sequence<IDim, IDims...>{});
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto GetLengths()
|
||||
{
|
||||
return GetLengths(typename arithmetic_sequence_gen<0, nDim, 1>::type{});
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto GetStrides()
|
||||
{
|
||||
return GetStrides(typename arithmetic_sequence_gen<0, nDim, 1>::type{});
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr index_t GetElementSize()
|
||||
{
|
||||
return reduce_on_sequence(GetLengths(), math::multiplies<index_t>{}, Number<1>{});
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr index_t GetElementSpace()
|
||||
{
|
||||
return reduce_on_sequence(
|
||||
(GetLengths() - Number<1>{}) * GetStrides(), math::plus<index_t>{}, Number<1>{});
|
||||
}
|
||||
|
||||
// TODO: this cannot return constepxr because of use of lambda
|
||||
__host__ __device__ static constexpr index_t CalculateOffset(const Index& idx)
|
||||
{
|
||||
index_t offset = 0;
|
||||
|
||||
static_for<0, nDim, 1>{}([&](auto idim) { offset += idx[idim] * GetStride(idim); });
|
||||
|
||||
return offset;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr index_t CalculateOffsetDiff(const Index& idx_diff)
|
||||
{
|
||||
index_t offset_diff = 0;
|
||||
|
||||
static_for<0, nDim, 1>{}(
|
||||
[&](auto idim) { offset_diff += idx_diff[idim] * GetStride(idim); });
|
||||
|
||||
return offset_diff;
|
||||
}
|
||||
|
||||
template <index_t IDim>
|
||||
__host__ __device__ static constexpr bool IsLinearDimension(Number<IDim>)
|
||||
{
|
||||
return true;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto GetLinearDimensionMask()
|
||||
{
|
||||
return typename uniform_sequence_gen<nDim, 1>::type{};
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto GetNonLinearDimensionMask()
|
||||
{
|
||||
return typename uniform_sequence_gen<nDim, 0>::type{};
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto GetNonLinearDimensions() { return Sequence<>{}; }
|
||||
|
||||
__host__ __device__ static constexpr auto GetNonLinearIndependentDimensionGroups()
|
||||
{
|
||||
return Tuple<>{};
|
||||
}
|
||||
|
||||
// a multi-index is valid if there is a corresponding point for it in the tensor
|
||||
__host__ __device__ static constexpr bool IsUpperIndexValid(const Index& idx)
|
||||
{
|
||||
bool flag = true;
|
||||
|
||||
for(index_t i = 0; i < nDim; ++i)
|
||||
{
|
||||
flag = flag && idx[i] >= 0 && idx[i] < GetLengths()[i];
|
||||
}
|
||||
|
||||
return flag;
|
||||
}
|
||||
};
|
||||
|
||||
// Tensor descriptor for "transformed tensor"
|
||||
template <typename LowTensorDescriptor, // NativeTensorDescriptor or TransformedTensorDescriptor
|
||||
typename Transforms, // Tuple<MultIndexTransforms...>
|
||||
typename LowDimensionIds, // Tuple<Sequence<...>>
|
||||
typename UpDimensionIds> // Tuple<Sequence<...>>
|
||||
struct TransformedTensorDescriptor
|
||||
{
|
||||
using type = TransformedTensorDescriptor;
|
||||
static constexpr index_t nTransform = Transforms::Size();
|
||||
|
||||
struct lambda_merge_sequences
|
||||
{
|
||||
template <typename... Seqs>
|
||||
__host__ __device__ constexpr auto operator()(Seqs... seqs) const
|
||||
{
|
||||
return merge_sequences(seqs...);
|
||||
}
|
||||
};
|
||||
|
||||
__host__ __device__ static constexpr auto GetNumOfLowerDimension()
|
||||
{
|
||||
// Here, we assume all lower-dimensions are active
|
||||
// TODO: sanity-check all lower-dimension are indeed active
|
||||
|
||||
using duplicated_low_active_dims =
|
||||
decltype(unpack(lambda_merge_sequences{}, LowDimensionIds{}));
|
||||
|
||||
using low_active_dims = typename sequence_unique_sort<duplicated_low_active_dims,
|
||||
math::less<index_t>,
|
||||
math::equal<index_t>>::type;
|
||||
|
||||
return low_active_dims::Size();
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto GetNumOfUpperDimension()
|
||||
{
|
||||
using duplicated_up_active_dims =
|
||||
decltype(unpack(lambda_merge_sequences{}, UpDimensionIds{}));
|
||||
|
||||
using up_active_dims = typename sequence_unique_sort<duplicated_up_active_dims,
|
||||
math::less<index_t>,
|
||||
math::equal<index_t>>::type;
|
||||
|
||||
return up_active_dims::Size();
|
||||
}
|
||||
|
||||
static constexpr index_t nDimUp = GetNumOfUpperDimension();
|
||||
static constexpr index_t nDimLow = GetNumOfLowerDimension();
|
||||
|
||||
using UpperIndex = MultiIndex<nDimUp>;
|
||||
using LowerIndex = MultiIndex<nDimLow>;
|
||||
|
||||
__host__ __device__ constexpr TransformedTensorDescriptor()
|
||||
{
|
||||
static_assert(nTransform == Transforms::Size() && nTransform == LowDimensionIds::Size() &&
|
||||
nTransform == UpDimensionIds::Size(),
|
||||
"wrong! # of transformations not the same");
|
||||
|
||||
// sanity check:
|
||||
// LowDimensionIds should include all low-dimensions,
|
||||
// UpDimensionIds should include all up-dimensions
|
||||
using mingled_up_dimension_ids =
|
||||
decltype(unpack(lambda_merge_sequences{}, UpDimensionIds{}));
|
||||
|
||||
using sorted_up_dimension_ids =
|
||||
typename sequence_sort<mingled_up_dimension_ids, math::less<index_t>>::type;
|
||||
|
||||
static_assert(sorted_up_dimension_ids::Size() == nDimUp &&
|
||||
is_valid_sequence_map<sorted_up_dimension_ids>{},
|
||||
"wrong! UpDimensionIds is not configured correctly");
|
||||
|
||||
using mingled_low_dimension_ids =
|
||||
decltype(unpack(lambda_merge_sequences{}, LowDimensionIds{}));
|
||||
|
||||
using sorted_low_dimension_ids =
|
||||
typename sequence_sort<mingled_low_dimension_ids, math::less<index_t>>::type;
|
||||
|
||||
static_assert(sorted_low_dimension_ids::Size() == nDimLow &&
|
||||
is_valid_sequence_map<sorted_low_dimension_ids>{},
|
||||
"wrong! LowDimensionIds is not configured correctly");
|
||||
|
||||
// TODO: sanity check: while a up-dimension could be associated with multille
|
||||
// transformation, a low-dimension should be associated with only one transformation
|
||||
|
||||
// TODO: sanity-check: GetLowerLengths of each transform should be consistent with lengths
|
||||
// of lower-tensor-descriptor
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto GetNumOfDimension()
|
||||
{
|
||||
return GetNumOfUpperDimension();
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto GetLowerTensorDescriptor()
|
||||
{
|
||||
return LowTensorDescriptor{};
|
||||
}
|
||||
|
||||
struct lambda_GetUpperLengths
|
||||
{
|
||||
template <typename Transform>
|
||||
__host__ __device__ constexpr auto operator()(const Transform& tran) const
|
||||
{
|
||||
return tran.GetUpperLengths();
|
||||
}
|
||||
};
|
||||
|
||||
__host__ __device__ static constexpr auto GetUpperLengths()
|
||||
{
|
||||
constexpr auto tuple_of_up_lengths =
|
||||
transform_tuples(lambda_GetUpperLengths{}, Transforms{});
|
||||
|
||||
constexpr auto mingled_up_lengths = unpack(lambda_merge_sequences{}, tuple_of_up_lengths);
|
||||
|
||||
constexpr auto mingled_up_dimension_ids =
|
||||
unpack(lambda_merge_sequences{}, UpDimensionIds{});
|
||||
|
||||
// TODO: sanity-check mingled_up_dimension_ids contain all upper-dimensions
|
||||
// TODO: sanity-check mingled_up_lengths have no conflicting upper-length
|
||||
|
||||
// sort by upper-dimension-ids
|
||||
using sort_up_dimension_ids = sequence_unique_sort<decltype(mingled_up_dimension_ids),
|
||||
math::less<index_t>,
|
||||
math::equal<index_t>>;
|
||||
|
||||
// sanity-check sorted-upper-dimension-ids should be Sequence<0, 1, ... nDimUp-1>
|
||||
static_assert(is_same<typename sort_up_dimension_ids::type,
|
||||
typename arithmetic_sequence_gen<0, nDimUp, 1>::type>{},
|
||||
"wrong! UpDimensionIds is not configured correctly");
|
||||
|
||||
constexpr auto sorted2unsorted_map = typename sort_up_dimension_ids::sorted2unsorted_map{};
|
||||
|
||||
constexpr auto sorted_up_lengths =
|
||||
pick_sequence_elements_by_ids(mingled_up_lengths, sorted2unsorted_map);
|
||||
|
||||
return sorted_up_lengths;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto GetLengths() { return GetUpperLengths(); }
|
||||
|
||||
template <index_t IDim>
|
||||
__host__ __device__ static constexpr auto GetLength(Number<IDim>)
|
||||
{
|
||||
return GetLengths()[IDim];
|
||||
}
|
||||
|
||||
template <index_t... IDims>
|
||||
__host__ __device__ static constexpr auto GetLengths(Sequence<IDims...>)
|
||||
{
|
||||
return Sequence<GetLength(Number<IDims>{})...>{};
|
||||
}
|
||||
|
||||
template <index_t IDim, index_t... IDims>
|
||||
__host__ __device__ static constexpr auto GetLengths(Number<IDim>, Number<IDims>...)
|
||||
{
|
||||
return GetLengths(Sequence<IDim, IDims...>{});
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr index_t GetElementSize()
|
||||
{
|
||||
return reduce_on_sequence(GetLengths(), math::multiplies<index_t>{}, Number<1>{});
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr index_t GetElementSpace()
|
||||
{
|
||||
// TODO: Is this the correct definition for transformed tensor?
|
||||
return GetLowerTensorDescriptor().GetElementSpace();
|
||||
}
|
||||
|
||||
// TODO: right now return value is not constexpr because use of non-constexpr lambda
|
||||
__host__ __device__ static constexpr LowerIndex CalculateLowerIndex(const UpperIndex& idx_up)
|
||||
{
|
||||
LowerIndex idx_low;
|
||||
|
||||
static_for<0, nTransform, 1>{}([&](auto itran) {
|
||||
constexpr auto tran = Transforms{}.At(itran);
|
||||
|
||||
const auto idx_up_part = pick_container_element(idx_up, UpDimensionIds{}.At(itran));
|
||||
auto idx_low_part = pick_container_element(idx_low, LowDimensionIds{}.At(itran));
|
||||
|
||||
// this assume each lower (single) index is only assocaited with one transformation,
|
||||
// which is required for index transformation, and has been checked during constructor
|
||||
// of TransformedTensorDescriptor
|
||||
idx_low_part = tran.CalculateLowerIndex(to_multi_index(idx_up_part));
|
||||
});
|
||||
|
||||
return idx_low;
|
||||
}
|
||||
|
||||
// TODO: right now return value is not constexpr because use of non-constepxr lambda
|
||||
__host__ __device__ static constexpr LowerIndex CalculateLowerIndexDiff(
|
||||
const UpperIndex& idx_up_diff, const UpperIndex& idx_up_old, const LowerIndex& idx_low_old)
|
||||
{
|
||||
LowerIndex idx_low_diff;
|
||||
|
||||
static_for<0, nTransform, 1>{}([&](auto itran) {
|
||||
constexpr auto tran = Transforms{}.At(itran);
|
||||
|
||||
const auto idx_up_diff_part =
|
||||
pick_container_element(idx_up_diff, UpDimensionIds{}.At(itran));
|
||||
|
||||
const auto idx_up_old_part =
|
||||
pick_container_element(idx_up_old, UpDimensionIds{}.At(itran));
|
||||
|
||||
const auto idx_low_old_part =
|
||||
pick_container_element(idx_low_old, LowDimensionIds{}.At(itran));
|
||||
|
||||
auto idx_low_diff_part =
|
||||
pick_container_element(idx_low_diff, LowDimensionIds{}.At(itran));
|
||||
|
||||
// this assume each lower (single) index is associated with only one transformation,
|
||||
// which is required for index transformation, and has been checked during constructor
|
||||
// of TransformedTensorDescriptor
|
||||
idx_low_diff_part = tran.CalculateLowerIndexDiff(to_multi_index(idx_up_diff_part),
|
||||
to_multi_index(idx_up_old_part),
|
||||
to_multi_index(idx_low_old_part));
|
||||
});
|
||||
|
||||
return idx_low_diff;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr index_t CalculateOffset(const UpperIndex& idx_up)
|
||||
{
|
||||
return GetLowerTensorDescriptor().CalculateOffset(CalculateLowerIndex(idx_up));
|
||||
}
|
||||
|
||||
struct lambda_sequence_logical_and
|
||||
{
|
||||
template <typename... Seqs>
|
||||
__host__ __device__ constexpr auto operator()(Seqs...) const
|
||||
{
|
||||
return typename sequence_reduce<logical_and<index_t>, Seqs...>::type{};
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct lambda_is_true
|
||||
{
|
||||
__host__ __device__ constexpr auto operator()(const T& x) const
|
||||
{
|
||||
// TODO: remove static_cast once Sequence can take bool as entries
|
||||
return static_cast<bool>(x) == true;
|
||||
}
|
||||
};
|
||||
|
||||
struct lambda_get_linear_dimension_mask_of_single_tranform
|
||||
{
|
||||
// check only one transform at a time
|
||||
template <typename Transform, typename LowDimensionId, typename UpDimensionId>
|
||||
__host__ __device__ constexpr auto
|
||||
operator()(Transform, LowDimensionId, UpDimensionId) const
|
||||
{
|
||||
// judge if transformation is linear
|
||||
constexpr bool is_linear_transform = Transform::IsLinearTransform();
|
||||
|
||||
// judge if all lower dimension are linear
|
||||
constexpr bool are_all_low_dim_linear = sequence_all_of(
|
||||
pick_sequence_elements_by_ids(GetLowerTensorDescriptor().GetLinearDimensionMask(),
|
||||
LowDimensionId{}),
|
||||
lambda_is_true<index_t>{});
|
||||
|
||||
// create linear mask for upper dimensions
|
||||
constexpr bool are_up_dim_linear = is_linear_transform && are_all_low_dim_linear;
|
||||
|
||||
constexpr auto mask_of_up_linear_dims = modify_sequence_elements_by_ids(
|
||||
typename uniform_sequence_gen<nDimUp, 1>::type{},
|
||||
typename uniform_sequence_gen<UpDimensionId::Size(), are_up_dim_linear>::type{},
|
||||
UpDimensionId{});
|
||||
|
||||
return mask_of_up_linear_dims;
|
||||
}
|
||||
};
|
||||
|
||||
// TODO: this is a hack, transform_tuples() doesn't compile, would complain about constexpr
|
||||
template <typename F, typename X, typename Y, typename Z, index_t... Is>
|
||||
__host__ __device__ static constexpr auto
|
||||
dummy_transform_tuples_impl(F f, X x, Y y, Z z, Sequence<Is...>)
|
||||
{
|
||||
return make_tuple(f(x.At(Number<Is>{}), y.At(Number<Is>{}), z.At(Number<Is>{}))...);
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto GetLinearDimensionMask()
|
||||
{
|
||||
#if 0
|
||||
// create tuple of linear dimension masks, for all transformations
|
||||
// TODO: this doesn't compile, because transform_tuples() complain about constexpr
|
||||
constexpr auto tuple_of_linear_dimension_mask =
|
||||
transform_tuples(lambda_get_linear_dimension_mask_of_single_tranform{},
|
||||
Transforms{},
|
||||
LowDimensionIds{},
|
||||
UpDimensionIds{});
|
||||
#else
|
||||
// create tuple of linear dimension masks, for all transformations
|
||||
// TODO: this is a hack
|
||||
constexpr auto tuple_of_linear_dimension_mask = dummy_transform_tuples_impl(
|
||||
lambda_get_linear_dimension_mask_of_single_tranform{},
|
||||
Transforms{},
|
||||
LowDimensionIds{},
|
||||
UpDimensionIds{},
|
||||
typename arithmetic_sequence_gen<0, Transforms::Size(), 1>::type{});
|
||||
#endif
|
||||
|
||||
// reduce tuple of masks into one mask
|
||||
constexpr auto linear_dimension_mask =
|
||||
unpack(lambda_sequence_logical_and{}, tuple_of_linear_dimension_mask);
|
||||
|
||||
return linear_dimension_mask;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto GetNonLinearDimensionMask()
|
||||
{
|
||||
return GetLinearDimensionMask().Transform(logical_not<index_t>{});
|
||||
}
|
||||
|
||||
template <index_t IDim>
|
||||
__host__ __device__ static constexpr bool IsLinearDimension(Number<IDim>)
|
||||
{
|
||||
return GetLinearDimensionMask().At(Number<IDim>{});
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto GetLinearDimensions()
|
||||
{
|
||||
constexpr auto linear_dimension_mask = GetLinearDimensionMask();
|
||||
|
||||
return pick_sequence_elements_by_mask(
|
||||
typename arithmetic_sequence_gen<0, nDimUp, 1>::type{}, linear_dimension_mask);
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto GetNonLinearDimensions()
|
||||
{
|
||||
constexpr auto nonlinear_dimension_mask = GetNonLinearDimensionMask();
|
||||
|
||||
return pick_sequence_elements_by_mask(
|
||||
typename arithmetic_sequence_gen<0, nDimUp, 1>::type{}, nonlinear_dimension_mask);
|
||||
}
|
||||
|
||||
#if 0
|
||||
__host__ __device__ static constexpr auto GetNonLinearIndependentDimensionGroups()
|
||||
{
|
||||
// TODO: not implemented
|
||||
}
|
||||
#endif
|
||||
|
||||
// a multi-index is valid if there is a corresponding point for it in the tensor
|
||||
__host__ __device__ constexpr bool IsUpperIndexValid(const UpperIndex& idx_up) const
|
||||
{
|
||||
bool flag = true;
|
||||
|
||||
for(index_t i = 0; i < nDimUp; ++i)
|
||||
{
|
||||
flag = flag && idx_up[i] >= 0 && idx_up[i] < GetLengths()[i];
|
||||
}
|
||||
|
||||
return flag;
|
||||
}
|
||||
|
||||
// this function is for optimization purpose, it's called by tensor coordinate
|
||||
// this function tells you: If a lower-index is valid or not, assuming upper index is valid
|
||||
__host__ __device__ static constexpr bool
|
||||
IsLowerIndexValidAssumingUpperIndexIsValid(const LowerIndex& idx_low)
|
||||
{
|
||||
bool flag = true;
|
||||
|
||||
static_for<0, nTransform, 1>{}([&](auto itran) {
|
||||
constexpr auto tran = Transforms{}.At(itran);
|
||||
|
||||
// check a indtransformation if it does not always has a valid mapping
|
||||
constexpr bool is_valid_up_always_mapped_to_valid_low =
|
||||
decltype(tran)::IsValidUpperIndexAlwaysMappedToValidLowerIndex();
|
||||
|
||||
if(!is_valid_up_always_mapped_to_valid_low)
|
||||
{
|
||||
constexpr auto low_dims_part = LowDimensionIds{}.At(itran);
|
||||
constexpr auto low_lengths_part =
|
||||
GetLowerTensorDescriptor().GetLengths(low_dims_part);
|
||||
const auto idx_low_part =
|
||||
to_multi_index(pick_container_element(idx_low, low_dims_part));
|
||||
|
||||
static_for<0, decltype(low_dims_part)::Size(), 1>{}([&](auto i) {
|
||||
flag = flag && idx_low_part[i] >= 0 && idx_low_part[i] < low_lengths_part[i];
|
||||
});
|
||||
}
|
||||
});
|
||||
|
||||
return flag;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
@@ -1,176 +0,0 @@
|
||||
#ifndef CK_TENSOR_DESCRIPTOR_HELPER_HPP
|
||||
#define CK_TENSOR_DESCRIPTOR_HELPER_HPP
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "tensor_descriptor.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <typename Lengths>
|
||||
__host__ __device__ constexpr auto calculate_tensor_strides_packed(Lengths)
|
||||
{
|
||||
return reverse_inclusive_scan_sequence(
|
||||
Lengths{}.PopFront(), math::multiplies<index_t>{}, Number<1>{})
|
||||
.PushBack(Number<1>{});
|
||||
}
|
||||
|
||||
template <typename Lengths, index_t Align>
|
||||
__host__ __device__ constexpr auto calculate_tensor_strides_aligned(Lengths, Number<Align>)
|
||||
{
|
||||
constexpr index_t L_back_align =
|
||||
Align * math::integer_divide_ceiler<index_t>{}(Lengths{}.Back(), Align);
|
||||
|
||||
return calculate_tensor_strides_packed(
|
||||
Lengths{}.Modify(Number<Lengths{}.GetSize() - 1>{}, Number<L_back_align>{}));
|
||||
}
|
||||
|
||||
template <index_t... Lengths, index_t... Strides>
|
||||
__host__ __device__ constexpr auto make_native_tensor_descriptor(Sequence<Lengths...>,
|
||||
Sequence<Strides...>)
|
||||
{
|
||||
return NativeTensorDescriptor<NativeDimension<Lengths, Strides>...>{};
|
||||
}
|
||||
|
||||
template <typename Lengths>
|
||||
__host__ __device__ constexpr auto make_native_tensor_descriptor_packed(Lengths)
|
||||
{
|
||||
constexpr auto strides = calculate_tensor_strides_packed(Lengths{});
|
||||
|
||||
return make_native_tensor_descriptor(Lengths{}, strides);
|
||||
}
|
||||
|
||||
template <typename Lengths, index_t Align>
|
||||
__host__ __device__ constexpr auto make_native_tensor_descriptor_aligned(Lengths, Number<Align>)
|
||||
{
|
||||
constexpr auto strides = calculate_tensor_strides_aligned(Lengths{}, Number<Align>{});
|
||||
return make_native_tensor_descriptor(Lengths{}, strides);
|
||||
}
|
||||
|
||||
template <typename LowTensorDescriptor,
|
||||
typename Transforms,
|
||||
typename LowDimensionIds,
|
||||
typename UpDimensionIds>
|
||||
__host__ __device__ constexpr auto
|
||||
transform_tensor_descriptor(LowTensorDescriptor, Transforms, LowDimensionIds, UpDimensionIds)
|
||||
{
|
||||
return TransformedTensorDescriptor<LowTensorDescriptor,
|
||||
Transforms,
|
||||
LowDimensionIds,
|
||||
UpDimensionIds>{};
|
||||
}
|
||||
|
||||
template <typename LowerTensorDescriptor,
|
||||
index_t... LowerLengths,
|
||||
index_t... LowerDimensionIds,
|
||||
index_t... UpperDimensionIds>
|
||||
__host__ __device__ constexpr auto
|
||||
reorder_transformed_tensor_descriptor_impl(LowerTensorDescriptor,
|
||||
Sequence<LowerLengths...>,
|
||||
Sequence<LowerDimensionIds...>,
|
||||
Sequence<UpperDimensionIds...>)
|
||||
{
|
||||
return TransformedTensorDescriptor<LowerTensorDescriptor,
|
||||
Tuple<PassThrough<LowerLengths>...>,
|
||||
Tuple<Sequence<LowerDimensionIds>...>,
|
||||
Tuple<Sequence<UpperDimensionIds>...>>{};
|
||||
}
|
||||
|
||||
// reorder a NativeTensorDescriptor
|
||||
template <typename... Ts, typename MapLower2Upper>
|
||||
__host__ __device__ constexpr auto
|
||||
reorder_tensor_descriptor_given_lower2upper(NativeTensorDescriptor<Ts...>, MapLower2Upper)
|
||||
{
|
||||
static_assert(is_valid_sequence_map<MapLower2Upper>{},
|
||||
"wrong! MapLower2Upper is not a valid map");
|
||||
|
||||
constexpr auto old_desc = NativeTensorDescriptor<Ts...>{};
|
||||
|
||||
static_assert(old_desc.GetNumOfDimension() == MapLower2Upper::Size(), "wrong!");
|
||||
|
||||
constexpr auto new_lengths = old_desc.GetLengths().ReorderGivenOld2New(MapLower2Upper{});
|
||||
constexpr auto new_strides = old_desc.GetStrides().ReorderGivenOld2New(MapLower2Upper{});
|
||||
|
||||
return make_native_tensor_descriptor(new_lengths, new_strides);
|
||||
}
|
||||
|
||||
// reorder a TransformedTensorDescriptor
|
||||
template <typename... Ts, typename MapLower2Upper>
|
||||
__host__ __device__ constexpr auto
|
||||
reorder_tensor_descriptor_given_lower2upper(TransformedTensorDescriptor<Ts...>, MapLower2Upper)
|
||||
{
|
||||
static_assert(is_valid_sequence_map<MapLower2Upper>{},
|
||||
"wrong! MapLower2Upper is not a valid map");
|
||||
|
||||
constexpr auto low_desc = TransformedTensorDescriptor<Ts...>{};
|
||||
|
||||
static_assert(low_desc.GetNumOfDimension() == MapLower2Upper::Size(), "wrong!");
|
||||
|
||||
return reorder_transformed_tensor_descriptor_impl(
|
||||
low_desc,
|
||||
low_desc.GetLengths(),
|
||||
typename arithmetic_sequence_gen<0, low_desc.GetNumOfDimension(), 1>::type{},
|
||||
MapLower2Upper{});
|
||||
}
|
||||
|
||||
template <typename LowerTensorDescriptor, typename MapUpper2Lower>
|
||||
__host__ __device__ constexpr auto
|
||||
reorder_tensor_descriptor_given_upper2lower(LowerTensorDescriptor, MapUpper2Lower)
|
||||
{
|
||||
return reorder_tensor_descriptor_given_lower2upper(
|
||||
LowerTensorDescriptor{}, typename sequence_map_inverse<MapUpper2Lower>::type{});
|
||||
}
|
||||
|
||||
template <typename Lengths, typename Strides>
|
||||
__host__ __device__ constexpr bool are_dimensions_unfoldable(Lengths, Strides)
|
||||
{
|
||||
static_assert(Lengths::Size() == Strides::Size(), "wrong!");
|
||||
|
||||
bool flag = true;
|
||||
|
||||
for(index_t i = 0; i < Lengths::Size() - 1; ++i)
|
||||
{
|
||||
flag = flag && Strides::At(i) == Strides::At(i + 1) * Lengths::At(i + 1);
|
||||
}
|
||||
|
||||
return flag;
|
||||
}
|
||||
|
||||
// unfold only support NativeTennsorDescriptor, for now
|
||||
template <index_t FirstUnfoldDim, index_t LastUnfoldDim, typename... Ts>
|
||||
__host__ __device__ constexpr auto unfold_tensor_descriptor(NativeTensorDescriptor<Ts...> desc,
|
||||
Number<FirstUnfoldDim>,
|
||||
Number<LastUnfoldDim>)
|
||||
{
|
||||
constexpr index_t nDim = desc.GetNumOfDimension();
|
||||
|
||||
static_assert(FirstUnfoldDim >= 0 && LastUnfoldDim < nDim && FirstUnfoldDim <= LastUnfoldDim,
|
||||
"wrong! should have FirstUnfoldDim <= LastUnfoldDim!");
|
||||
|
||||
// left and right
|
||||
constexpr auto left = typename arithmetic_sequence_gen<0, FirstUnfoldDim, 1>::type{};
|
||||
constexpr auto middle =
|
||||
typename arithmetic_sequence_gen<FirstUnfoldDim, LastUnfoldDim + 1, 1>::type{};
|
||||
constexpr auto right = typename arithmetic_sequence_gen<LastUnfoldDim + 1, nDim, 1>::type{};
|
||||
|
||||
// sanity-check if unfold-able
|
||||
static_assert(are_dimensions_unfoldable(desc.GetLengths(middle), desc.GetStrides(middle)),
|
||||
"wrong! not unfold-able");
|
||||
|
||||
// unfolded length, stride
|
||||
constexpr index_t unfold_length =
|
||||
reduce_on_sequence(desc.GetLengths(middle), math::multiplies<index_t>{}, Number<1>{});
|
||||
|
||||
constexpr index_t unfold_stride = desc.GetStride(Number<LastUnfoldDim>{});
|
||||
|
||||
// new lengths, strides
|
||||
constexpr auto new_lengths =
|
||||
desc.GetLengths(left).PushBack(Number<unfold_length>{}).PushBack(desc.GetLengths(right));
|
||||
|
||||
constexpr auto new_strides =
|
||||
desc.GetStrides(left).PushBack(Number<unfold_stride>{}).PushBack(desc.GetStrides(right));
|
||||
|
||||
return make_native_tensor_descriptor(new_lengths, new_strides);
|
||||
}
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
@@ -1,406 +0,0 @@
|
||||
#ifndef CK_BLOCKWISE_BATCHED_GEMM_HPP
|
||||
#define CK_BLOCKWISE_BATCHED_GEMM_HPP
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "ConstantMatrixDescriptor.hpp"
|
||||
#include "threadwise_gemm.hpp"
|
||||
|
||||
#ifndef CK_BLOCKWISE_GEMM_USE_AMD_INLINE_ASM
|
||||
#define CK_BLOCKWISE_GEMM_USE_AMD_INLINE_ASM 1
|
||||
#endif
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <index_t BlockSize,
|
||||
class BlockMatrixA,
|
||||
class BlockMatrixB,
|
||||
class ThreadMatrixC,
|
||||
index_t BlockMatrixStrideA,
|
||||
index_t BlockMatrixStrideB,
|
||||
index_t ThreadMatrixStrideC,
|
||||
index_t BatchSize,
|
||||
index_t MPerThreadSubC,
|
||||
index_t NPerThreadSubC,
|
||||
index_t MLevel0Cluster,
|
||||
index_t NLevel0Cluster,
|
||||
index_t MLevel1Cluster,
|
||||
index_t NLevel1Cluster,
|
||||
index_t KPerThreadLoop,
|
||||
index_t BatchPerThread,
|
||||
index_t DataPerReadA,
|
||||
index_t DataPerReadB>
|
||||
struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
|
||||
{
|
||||
index_t mMyThreadOffsetA = 0;
|
||||
index_t mMyThreadOffsetB = 0;
|
||||
|
||||
struct MatrixIndex
|
||||
{
|
||||
index_t batch;
|
||||
index_t row;
|
||||
index_t col;
|
||||
};
|
||||
|
||||
__device__ BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2()
|
||||
{
|
||||
static_assert(BatchSize % BatchPerThread == 0,
|
||||
"wrong! BatchSize is not dividable by BatchPerThread");
|
||||
|
||||
constexpr index_t BatchThreadWork = BatchSize / BatchPerThread;
|
||||
|
||||
constexpr index_t ThreadPerLevel1Cluster =
|
||||
MLevel0Cluster * NLevel0Cluster * MLevel1Cluster * NLevel1Cluster;
|
||||
|
||||
static_assert(BlockSize == BatchThreadWork * ThreadPerLevel1Cluster,
|
||||
"wrong! wrong blocksize\n");
|
||||
|
||||
constexpr auto a_block_mtx = BlockMatrixA{};
|
||||
constexpr auto b_block_mtx = BlockMatrixB{};
|
||||
constexpr auto c_thread_mtx = ThreadMatrixC{};
|
||||
|
||||
static_assert(a_block_mtx.NRow() == b_block_mtx.NRow(),
|
||||
"wrong! K dimension not consistent\n");
|
||||
|
||||
constexpr index_t M = a_block_mtx.NCol(); // A is transposed
|
||||
constexpr index_t N = b_block_mtx.NCol();
|
||||
|
||||
constexpr index_t MPerThread = c_thread_mtx.NRow();
|
||||
constexpr index_t NPerThread = c_thread_mtx.NCol();
|
||||
|
||||
static_assert((MPerThread % MPerThreadSubC == 0) && (NPerThread % NPerThreadSubC == 0),
|
||||
"wrong! Cannot evenly divide thread work among repeat \n");
|
||||
|
||||
constexpr index_t MRepeat = MPerThread / MPerThreadSubC;
|
||||
constexpr index_t NRepeat = NPerThread / NPerThreadSubC;
|
||||
|
||||
static_assert((M % MRepeat == 0) && (N % NRepeat == 0),
|
||||
"wrong! Cannot evenly divide work among repeat\n");
|
||||
|
||||
constexpr index_t MPerLevel1Cluster = M / MRepeat;
|
||||
constexpr index_t NPerLevel1Cluster = N / NRepeat;
|
||||
|
||||
static_assert((MPerLevel1Cluster % MLevel1Cluster == 0) &&
|
||||
(NPerLevel1Cluster % NLevel1Cluster == 0),
|
||||
"wrong! Cannot evenly divide work among Level1Cluster\n");
|
||||
|
||||
constexpr index_t MPerLevel0Cluster = MPerLevel1Cluster / MLevel1Cluster;
|
||||
constexpr index_t NPerLevel0Cluster = NPerLevel1Cluster / NLevel1Cluster;
|
||||
|
||||
static_assert((MPerLevel0Cluster % MLevel0Cluster == 0) &&
|
||||
(NPerLevel0Cluster % NLevel0Cluster == 0),
|
||||
"wrong! Cannot evenly divide work among Level0Cluster\n");
|
||||
|
||||
static_assert((MPerThreadSubC == MPerLevel0Cluster / MLevel0Cluster) &&
|
||||
(NPerThreadSubC == NPerLevel0Cluster / NLevel0Cluster),
|
||||
"wrong! thread work size is wrong\n");
|
||||
|
||||
const auto c_thread_mtx_index = GetBeginOfThreadMatrixC(get_thread_local_1d_id());
|
||||
|
||||
mMyThreadOffsetA = c_thread_mtx_index.batch * BlockMatrixStrideA +
|
||||
a_block_mtx.GetOffsetFromMultiIndex(0, c_thread_mtx_index.row);
|
||||
|
||||
mMyThreadOffsetB = c_thread_mtx_index.batch * BlockMatrixStrideB +
|
||||
b_block_mtx.GetOffsetFromMultiIndex(0, c_thread_mtx_index.col);
|
||||
}
|
||||
|
||||
__device__ MatrixIndex GetBeginOfThreadMatrixC(index_t thread_id) const
|
||||
{
|
||||
constexpr index_t ThreadPerLevel1Cluster =
|
||||
MLevel0Cluster * NLevel0Cluster * MLevel1Cluster * NLevel1Cluster;
|
||||
|
||||
constexpr index_t ThreadPerLevel0Cluster = MLevel0Cluster * NLevel0Cluster;
|
||||
|
||||
index_t batch_work_id = thread_id / ThreadPerLevel1Cluster;
|
||||
index_t cluster_id = thread_id - batch_work_id * ThreadPerLevel1Cluster;
|
||||
|
||||
index_t level1_id = cluster_id / ThreadPerLevel0Cluster;
|
||||
index_t level1_m_id = level1_id / NLevel1Cluster;
|
||||
index_t level1_n_id = level1_id % NLevel1Cluster;
|
||||
|
||||
index_t level0_id = cluster_id % ThreadPerLevel0Cluster;
|
||||
index_t level0_m_id = level0_id / NLevel0Cluster;
|
||||
index_t level0_n_id = level0_id % NLevel0Cluster;
|
||||
|
||||
constexpr index_t MPerLevel0Cluster = MPerThreadSubC * MLevel0Cluster;
|
||||
constexpr index_t NPerLevel0Cluster = NPerThreadSubC * NLevel0Cluster;
|
||||
|
||||
return MatrixIndex{batch_work_id * BatchPerThread,
|
||||
level1_m_id * MPerLevel0Cluster + level0_m_id * MPerThreadSubC,
|
||||
level1_n_id * NPerLevel0Cluster + level0_n_id * NPerThreadSubC};
|
||||
}
|
||||
|
||||
// this should be optimized away because input will be known at compile time
|
||||
__device__ static MatrixIndex
|
||||
GetDistanceFromBeginOfThreadMatrixC(index_t batch_in_c, index_t m_in_c, index_t n_in_c)
|
||||
{
|
||||
constexpr auto c_thread_mtx = ThreadMatrixC{};
|
||||
|
||||
constexpr index_t MPerThread = c_thread_mtx.NRow();
|
||||
constexpr index_t NPerThread = c_thread_mtx.NCol();
|
||||
|
||||
constexpr index_t MRepeat = MPerThread / MPerThreadSubC;
|
||||
constexpr index_t NRepeat = NPerThread / NPerThreadSubC;
|
||||
|
||||
constexpr index_t MPerLevel1Cluster = MPerThreadSubC * MLevel0Cluster * MLevel1Cluster;
|
||||
constexpr index_t NPerLevel1Cluster = NPerThreadSubC * NLevel0Cluster * NLevel1Cluster;
|
||||
|
||||
index_t m_repeat = m_in_c / MPerThreadSubC;
|
||||
index_t n_repeat = n_in_c / NPerThreadSubC;
|
||||
|
||||
index_t m_in_sub_c = m_in_c % MPerThreadSubC;
|
||||
index_t n_in_sub_c = n_in_c % NPerThreadSubC;
|
||||
|
||||
return MatrixIndex{batch_in_c,
|
||||
m_repeat * MPerLevel1Cluster + m_in_sub_c,
|
||||
n_repeat * NPerLevel1Cluster + n_in_sub_c};
|
||||
}
|
||||
|
||||
template <class FloatA, class FloatB, class FloatC>
|
||||
__device__ void Run_source(const FloatA* __restrict__ p_a_block,
|
||||
const FloatB* __restrict__ p_b_block,
|
||||
FloatC* __restrict__ p_c_thread) const
|
||||
{
|
||||
constexpr auto True = integral_constant<bool, true>{};
|
||||
constexpr auto False = integral_constant<bool, false>{};
|
||||
|
||||
constexpr auto a_block_mtx = BlockMatrixA{};
|
||||
constexpr auto b_block_mtx = BlockMatrixB{};
|
||||
constexpr auto c_thread_mtx = ThreadMatrixC{};
|
||||
|
||||
constexpr index_t KPerBlock = a_block_mtx.NRow(); // A is transposed
|
||||
|
||||
constexpr index_t MPerThread = c_thread_mtx.NRow();
|
||||
constexpr index_t NPerThread = c_thread_mtx.NCol();
|
||||
|
||||
// thread A, B for GEMM
|
||||
// A is transposed, b is not
|
||||
constexpr auto a_thread_mtx =
|
||||
make_ConstantMatrixDescriptor_packed(Number<KPerThreadLoop>{}, Number<MPerThread>{});
|
||||
|
||||
constexpr auto b_thread_mtx =
|
||||
make_ConstantMatrixDescriptor_packed(Number<KPerThreadLoop>{}, Number<NPerThread>{});
|
||||
|
||||
// thread A-sub, B-sub for copy
|
||||
constexpr auto a_thread_sub_mtx = make_ConstantMatrixDescriptor(
|
||||
Number<KPerThreadLoop>{}, Number<MPerThreadSubC>{}, Number<MPerThread>{});
|
||||
|
||||
constexpr auto b_thread_sub_mtx = make_ConstantMatrixDescriptor(
|
||||
Number<KPerThreadLoop>{}, Number<NPerThreadSubC>{}, Number<NPerThread>{});
|
||||
|
||||
FloatA p_a_thread[a_thread_mtx.GetElementSpace()];
|
||||
FloatB p_b_thread[b_thread_mtx.GetElementSpace()];
|
||||
|
||||
constexpr index_t MPerLevel1Cluster = MPerThreadSubC * MLevel0Cluster * MLevel1Cluster;
|
||||
constexpr index_t NPerLevel1Cluster = NPerThreadSubC * NLevel0Cluster * NLevel1Cluster;
|
||||
|
||||
constexpr index_t MRepeat = MPerThread / MPerThreadSubC;
|
||||
constexpr index_t NRepeat = NPerThread / NPerThreadSubC;
|
||||
|
||||
// loop over k
|
||||
#pragma unroll
|
||||
for(index_t k_begin = 0; k_begin < KPerBlock; k_begin += KPerThreadLoop)
|
||||
{
|
||||
// loop over batch
|
||||
#pragma unroll
|
||||
for(index_t ib = 0; ib < BatchPerThread; ++ib)
|
||||
{
|
||||
// read next batch of a, b
|
||||
if(BlockMatrixStrideA != 0 or ib == 0)
|
||||
{
|
||||
#pragma unroll
|
||||
for(index_t m_repeat = 0; m_repeat < MRepeat; ++m_repeat)
|
||||
{
|
||||
threadwise_matrix_copy(a_block_mtx,
|
||||
p_a_block +
|
||||
a_block_mtx.GetOffsetFromMultiIndex(
|
||||
k_begin, m_repeat * MPerLevel1Cluster) +
|
||||
ib * BlockMatrixStrideA + mMyThreadOffsetA,
|
||||
a_thread_mtx,
|
||||
p_a_thread + a_thread_mtx.GetOffsetFromMultiIndex(
|
||||
0, m_repeat * MPerThreadSubC),
|
||||
a_thread_sub_mtx.GetLengths(),
|
||||
Number<DataPerReadA>{});
|
||||
}
|
||||
}
|
||||
|
||||
if(BlockMatrixStrideB != 0 or ib == 0)
|
||||
{
|
||||
#pragma unroll
|
||||
for(index_t n_repeat = 0; n_repeat < NRepeat; ++n_repeat)
|
||||
{
|
||||
threadwise_matrix_copy(b_block_mtx,
|
||||
p_b_block +
|
||||
b_block_mtx.GetOffsetFromMultiIndex(
|
||||
k_begin, n_repeat * NPerLevel1Cluster) +
|
||||
ib * BlockMatrixStrideB + mMyThreadOffsetB,
|
||||
b_thread_mtx,
|
||||
p_b_thread + b_thread_mtx.GetOffsetFromMultiIndex(
|
||||
0, n_repeat * NPerThreadSubC),
|
||||
b_thread_sub_mtx.GetLengths(),
|
||||
Number<DataPerReadB>{});
|
||||
}
|
||||
}
|
||||
|
||||
threadwise_gemm(a_thread_mtx,
|
||||
True,
|
||||
p_a_thread,
|
||||
b_thread_mtx,
|
||||
False,
|
||||
p_b_thread,
|
||||
c_thread_mtx,
|
||||
False,
|
||||
p_c_thread + ib * ThreadMatrixStrideC);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#if CK_USE_AMD_INLINE_ASM
|
||||
template <class FloatA, class FloatB, class FloatC>
|
||||
__device__ void Run_amd_asm(const FloatA* __restrict__ p_a_block,
|
||||
const FloatB* __restrict__ p_b_block,
|
||||
FloatC* __restrict__ p_c_thread) const
|
||||
{
|
||||
constexpr auto a_block_mtx = BlockMatrixA{};
|
||||
constexpr auto b_block_mtx = BlockMatrixB{};
|
||||
constexpr auto c_thread_mtx = ThreadMatrixC{};
|
||||
|
||||
constexpr index_t K = a_block_mtx.NRow(); // A is transposed
|
||||
|
||||
constexpr index_t MPerThread = c_thread_mtx.NRow();
|
||||
constexpr index_t NPerThread = c_thread_mtx.NCol();
|
||||
|
||||
// thread A, B for GEMM
|
||||
// A is transposed, b is not
|
||||
constexpr auto a_thread_mtx =
|
||||
make_ConstantMatrixDescriptor_packed(Number<KPerThreadLoop>{}, Number<MPerThread>{});
|
||||
|
||||
constexpr auto b_thread_mtx =
|
||||
make_ConstantMatrixDescriptor_packed(Number<KPerThreadLoop>{}, Number<NPerThread>{});
|
||||
|
||||
// thread A-sub, B-sub for copy
|
||||
constexpr auto a_thread_sub_mtx = make_ConstantMatrixDescriptor(
|
||||
Number<KPerThreadLoop>{}, Number<MPerThreadSubC>{}, Number<MPerThread>{});
|
||||
|
||||
constexpr auto b_thread_sub_mtx = make_ConstantMatrixDescriptor(
|
||||
Number<KPerThreadLoop>{}, Number<NPerThreadSubC>{}, Number<NPerThread>{});
|
||||
|
||||
FloatA p_a_thread[a_thread_mtx.GetElementSpace()];
|
||||
FloatB p_b_thread[b_thread_mtx.GetElementSpace()];
|
||||
|
||||
constexpr index_t MPerLevel1Cluster = MPerThreadSubC * MLevel0Cluster * MLevel1Cluster;
|
||||
constexpr index_t NPerLevel1Cluster = NPerThreadSubC * NLevel0Cluster * NLevel1Cluster;
|
||||
|
||||
// assertion for inline asm
|
||||
static_assert(is_same<FloatA, float>{} && is_same<FloatB, float>{} &&
|
||||
is_same<FloatC, float>{},
|
||||
"Run_amd_asm only deal with float\n");
|
||||
|
||||
static_assert(MPerThreadSubC == 4 && NPerThreadSubC == 4 && KPerThreadLoop == 1 &&
|
||||
MPerThread == 8 && NPerThread == 8,
|
||||
"Run_amd_asm cannot deal with this GEMM shape yet\n");
|
||||
|
||||
static_assert(DataPerReadA == 4 && DataPerReadB == 4, "Run_amd_asm only do float4 read\n");
|
||||
|
||||
static_assert(BlockMatrixStrideA == 0 && BatchPerThread == 1,
|
||||
"Run_amd_asm can only deal with BlockMatrixStrideA == 0 && BatchPerThread == "
|
||||
"1 for now\n");
|
||||
|
||||
using Float4 = vector_type<float, 4>::type;
|
||||
|
||||
Float4* reg_a = (Float4*)(p_a_thread);
|
||||
Float4* reg_b = (Float4*)(p_b_thread);
|
||||
Float4* reg_c = (Float4*)(p_c_thread);
|
||||
|
||||
reg_a[0] = *reinterpret_cast<const Float4*>(&p_a_block[mMyThreadOffsetA]);
|
||||
reg_b[0] = *reinterpret_cast<const Float4*>(&p_b_block[mMyThreadOffsetB]);
|
||||
reg_b[1] = *reinterpret_cast<const Float4*>(
|
||||
&p_b_block[b_block_mtx.GetOffsetFromMultiIndex(0, NPerLevel1Cluster) +
|
||||
mMyThreadOffsetB]);
|
||||
reg_a[1] = *reinterpret_cast<const Float4*>(
|
||||
&p_a_block[a_block_mtx.GetOffsetFromMultiIndex(0, MPerLevel1Cluster) +
|
||||
mMyThreadOffsetA]);
|
||||
outerProduct4x4(reg_a[0], reg_b[0], reg_c[0], reg_c[2], reg_c[4], reg_c[6]);
|
||||
outerProduct4x4(reg_a[0], reg_b[1], reg_c[1], reg_c[3], reg_c[5], reg_c[7]);
|
||||
|
||||
#pragma unroll
|
||||
for(index_t k = 1; k < K; ++k)
|
||||
{
|
||||
reg_a[0] = *reinterpret_cast<const Float4*>(
|
||||
&p_a_block[a_block_mtx.GetOffsetFromMultiIndex(k, 0) + mMyThreadOffsetA]);
|
||||
outerProduct4x4(reg_a[1], reg_b[0], reg_c[8], reg_c[10], reg_c[12], reg_c[14]);
|
||||
reg_b[0] = *reinterpret_cast<const Float4*>(
|
||||
&p_b_block[b_block_mtx.GetOffsetFromMultiIndex(k, 0) + mMyThreadOffsetB]);
|
||||
outerProduct4x4(reg_a[1], reg_b[1], reg_c[9], reg_c[11], reg_c[13], reg_c[15]);
|
||||
reg_b[1] = *reinterpret_cast<const Float4*>(
|
||||
&p_b_block[b_block_mtx.GetOffsetFromMultiIndex(k, NPerLevel1Cluster) +
|
||||
mMyThreadOffsetB]);
|
||||
reg_a[1] = *reinterpret_cast<const Float4*>(
|
||||
&p_a_block[a_block_mtx.GetOffsetFromMultiIndex(k, MPerLevel1Cluster) +
|
||||
mMyThreadOffsetA]);
|
||||
outerProduct4x4(reg_a[0], reg_b[0], reg_c[0], reg_c[2], reg_c[4], reg_c[6]);
|
||||
outerProduct4x4(reg_a[0], reg_b[1], reg_c[1], reg_c[3], reg_c[5], reg_c[7]);
|
||||
}
|
||||
outerProduct4x4(reg_a[1], reg_b[0], reg_c[8], reg_c[10], reg_c[12], reg_c[14]);
|
||||
outerProduct4x4(reg_a[1], reg_b[1], reg_c[9], reg_c[11], reg_c[13], reg_c[15]);
|
||||
}
|
||||
#endif
|
||||
|
||||
template <class FloatA, class FloatB, class FloatC>
|
||||
__device__ void Run(const FloatA* __restrict__ p_a_block,
|
||||
const FloatB* __restrict__ p_b_block,
|
||||
FloatC* __restrict__ p_c_thread) const
|
||||
|
||||
{
|
||||
#if CK_USE_AMD_INLINE_ASM && CK_BLOCKWISE_GEMM_USE_AMD_INLINE_ASM
|
||||
Run_amd_asm(p_a_block, p_b_block, p_c_thread);
|
||||
#else
|
||||
Run_source(p_a_block, p_b_block, p_c_thread);
|
||||
#endif
|
||||
}
|
||||
|
||||
template <class BlockMatrixC, index_t BlockMatrixStrideC, class FloatC>
|
||||
__device__ void CopyThreadMatrixCToBlockMatrixC(const FloatC* __restrict__ p_c_thread,
|
||||
FloatC* __restrict__ p_c_block) const
|
||||
{
|
||||
constexpr auto c_block_mtx = BlockMatrixC{};
|
||||
constexpr auto c_thread_mtx = ThreadMatrixC{};
|
||||
|
||||
constexpr index_t MPerThread = c_thread_mtx.NRow();
|
||||
constexpr index_t NPerThread = c_thread_mtx.NCol();
|
||||
|
||||
constexpr auto c_thread_sub_mtx = make_ConstantMatrixDescriptor(
|
||||
Number<MPerThreadSubC>{}, Number<NPerThreadSubC>{}, Number<NPerThread>{});
|
||||
|
||||
constexpr index_t MPerLevel1Cluster = MPerThreadSubC * MLevel0Cluster * MLevel1Cluster;
|
||||
constexpr index_t NPerLevel1Cluster = NPerThreadSubC * NLevel0Cluster * NLevel1Cluster;
|
||||
|
||||
constexpr index_t MRepeat = MPerThread / MPerThreadSubC;
|
||||
constexpr index_t NRepeat = NPerThread / NPerThreadSubC;
|
||||
|
||||
const auto c_thread_mtx_begin = GetBeginOfThreadMatrixC(get_thread_local_1d_id());
|
||||
|
||||
const index_t c_thread_offset =
|
||||
c_thread_mtx_begin.batch * BlockMatrixStrideC +
|
||||
c_block_mtx.GetOffsetFromMultiIndex(c_thread_mtx_begin.row, c_thread_mtx_begin.col);
|
||||
|
||||
for(index_t m_repeat = 0; m_repeat < MRepeat; ++m_repeat)
|
||||
{
|
||||
for(index_t n_repeat = 0; n_repeat < NRepeat; ++n_repeat)
|
||||
{
|
||||
threadwise_matrix_copy(
|
||||
c_thread_sub_mtx,
|
||||
p_c_thread + c_thread_sub_mtx.GetOffsetFromMultiIndex(
|
||||
m_repeat * MPerLevel1Cluster, n_repeat * NPerLevel1Cluster),
|
||||
c_block_mtx,
|
||||
p_c_block +
|
||||
c_block_mtx.GetOffsetFromMultiIndex(m_repeat * MPerLevel1Cluster,
|
||||
n_repeat * NPerLevel1Cluster) +
|
||||
c_thread_offset,
|
||||
c_thread_sub_mtx.GetLengths());
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
@@ -1,334 +0,0 @@
|
||||
#ifndef CK_BLOCKWISE_GEMM_HPP
|
||||
#define CK_BLOCKWISE_GEMM_HPP
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "ConstantMatrixDescriptor.hpp"
|
||||
#include "threadwise_gemm.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
// blockwise GEMM: C += transpose(A) * B
|
||||
// A and B are visable to the whole block, C is distributed among each thread
|
||||
// If following number are power of 2, index calculation shall be greatly reduced:
|
||||
// MPerThreadSubC, NPerThreadSubC, MLevel0ThreadCluster, NLevel0ThreadCluster,
|
||||
// MLevel1ThreadCluster, NLevel1ThreadCluster
|
||||
template <index_t BlockSize,
|
||||
typename BlockMatrixA,
|
||||
typename BlockMatrixB,
|
||||
typename ThreadMatrixC,
|
||||
index_t MPerThreadSubC,
|
||||
index_t NPerThreadSubC,
|
||||
index_t KPerThreadLoop,
|
||||
index_t MLevel0ThreadCluster,
|
||||
index_t NLevel0ThreadCluster,
|
||||
index_t MLevel1ThreadCluster,
|
||||
index_t NLevel1ThreadCluster,
|
||||
index_t ThreadGemmADataPerRead_M,
|
||||
index_t ThreadGemmBDataPerRead_N>
|
||||
struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
|
||||
{
|
||||
struct MatrixIndex
|
||||
{
|
||||
index_t row;
|
||||
index_t col;
|
||||
};
|
||||
|
||||
index_t mMyThreadOffsetA;
|
||||
index_t mMyThreadOffsetB;
|
||||
|
||||
__device__ BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2()
|
||||
{
|
||||
constexpr index_t ThreadPerLevel1Cluster = MLevel0ThreadCluster * NLevel0ThreadCluster *
|
||||
MLevel1ThreadCluster * NLevel1ThreadCluster;
|
||||
|
||||
static_assert(BlockSize == ThreadPerLevel1Cluster, "wrong! wrong blocksize\n");
|
||||
|
||||
static_assert(BlockMatrixA::NRow() == BlockMatrixB::NRow(),
|
||||
"wrong! K dimension not consistent\n");
|
||||
|
||||
constexpr index_t M = BlockMatrixA::NCol(); // A is transposed
|
||||
constexpr index_t N = BlockMatrixB::NCol();
|
||||
|
||||
static_assert(M % (MPerThreadSubC * MLevel0ThreadCluster * MLevel1ThreadCluster) == 0 &&
|
||||
N % (NPerThreadSubC * NLevel0ThreadCluster * NLevel1ThreadCluster) == 0,
|
||||
"wrong! Cannot evenly divide work among\n");
|
||||
|
||||
static_assert(
|
||||
is_same<decltype(ThreadMatrixC::GetLengths()), decltype(GetThreadMatrixCLengths())>{},
|
||||
"wrong! ThreadMatrixC lengths is wrong");
|
||||
|
||||
auto c_thread_mtx_index = GetBeginOfThreadMatrixC(get_thread_local_1d_id());
|
||||
|
||||
mMyThreadOffsetA = BlockMatrixA::GetOffsetFromMultiIndex(0, c_thread_mtx_index.row);
|
||||
mMyThreadOffsetB = BlockMatrixB::GetOffsetFromMultiIndex(0, c_thread_mtx_index.col);
|
||||
}
|
||||
|
||||
__device__ static constexpr auto GetThreadMatrixCLengths()
|
||||
{
|
||||
constexpr index_t M = BlockMatrixA::NCol(); // A is transposed
|
||||
constexpr index_t N = BlockMatrixB::NCol();
|
||||
|
||||
constexpr index_t MRepeat =
|
||||
M / (MPerThreadSubC * MLevel0ThreadCluster * MLevel1ThreadCluster);
|
||||
constexpr index_t NRepeat =
|
||||
N / (NPerThreadSubC * NLevel0ThreadCluster * NLevel1ThreadCluster);
|
||||
|
||||
return Sequence<MRepeat * MPerThreadSubC, NRepeat * NPerThreadSubC>{};
|
||||
}
|
||||
|
||||
__device__ static MatrixIndex GetBeginOfThreadMatrixC(index_t thread_id)
|
||||
{
|
||||
constexpr index_t ThreadPerLevel0Cluster = MLevel0ThreadCluster * NLevel0ThreadCluster;
|
||||
|
||||
index_t level1_id = thread_id / ThreadPerLevel0Cluster;
|
||||
index_t level1_m_id = level1_id / NLevel1ThreadCluster;
|
||||
index_t level1_n_id = level1_id % NLevel1ThreadCluster;
|
||||
|
||||
index_t level0_id = thread_id % ThreadPerLevel0Cluster;
|
||||
index_t level0_m_id = level0_id / NLevel0ThreadCluster;
|
||||
index_t level0_n_id = level0_id % NLevel0ThreadCluster;
|
||||
|
||||
constexpr index_t MPerLevel0Cluster = MPerThreadSubC * MLevel0ThreadCluster;
|
||||
constexpr index_t NPerLevel0Cluster = NPerThreadSubC * NLevel0ThreadCluster;
|
||||
|
||||
return MatrixIndex{level1_m_id * MPerLevel0Cluster + level0_m_id * MPerThreadSubC,
|
||||
level1_n_id * NPerLevel0Cluster + level0_n_id * NPerThreadSubC};
|
||||
}
|
||||
|
||||
template <typename FloatA, typename FloatB, typename FloatC>
|
||||
__device__ void
|
||||
Run_naive(const FloatA* p_a_block, const FloatB* p_b_block, FloatC* p_c_thread) const
|
||||
{
|
||||
constexpr auto a_block_mtx = BlockMatrixA{};
|
||||
constexpr auto b_block_mtx = BlockMatrixB{};
|
||||
constexpr auto c_thread_mtx = ThreadMatrixC{};
|
||||
|
||||
constexpr index_t K = a_block_mtx.NRow();
|
||||
|
||||
constexpr index_t MPerThread = c_thread_mtx.NRow();
|
||||
constexpr index_t NPerThread = c_thread_mtx.NCol();
|
||||
|
||||
constexpr index_t MPerLevel1Cluster =
|
||||
MPerThreadSubC * MLevel0ThreadCluster * MLevel1ThreadCluster;
|
||||
constexpr index_t NPerLevel1Cluster =
|
||||
NPerThreadSubC * NLevel0ThreadCluster * NLevel1ThreadCluster;
|
||||
|
||||
constexpr index_t MRepeat = MPerThread / MPerThreadSubC;
|
||||
constexpr index_t NRepeat = NPerThread / NPerThreadSubC;
|
||||
|
||||
// thread A, B for GEMM
|
||||
constexpr auto a_thread_mtx =
|
||||
make_ConstantMatrixDescriptor_packed(Number<KPerThreadLoop>{}, Number<MPerThread>{});
|
||||
|
||||
constexpr auto b_thread_mtx =
|
||||
make_ConstantMatrixDescriptor_packed(Number<KPerThreadLoop>{}, Number<NPerThread>{});
|
||||
|
||||
FloatA p_a_thread[a_thread_mtx.GetElementSpace()];
|
||||
FloatB p_b_thread[b_thread_mtx.GetElementSpace()];
|
||||
|
||||
constexpr auto a_thread_copy = ThreadwiseMatrixSliceCopy<BlockMatrixA,
|
||||
decltype(a_thread_mtx),
|
||||
KPerThreadLoop,
|
||||
MPerThreadSubC,
|
||||
ThreadGemmADataPerRead_M>{};
|
||||
|
||||
constexpr auto b_thread_copy = ThreadwiseMatrixSliceCopy<BlockMatrixB,
|
||||
decltype(b_thread_mtx),
|
||||
KPerThreadLoop,
|
||||
NPerThreadSubC,
|
||||
ThreadGemmBDataPerRead_N>{};
|
||||
|
||||
constexpr auto threadwise_gemm =
|
||||
ThreadwiseGemmTransANormalBNormalC<decltype(a_thread_mtx),
|
||||
decltype(b_thread_mtx),
|
||||
decltype(c_thread_mtx)>{};
|
||||
#pragma unroll
|
||||
// loop over k
|
||||
for(index_t k_begin = 0; k_begin < K; k_begin += KPerThreadLoop)
|
||||
{
|
||||
#pragma unroll
|
||||
// read A
|
||||
for(index_t m_repeat = 0; m_repeat < MRepeat; ++m_repeat)
|
||||
{
|
||||
a_thread_copy.Run(
|
||||
p_a_block + a_block_mtx.CalculateOffset(k_begin, m_repeat * MPerLevel1Cluster) +
|
||||
mMyThreadOffsetA,
|
||||
p_a_thread + a_thread_mtx.CalculateOffset(0, m_repeat * MPerThreadSubC));
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
// read B
|
||||
for(index_t n_repeat = 0; n_repeat < NRepeat; ++n_repeat)
|
||||
{
|
||||
b_thread_copy.Run(
|
||||
p_b_block + b_block_mtx.CalculateOffset(k_begin, n_repeat * NPerLevel1Cluster) +
|
||||
mMyThreadOffsetB,
|
||||
p_b_thread + b_thread_mtx.CalculateOffset(0, n_repeat * NPerThreadSubC));
|
||||
}
|
||||
|
||||
// C += A * B
|
||||
threadwise_gemm.Run(p_a_thread, p_b_thread, p_c_thread);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename FloatA, typename FloatB, typename FloatC>
|
||||
__device__ void
|
||||
Run_pipelined_2x2(const FloatA* p_a_block, const FloatB* p_b_block, FloatC* p_c_thread) const
|
||||
{
|
||||
constexpr auto a_block_mtx = BlockMatrixA{};
|
||||
constexpr auto b_block_mtx = BlockMatrixB{};
|
||||
constexpr auto c_thread_mtx = ThreadMatrixC{};
|
||||
|
||||
constexpr index_t K = a_block_mtx.NRow();
|
||||
|
||||
constexpr index_t MPerThread = c_thread_mtx.NRow();
|
||||
constexpr index_t NPerThread = c_thread_mtx.NCol();
|
||||
|
||||
constexpr index_t MPerLevel1Cluster =
|
||||
MPerThreadSubC * MLevel0ThreadCluster * MLevel1ThreadCluster;
|
||||
constexpr index_t NPerLevel1Cluster =
|
||||
NPerThreadSubC * NLevel0ThreadCluster * NLevel1ThreadCluster;
|
||||
|
||||
constexpr index_t MRepeat = MPerThread / MPerThreadSubC;
|
||||
constexpr index_t NRepeat = NPerThread / NPerThreadSubC;
|
||||
|
||||
static_assert(MRepeat == 2 && NRepeat == 2,
|
||||
"wrong! inline asm cannot deal with this GEMM config yet");
|
||||
|
||||
// thread A, B
|
||||
constexpr auto a_thread_mtx =
|
||||
make_ConstantMatrixDescriptor_packed(Number<KPerThreadLoop>{}, Number<MPerThread>{});
|
||||
constexpr auto b_thread_mtx =
|
||||
make_ConstantMatrixDescriptor_packed(Number<KPerThreadLoop>{}, Number<NPerThread>{});
|
||||
|
||||
// thread A-sub, B-sub
|
||||
constexpr auto a_thread_sub_mtx = a_thread_mtx.MakeSubMatrixDescriptor(
|
||||
Number<KPerThreadLoop>{}, Number<MPerThreadSubC>{});
|
||||
constexpr auto b_thread_sub_mtx = b_thread_mtx.MakeSubMatrixDescriptor(
|
||||
Number<KPerThreadLoop>{}, Number<NPerThreadSubC>{});
|
||||
|
||||
// thread C-sub
|
||||
constexpr auto c_thread_sub_mtx = ThreadMatrixC::MakeSubMatrixDescriptor(
|
||||
Number<MPerThreadSubC>{}, Number<NPerThreadSubC>{});
|
||||
|
||||
FloatA p_a_thread[a_thread_mtx.GetElementSpace()];
|
||||
FloatB p_b_thread[b_thread_mtx.GetElementSpace()];
|
||||
|
||||
constexpr auto a_thread_copy = ThreadwiseMatrixSliceCopy<BlockMatrixA,
|
||||
decltype(a_thread_mtx),
|
||||
KPerThreadLoop,
|
||||
MPerThreadSubC,
|
||||
ThreadGemmADataPerRead_M>{};
|
||||
|
||||
constexpr auto b_thread_copy = ThreadwiseMatrixSliceCopy<BlockMatrixB,
|
||||
decltype(b_thread_mtx),
|
||||
KPerThreadLoop,
|
||||
NPerThreadSubC,
|
||||
ThreadGemmBDataPerRead_N>{};
|
||||
|
||||
constexpr auto threadwise_gemm =
|
||||
ThreadwiseGemmTransANormalBNormalC<decltype(a_thread_sub_mtx),
|
||||
decltype(b_thread_sub_mtx),
|
||||
decltype(c_thread_sub_mtx)>{};
|
||||
|
||||
const FloatA* p_a_block_off = p_a_block + mMyThreadOffsetA;
|
||||
const FloatB* p_b_block_off = p_b_block + mMyThreadOffsetB;
|
||||
|
||||
// read A_sub_0
|
||||
a_thread_copy.Run(p_a_block_off, p_a_thread);
|
||||
|
||||
// read B_sub_0
|
||||
b_thread_copy.Run(p_b_block_off, p_b_thread);
|
||||
|
||||
// read B_sub_1
|
||||
b_thread_copy.Run(p_b_block_off + b_block_mtx.CalculateOffset(0, NPerLevel1Cluster),
|
||||
p_b_thread + b_thread_mtx.CalculateOffset(0, NPerThreadSubC));
|
||||
|
||||
// read A_sub_1
|
||||
a_thread_copy.Run(p_a_block_off + a_block_mtx.CalculateOffset(0, MPerLevel1Cluster),
|
||||
p_a_thread + a_thread_mtx.CalculateOffset(0, MPerThreadSubC));
|
||||
|
||||
// C_sub_00 += transpose(A_sub_0) * B_sub_0
|
||||
threadwise_gemm.Run(p_a_thread, p_b_thread, p_c_thread);
|
||||
|
||||
// C_sub_01 += transpose(A_sub_0) * B_sub_1
|
||||
threadwise_gemm.Run(p_a_thread,
|
||||
p_b_thread + b_thread_mtx.CalculateOffset(0, NPerThreadSubC),
|
||||
p_c_thread + ThreadMatrixC::CalculateOffset(0, NPerThreadSubC));
|
||||
|
||||
#pragma unroll
|
||||
// loop over rest of k
|
||||
for(index_t k = KPerThreadLoop; k < K; k += KPerThreadLoop)
|
||||
{
|
||||
// read A_sub_0
|
||||
a_thread_copy.Run(p_a_block_off + a_block_mtx.CalculateOffset(k, 0), p_a_thread);
|
||||
|
||||
// C_sub_10 += transpose(A_sub_1) * B_sub_0
|
||||
threadwise_gemm.Run(p_a_thread + a_thread_mtx.CalculateOffset(0, MPerThreadSubC),
|
||||
p_b_thread,
|
||||
p_c_thread + ThreadMatrixC::CalculateOffset(MPerThreadSubC, 0));
|
||||
|
||||
// read B_sub_0
|
||||
b_thread_copy.Run(p_b_block_off + b_block_mtx.CalculateOffset(k, 0), p_b_thread);
|
||||
|
||||
// C_sub_11 += transpose(A_sub_1) * B_sub_1
|
||||
threadwise_gemm.Run(p_a_thread + a_thread_mtx.CalculateOffset(0, MPerThreadSubC),
|
||||
p_b_thread + b_thread_mtx.CalculateOffset(0, NPerThreadSubC),
|
||||
p_c_thread +
|
||||
ThreadMatrixC::CalculateOffset(MPerThreadSubC, NPerThreadSubC));
|
||||
|
||||
// read B_sub_1
|
||||
b_thread_copy.Run(p_b_block_off + b_block_mtx.CalculateOffset(k, NPerLevel1Cluster),
|
||||
p_b_thread + b_thread_mtx.CalculateOffset(0, NPerThreadSubC));
|
||||
|
||||
// read A_sub_1
|
||||
a_thread_copy.Run(p_a_block_off + a_block_mtx.CalculateOffset(k, MPerLevel1Cluster),
|
||||
p_a_thread + a_thread_mtx.CalculateOffset(0, MPerThreadSubC));
|
||||
|
||||
// C_sub_00 += transpose(A_sub_0) * B_sub_0
|
||||
threadwise_gemm.Run(p_a_thread, p_b_thread, p_c_thread);
|
||||
|
||||
// C_sub_01 += transpose(A_sub_0) * B_sub_1
|
||||
threadwise_gemm.Run(p_a_thread,
|
||||
p_b_thread + b_thread_mtx.CalculateOffset(0, NPerThreadSubC),
|
||||
p_c_thread + ThreadMatrixC::CalculateOffset(0, NPerThreadSubC));
|
||||
}
|
||||
|
||||
// C_sub_10 += transpose(A_sub_1) * B_sub_0
|
||||
threadwise_gemm.Run(p_a_thread + a_thread_mtx.CalculateOffset(0, MPerThreadSubC),
|
||||
p_b_thread,
|
||||
p_c_thread + ThreadMatrixC::CalculateOffset(MPerThreadSubC, 0));
|
||||
|
||||
// C_sub_11 += transpose(A_sub_1) * B_sub_1
|
||||
threadwise_gemm.Run(p_a_thread + a_thread_mtx.CalculateOffset(0, MPerThreadSubC),
|
||||
p_b_thread + b_thread_mtx.CalculateOffset(0, NPerThreadSubC),
|
||||
p_c_thread +
|
||||
ThreadMatrixC::CalculateOffset(MPerThreadSubC, NPerThreadSubC));
|
||||
}
|
||||
|
||||
template <typename FloatA, typename FloatB, typename FloatC>
|
||||
__device__ void Run(const FloatA* p_a_block, const FloatB* p_b_block, FloatC* p_c_thread) const
|
||||
{
|
||||
#if CK_EXPERIMENTAL_BLOCKWISE_GEMM_USE_PIPELINE
|
||||
constexpr index_t MPerThread = ThreadMatrixC::NRow();
|
||||
constexpr index_t NPerThread = ThreadMatrixC::NCol();
|
||||
|
||||
constexpr index_t MRepeat = MPerThread / MPerThreadSubC;
|
||||
constexpr index_t NRepeat = NPerThread / NPerThreadSubC;
|
||||
|
||||
if constexpr(MRepeat == 2 && NRepeat == 2)
|
||||
{
|
||||
Run_pipelined_2x2(p_a_block, p_b_block, p_c_thread);
|
||||
}
|
||||
else
|
||||
{
|
||||
Run_naive(p_a_block, p_b_block, p_c_thread);
|
||||
}
|
||||
#else
|
||||
Run_naive(p_a_block, p_b_block, p_c_thread);
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
@@ -1,189 +0,0 @@
|
||||
#ifndef CK_BLOCKWISE_GENERIC_TENSOR_SLICE_COPY_HPP
|
||||
#define CK_BLOCKWISE_GENERIC_TENSOR_SLICE_COPY_HPP
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "tensor_descriptor.hpp"
|
||||
#include "tensor_descriptor_helper.hpp"
|
||||
#include "tensor_coordinate.hpp"
|
||||
#include "cluster_descriptor.hpp"
|
||||
#include "threadwise_generic_tensor_slice_copy.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
// This blockwise copy allow vector access of src and dst.
|
||||
// It allows the vector size to be different on src and dst.
|
||||
// The dimension of vector access can be different for src and dst.
|
||||
// The dimension access order can be different for src and dst.
|
||||
// Will do valid mapping check on src data: Read 0 if src data has a invalid mapping
|
||||
// Will do valid mapping check on dst data: No write if dst data has a invalid mapping
|
||||
// BlockSize can be equal or larger than ThreadCluster size, which means some threads may not do
|
||||
// threadwise copy
|
||||
template <index_t BlockSize,
|
||||
typename BlockSrcDesc,
|
||||
typename BlockDstDesc,
|
||||
typename BlockSliceLengths,
|
||||
typename ThreadSliceLengths,
|
||||
typename ThreadClusterLengths,
|
||||
typename ThreadClusterArrangeOrder,
|
||||
typename SrcDimAccessOrder,
|
||||
typename DstDimAccessOrder,
|
||||
index_t SrcVectoReadDim,
|
||||
index_t DstVectorWriteDim,
|
||||
index_t SrcDataPerRead,
|
||||
index_t DstDataPerWrite,
|
||||
AddressSpace SrcAddressSpace = AddressSpace::Generic,
|
||||
AddressSpace ThreadBufferAddressSpace = AddressSpace::Generic,
|
||||
AddressSpace DstAddressSpace = AddressSpace::Generic,
|
||||
InMemoryDataOperation DstInMemOp = InMemoryDataOperation::Set,
|
||||
index_t SrcDataStride = 1,
|
||||
index_t DstDataStride = 1>
|
||||
struct BlockwiseGenericTensorSliceCopy_v4
|
||||
{
|
||||
static constexpr index_t nDim = BlockSrcDesc::GetNumOfDimension();
|
||||
using Index = MultiIndex<nDim>;
|
||||
|
||||
__device__ constexpr BlockwiseGenericTensorSliceCopy_v4(const Index& src_block_slice_origin,
|
||||
const Index& dst_block_slice_origin)
|
||||
{
|
||||
static_assert(nDim == BlockSrcDesc::GetNumOfDimension() &&
|
||||
nDim == BlockDstDesc::GetNumOfDimension() &&
|
||||
nDim == BlockSliceLengths::Size() && nDim == ThreadSliceLengths::Size() &&
|
||||
nDim == ThreadClusterLengths::Size() &&
|
||||
nDim == ThreadClusterArrangeOrder::Size() &&
|
||||
nDim == SrcDimAccessOrder::Size() && nDim == DstDimAccessOrder::Size(),
|
||||
"wrong! nDim not consistent");
|
||||
|
||||
static_assert(
|
||||
is_same<BlockSliceLengths, decltype(ThreadSliceLengths{} * ThreadClusterLengths{})>{},
|
||||
"wrong! threads should be mapped to cover entire slicing window");
|
||||
|
||||
static_assert(BlockSize >= mThreadClusterDesc.GetElementSize(),
|
||||
"wrong! BlockSize too small");
|
||||
|
||||
if(BlockSize == mThreadClusterDesc.GetElementSize() or
|
||||
get_thread_local_1d_id() < mThreadClusterDesc.GetElementSize())
|
||||
{
|
||||
const auto thread_cluster_id =
|
||||
mThreadClusterDesc.CalculateClusterIndex(get_thread_local_1d_id());
|
||||
|
||||
const auto thread_data_id_begin = thread_cluster_id * ThreadSliceLengths{};
|
||||
|
||||
mThreadwiseLoad.SetSrcSliceOrigin(src_block_slice_origin + thread_data_id_begin);
|
||||
mThreadwiseLoad.SetDstSliceOrigin(make_zero_multi_index<nDim>());
|
||||
|
||||
mThreadwiseStore.SetSrcSliceOrigin(make_zero_multi_index<nDim>());
|
||||
mThreadwiseStore.SetDstSliceOrigin(dst_block_slice_origin + thread_data_id_begin);
|
||||
}
|
||||
}
|
||||
|
||||
__device__ static constexpr index_t GetThreadBufferSize()
|
||||
{
|
||||
return ThreadBufferDesc::GetElementSpace();
|
||||
}
|
||||
|
||||
template <typename BlockSrcData, typename ThreadBufferData>
|
||||
__device__ void RunLoadThreadBuffer(const BlockSrcData* p_block_src,
|
||||
ThreadBufferData* p_thread_buffer) const
|
||||
{
|
||||
if(BlockSize == mThreadClusterDesc.GetElementSize() or
|
||||
get_thread_local_1d_id() < mThreadClusterDesc.GetElementSize())
|
||||
{
|
||||
mThreadwiseLoad.Run(p_block_src, p_thread_buffer);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename ThreadBufferData, typename BlockDstData>
|
||||
__device__ void RunStoreThreadBuffer(const ThreadBufferData* p_thread_buffer,
|
||||
BlockDstData* p_block_dst) const
|
||||
{
|
||||
if(BlockSize == mThreadClusterDesc.GetElementSize() or
|
||||
get_thread_local_1d_id() < mThreadClusterDesc.GetElementSize())
|
||||
{
|
||||
mThreadwiseStore.Run(p_thread_buffer, p_block_dst);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename BlockSrcData, typename BlockDstData>
|
||||
__device__ void Run(const BlockSrcData* p_block_src, BlockDstData* p_block_dst) const
|
||||
{
|
||||
static_assert(ThreadBufferAddressSpace == AddressSpace::Vgpr,
|
||||
"wrong! This function use vgpr as its thread "
|
||||
"buffer. However, you have set RunLoadThreadBuffer and RunStoreThreadBuffer "
|
||||
"to use ThreadBufferAddressSpace as their thread buffer, which is not vgpr. "
|
||||
"Behavior may be different");
|
||||
|
||||
BlockSrcData p_thread_buffer[GetThreadBufferSize()];
|
||||
|
||||
if(BlockSize == mThreadClusterDesc.GetElementSize() or
|
||||
get_thread_local_1d_id() < mThreadClusterDesc.GetElementSize())
|
||||
{
|
||||
RunLoadThreadBuffer(p_block_src, p_thread_buffer);
|
||||
|
||||
// if there is type conversion, it's done during store
|
||||
RunStoreThreadBuffer(p_thread_buffer, p_block_dst);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, bool PositiveDirection>
|
||||
__device__ void
|
||||
MoveSrcSliceWindow(const T& step_sizes,
|
||||
integral_constant<bool, PositiveDirection> positive_direction)
|
||||
{
|
||||
if(BlockSize == mThreadClusterDesc.GetElementSize() or
|
||||
get_thread_local_1d_id() < mThreadClusterDesc.GetElementSize())
|
||||
{
|
||||
mThreadwiseLoad.MoveSrcSliceWindow(step_sizes, positive_direction);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, bool PositiveDirection>
|
||||
__device__ void
|
||||
MoveDstSliceWindow(const T& step_sizes,
|
||||
integral_constant<bool, PositiveDirection> positive_direction)
|
||||
{
|
||||
if(BlockSize == mThreadClusterDesc.GetElementSize() or
|
||||
get_thread_local_1d_id() < mThreadClusterDesc.GetElementSize())
|
||||
{
|
||||
mThreadwiseStore.MoveDstSliceWindow(step_sizes, positive_direction);
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
using ThreadBufferDesc = decltype(make_native_tensor_descriptor_packed(ThreadSliceLengths{}));
|
||||
|
||||
using ThreadwiseLoad = ThreadwiseGenericTensorSliceCopy_v4r2<BlockSrcDesc,
|
||||
ThreadBufferDesc,
|
||||
ThreadSliceLengths,
|
||||
SrcDimAccessOrder,
|
||||
SrcVectoReadDim,
|
||||
SrcDataPerRead,
|
||||
1,
|
||||
SrcAddressSpace,
|
||||
ThreadBufferAddressSpace,
|
||||
InMemoryDataOperation::Set,
|
||||
SrcDataStride,
|
||||
1>;
|
||||
|
||||
using ThreadwiseStore = ThreadwiseGenericTensorSliceCopy_v4r2<ThreadBufferDesc,
|
||||
BlockDstDesc,
|
||||
ThreadSliceLengths,
|
||||
DstDimAccessOrder,
|
||||
DstVectorWriteDim,
|
||||
1,
|
||||
DstDataPerWrite,
|
||||
ThreadBufferAddressSpace,
|
||||
DstAddressSpace,
|
||||
DstInMemOp,
|
||||
1,
|
||||
DstDataStride>;
|
||||
|
||||
static constexpr auto mThreadClusterDesc =
|
||||
make_cluster_descriptor(ThreadClusterLengths{}, ThreadClusterArrangeOrder{});
|
||||
|
||||
ThreadwiseLoad mThreadwiseLoad;
|
||||
ThreadwiseStore mThreadwiseStore;
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
|
||||
#endif
|
||||
@@ -1,785 +0,0 @@
|
||||
#ifndef CK_GRIDWISE_GEMM_HPP
|
||||
#define CK_GRIDWISE_GEMM_HPP
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "tensor_descriptor.hpp"
|
||||
#include "tensor_descriptor_helper.hpp"
|
||||
#include "ConstantMatrixDescriptor.hpp"
|
||||
#include "blockwise_generic_tensor_slice_copy.hpp"
|
||||
#include "threadwise_generic_tensor_slice_copy.hpp"
|
||||
#include "blockwise_gemm.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <index_t GridSize,
|
||||
index_t BlockSize,
|
||||
typename Float,
|
||||
typename AccFloat,
|
||||
typename AGlobalDesc,
|
||||
typename BGlobalDesc,
|
||||
typename CGlobalDesc,
|
||||
InMemoryDataOperation CGlobalMemoryDataOperation,
|
||||
index_t MPerBlock,
|
||||
index_t NPerBlock,
|
||||
index_t KPerBlock,
|
||||
index_t MPerThread,
|
||||
index_t NPerThread,
|
||||
index_t KPerThread,
|
||||
index_t MLevel0Cluster,
|
||||
index_t NLevel0Cluster,
|
||||
index_t MLevel1Cluster,
|
||||
index_t NLevel1Cluster,
|
||||
index_t ThreadGemmAThreadCopySrcDataPerRead_M,
|
||||
index_t ThreadGemmBThreadCopySrcDataPerRead_N,
|
||||
typename ABlockCopyThreadSliceLengths_K_M,
|
||||
typename ABlockCopyThreadClusterLengths_K_M,
|
||||
typename ABlockCopyThreadClusterArrangeOrder,
|
||||
typename ABlockCopySrcAccessOrder,
|
||||
index_t ABlockCopySrcVectorReadDim,
|
||||
index_t ABlockCopySrcDataPerRead,
|
||||
index_t ABlockCopyDstDataPerWrite_M,
|
||||
typename BBlockCopyThreadSliceLengths_K_N,
|
||||
typename BBlockCopyThreadClusterLengths_K_N,
|
||||
typename BBlockCopyThreadClusterArrangeOrder,
|
||||
typename BBlockCopySrcAccessOrder,
|
||||
index_t BBlockCopySrcVectorReadDim,
|
||||
index_t BBlockCopySrcDataPerRead,
|
||||
index_t BBlockCopyDstDataPerWrite_N,
|
||||
typename CThreadCopySrcDstAccessOrder,
|
||||
index_t CThreadCopySrcDstVectorReadWriteDim,
|
||||
index_t CThreadCopyDstDataPerWrite>
|
||||
struct GridwiseGemmTransposedANormalBNormalC_v1
|
||||
{
|
||||
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
|
||||
{
|
||||
constexpr index_t max_lds_align = math::lcm(ABlockCopyDstDataPerWrite_M,
|
||||
BBlockCopyDstDataPerWrite_N,
|
||||
ThreadGemmAThreadCopySrcDataPerRead_M,
|
||||
ThreadGemmBThreadCopySrcDataPerRead_N);
|
||||
|
||||
// A matrix in LDS memory, dst of blockwise copy
|
||||
// be careful of LDS alignment
|
||||
constexpr auto a_k_m_block_desc = make_native_tensor_descriptor_aligned(
|
||||
Sequence<KPerBlock, MPerBlock>{}, Number<max_lds_align>{});
|
||||
|
||||
// B matrix in LDS memory, dst of blockwise copy
|
||||
// be careful of LDS alignment
|
||||
constexpr auto b_k_n_block_desc = make_native_tensor_descriptor_aligned(
|
||||
Sequence<KPerBlock, NPerBlock>{}, Number<max_lds_align>{});
|
||||
|
||||
// LDS allocation for A and B: be careful of alignment
|
||||
constexpr index_t a_block_space_size =
|
||||
math::integer_least_multiple(a_k_m_block_desc.GetElementSpace(), max_lds_align);
|
||||
|
||||
constexpr index_t b_block_space_size =
|
||||
math::integer_least_multiple(b_k_n_block_desc.GetElementSpace(), max_lds_align);
|
||||
|
||||
return 2 * (a_block_space_size + b_block_space_size) * sizeof(Float);
|
||||
}
|
||||
|
||||
__device__ void Run(const Float* __restrict__ p_a_global,
|
||||
const Float* __restrict__ p_b_global,
|
||||
Float* __restrict__ p_c_global,
|
||||
Float* __restrict__ p_shared_block) const
|
||||
{
|
||||
constexpr auto True = integral_constant<bool, true>{};
|
||||
|
||||
constexpr auto a_k_m_global_desc = AGlobalDesc{};
|
||||
constexpr auto b_k_n_global_desc = BGlobalDesc{};
|
||||
constexpr auto c_m_n_global_desc = CGlobalDesc{};
|
||||
|
||||
constexpr auto K = a_k_m_global_desc.GetLengths()[0];
|
||||
constexpr auto M = a_k_m_global_desc.GetLengths()[1];
|
||||
constexpr auto N = b_k_n_global_desc.GetLengths()[1];
|
||||
|
||||
// don't do anything if K == 0
|
||||
if(K == 0)
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
// lds max alignment
|
||||
constexpr index_t max_lds_align = math::lcm(ABlockCopyDstDataPerWrite_M,
|
||||
BBlockCopyDstDataPerWrite_N,
|
||||
ThreadGemmAThreadCopySrcDataPerRead_M,
|
||||
ThreadGemmBThreadCopySrcDataPerRead_N);
|
||||
|
||||
// divide block work by [M, N]
|
||||
static_assert(M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0,
|
||||
"wrong! cannot divide work evenly among block");
|
||||
|
||||
constexpr index_t MBlockWork = M / MPerBlock;
|
||||
constexpr index_t NBlockWork = N / NPerBlock;
|
||||
|
||||
constexpr auto block_work_desc =
|
||||
make_cluster_descriptor(Sequence<MBlockWork, NBlockWork>{});
|
||||
|
||||
const auto block_work_id = block_work_desc.CalculateClusterIndex(get_block_1d_id());
|
||||
|
||||
const index_t m_block_data_on_global = block_work_id[Number<0>{}] * MPerBlock;
|
||||
const index_t n_block_data_on_global = block_work_id[Number<1>{}] * NPerBlock;
|
||||
|
||||
// A matrix in LDS memory, dst of blockwise copy
|
||||
// be careful of LDS alignment
|
||||
constexpr auto a_k_m_block_desc = make_native_tensor_descriptor_aligned(
|
||||
Sequence<KPerBlock, MPerBlock>{}, Number<max_lds_align>{});
|
||||
|
||||
// A matrix blockwise copy
|
||||
auto a_blockwise_copy =
|
||||
BlockwiseGenericTensorSliceCopy_v4<BlockSize,
|
||||
decltype(a_k_m_global_desc),
|
||||
decltype(a_k_m_block_desc),
|
||||
decltype(a_k_m_block_desc.GetLengths()),
|
||||
ABlockCopyThreadSliceLengths_K_M,
|
||||
ABlockCopyThreadClusterLengths_K_M,
|
||||
ABlockCopyThreadClusterArrangeOrder,
|
||||
ABlockCopySrcAccessOrder,
|
||||
Sequence<0, 1>,
|
||||
ABlockCopySrcVectorReadDim,
|
||||
1,
|
||||
ABlockCopySrcDataPerRead,
|
||||
ABlockCopyDstDataPerWrite_M,
|
||||
AddressSpace::Global,
|
||||
AddressSpace::Vgpr,
|
||||
AddressSpace::Lds,
|
||||
InMemoryDataOperation::Set>(
|
||||
make_multi_index(0, m_block_data_on_global), make_multi_index(0, 0));
|
||||
|
||||
// B matrix in LDS memory, dst of blockwise copy
|
||||
// be careful of LDS alignment
|
||||
constexpr auto b_k_n_block_desc = make_native_tensor_descriptor_aligned(
|
||||
Sequence<KPerBlock, NPerBlock>{}, Number<max_lds_align>{});
|
||||
|
||||
// B matrix blockwise copy
|
||||
auto b_blockwise_copy =
|
||||
BlockwiseGenericTensorSliceCopy_v4<BlockSize,
|
||||
decltype(b_k_n_global_desc),
|
||||
decltype(b_k_n_block_desc),
|
||||
decltype(b_k_n_block_desc.GetLengths()),
|
||||
BBlockCopyThreadSliceLengths_K_N,
|
||||
BBlockCopyThreadClusterLengths_K_N,
|
||||
BBlockCopyThreadClusterArrangeOrder,
|
||||
BBlockCopySrcAccessOrder,
|
||||
Sequence<0, 1>,
|
||||
BBlockCopySrcVectorReadDim,
|
||||
1,
|
||||
BBlockCopySrcDataPerRead,
|
||||
BBlockCopyDstDataPerWrite_N,
|
||||
AddressSpace::Global,
|
||||
AddressSpace::Vgpr,
|
||||
AddressSpace::Lds,
|
||||
InMemoryDataOperation::Set>(
|
||||
make_multi_index(0, n_block_data_on_global), make_multi_index(0, 0));
|
||||
|
||||
// GEMM definition
|
||||
// c_mtx += transpose(a_mtx) * b_mtx
|
||||
// a_mtx[KPerBlock, MPerBlock] is in LDS
|
||||
// b_mtx[KPerBlocl, NPerBlock] is in LDS
|
||||
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
|
||||
// register
|
||||
constexpr auto a_k_m_block_mtx_desc = make_ConstantMatrixDescriptor(a_k_m_block_desc);
|
||||
constexpr auto b_k_n_block_mtx_desc = make_ConstantMatrixDescriptor(b_k_n_block_desc);
|
||||
|
||||
// sanity check
|
||||
static_assert(MPerBlock % (MPerThread * MLevel0Cluster * MLevel1Cluster) == 0 &&
|
||||
NPerBlock % (NPerThread * NLevel0Cluster * NLevel1Cluster) == 0,
|
||||
"wrong!");
|
||||
|
||||
constexpr index_t GemmMRepeat = MPerBlock / (MPerThread * MLevel0Cluster * MLevel1Cluster);
|
||||
constexpr index_t GemmNRepeat = NPerBlock / (NPerThread * NLevel0Cluster * NLevel1Cluster);
|
||||
|
||||
// c_thread_mtx definition: this is a mess
|
||||
// TODO:: more elegent way of defining c_thread_mtx
|
||||
constexpr auto c_m0m1_n0n1_thread_mtx_desc = make_ConstantMatrixDescriptor_packed(
|
||||
Number<GemmMRepeat * MPerThread>{}, Number<GemmNRepeat * NPerThread>{});
|
||||
|
||||
const auto blockwise_gemm = BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2<
|
||||
BlockSize,
|
||||
decltype(a_k_m_block_mtx_desc),
|
||||
decltype(b_k_n_block_mtx_desc),
|
||||
decltype(c_m0m1_n0n1_thread_mtx_desc),
|
||||
MPerThread,
|
||||
NPerThread,
|
||||
KPerThread,
|
||||
MLevel0Cluster,
|
||||
NLevel0Cluster,
|
||||
MLevel1Cluster,
|
||||
NLevel1Cluster,
|
||||
ThreadGemmAThreadCopySrcDataPerRead_M,
|
||||
ThreadGemmBThreadCopySrcDataPerRead_N>{};
|
||||
|
||||
// LDS allocation for A and B: be careful of alignment
|
||||
constexpr index_t a_block_space_size =
|
||||
math::integer_least_multiple(a_k_m_block_desc.GetElementSpace(), max_lds_align);
|
||||
|
||||
constexpr index_t b_block_space_size =
|
||||
math::integer_least_multiple(b_k_n_block_desc.GetElementSpace(), max_lds_align);
|
||||
|
||||
Float* p_a_block_double = p_shared_block;
|
||||
Float* p_b_block_double = p_shared_block + 2 * a_block_space_size;
|
||||
|
||||
// register allocation for output
|
||||
AccFloat p_c_thread[c_m0m1_n0n1_thread_mtx_desc.GetElementSpace()];
|
||||
|
||||
// zero out threadwise output
|
||||
threadwise_matrix_set_zero(c_m0m1_n0n1_thread_mtx_desc, p_c_thread);
|
||||
|
||||
// LDS double buffer: preload data into LDS
|
||||
{
|
||||
a_blockwise_copy.Run(p_a_global, p_a_block_double);
|
||||
b_blockwise_copy.Run(p_b_global, p_b_block_double);
|
||||
}
|
||||
|
||||
constexpr auto a_block_slice_copy_step = Sequence<KPerBlock, 0>{};
|
||||
constexpr auto b_block_slice_copy_step = Sequence<KPerBlock, 0>{};
|
||||
|
||||
Float* p_a_block_even = p_a_block_double;
|
||||
Float* p_b_block_even = p_b_block_double;
|
||||
|
||||
Float* p_a_block_odd = p_a_block_double + a_block_space_size;
|
||||
Float* p_b_block_odd = p_b_block_double + b_block_space_size;
|
||||
|
||||
// LDS double buffer: main body
|
||||
for(index_t k_block_data_begin = 0; k_block_data_begin < K - 2 * KPerBlock;
|
||||
k_block_data_begin += 2 * KPerBlock)
|
||||
{
|
||||
Float p_a_thread_buffer[a_blockwise_copy.GetThreadBufferSize()];
|
||||
Float p_b_thread_buffer[b_blockwise_copy.GetThreadBufferSize()];
|
||||
|
||||
// even iteration
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_block_slice_copy_step, True);
|
||||
b_blockwise_copy.MoveSrcSliceWindow(b_block_slice_copy_step, True);
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// LDS doubel buffer: load next data from device mem
|
||||
a_blockwise_copy.RunLoadThreadBuffer(p_a_global, p_a_thread_buffer);
|
||||
b_blockwise_copy.RunLoadThreadBuffer(p_b_global, p_b_thread_buffer);
|
||||
|
||||
// LDS double buffer: GEMM on current data
|
||||
blockwise_gemm.Run(p_a_block_even, p_b_block_even, p_c_thread);
|
||||
|
||||
// LDS double buffer: store next data to LDS
|
||||
a_blockwise_copy.RunStoreThreadBuffer(p_a_thread_buffer, p_a_block_odd);
|
||||
b_blockwise_copy.RunStoreThreadBuffer(p_b_thread_buffer, p_b_block_odd);
|
||||
|
||||
// odd iteration
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_block_slice_copy_step, True);
|
||||
b_blockwise_copy.MoveSrcSliceWindow(b_block_slice_copy_step, True);
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// LDS doubel buffer: load next data from device mem
|
||||
a_blockwise_copy.RunLoadThreadBuffer(p_a_global, p_a_thread_buffer);
|
||||
b_blockwise_copy.RunLoadThreadBuffer(p_b_global, p_b_thread_buffer);
|
||||
|
||||
// LDS double buffer: GEMM on current data
|
||||
blockwise_gemm.Run(p_a_block_odd, p_b_block_odd, p_c_thread);
|
||||
|
||||
// LDS double buffer: store next data to LDS
|
||||
a_blockwise_copy.RunStoreThreadBuffer(p_a_thread_buffer, p_a_block_even);
|
||||
b_blockwise_copy.RunStoreThreadBuffer(p_b_thread_buffer, p_b_block_even);
|
||||
}
|
||||
|
||||
// LDS double buffer: tail
|
||||
{
|
||||
constexpr bool has_two_iteration_left = (K % (2 * KPerBlock) == 0);
|
||||
|
||||
if(has_two_iteration_left) // if has 2 iteration left
|
||||
{
|
||||
Float p_a_thread_buffer[a_blockwise_copy.GetThreadBufferSize()];
|
||||
Float p_b_thread_buffer[b_blockwise_copy.GetThreadBufferSize()];
|
||||
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_block_slice_copy_step, True);
|
||||
b_blockwise_copy.MoveSrcSliceWindow(b_block_slice_copy_step, True);
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// LDS double buffer: load last data from device mem
|
||||
a_blockwise_copy.RunLoadThreadBuffer(p_a_global, p_a_thread_buffer);
|
||||
b_blockwise_copy.RunLoadThreadBuffer(p_b_global, p_b_thread_buffer);
|
||||
|
||||
// LDS double buffer: GEMM on 2nd-last data
|
||||
blockwise_gemm.Run(p_a_block_double, p_b_block_double, p_c_thread);
|
||||
|
||||
// LDS double buffer: store last data to LDS
|
||||
a_blockwise_copy.RunStoreThreadBuffer(p_a_thread_buffer,
|
||||
p_a_block_double + a_block_space_size);
|
||||
b_blockwise_copy.RunStoreThreadBuffer(p_b_thread_buffer,
|
||||
p_b_block_double + b_block_space_size);
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// LDS double buffer: GEMM on last data
|
||||
blockwise_gemm.Run(p_a_block_double + a_block_space_size,
|
||||
p_b_block_double + b_block_space_size,
|
||||
p_c_thread);
|
||||
}
|
||||
else // if has 1 iteration left
|
||||
{
|
||||
__syncthreads();
|
||||
|
||||
// LDS double buffer: GEMM on last data
|
||||
blockwise_gemm.Run(p_a_block_double, p_b_block_double, p_c_thread);
|
||||
}
|
||||
}
|
||||
|
||||
// input: register to global memory
|
||||
{
|
||||
constexpr index_t M1 = MPerThread * MLevel0Cluster * MLevel1Cluster;
|
||||
constexpr index_t M0 = M / M1;
|
||||
|
||||
constexpr index_t N1 = NPerThread * NLevel0Cluster * NLevel1Cluster;
|
||||
constexpr index_t N0 = N / N1;
|
||||
|
||||
// define input tensor descriptor for threadwise copy
|
||||
// thread input tensor, src of threadwise copy
|
||||
constexpr auto c_m0_m1_n0_n1_thread_desc = make_native_tensor_descriptor_packed(
|
||||
Sequence<GemmMRepeat, MPerThread, GemmNRepeat, NPerThread>{});
|
||||
|
||||
constexpr auto c_m0_m1_n0_n1_global_desc = transform_tensor_descriptor(
|
||||
c_m_n_global_desc,
|
||||
make_tuple(UnMerge<Sequence<M0, M1>>{}, UnMerge<Sequence<N0, N1>>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{}));
|
||||
|
||||
// calculate origin of thread input tensor on global memory
|
||||
// blockwise GEMM c matrix starting index
|
||||
const auto c_thread_mtx_on_block =
|
||||
blockwise_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id());
|
||||
|
||||
const index_t m_thread_data_on_global =
|
||||
m_block_data_on_global + c_thread_mtx_on_block.row;
|
||||
|
||||
const index_t n_thread_data_on_global =
|
||||
n_block_data_on_global + c_thread_mtx_on_block.col;
|
||||
|
||||
ThreadwiseGenericTensorSliceCopy_v4r2<decltype(c_m0_m1_n0_n1_thread_desc),
|
||||
decltype(c_m0_m1_n0_n1_global_desc),
|
||||
decltype(c_m0_m1_n0_n1_thread_desc.GetLengths()),
|
||||
CThreadCopySrcDstAccessOrder,
|
||||
CThreadCopySrcDstVectorReadWriteDim,
|
||||
1,
|
||||
CThreadCopyDstDataPerWrite,
|
||||
AddressSpace::Vgpr,
|
||||
AddressSpace::Global,
|
||||
CGlobalMemoryDataOperation>(
|
||||
make_multi_index(0, 0, 0, 0),
|
||||
make_multi_index(m_thread_data_on_global / M1,
|
||||
m_thread_data_on_global % M1,
|
||||
n_thread_data_on_global / N1,
|
||||
n_thread_data_on_global % N1))
|
||||
.Run(p_c_thread, p_c_global);
|
||||
}
|
||||
}
|
||||
|
||||
__device__ void Run(const Float* __restrict__ p_a_global,
|
||||
const Float* __restrict__ p_b_global,
|
||||
Float* __restrict__ p_c_global) const
|
||||
{
|
||||
constexpr index_t shared_block_size = GetSharedMemoryNumberOfByte() / sizeof(Float);
|
||||
|
||||
__shared__ Float p_shared_block[shared_block_size];
|
||||
|
||||
Run(p_a_global, p_b_global, p_c_global, p_shared_block);
|
||||
}
|
||||
};
|
||||
|
||||
template <index_t GridSize,
|
||||
index_t BlockSize,
|
||||
typename Float,
|
||||
typename AccFloat,
|
||||
typename AGlobalDesc,
|
||||
typename BGlobalDesc,
|
||||
typename CGlobalDesc,
|
||||
InMemoryDataOperation CGlobalMemoryDataOperation,
|
||||
index_t MPerBlock,
|
||||
index_t NPerBlock,
|
||||
index_t KPerBlock,
|
||||
index_t MPerThread,
|
||||
index_t NPerThread,
|
||||
index_t KPerThread,
|
||||
index_t MLevel0Cluster,
|
||||
index_t NLevel0Cluster,
|
||||
index_t MLevel1Cluster,
|
||||
index_t NLevel1Cluster,
|
||||
index_t ThreadGemmAThreadCopySrcDataPerRead_M,
|
||||
index_t ThreadGemmBThreadCopySrcDataPerRead_N,
|
||||
typename ABlockCopyThreadSliceLengths_K0_K1_K2_M,
|
||||
typename ABlockCopyThreadClusterLengths_K0_K1_K2_M,
|
||||
typename ABlockCopyThreadClusterArrangeOrder,
|
||||
typename ABlockCopySrcAccessOrder,
|
||||
index_t ABlockCopySrcVectorReadDim,
|
||||
index_t ABlockCopySrcDataPerRead,
|
||||
index_t ABlockCopyDstDataPerWrite_M,
|
||||
typename BBlockCopyThreadSliceLengths_K0_K1_K2_N,
|
||||
typename BBlockCopyThreadClusterLengths_K0_K1_K2_N,
|
||||
typename BBlockCopyThreadClusterArrangeOrder,
|
||||
typename BBlockCopySrcAccessOrder,
|
||||
index_t BBlockCopySrcVectorReadDim,
|
||||
index_t BBlockCopySrcDataPerRead,
|
||||
index_t BBlockCopyDstDataPerWrite_N,
|
||||
typename CThreadCopySrcDstAccessOrder,
|
||||
index_t CThreadCopySrcDstVectorReadWriteDim,
|
||||
index_t CThreadCopyDstDataPerWrite>
|
||||
struct GridwiseGemmTransposedANormalBNormalC_v2
|
||||
{
|
||||
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
|
||||
{
|
||||
constexpr index_t max_lds_align = math::lcm(ABlockCopyDstDataPerWrite_M,
|
||||
BBlockCopyDstDataPerWrite_N,
|
||||
ThreadGemmAThreadCopySrcDataPerRead_M,
|
||||
ThreadGemmBThreadCopySrcDataPerRead_N);
|
||||
|
||||
// A matrix in LDS memory, dst of blockwise copy
|
||||
// be careful of LDS alignment
|
||||
constexpr auto a_k_m_block_desc = make_native_tensor_descriptor_aligned(
|
||||
Sequence<KPerBlock, MPerBlock>{}, Number<max_lds_align>{});
|
||||
|
||||
// B matrix in LDS memory, dst of blockwise copy
|
||||
// be careful of LDS alignment
|
||||
constexpr auto b_k_n_block_desc = make_native_tensor_descriptor_aligned(
|
||||
Sequence<KPerBlock, NPerBlock>{}, Number<max_lds_align>{});
|
||||
|
||||
// LDS allocation for A and B: be careful of alignment
|
||||
constexpr index_t a_block_space_size =
|
||||
math::integer_least_multiple(a_k_m_block_desc.GetElementSpace(), max_lds_align);
|
||||
|
||||
constexpr index_t b_block_space_size =
|
||||
math::integer_least_multiple(b_k_n_block_desc.GetElementSpace(), max_lds_align);
|
||||
|
||||
return 2 * (a_block_space_size + b_block_space_size) * sizeof(Float);
|
||||
}
|
||||
|
||||
__device__ void Run(const Float* __restrict__ p_a_global,
|
||||
const Float* __restrict__ p_b_global,
|
||||
Float* __restrict__ p_c_global,
|
||||
Float* __restrict__ p_shared_block) const
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
constexpr auto True = integral_constant<bool, true>{};
|
||||
constexpr auto False = integral_constant<bool, false>{};
|
||||
|
||||
constexpr auto a_k0_k1_k2_m_global_desc = AGlobalDesc{};
|
||||
constexpr auto b_k0_k1_k2_n_global_desc = BGlobalDesc{};
|
||||
constexpr auto c_m_n_global_desc = CGlobalDesc{};
|
||||
|
||||
constexpr auto K0 = a_k0_k1_k2_m_global_desc.GetLengths()[I0];
|
||||
constexpr auto K1 = a_k0_k1_k2_m_global_desc.GetLengths()[I1];
|
||||
constexpr auto K = a_k0_k1_k2_m_global_desc.GetLengths()[I2];
|
||||
constexpr auto M = c_m_n_global_desc.GetLengths()[I0];
|
||||
constexpr auto N = c_m_n_global_desc.GetLengths()[I1];
|
||||
|
||||
// don't do anything if K == 0
|
||||
if(K == 0)
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
// lds max alignment
|
||||
constexpr index_t max_lds_align = math::lcm(ABlockCopyDstDataPerWrite_M,
|
||||
BBlockCopyDstDataPerWrite_N,
|
||||
ThreadGemmAThreadCopySrcDataPerRead_M,
|
||||
ThreadGemmBThreadCopySrcDataPerRead_N);
|
||||
|
||||
// divide block work by [M, N]
|
||||
static_assert(M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0,
|
||||
"wrong! cannot divide work evenly among block");
|
||||
|
||||
constexpr index_t MBlockWork = M / MPerBlock;
|
||||
constexpr index_t NBlockWork = N / NPerBlock;
|
||||
|
||||
constexpr auto block_work_desc =
|
||||
make_cluster_descriptor(Sequence<MBlockWork, NBlockWork>{});
|
||||
|
||||
const auto block_work_id = block_work_desc.CalculateClusterIndex(get_block_1d_id());
|
||||
|
||||
const index_t m_block_data_on_global = block_work_id[I0] * MPerBlock;
|
||||
const index_t n_block_data_on_global = block_work_id[I1] * NPerBlock;
|
||||
|
||||
// A matrix in LDS memory, dst of blockwise copy
|
||||
// be careful of LDS alignment
|
||||
constexpr auto a_k0_k1_k2_m_block_desc = make_native_tensor_descriptor_aligned(
|
||||
Sequence<1, 1, KPerBlock, MPerBlock>{}, Number<max_lds_align>{});
|
||||
|
||||
// A matrix blockwise copy
|
||||
auto a_blockwise_copy =
|
||||
BlockwiseGenericTensorSliceCopy_v4<BlockSize,
|
||||
decltype(a_k0_k1_k2_m_global_desc),
|
||||
decltype(a_k0_k1_k2_m_block_desc),
|
||||
decltype(a_k0_k1_k2_m_block_desc.GetLengths()),
|
||||
ABlockCopyThreadSliceLengths_K0_K1_K2_M,
|
||||
ABlockCopyThreadClusterLengths_K0_K1_K2_M,
|
||||
ABlockCopyThreadClusterArrangeOrder,
|
||||
ABlockCopySrcAccessOrder,
|
||||
Sequence<0, 1, 2, 3>,
|
||||
ABlockCopySrcVectorReadDim,
|
||||
3,
|
||||
ABlockCopySrcDataPerRead,
|
||||
ABlockCopyDstDataPerWrite_M,
|
||||
AddressSpace::Global,
|
||||
AddressSpace::Vgpr,
|
||||
AddressSpace::Lds,
|
||||
InMemoryDataOperation::Set>(
|
||||
make_multi_index(0, 0, 0, m_block_data_on_global), make_multi_index(0, 0, 0, 0));
|
||||
|
||||
// B matrix in LDS memory, dst of blockwise copy
|
||||
// be careful of LDS alignment
|
||||
constexpr auto b_k0_k1_k2_n_block_desc = make_native_tensor_descriptor_aligned(
|
||||
Sequence<1, 1, KPerBlock, NPerBlock>{}, Number<max_lds_align>{});
|
||||
|
||||
// B matrix blockwise copy
|
||||
auto b_blockwise_copy =
|
||||
BlockwiseGenericTensorSliceCopy_v4<BlockSize,
|
||||
decltype(b_k0_k1_k2_n_global_desc),
|
||||
decltype(b_k0_k1_k2_n_block_desc),
|
||||
decltype(b_k0_k1_k2_n_block_desc.GetLengths()),
|
||||
BBlockCopyThreadSliceLengths_K0_K1_K2_N,
|
||||
BBlockCopyThreadClusterLengths_K0_K1_K2_N,
|
||||
BBlockCopyThreadClusterArrangeOrder,
|
||||
BBlockCopySrcAccessOrder,
|
||||
Sequence<0, 1, 2, 3>,
|
||||
BBlockCopySrcVectorReadDim,
|
||||
3,
|
||||
BBlockCopySrcDataPerRead,
|
||||
BBlockCopyDstDataPerWrite_N,
|
||||
AddressSpace::Global,
|
||||
AddressSpace::Vgpr,
|
||||
AddressSpace::Lds,
|
||||
InMemoryDataOperation::Set>(
|
||||
make_multi_index(0, 0, 0, n_block_data_on_global), make_multi_index(0, 0, 0, 0));
|
||||
|
||||
// GEMM definition
|
||||
// c_mtx += transpose(a_mtx) * b_mtx
|
||||
// a_mtx[KPerBlock, MPerBlock] is in LDS
|
||||
// b_mtx[KPerBlocl, NPerBlock] is in LDS
|
||||
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
|
||||
// register
|
||||
constexpr auto a_k_m_block_mtx_desc = make_ConstantMatrixDescriptor(
|
||||
unfold_tensor_descriptor(a_k0_k1_k2_m_block_desc, I0, I2));
|
||||
constexpr auto b_k_n_block_mtx_desc = make_ConstantMatrixDescriptor(
|
||||
unfold_tensor_descriptor(b_k0_k1_k2_n_block_desc, I0, I2));
|
||||
|
||||
// sanity check
|
||||
static_assert(MPerBlock % (MPerThread * MLevel0Cluster * MLevel1Cluster) == 0 &&
|
||||
NPerBlock % (NPerThread * NLevel0Cluster * NLevel1Cluster) == 0,
|
||||
"wrong!");
|
||||
|
||||
constexpr index_t GemmMRepeat = MPerBlock / (MPerThread * MLevel0Cluster * MLevel1Cluster);
|
||||
constexpr index_t GemmNRepeat = NPerBlock / (NPerThread * NLevel0Cluster * NLevel1Cluster);
|
||||
|
||||
// c_thread_mtx definition: this is a mess
|
||||
// TODO:: more elegent way of defining c_thread_mtx
|
||||
constexpr auto c_m0m1_n0n1_thread_mtx_desc = make_ConstantMatrixDescriptor_packed(
|
||||
Number<GemmMRepeat * MPerThread>{}, Number<GemmNRepeat * NPerThread>{});
|
||||
|
||||
const auto blockwise_gemm = BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2<
|
||||
BlockSize,
|
||||
decltype(a_k_m_block_mtx_desc),
|
||||
decltype(b_k_n_block_mtx_desc),
|
||||
decltype(c_m0m1_n0n1_thread_mtx_desc),
|
||||
MPerThread,
|
||||
NPerThread,
|
||||
KPerThread,
|
||||
MLevel0Cluster,
|
||||
NLevel0Cluster,
|
||||
MLevel1Cluster,
|
||||
NLevel1Cluster,
|
||||
ThreadGemmAThreadCopySrcDataPerRead_M,
|
||||
ThreadGemmBThreadCopySrcDataPerRead_N>{};
|
||||
|
||||
// LDS allocation for A and B: be careful of alignment
|
||||
constexpr index_t a_block_space_size =
|
||||
math::integer_least_multiple(a_k0_k1_k2_m_block_desc.GetElementSpace(), max_lds_align);
|
||||
|
||||
constexpr index_t b_block_space_size =
|
||||
math::integer_least_multiple(b_k0_k1_k2_n_block_desc.GetElementSpace(), max_lds_align);
|
||||
|
||||
Float* p_a_block_double = p_shared_block;
|
||||
Float* p_b_block_double = p_shared_block + 2 * a_block_space_size;
|
||||
|
||||
// register allocation for output
|
||||
AccFloat p_c_thread[c_m0m1_n0n1_thread_mtx_desc.GetElementSpace()];
|
||||
|
||||
// zero out threadwise output
|
||||
threadwise_matrix_set_zero(c_m0m1_n0n1_thread_mtx_desc, p_c_thread);
|
||||
|
||||
for(index_t k0 = 0; k0 < K0; ++k0)
|
||||
{
|
||||
for(index_t k1 = 0; k1 < K1; ++k1)
|
||||
{
|
||||
// LDS double buffer: preload data into LDS
|
||||
{
|
||||
a_blockwise_copy.Run(p_a_global, p_a_block_double);
|
||||
b_blockwise_copy.Run(p_b_global, p_b_block_double);
|
||||
}
|
||||
|
||||
constexpr auto a_block_slice_copy_step = Sequence<0, 0, KPerBlock, 0>{};
|
||||
constexpr auto b_block_slice_copy_step = Sequence<0, 0, KPerBlock, 0>{};
|
||||
|
||||
// LDS double buffer: main body
|
||||
for(index_t k_block_data_begin = 0; k_block_data_begin + 2 * KPerBlock < K;
|
||||
k_block_data_begin += 2 * KPerBlock)
|
||||
{
|
||||
#pragma unroll
|
||||
for(index_t iloop = 0; iloop < 2; ++iloop)
|
||||
{
|
||||
const bool even_loop = (iloop % 2 == 0);
|
||||
|
||||
Float* p_a_block_now =
|
||||
even_loop ? p_a_block_double : p_a_block_double + a_block_space_size;
|
||||
Float* p_b_block_now =
|
||||
even_loop ? p_b_block_double : p_b_block_double + b_block_space_size;
|
||||
|
||||
Float* p_a_block_next =
|
||||
even_loop ? p_a_block_double + a_block_space_size : p_a_block_double;
|
||||
Float* p_b_block_next =
|
||||
even_loop ? p_b_block_double + b_block_space_size : p_b_block_double;
|
||||
|
||||
Float p_a_thread_buffer[a_blockwise_copy.GetThreadBufferSize()];
|
||||
Float p_b_thread_buffer[b_blockwise_copy.GetThreadBufferSize()];
|
||||
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_block_slice_copy_step, True);
|
||||
b_blockwise_copy.MoveSrcSliceWindow(b_block_slice_copy_step, True);
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// LDS doubel buffer: load next data from device mem
|
||||
a_blockwise_copy.RunLoadThreadBuffer(p_a_global, p_a_thread_buffer);
|
||||
b_blockwise_copy.RunLoadThreadBuffer(p_b_global, p_b_thread_buffer);
|
||||
|
||||
// LDS double buffer: GEMM on current data
|
||||
blockwise_gemm.Run(p_a_block_now, p_b_block_now, p_c_thread);
|
||||
|
||||
// LDS double buffer: store next data to LDS
|
||||
a_blockwise_copy.RunStoreThreadBuffer(p_a_thread_buffer, p_a_block_next);
|
||||
b_blockwise_copy.RunStoreThreadBuffer(p_b_thread_buffer, p_b_block_next);
|
||||
}
|
||||
}
|
||||
|
||||
// LDS double buffer: tail
|
||||
{
|
||||
constexpr bool has_two_iteration_left = (K % (2 * KPerBlock) == 0);
|
||||
|
||||
if(has_two_iteration_left) // if has 2 iteration left
|
||||
{
|
||||
Float p_a_thread_buffer[a_blockwise_copy.GetThreadBufferSize()];
|
||||
Float p_b_thread_buffer[b_blockwise_copy.GetThreadBufferSize()];
|
||||
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_block_slice_copy_step, True);
|
||||
b_blockwise_copy.MoveSrcSliceWindow(b_block_slice_copy_step, True);
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// LDS double buffer: load last data from device mem
|
||||
a_blockwise_copy.RunLoadThreadBuffer(p_a_global, p_a_thread_buffer);
|
||||
b_blockwise_copy.RunLoadThreadBuffer(p_b_global, p_b_thread_buffer);
|
||||
|
||||
// LDS double buffer: GEMM on 2nd-last data
|
||||
blockwise_gemm.Run(p_a_block_double, p_b_block_double, p_c_thread);
|
||||
|
||||
// LDS double buffer: store last data to LDS
|
||||
a_blockwise_copy.RunStoreThreadBuffer(
|
||||
p_a_thread_buffer, p_a_block_double + a_block_space_size);
|
||||
b_blockwise_copy.RunStoreThreadBuffer(
|
||||
p_b_thread_buffer, p_b_block_double + b_block_space_size);
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// LDS double buffer: GEMM on last data
|
||||
blockwise_gemm.Run(p_a_block_double + a_block_space_size,
|
||||
p_b_block_double + b_block_space_size,
|
||||
p_c_thread);
|
||||
}
|
||||
else // if has 1 iteration left
|
||||
{
|
||||
__syncthreads();
|
||||
|
||||
// LDS double buffer: GEMM on last data
|
||||
blockwise_gemm.Run(p_a_block_double, p_b_block_double, p_c_thread);
|
||||
}
|
||||
}
|
||||
|
||||
// reset slice windoww on K2 dimension, then move forward on K1 dimension
|
||||
a_blockwise_copy.MoveSrcSliceWindow(Sequence<0, 0, K - KPerBlock, 0>{}, False);
|
||||
b_blockwise_copy.MoveSrcSliceWindow(Sequence<0, 0, K - KPerBlock, 0>{}, False);
|
||||
|
||||
a_blockwise_copy.MoveSrcSliceWindow(Sequence<0, 1, 0, 0>{}, True);
|
||||
b_blockwise_copy.MoveSrcSliceWindow(Sequence<0, 1, 0, 0>{}, True);
|
||||
}
|
||||
|
||||
// reset slice windoww on K1 dimension, then move forward on K0 dimension
|
||||
a_blockwise_copy.MoveSrcSliceWindow(Sequence<0, K1, 0, 0>{}, False);
|
||||
b_blockwise_copy.MoveSrcSliceWindow(Sequence<0, K1, 0, 0>{}, False);
|
||||
|
||||
a_blockwise_copy.MoveSrcSliceWindow(Sequence<1, 0, 0, 0>{}, True);
|
||||
b_blockwise_copy.MoveSrcSliceWindow(Sequence<1, 0, 0, 0>{}, True);
|
||||
}
|
||||
|
||||
// input: register to global memory
|
||||
{
|
||||
constexpr index_t M1 = MPerThread * MLevel0Cluster * MLevel1Cluster;
|
||||
constexpr index_t M0 = M / M1;
|
||||
|
||||
constexpr index_t N1 = NPerThread * NLevel0Cluster * NLevel1Cluster;
|
||||
constexpr index_t N0 = N / N1;
|
||||
|
||||
// define input tensor descriptor for threadwise copy
|
||||
// thread input tensor, src of threadwise copy
|
||||
constexpr auto c_m0_m1_n0_n1_thread_desc = make_native_tensor_descriptor_packed(
|
||||
Sequence<GemmMRepeat, MPerThread, GemmNRepeat, NPerThread>{});
|
||||
|
||||
constexpr auto c_m0_m1_n0_n1_global_desc = transform_tensor_descriptor(
|
||||
c_m_n_global_desc,
|
||||
make_tuple(UnMerge<Sequence<M0, M1>>{}, UnMerge<Sequence<N0, N1>>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{}));
|
||||
|
||||
// calculate origin of thread input tensor on global memory
|
||||
// blockwise GEMM c matrix starting index
|
||||
const auto c_thread_mtx_on_block =
|
||||
blockwise_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id());
|
||||
|
||||
const index_t m_thread_data_on_global =
|
||||
m_block_data_on_global + c_thread_mtx_on_block.row;
|
||||
|
||||
const index_t n_thread_data_on_global =
|
||||
n_block_data_on_global + c_thread_mtx_on_block.col;
|
||||
|
||||
ThreadwiseGenericTensorSliceCopy_v4r2<decltype(c_m0_m1_n0_n1_thread_desc),
|
||||
decltype(c_m0_m1_n0_n1_global_desc),
|
||||
decltype(c_m0_m1_n0_n1_thread_desc.GetLengths()),
|
||||
CThreadCopySrcDstAccessOrder,
|
||||
CThreadCopySrcDstVectorReadWriteDim,
|
||||
1,
|
||||
CThreadCopyDstDataPerWrite,
|
||||
AddressSpace::Vgpr,
|
||||
AddressSpace::Global,
|
||||
CGlobalMemoryDataOperation>(
|
||||
make_multi_index(0, 0, 0, 0),
|
||||
make_multi_index(m_thread_data_on_global / M1,
|
||||
m_thread_data_on_global % M1,
|
||||
n_thread_data_on_global / N1,
|
||||
n_thread_data_on_global % N1))
|
||||
.Run(p_c_thread, p_c_global);
|
||||
}
|
||||
}
|
||||
|
||||
__device__ void Run(const Float* __restrict__ p_a_global,
|
||||
const Float* __restrict__ p_b_global,
|
||||
Float* __restrict__ p_c_global) const
|
||||
{
|
||||
constexpr index_t shared_block_size = GetSharedMemoryNumberOfByte() / sizeof(Float);
|
||||
|
||||
__shared__ Float p_shared_block[shared_block_size];
|
||||
|
||||
Run(p_a_global, p_b_global, p_c_global, p_shared_block);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
@@ -1,228 +0,0 @@
|
||||
#ifndef CK_THREADWISE_DIRECT_CONVOLUTION_HPP
|
||||
#define CK_THREADWISE_DIRECT_CONVOLUTION_HPP
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "ConstantTensorDescriptor_deprecated.hpp"
|
||||
#include "threadwise_tensor_slice_copy.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
// optimized for scenario if p_in, p_wei, p_out are in register
|
||||
template <class TInWei, class TOut, class InDesc, class WeiDesc, class OutDesc>
|
||||
__device__ void threadwise_direct_convolution_1(InDesc,
|
||||
TInWei* const __restrict__ p_in,
|
||||
WeiDesc,
|
||||
TInWei* const __restrict__ p_wei,
|
||||
OutDesc,
|
||||
TOut* __restrict__ p_out)
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
constexpr auto in_desc = InDesc{};
|
||||
constexpr auto wei_desc = WeiDesc{};
|
||||
constexpr auto out_desc = OutDesc{};
|
||||
|
||||
#if 0
|
||||
if(blockIdx.x == 0 && get_thread_local_1d_id() == 0)
|
||||
{
|
||||
print_ConstantTensorDescriptor(in_desc, "threadwise_direct_convolution: in_desc: ");
|
||||
print_ConstantTensorDescriptor(wei_desc, "threadwise_direct_convolution: wei_desc: ");
|
||||
print_ConstantTensorDescriptor(out_desc, "threadwise_direct_convolution: out_desc: ");
|
||||
}
|
||||
#endif
|
||||
|
||||
for(index_t n = 0; n < out_desc.GetLength(I0); ++n)
|
||||
{
|
||||
for(index_t k = 0; k < out_desc.GetLength(I1); ++k)
|
||||
{
|
||||
for(index_t ho = 0; ho < out_desc.GetLength(I2); ++ho)
|
||||
{
|
||||
for(index_t wo = 0; wo < out_desc.GetLength(I3); ++wo)
|
||||
{
|
||||
for(index_t c = 0; c < wei_desc.GetLength(I1); ++c)
|
||||
{
|
||||
for(index_t y = 0; y < wei_desc.GetLength(I2); ++y)
|
||||
{
|
||||
for(index_t x = 0; x < wei_desc.GetLength(I3); ++x)
|
||||
{
|
||||
const index_t hi = ho + y;
|
||||
const index_t wi = wo + x;
|
||||
|
||||
const index_t in_index =
|
||||
in_desc.GetOffsetFromMultiIndex(n, c, hi, wi);
|
||||
|
||||
const index_t wei_index =
|
||||
wei_desc.GetOffsetFromMultiIndex(k, c, y, x);
|
||||
|
||||
const index_t out_index =
|
||||
out_desc.GetOffsetFromMultiIndex(n, k, ho, wo);
|
||||
|
||||
fused_multiply_accumulate(
|
||||
p_out[out_index], p_wei[wei_index], p_in[in_index]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Optimized for scenario if p_in and p_wei are in LDS, p_out are in register
|
||||
// Copy in and wei into register before doing convolution
|
||||
template <class TInWei, class TOut, class InDesc, class WeiDesc, class OutDesc>
|
||||
__device__ void threadwise_direct_convolution_2(InDesc,
|
||||
TInWei* const __restrict__ p_in,
|
||||
WeiDesc,
|
||||
TInWei* const __restrict__ p_wei,
|
||||
OutDesc,
|
||||
TOut* __restrict__ p_out)
|
||||
{
|
||||
constexpr auto in_desc = InDesc{};
|
||||
constexpr auto wei_desc = WeiDesc{};
|
||||
constexpr auto out_desc = OutDesc{};
|
||||
|
||||
constexpr auto in_reg_desc = make_ConstantTensorDescriptor_packed(in_desc.GetLengths());
|
||||
constexpr auto wei_reg_desc = make_ConstantTensorDescriptor_packed(wei_desc.GetLengths());
|
||||
|
||||
// register
|
||||
TInWei p_in_reg[in_reg_desc.GetElementSpace()];
|
||||
TInWei p_wei_reg[wei_reg_desc.GetElementSpace()];
|
||||
|
||||
// copy input tensor into register
|
||||
threadwise_tensor_slice_copy(
|
||||
in_desc, p_in, in_reg_desc, p_in_reg, in_reg_desc.GetLengths(), Number<1>{});
|
||||
|
||||
// copy input tensor into register
|
||||
threadwise_tensor_slice_copy(
|
||||
wei_desc, p_wei, wei_reg_desc, p_wei_reg, wei_reg_desc.GetLengths(), Number<1>{});
|
||||
|
||||
// do convolution
|
||||
threadwise_direct_convolution_1(
|
||||
in_reg_desc, p_in_reg, wei_reg_desc, p_wei_reg, out_desc, p_out);
|
||||
}
|
||||
|
||||
// optimized for scenario where p_in and p_wei are in LDS, p_out is in register
|
||||
// break down a non-1x1 convolution into a sequence of 1x1 convolutions,
|
||||
// load 1x1 weight into register, and do 1x1 convolution in register.
|
||||
template <class Data, class InDesc, class WeiDesc, class OutDesc>
|
||||
__device__ void threadwise_direct_convolution_3(InDesc,
|
||||
Data* const __restrict__ p_in,
|
||||
WeiDesc,
|
||||
Data* const __restrict__ p_wei,
|
||||
OutDesc,
|
||||
Data* __restrict__ p_out)
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
constexpr auto in_desc = InDesc{};
|
||||
constexpr auto wei_desc = WeiDesc{};
|
||||
constexpr auto out_desc = OutDesc{};
|
||||
|
||||
constexpr auto in_reg_desc = make_ConstantTensorDescriptor(Sequence<in_desc.GetLength(I0),
|
||||
in_desc.GetLength(I1),
|
||||
out_desc.GetLength(I2),
|
||||
out_desc.GetLength(I3)>{});
|
||||
|
||||
constexpr auto wei_reg_desc = make_ConstantTensorDescriptor(
|
||||
Sequence<wei_desc.GetLength(I0), wei_desc.GetLength(I1), 1, 1>{});
|
||||
|
||||
Data p_in_reg[in_reg_desc.GetElementSpace()];
|
||||
Data p_wei_reg[wei_reg_desc.GetElementSpace()];
|
||||
|
||||
constexpr index_t in_w_new_read = 1;
|
||||
|
||||
constexpr auto in_desc_reg_new_read =
|
||||
make_ConstantTensorDescriptor(Sequence<in_reg_desc.GetLength(I0),
|
||||
in_reg_desc.GetLength(I1),
|
||||
in_reg_desc.GetLength(I2),
|
||||
in_w_new_read>{});
|
||||
|
||||
#if 0
|
||||
// this verison reused old input data in register, and read new data from LDS
|
||||
// loop over vertical direction
|
||||
for(index_t y = 0; y < wei_desc.GetLength(I2); ++y)
|
||||
{
|
||||
// read first input
|
||||
threadwise_4d_tensor_copy(in_desc,
|
||||
p_in + in_desc.GetOffsetFromMultiIndex(0, 0, y, 0),
|
||||
in_reg_desc,
|
||||
p_in_reg,
|
||||
in_reg_desc.GetLengths());
|
||||
|
||||
// read first 1x1 weight
|
||||
threadwise_4d_tensor_copy(wei_desc,
|
||||
p_wei + wei_desc.GetOffsetFromMultiIndex(0, 0, y, 0),
|
||||
wei_reg_desc,
|
||||
p_wei_reg,
|
||||
wei_reg_desc.GetLengths());
|
||||
|
||||
// do first 1x1 conv
|
||||
threadwise_direct_convolution_1(
|
||||
in_reg_desc, p_in_reg, wei_reg_desc, p_wei_reg, out_desc, p_out);
|
||||
|
||||
// loop over horizontal direction
|
||||
for(index_t x = 1; x < wei_desc.GetLength(I3); ++x)
|
||||
{
|
||||
// read new weight
|
||||
threadwise_4d_tensor_copy(wei_desc,
|
||||
p_wei + wei_desc.GetOffsetFromMultiIndex(0, 0, y, x),
|
||||
wei_reg_desc,
|
||||
p_wei_reg,
|
||||
wei_reg_desc.GetLengths());
|
||||
|
||||
// shift old input to the left
|
||||
threadwise_4d_tensor_shift_down(in_reg_desc, p_in_reg, I3, Number<in_w_new_read>{});
|
||||
|
||||
// read new input
|
||||
threadwise_4d_tensor_copy(
|
||||
in_desc,
|
||||
p_in + in_desc.GetOffsetFromMultiIndex(0, 0, y, x + in_reg_desc.GetLength(I3) - 1),
|
||||
in_reg_desc,
|
||||
p_in_reg +
|
||||
in_reg_desc.GetOffsetFromMultiIndex(0, 0, 0, in_reg_desc.GetLength(I3) - in_w_new_read),
|
||||
in_desc_reg_new_read.GetLengths());
|
||||
|
||||
// do 1x1 conv
|
||||
threadwise_direct_convolution_1(
|
||||
in_reg_desc, p_in_reg, wei_reg_desc, p_wei_reg, out_desc, p_out);
|
||||
}
|
||||
}
|
||||
#elif 1
|
||||
// this version read all input from LDS when filter moves
|
||||
// loop over vertical direction
|
||||
for(index_t y = 0; y < wei_desc.GetLength(I2); ++y)
|
||||
{
|
||||
// loop over horizontal direction
|
||||
for(index_t x = 0; x < wei_desc.GetLength(I3); ++x)
|
||||
{
|
||||
// read new weight
|
||||
threadwise_4d_tensor_copy(wei_desc,
|
||||
p_wei + wei_desc.GetOffsetFromMultiIndex(0, 0, y, x),
|
||||
wei_reg_desc,
|
||||
p_wei_reg,
|
||||
wei_reg_desc.GetLengths());
|
||||
|
||||
// read new input
|
||||
threadwise_4d_tensor_copy(in_desc,
|
||||
p_in + in_desc.GetOffsetFromMultiIndex(0, 0, y, x),
|
||||
in_reg_desc,
|
||||
p_in_reg,
|
||||
in_reg_desc.GetLengths());
|
||||
|
||||
// do 1x1 conv
|
||||
threadwise_direct_convolution_1(
|
||||
in_reg_desc, p_in_reg, wei_reg_desc, p_wei_reg, out_desc, p_out);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
@@ -1,165 +0,0 @@
|
||||
#ifndef CK_THREADWISE_GEMM_HPP
|
||||
#define CK_THREADWISE_GEMM_HPP
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "ConstantMatrixDescriptor.hpp"
|
||||
#include "math.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <typename Float, class Matrix>
|
||||
__device__ void threadwise_matrix_set_zero(Matrix, Float* __restrict__ p_thread)
|
||||
{
|
||||
for(index_t i = 0; i < Matrix::NRow(); ++i)
|
||||
{
|
||||
for(index_t j = 0; j < Matrix::NCol(); ++j)
|
||||
{
|
||||
const index_t id = Matrix::CalculateOffset(i, j);
|
||||
p_thread[id] = Float(0);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename SrcMatrix,
|
||||
typename DstMatrix,
|
||||
index_t NSliceRow,
|
||||
index_t NSliceCol,
|
||||
index_t DataPerAccess>
|
||||
struct ThreadwiseMatrixSliceCopy
|
||||
{
|
||||
__device__ constexpr ThreadwiseMatrixSliceCopy()
|
||||
{
|
||||
static_assert(SrcMatrix::RowStride() % DataPerAccess == 0 &&
|
||||
DstMatrix::RowStride() % DataPerAccess == 0,
|
||||
"wrong! wrong alignment");
|
||||
static_assert(NSliceCol % DataPerAccess == 0,
|
||||
"wrong! should be NSliceCol % DataPerAccess == 0");
|
||||
}
|
||||
|
||||
template <typename Data>
|
||||
__device__ static void Run(const Data* p_src, Data* p_dst)
|
||||
{
|
||||
using vector_t = typename vector_type<Data, DataPerAccess>::type;
|
||||
|
||||
for(index_t i = 0; i < NSliceRow; ++i)
|
||||
{
|
||||
for(index_t j = 0; j < NSliceCol; j += DataPerAccess)
|
||||
{
|
||||
const index_t src_index = SrcMatrix::CalculateOffset(i, j);
|
||||
const index_t dst_index = DstMatrix::CalculateOffset(i, j);
|
||||
|
||||
*reinterpret_cast<vector_t*>(&p_dst[dst_index]) =
|
||||
*reinterpret_cast<const vector_t*>(&p_src[src_index]);
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// C += transpose(A) * B
|
||||
// Element of matrix can be vectorized data
|
||||
template <typename MatrixA, typename MatrixB, typename MatrixC>
|
||||
struct ThreadwiseGemmTransANormalBNormalC
|
||||
{
|
||||
__device__ constexpr ThreadwiseGemmTransANormalBNormalC()
|
||||
{
|
||||
static_assert(MatrixA::NRow() == MatrixB::NRow() && MatrixA::NCol() == MatrixC::NRow() &&
|
||||
MatrixB::NCol() == MatrixC::NCol(),
|
||||
"wrong!");
|
||||
}
|
||||
|
||||
template <typename FloatA, typename FloatB, typename FloatC>
|
||||
__device__ static void Run_source(const FloatA* p_a, const FloatB* p_b, FloatC* p_c)
|
||||
{
|
||||
constexpr index_t M = MatrixC::NRow();
|
||||
constexpr index_t N = MatrixC::NCol();
|
||||
constexpr index_t K = MatrixA::NRow(); // A is transposed
|
||||
|
||||
for(index_t k = 0; k < K; ++k)
|
||||
{
|
||||
for(index_t m = 0; m < M; ++m)
|
||||
{
|
||||
for(index_t n = 0; n < N; ++n)
|
||||
{
|
||||
const index_t aindex = MatrixA::CalculateOffset(k, m); // A is transposed
|
||||
const index_t bindex = MatrixB::CalculateOffset(k, n);
|
||||
const index_t cindex = MatrixC::CalculateOffset(m, n);
|
||||
|
||||
p_c[cindex] +=
|
||||
inner_product_with_conversion<FloatC>{}(p_a[aindex], p_b[bindex]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#if CK_THREADWISE_GEMM_USE_AMD_INLINE_ASM
|
||||
template <typename FloatA, typename FloatB, typename FloatC>
|
||||
__device__ static void Run_amd_asm(const FloatA* p_a, const FloatB* p_b, FloatC* p_c)
|
||||
{
|
||||
constexpr index_t M = MatrixC::NRow();
|
||||
constexpr index_t N = MatrixC::NCol();
|
||||
constexpr index_t K = MatrixA::NRow(); // A is transposed
|
||||
|
||||
static_assert(N == 4 || N == 2, "wrong! this config not supported by asm yet");
|
||||
|
||||
for(index_t k = 0; k < K; ++k)
|
||||
{
|
||||
for(index_t m = 0; m < M; ++m)
|
||||
{
|
||||
const index_t aindex = MatrixA::CalculateOffset(k, m); // A is transposed
|
||||
|
||||
static_if<N == 2>{}([&](auto) {
|
||||
const index_t bindex_0 = MatrixB::CalculateOffset(k, 0);
|
||||
const index_t bindex_1 = MatrixB::CalculateOffset(k, 1);
|
||||
|
||||
const index_t cindex_0 = MatrixC::CalculateOffset(m, 0);
|
||||
const index_t cindex_1 = MatrixC::CalculateOffset(m, 1);
|
||||
|
||||
amd_assembly_outer_product_1x2(
|
||||
p_a[aindex], p_b[bindex_0], p_b[bindex_1], p_c[cindex_0], p_c[cindex_1]);
|
||||
});
|
||||
|
||||
static_if<N == 4>{}([&](auto) {
|
||||
const index_t bindex_0 = MatrixB::CalculateOffset(k, 0);
|
||||
const index_t bindex_1 = MatrixB::CalculateOffset(k, 1);
|
||||
const index_t bindex_2 = MatrixB::CalculateOffset(k, 2);
|
||||
const index_t bindex_3 = MatrixB::CalculateOffset(k, 3);
|
||||
|
||||
const index_t cindex_0 = MatrixC::CalculateOffset(m, 0);
|
||||
const index_t cindex_1 = MatrixC::CalculateOffset(m, 1);
|
||||
const index_t cindex_2 = MatrixC::CalculateOffset(m, 2);
|
||||
const index_t cindex_3 = MatrixC::CalculateOffset(m, 3);
|
||||
|
||||
amd_assembly_outer_product_1x4(p_a[aindex],
|
||||
p_b[bindex_0],
|
||||
p_b[bindex_1],
|
||||
p_b[bindex_2],
|
||||
p_b[bindex_3],
|
||||
p_c[cindex_0],
|
||||
p_c[cindex_1],
|
||||
p_c[cindex_2],
|
||||
p_c[cindex_3]);
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
template <typename FloatA, typename FloatB, typename FloatC>
|
||||
__device__ static void Run(const FloatA* p_a, const FloatB* p_b, FloatC* p_c)
|
||||
{
|
||||
#if CK_THREADWISE_GEMM_USE_AMD_INLINE_ASM
|
||||
constexpr bool has_amd_asm = is_same<FloatC, float>{} &&
|
||||
((is_same<FloatA, float>{} && is_same<FloatB, float>{}) ||
|
||||
(is_same<FloatA, half2_t>{} && is_same<FloatB, half2_t>{}) ||
|
||||
(is_same<FloatA, half4_t>{} && is_same<FloatB, half4_t>{}));
|
||||
|
||||
static_if<has_amd_asm>{}([&](auto fwd) { Run_amd_asm(p_a, p_b, fwd(p_c)); })
|
||||
.Else([&](auto) { Run_source(p_a, p_b, p_c); });
|
||||
#else
|
||||
Run_source(p_a, p_b, p_c);
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
@@ -1,20 +0,0 @@
|
||||
#ifndef CK_THREADWISE_GENERIC_TENSOR_OP_HPP
|
||||
#define CK_THREADWISE_GENERIC_TENSOR_OP_HPP
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "ConstantTensorDescriptor_deprecated.hpp"
|
||||
#include "ConstantMergedTensorDescriptor_deprecated.hpp"
|
||||
|
||||
namespace ck {
|
||||
template <class Float, class TDesc>
|
||||
__device__ void threadwise_generic_tensor_set_zero(TDesc, Float* __restrict__ p)
|
||||
{
|
||||
static_ford<decltype(TDesc::GetLengths())>{}([&](auto multi_id) {
|
||||
constexpr index_t offset = TDesc::GetOffsetFromMultiIndex(multi_id);
|
||||
|
||||
p[offset] = static_cast<Float>(0);
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
@@ -1,191 +0,0 @@
|
||||
#ifndef CK_THREADWISE_GENERIC_TENSOR_SLICE_COPY_HPP
|
||||
#define CK_THREADWISE_GENERIC_TENSOR_SLICE_COPY_HPP
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "tensor_descriptor.hpp"
|
||||
#include "tensor_descriptor_helper.hpp"
|
||||
#include "tensor_coordinate.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
// This threadwise copy allow vector access of src and dst.
|
||||
// It allows the vector size to be different on src and dst.
|
||||
// The dimensions of vector access should be the same on src and dst.
|
||||
// The dimension access order should be the same on src and dst.
|
||||
// Will do valid mapping check on src data: Read 0 if src data has a invalid mapping
|
||||
// Will do valid mapping check on dst data: No write if dst data has a invalid mapping
|
||||
template <typename SrcDesc,
|
||||
typename DstDesc,
|
||||
typename SliceLengths,
|
||||
typename SrcDstDimAccessOrder,
|
||||
index_t SrcDstVectorReadWriteDim,
|
||||
index_t SrcDataPerRead,
|
||||
index_t DstDataPerWrite,
|
||||
AddressSpace SrcAddressSpace = AddressSpace::Generic,
|
||||
AddressSpace DstAddressSpace = AddressSpace::Generic,
|
||||
InMemoryDataOperation DstInMemOp = InMemoryDataOperation::Set,
|
||||
index_t SrcDataStride = 1,
|
||||
index_t DstDataStride = 1>
|
||||
struct ThreadwiseGenericTensorSliceCopy_v4r2
|
||||
{
|
||||
static constexpr index_t nDim = SliceLengths::Size();
|
||||
using Index = MultiIndex<nDim>;
|
||||
|
||||
using SrcCoord = typename TensorCoordinate<SrcDesc>::type;
|
||||
using DstCoord = typename TensorCoordinate<DstDesc>::type;
|
||||
|
||||
__device__ constexpr ThreadwiseGenericTensorSliceCopy_v4r2(const Index& src_slice_origin,
|
||||
const Index& dst_slice_origin)
|
||||
: mSrcSliceOrigin(src_slice_origin), mDstSliceOrigin(dst_slice_origin)
|
||||
{
|
||||
static_assert(nDim == SrcDesc::GetNumOfDimension() &&
|
||||
nDim == DstDesc::GetNumOfDimension() && nDim == SliceLengths::Size() &&
|
||||
nDim == SrcDstDimAccessOrder::Size(),
|
||||
"wrong! # of dimensions not the same");
|
||||
|
||||
static_assert(is_valid_sequence_map<SrcDstDimAccessOrder>{}, "wrong! map is not valid");
|
||||
|
||||
static_assert(SliceLengths{}[SrcDstVectorReadWriteDim] %
|
||||
math::lcm(SrcDataPerRead, DstDataPerWrite) ==
|
||||
0,
|
||||
"wrong! cannot evenly divide");
|
||||
|
||||
// TODO:: sanity-check if vectorized memory read/write is allowed on src and dst
|
||||
}
|
||||
|
||||
__device__ constexpr ThreadwiseGenericTensorSliceCopy_v4r2()
|
||||
: ThreadwiseGenericTensorSliceCopy_v4r2(make_zero_multi_index<nDim>(),
|
||||
make_zero_multi_index<nDim>())
|
||||
{
|
||||
}
|
||||
|
||||
__device__ void SetSrcSliceOrigin(SrcCoord src_slice_origin)
|
||||
{
|
||||
mSrcSliceOrigin = src_slice_origin;
|
||||
}
|
||||
|
||||
__device__ void SetDstSliceOrigin(DstCoord dst_slice_origin)
|
||||
{
|
||||
mDstSliceOrigin = dst_slice_origin;
|
||||
}
|
||||
|
||||
template <typename SrcData, typename DstData>
|
||||
__device__ void Run(const SrcData* p_src, DstData* p_dst) const
|
||||
{
|
||||
constexpr auto vector_access_dim = Number<SrcDstVectorReadWriteDim>{};
|
||||
|
||||
constexpr auto src_data_per_access = Number<SrcDataPerRead>{};
|
||||
constexpr auto dst_data_per_access = Number<DstDataPerWrite>{};
|
||||
|
||||
constexpr auto long_vector_size = Number<math::lcm(SrcDataPerRead, DstDataPerWrite)>{};
|
||||
|
||||
constexpr auto long_vector_access_lengths = SliceLengths::Modify(
|
||||
vector_access_dim, SliceLengths::Get(vector_access_dim) / long_vector_size);
|
||||
|
||||
ford<decltype(long_vector_access_lengths), SrcDstDimAccessOrder>{}(
|
||||
[&](auto long_vector_access_id) {
|
||||
|
||||
// data id w.r.t slicing-window
|
||||
auto long_vector_data_begin_id = long_vector_access_id;
|
||||
long_vector_data_begin_id(vector_access_dim) =
|
||||
long_vector_size * long_vector_access_id[vector_access_dim];
|
||||
|
||||
// buffer to hold a src long-vector
|
||||
SrcData p_src_long_vector[long_vector_size];
|
||||
|
||||
// load data from src to the long-vector buffer
|
||||
static_for<0, long_vector_size / src_data_per_access, 1>{}([&](auto i) {
|
||||
auto scalar_id = make_zero_multi_index<nDim>();
|
||||
scalar_id(vector_access_dim) = i * src_data_per_access;
|
||||
|
||||
const index_t buffer_offset = i * src_data_per_access;
|
||||
|
||||
const auto src_coord =
|
||||
mSrcSliceOrigin + (long_vector_data_begin_id + scalar_id);
|
||||
|
||||
// Check src data's valid mapping situation, only check the first data in this
|
||||
// src
|
||||
// vector. It's user's responsiblity to make sure all data in the src vector
|
||||
// has the valid/invalid mapping situation
|
||||
transfer_data<SrcData,
|
||||
SrcDataPerRead,
|
||||
SrcAddressSpace,
|
||||
AddressSpace::Vgpr,
|
||||
InMemoryDataOperation::Set,
|
||||
SrcDataStride,
|
||||
1>(p_src,
|
||||
src_coord.GetOffset(),
|
||||
src_coord.IsOffsetValidAssumingUpperIndexIsValid(),
|
||||
SrcDesc::GetElementSpace(),
|
||||
p_src_long_vector,
|
||||
buffer_offset,
|
||||
true,
|
||||
long_vector_size);
|
||||
});
|
||||
|
||||
// SrcData to DstData conversion
|
||||
DstData p_dst_long_vector[long_vector_size];
|
||||
|
||||
static_for<0, long_vector_size, 1>{}([&](auto i) {
|
||||
p_dst_long_vector[i] = type_convert<DstData>{}(p_src_long_vector[i]);
|
||||
});
|
||||
|
||||
// store data from the long-vector buffer to dst
|
||||
static_for<0, long_vector_size / dst_data_per_access, 1>{}([&](auto i) {
|
||||
auto scalar_id = make_zero_multi_index<nDim>();
|
||||
scalar_id(vector_access_dim) = i * dst_data_per_access;
|
||||
|
||||
const index_t buffer_offset = i * dst_data_per_access;
|
||||
|
||||
const auto dst_coord =
|
||||
mDstSliceOrigin + (long_vector_data_begin_id + scalar_id);
|
||||
|
||||
// Check dst data's valid mapping situation, only check the first data in this
|
||||
// dst
|
||||
// vector. It's user's responsiblity to make sure all data in the dst vector
|
||||
// has the valid/invalid mapping situation
|
||||
transfer_data<DstData,
|
||||
DstDataPerWrite,
|
||||
AddressSpace::Vgpr,
|
||||
DstAddressSpace,
|
||||
DstInMemOp,
|
||||
1,
|
||||
DstDataStride>(p_dst_long_vector,
|
||||
buffer_offset,
|
||||
true,
|
||||
long_vector_size,
|
||||
p_dst,
|
||||
dst_coord.GetOffset(),
|
||||
dst_coord.IsOffsetValidAssumingUpperIndexIsValid(),
|
||||
DstDesc::GetElementSpace());
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
template <typename T, bool PositiveDirection>
|
||||
__device__ void MoveSrcSliceWindow(const T& step_sizes_,
|
||||
integral_constant<bool, PositiveDirection>)
|
||||
{
|
||||
const auto step_sizes = to_multi_index(step_sizes_);
|
||||
|
||||
static_if<PositiveDirection>{}([&](auto) { mSrcSliceOrigin += to_multi_index(step_sizes); })
|
||||
.Else([&](auto) { mSrcSliceOrigin -= step_sizes; });
|
||||
}
|
||||
|
||||
template <typename T, bool PositiveDirection>
|
||||
__device__ void MoveDstSliceWindow(const T& step_sizes_,
|
||||
integral_constant<bool, PositiveDirection>)
|
||||
{
|
||||
const auto step_sizes = to_multi_index(step_sizes_);
|
||||
|
||||
static_if<PositiveDirection>{}([&](auto) { mDstSliceOrigin += step_sizes; })
|
||||
.Else([&](auto) { mDstSliceOrigin -= step_sizes; });
|
||||
}
|
||||
|
||||
private:
|
||||
SrcCoord mSrcSliceOrigin;
|
||||
DstCoord mDstSliceOrigin;
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
@@ -2,7 +2,6 @@
|
||||
#define CK_XDLOPS_GEMM_HPP
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "ConstantMatrixDescriptor.hpp"
|
||||
#include "math.hpp"
|
||||
#include "amd_xdlops.hpp"
|
||||
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -23,6 +23,48 @@ amd_inner_product_dlop<float, float, float>(const float& a, const float& b, floa
|
||||
#endif
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ void
|
||||
amd_inner_product_dlop<float2_t, float2_t, float>(const float2_t& a, const float2_t& b, float& c)
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
|
||||
amd_inner_product_dlop(vector_type<float, 2>{a}.AsType<float>()[I0],
|
||||
vector_type<float, 2>{b}.AsType<float>()[I0],
|
||||
c);
|
||||
|
||||
amd_inner_product_dlop(vector_type<float, 2>{a}.AsType<float>()[I1],
|
||||
vector_type<float, 2>{b}.AsType<float>()[I1],
|
||||
c);
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ void
|
||||
amd_inner_product_dlop<float4_t, float4_t, float>(const float4_t& a, const float4_t& b, float& c)
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
amd_inner_product_dlop(vector_type<float, 4>{a}.AsType<float>()[I0],
|
||||
vector_type<float, 4>{b}.AsType<float>()[I0],
|
||||
c);
|
||||
|
||||
amd_inner_product_dlop(vector_type<float, 4>{a}.AsType<float>()[I1],
|
||||
vector_type<float, 4>{b}.AsType<float>()[I1],
|
||||
c);
|
||||
|
||||
amd_inner_product_dlop(vector_type<float, 4>{a}.AsType<float>()[I2],
|
||||
vector_type<float, 4>{b}.AsType<float>()[I2],
|
||||
c);
|
||||
|
||||
amd_inner_product_dlop(vector_type<float, 4>{a}.AsType<float>()[I3],
|
||||
vector_type<float, 4>{b}.AsType<float>()[I3],
|
||||
c);
|
||||
}
|
||||
|
||||
#if CK_USE_AMD_DLOP
|
||||
template <>
|
||||
__device__ void
|
||||
|
||||
@@ -13,7 +13,6 @@
|
||||
#include "functional2.hpp"
|
||||
#include "functional3.hpp"
|
||||
#include "functional4.hpp"
|
||||
#include "in_memory_operation.hpp"
|
||||
#include "integral_constant.hpp"
|
||||
#include "math.hpp"
|
||||
#include "number.hpp"
|
||||
@@ -25,6 +24,7 @@
|
||||
#include "type.hpp"
|
||||
#include "utility.hpp"
|
||||
#include "magic_division.hpp"
|
||||
#include "amd_buffer_addressing_v2.hpp"
|
||||
#include "static_buffer.hpp"
|
||||
#include "dynamic_buffer.hpp"
|
||||
|
||||
|
||||
@@ -1,54 +0,0 @@
|
||||
#ifndef CK_CONFIG_NVIDIA_HPP
|
||||
#define CK_CONFIG_NVIDIA_HPP
|
||||
|
||||
#include <cuda_runtime.h>
|
||||
#include <cuda_fp16.h>
|
||||
#include <nvToolsExt.h>
|
||||
|
||||
// index type: unsigned or signed
|
||||
#define CK_UNSIGNED_INDEX_TYPE 0
|
||||
|
||||
// device backend
|
||||
#define CK_DEVICE_BACKEND_NVIDIA 1
|
||||
|
||||
// disable AMD inline asm and intrinsic
|
||||
#define CK_USE_AMD_INLINE_ASM 0
|
||||
#define CK_THREADWISE_GEMM_USE_AMD_INLINE_ASM 0
|
||||
#define CK_USE_AMD_BUFFER_ADDRESSING 0
|
||||
#define CK_USE_AMD_BUFFER_ADDRESSING_INTRINSIC 0
|
||||
#define CK_USE_AMD_XDLOPS 0
|
||||
#define CK_USE_AMD_XDLOPS_INLINE_ASM 0
|
||||
#define CK_USE_AMD_XDLOPS_EMULATE 0
|
||||
|
||||
// experimental implementation
|
||||
#define CK_EXPERIMENTAL_BLOCKWISE_GEMM_USE_PIPELINE 0
|
||||
#define CK_EXPERIMENTAL_TENSOR_COORDINATE_USE_CALCULATE_OFFSET_DIFF 0
|
||||
#define CK_EXPERIMENTAL_THREADWISE_COPY_V4R2_USE_OPTIMIZED_ADDRESS_CACLULATION 0
|
||||
#define CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_BLOCKWISE_GENERIC_SLICE_COPY_V1 0
|
||||
#define CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_THREADWISE_GENERIC_TENSOR_SLICE_COPY_V1R2 0
|
||||
#define CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_THREADWISE_GENERIC_TENSOR_SLICE_COPY_V2R1 0
|
||||
|
||||
namespace ck {
|
||||
|
||||
enum AddressSpace
|
||||
{
|
||||
Generic,
|
||||
Global,
|
||||
Lds,
|
||||
Vgpr
|
||||
};
|
||||
|
||||
enum InMemoryDataOperation
|
||||
{
|
||||
Set,
|
||||
AtomicAdd
|
||||
};
|
||||
|
||||
#if CK_UNSIGNED_INDEX_TYPE
|
||||
using index_t = uint32_t;
|
||||
#else
|
||||
using index_t = int32_t;
|
||||
#endif
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
@@ -1,180 +0,0 @@
|
||||
#ifndef CK_FLOAT_TYPE_NVIDIA_HPP
|
||||
#define CK_FLOAT_TYPE_NVIDIA_HPP
|
||||
|
||||
#include "number.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
// For some reason, CUDA need this definition, otherwise
|
||||
// compiler won't generate optimal load and store instruction, and
|
||||
// kernel would produce wrong result, indicating the compiler fail to generate correct
|
||||
// instruction,
|
||||
// float
|
||||
using float2_t = float2;
|
||||
using float4_t = float4;
|
||||
|
||||
// float
|
||||
typedef float float32_t __attribute__((ext_vector_type(32)));
|
||||
|
||||
// bfloat16
|
||||
typedef ushort ushort2_t __attribute__((ext_vector_type(2)));
|
||||
typedef ushort ushort4_t __attribute__((ext_vector_type(4)));
|
||||
typedef ushort ushort8_t __attribute__((ext_vector_type(8)));
|
||||
|
||||
// fp16
|
||||
using half_t = half;
|
||||
using half2_t = half2;
|
||||
using half4_t = float2;
|
||||
|
||||
template <class T, index_t N>
|
||||
struct vector_type
|
||||
{
|
||||
typedef struct
|
||||
{
|
||||
T scalar[N];
|
||||
} type;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct vector_type<float, 1>
|
||||
{
|
||||
using type = float;
|
||||
|
||||
template <index_t I>
|
||||
__host__ __device__ static void SetScalar(type& v, float s, Number<I>)
|
||||
{
|
||||
static_assert(I < 1, "wrong");
|
||||
*(reinterpret_cast<float*>(&v) + I) = s;
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct vector_type<float, 2>
|
||||
{
|
||||
using type = float2_t;
|
||||
|
||||
union DataType
|
||||
{
|
||||
type vector;
|
||||
float scalar[2];
|
||||
};
|
||||
|
||||
template <index_t I>
|
||||
__host__ __device__ static void SetScalar(type& v, float s, Number<I>)
|
||||
{
|
||||
static_assert(I < 2, "wrong");
|
||||
*(reinterpret_cast<float*>(&v) + I) = s;
|
||||
}
|
||||
|
||||
__host__ __device__ static type Pack(float s0, float s1)
|
||||
{
|
||||
DataType data;
|
||||
data.scalar[0] = s0;
|
||||
data.scalar[1] = s1;
|
||||
return data.vector;
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct vector_type<float, 4>
|
||||
{
|
||||
using type = float4_t;
|
||||
|
||||
__host__ __device__ static constexpr index_t GetSize() { return 4; }
|
||||
|
||||
template <index_t I>
|
||||
__host__ __device__ static void SetScalar(type& v, float s, Number<I>)
|
||||
{
|
||||
static_assert(I < 4, "wrong");
|
||||
*(reinterpret_cast<float*>(&v) + I) = s;
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct vector_type<half_t, 1>
|
||||
{
|
||||
using type = half_t;
|
||||
|
||||
template <index_t I>
|
||||
__host__ __device__ static void SetScalar(type& v, half_t s, Number<I>)
|
||||
{
|
||||
static_assert(I < 1, "wrong");
|
||||
*(reinterpret_cast<half_t*>(&v) + I) = s;
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct vector_type<half_t, 2>
|
||||
{
|
||||
using type = half2_t;
|
||||
|
||||
union DataType
|
||||
{
|
||||
type vector;
|
||||
half_t scalar[2];
|
||||
};
|
||||
|
||||
template <index_t I>
|
||||
__host__ __device__ static void SetScalar(type& v, half_t s, Number<I>)
|
||||
{
|
||||
static_assert(I < 2, "wrong");
|
||||
*(reinterpret_cast<half_t*>(&v) + I) = s;
|
||||
}
|
||||
|
||||
__host__ __device__ static type Pack(half_t s0, half_t s1)
|
||||
{
|
||||
DataType data;
|
||||
data.scalar[0] = s0;
|
||||
data.scalar[1] = s1;
|
||||
return data.vector;
|
||||
}
|
||||
};
|
||||
|
||||
// data type conversion
|
||||
template <typename T>
|
||||
struct type_convert
|
||||
{
|
||||
template <typename X>
|
||||
__device__ T operator()(const X& x) const
|
||||
{
|
||||
return static_cast<T>(x);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct inner_product_with_conversion
|
||||
{
|
||||
static constexpr auto convert = type_convert<T>();
|
||||
|
||||
__device__ T operator()(float a, float b) const { return convert(a) * convert(b); }
|
||||
|
||||
__device__ T operator()(half2_t a, half2_t b) const
|
||||
{
|
||||
const half_t* p_a_half = reinterpret_cast<const half_t*>(&a);
|
||||
const half_t* p_b_half = reinterpret_cast<const half_t*>(&b);
|
||||
|
||||
T acc = 0;
|
||||
for(index_t v = 0; v < 2; ++v)
|
||||
{
|
||||
acc += convert(p_a_half[v]) * convert(p_b_half[v]);
|
||||
}
|
||||
|
||||
return acc;
|
||||
}
|
||||
|
||||
__device__ T operator()(half4_t a, half4_t b) const
|
||||
{
|
||||
const half_t* p_a_half = reinterpret_cast<const half_t*>(&a);
|
||||
const half_t* p_b_half = reinterpret_cast<const half_t*>(&b);
|
||||
|
||||
T acc = 0;
|
||||
for(index_t v = 0; v < 4; ++v)
|
||||
{
|
||||
acc += convert(p_a_half[v]) * convert(p_b_half[v]);
|
||||
}
|
||||
return acc;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
@@ -1,241 +0,0 @@
|
||||
#ifndef CK_IN_MEMORY_OPERATION_AMD_HPP
|
||||
#define CK_IN_MEMORY_OPERATION_AMD_HPP
|
||||
|
||||
#include "float_type.hpp"
|
||||
|
||||
#if CK_USE_AMD_BUFFER_ADDRESSING
|
||||
#include "amd_buffer_addressing.hpp"
|
||||
#include "amd_buffer_addressing_v2.hpp"
|
||||
#endif
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <typename T>
|
||||
__device__ void atomic_add_impl(T* p_dst, T src)
|
||||
{
|
||||
atomicAdd(p_dst, src);
|
||||
}
|
||||
|
||||
// atomicAdd for float does not support vector type
|
||||
template <>
|
||||
__device__ void atomic_add_impl<float2_t>(float2_t* p_dst, float2_t src)
|
||||
{
|
||||
float* p_dst_float = reinterpret_cast<float*>(p_dst);
|
||||
const float* p_src_float = reinterpret_cast<const float*>(&src);
|
||||
|
||||
for(index_t i = 0; i < 2; ++i)
|
||||
{
|
||||
atomicAdd(&(p_dst_float[i]), p_src_float[i]);
|
||||
}
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ void atomic_add_impl<float4_t>(float4_t* p_dst, float4_t src)
|
||||
{
|
||||
float* p_dst_float = reinterpret_cast<float*>(p_dst);
|
||||
const float* p_src_float = reinterpret_cast<const float*>(&src);
|
||||
|
||||
for(index_t i = 0; i < 4; ++i)
|
||||
{
|
||||
atomicAdd(&(p_dst_float[i]), p_src_float[i]);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, index_t DataPerAccess>
|
||||
struct SetData
|
||||
{
|
||||
using vector_t = typename vector_type<T, DataPerAccess>::type;
|
||||
|
||||
// This version is only for compatibility, don't use this version if possible
|
||||
template <AddressSpace SrcAddressSpace, AddressSpace DstAddressSpace>
|
||||
__device__ void Run(const T* p_src,
|
||||
index_t src_offset,
|
||||
bool src_valid,
|
||||
index_t /* src_range */,
|
||||
T* p_dst,
|
||||
index_t dst_offset,
|
||||
bool dst_valid,
|
||||
index_t /* dst_range */) const
|
||||
{
|
||||
if(dst_valid)
|
||||
{
|
||||
if(src_valid)
|
||||
{
|
||||
#if 0
|
||||
*reinterpret_cast<vector_t*>(&p_dst[dst_offset]) =
|
||||
*reinterpret_cast<const vector_t*>(&p_src[src_offset]);
|
||||
#else
|
||||
*reinterpret_cast<vector_t*>(&p_dst[dst_offset]) =
|
||||
*reinterpret_cast<const vector_t*>(&p_src[0x3fffffff & src_offset]);
|
||||
#endif
|
||||
}
|
||||
else
|
||||
{
|
||||
*reinterpret_cast<vector_t*>(&p_dst[dst_offset]) = 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#if CK_USE_AMD_BUFFER_ADDRESSING
|
||||
// buffer_load requires:
|
||||
// 1) p_src_thread must be in global memory space, p_dst_thread must be vgpr
|
||||
// 2) p_src_thread to be a wavewise pointer.
|
||||
// It is user's responsibility to make sure that is true.
|
||||
template <>
|
||||
__device__ void Run<AddressSpace::Global, AddressSpace::Vgpr>(const T* p_src,
|
||||
index_t src_offset,
|
||||
bool src_valid,
|
||||
index_t src_range,
|
||||
T* p_dst,
|
||||
index_t dst_offset,
|
||||
bool dst_valid,
|
||||
index_t /* dst_range */) const
|
||||
{
|
||||
if(dst_valid)
|
||||
{
|
||||
*reinterpret_cast<vector_t*>(&p_dst[dst_offset]) =
|
||||
amd_buffer_load_v2<T, DataPerAccess>(p_src, src_offset, src_valid, src_range);
|
||||
}
|
||||
}
|
||||
|
||||
// buffer_store requires:
|
||||
// 1) p_src_thread must be in vgpr space, p_dst_thread must be global memory
|
||||
// 2) p_dst_thread to be a wavewise pointer.
|
||||
// It is user's responsibility to make sure that is true.
|
||||
template <>
|
||||
__device__ void Run<AddressSpace::Vgpr, AddressSpace::Global>(const T* p_src,
|
||||
index_t src_offset,
|
||||
bool src_valid,
|
||||
index_t /* src_range */,
|
||||
T* p_dst,
|
||||
index_t dst_offset,
|
||||
bool dst_valid,
|
||||
index_t dst_range) const
|
||||
{
|
||||
const auto zeros = vector_t(0);
|
||||
|
||||
amd_buffer_store_v2<T, DataPerAccess>(
|
||||
src_valid ? *reinterpret_cast<const vector_t*>(&(p_src[src_offset])) : zeros,
|
||||
p_dst,
|
||||
dst_offset,
|
||||
dst_valid,
|
||||
dst_range);
|
||||
}
|
||||
#endif
|
||||
};
|
||||
|
||||
template <typename T, index_t DataPerAccess>
|
||||
struct AtomicAddData
|
||||
{
|
||||
using vector_t = typename vector_type<T, DataPerAccess>::type;
|
||||
|
||||
// This version is only for compatibility, don't use this version if possible
|
||||
template <AddressSpace SrcAddressSpace, AddressSpace DstAddressSpace>
|
||||
__device__ void Run(const T* p_src,
|
||||
index_t src_offset,
|
||||
bool src_valid,
|
||||
index_t /* src_range */,
|
||||
T* p_dst,
|
||||
index_t dst_offset,
|
||||
bool dst_valid,
|
||||
index_t /* dst_range */) const
|
||||
{
|
||||
if(src_valid && dst_valid)
|
||||
{
|
||||
atomic_add_impl(reinterpret_cast<vector_t*>(&p_dst[dst_offset]),
|
||||
*reinterpret_cast<const vector_t*>(&p_src[src_offset]));
|
||||
}
|
||||
}
|
||||
|
||||
#if CK_USE_AMD_BUFFER_ADDRESSING && CK_USE_AMD_BUFFER_ATOMIC_FADD
|
||||
// buffer_atomic requires:
|
||||
// 1) p_src_thread must be in vgpr space, p_dst_thread must be global memory
|
||||
// 2) p_dst_thread to be a wavewise pointer.
|
||||
// It is user's responsibility to make sure that is true.
|
||||
template <>
|
||||
__device__ void Run<AddressSpace::Vgpr, AddressSpace::Global>(const T* p_src,
|
||||
index_t src_offset,
|
||||
bool src_valid,
|
||||
index_t /* src_range */,
|
||||
T* p_dst,
|
||||
index_t dst_offset,
|
||||
bool dst_valid,
|
||||
index_t dst_range) const
|
||||
{
|
||||
const auto zeros = vector_t(0);
|
||||
|
||||
amd_buffer_atomic_add<T, DataPerAccess>(
|
||||
src_valid ? &(p_src[src_offset]) : &zeros, p_dst, dst_offset, dst_valid, dst_range);
|
||||
}
|
||||
#endif
|
||||
};
|
||||
|
||||
template <typename T,
|
||||
index_t DataPerAccess,
|
||||
AddressSpace SrcAddressSpace,
|
||||
AddressSpace DstAddressSpace,
|
||||
InMemoryDataOperation DstInMemOp,
|
||||
index_t SrcDataStride = 1,
|
||||
index_t DstDataStride = 1>
|
||||
__device__ void transfer_data(const T* p_src,
|
||||
index_t src_offset,
|
||||
bool src_valid,
|
||||
index_t src_range,
|
||||
T* p_dst,
|
||||
index_t dst_offset,
|
||||
bool dst_valid,
|
||||
index_t dst_range)
|
||||
{
|
||||
static_assert(DstInMemOp == InMemoryDataOperation::Set ||
|
||||
DstInMemOp == InMemoryDataOperation::AtomicAdd,
|
||||
"wrong! InMemoryDataOperation not supported!");
|
||||
|
||||
// keep it simple, don't use static_if here, otherwise compiler will do weird things
|
||||
if constexpr(SrcDataStride == 1 && DstDataStride == 1)
|
||||
{
|
||||
if constexpr(DstInMemOp == InMemoryDataOperation::Set)
|
||||
{
|
||||
SetData<T, DataPerAccess>{}.template Run<SrcAddressSpace, DstAddressSpace>(
|
||||
p_src, src_offset, src_valid, src_range, p_dst, dst_offset, dst_valid, dst_range);
|
||||
}
|
||||
else if constexpr(DstInMemOp == InMemoryDataOperation::AtomicAdd)
|
||||
{
|
||||
AtomicAddData<T, DataPerAccess>{}.template Run<SrcAddressSpace, DstAddressSpace>(
|
||||
p_src, src_offset, src_valid, src_range, p_dst, dst_offset, dst_valid, dst_range);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
#pragma unroll
|
||||
for(index_t i = 0; i < DataPerAccess; ++i)
|
||||
{
|
||||
if constexpr(DstInMemOp == InMemoryDataOperation::Set)
|
||||
{
|
||||
SetData<T, 1>{}.template Run<SrcAddressSpace, DstAddressSpace>(
|
||||
p_src,
|
||||
src_offset + i * SrcDataStride,
|
||||
src_valid,
|
||||
src_range,
|
||||
p_dst,
|
||||
dst_offset + i * DstDataStride,
|
||||
dst_valid,
|
||||
dst_range);
|
||||
}
|
||||
else if constexpr(DstInMemOp == InMemoryDataOperation::AtomicAdd)
|
||||
{
|
||||
AtomicAddData<T, 1>{}.template Run<SrcAddressSpace, DstAddressSpace>(
|
||||
p_src,
|
||||
src_offset + i * SrcDataStride,
|
||||
src_valid,
|
||||
src_range,
|
||||
p_dst,
|
||||
dst_offset + i * DstDataStride,
|
||||
dst_valid,
|
||||
dst_range);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
@@ -1,109 +0,0 @@
|
||||
#ifndef CK_IN_MEMORY_OPERATION_NVIDIA_HPP
|
||||
#define CK_IN_MEMORY_OPERATION_NVIDIA_HPP
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <typename T>
|
||||
__device__ void atomic_add_impl(T* p_dst, T src)
|
||||
{
|
||||
atomicAdd(p_dst, src);
|
||||
}
|
||||
|
||||
// atomicAdd for float does not support vector type
|
||||
template <>
|
||||
__device__ void atomic_add_impl<float2_t>(float2_t* p_dst, float2_t src)
|
||||
{
|
||||
float* p_dst_float = reinterpret_cast<float*>(p_dst);
|
||||
const float* p_src_float = reinterpret_cast<const float*>(&src);
|
||||
|
||||
for(index_t i = 0; i < 2; ++i)
|
||||
{
|
||||
atomicAdd(&(p_dst_float[i]), p_src_float[i]);
|
||||
}
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ void atomic_add_impl<float4_t>(float4_t* p_dst, float4_t src)
|
||||
{
|
||||
float* p_dst_float = reinterpret_cast<float*>(p_dst);
|
||||
const float* p_src_float = reinterpret_cast<const float*>(&src);
|
||||
|
||||
for(index_t i = 0; i < 4; ++i)
|
||||
{
|
||||
atomicAdd(&(p_dst_float[i]), p_src_float[i]);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, index_t DataPerAccess>
|
||||
struct SetData
|
||||
{
|
||||
using vector_t = typename vector_type<T, DataPerAccess>::type;
|
||||
|
||||
template <AddressSpace SrcAddressSpace, AddressSpace DstAddressSpace>
|
||||
__device__ void Run(const T* p_src, index_t src_offset, T* p_dst, index_t dst_offset) const
|
||||
{
|
||||
*reinterpret_cast<vector_t*>(&p_dst[dst_offset]) =
|
||||
*reinterpret_cast<const vector_t*>(&p_src[src_offset]);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, index_t DataPerAccess>
|
||||
struct AtomicAddData
|
||||
{
|
||||
using vector_t = typename vector_type<T, DataPerAccess>::type;
|
||||
|
||||
template <AddressSpace SrcAddressSpace, AddressSpace DstAddressSpace>
|
||||
__device__ void Run(const T* p_src, index_t src_offset, T* p_dst, index_t dst_offset) const
|
||||
{
|
||||
atomic_add_impl(reinterpret_cast<vector_t*>(&p_dst[dst_offset]),
|
||||
*reinterpret_cast<const vector_t*>(&p_src[src_offset]));
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T,
|
||||
index_t DataPerAccess,
|
||||
AddressSpace SrcAddressSpace,
|
||||
AddressSpace DstAddressSpace,
|
||||
InMemoryDataOperation DstInMemOp,
|
||||
index_t SrcDataStride = 1,
|
||||
index_t DstDataStride = 1>
|
||||
__device__ void transfer_data(const T* p_src, index_t src_offset, T* p_dst, index_t dst_offset)
|
||||
{
|
||||
static_assert(DstInMemOp == InMemoryDataOperation::Set ||
|
||||
DstInMemOp == InMemoryDataOperation::AtomicAdd,
|
||||
"wrong! InMemoryDataOperation not supported!");
|
||||
|
||||
// keep it simple, don't use static_if here, otherwise compiler will do weird things
|
||||
if(SrcDataStride == 1 && DstDataStride == 1)
|
||||
{
|
||||
// TODO: use static_if::ElseIf
|
||||
static_if<DstInMemOp == InMemoryDataOperation::Set>{}([&](auto) {
|
||||
SetData<T, DataPerAccess>{}.template Run<SrcAddressSpace, DstAddressSpace>(
|
||||
p_src, src_offset, p_dst, dst_offset);
|
||||
});
|
||||
|
||||
static_if<DstInMemOp == InMemoryDataOperation::AtomicAdd>{}([&](auto) {
|
||||
AtomicAddData<T, DataPerAccess>{}.template Run<SrcAddressSpace, DstAddressSpace>(
|
||||
p_src, src_offset, p_dst, dst_offset);
|
||||
});
|
||||
}
|
||||
else
|
||||
{
|
||||
for(index_t i = 0; i < DataPerAccess; i++)
|
||||
{
|
||||
// TODO: use static_if::ElseIf
|
||||
static_if<DstInMemOp == InMemoryDataOperation::Set>{}([&](auto) {
|
||||
SetData<T, 1>{}.template Run<SrcAddressSpace, DstAddressSpace>(
|
||||
p_src, src_offset + i * SrcDataStride, p_dst, dst_offset + i * DstDataStride);
|
||||
});
|
||||
|
||||
static_if<DstInMemOp == InMemoryDataOperation::AtomicAdd>{}([&](auto) {
|
||||
AtomicAddData<T, 1>{}.template Run<SrcAddressSpace, DstAddressSpace>(
|
||||
p_src, src_offset + i * SrcDataStride, p_dst, dst_offset + i * DstDataStride);
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
@@ -1,13 +0,0 @@
|
||||
#ifndef CK_SYNCHRONIZATION_NVIDIA_HPP
|
||||
#define CK_SYNCHRONIZATION_NVIDIA_HPP
|
||||
|
||||
#include "config.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
__device__ void block_sync_lds() { __syncthreads(); }
|
||||
|
||||
__device__ void block_sync_lds_vmem() { __syncthreads(); }
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
@@ -1,8 +0,0 @@
|
||||
|
||||
extern "C" __global__ void
|
||||
gridwise_convolution_forward_implicit_gemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer(
|
||||
const void* const __restrict__ p_in_global,
|
||||
const void* const __restrict__ p_wei_global,
|
||||
void* const __restrict__ p_out_global){
|
||||
|
||||
};
|
||||
@@ -1,7 +0,0 @@
|
||||
|
||||
extern "C" __global__ void gridwise_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw(
|
||||
const void* const __restrict__ p_in_global,
|
||||
const void* const __restrict__ p_wei_global,
|
||||
void* const __restrict__ p_out_global){
|
||||
|
||||
};
|
||||
@@ -1,8 +0,0 @@
|
||||
|
||||
|
||||
extern "C" __global__ void gridwise_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk(
|
||||
const void* const __restrict__ p_in_global,
|
||||
const void* const __restrict__ p_wei_global,
|
||||
void* const __restrict__ p_out_global){
|
||||
|
||||
};
|
||||
@@ -1,299 +0,0 @@
|
||||
#include <iostream>
|
||||
#include <numeric>
|
||||
#include <initializer_list>
|
||||
#include <cstdlib>
|
||||
#include <stdlib.h>
|
||||
#include "config.hpp"
|
||||
#include "print.hpp"
|
||||
#include "device.hpp"
|
||||
#include "host_tensor_generator.hpp"
|
||||
#include "device_tensor.hpp"
|
||||
#include "conv_common.hpp"
|
||||
#include "host_conv_bwd_data.hpp"
|
||||
#include "device_convolution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw.hpp"
|
||||
#include "device_convolution_backward_data_implicit_gemm_v1r2_nchw_kcyx_nkhw.hpp"
|
||||
#include "device_convolution_backward_data_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp"
|
||||
#include "device_convolution_backward_data_implicit_gemm_v5r1_nhwc_kyxc_nhwk.hpp"
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
using namespace launcher;
|
||||
|
||||
#if 1
|
||||
// 1x1 filter, 14x14 image
|
||||
constexpr index_t N = 1;
|
||||
constexpr index_t C = 256;
|
||||
constexpr index_t HI = 1;
|
||||
constexpr index_t WI = 128;
|
||||
constexpr index_t K = 16;
|
||||
constexpr index_t Y = 1;
|
||||
constexpr index_t X = 1;
|
||||
|
||||
using ConvStrides = Sequence<1, 1>;
|
||||
using ConvDilations = Sequence<1, 1>;
|
||||
|
||||
using LeftPads = Sequence<0, 0>;
|
||||
using RightPads = Sequence<0, 0>;
|
||||
#elif 0
|
||||
constexpr index_t N = 64;
|
||||
constexpr index_t C = 256;
|
||||
constexpr index_t HI = 56;
|
||||
constexpr index_t WI = 56;
|
||||
constexpr index_t K = 256;
|
||||
constexpr index_t Y = 1;
|
||||
constexpr index_t X = 1;
|
||||
|
||||
using ConvStrides = Sequence<1, 1>;
|
||||
using ConvDilations = Sequence<1, 1>;
|
||||
|
||||
using LeftPads = Sequence<0, 0>;
|
||||
using RightPads = Sequence<0, 0>;
|
||||
#elif 0
|
||||
// 3x3, 34x34
|
||||
constexpr index_t N = 64;
|
||||
constexpr index_t C = 256;
|
||||
constexpr index_t HI = 34;
|
||||
constexpr index_t WI = 34;
|
||||
constexpr index_t K = 256;
|
||||
constexpr index_t Y = 3;
|
||||
constexpr index_t X = 3;
|
||||
|
||||
using ConvStrides = Sequence<1, 1>;
|
||||
using ConvDilations = Sequence<1, 1>;
|
||||
|
||||
using LeftPads = Sequence<0, 0>;
|
||||
using RightPads = Sequence<0, 0>;
|
||||
#elif 0
|
||||
// 3x3, 28x28
|
||||
constexpr index_t N = 128;
|
||||
constexpr index_t C = 128;
|
||||
constexpr index_t HI = 28;
|
||||
constexpr index_t WI = 28;
|
||||
constexpr index_t K = 128;
|
||||
constexpr index_t Y = 3;
|
||||
constexpr index_t X = 3;
|
||||
|
||||
using ConvStrides = Sequence<1, 1>;
|
||||
using ConvDilations = Sequence<1, 1>;
|
||||
|
||||
using LeftPads = Sequence<1, 1>;
|
||||
using RightPads = Sequence<1, 1>;
|
||||
#elif 0
|
||||
// 1x1 filter, 8x8 image
|
||||
constexpr index_t N = 256;
|
||||
constexpr index_t C = 1024;
|
||||
constexpr index_t HI = 8;
|
||||
constexpr index_t WI = 8;
|
||||
constexpr index_t K = 1024;
|
||||
constexpr index_t Y = 1;
|
||||
constexpr index_t X = 1;
|
||||
|
||||
using ConvStrides = Sequence<1, 1>;
|
||||
using ConvDilations = Sequence<1, 1>;
|
||||
|
||||
using LeftPads = Sequence<0, 0>;
|
||||
using RightPads = Sequence<0, 0>;
|
||||
#elif 0
|
||||
// 1x1 filter, 7x7 image
|
||||
constexpr index_t N = 128;
|
||||
constexpr index_t C = 1024;
|
||||
constexpr index_t HI = 7;
|
||||
constexpr index_t WI = 7;
|
||||
constexpr index_t K = 1024;
|
||||
constexpr index_t Y = 1;
|
||||
constexpr index_t X = 1;
|
||||
|
||||
using ConvStrides = Sequence<1, 1>;
|
||||
using ConvDilations = Sequence<1, 1>;
|
||||
|
||||
using LeftPads = Sequence<0, 0>;
|
||||
using RightPads = Sequence<0, 0>;
|
||||
#elif 1
|
||||
// 1x1 filter, 14x14 image
|
||||
constexpr index_t N = 128;
|
||||
constexpr index_t C = 512;
|
||||
constexpr index_t HI = 14;
|
||||
constexpr index_t WI = 14;
|
||||
constexpr index_t K = 128;
|
||||
constexpr index_t Y = 1;
|
||||
constexpr index_t X = 1;
|
||||
|
||||
using ConvStrides = Sequence<1, 1>;
|
||||
using ConvDilations = Sequence<1, 1>;
|
||||
|
||||
using LeftPads = Sequence<0, 0>;
|
||||
using RightPads = Sequence<0, 0>;
|
||||
#elif 0
|
||||
// 1x1 filter, 28x28 image
|
||||
constexpr index_t N = 128;
|
||||
constexpr index_t C = 128;
|
||||
constexpr index_t HI = 28;
|
||||
constexpr index_t WI = 28;
|
||||
constexpr index_t K = 128;
|
||||
constexpr index_t Y = 1;
|
||||
constexpr index_t X = 1;
|
||||
|
||||
using ConvStrides = Sequence<1, 1>;
|
||||
using ConvDilations = Sequence<1, 1>;
|
||||
|
||||
using LeftPads = Sequence<0, 0>;
|
||||
using RightPads = Sequence<0, 0>;
|
||||
#elif 0
|
||||
// 1x1 filter, 17x17 input
|
||||
constexpr index_t N = 128;
|
||||
constexpr index_t C = 1024;
|
||||
constexpr index_t HI = 17;
|
||||
constexpr index_t WI = 17;
|
||||
constexpr index_t K = 1024;
|
||||
constexpr index_t Y = 1;
|
||||
constexpr index_t X = 1;
|
||||
|
||||
using ConvStrides = Sequence<1, 1>;
|
||||
using ConvDilations = Sequence<1, 1>;
|
||||
|
||||
using LeftPads = Sequence<0, 0>;
|
||||
using RightPads = Sequence<0, 0>;
|
||||
#elif 0
|
||||
// 5x5 filter, 2x2 pad, 7x7 input
|
||||
constexpr index_t N = 128;
|
||||
constexpr index_t C = 1024;
|
||||
constexpr index_t HI = 7;
|
||||
constexpr index_t WI = 7;
|
||||
constexpr index_t K = 1024;
|
||||
constexpr index_t Y = 5;
|
||||
constexpr index_t X = 5;
|
||||
|
||||
using ConvStrides = Sequence<1, 1>;
|
||||
using ConvDilations = Sequence<1, 1>;
|
||||
|
||||
using LeftPads = Sequence<2, 2>;
|
||||
using RightPads = Sequence<2, 2>;
|
||||
#elif 0
|
||||
// 1x7 filter, 0x3 pad, 17x17 input
|
||||
constexpr index_t N = 128;
|
||||
constexpr index_t C = 128;
|
||||
constexpr index_t HI = 17;
|
||||
constexpr index_t WI = 17;
|
||||
constexpr index_t K = 128;
|
||||
constexpr index_t Y = 1;
|
||||
constexpr index_t X = 7;
|
||||
|
||||
using ConvStrides = Sequence<1, 1>;
|
||||
using ConvDilations = Sequence<1, 1>;
|
||||
|
||||
using LeftPads = Sequence<0, 3>;
|
||||
using RightPads = Sequence<0, 3>;
|
||||
#elif 0
|
||||
// 7x1 filter, 3x0 pad, 17x17 input
|
||||
constexpr index_t N = 128;
|
||||
constexpr index_t C = 256;
|
||||
constexpr index_t HI = 17;
|
||||
constexpr index_t WI = 17;
|
||||
constexpr index_t K = 1024;
|
||||
constexpr index_t Y = 7;
|
||||
constexpr index_t X = 1;
|
||||
|
||||
using ConvStrides = Sequence<1, 1>;
|
||||
using ConvDilations = Sequence<1, 1>;
|
||||
|
||||
using LeftPads = Sequence<3, 0>;
|
||||
using RightPads = Sequence<3, 0>;
|
||||
#elif 1
|
||||
// 3x3 filter, 2x2 stride, 35x35 input, 17x17 output
|
||||
constexpr index_t N = 128;
|
||||
constexpr index_t C = 256;
|
||||
constexpr index_t HI = 35;
|
||||
constexpr index_t WI = 35;
|
||||
constexpr index_t K = 1280;
|
||||
constexpr index_t Y = 3;
|
||||
constexpr index_t X = 3;
|
||||
|
||||
using ConvStrides = Sequence<2, 2>;
|
||||
using ConvDilations = Sequence<2, 2>;
|
||||
|
||||
using LeftPads = Sequence<0, 0>;
|
||||
using RightPads = Sequence<0, 0>;
|
||||
#endif
|
||||
|
||||
constexpr auto in_nchw_desc = make_native_tensor_descriptor_packed(Sequence<N, C, HI, WI>{});
|
||||
constexpr auto wei_kcyx_desc = make_native_tensor_descriptor_packed(Sequence<K, C, Y, X>{});
|
||||
constexpr auto out_nkhw_desc = get_convolution_output_default_4d_tensor_descriptor(
|
||||
in_nchw_desc, wei_kcyx_desc, ConvStrides{}, ConvDilations{}, LeftPads{}, RightPads{});
|
||||
|
||||
ostream_tensor_descriptor(in_nchw_desc, std::cout << "in_nchw_desc: ");
|
||||
ostream_tensor_descriptor(wei_kcyx_desc, std::cout << "wei_kcyx_desc: ");
|
||||
ostream_tensor_descriptor(out_nkhw_desc, std::cout << "out_nkhw_desc: ");
|
||||
print_array("LeftPads", LeftPads{});
|
||||
print_array("LeftPads", LeftPads{});
|
||||
print_array("RightPads", RightPads{});
|
||||
print_array("ConvStrides", ConvStrides{});
|
||||
print_array("ConvDilations", ConvDilations{});
|
||||
|
||||
Tensor<float> in_nchw_device(make_HostTensorDescriptor(in_nchw_desc));
|
||||
Tensor<float> in_nchw_host(make_HostTensorDescriptor(in_nchw_desc));
|
||||
Tensor<float> wei_kcyx(make_HostTensorDescriptor(wei_kcyx_desc));
|
||||
Tensor<float> out_nkhw(make_HostTensorDescriptor(out_nkhw_desc));
|
||||
|
||||
std::size_t num_thread = std::thread::hardware_concurrency();
|
||||
|
||||
if(argc != 3)
|
||||
{
|
||||
printf("arg1: do_verification, arg2: nrepeat\n");
|
||||
exit(1);
|
||||
}
|
||||
|
||||
bool do_verification = atoi(argv[1]);
|
||||
std::size_t nrepeat = atoi(argv[2]);
|
||||
|
||||
if(do_verification)
|
||||
{
|
||||
#if 0
|
||||
wei_kcyx.GenerateTensorValue(GeneratorTensor_1{1}, num_thread);
|
||||
out_nkhw.GenerateTensorValue(GeneratorTensor_1{1}, num_thread);
|
||||
#else
|
||||
wei_kcyx.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread);
|
||||
out_nkhw.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread);
|
||||
#endif
|
||||
}
|
||||
|
||||
#if 0
|
||||
device_convolution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw
|
||||
#elif 0
|
||||
device_convolution_backward_data_implicit_gemm_v1r2_nchw_kcyx_nkhw
|
||||
#elif 0
|
||||
device_convolution_backward_data_implicit_gemm_v4r1_nchw_kcyx_nkhw
|
||||
#elif 1
|
||||
device_convolution_backward_data_implicit_gemm_v5r1_nhwc_kyxc_nhwk
|
||||
#endif
|
||||
(in_nchw_desc,
|
||||
in_nchw_device,
|
||||
wei_kcyx_desc,
|
||||
wei_kcyx,
|
||||
out_nkhw_desc,
|
||||
out_nkhw,
|
||||
ConvStrides{},
|
||||
ConvDilations{},
|
||||
LeftPads{},
|
||||
RightPads{},
|
||||
nrepeat);
|
||||
|
||||
if(do_verification)
|
||||
{
|
||||
host_direct_convolution_backward_data(in_nchw_host,
|
||||
wei_kcyx,
|
||||
out_nkhw,
|
||||
ConvStrides{},
|
||||
ConvDilations{},
|
||||
LeftPads{},
|
||||
RightPads{});
|
||||
|
||||
check_error(in_nchw_host, in_nchw_device);
|
||||
|
||||
#if 0
|
||||
LogRange(std::cout << "out_nkhw : ", out_nkhw.mData, ",") << std::endl;
|
||||
LogRange(std::cout << "wei_kcyx : ", wei_kcyx.mData, ",") << std::endl;
|
||||
LogRange(std::cout << "in_nchw_host : ", in_nchw_host.mData, ",") << std::endl;
|
||||
LogRange(std::cout << "in_nchw_device : ", in_nchw_device.mData, ",") << std::endl;
|
||||
#endif
|
||||
}
|
||||
}
|
||||
@@ -1,780 +0,0 @@
|
||||
#include <iostream>
|
||||
#include <numeric>
|
||||
#include <initializer_list>
|
||||
#include <cstdlib>
|
||||
#include <stdlib.h>
|
||||
#include <half.hpp>
|
||||
#include "config.hpp"
|
||||
#include "print.hpp"
|
||||
#include "device.hpp"
|
||||
#include "host_tensor_generator.hpp"
|
||||
#include "conv_common.hpp"
|
||||
#include "host_conv.hpp"
|
||||
#include "device_tensor.hpp"
|
||||
#include "device_convolution_forward_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp"
|
||||
#include "device_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp"
|
||||
#include "device_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk.hpp"
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
using namespace ck;
|
||||
|
||||
if(argc != 5)
|
||||
{
|
||||
printf("arg1: do_verification, arg2: do_log, arg3: init_method, arg4: nrepeat\n");
|
||||
exit(1);
|
||||
}
|
||||
|
||||
const bool do_verification = atoi(argv[1]);
|
||||
const bool do_log = atoi(argv[2]);
|
||||
const int init_method = atoi(argv[3]);
|
||||
const int nrepeat = atoi(argv[4]);
|
||||
|
||||
#if 0
|
||||
constexpr index_t N = 256;
|
||||
constexpr index_t C = 256;
|
||||
constexpr index_t HI = 16;
|
||||
constexpr index_t WI = 16;
|
||||
constexpr index_t K = 256;
|
||||
constexpr index_t Y = 1;
|
||||
constexpr index_t X = 1;
|
||||
|
||||
using ConvStrides = Sequence<1, 1>;
|
||||
using ConvDilations = Sequence<1, 1>;
|
||||
|
||||
using InLeftPads = Sequence<0, 0>;
|
||||
using InRightPads = Sequence<0, 0>;
|
||||
#elif 0
|
||||
constexpr index_t N = 1;
|
||||
constexpr index_t C = 16;
|
||||
constexpr index_t HI = 1080;
|
||||
constexpr index_t WI = 1920;
|
||||
constexpr index_t K = 16;
|
||||
constexpr index_t Y = 1;
|
||||
constexpr index_t X = 1;
|
||||
|
||||
using ConvStrides = Sequence<1, 1>;
|
||||
using ConvDilations = Sequence<1, 1>;
|
||||
|
||||
using InLeftPads = Sequence<1, 1>;
|
||||
using InRightPads = Sequence<1, 1>;
|
||||
#elif 0
|
||||
constexpr index_t N = 1;
|
||||
constexpr index_t C = 16;
|
||||
constexpr index_t Hi = 540;
|
||||
constexpr index_t Wi = 960;
|
||||
constexpr index_t K = 16;
|
||||
constexpr index_t Y = 1;
|
||||
constexpr index_t X = 1;
|
||||
|
||||
using ConvStrides = Sequence<1, 1>;
|
||||
using ConvDilations = Sequence<1, 1>;
|
||||
|
||||
using InLeftPads = Sequence<0, 0>;
|
||||
using InRightPads = Sequence<0, 0>;
|
||||
#elif 0
|
||||
constexpr index_t N = 1;
|
||||
constexpr index_t C = 16;
|
||||
constexpr index_t Hi = 270;
|
||||
constexpr index_t Wi = 480;
|
||||
constexpr index_t K = 16;
|
||||
constexpr index_t Y = 1;
|
||||
constexpr index_t X = 1;
|
||||
|
||||
using ConvStrides = Sequence<1, 1>;
|
||||
using ConvDilations = Sequence<1, 1>;
|
||||
|
||||
using InLeftPads = Sequence<0, 0>;
|
||||
using InRightPads = Sequence<0, 0>;
|
||||
#elif 0
|
||||
constexpr index_t N = 1;
|
||||
constexpr index_t C = 16;
|
||||
constexpr index_t Hi = 1080;
|
||||
constexpr index_t Wi = 1920;
|
||||
constexpr index_t K = 16;
|
||||
constexpr index_t Y = 3;
|
||||
constexpr index_t X = 3;
|
||||
|
||||
using ConvStrides = Sequence<1, 1>;
|
||||
using ConvDilations = Sequence<1, 1>;
|
||||
|
||||
using InLeftPads = Sequence<1, 1>;
|
||||
using InRightPads = Sequence<1, 1>;
|
||||
#elif 0
|
||||
constexpr index_t N = 1;
|
||||
constexpr index_t C = 1;
|
||||
constexpr index_t Hi = 1024;
|
||||
constexpr index_t Wi = 2048;
|
||||
constexpr index_t K = 4;
|
||||
constexpr index_t Y = 3;
|
||||
constexpr index_t X = 3;
|
||||
|
||||
using ConvStrides = Sequence<1, 1>;
|
||||
using ConvDilations = Sequence<1, 1>;
|
||||
|
||||
using InLeftPads = Sequence<1, 1>;
|
||||
using InRightPads = Sequence<1, 1>;
|
||||
#elif 0
|
||||
constexpr index_t N = 1;
|
||||
constexpr index_t C = 16;
|
||||
constexpr index_t Hi = 540;
|
||||
constexpr index_t Wi = 960;
|
||||
constexpr index_t K = 16;
|
||||
constexpr index_t Y = 3;
|
||||
constexpr index_t X = 3;
|
||||
|
||||
using ConvStrides = Sequence<1, 1>;
|
||||
using ConvDilations = Sequence<1, 1>;
|
||||
|
||||
using InLeftPads = Sequence<1, 1>;
|
||||
using InRightPads = Sequence<1, 1>;
|
||||
#elif 0
|
||||
constexpr index_t N = 1;
|
||||
constexpr index_t C = 16;
|
||||
constexpr index_t Hi = 270;
|
||||
constexpr index_t Wi = 480;
|
||||
constexpr index_t K = 16;
|
||||
constexpr index_t Y = 3;
|
||||
constexpr index_t X = 3;
|
||||
|
||||
using ConvStrides = Sequence<1, 1>;
|
||||
using ConvDilations = Sequence<1, 1>;
|
||||
|
||||
using InLeftPads = Sequence<1, 1>;
|
||||
using InRightPads = Sequence<1, 1>;
|
||||
#elif 0
|
||||
// 3x3, 36x36, stride 2
|
||||
constexpr index_t N = 128;
|
||||
constexpr index_t C = 192;
|
||||
constexpr index_t Hi = 37;
|
||||
constexpr index_t Wi = 37;
|
||||
constexpr index_t K = 384;
|
||||
constexpr index_t Y = 3;
|
||||
constexpr index_t X = 3;
|
||||
|
||||
using ConvStrides = Sequence<2, 2>;
|
||||
using ConvDilations = Sequence<1, 1>;
|
||||
|
||||
using InLeftPads = Sequence<0, 0>;
|
||||
using InRightPads = Sequence<0, 0>;
|
||||
#elif 0
|
||||
// 3x3, 35x35, stride 2
|
||||
constexpr index_t N = 128;
|
||||
constexpr index_t C = 192;
|
||||
constexpr index_t Hi = 35;
|
||||
constexpr index_t Wi = 35;
|
||||
constexpr index_t K = 384;
|
||||
constexpr index_t Y = 3;
|
||||
constexpr index_t X = 3;
|
||||
|
||||
using ConvStrides = Sequence<2, 2>;
|
||||
using ConvDilations = Sequence<1, 1>;
|
||||
|
||||
using InLeftPads = Sequence<0, 0>;
|
||||
using InRightPads = Sequence<0, 0>;
|
||||
#elif 0
|
||||
// 3x3, 71x71
|
||||
constexpr index_t N = 128;
|
||||
constexpr index_t C = 192;
|
||||
constexpr index_t HI = 71;
|
||||
constexpr index_t WI = 71;
|
||||
constexpr index_t K = 256;
|
||||
constexpr index_t Y = 3;
|
||||
constexpr index_t X = 3;
|
||||
|
||||
using ConvStrides = Sequence<2, 2>;
|
||||
using ConvDilations = Sequence<1, 1>;
|
||||
|
||||
using InLeftPads = Sequence<1, 1>;
|
||||
using InRightPads = Sequence<1, 1>;
|
||||
#elif 0
|
||||
// 1x1, 8x8
|
||||
constexpr index_t N = 128;
|
||||
constexpr index_t C = 1536;
|
||||
constexpr index_t Hi = 8;
|
||||
constexpr index_t Wi = 8;
|
||||
constexpr index_t K = 256;
|
||||
constexpr index_t Y = 1;
|
||||
constexpr index_t X = 1;
|
||||
|
||||
using ConvStrides = Sequence<1, 1>;
|
||||
using ConvDilations = Sequence<1, 1>;
|
||||
|
||||
using InLeftPads = Sequence<0, 0>;
|
||||
using InRightPads = Sequence<0, 0>;
|
||||
#elif 0
|
||||
// 1x1, 73x73
|
||||
constexpr index_t N = 128;
|
||||
constexpr index_t C = 160;
|
||||
constexpr index_t Hi = 73;
|
||||
constexpr index_t Wi = 73;
|
||||
constexpr index_t K = 64;
|
||||
constexpr index_t Y = 1;
|
||||
constexpr index_t X = 1;
|
||||
|
||||
using ConvStrides = Sequence<1, 1>;
|
||||
using ConvDilations = Sequence<1, 1>;
|
||||
|
||||
using InLeftPads = Sequence<0, 0>;
|
||||
using InRightPads = Sequence<0, 0>;
|
||||
#elif 0
|
||||
// 3x3, 35x35
|
||||
constexpr index_t N = 128;
|
||||
constexpr index_t C = 96;
|
||||
constexpr index_t Hi = 35;
|
||||
constexpr index_t Wi = 35;
|
||||
constexpr index_t K = 128;
|
||||
constexpr index_t Y = 3;
|
||||
constexpr index_t X = 3;
|
||||
|
||||
using ConvStrides = Sequence<1, 1>;
|
||||
using ConvDilations = Sequence<1, 1>;
|
||||
|
||||
using InLeftPads = Sequence<1, 1>;
|
||||
using InRightPads = Sequence<1, 1>;
|
||||
#elif 0
|
||||
// 3x3, 71x71
|
||||
constexpr index_t N = 128;
|
||||
constexpr index_t C = 192;
|
||||
constexpr index_t Hi = 71;
|
||||
constexpr index_t Wi = 71;
|
||||
constexpr index_t K = 192;
|
||||
constexpr index_t Y = 3;
|
||||
constexpr index_t X = 3;
|
||||
|
||||
using ConvStrides = Sequence<2, 2>;
|
||||
using ConvDilations = Sequence<1, 1>;
|
||||
|
||||
using InLeftPads = Sequence<1, 1>;
|
||||
using InRightPads = Sequence<1, 1>;
|
||||
#elif 0
|
||||
// 7x1, 17x17
|
||||
constexpr index_t N = 128;
|
||||
constexpr index_t C = 128;
|
||||
constexpr index_t Hi = 17;
|
||||
constexpr index_t Wi = 17;
|
||||
constexpr index_t K = 128;
|
||||
constexpr index_t Y = 7;
|
||||
constexpr index_t X = 1;
|
||||
|
||||
using ConvStrides = Sequence<1, 1>;
|
||||
using ConvDilations = Sequence<1, 1>;
|
||||
|
||||
using InLeftPads = Sequence<3, 0>;
|
||||
using InRightPads = Sequence<3, 0>;
|
||||
#elif 1
|
||||
// 1x7, 17x17
|
||||
constexpr index_t N = 128;
|
||||
constexpr index_t C = 128;
|
||||
constexpr index_t Hi = 17;
|
||||
constexpr index_t Wi = 17;
|
||||
constexpr index_t K = 128;
|
||||
constexpr index_t Y = 1;
|
||||
constexpr index_t X = 7;
|
||||
|
||||
using ConvStrides = Sequence<1, 1>;
|
||||
using ConvDilations = Sequence<1, 1>;
|
||||
|
||||
using InLeftPads = Sequence<0, 3>;
|
||||
using InRightPads = Sequence<0, 3>;
|
||||
#elif 0
|
||||
// 3x3, 299x299 stride=2
|
||||
constexpr index_t N = 128;
|
||||
constexpr index_t C = 3;
|
||||
constexpr index_t Hi = 299;
|
||||
constexpr index_t Wi = 299;
|
||||
constexpr index_t K = 32;
|
||||
constexpr index_t Y = 3;
|
||||
constexpr index_t X = 3;
|
||||
|
||||
using ConvStrides = Sequence<2, 2>;
|
||||
using ConvDilations = Sequence<1, 1>;
|
||||
|
||||
using InLeftPads = Sequence<0, 0>;
|
||||
using InRightPads = Sequence<0, 0>;
|
||||
#elif 0
|
||||
// 3x3, 147x147
|
||||
constexpr index_t N = 128;
|
||||
constexpr index_t C = 128;
|
||||
constexpr index_t Hi = 147;
|
||||
constexpr index_t Wi = 147;
|
||||
constexpr index_t K = 128;
|
||||
constexpr index_t Y = 3;
|
||||
constexpr index_t X = 3;
|
||||
|
||||
using ConvStrides = Sequence<1, 1>;
|
||||
using ConvDilations = Sequence<1, 1>;
|
||||
|
||||
using InLeftPads = Sequence<1, 1>;
|
||||
using InRightPads = Sequence<1, 1>;
|
||||
#elif 0
|
||||
// 3x3, 149x149
|
||||
constexpr index_t N = 128;
|
||||
constexpr index_t C = 32;
|
||||
constexpr index_t Hi = 149;
|
||||
constexpr index_t Wi = 149;
|
||||
constexpr index_t K = 32;
|
||||
constexpr index_t Y = 3;
|
||||
constexpr index_t X = 3;
|
||||
|
||||
using ConvStrides = Sequence<1, 1>;
|
||||
using ConvDilations = Sequence<1, 1>;
|
||||
|
||||
using InLeftPads = Sequence<0, 0>;
|
||||
using InRightPads = Sequence<0, 0>;
|
||||
#elif 0
|
||||
// 3x3, 17x17, stride 2
|
||||
constexpr index_t N = 128;
|
||||
constexpr index_t C = 192;
|
||||
constexpr index_t Hi = 17;
|
||||
constexpr index_t Wi = 17;
|
||||
constexpr index_t K = 192;
|
||||
constexpr index_t Y = 3;
|
||||
constexpr index_t X = 3;
|
||||
|
||||
using ConvStrides = Sequence<2, 2>;
|
||||
using ConvDilations = Sequence<1, 1>;
|
||||
|
||||
using InLeftPads = Sequence<0, 0>;
|
||||
using InRightPads = Sequence<0, 0>;
|
||||
#elif 0
|
||||
// 1x1, 35x35
|
||||
constexpr index_t N = 128;
|
||||
constexpr index_t C = 384;
|
||||
constexpr index_t Hi = 35;
|
||||
constexpr index_t Wi = 35;
|
||||
constexpr index_t K = 96;
|
||||
constexpr index_t Y = 1;
|
||||
constexpr index_t X = 1;
|
||||
|
||||
using ConvStrides = Sequence<1, 1>;
|
||||
using ConvDilations = Sequence<1, 1>;
|
||||
|
||||
using InLeftPads = Sequence<0, 0>;
|
||||
using InRightPads = Sequence<0, 0>;
|
||||
#elif 0
|
||||
// 3x3, 35x35, stride 2
|
||||
constexpr index_t N = 128;
|
||||
constexpr index_t C = 288;
|
||||
constexpr index_t Hi = 35;
|
||||
constexpr index_t Wi = 35;
|
||||
constexpr index_t K = 384;
|
||||
constexpr index_t Y = 3;
|
||||
constexpr index_t X = 3;
|
||||
|
||||
using ConvStrides = Sequence<2, 2>;
|
||||
using ConvDilations = Sequence<1, 1>;
|
||||
|
||||
using InLeftPads = Sequence<0, 0>;
|
||||
using InRightPads = Sequence<0, 0>;
|
||||
#elif 0
|
||||
// 1x3, 8x8
|
||||
constexpr index_t N = 128;
|
||||
constexpr index_t C = 384;
|
||||
constexpr index_t Hi = 8;
|
||||
constexpr index_t Wi = 8;
|
||||
constexpr index_t K = 448;
|
||||
constexpr index_t Y = 1;
|
||||
constexpr index_t X = 3;
|
||||
|
||||
using ConvStrides = Sequence<1, 1>;
|
||||
using ConvDilations = Sequence<1, 1>;
|
||||
|
||||
using InLeftPads = Sequence<0, 1>;
|
||||
using InRightPads = Sequence<0, 1>;
|
||||
#elif 0
|
||||
// 3x1, 8x8
|
||||
constexpr index_t N = 128;
|
||||
constexpr index_t C = 448;
|
||||
constexpr index_t Hi = 8;
|
||||
constexpr index_t Wi = 8;
|
||||
constexpr index_t K = 512;
|
||||
constexpr index_t Y = 3;
|
||||
constexpr index_t X = 1;
|
||||
|
||||
using ConvStrides = Sequence<1, 1>;
|
||||
using ConvDilations = Sequence<1, 1>;
|
||||
|
||||
using InLeftPads = Sequence<1, 0>;
|
||||
using InRightPads = Sequence<1, 0>;
|
||||
#elif 0
|
||||
// 3x3, 147x147
|
||||
constexpr index_t N = 128;
|
||||
constexpr index_t C = 64;
|
||||
constexpr index_t Hi = 147;
|
||||
constexpr index_t Wi = 147;
|
||||
constexpr index_t K = 96;
|
||||
constexpr index_t Y = 3;
|
||||
constexpr index_t X = 3;
|
||||
|
||||
using ConvStrides = Sequence<2, 2>;
|
||||
using ConvDilations = Sequence<1, 1>;
|
||||
|
||||
using InLeftPads = Sequence<0, 0>;
|
||||
using InRightPads = Sequence<0, 0>;
|
||||
#elif 0
|
||||
// 7x1, 73x73
|
||||
constexpr index_t N = 128;
|
||||
constexpr index_t C = 64;
|
||||
constexpr index_t Hi = 73;
|
||||
constexpr index_t Wi = 73;
|
||||
constexpr index_t K = 64;
|
||||
constexpr index_t Y = 7;
|
||||
constexpr index_t X = 1;
|
||||
|
||||
using ConvStrides = Sequence<1, 1>;
|
||||
using ConvDilations = Sequence<1, 1>;
|
||||
|
||||
using InLeftPads = Sequence<3, 0>;
|
||||
using InRightPads = Sequence<3, 0>;
|
||||
#elif 0
|
||||
// 3x3, 73x73
|
||||
constexpr index_t N = 128;
|
||||
constexpr index_t C = 64;
|
||||
constexpr index_t Hi = 73;
|
||||
constexpr index_t Wi = 73;
|
||||
constexpr index_t K = 96;
|
||||
constexpr index_t Y = 3;
|
||||
constexpr index_t X = 3;
|
||||
|
||||
using ConvStrides = Sequence<1, 1>;
|
||||
using ConvDilations = Sequence<1, 1>;
|
||||
|
||||
using InLeftPads = Sequence<0, 0>;
|
||||
using InRightPads = Sequence<0, 0>;
|
||||
#elif 0
|
||||
// 1x1, 14x14, stride 2
|
||||
constexpr index_t N = 256;
|
||||
constexpr index_t C = 1024;
|
||||
constexpr index_t Hi = 14;
|
||||
constexpr index_t Wi = 14;
|
||||
constexpr index_t K = 2048;
|
||||
constexpr index_t Y = 1;
|
||||
constexpr index_t X = 1;
|
||||
|
||||
using ConvStrides = Sequence<2, 2>;
|
||||
using ConvDilations = Sequence<1, 1>;
|
||||
|
||||
using InLeftPads = Sequence<0, 0>;
|
||||
using InRightPads = Sequence<0, 0>;
|
||||
#elif 0
|
||||
// 1x1, 14x14
|
||||
constexpr index_t N = 256;
|
||||
constexpr index_t C = 1024;
|
||||
constexpr index_t Hi = 14;
|
||||
constexpr index_t Wi = 14;
|
||||
constexpr index_t K = 256;
|
||||
constexpr index_t Y = 1;
|
||||
constexpr index_t X = 1;
|
||||
|
||||
using ConvStrides = Sequence<1, 1>;
|
||||
using ConvDilations = Sequence<1, 1>;
|
||||
|
||||
using InLeftPads = Sequence<0, 0>;
|
||||
using InRightPads = Sequence<0, 0>;
|
||||
#elif 0
|
||||
// 1x1, 14x14, stride 2
|
||||
constexpr index_t N = 128;
|
||||
constexpr index_t C = 1024;
|
||||
constexpr index_t Hi = 14;
|
||||
constexpr index_t Wi = 14;
|
||||
constexpr index_t K = 512;
|
||||
constexpr index_t Y = 1;
|
||||
constexpr index_t X = 1;
|
||||
|
||||
using ConvStrides = Sequence<2, 2>;
|
||||
using ConvDilations = Sequence<1, 1>;
|
||||
|
||||
using InLeftPads = Sequence<0, 0>;
|
||||
using InRightPads = Sequence<0, 0>;
|
||||
#elif 1
|
||||
// 3x3, 28x28
|
||||
constexpr index_t N = 128;
|
||||
constexpr index_t C = 128;
|
||||
constexpr index_t Hi = 28;
|
||||
constexpr index_t Wi = 28;
|
||||
constexpr index_t K = 128;
|
||||
constexpr index_t Y = 3;
|
||||
constexpr index_t X = 3;
|
||||
|
||||
using ConvStrides = Sequence<1, 1>;
|
||||
using ConvDilations = Sequence<1, 1>;
|
||||
|
||||
using InLeftPads = Sequence<1, 1>;
|
||||
using InRightPads = Sequence<1, 1>;
|
||||
#elif 1
|
||||
// 3x3, 14x14
|
||||
constexpr index_t N = 128;
|
||||
constexpr index_t C = 256;
|
||||
constexpr index_t Hi = 14;
|
||||
constexpr index_t Wi = 14;
|
||||
constexpr index_t K = 256;
|
||||
constexpr index_t Y = 3;
|
||||
constexpr index_t X = 3;
|
||||
|
||||
using ConvStrides = Sequence<1, 1>;
|
||||
using ConvDilations = Sequence<1, 1>;
|
||||
|
||||
using InLeftPads = Sequence<1, 1>;
|
||||
using InRightPads = Sequence<1, 1>;
|
||||
#elif 0
|
||||
// 1x1, 56x56, stride 2
|
||||
constexpr index_t N = 128;
|
||||
constexpr index_t C = 256;
|
||||
constexpr index_t Hi = 56;
|
||||
constexpr index_t Wi = 56;
|
||||
constexpr index_t K = 128;
|
||||
constexpr index_t Y = 1;
|
||||
constexpr index_t X = 1;
|
||||
|
||||
using ConvStrides = Sequence<2, 2>;
|
||||
using ConvDilations = Sequence<1, 1>;
|
||||
|
||||
using InLeftPads = Sequence<0, 0>;
|
||||
using InRightPads = Sequence<0, 0>;
|
||||
#elif 0
|
||||
// 7x7, 230x230 stride=2
|
||||
constexpr index_t N = 128;
|
||||
constexpr index_t C = 3;
|
||||
constexpr index_t Hi = 230;
|
||||
constexpr index_t Wi = 230;
|
||||
constexpr index_t K = 64;
|
||||
constexpr index_t Y = 7;
|
||||
constexpr index_t X = 7;
|
||||
|
||||
using ConvStrides = Sequence<2, 2>;
|
||||
using ConvDilations = Sequence<1, 1>;
|
||||
|
||||
using InLeftPads = Sequence<0, 0>;
|
||||
using InRightPads = Sequence<0, 0>;
|
||||
#elif 0
|
||||
// 1x1, 28x28, stride = 2
|
||||
constexpr index_t N = 128;
|
||||
constexpr index_t C = 512;
|
||||
constexpr index_t Hi = 28;
|
||||
constexpr index_t Wi = 28;
|
||||
constexpr index_t K = 1024;
|
||||
constexpr index_t Y = 1;
|
||||
constexpr index_t X = 1;
|
||||
|
||||
using ConvStrides = Sequence<2, 2>;
|
||||
using ConvDilations = Sequence<1, 1>;
|
||||
|
||||
using InLeftPads = Sequence<0, 0>;
|
||||
using InRightPads = Sequence<0, 0>;
|
||||
#elif 0
|
||||
// 1x1, 28x28, stride 2
|
||||
constexpr index_t N = 128;
|
||||
constexpr index_t C = 512;
|
||||
constexpr index_t Hi = 28;
|
||||
constexpr index_t Wi = 28;
|
||||
constexpr index_t K = 256;
|
||||
constexpr index_t Y = 1;
|
||||
constexpr index_t X = 1;
|
||||
|
||||
using ConvStrides = Sequence<2, 2>;
|
||||
using ConvDilations = Sequence<1, 1>;
|
||||
|
||||
using InLeftPads = Sequence<0, 0>;
|
||||
using InRightPads = Sequence<0, 0>;
|
||||
#elif 1
|
||||
// 1x1, 7x7
|
||||
constexpr index_t N = 128;
|
||||
constexpr index_t C = 512;
|
||||
constexpr index_t Hi = 7;
|
||||
constexpr index_t Wi = 7;
|
||||
constexpr index_t K = 2048;
|
||||
constexpr index_t Y = 1;
|
||||
constexpr index_t X = 1;
|
||||
|
||||
using ConvStrides = Sequence<1, 1>;
|
||||
using ConvDilations = Sequence<1, 1>;
|
||||
|
||||
using InLeftPads = Sequence<0, 0>;
|
||||
using InRightPads = Sequence<0, 0>;
|
||||
#elif 0
|
||||
// 3x3, 7x7
|
||||
constexpr index_t N = 128;
|
||||
constexpr index_t C = 512;
|
||||
constexpr index_t Hi = 7;
|
||||
constexpr index_t Wi = 7;
|
||||
constexpr index_t K = 512;
|
||||
constexpr index_t Y = 3;
|
||||
constexpr index_t X = 3;
|
||||
|
||||
using ConvStrides = Sequence<1, 1>;
|
||||
using ConvDilations = Sequence<1, 1>;
|
||||
|
||||
using InLeftPads = Sequence<1, 1>;
|
||||
using InRightPads = Sequence<1, 1>;
|
||||
#elif 0
|
||||
// 1x1, 56x56
|
||||
constexpr index_t N = 128;
|
||||
constexpr index_t C = 64;
|
||||
constexpr index_t Hi = 56;
|
||||
constexpr index_t Wi = 56;
|
||||
constexpr index_t K = 64;
|
||||
constexpr index_t Y = 1;
|
||||
constexpr index_t X = 1;
|
||||
|
||||
using ConvStrides = Sequence<1, 1>;
|
||||
using ConvDilations = Sequence<1, 1>;
|
||||
|
||||
using InLeftPads = Sequence<0, 0>;
|
||||
using InRightPads = Sequence<0, 0>;
|
||||
#elif 0
|
||||
// 3x3, 56x56
|
||||
constexpr index_t N = 128;
|
||||
constexpr index_t C = 64;
|
||||
constexpr index_t Hi = 56;
|
||||
constexpr index_t Wi = 56;
|
||||
constexpr index_t K = 64;
|
||||
constexpr index_t Y = 3;
|
||||
constexpr index_t X = 3;
|
||||
|
||||
using ConvStrides = Sequence<1, 1>;
|
||||
using ConvDilations = Sequence<1, 1>;
|
||||
|
||||
using InLeftPads = Sequence<1, 1>;
|
||||
using InRightPads = Sequence<1, 1>;
|
||||
#endif
|
||||
|
||||
constexpr index_t YEff = (Y - 1) * ConvDilations{}[0] + 1;
|
||||
constexpr index_t XEff = (X - 1) * ConvDilations{}[1] + 1;
|
||||
|
||||
constexpr index_t Ho = (Hi + InLeftPads{}[0] + InRightPads{}[0] - YEff) / ConvStrides{}[0] + 1;
|
||||
constexpr index_t Wo = (Wi + InLeftPads{}[1] + InRightPads{}[1] - XEff) / ConvStrides{}[1] + 1;
|
||||
|
||||
#if 1
|
||||
constexpr index_t in_vector_size = 1;
|
||||
using in_data_t = typename vector_type<float, in_vector_size>::type;
|
||||
using acc_data_t = float;
|
||||
using out_data_t = float;
|
||||
#elif 1
|
||||
using in_data_t = half_t;
|
||||
constexpr index_t in_vector_size = 1;
|
||||
using acc_data_t = float;
|
||||
using out_data_t = half_t;
|
||||
#elif 0
|
||||
constexpr index_t in_vector_size = 1;
|
||||
using in_data_t = typename vector_type<float, in_vector_size>::type;
|
||||
using acc_data_t = float;
|
||||
using out_data_t = int8_t;
|
||||
#elif 1
|
||||
constexpr index_t in_vector_size = 16;
|
||||
using in_data_t = typename vector_type<int8_t, in_vector_size>::type;
|
||||
using acc_data_t = int32_t;
|
||||
using out_data_t = int8_t;
|
||||
#endif
|
||||
|
||||
Tensor<in_data_t> in_nchw(HostTensorDescriptor(std::initializer_list<index_t>{N, C, Hi, Wi}));
|
||||
Tensor<in_data_t> wei_kcyx(HostTensorDescriptor(std::initializer_list<index_t>{K, C, Y, X}));
|
||||
Tensor<out_data_t> out_nkhw_host(
|
||||
HostTensorDescriptor(std::initializer_list<index_t>{N, K, Ho, Wo}));
|
||||
Tensor<out_data_t> out_nkhw_device(
|
||||
HostTensorDescriptor(std::initializer_list<index_t>{N, K, Ho, Wo}));
|
||||
|
||||
ostream_HostTensorDescriptor(in_nchw.mDesc, std::cout << "in_nchw_desc: ");
|
||||
ostream_HostTensorDescriptor(wei_kcyx.mDesc, std::cout << "wei_kcyx_desc: ");
|
||||
ostream_HostTensorDescriptor(out_nkhw_host.mDesc, std::cout << "out_nkhw_desc: ");
|
||||
|
||||
print_array("InLeftPads", InLeftPads{});
|
||||
print_array("InRightPads", InRightPads{});
|
||||
print_array("ConvStrides", ConvStrides{});
|
||||
print_array("ConvDilations", ConvDilations{});
|
||||
|
||||
std::size_t num_thread = std::thread::hardware_concurrency();
|
||||
|
||||
if(do_verification)
|
||||
{
|
||||
switch(init_method)
|
||||
{
|
||||
case 0:
|
||||
in_nchw.GenerateTensorValue(GeneratorTensor_1{}, num_thread);
|
||||
wei_kcyx.GenerateTensorValue(GeneratorTensor_1{}, num_thread);
|
||||
break;
|
||||
case 1:
|
||||
in_nchw.GenerateTensorValue(GeneratorTensor_1{}, num_thread);
|
||||
wei_kcyx.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread);
|
||||
break;
|
||||
case 2:
|
||||
in_nchw.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread);
|
||||
wei_kcyx.GenerateTensorValue(GeneratorTensor_1{}, num_thread);
|
||||
break;
|
||||
case 3:
|
||||
in_nchw.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread);
|
||||
wei_kcyx.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread);
|
||||
break;
|
||||
default:
|
||||
in_nchw.GenerateTensorValue(GeneratorTensor_2{1, 5}, num_thread);
|
||||
|
||||
auto gen_wei = [](auto... is) {
|
||||
return GeneratorTensor_2{1, 5}(is...) * GeneratorTensor_Checkboard{}(is...);
|
||||
};
|
||||
wei_kcyx.GenerateTensorValue(gen_wei, num_thread);
|
||||
}
|
||||
}
|
||||
|
||||
constexpr auto in_nchw_desc = make_native_tensor_descriptor_packed(Sequence<N, C, Hi, Wi>{});
|
||||
constexpr auto wei_kcyx_desc = make_native_tensor_descriptor_packed(Sequence<K, C, Y, X>{});
|
||||
constexpr auto out_nkhw_desc = make_native_tensor_descriptor_packed(Sequence<N, K, Ho, Wo>{});
|
||||
|
||||
#if 1
|
||||
device_convolution_forward_implicit_gemm_v4r1_nchw_kcyx_nkhw(in_nchw_desc,
|
||||
in_nchw,
|
||||
wei_kcyx_desc,
|
||||
wei_kcyx,
|
||||
out_nkhw_desc,
|
||||
out_nkhw_device,
|
||||
ConvStrides{},
|
||||
ConvDilations{},
|
||||
InLeftPads{},
|
||||
InRightPads{},
|
||||
nrepeat);
|
||||
#elif 0
|
||||
device_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw(in_nchw_desc,
|
||||
in_nchw,
|
||||
wei_kcyx_desc,
|
||||
wei_kcyx,
|
||||
out_nkhw_desc,
|
||||
out_nkhw_device,
|
||||
ConvStrides{},
|
||||
ConvDilations{},
|
||||
InLeftPads{},
|
||||
InRightPads{},
|
||||
nrepeat);
|
||||
#elif 0
|
||||
device_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk(in_nchw_desc,
|
||||
in_nchw,
|
||||
wei_kcyx_desc,
|
||||
wei_kcyx,
|
||||
out_nkhw_desc,
|
||||
out_nkhw_device,
|
||||
ConvStrides{},
|
||||
ConvDilations{},
|
||||
InLeftPads{},
|
||||
InRightPads{},
|
||||
nrepeat);
|
||||
#endif
|
||||
|
||||
if(do_verification)
|
||||
{
|
||||
host_direct_convolution(in_nchw,
|
||||
wei_kcyx,
|
||||
out_nkhw_host,
|
||||
ConvStrides{},
|
||||
ConvDilations{},
|
||||
InLeftPads{},
|
||||
InRightPads{});
|
||||
|
||||
check_error(out_nkhw_host, out_nkhw_device);
|
||||
|
||||
if(do_log)
|
||||
{
|
||||
LogRange(std::cout << "in_nchw : ", in_nchw.mData, ",") << std::endl;
|
||||
LogRange(std::cout << "wei_kcyx: ", wei_kcyx.mData, ",") << std::endl;
|
||||
LogRange(std::cout << "out_nkhw_host : ", out_nkhw_host.mData, ",") << std::endl;
|
||||
LogRange(std::cout << "out_nkhw_device: ", out_nkhw_device.mData, ",") << std::endl;
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -24,16 +24,16 @@
|
||||
#include "device_dynamic_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp"
|
||||
|
||||
#define USE_DYNAMIC_MODE 1
|
||||
#define USE_CONV_FWD_V4R4_NCHW 0
|
||||
#define USE_CONV_FWD_V4R4_NHWC 0
|
||||
#define USE_CONV_FWD_V4R4R2_NHWC 0
|
||||
#define USE_CONV_FWD_V4R5_NCHW 0
|
||||
#define USE_CONV_FWD_V4R4_NCHW 1
|
||||
#define USE_CONV_FWD_V4R4_NHWC 1
|
||||
#define USE_CONV_FWD_V4R4R2_NHWC 1
|
||||
#define USE_CONV_FWD_V4R5_NCHW 1
|
||||
#define USE_CONV_FWD_V4R5R2_NCHW 1
|
||||
#define USE_CONV_FWD_V5R1_NCHW 0
|
||||
#define USE_CONV_FWD_V4R4_XDL_NCHW 0
|
||||
#define USE_CONV_FWD_V4R4R2_XDL_NHWC 0
|
||||
#define USE_CONV_FWD_V4R4R3_XDL_NHWC 0
|
||||
#define USE_CONV_FWD_V4R4R4_XDL_NHWC 0
|
||||
#define USE_CONV_FWD_V4R4_XDL_NCHW 1
|
||||
#define USE_CONV_FWD_V4R4R2_XDL_NHWC 1
|
||||
#define USE_CONV_FWD_V4R4R3_XDL_NHWC 1
|
||||
#define USE_CONV_FWD_V4R4R4_XDL_NHWC 1
|
||||
|
||||
enum ConvForwardAlgo
|
||||
{
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
#ifndef CONV_COMMON_HPP
|
||||
#define CONV_COMMON_HPP
|
||||
|
||||
#include "tensor_descriptor.hpp"
|
||||
#include "dynamic_tensor_descriptor.hpp"
|
||||
|
||||
enum ConvTensorLayout
|
||||
@@ -13,53 +12,6 @@ enum ConvTensorLayout
|
||||
NHWCc
|
||||
};
|
||||
|
||||
template <class InDesc,
|
||||
class WeiDesc,
|
||||
class ConvStrides,
|
||||
class ConvDilations,
|
||||
class LeftPads,
|
||||
class RightPads>
|
||||
constexpr auto get_convolution_output_default_4d_tensor_descriptor(
|
||||
InDesc, WeiDesc, ConvStrides, ConvDilations, LeftPads, RightPads)
|
||||
{
|
||||
using namespace ck;
|
||||
|
||||
constexpr auto in_desc = InDesc{};
|
||||
constexpr auto wei_desc = WeiDesc{};
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
static_assert(in_desc.GetNumOfDimension() == 4, "input nDim is not 4");
|
||||
static_assert(wei_desc.GetNumOfDimension() == 4, "weight nDim is not 4");
|
||||
static_assert(in_desc.GetLength(I1) == wei_desc.GetLength(I1),
|
||||
"input & weight dimension not consistent");
|
||||
|
||||
constexpr index_t N = in_desc.GetLength(I0);
|
||||
constexpr index_t Hi = in_desc.GetLength(I2);
|
||||
constexpr index_t Wi = in_desc.GetLength(I3);
|
||||
|
||||
constexpr index_t K = wei_desc.GetLength(I0);
|
||||
constexpr index_t Y = wei_desc.GetLength(I2);
|
||||
constexpr index_t X = wei_desc.GetLength(I3);
|
||||
|
||||
constexpr index_t LeftPadH = LeftPads{}.Get(I0);
|
||||
constexpr index_t LeftPadW = LeftPads{}.Get(I1);
|
||||
|
||||
constexpr index_t RightPadH = RightPads{}.Get(I0);
|
||||
constexpr index_t RightPadW = RightPads{}.Get(I1);
|
||||
|
||||
constexpr index_t YEff = (Y - 1) * ConvDilations{}[0] + 1;
|
||||
constexpr index_t XEff = (X - 1) * ConvDilations{}[1] + 1;
|
||||
|
||||
constexpr index_t Ho = (Hi + LeftPadH + RightPadH - YEff) / ConvStrides{}[0] + 1;
|
||||
constexpr index_t Wo = (Wi + LeftPadW + RightPadW - XEff) / ConvStrides{}[1] + 1;
|
||||
|
||||
return make_native_tensor_descriptor_packed(Sequence<N, K, Ho, Wo>{});
|
||||
}
|
||||
|
||||
template <typename... InDesc,
|
||||
typename... WeiDesc,
|
||||
typename ConvStrides,
|
||||
@@ -131,30 +83,4 @@ calculate_convolution_flops(const InDesc& in_desc, const WeiDesc& wei_desc, cons
|
||||
return std::size_t(2) * N * K * Ho * Wo * C * Y * X;
|
||||
}
|
||||
|
||||
template <class Float, class InDesc, class WeiDesc, class OutDesc>
|
||||
constexpr std::size_t calculate_convolution_memory_size(Float, InDesc, WeiDesc, OutDesc)
|
||||
{
|
||||
using namespace ck;
|
||||
|
||||
constexpr auto wei_desc = WeiDesc{};
|
||||
constexpr auto out_desc = OutDesc{};
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
constexpr index_t N = out_desc.GetLength(I0);
|
||||
constexpr index_t K = out_desc.GetLength(I1);
|
||||
constexpr index_t Ho = out_desc.GetLength(I2);
|
||||
constexpr index_t Wo = out_desc.GetLength(I3);
|
||||
|
||||
constexpr index_t C = wei_desc.GetLength(I1);
|
||||
constexpr index_t Y = wei_desc.GetLength(I2);
|
||||
constexpr index_t X = wei_desc.GetLength(I3);
|
||||
|
||||
return sizeof(Float) *
|
||||
(InDesc::GetElementSpace() + WeiDesc::GetElementSpace() + OutDesc::GetElementSpace());
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
@@ -1,221 +0,0 @@
|
||||
#pragma once
|
||||
#include <unistd.h>
|
||||
#include "device.hpp"
|
||||
#include "host_tensor.hpp"
|
||||
#include "gridwise_operation_wrapper.hpp"
|
||||
#include "gridwise_convolution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw.hpp"
|
||||
|
||||
namespace launcher {
|
||||
|
||||
using namespace ck;
|
||||
|
||||
template <typename T,
|
||||
typename InDesc,
|
||||
typename WeiDesc,
|
||||
typename OutDesc,
|
||||
typename ConvStrides,
|
||||
typename ConvDilations,
|
||||
typename InLeftPads,
|
||||
typename InRightPads>
|
||||
void device_convolution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw(InDesc in_nchw_desc,
|
||||
Tensor<T>& in_nchw,
|
||||
WeiDesc wei_kcyx_desc,
|
||||
const Tensor<T>& wei_kcyx,
|
||||
OutDesc out_nkhw_desc,
|
||||
const Tensor<T>& out_nkhw,
|
||||
ConvStrides,
|
||||
ConvDilations,
|
||||
InLeftPads,
|
||||
InRightPads,
|
||||
std::size_t nrepeat)
|
||||
{
|
||||
using namespace ck;
|
||||
|
||||
constexpr index_t N = out_nkhw_desc.GetLengths()[0];
|
||||
constexpr index_t K = out_nkhw_desc.GetLengths()[1];
|
||||
constexpr index_t Ho = out_nkhw_desc.GetLengths()[2];
|
||||
constexpr index_t Wo = out_nkhw_desc.GetLengths()[3];
|
||||
|
||||
constexpr index_t C = wei_kcyx_desc.GetLengths()[1];
|
||||
constexpr index_t Y = wei_kcyx_desc.GetLengths()[2];
|
||||
constexpr index_t X = wei_kcyx_desc.GetLengths()[3];
|
||||
|
||||
std::size_t data_sz = sizeof(T);
|
||||
DeviceMem in_nchw_device_buf(data_sz * in_nchw.mDesc.GetElementSpace());
|
||||
DeviceMem wei_kcyx_device_buf(data_sz * wei_kcyx.mDesc.GetElementSpace());
|
||||
DeviceMem out_nkhw_device_buf(data_sz * out_nkhw.mDesc.GetElementSpace());
|
||||
|
||||
in_nchw_device_buf.ToDevice(in_nchw.mData.data());
|
||||
wei_kcyx_device_buf.ToDevice(wei_kcyx.mData.data());
|
||||
out_nkhw_device_buf.ToDevice(out_nkhw.mData.data());
|
||||
|
||||
#if 1
|
||||
// BlockSize = 256, each thread hold 64 data
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 128;
|
||||
constexpr index_t GemmNPerBlock = 128;
|
||||
constexpr index_t GemmKPerBlock = 8;
|
||||
constexpr index_t GemmMPerThread = 4;
|
||||
constexpr index_t GemmNPerThread = 4;
|
||||
constexpr index_t GemmKPerThread = 1;
|
||||
constexpr index_t GemmMLevel0Cluster = 4;
|
||||
constexpr index_t GemmNLevel0Cluster = 4;
|
||||
constexpr index_t GemmMLevel1Cluster = 4;
|
||||
constexpr index_t GemmNLevel1Cluster = 4;
|
||||
constexpr index_t GemmThreadGemmDataPerReadM = 4;
|
||||
constexpr index_t GemmThreadGemmDataPerReadN = 4;
|
||||
|
||||
using GemmABlockCopyThreadSliceLengths_GemmK_GemmM = Sequence<1, 4>;
|
||||
using GemmABlockCopyThreadClusterLengths_GemmK_GemmM = Sequence<8, 32>;
|
||||
|
||||
constexpr index_t GemmABlockCopySrcDataPerRead_GemmM = 4;
|
||||
constexpr index_t GemmABlockCopyDstDataPerWrite_GemmM = 4;
|
||||
|
||||
using GemmBBlockCopyThreadSliceLengths_GemmK_GemmN = Sequence<4, 1>;
|
||||
using GemmBBlockCopyThreadClusterLengths_GemmK_GemmN = Sequence<2, 128>;
|
||||
|
||||
constexpr index_t GemmBBlockCopySrcDataPerRead_GemmN = 1;
|
||||
constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 1;
|
||||
|
||||
constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 1;
|
||||
#elif 1
|
||||
// BlockSize = 256, each thread hold 64 data
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 128;
|
||||
constexpr index_t GemmNPerBlock = 128;
|
||||
constexpr index_t GemmKPerBlock = 8;
|
||||
constexpr index_t GemmMPerThread = 4;
|
||||
constexpr index_t GemmNPerThread = 4;
|
||||
constexpr index_t GemmKPerThread = 1;
|
||||
constexpr index_t GemmMLevel0Cluster = 4;
|
||||
constexpr index_t GemmNLevel0Cluster = 4;
|
||||
constexpr index_t GemmMLevel1Cluster = 4;
|
||||
constexpr index_t GemmNLevel1Cluster = 4;
|
||||
constexpr index_t GemmThreadGemmDataPerReadM = 4;
|
||||
constexpr index_t GemmThreadGemmDataPerReadN = 4;
|
||||
|
||||
using GemmABlockCopyThreadSliceLengths_GemmK_GemmM = Sequence<1, 4>;
|
||||
using GemmABlockCopyThreadClusterLengths_GemmK_GemmM = Sequence<8, 32>;
|
||||
|
||||
constexpr index_t GemmABlockCopySrcDataPerRead_GemmM = 4;
|
||||
constexpr index_t GemmABlockCopyDstDataPerWrite_GemmM = 4;
|
||||
|
||||
using GemmBBlockCopyThreadSliceLengths_GemmK_GemmN = Sequence<1, 4>;
|
||||
using GemmBBlockCopyThreadClusterLengths_GemmK_GemmN = Sequence<8, 32>;
|
||||
|
||||
constexpr index_t GemmBBlockCopySrcDataPerRead_GemmN = 4;
|
||||
constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 4;
|
||||
|
||||
constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 4;
|
||||
#elif 1
|
||||
// BlockSize = 256, each thread hold 64 data
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 128;
|
||||
constexpr index_t GemmNPerBlock = 128;
|
||||
constexpr index_t GemmKPerBlock = 16;
|
||||
constexpr index_t GemmMPerThread = 4;
|
||||
constexpr index_t GemmNPerThread = 4;
|
||||
constexpr index_t GemmKPerThread = 1;
|
||||
constexpr index_t GemmMLevel0Cluster = 4;
|
||||
constexpr index_t GemmNLevel0Cluster = 4;
|
||||
constexpr index_t GemmMLevel1Cluster = 4;
|
||||
constexpr index_t GemmNLevel1Cluster = 4;
|
||||
constexpr index_t GemmThreadGemmDataPerReadM = 4;
|
||||
constexpr index_t GemmThreadGemmDataPerReadN = 4;
|
||||
|
||||
using GemmABlockCopyThreadSliceLengths_GemmK_GemmM = Sequence<2, 4>;
|
||||
using GemmABlockCopyThreadClusterLengths_GemmK_GemmM = Sequence<8, 32>;
|
||||
|
||||
constexpr index_t GemmABlockCopySrcDataPerRead_GemmM = 4;
|
||||
constexpr index_t GemmABlockCopyDstDataPerWrite_GemmM = 4;
|
||||
|
||||
using GemmBBlockCopyThreadSliceLengths_GemmK_GemmN = Sequence<2, 4>;
|
||||
using GemmBBlockCopyThreadClusterLengths_GemmK_GemmN = Sequence<8, 32>;
|
||||
|
||||
constexpr index_t GemmBBlockCopySrcDataPerRead_GemmN = 4;
|
||||
constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 4;
|
||||
|
||||
constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 4;
|
||||
#endif
|
||||
|
||||
constexpr index_t GemmM = C * Y * X;
|
||||
constexpr index_t GemmN = N * Ho * Wo;
|
||||
|
||||
constexpr index_t GridSize = math::integer_divide_ceil(GemmM, GemmMPerBlock) *
|
||||
math::integer_divide_ceil(GemmN, GemmNPerBlock);
|
||||
|
||||
printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize);
|
||||
|
||||
using gridwise_conv_bwd_data = GridwiseConvolutionBackwardDataImplicitGemm_v1r1_nchw_kcyx_nkhw<
|
||||
GridSize,
|
||||
BlockSize,
|
||||
T,
|
||||
T,
|
||||
decltype(in_nchw_desc),
|
||||
decltype(wei_kcyx_desc),
|
||||
decltype(out_nkhw_desc),
|
||||
ConvStrides,
|
||||
ConvDilations,
|
||||
InLeftPads,
|
||||
InRightPads,
|
||||
GemmMPerBlock,
|
||||
GemmNPerBlock,
|
||||
GemmKPerBlock,
|
||||
GemmMPerThread,
|
||||
GemmNPerThread,
|
||||
GemmKPerThread,
|
||||
GemmMLevel0Cluster,
|
||||
GemmNLevel0Cluster,
|
||||
GemmMLevel1Cluster,
|
||||
GemmNLevel1Cluster,
|
||||
GemmThreadGemmDataPerReadM,
|
||||
GemmThreadGemmDataPerReadN,
|
||||
GemmABlockCopyThreadSliceLengths_GemmK_GemmM,
|
||||
GemmABlockCopyThreadClusterLengths_GemmK_GemmM,
|
||||
GemmABlockCopySrcDataPerRead_GemmM,
|
||||
GemmABlockCopyDstDataPerWrite_GemmM,
|
||||
GemmBBlockCopyThreadSliceLengths_GemmK_GemmN,
|
||||
GemmBBlockCopyThreadClusterLengths_GemmK_GemmN,
|
||||
GemmBBlockCopySrcDataPerRead_GemmN,
|
||||
GemmBBlockCopyDstDataPerWrite_GemmN,
|
||||
GemmCThreadCopyDstDataPerWrite_GemmN1>;
|
||||
|
||||
for(index_t i = 0; i < 1; ++i)
|
||||
{
|
||||
std::cout << "Start running " << nrepeat << " times..." << std::endl;
|
||||
|
||||
KernelTimer timer;
|
||||
timer.Start();
|
||||
|
||||
for(index_t j = 0; j < nrepeat; ++j)
|
||||
{
|
||||
launch_kernel(run_gridwise_operation<gridwise_conv_bwd_data,
|
||||
T* const __restrict__,
|
||||
const T* const __restrict__,
|
||||
const T* const __restrict__>,
|
||||
dim3(GridSize),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
0,
|
||||
static_cast<T*>(in_nchw_device_buf.GetDeviceBuffer()),
|
||||
static_cast<T*>(wei_kcyx_device_buf.GetDeviceBuffer()),
|
||||
static_cast<T*>(out_nkhw_device_buf.GetDeviceBuffer()));
|
||||
}
|
||||
|
||||
timer.End();
|
||||
|
||||
float ave_time = timer.GetElapsedTime() / nrepeat;
|
||||
|
||||
float perf = (float)calculate_convolution_flops(InDesc{}, WeiDesc{}, OutDesc{}) /
|
||||
(std::size_t(1000) * 1000 * 1000) / ave_time;
|
||||
|
||||
std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" << std::endl;
|
||||
}
|
||||
|
||||
in_nchw_device_buf.FromDevice(in_nchw.mData.data());
|
||||
}
|
||||
|
||||
} // namespace launcher
|
||||
@@ -1,171 +0,0 @@
|
||||
#pragma once
|
||||
#include <unistd.h>
|
||||
#include "device.hpp"
|
||||
#include "host_tensor.hpp"
|
||||
#include "gridwise_operation_wrapper.hpp"
|
||||
#include "gridwise_convolution_backward_data_implicit_gemm_v1r2_nchw_kcyx_nkhw_lds_double_buffer.hpp"
|
||||
|
||||
namespace launcher {
|
||||
|
||||
using namespace ck;
|
||||
|
||||
template <typename T,
|
||||
typename InDesc,
|
||||
typename WeiDesc,
|
||||
typename OutDesc,
|
||||
typename ConvStrides,
|
||||
typename ConvDilations,
|
||||
typename LeftPads,
|
||||
typename RightPads>
|
||||
void device_convolution_backward_data_implicit_gemm_v1r2_nchw_kcyx_nkhw(InDesc in_nchw_desc,
|
||||
Tensor<T>& in_nchw,
|
||||
WeiDesc wei_kcyx_desc,
|
||||
const Tensor<T>& wei_kcyx,
|
||||
OutDesc out_nkhw_desc,
|
||||
const Tensor<T>& out_nkhw,
|
||||
ConvStrides,
|
||||
ConvDilations,
|
||||
LeftPads,
|
||||
RightPads,
|
||||
std::size_t nrepeat)
|
||||
{
|
||||
using namespace ck;
|
||||
|
||||
constexpr index_t N = out_nkhw_desc.GetLengths()[0];
|
||||
constexpr index_t K = out_nkhw_desc.GetLengths()[1];
|
||||
constexpr index_t Ho = out_nkhw_desc.GetLengths()[2];
|
||||
constexpr index_t Wo = out_nkhw_desc.GetLengths()[3];
|
||||
|
||||
constexpr index_t C = wei_kcyx_desc.GetLengths()[1];
|
||||
constexpr index_t Y = wei_kcyx_desc.GetLengths()[2];
|
||||
constexpr index_t X = wei_kcyx_desc.GetLengths()[3];
|
||||
|
||||
std::size_t data_sz = sizeof(T);
|
||||
DeviceMem in_nchw_device_buf(data_sz * in_nchw.mDesc.GetElementSpace());
|
||||
DeviceMem wei_kcyx_device_buf(data_sz * wei_kcyx.mDesc.GetElementSpace());
|
||||
DeviceMem out_nkhw_device_buf(data_sz * out_nkhw.mDesc.GetElementSpace());
|
||||
|
||||
in_nchw_device_buf.ToDevice(in_nchw.mData.data());
|
||||
wei_kcyx_device_buf.ToDevice(wei_kcyx.mData.data());
|
||||
out_nkhw_device_buf.ToDevice(out_nkhw.mData.data());
|
||||
|
||||
#if 1
|
||||
// BlockSize = 256, each thread hold 64 data
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t BPerBlock = 32;
|
||||
constexpr index_t EPerBlock = 32;
|
||||
constexpr index_t KPerBlock = 16;
|
||||
|
||||
constexpr index_t GemmMPerThread = 4;
|
||||
constexpr index_t GemmNPerThread = 4;
|
||||
constexpr index_t GemmKPerThread = 1;
|
||||
|
||||
constexpr index_t GemmMLevel0Cluster = 4;
|
||||
constexpr index_t GemmNLevel0Cluster = 4;
|
||||
constexpr index_t GemmMLevel1Cluster = 4;
|
||||
constexpr index_t GemmNLevel1Cluster = 4;
|
||||
|
||||
constexpr index_t GemmDataPerReadA = 4;
|
||||
constexpr index_t GemmDataPerReadB = 4;
|
||||
|
||||
using OutBlockCopySubLengths_K_B_N0 = Sequence<2, 1, 4>;
|
||||
using OutBlockCopyClusterLengths_K_B_N0 = Sequence<8, 32, 1>;
|
||||
|
||||
constexpr index_t OutBlockCopySrcDataPerRead_B = 1;
|
||||
constexpr index_t OutBlockCopyDstDataPerWrite_N0 = 4;
|
||||
|
||||
using WeiBlockCopySubLengths_K_E_C0 = Sequence<2, 4, 1>;
|
||||
using WeiBlockCopyClusterLengths_K_E_C0 = Sequence<8, 8, 4>;
|
||||
|
||||
constexpr index_t WeiBlockCopySrcDataPerRead_E = 4;
|
||||
constexpr index_t WeiBlockCopyDstDataPerWrite_C0 = 1;
|
||||
|
||||
constexpr index_t InThreadCopyDstDataPerWrite_B = 1;
|
||||
#endif
|
||||
|
||||
constexpr index_t C0 = GemmMPerThread;
|
||||
constexpr index_t N0 = GemmNPerThread;
|
||||
|
||||
constexpr index_t C1 = C / C0;
|
||||
constexpr index_t N1 = N / N0;
|
||||
|
||||
constexpr index_t E = C1 * Y * X;
|
||||
constexpr index_t B = (N1 * Ho * Wo);
|
||||
|
||||
constexpr index_t GridSize =
|
||||
((E + EPerBlock - 1) / EPerBlock) * ((B + BPerBlock - 1) / BPerBlock);
|
||||
|
||||
printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize);
|
||||
|
||||
using gridwise_conv_bwd_data =
|
||||
GridwiseConvolutionBackwardDataImplicitGemm_v1r2_nchw_kcyx_nkhw_lds_double_buffer<
|
||||
GridSize,
|
||||
BlockSize,
|
||||
T,
|
||||
T,
|
||||
decltype(in_nchw_desc),
|
||||
decltype(wei_kcyx_desc),
|
||||
decltype(out_nkhw_desc),
|
||||
ConvStrides,
|
||||
ConvDilations,
|
||||
LeftPads,
|
||||
RightPads,
|
||||
EPerBlock,
|
||||
BPerBlock,
|
||||
KPerBlock,
|
||||
GemmMPerThread,
|
||||
GemmNPerThread,
|
||||
GemmKPerThread,
|
||||
GemmMLevel0Cluster,
|
||||
GemmNLevel0Cluster,
|
||||
GemmMLevel1Cluster,
|
||||
GemmNLevel1Cluster,
|
||||
GemmDataPerReadA,
|
||||
GemmDataPerReadB,
|
||||
OutBlockCopySubLengths_K_B_N0,
|
||||
OutBlockCopyClusterLengths_K_B_N0,
|
||||
OutBlockCopySrcDataPerRead_B,
|
||||
OutBlockCopyDstDataPerWrite_N0,
|
||||
WeiBlockCopySubLengths_K_E_C0,
|
||||
WeiBlockCopyClusterLengths_K_E_C0,
|
||||
WeiBlockCopySrcDataPerRead_E,
|
||||
WeiBlockCopyDstDataPerWrite_C0,
|
||||
InThreadCopyDstDataPerWrite_B>;
|
||||
|
||||
for(index_t i = 0; i < 5; ++i)
|
||||
{
|
||||
std::cout << "Start running " << nrepeat << " times..." << std::endl;
|
||||
|
||||
KernelTimer timer;
|
||||
timer.Start();
|
||||
|
||||
for(index_t j = 0; j < nrepeat; ++j)
|
||||
{
|
||||
launch_kernel(run_gridwise_operation<gridwise_conv_bwd_data,
|
||||
T* const __restrict__,
|
||||
const T* const __restrict__,
|
||||
const T* const __restrict__>,
|
||||
dim3(GridSize),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
0,
|
||||
static_cast<T*>(in_nchw_device_buf.GetDeviceBuffer()),
|
||||
static_cast<T*>(wei_kcyx_device_buf.GetDeviceBuffer()),
|
||||
static_cast<T*>(out_nkhw_device_buf.GetDeviceBuffer()));
|
||||
}
|
||||
|
||||
timer.End();
|
||||
|
||||
float ave_time = timer.GetElapsedTime() / nrepeat;
|
||||
|
||||
float perf = (float)calculate_convolution_flops(InDesc{}, WeiDesc{}, OutDesc{}) /
|
||||
(std::size_t(1000) * 1000 * 1000) / ave_time;
|
||||
|
||||
std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" << std::endl;
|
||||
}
|
||||
|
||||
in_nchw_device_buf.FromDevice(in_nchw.mData.data());
|
||||
}
|
||||
|
||||
} // namespace launcher
|
||||
@@ -1,267 +0,0 @@
|
||||
#pragma once
|
||||
#include <unistd.h>
|
||||
#include "device.hpp"
|
||||
#include "host_tensor.hpp"
|
||||
#include "gridwise_operation_wrapper.hpp"
|
||||
#include "gridwise_convolution_backward_data_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp"
|
||||
|
||||
namespace launcher {
|
||||
|
||||
using namespace ck;
|
||||
|
||||
template <typename T,
|
||||
typename InDesc,
|
||||
typename WeiDesc,
|
||||
typename OutDesc,
|
||||
typename ConvStrides,
|
||||
typename ConvDilations,
|
||||
typename InLeftPads,
|
||||
typename InRightPads>
|
||||
void device_convolution_backward_data_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc in_nchw_desc,
|
||||
Tensor<T>& in_nchw,
|
||||
WeiDesc wei_kcyx_desc,
|
||||
const Tensor<T>& wei_kcyx,
|
||||
OutDesc out_nkhw_desc,
|
||||
const Tensor<T>& out_nkhw,
|
||||
ConvStrides,
|
||||
ConvDilations,
|
||||
InLeftPads,
|
||||
InRightPads,
|
||||
std::size_t nrepeat)
|
||||
{
|
||||
constexpr index_t N = out_nkhw_desc.GetLengths()[0];
|
||||
constexpr index_t K = out_nkhw_desc.GetLengths()[1];
|
||||
constexpr index_t C = wei_kcyx_desc.GetLengths()[1];
|
||||
|
||||
constexpr index_t Hi = in_nchw_desc.GetLengths()[2];
|
||||
constexpr index_t Wi = in_nchw_desc.GetLengths()[3];
|
||||
|
||||
constexpr index_t Ho = out_nkhw_desc.GetLengths()[2];
|
||||
constexpr index_t Wo = out_nkhw_desc.GetLengths()[3];
|
||||
|
||||
constexpr index_t Y = wei_kcyx_desc.GetLengths()[2];
|
||||
constexpr index_t X = wei_kcyx_desc.GetLengths()[3];
|
||||
|
||||
constexpr index_t ConvStrideH = ConvStrides{}[0];
|
||||
constexpr index_t ConvStrideW = ConvStrides{}[1];
|
||||
|
||||
constexpr index_t ConvDilationH = ConvDilations{}[0];
|
||||
constexpr index_t ConvDilationW = ConvDilations{}[1];
|
||||
|
||||
std::size_t data_sz = sizeof(T);
|
||||
DeviceMem in_nchw_device_buf(data_sz * in_nchw.mDesc.GetElementSpace());
|
||||
DeviceMem wei_kcyx_device_buf(data_sz * wei_kcyx.mDesc.GetElementSpace());
|
||||
DeviceMem out_nkhw_device_buf(data_sz * out_nkhw.mDesc.GetElementSpace());
|
||||
|
||||
in_nchw_device_buf.ToDevice(in_nchw.mData.data());
|
||||
wei_kcyx_device_buf.ToDevice(wei_kcyx.mData.data());
|
||||
out_nkhw_device_buf.ToDevice(out_nkhw.mData.data());
|
||||
|
||||
#if 1
|
||||
// cdata = 64, BlockSize = 256, 128x128x8
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 128;
|
||||
constexpr index_t GemmNPerBlock = 128;
|
||||
constexpr index_t GemmKPerBlock = 8;
|
||||
constexpr index_t GemmMPerThread = 4;
|
||||
constexpr index_t GemmNPerThread = 4;
|
||||
constexpr index_t GemmKPerThread = 1;
|
||||
constexpr index_t GemmMLevel0Cluster = 2;
|
||||
constexpr index_t GemmNLevel0Cluster = 2;
|
||||
constexpr index_t GemmMLevel1Cluster = 8;
|
||||
constexpr index_t GemmNLevel1Cluster = 8;
|
||||
constexpr index_t GemmThreadGemmDataPerReadM = 4;
|
||||
constexpr index_t GemmThreadGemmDataPerReadN = 4;
|
||||
|
||||
using GemmABlockCopyThreadSliceLengths_GemmK_GemmM = Sequence<1, 4>;
|
||||
using GemmABlockCopyThreadClusterLengths_GemmK_GemmM = Sequence<8, 32>;
|
||||
|
||||
constexpr index_t GemmABlockCopySrcDataPerRead_GemmM = 1;
|
||||
constexpr index_t GemmABlockCopyDstDataPerWrite_GemmM = 4;
|
||||
|
||||
using GemmBBlockCopyThreadSliceLengths_GemmK_GemmN = Sequence<4, 1>;
|
||||
using GemmBBlockCopyThreadClusterLengths_GemmK_GemmN = Sequence<2, 128>;
|
||||
|
||||
constexpr index_t GemmBBlockCopySrcDataPerRead_GemmN = 1;
|
||||
constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 1;
|
||||
|
||||
constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 1;
|
||||
#elif 1
|
||||
// cdata = 64, BlockSize = 256, 128x128x8
|
||||
// GemmABlockCopySrcDataPerRead_GemmM = 4
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 128;
|
||||
constexpr index_t GemmNPerBlock = 128;
|
||||
constexpr index_t GemmKPerBlock = 8;
|
||||
constexpr index_t GemmMPerThread = 4;
|
||||
constexpr index_t GemmNPerThread = 4;
|
||||
constexpr index_t GemmKPerThread = 1;
|
||||
constexpr index_t GemmMLevel0Cluster = 4;
|
||||
constexpr index_t GemmNLevel0Cluster = 4;
|
||||
constexpr index_t GemmMLevel1Cluster = 4;
|
||||
constexpr index_t GemmNLevel1Cluster = 4;
|
||||
constexpr index_t GemmThreadGemmDataPerReadM = 4;
|
||||
constexpr index_t GemmThreadGemmDataPerReadN = 4;
|
||||
|
||||
using GemmABlockCopyThreadSliceLengths_GemmK_GemmM = Sequence<1, 4>;
|
||||
using GemmABlockCopyThreadClusterLengths_GemmK_GemmM = Sequence<8, 32>;
|
||||
|
||||
constexpr index_t GemmABlockCopySrcDataPerRead_GemmM = 4;
|
||||
constexpr index_t GemmABlockCopyDstDataPerWrite_GemmM = 4;
|
||||
|
||||
using GemmBBlockCopyThreadSliceLengths_GemmK_GemmN = Sequence<4, 1>;
|
||||
using GemmBBlockCopyThreadClusterLengths_GemmK_GemmN = Sequence<2, 128>;
|
||||
|
||||
constexpr index_t GemmBBlockCopySrcDataPerRead_GemmN = 1;
|
||||
constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 1;
|
||||
|
||||
constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 1;
|
||||
#elif 1
|
||||
// cdata = 64, BlockSize = 256, 128x128x16
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 128;
|
||||
constexpr index_t GemmNPerBlock = 128;
|
||||
constexpr index_t GemmKPerBlock = 16;
|
||||
constexpr index_t GemmMPerThread = 4;
|
||||
constexpr index_t GemmNPerThread = 4;
|
||||
constexpr index_t GemmKPerThread = 1;
|
||||
constexpr index_t GemmMLevel0Cluster = 4;
|
||||
constexpr index_t GemmNLevel0Cluster = 4;
|
||||
constexpr index_t GemmMLevel1Cluster = 4;
|
||||
constexpr index_t GemmNLevel1Cluster = 4;
|
||||
constexpr index_t GemmThreadGemmDataPerReadM = 4;
|
||||
constexpr index_t GemmThreadGemmDataPerReadN = 4;
|
||||
|
||||
using GemmABlockCopyThreadSliceLengths_GemmK_GemmM = Sequence<1, 8>;
|
||||
using GemmABlockCopyThreadClusterLengths_GemmK_GemmM = Sequence<16, 16>;
|
||||
|
||||
constexpr index_t GemmABlockCopySrcDataPerRead_GemmM = 1;
|
||||
constexpr index_t GemmABlockCopyDstDataPerWrite_GemmM = 4;
|
||||
|
||||
using GemmBBlockCopyThreadSliceLengths_GemmK_GemmN = Sequence<8, 1>;
|
||||
using GemmBBlockCopyThreadClusterLengths_GemmK_GemmN = Sequence<2, 128>;
|
||||
|
||||
constexpr index_t GemmBBlockCopySrcDataPerRead_GemmN = 1;
|
||||
constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 1;
|
||||
|
||||
constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 1;
|
||||
#endif
|
||||
|
||||
constexpr index_t GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH);
|
||||
constexpr index_t GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW);
|
||||
|
||||
constexpr index_t YTilda = ConvStrideH / GcdStrideDilationH;
|
||||
constexpr index_t XTilda = ConvStrideW / GcdStrideDilationW;
|
||||
|
||||
constexpr index_t YDot = math::integer_divide_ceil(Y, YTilda);
|
||||
constexpr index_t XDot = math::integer_divide_ceil(X, XTilda);
|
||||
|
||||
constexpr index_t HTilda = Ho + math::integer_divide_ceil(ConvDilationH * (Y - 1), ConvStrideH);
|
||||
constexpr index_t WTilda = Wo + math::integer_divide_ceil(ConvDilationW * (X - 1), ConvStrideW);
|
||||
|
||||
constexpr index_t HTildaLeft = math::integer_divide_floor(
|
||||
math::max(0, InLeftPads{}[0] - ConvDilationH * (YTilda - 1)), ConvStrides{}[0]);
|
||||
constexpr index_t WTildaLeft = math::integer_divide_floor(
|
||||
math::max(0, InLeftPads{}[1] - ConvDilationW * (XTilda - 1)), ConvStrides{}[1]);
|
||||
|
||||
constexpr index_t HTildaRight = math::min(
|
||||
HTilda, math::integer_divide_ceil(InLeftPads{}[0] + Hi - 1, ConvStrides{}[0]) + 1);
|
||||
constexpr index_t WTildaRight = math::min(
|
||||
WTilda, math::integer_divide_ceil(InLeftPads{}[1] + Wi - 1, ConvStrides{}[1]) + 1);
|
||||
|
||||
constexpr index_t HTildaSlice = HTildaRight - HTildaLeft;
|
||||
constexpr index_t WTildaSlice = WTildaRight - WTildaLeft;
|
||||
|
||||
constexpr index_t GemmM = C;
|
||||
constexpr index_t GemmN = N * HTildaSlice * WTildaSlice;
|
||||
|
||||
constexpr index_t GridSize = math::integer_divide_ceil(GemmM, GemmMPerBlock) *
|
||||
math::integer_divide_ceil(GemmN, GemmNPerBlock);
|
||||
|
||||
printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize);
|
||||
|
||||
for(index_t i = 0; i < 5; ++i)
|
||||
{
|
||||
std::cout << "Start running " << nrepeat << " times..." << std::endl;
|
||||
|
||||
KernelTimer timer;
|
||||
timer.Start();
|
||||
|
||||
for(index_t i = 0; i < nrepeat; ++i)
|
||||
{
|
||||
using GridwiseConvBwdData =
|
||||
GridwiseConvolutionBackwardDataImplicitGemm_v4r1_nchw_kcyx_nkhw<
|
||||
GridSize,
|
||||
BlockSize,
|
||||
T,
|
||||
T,
|
||||
decltype(in_nchw_desc),
|
||||
decltype(wei_kcyx_desc),
|
||||
decltype(out_nkhw_desc),
|
||||
ConvStrides,
|
||||
ConvDilations,
|
||||
InLeftPads,
|
||||
InRightPads,
|
||||
GemmMPerBlock,
|
||||
GemmNPerBlock,
|
||||
GemmKPerBlock,
|
||||
GemmMPerThread,
|
||||
GemmNPerThread,
|
||||
GemmKPerThread,
|
||||
GemmMLevel0Cluster,
|
||||
GemmNLevel0Cluster,
|
||||
GemmMLevel1Cluster,
|
||||
GemmNLevel1Cluster,
|
||||
GemmThreadGemmDataPerReadM,
|
||||
GemmThreadGemmDataPerReadN,
|
||||
GemmABlockCopyThreadSliceLengths_GemmK_GemmM,
|
||||
GemmABlockCopyThreadClusterLengths_GemmK_GemmM,
|
||||
GemmABlockCopySrcDataPerRead_GemmM,
|
||||
GemmABlockCopyDstDataPerWrite_GemmM,
|
||||
GemmBBlockCopyThreadSliceLengths_GemmK_GemmN,
|
||||
GemmBBlockCopyThreadClusterLengths_GemmK_GemmN,
|
||||
GemmBBlockCopySrcDataPerRead_GemmN,
|
||||
GemmBBlockCopyDstDataPerWrite_GemmN,
|
||||
GemmCThreadCopyDstDataPerWrite_GemmN1>;
|
||||
|
||||
static_for<0, GridwiseConvBwdData::GetNumberOfGemm(), 1>{}([&](auto gemm_id) {
|
||||
constexpr auto gemm_sizes = GridwiseConvBwdData::GetGemmSize(gemm_id);
|
||||
constexpr index_t gemm_k = gemm_sizes.At(2);
|
||||
constexpr bool is_gemm_not_empty = gemm_k > 0;
|
||||
|
||||
// only compile and run if GEMM is no empty
|
||||
static_if<is_gemm_not_empty>{}([&](auto fwd) {
|
||||
launch_kernel(run_gridwise_operation<GridwiseConvBwdData,
|
||||
T* const __restrict__,
|
||||
const T* const __restrict__,
|
||||
const T* const __restrict__,
|
||||
decltype(gemm_id)>,
|
||||
dim3(GridSize),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
0,
|
||||
static_cast<T*>(in_nchw_device_buf.GetDeviceBuffer()),
|
||||
static_cast<T*>(wei_kcyx_device_buf.GetDeviceBuffer()),
|
||||
static_cast<T*>(out_nkhw_device_buf.GetDeviceBuffer()),
|
||||
fwd(gemm_id));
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
timer.End();
|
||||
|
||||
float ave_time = timer.GetElapsedTime() / nrepeat;
|
||||
|
||||
float perf = (float)calculate_convolution_flops(InDesc{}, WeiDesc{}, OutDesc{}) /
|
||||
(std::size_t(1000) * 1000 * 1000) / ave_time;
|
||||
|
||||
std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" << std::endl;
|
||||
}
|
||||
|
||||
in_nchw_device_buf.FromDevice(in_nchw.mData.data());
|
||||
}
|
||||
|
||||
} // namespace launcher
|
||||
@@ -1,266 +0,0 @@
|
||||
#pragma once
|
||||
#include <unistd.h>
|
||||
#include "device.hpp"
|
||||
#include "host_tensor.hpp"
|
||||
#include "gridwise_operation_wrapper.hpp"
|
||||
#include "gridwise_convolution_backward_data_implicit_gemm_v5r1_nhwc_kyxc_nhwk.hpp"
|
||||
|
||||
namespace launcher {
|
||||
|
||||
using namespace ck;
|
||||
|
||||
template <typename T,
|
||||
typename InDesc,
|
||||
typename WeiDesc,
|
||||
typename OutDesc,
|
||||
typename ConvStrides,
|
||||
typename ConvDilations,
|
||||
typename InLeftPads,
|
||||
typename InRightPads>
|
||||
void device_convolution_backward_data_implicit_gemm_v5r1_nhwc_kyxc_nhwk(InDesc in_nchw_desc,
|
||||
Tensor<T>& in_nchw,
|
||||
WeiDesc wei_kcyx_desc,
|
||||
const Tensor<T>& wei_kcyx,
|
||||
OutDesc out_nkhw_desc,
|
||||
const Tensor<T>& out_nkhw,
|
||||
ConvStrides,
|
||||
ConvDilations,
|
||||
InLeftPads,
|
||||
InRightPads,
|
||||
std::size_t nrepeat)
|
||||
{
|
||||
constexpr index_t N = out_nkhw_desc.GetLengths()[0];
|
||||
constexpr index_t K = out_nkhw_desc.GetLengths()[1];
|
||||
constexpr index_t C = wei_kcyx_desc.GetLengths()[1];
|
||||
|
||||
constexpr index_t Hi = in_nchw_desc.GetLengths()[2];
|
||||
constexpr index_t Wi = in_nchw_desc.GetLengths()[3];
|
||||
|
||||
constexpr index_t Ho = out_nkhw_desc.GetLengths()[2];
|
||||
constexpr index_t Wo = out_nkhw_desc.GetLengths()[3];
|
||||
|
||||
constexpr index_t Y = wei_kcyx_desc.GetLengths()[2];
|
||||
constexpr index_t X = wei_kcyx_desc.GetLengths()[3];
|
||||
|
||||
constexpr index_t ConvStrideH = ConvStrides{}[0];
|
||||
constexpr index_t ConvStrideW = ConvStrides{}[1];
|
||||
|
||||
constexpr index_t ConvDilationH = ConvDilations{}[0];
|
||||
constexpr index_t ConvDilationW = ConvDilations{}[1];
|
||||
|
||||
constexpr auto in_nhwc_desc = make_native_tensor_descriptor_packed(Sequence<N, Hi, Wi, C>{});
|
||||
constexpr auto wei_kyxc_desc = make_native_tensor_descriptor_packed(Sequence<K, Y, X, C>{});
|
||||
constexpr auto out_nhwk_desc = make_native_tensor_descriptor_packed(Sequence<N, Ho, Wo, K>{});
|
||||
|
||||
Tensor<float> in_nhwc(make_HostTensorDescriptor(in_nhwc_desc));
|
||||
Tensor<float> wei_kyxc(make_HostTensorDescriptor(wei_kyxc_desc));
|
||||
Tensor<float> out_nhwk(make_HostTensorDescriptor(out_nhwk_desc));
|
||||
|
||||
auto f_nchw2nhwc = [&](auto n, auto hi, auto wi, auto c) {
|
||||
in_nhwc(n, hi, wi, c) = in_nchw(n, c, hi, wi);
|
||||
};
|
||||
|
||||
auto f_kcyx2kyxc = [&](auto k, auto y, auto x, auto c) {
|
||||
wei_kyxc(k, y, x, c) = wei_kcyx(k, c, y, x);
|
||||
};
|
||||
|
||||
auto f_nkhw2nhwk = [&](auto n, auto ho, auto wo, auto k) {
|
||||
out_nhwk(n, ho, wo, k) = out_nkhw(n, k, ho, wo);
|
||||
};
|
||||
|
||||
make_ParallelTensorFunctor(f_nchw2nhwc, N, Hi, Wi, C)(std::thread::hardware_concurrency());
|
||||
make_ParallelTensorFunctor(f_kcyx2kyxc, K, Y, X, C)(std::thread::hardware_concurrency());
|
||||
make_ParallelTensorFunctor(f_nkhw2nhwk, N, Ho, Wo, K)(std::thread::hardware_concurrency());
|
||||
|
||||
std::size_t data_sz = sizeof(T);
|
||||
DeviceMem in_nhwc_device_buf(data_sz * in_nhwc.mDesc.GetElementSpace());
|
||||
DeviceMem wei_kyxc_device_buf(data_sz * wei_kyxc.mDesc.GetElementSpace());
|
||||
DeviceMem out_nhwk_device_buf(data_sz * out_nhwk.mDesc.GetElementSpace());
|
||||
|
||||
in_nhwc_device_buf.ToDevice(in_nhwc.mData.data());
|
||||
wei_kyxc_device_buf.ToDevice(wei_kyxc.mData.data());
|
||||
out_nhwk_device_buf.ToDevice(out_nhwk.mData.data());
|
||||
|
||||
#if 0
|
||||
// cdata = 64, BlockSize = 256, 128x128x8
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 128;
|
||||
constexpr index_t GemmNPerBlock = 128;
|
||||
constexpr index_t GemmKPerBlock = 8;
|
||||
constexpr index_t GemmMPerThread = 4;
|
||||
constexpr index_t GemmNPerThread = 4;
|
||||
constexpr index_t GemmKPerThread = 1;
|
||||
constexpr index_t GemmMLevel0Cluster = 4;
|
||||
constexpr index_t GemmNLevel0Cluster = 4;
|
||||
constexpr index_t GemmMLevel1Cluster = 4;
|
||||
constexpr index_t GemmNLevel1Cluster = 4;
|
||||
constexpr index_t GemmThreadGemmDataPerReadM = 4;
|
||||
constexpr index_t GemmThreadGemmDataPerReadN = 4;
|
||||
|
||||
using GemmABlockCopyThreadSliceLengths_GemmK0_GemmK1_GemmK2_GemmM = Sequence<1, 1, 1, 4>;
|
||||
using GemmABlockCopyThreadClusterLengths_GemmK0_GemmK1_GemmK2_GemmM = Sequence<1, 1, 8, 32>;
|
||||
|
||||
constexpr index_t GemmABlockCopySrcDataPerRead_GemmM = 4;
|
||||
constexpr index_t GemmABlockCopyDstDataPerWrite_GemmM = 4;
|
||||
|
||||
using GemmBBlockCopyThreadSliceLengths_GemmK0_GemmK1_GemmK2_GemmN = Sequence<1, 1, 4, 1>;
|
||||
using GemmBBlockCopyThreadClusterLengths_GemmK0_GemmK1_GemmK2_GemmN = Sequence<1, 1, 2, 128>;
|
||||
|
||||
constexpr index_t GemmBBlockCopySrcDataPerRead_GemmK2 = 4;
|
||||
constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 1;
|
||||
|
||||
constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 1;
|
||||
#elif 1
|
||||
// cdata = 64, BlockSize = 256, 128x128x16
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 128;
|
||||
constexpr index_t GemmNPerBlock = 128;
|
||||
constexpr index_t GemmKPerBlock = 16;
|
||||
constexpr index_t GemmMPerThread = 4;
|
||||
constexpr index_t GemmNPerThread = 4;
|
||||
constexpr index_t GemmKPerThread = 1;
|
||||
constexpr index_t GemmMLevel0Cluster = 4;
|
||||
constexpr index_t GemmNLevel0Cluster = 4;
|
||||
constexpr index_t GemmMLevel1Cluster = 4;
|
||||
constexpr index_t GemmNLevel1Cluster = 4;
|
||||
constexpr index_t GemmThreadGemmDataPerReadM = 4;
|
||||
constexpr index_t GemmThreadGemmDataPerReadN = 4;
|
||||
|
||||
using GemmABlockCopyThreadSliceLengths_GemmK0_GemmK1_GemmK2_GemmM = Sequence<1, 1, 2, 4>;
|
||||
using GemmABlockCopyThreadClusterLengths_GemmK0_GemmK1_GemmK2_GemmM = Sequence<1, 1, 8, 32>;
|
||||
|
||||
constexpr index_t GemmABlockCopySrcDataPerRead_GemmM = 4;
|
||||
constexpr index_t GemmABlockCopyDstDataPerWrite_GemmM = 4;
|
||||
|
||||
using GemmBBlockCopyThreadSliceLengths_GemmK0_GemmK1_GemmK2_GemmN = Sequence<1, 1, 8, 1>;
|
||||
using GemmBBlockCopyThreadClusterLengths_GemmK0_GemmK1_GemmK2_GemmN = Sequence<1, 1, 2, 128>;
|
||||
|
||||
constexpr index_t GemmBBlockCopySrcDataPerRead_GemmK2 = 4;
|
||||
constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 1;
|
||||
|
||||
constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 1;
|
||||
#endif
|
||||
|
||||
constexpr index_t GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH);
|
||||
constexpr index_t GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW);
|
||||
|
||||
constexpr index_t YTilda = ConvStrideH / GcdStrideDilationH;
|
||||
constexpr index_t XTilda = ConvStrideW / GcdStrideDilationW;
|
||||
|
||||
constexpr index_t YDot = math::integer_divide_ceil(Y, YTilda);
|
||||
constexpr index_t XDot = math::integer_divide_ceil(X, XTilda);
|
||||
|
||||
constexpr index_t HTilda = Ho + math::integer_divide_ceil(ConvDilationH * (Y - 1), ConvStrideH);
|
||||
constexpr index_t WTilda = Wo + math::integer_divide_ceil(ConvDilationW * (X - 1), ConvStrideW);
|
||||
|
||||
constexpr index_t HTildaLeft = math::integer_divide_floor(
|
||||
math::max(0, InLeftPads{}[0] - ConvDilationH * (YTilda - 1)), ConvStrides{}[0]);
|
||||
constexpr index_t WTildaLeft = math::integer_divide_floor(
|
||||
math::max(0, InLeftPads{}[1] - ConvDilationW * (XTilda - 1)), ConvStrides{}[1]);
|
||||
|
||||
constexpr index_t HTildaRight = math::min(
|
||||
HTilda, math::integer_divide_ceil(InLeftPads{}[0] + Hi - 1, ConvStrides{}[0]) + 1);
|
||||
constexpr index_t WTildaRight = math::min(
|
||||
WTilda, math::integer_divide_ceil(InLeftPads{}[1] + Wi - 1, ConvStrides{}[1]) + 1);
|
||||
|
||||
constexpr index_t HTildaSlice = HTildaRight - HTildaLeft;
|
||||
constexpr index_t WTildaSlice = WTildaRight - WTildaLeft;
|
||||
|
||||
constexpr index_t GemmM = C;
|
||||
constexpr index_t GemmN = N * HTildaSlice * WTildaSlice;
|
||||
|
||||
constexpr index_t GridSize = math::integer_divide_ceil(GemmM, GemmMPerBlock) *
|
||||
math::integer_divide_ceil(GemmN, GemmNPerBlock);
|
||||
|
||||
printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize);
|
||||
|
||||
for(index_t i = 0; i < 5; ++i)
|
||||
{
|
||||
std::cout << "Start running " << nrepeat << " times..." << std::endl;
|
||||
|
||||
KernelTimer timer;
|
||||
timer.Start();
|
||||
|
||||
for(index_t i = 0; i < nrepeat; ++i)
|
||||
{
|
||||
using GridwiseConvBwdData =
|
||||
GridwiseConvolutionBackwardDataImplicitGemm_v5r1_nhwc_kyxc_nhwk<
|
||||
GridSize,
|
||||
BlockSize,
|
||||
T,
|
||||
T,
|
||||
decltype(in_nhwc_desc),
|
||||
decltype(wei_kyxc_desc),
|
||||
decltype(out_nhwk_desc),
|
||||
ConvStrides,
|
||||
ConvDilations,
|
||||
InLeftPads,
|
||||
InRightPads,
|
||||
GemmMPerBlock,
|
||||
GemmNPerBlock,
|
||||
GemmKPerBlock,
|
||||
GemmMPerThread,
|
||||
GemmNPerThread,
|
||||
GemmKPerThread,
|
||||
GemmMLevel0Cluster,
|
||||
GemmNLevel0Cluster,
|
||||
GemmMLevel1Cluster,
|
||||
GemmNLevel1Cluster,
|
||||
GemmThreadGemmDataPerReadM,
|
||||
GemmThreadGemmDataPerReadN,
|
||||
GemmABlockCopyThreadSliceLengths_GemmK0_GemmK1_GemmK2_GemmM,
|
||||
GemmABlockCopyThreadClusterLengths_GemmK0_GemmK1_GemmK2_GemmM,
|
||||
GemmABlockCopySrcDataPerRead_GemmM,
|
||||
GemmABlockCopyDstDataPerWrite_GemmM,
|
||||
GemmBBlockCopyThreadSliceLengths_GemmK0_GemmK1_GemmK2_GemmN,
|
||||
GemmBBlockCopyThreadClusterLengths_GemmK0_GemmK1_GemmK2_GemmN,
|
||||
GemmBBlockCopySrcDataPerRead_GemmK2,
|
||||
GemmBBlockCopyDstDataPerWrite_GemmN,
|
||||
GemmCThreadCopyDstDataPerWrite_GemmN1>;
|
||||
|
||||
static_for<0, GridwiseConvBwdData::GetNumberOfGemm(), 1>{}([&](auto gemm_id) {
|
||||
constexpr auto gemm_sizes = GridwiseConvBwdData::GetGemmSize(gemm_id);
|
||||
constexpr index_t gemm_k2 = gemm_sizes[Number<4>{}];
|
||||
constexpr bool is_gemm_not_empty = gemm_k2 > 0;
|
||||
|
||||
// only compile and run if GEMM is no empty
|
||||
static_if<is_gemm_not_empty>{}([&](auto fwd) {
|
||||
launch_kernel(run_gridwise_operation<GridwiseConvBwdData,
|
||||
T* const __restrict__,
|
||||
const T* const __restrict__,
|
||||
const T* const __restrict__,
|
||||
decltype(gemm_id)>,
|
||||
dim3(GridSize),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
0,
|
||||
static_cast<T*>(in_nhwc_device_buf.GetDeviceBuffer()),
|
||||
static_cast<T*>(wei_kyxc_device_buf.GetDeviceBuffer()),
|
||||
static_cast<T*>(out_nhwk_device_buf.GetDeviceBuffer()),
|
||||
fwd(gemm_id));
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
timer.End();
|
||||
|
||||
float ave_time = timer.GetElapsedTime() / nrepeat;
|
||||
|
||||
float perf = (float)calculate_convolution_flops(InDesc{}, WeiDesc{}, OutDesc{}) /
|
||||
(std::size_t(1000) * 1000 * 1000) / ave_time;
|
||||
|
||||
std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" << std::endl;
|
||||
}
|
||||
|
||||
in_nhwc_device_buf.FromDevice(in_nhwc.mData.data());
|
||||
|
||||
auto f_nhwc2nchw = [&](auto n, auto c, auto hi, auto wi) {
|
||||
in_nchw(n, c, hi, wi) = in_nhwc(n, hi, wi, c);
|
||||
};
|
||||
|
||||
make_ParallelTensorFunctor(f_nhwc2nchw, N, C, Hi, Wi)(std::thread::hardware_concurrency());
|
||||
}
|
||||
|
||||
} // namespace launcher
|
||||
@@ -1,849 +0,0 @@
|
||||
#pragma once
|
||||
#include <unistd.h>
|
||||
#include "device.hpp"
|
||||
#include "host_tensor.hpp"
|
||||
#include "gridwise_operation_wrapper.hpp"
|
||||
#include "gridwise_convolution_forward_implicit_gemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer.hpp"
|
||||
|
||||
template <typename T,
|
||||
typename InDesc,
|
||||
typename WeiDesc,
|
||||
typename OutDesc,
|
||||
typename ConvStrides,
|
||||
typename ConvDilations,
|
||||
typename LeftPads,
|
||||
typename RightPads>
|
||||
void device_convolution_forward_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc,
|
||||
const Tensor<T>& in_nchw,
|
||||
WeiDesc,
|
||||
const Tensor<T>& wei_kcyx,
|
||||
OutDesc,
|
||||
Tensor<T>& out_nkhw,
|
||||
ConvStrides,
|
||||
ConvDilations,
|
||||
LeftPads,
|
||||
RightPads,
|
||||
ck::index_t nrepeat)
|
||||
{
|
||||
std::cout << "device_convolution_forward_implicit_gemm_v4r1_nchw_kcyx_nkhw" << std::endl;
|
||||
|
||||
using namespace ck;
|
||||
|
||||
using TDevice = typename conditional<is_same<half_float::half, T>::value, half_t, T>::type;
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
constexpr auto in_nchw_desc =
|
||||
make_native_tensor_descriptor(InDesc::GetLengths(), InDesc::GetStrides());
|
||||
constexpr auto wei_kcyx_desc =
|
||||
make_native_tensor_descriptor(WeiDesc::GetLengths(), WeiDesc::GetStrides());
|
||||
constexpr auto out_nkhw_desc =
|
||||
make_native_tensor_descriptor(OutDesc::GetLengths(), OutDesc::GetStrides());
|
||||
|
||||
constexpr index_t N = out_nkhw_desc.GetLength(I0);
|
||||
constexpr index_t K = out_nkhw_desc.GetLength(I1);
|
||||
constexpr index_t Ho = out_nkhw_desc.GetLength(I2);
|
||||
constexpr index_t Wo = out_nkhw_desc.GetLength(I3);
|
||||
|
||||
std::size_t data_sz = sizeof(T);
|
||||
DeviceMem in_nchw_device_buf(data_sz * in_nchw.mDesc.GetElementSpace());
|
||||
DeviceMem wei_kcyx_device_buf(data_sz * wei_kcyx.mDesc.GetElementSpace());
|
||||
DeviceMem out_nkhw_device_buf(data_sz * out_nkhw.mDesc.GetElementSpace());
|
||||
|
||||
in_nchw_device_buf.ToDevice(in_nchw.mData.data());
|
||||
wei_kcyx_device_buf.ToDevice(wei_kcyx.mData.data());
|
||||
out_nkhw_device_buf.ToDevice(out_nkhw.mData.data());
|
||||
|
||||
#if 0
|
||||
// cdata = 64, BlockSize = 256, 64x256x8
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t KPerBlock = 64;
|
||||
constexpr index_t BPerBlock = 32;
|
||||
constexpr index_t EPerBlock = 8;
|
||||
|
||||
constexpr index_t GemmNRepeat = 2;
|
||||
|
||||
constexpr index_t GemmMPerThread = 4;
|
||||
constexpr index_t GemmNPerThread = 4;
|
||||
constexpr index_t GemmKPerThread = 1;
|
||||
|
||||
constexpr index_t GemmMLevel0Cluster = 2;
|
||||
constexpr index_t GemmNLevel0Cluster = 2;
|
||||
constexpr index_t GemmMLevel1Cluster = 4;
|
||||
constexpr index_t GemmNLevel1Cluster = 16;
|
||||
|
||||
constexpr index_t GemmDataPerReadA = 4;
|
||||
constexpr index_t GemmDataPerReadB = 4;
|
||||
|
||||
using InBlockCopySubLengths_E_N1_B_N2 = Sequence<1, 2, 1, 4>;
|
||||
using InBlockCopyClusterLengths_E_N1_B_N2 = Sequence<8, 1, 32, 1>;
|
||||
using InBlockCopyThreadClusterArrangeOrder = Sequence<0, 1, 3, 2>; // [E, N1, N2, B]
|
||||
using InBlockCopySrcAccessOrder = Sequence<0, 2, 1, 3>; // [E, B, N1, N2]
|
||||
using InBlockCopyDstAccessOrder = Sequence<0, 1, 2, 3>; // [E, N1, B, N2]
|
||||
|
||||
constexpr index_t InBlockCopySrcDataPerRead_B = 1;
|
||||
constexpr index_t InBlockCopyDstDataPerWrite_N2 = 4;
|
||||
|
||||
using WeiBlockCopySubLengths_E_K = Sequence<2, 1>;
|
||||
using WeiBlockCopyClusterLengths_E_K = Sequence<4, 64>;
|
||||
using WeiBlockCopyThreadClusterArrangeOrder = Sequence<1, 0>; // [K, E]
|
||||
using WeiBlockCopySrcAccessOrder = Sequence<1, 0>; // [K, E]
|
||||
using WeiBlockCopyDstAccessOrder = Sequence<0, 1>; // [E, K]
|
||||
|
||||
constexpr index_t WeiBlockCopySrcDataPerRead_E = 2;
|
||||
constexpr index_t WeiBlockCopyDstDataPerWrite_K = 1;
|
||||
#elif 0
|
||||
// cdata = 64, BlockSize = 256, 128x128x4
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t KPerBlock = 128;
|
||||
constexpr index_t BPerBlock = 16;
|
||||
constexpr index_t EPerBlock = 4;
|
||||
|
||||
constexpr index_t GemmNRepeat = 2;
|
||||
|
||||
constexpr index_t GemmMPerThread = 4;
|
||||
constexpr index_t GemmNPerThread = 4;
|
||||
constexpr index_t GemmKPerThread = 1;
|
||||
|
||||
constexpr index_t GemmMLevel0Cluster = 2;
|
||||
constexpr index_t GemmNLevel0Cluster = 2;
|
||||
constexpr index_t GemmMLevel1Cluster = 8;
|
||||
constexpr index_t GemmNLevel1Cluster = 8;
|
||||
|
||||
constexpr index_t GemmDataPerReadA = 4;
|
||||
constexpr index_t GemmDataPerReadB = 4;
|
||||
|
||||
using InBlockCopySubLengths_E_N1_B_N2 = Sequence<1, 1, 1, 2>;
|
||||
using InBlockCopyClusterLengths_E_N1_B_N2 = Sequence<4, 2, 16, 2>;
|
||||
using InBlockCopyThreadClusterArrangeOrder = Sequence<0, 1, 3, 2>; // [E, N1, N2, B]
|
||||
using InBlockCopySrcAccessOrder = Sequence<0, 2, 1, 3>; // [E, B, N1, N2]
|
||||
using InBlockCopyDstAccessOrder = Sequence<0, 1, 2, 3>; // [E, N1, B, N2]
|
||||
|
||||
constexpr index_t InBlockCopySrcDataPerRead_B = 1;
|
||||
constexpr index_t InBlockCopyDstDataPerWrite_N2 = 2;
|
||||
|
||||
using WeiBlockCopySubLengths_E_K = Sequence<2, 1>;
|
||||
using WeiBlockCopyClusterLengths_E_K = Sequence<2, 128>;
|
||||
using WeiBlockCopyThreadClusterArrangeOrder = Sequence<1, 0>; // [K, E]
|
||||
using WeiBlockCopySrcAccessOrder = Sequence<1, 0>; // [K, E]
|
||||
using WeiBlockCopyDstAccessOrder = Sequence<0, 1>; // [E, K]
|
||||
|
||||
constexpr index_t WeiBlockCopySrcDataPerRead_E = 2;
|
||||
constexpr index_t WeiBlockCopyDstDataPerWrite_K = 1;
|
||||
#elif 1
|
||||
// cdata = 64, BlockSize = 256, 128x128x8
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t KPerBlock = 128;
|
||||
constexpr index_t BPerBlock = 16;
|
||||
constexpr index_t EPerBlock = 8;
|
||||
|
||||
constexpr index_t GemmNRepeat = 2;
|
||||
|
||||
constexpr index_t GemmMPerThread = 4;
|
||||
constexpr index_t GemmNPerThread = 4;
|
||||
constexpr index_t GemmKPerThread = 1;
|
||||
|
||||
constexpr index_t GemmMLevel0Cluster = 2;
|
||||
constexpr index_t GemmNLevel0Cluster = 2;
|
||||
constexpr index_t GemmMLevel1Cluster = 8;
|
||||
constexpr index_t GemmNLevel1Cluster = 8;
|
||||
|
||||
constexpr index_t GemmDataPerReadA = 4;
|
||||
constexpr index_t GemmDataPerReadB = 4;
|
||||
|
||||
using InBlockCopySubLengths_E_N1_B_N2 = Sequence<1, 1, 1, 4>;
|
||||
using InBlockCopyClusterLengths_E_N1_B_N2 = Sequence<8, 2, 16, 1>;
|
||||
using InBlockCopyThreadClusterArrangeOrder = Sequence<0, 1, 3, 2>; // [E, N1, N2, B]
|
||||
using InBlockCopySrcAccessOrder = Sequence<0, 2, 1, 3>; // [E, B, N1, N2]
|
||||
using InBlockCopyDstAccessOrder = Sequence<0, 1, 2, 3>; // [E, N1, B, N2]
|
||||
|
||||
constexpr index_t InBlockCopySrcDataPerRead_B = 1;
|
||||
constexpr index_t InBlockCopyDstDataPerWrite_N2 = 4;
|
||||
|
||||
using WeiBlockCopySubLengths_E_K = Sequence<4, 1>;
|
||||
using WeiBlockCopyClusterLengths_E_K = Sequence<2, 128>;
|
||||
using WeiBlockCopyThreadClusterArrangeOrder = Sequence<1, 0>; // [K, E]
|
||||
using WeiBlockCopySrcAccessOrder = Sequence<1, 0>; // [K, E]
|
||||
using WeiBlockCopyDstAccessOrder = Sequence<0, 1>; // [E, K]
|
||||
|
||||
constexpr index_t WeiBlockCopySrcDataPerRead_E = 4;
|
||||
constexpr index_t WeiBlockCopyDstDataPerWrite_K = 1;
|
||||
#elif 1
|
||||
// cdata = 64, BlockSize = 256, 128x128x16
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t KPerBlock = 128;
|
||||
constexpr index_t BPerBlock = 16;
|
||||
constexpr index_t EPerBlock = 16;
|
||||
|
||||
constexpr index_t GemmNRepeat = 2;
|
||||
|
||||
constexpr index_t GemmMPerThread = 4;
|
||||
constexpr index_t GemmNPerThread = 4;
|
||||
constexpr index_t GemmKPerThread = 1;
|
||||
|
||||
constexpr index_t GemmMLevel0Cluster = 4;
|
||||
constexpr index_t GemmNLevel0Cluster = 4;
|
||||
constexpr index_t GemmMLevel1Cluster = 4;
|
||||
constexpr index_t GemmNLevel1Cluster = 4;
|
||||
|
||||
constexpr index_t GemmDataPerReadA = 4;
|
||||
constexpr index_t GemmDataPerReadB = 4;
|
||||
|
||||
using InBlockCopySubLengths_E_N1_B_N2 = Sequence<1, 2, 1, 4>;
|
||||
using InBlockCopyClusterLengths_E_N1_B_N2 = Sequence<16, 1, 16, 1>;
|
||||
using InBlockCopyThreadClusterArrangeOrder = Sequence<0, 1, 3, 2>; // [E, N1, N2, B]
|
||||
using InBlockCopySrcAccessOrder = Sequence<0, 2, 1, 3>; // [E, B, N1, N2]
|
||||
using InBlockCopyDstAccessOrder = Sequence<0, 1, 2, 3>; // [E, N1, B, N2]
|
||||
|
||||
constexpr index_t InBlockCopySrcDataPerRead_B = 1;
|
||||
constexpr index_t InBlockCopyDstDataPerWrite_N2 = 4;
|
||||
|
||||
using WeiBlockCopySubLengths_E_K = Sequence<4, 2>;
|
||||
using WeiBlockCopyClusterLengths_E_K = Sequence<4, 64>;
|
||||
using WeiBlockCopyThreadClusterArrangeOrder = Sequence<1, 0>; // [K, E]
|
||||
using WeiBlockCopySrcAccessOrder = Sequence<1, 0>; // [K, E]
|
||||
using WeiBlockCopyDstAccessOrder = Sequence<0, 1>; // [E, K]
|
||||
|
||||
constexpr index_t WeiBlockCopySrcDataPerRead_E = 4;
|
||||
constexpr index_t WeiBlockCopyDstDataPerWrite_K = 2;
|
||||
#elif 0
|
||||
// cdata = 4, BlockSize = 256, 128x128x16
|
||||
// for 1x1
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t KPerBlock = 128;
|
||||
constexpr index_t BPerBlock = 16;
|
||||
constexpr index_t EPerBlock = 16;
|
||||
|
||||
constexpr index_t GemmNRepeat = 2;
|
||||
|
||||
constexpr index_t GemmMPerThread = 4;
|
||||
constexpr index_t GemmNPerThread = 4;
|
||||
constexpr index_t GemmKPerThread = 1;
|
||||
|
||||
constexpr index_t GemmMLevel0Cluster = 4;
|
||||
constexpr index_t GemmNLevel0Cluster = 4;
|
||||
constexpr index_t GemmMLevel1Cluster = 4;
|
||||
constexpr index_t GemmNLevel1Cluster = 4;
|
||||
|
||||
constexpr index_t GemmDataPerReadA = 4;
|
||||
constexpr index_t GemmDataPerReadB = 4;
|
||||
|
||||
using InBlockCopySubLengths_E_N1_B_N2 = Sequence<4, 1, 1, 2>;
|
||||
using InBlockCopyClusterLengths_E_N1_B_N2 = Sequence<4, 2, 16, 2>;
|
||||
using InBlockCopyThreadClusterArrangeOrder = Sequence<0, 1, 3, 2>; // [E, N1, N2, B]
|
||||
using InBlockCopySrcAccessOrder = Sequence<0, 2, 1, 3>; // [E, B, N1, N2]
|
||||
using InBlockCopyDstAccessOrder = Sequence<0, 1, 2, 3>; // [E, N1, B, N2]
|
||||
|
||||
constexpr index_t InBlockCopySrcDataPerRead_B = 1;
|
||||
constexpr index_t InBlockCopyDstDataPerWrite_N2 = 2;
|
||||
|
||||
using WeiBlockCopySubLengths_E_K = Sequence<4, 2>;
|
||||
using WeiBlockCopyClusterLengths_E_K = Sequence<4, 64>;
|
||||
using WeiBlockCopyThreadClusterArrangeOrder = Sequence<1, 0>; // [K, E]
|
||||
using WeiBlockCopySrcAccessOrder = Sequence<1, 0>; // [K, E]
|
||||
using WeiBlockCopyDstAccessOrder = Sequence<0, 1>; // [E, K]
|
||||
|
||||
constexpr index_t WeiBlockCopySrcDataPerRead_E = 4;
|
||||
constexpr index_t WeiBlockCopyDstDataPerWrite_K = 2;
|
||||
#elif 0
|
||||
// cdata = 64, BlockSize = 128, 64x128x4
|
||||
constexpr index_t BlockSize = 128;
|
||||
|
||||
constexpr index_t KPerBlock = 64;
|
||||
constexpr index_t BPerBlock = 16;
|
||||
constexpr index_t EPerBlock = 4;
|
||||
|
||||
constexpr index_t GemmNRepeat = 2;
|
||||
|
||||
constexpr index_t GemmMPerThread = 4;
|
||||
constexpr index_t GemmNPerThread = 4;
|
||||
constexpr index_t GemmKPerThread = 1;
|
||||
|
||||
constexpr index_t GemmMLevel0Cluster = 2;
|
||||
constexpr index_t GemmNLevel0Cluster = 2;
|
||||
constexpr index_t GemmMLevel1Cluster = 4;
|
||||
constexpr index_t GemmNLevel1Cluster = 8;
|
||||
|
||||
constexpr index_t GemmDataPerReadA = 4;
|
||||
constexpr index_t GemmDataPerReadB = 4;
|
||||
|
||||
using InBlockCopySubLengths_E_N1_B_N2 = Sequence<1, 1, 1, 4>;
|
||||
using InBlockCopyClusterLengths_E_N1_B_N2 = Sequence<4, 2, 16, 1>;
|
||||
using InBlockCopyThreadClusterArrangeOrder = Sequence<0, 1, 3, 2>; // [E, N1, N2, B]
|
||||
using InBlockCopySrcAccessOrder = Sequence<0, 2, 1, 3>; // [E, B, N1, N2]
|
||||
using InBlockCopyDstAccessOrder = Sequence<0, 1, 2, 3>; // [E, N1, B, N2]
|
||||
|
||||
constexpr index_t InBlockCopySrcDataPerRead_B = 1;
|
||||
constexpr index_t InBlockCopyDstDataPerWrite_N2 = 4;
|
||||
|
||||
using WeiBlockCopySubLengths_E_K = Sequence<2, 1>;
|
||||
using WeiBlockCopyClusterLengths_E_K = Sequence<2, 64>;
|
||||
using WeiBlockCopyThreadClusterArrangeOrder = Sequence<1, 0>; // [K, E]
|
||||
using WeiBlockCopySrcAccessOrder = Sequence<1, 0>; // [K, E]
|
||||
using WeiBlockCopyDstAccessOrder = Sequence<0, 1>; // [E, K]
|
||||
|
||||
constexpr index_t WeiBlockCopySrcDataPerRead_E = 2;
|
||||
constexpr index_t WeiBlockCopyDstDataPerWrite_K = 1;
|
||||
#elif 0
|
||||
// cdata = 64, BlockSize = 128, 64x128x8
|
||||
constexpr index_t BlockSize = 128;
|
||||
|
||||
constexpr index_t KPerBlock = 64;
|
||||
constexpr index_t BPerBlock = 16;
|
||||
constexpr index_t EPerBlock = 8;
|
||||
|
||||
constexpr index_t GemmNRepeat = 2;
|
||||
|
||||
constexpr index_t GemmMPerThread = 4;
|
||||
constexpr index_t GemmNPerThread = 4;
|
||||
constexpr index_t GemmKPerThread = 1;
|
||||
|
||||
constexpr index_t GemmMLevel0Cluster = 2;
|
||||
constexpr index_t GemmNLevel0Cluster = 2;
|
||||
constexpr index_t GemmMLevel1Cluster = 4;
|
||||
constexpr index_t GemmNLevel1Cluster = 8;
|
||||
|
||||
constexpr index_t GemmDataPerReadA = 4;
|
||||
constexpr index_t GemmDataPerReadB = 4;
|
||||
|
||||
using InBlockCopySubLengths_E_N1_B_N2 = Sequence<1, 2, 1, 4>;
|
||||
using InBlockCopyClusterLengths_E_N1_B_N2 = Sequence<8, 1, 16, 1>;
|
||||
using InBlockCopyThreadClusterArrangeOrder = Sequence<0, 1, 3, 2>; // [E, N1, N2, B]
|
||||
using InBlockCopySrcAccessOrder = Sequence<0, 2, 1, 3>; // [E, B, N1, N2]
|
||||
using InBlockCopyDstAccessOrder = Sequence<0, 1, 2, 3>; // [E, N1, B, N2]
|
||||
|
||||
constexpr index_t InBlockCopySrcDataPerRead_B = 1;
|
||||
constexpr index_t InBlockCopyDstDataPerWrite_N2 = 4;
|
||||
|
||||
using WeiBlockCopySubLengths_E_K = Sequence<4, 1>;
|
||||
using WeiBlockCopyClusterLengths_E_K = Sequence<2, 64>;
|
||||
using WeiBlockCopyThreadClusterArrangeOrder = Sequence<1, 0>; // [K, E]
|
||||
using WeiBlockCopySrcAccessOrder = Sequence<1, 0>; // [K, E]
|
||||
using WeiBlockCopyDstAccessOrder = Sequence<0, 1>; // [E, K]
|
||||
|
||||
constexpr index_t WeiBlockCopySrcDataPerRead_E = 4;
|
||||
constexpr index_t WeiBlockCopyDstDataPerWrite_K = 1;
|
||||
#elif 0
|
||||
// cdata = 64, BlockSize = 128, 64x128x16
|
||||
constexpr index_t BlockSize = 128;
|
||||
|
||||
constexpr index_t KPerBlock = 64;
|
||||
constexpr index_t BPerBlock = 16;
|
||||
constexpr index_t EPerBlock = 16;
|
||||
|
||||
constexpr index_t GemmNRepeat = 2;
|
||||
|
||||
constexpr index_t GemmMPerThread = 4;
|
||||
constexpr index_t GemmNPerThread = 4;
|
||||
constexpr index_t GemmKPerThread = 1;
|
||||
|
||||
constexpr index_t GemmMLevel0Cluster = 4;
|
||||
constexpr index_t GemmNLevel0Cluster = 4;
|
||||
constexpr index_t GemmMLevel1Cluster = 2;
|
||||
constexpr index_t GemmNLevel1Cluster = 4;
|
||||
|
||||
constexpr index_t GemmDataPerReadA = 4;
|
||||
constexpr index_t GemmDataPerReadB = 4;
|
||||
|
||||
using InBlockCopySubLengths_E_N1_B_N2 = Sequence<2, 2, 1, 4>;
|
||||
using InBlockCopyClusterLengths_E_N1_B_N2 = Sequence<8, 1, 16, 1>;
|
||||
using InBlockCopyThreadClusterArrangeOrder = Sequence<0, 1, 3, 2>; // [E, N1, N2, B]
|
||||
using InBlockCopySrcAccessOrder = Sequence<0, 2, 1, 3>; // [E, B, N1, N2]
|
||||
using InBlockCopyDstAccessOrder = Sequence<0, 1, 2, 3>; // [E, N1, B, N2]
|
||||
|
||||
constexpr index_t InBlockCopySrcDataPerRead_B = 1;
|
||||
constexpr index_t InBlockCopyDstDataPerWrite_N2 = 4;
|
||||
|
||||
using WeiBlockCopySubLengths_E_K = Sequence<4, 2>;
|
||||
using WeiBlockCopyClusterLengths_E_K = Sequence<4, 32>;
|
||||
using WeiBlockCopyThreadClusterArrangeOrder = Sequence<1, 0>; // [K, E]
|
||||
using WeiBlockCopySrcAccessOrder = Sequence<1, 0>; // [K, E]
|
||||
using WeiBlockCopyDstAccessOrder = Sequence<0, 1>; // [E, K]
|
||||
|
||||
constexpr index_t WeiBlockCopySrcDataPerRead_E = 4;
|
||||
constexpr index_t WeiBlockCopyDstDataPerWrite_K = 2;
|
||||
#elif 0
|
||||
// cdata = 64, BlockSize = 128, 128x64x4
|
||||
constexpr index_t BlockSize = 128;
|
||||
|
||||
constexpr index_t KPerBlock = 128;
|
||||
constexpr index_t BPerBlock = 8;
|
||||
constexpr index_t EPerBlock = 4;
|
||||
|
||||
constexpr index_t GemmNRepeat = 2;
|
||||
|
||||
constexpr index_t GemmMPerThread = 4;
|
||||
constexpr index_t GemmNPerThread = 4;
|
||||
constexpr index_t GemmKPerThread = 1;
|
||||
|
||||
constexpr index_t GemmMLevel0Cluster = 2;
|
||||
constexpr index_t GemmNLevel0Cluster = 2;
|
||||
constexpr index_t GemmMLevel1Cluster = 8;
|
||||
constexpr index_t GemmNLevel1Cluster = 4;
|
||||
|
||||
constexpr index_t GemmDataPerReadA = 4;
|
||||
constexpr index_t GemmDataPerReadB = 4;
|
||||
|
||||
using InBlockCopySubLengths_E_N1_B_N2 = Sequence<1, 1, 1, 2>;
|
||||
using InBlockCopyClusterLengths_E_N1_B_N2 = Sequence<4, 2, 8, 2>;
|
||||
using InBlockCopyThreadClusterArrangeOrder = Sequence<0, 1, 3, 2>; // [E, N1, N2, B]
|
||||
using InBlockCopySrcAccessOrder = Sequence<0, 2, 1, 3>; // [E, B, N1, N2]
|
||||
using InBlockCopyDstAccessOrder = Sequence<0, 1, 2, 3>; // [E, N1, B, N2]
|
||||
|
||||
constexpr index_t InBlockCopySrcDataPerRead_B = 1;
|
||||
constexpr index_t InBlockCopyDstDataPerWrite_N2 = 2;
|
||||
|
||||
using WeiBlockCopySubLengths_E_K = Sequence<2, 2>;
|
||||
using WeiBlockCopyClusterLengths_E_K = Sequence<2, 64>;
|
||||
using WeiBlockCopyThreadClusterArrangeOrder = Sequence<1, 0>; // [K, E]
|
||||
using WeiBlockCopySrcAccessOrder = Sequence<1, 0>; // [K, E]
|
||||
using WeiBlockCopyDstAccessOrder = Sequence<0, 1>; // [E, K]
|
||||
|
||||
constexpr index_t WeiBlockCopySrcDataPerRead_E = 2;
|
||||
constexpr index_t WeiBlockCopyDstDataPerWrite_K = 2;
|
||||
#elif 0
|
||||
// cdata = 64, BlockSize = 128, 128x64x8
|
||||
constexpr index_t BlockSize = 128;
|
||||
|
||||
constexpr index_t KPerBlock = 128;
|
||||
constexpr index_t BPerBlock = 8;
|
||||
constexpr index_t EPerBlock = 8;
|
||||
|
||||
constexpr index_t GemmNRepeat = 2;
|
||||
|
||||
constexpr index_t GemmMPerThread = 4;
|
||||
constexpr index_t GemmNPerThread = 4;
|
||||
constexpr index_t GemmKPerThread = 1;
|
||||
|
||||
constexpr index_t GemmMLevel0Cluster = 2;
|
||||
constexpr index_t GemmNLevel0Cluster = 2;
|
||||
constexpr index_t GemmMLevel1Cluster = 8;
|
||||
constexpr index_t GemmNLevel1Cluster = 4;
|
||||
|
||||
constexpr index_t GemmDataPerReadA = 4;
|
||||
constexpr index_t GemmDataPerReadB = 4;
|
||||
|
||||
using InBlockCopySubLengths_E_N1_B_N2 = Sequence<1, 1, 1, 4>;
|
||||
using InBlockCopyClusterLengths_E_N1_B_N2 = Sequence<8, 2, 8, 1>;
|
||||
using InBlockCopyThreadClusterArrangeOrder = Sequence<0, 1, 3, 2>; // [E, N1, N2, B]
|
||||
using InBlockCopySrcAccessOrder = Sequence<0, 2, 1, 3>; // [E, B, N1, N2]
|
||||
using InBlockCopyDstAccessOrder = Sequence<0, 1, 2, 3>; // [E, N1, B, N2]
|
||||
|
||||
constexpr index_t InBlockCopySrcDataPerRead_B = 1;
|
||||
constexpr index_t InBlockCopyDstDataPerWrite_N2 = 4;
|
||||
|
||||
using WeiBlockCopySubLengths_E_K = Sequence<4, 2>;
|
||||
using WeiBlockCopyClusterLengths_E_K = Sequence<2, 64>;
|
||||
using WeiBlockCopyThreadClusterArrangeOrder = Sequence<1, 0>; // [K, E]
|
||||
using WeiBlockCopySrcAccessOrder = Sequence<1, 0>; // [K, E]
|
||||
using WeiBlockCopyDstAccessOrder = Sequence<0, 1>; // [E, K]
|
||||
|
||||
constexpr index_t WeiBlockCopySrcDataPerRead_E = 4;
|
||||
constexpr index_t WeiBlockCopyDstDataPerWrite_K = 2;
|
||||
#elif 0
|
||||
// cdata = 64, BlockSize = 128, 128x64x16
|
||||
constexpr index_t BlockSize = 128;
|
||||
|
||||
constexpr index_t KPerBlock = 128;
|
||||
constexpr index_t BPerBlock = 8;
|
||||
constexpr index_t EPerBlock = 16;
|
||||
|
||||
constexpr index_t GemmNRepeat = 2;
|
||||
|
||||
constexpr index_t GemmMPerThread = 4;
|
||||
constexpr index_t GemmNPerThread = 4;
|
||||
constexpr index_t GemmKPerThread = 1;
|
||||
|
||||
constexpr index_t GemmMLevel0Cluster = 2;
|
||||
constexpr index_t GemmNLevel0Cluster = 2;
|
||||
constexpr index_t GemmMLevel1Cluster = 8;
|
||||
constexpr index_t GemmNLevel1Cluster = 4;
|
||||
|
||||
constexpr index_t GemmDataPerReadA = 4;
|
||||
constexpr index_t GemmDataPerReadB = 4;
|
||||
|
||||
using InBlockCopySubLengths_E_N1_B_N2 = Sequence<1, 2, 1, 4>;
|
||||
using InBlockCopyClusterLengths_E_N1_B_N2 = Sequence<16, 1, 8, 1>;
|
||||
using InBlockCopyThreadClusterArrangeOrder = Sequence<0, 1, 3, 2>; // [E, N1, N2, B]
|
||||
using InBlockCopySrcAccessOrder = Sequence<0, 2, 1, 3>; // [E, B, N1, N2]
|
||||
using InBlockCopyDstAccessOrder = Sequence<0, 1, 2, 3>; // [E, N1, B, N2]
|
||||
|
||||
constexpr index_t InBlockCopySrcDataPerRead_B = 1;
|
||||
constexpr index_t InBlockCopyDstDataPerWrite_N2 = 4;
|
||||
|
||||
using WeiBlockCopySubLengths_E_K = Sequence<4, 4>;
|
||||
using WeiBlockCopyClusterLengths_E_K = Sequence<4, 32>;
|
||||
using WeiBlockCopyThreadClusterArrangeOrder = Sequence<1, 0>; // [K, E]
|
||||
using WeiBlockCopySrcAccessOrder = Sequence<1, 0>; // [K, E]
|
||||
using WeiBlockCopyDstAccessOrder = Sequence<0, 1>; // [E, K]
|
||||
|
||||
constexpr index_t WeiBlockCopySrcDataPerRead_E = 4;
|
||||
constexpr index_t WeiBlockCopyDstDataPerWrite_K = 4;
|
||||
#elif 0
|
||||
// cdata = 64, BlockSize = 64, 64x64x8
|
||||
constexpr index_t BlockSize = 64;
|
||||
|
||||
constexpr index_t KPerBlock = 64;
|
||||
constexpr index_t BPerBlock = 8;
|
||||
constexpr index_t EPerBlock = 8;
|
||||
|
||||
constexpr index_t GemmNRepeat = 2;
|
||||
|
||||
constexpr index_t GemmMPerThread = 4;
|
||||
constexpr index_t GemmNPerThread = 4;
|
||||
constexpr index_t GemmKPerThread = 1;
|
||||
|
||||
constexpr index_t GemmMLevel0Cluster = 4;
|
||||
constexpr index_t GemmNLevel0Cluster = 4;
|
||||
constexpr index_t GemmMLevel1Cluster = 2;
|
||||
constexpr index_t GemmNLevel1Cluster = 2;
|
||||
|
||||
constexpr index_t GemmDataPerReadA = 4;
|
||||
constexpr index_t GemmDataPerReadB = 4;
|
||||
|
||||
using InBlockCopySubLengths_E_N1_B_N2 = Sequence<1, 2, 1, 4>;
|
||||
using InBlockCopyClusterLengths_E_N1_B_N2 = Sequence<8, 1, 8, 1>;
|
||||
using InBlockCopyThreadClusterArrangeOrder = Sequence<0, 1, 3, 2>; // [E, N1, N2, B]
|
||||
using InBlockCopySrcAccessOrder = Sequence<0, 2, 1, 3>; // [E, B, N1, N2]
|
||||
using InBlockCopyDstAccessOrder = Sequence<0, 1, 2, 3>; // [E, N1, B, N2]
|
||||
|
||||
constexpr index_t InBlockCopySrcDataPerRead_B = 1;
|
||||
constexpr index_t InBlockCopyDstDataPerWrite_N2 = 4;
|
||||
|
||||
using WeiBlockCopySubLengths_E_K = Sequence<4, 2>;
|
||||
using WeiBlockCopyClusterLengths_E_K = Sequence<2, 32>;
|
||||
using WeiBlockCopyThreadClusterArrangeOrder = Sequence<1, 0>; // [K, E]
|
||||
using WeiBlockCopySrcAccessOrder = Sequence<1, 0>; // [K, E]
|
||||
using WeiBlockCopyDstAccessOrder = Sequence<0, 1>; // [E, K]
|
||||
|
||||
constexpr index_t WeiBlockCopySrcDataPerRead_E = 4;
|
||||
constexpr index_t WeiBlockCopyDstDataPerWrite_K = 1;
|
||||
#elif 0
|
||||
// cdata = 64, BlockSize = 32, 32x64x3
|
||||
constexpr index_t BlockSize = 32;
|
||||
|
||||
constexpr index_t KPerBlock = 32;
|
||||
constexpr index_t BPerBlock = 8;
|
||||
constexpr index_t EPerBlock = 3;
|
||||
|
||||
constexpr index_t GemmNRepeat = 2;
|
||||
|
||||
constexpr index_t GemmMPerThread = 4;
|
||||
constexpr index_t GemmNPerThread = 4;
|
||||
constexpr index_t GemmKPerThread = 1;
|
||||
|
||||
constexpr index_t GemmMLevel0Cluster = 4;
|
||||
constexpr index_t GemmNLevel0Cluster = 4;
|
||||
constexpr index_t GemmMLevel1Cluster = 1;
|
||||
constexpr index_t GemmNLevel1Cluster = 2;
|
||||
|
||||
constexpr index_t GemmDataPerReadA = 4;
|
||||
constexpr index_t GemmDataPerReadB = 4;
|
||||
|
||||
using InBlockCopySubLengths_E_N1_B_N2 = Sequence<3, 1, 1, 2>;
|
||||
using InBlockCopyClusterLengths_E_N1_B_N2 = Sequence<1, 2, 8, 2>;
|
||||
using InBlockCopyThreadClusterArrangeOrder = Sequence<0, 1, 3, 2>; // [E, N1, N2, B]
|
||||
using InBlockCopySrcAccessOrder = Sequence<0, 2, 1, 3>; // [E, B, N1, N2]
|
||||
using InBlockCopyDstAccessOrder = Sequence<0, 1, 2, 3>; // [E, N1, B, N2]
|
||||
|
||||
constexpr index_t InBlockCopySrcDataPerRead_B = 1;
|
||||
constexpr index_t InBlockCopyDstDataPerWrite_N2 = 2;
|
||||
|
||||
using WeiBlockCopySubLengths_E_K = Sequence<3, 1>;
|
||||
using WeiBlockCopyClusterLengths_E_K = Sequence<1, 32>;
|
||||
using WeiBlockCopyThreadClusterArrangeOrder = Sequence<1, 0>; // [K, E]
|
||||
using WeiBlockCopySrcAccessOrder = Sequence<1, 0>; // [K, E]
|
||||
using WeiBlockCopyDstAccessOrder = Sequence<0, 1>; // [E, K]
|
||||
|
||||
constexpr index_t WeiBlockCopySrcDataPerRead_E = 1;
|
||||
constexpr index_t WeiBlockCopyDstDataPerWrite_K = 1;
|
||||
#elif 0
|
||||
// cdata = 64, BlockSize = 64, 32x128x3
|
||||
constexpr index_t BlockSize = 64;
|
||||
|
||||
constexpr index_t KPerBlock = 32;
|
||||
constexpr index_t BPerBlock = 16;
|
||||
constexpr index_t EPerBlock = 3;
|
||||
|
||||
constexpr index_t GemmNRepeat = 2;
|
||||
|
||||
constexpr index_t GemmMPerThread = 4;
|
||||
constexpr index_t GemmNPerThread = 4;
|
||||
constexpr index_t GemmKPerThread = 1;
|
||||
|
||||
constexpr index_t GemmMLevel0Cluster = 4;
|
||||
constexpr index_t GemmNLevel0Cluster = 4;
|
||||
constexpr index_t GemmMLevel1Cluster = 1;
|
||||
constexpr index_t GemmNLevel1Cluster = 4;
|
||||
|
||||
constexpr index_t GemmDataPerReadA = 4;
|
||||
constexpr index_t GemmDataPerReadB = 4;
|
||||
|
||||
using InBlockCopySubLengths_E_N1_B_N2 = Sequence<3, 1, 1, 2>;
|
||||
using InBlockCopyClusterLengths_E_N1_B_N2 = Sequence<1, 2, 16, 2>;
|
||||
using InBlockCopyThreadClusterArrangeOrder = Sequence<0, 1, 3, 2>; // [E, N1, N2, B]
|
||||
using InBlockCopySrcAccessOrder = Sequence<0, 2, 1, 3>; // [E, B, N1, N2]
|
||||
using InBlockCopyDstAccessOrder = Sequence<0, 1, 2, 3>; // [E, N1, B, N2]
|
||||
|
||||
constexpr index_t InBlockCopySrcDataPerRead_B = 1;
|
||||
constexpr index_t InBlockCopyDstDataPerWrite_N2 = 2;
|
||||
|
||||
using WeiBlockCopySubLengths_E_K = Sequence<3, 1>;
|
||||
using WeiBlockCopyClusterLengths_E_K = Sequence<1, 32>;
|
||||
using WeiBlockCopyThreadClusterArrangeOrder = Sequence<1, 0>; // [K, E]
|
||||
using WeiBlockCopySrcAccessOrder = Sequence<1, 0>; // [K, E]
|
||||
using WeiBlockCopyDstAccessOrder = Sequence<0, 1>; // [E, K]
|
||||
|
||||
constexpr index_t WeiBlockCopySrcDataPerRead_E = 1;
|
||||
constexpr index_t WeiBlockCopyDstDataPerWrite_K = 1;
|
||||
#elif 0
|
||||
// cdata = 64, BlockSize = 64, 64x64x3
|
||||
constexpr index_t BlockSize = 64;
|
||||
|
||||
constexpr index_t KPerBlock = 64;
|
||||
constexpr index_t BPerBlock = 8;
|
||||
constexpr index_t EPerBlock = 3;
|
||||
|
||||
constexpr index_t GemmNRepeat = 2;
|
||||
|
||||
constexpr index_t GemmMPerThread = 4;
|
||||
constexpr index_t GemmNPerThread = 4;
|
||||
constexpr index_t GemmKPerThread = 1;
|
||||
|
||||
constexpr index_t GemmMLevel0Cluster = 2;
|
||||
constexpr index_t GemmNLevel0Cluster = 2;
|
||||
constexpr index_t GemmMLevel1Cluster = 4;
|
||||
constexpr index_t GemmNLevel1Cluster = 4;
|
||||
|
||||
constexpr index_t GemmDataPerReadA = 4;
|
||||
constexpr index_t GemmDataPerReadB = 4;
|
||||
|
||||
using InBlockCopySubLengths_E_N1_B_N2 = Sequence<3, 1, 1, 1>;
|
||||
using InBlockCopyClusterLengths_E_N1_B_N2 = Sequence<1, 2, 8, 4>;
|
||||
using InBlockCopyThreadClusterArrangeOrder = Sequence<0, 1, 3, 2>; // [E, N1, N2, B]
|
||||
using InBlockCopySrcAccessOrder = Sequence<0, 2, 1, 3>; // [E, B, N1, N2]
|
||||
using InBlockCopyDstAccessOrder = Sequence<0, 1, 2, 3>; // [E, N1, B, N2]
|
||||
|
||||
constexpr index_t InBlockCopySrcDataPerRead_B = 1;
|
||||
constexpr index_t InBlockCopyDstDataPerWrite_N2 = 1;
|
||||
|
||||
using WeiBlockCopySubLengths_E_K = Sequence<3, 1>;
|
||||
using WeiBlockCopyClusterLengths_E_K = Sequence<1, 64>;
|
||||
using WeiBlockCopyThreadClusterArrangeOrder = Sequence<1, 0>; // [K, E]
|
||||
using WeiBlockCopySrcAccessOrder = Sequence<1, 0>; // [K, E]
|
||||
using WeiBlockCopyDstAccessOrder = Sequence<0, 1>; // [E, K]
|
||||
|
||||
constexpr index_t WeiBlockCopySrcDataPerRead_E = 1;
|
||||
constexpr index_t WeiBlockCopyDstDataPerWrite_K = 1;
|
||||
#elif 0
|
||||
// cdata = 64, BlockSize = 64, 32x128x4
|
||||
constexpr index_t BlockSize = 64;
|
||||
|
||||
constexpr index_t KPerBlock = 32;
|
||||
constexpr index_t BPerBlock = 16;
|
||||
constexpr index_t EPerBlock = 4;
|
||||
|
||||
constexpr index_t GemmNRepeat = 2;
|
||||
|
||||
constexpr index_t GemmMPerThread = 4;
|
||||
constexpr index_t GemmNPerThread = 4;
|
||||
constexpr index_t GemmKPerThread = 1;
|
||||
|
||||
constexpr index_t GemmMLevel0Cluster = 2;
|
||||
constexpr index_t GemmNLevel0Cluster = 2;
|
||||
constexpr index_t GemmMLevel1Cluster = 2;
|
||||
constexpr index_t GemmNLevel1Cluster = 8;
|
||||
|
||||
constexpr index_t GemmDataPerReadA = 4;
|
||||
constexpr index_t GemmDataPerReadB = 4;
|
||||
|
||||
using InBlockCopySubLengths_E_N1_B_N2 = Sequence<1, 2, 1, 4>;
|
||||
using InBlockCopyClusterLengths_E_N1_B_N2 = Sequence<4, 1, 16, 1>;
|
||||
using InBlockCopyThreadClusterArrangeOrder = Sequence<0, 1, 3, 2>; // [E, N1, N2, B]
|
||||
using InBlockCopySrcAccessOrder = Sequence<0, 2, 1, 3>; // [E, B, N1, N2]
|
||||
using InBlockCopyDstAccessOrder = Sequence<0, 1, 2, 3>; // [E, N1, B, N2]
|
||||
|
||||
constexpr index_t InBlockCopySrcDataPerRead_B = 1;
|
||||
constexpr index_t InBlockCopyDstDataPerWrite_N2 = 4;
|
||||
|
||||
using WeiBlockCopySubLengths_E_K = Sequence<2, 1>;
|
||||
using WeiBlockCopyClusterLengths_E_K = Sequence<2, 32>;
|
||||
using WeiBlockCopyThreadClusterArrangeOrder = Sequence<1, 0>; // [K, E]
|
||||
using WeiBlockCopySrcAccessOrder = Sequence<1, 0>; // [K, E]
|
||||
using WeiBlockCopyDstAccessOrder = Sequence<0, 1>; // [E, K]
|
||||
|
||||
constexpr index_t WeiBlockCopySrcDataPerRead_E = 2;
|
||||
constexpr index_t WeiBlockCopyDstDataPerWrite_K = 1;
|
||||
#elif 0
|
||||
// cdata = 64, BlockSize = 64, 32x128x8
|
||||
constexpr index_t BlockSize = 64;
|
||||
|
||||
constexpr index_t KPerBlock = 32;
|
||||
constexpr index_t BPerBlock = 16;
|
||||
constexpr index_t EPerBlock = 8;
|
||||
|
||||
constexpr index_t GemmNRepeat = 2;
|
||||
|
||||
constexpr index_t GemmMPerThread = 4;
|
||||
constexpr index_t GemmNPerThread = 4;
|
||||
constexpr index_t GemmKPerThread = 1;
|
||||
|
||||
constexpr index_t GemmMLevel0Cluster = 4;
|
||||
constexpr index_t GemmNLevel0Cluster = 4;
|
||||
constexpr index_t GemmMLevel1Cluster = 1;
|
||||
constexpr index_t GemmNLevel1Cluster = 4;
|
||||
|
||||
constexpr index_t GemmDataPerReadA = 4;
|
||||
constexpr index_t GemmDataPerReadB = 4;
|
||||
|
||||
using InBlockCopySubLengths_E_N1_B_N2 = Sequence<2, 2, 1, 4>;
|
||||
using InBlockCopyClusterLengths_E_N1_B_N2 = Sequence<4, 1, 16, 1>;
|
||||
using InBlockCopyThreadClusterArrangeOrder = Sequence<0, 1, 3, 2>; // [E, N1, N2, B]
|
||||
using InBlockCopySrcAccessOrder = Sequence<0, 2, 1, 3>; // [E, B, N1, N2]
|
||||
using InBlockCopyDstAccessOrder = Sequence<0, 1, 2, 3>; // [E, N1, B, N2]
|
||||
|
||||
constexpr index_t InBlockCopySrcDataPerRead_B = 1;
|
||||
constexpr index_t InBlockCopyDstDataPerWrite_N2 = 4;
|
||||
|
||||
using WeiBlockCopySubLengths_E_K = Sequence<4, 1>;
|
||||
using WeiBlockCopyClusterLengths_E_K = Sequence<2, 32>;
|
||||
using WeiBlockCopyThreadClusterArrangeOrder = Sequence<1, 0>; // [K, E]
|
||||
using WeiBlockCopySrcAccessOrder = Sequence<1, 0>; // [K, E]
|
||||
using WeiBlockCopyDstAccessOrder = Sequence<0, 1>; // [E, K]
|
||||
|
||||
constexpr index_t WeiBlockCopySrcDataPerRead_E = 4;
|
||||
constexpr index_t WeiBlockCopyDstDataPerWrite_K = 1;
|
||||
#elif 0
|
||||
// cdata = 32, BlockSize = 256, 64x128x8
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t KPerBlock = 64;
|
||||
constexpr index_t BPerBlock = 16;
|
||||
constexpr index_t EPerBlock = 8;
|
||||
|
||||
constexpr index_t GemmNRepeat = 2;
|
||||
|
||||
constexpr index_t GemmMPerThread = 2;
|
||||
constexpr index_t GemmNPerThread = 4;
|
||||
constexpr index_t GemmKPerThread = 1;
|
||||
|
||||
constexpr index_t GemmMLevel0Cluster = 4;
|
||||
constexpr index_t GemmNLevel0Cluster = 4;
|
||||
constexpr index_t GemmMLevel1Cluster = 4;
|
||||
constexpr index_t GemmNLevel1Cluster = 4;
|
||||
|
||||
constexpr index_t GemmDataPerReadA = 2;
|
||||
constexpr index_t GemmDataPerReadB = 4;
|
||||
|
||||
using InBlockCopySubLengths_E_N1_B_N2 = Sequence<1, 1, 1, 4>;
|
||||
using InBlockCopyClusterLengths_E_N1_B_N2 = Sequence<8, 2, 16, 1>;
|
||||
using InBlockCopyThreadClusterArrangeOrder = Sequence<0, 1, 3, 2>; // [E, N1, N2, B]
|
||||
using InBlockCopySrcAccessOrder = Sequence<0, 2, 1, 3>; // [E, B, N1, N2]
|
||||
using InBlockCopyDstAccessOrder = Sequence<0, 1, 2, 3>; // [E, N1, B, N2]
|
||||
|
||||
constexpr index_t InBlockCopySrcDataPerRead_B = 1;
|
||||
constexpr index_t InBlockCopyDstDataPerWrite_N2 = 4;
|
||||
|
||||
using WeiBlockCopySubLengths_E_K = Sequence<2, 1>;
|
||||
using WeiBlockCopyClusterLengths_E_K = Sequence<4, 64>;
|
||||
using WeiBlockCopyThreadClusterArrangeOrder = Sequence<1, 0>; // [K, E]
|
||||
using WeiBlockCopySrcAccessOrder = Sequence<1, 0>; // [K, E]
|
||||
using WeiBlockCopyDstAccessOrder = Sequence<0, 1>; // [E, K]
|
||||
|
||||
constexpr index_t WeiBlockCopySrcDataPerRead_E = 2;
|
||||
constexpr index_t WeiBlockCopyDstDataPerWrite_K = 1;
|
||||
#endif
|
||||
|
||||
constexpr index_t N1 = GemmNRepeat;
|
||||
constexpr index_t N2 = GemmNPerThread;
|
||||
|
||||
constexpr index_t B = (N * Ho * Wo) / (N1 * N2);
|
||||
|
||||
constexpr index_t GridSize =
|
||||
((B + BPerBlock - 1) / BPerBlock) * ((K + KPerBlock - 1) / KPerBlock);
|
||||
|
||||
printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize);
|
||||
|
||||
using gridwise_conv =
|
||||
GridwiseConvolutionForwardImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer<
|
||||
GridSize,
|
||||
BlockSize,
|
||||
T,
|
||||
T,
|
||||
decltype(in_nchw_desc),
|
||||
decltype(wei_kcyx_desc),
|
||||
decltype(out_nkhw_desc),
|
||||
ConvStrides,
|
||||
ConvDilations,
|
||||
LeftPads,
|
||||
RightPads,
|
||||
BPerBlock,
|
||||
KPerBlock,
|
||||
EPerBlock,
|
||||
GemmNRepeat,
|
||||
GemmMPerThread,
|
||||
GemmNPerThread,
|
||||
GemmKPerThread,
|
||||
GemmMLevel0Cluster,
|
||||
GemmNLevel0Cluster,
|
||||
GemmMLevel1Cluster,
|
||||
GemmNLevel1Cluster,
|
||||
GemmDataPerReadA,
|
||||
GemmDataPerReadB,
|
||||
InBlockCopySubLengths_E_N1_B_N2,
|
||||
InBlockCopyClusterLengths_E_N1_B_N2,
|
||||
InBlockCopyThreadClusterArrangeOrder,
|
||||
InBlockCopySrcAccessOrder,
|
||||
InBlockCopyDstAccessOrder,
|
||||
InBlockCopySrcDataPerRead_B,
|
||||
InBlockCopyDstDataPerWrite_N2,
|
||||
WeiBlockCopySubLengths_E_K,
|
||||
WeiBlockCopyClusterLengths_E_K,
|
||||
WeiBlockCopyThreadClusterArrangeOrder,
|
||||
WeiBlockCopySrcAccessOrder,
|
||||
WeiBlockCopyDstAccessOrder,
|
||||
WeiBlockCopySrcDataPerRead_E,
|
||||
WeiBlockCopyDstDataPerWrite_K>;
|
||||
|
||||
for(index_t i = 0; i < 5; ++i)
|
||||
{
|
||||
std::cout << "Start running " << nrepeat << " times..." << std::endl;
|
||||
|
||||
KernelTimer timer;
|
||||
timer.Start();
|
||||
|
||||
for(index_t j = 0; j < nrepeat; ++j)
|
||||
{
|
||||
launch_kernel(run_gridwise_operation<gridwise_conv,
|
||||
const TDevice* const __restrict__,
|
||||
const TDevice* const __restrict__,
|
||||
TDevice* const __restrict__>,
|
||||
dim3(GridSize),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
0,
|
||||
static_cast<TDevice*>(in_nchw_device_buf.GetDeviceBuffer()),
|
||||
static_cast<TDevice*>(wei_kcyx_device_buf.GetDeviceBuffer()),
|
||||
static_cast<TDevice*>(out_nkhw_device_buf.GetDeviceBuffer()));
|
||||
}
|
||||
|
||||
timer.End();
|
||||
|
||||
float ave_time = timer.GetElapsedTime() / nrepeat;
|
||||
|
||||
float perf = (float)calculate_convolution_flops(InDesc{}, WeiDesc{}, OutDesc{}) /
|
||||
(std::size_t(1000) * 1000 * 1000) / ave_time;
|
||||
|
||||
std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" << std::endl;
|
||||
}
|
||||
|
||||
out_nkhw_device_buf.FromDevice(out_nkhw.mData.data());
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,207 +0,0 @@
|
||||
#include <unistd.h>
|
||||
#include "device.hpp"
|
||||
#include "host_tensor.hpp"
|
||||
#include "gridwise_operation_wrapper.hpp"
|
||||
#include "gridwise_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk.hpp"
|
||||
|
||||
template <class T,
|
||||
class InDesc,
|
||||
class WeiDesc,
|
||||
class OutDesc,
|
||||
class ConvStrides,
|
||||
class ConvDilations,
|
||||
class InLeftPads,
|
||||
class InRightPads>
|
||||
void device_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk(InDesc,
|
||||
const Tensor<T>& in_nchw,
|
||||
WeiDesc,
|
||||
const Tensor<T>& wei_kcyx,
|
||||
OutDesc,
|
||||
Tensor<T>& out_nkhw,
|
||||
ConvStrides,
|
||||
ConvDilations,
|
||||
InLeftPads,
|
||||
InRightPads,
|
||||
ck::index_t nrepeat)
|
||||
{
|
||||
std::cout << "device_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk" << std::endl;
|
||||
|
||||
using namespace ck;
|
||||
|
||||
using TDevice = typename conditional<is_same<half_float::half, T>::value, half_t, T>::type;
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
constexpr auto N = OutDesc::GetLengths()[I0];
|
||||
constexpr auto K = OutDesc::GetLengths()[I1];
|
||||
constexpr auto C = WeiDesc::GetLengths()[I1];
|
||||
|
||||
constexpr auto Hi = InDesc::GetLengths()[I2];
|
||||
constexpr auto Wi = InDesc::GetLengths()[I3];
|
||||
|
||||
constexpr auto Ho = OutDesc::GetLengths()[I2];
|
||||
constexpr auto Wo = OutDesc::GetLengths()[I3];
|
||||
|
||||
constexpr auto Y = WeiDesc::GetLengths()[I2];
|
||||
constexpr auto X = WeiDesc::GetLengths()[I3];
|
||||
|
||||
// compile-time variables
|
||||
constexpr auto in_n_hi_wi_c_desc =
|
||||
make_native_tensor_descriptor_packed(Sequence<N, Hi, Wi, C>{});
|
||||
constexpr auto wei_k_y_x_c_desc = make_native_tensor_descriptor_packed(Sequence<K, Y, X, C>{});
|
||||
constexpr auto out_n_ho_wo_k_desc =
|
||||
make_native_tensor_descriptor_packed(Sequence<N, Ho, Wo, K>{});
|
||||
|
||||
Tensor<float> in_nhwc(
|
||||
make_HostTensorDescriptor(make_native_tensor_descriptor_packed(Sequence<N, Hi, Wi, C>{})));
|
||||
Tensor<float> wei_kyxc(
|
||||
make_HostTensorDescriptor(make_native_tensor_descriptor_packed(Sequence<K, Y, X, C>{})));
|
||||
Tensor<float> out_nhwk(
|
||||
make_HostTensorDescriptor(make_native_tensor_descriptor_packed(Sequence<N, Ho, Wo, K>{})));
|
||||
|
||||
auto f_nchw2nhwc = [&](auto n, auto hi, auto wi, auto c) {
|
||||
in_nhwc(n, hi, wi, c) = in_nchw(n, c, hi, wi);
|
||||
};
|
||||
|
||||
auto f_kcyx2kyxc = [&](auto k, auto y, auto x, auto c) {
|
||||
wei_kyxc(k, y, x, c) = wei_kcyx(k, c, y, x);
|
||||
};
|
||||
|
||||
auto f_nkhw2nhwk = [&](auto n, auto ho, auto wo, auto k) {
|
||||
out_nhwk(n, ho, wo, k) = out_nkhw(n, k, ho, wo);
|
||||
};
|
||||
|
||||
make_ParallelTensorFunctor(f_nchw2nhwc, N, Hi, Wi, C)(std::thread::hardware_concurrency());
|
||||
make_ParallelTensorFunctor(f_kcyx2kyxc, K, Y, X, C)(std::thread::hardware_concurrency());
|
||||
make_ParallelTensorFunctor(f_nkhw2nhwk, N, Ho, Wo, K)(std::thread::hardware_concurrency());
|
||||
|
||||
std::size_t data_sz = sizeof(T);
|
||||
|
||||
DeviceMem in_nhwc_device_buf(data_sz * in_nhwc.mDesc.GetElementSpace());
|
||||
DeviceMem wei_kyxc_device_buf(data_sz * wei_kyxc.mDesc.GetElementSpace());
|
||||
DeviceMem out_nhwk_device_buf(data_sz * out_nhwk.mDesc.GetElementSpace());
|
||||
|
||||
in_nhwc_device_buf.ToDevice(in_nhwc.mData.data());
|
||||
wei_kyxc_device_buf.ToDevice(wei_kyxc.mData.data());
|
||||
out_nhwk_device_buf.ToDevice(out_nhwk.mData.data());
|
||||
|
||||
#if 1
|
||||
// cdata = 16, BlockSize = 64, 16x64x4
|
||||
constexpr index_t BlockSize = 64;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 16;
|
||||
constexpr index_t GemmNPerBlock = 64;
|
||||
constexpr index_t GemmKPerBlock = 4;
|
||||
|
||||
constexpr index_t GemmMPerThread = 2;
|
||||
constexpr index_t GemmNPerThread = 2;
|
||||
constexpr index_t GemmKPerThread = 1;
|
||||
|
||||
constexpr index_t GemmMLevel0Cluster = 2;
|
||||
constexpr index_t GemmNLevel0Cluster = 2;
|
||||
constexpr index_t GemmMLevel1Cluster = 2;
|
||||
constexpr index_t GemmNLevel1Cluster = 8;
|
||||
|
||||
constexpr index_t ThreadGemmDataPerReadM = 2;
|
||||
constexpr index_t ThreadGemmDataPerReadN = 2;
|
||||
|
||||
using GemmABlockCopyThreadSliceLengths_GemmK_GemmM = Sequence<1, 1>;
|
||||
using GemmABlockCopyThreadClusterLengths_GemmK_GemmM = Sequence<4, 16>;
|
||||
|
||||
constexpr index_t GemmABlockCopySrcDataPerRead_GemmK = 1;
|
||||
constexpr index_t GemmABlockCopyDstDataPerWrite_GemmM = 1;
|
||||
|
||||
using GemmBBlockCopyThreadSliceLengths_GemmK_GemmN = Sequence<4, 1>;
|
||||
using GemmBBlockCopyThreadClusterLengths_GemmK_GemmN = Sequence<1, 64>;
|
||||
|
||||
constexpr index_t GemmBBlockCopySrcDataPerRead_GemmK = 4;
|
||||
constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 1;
|
||||
|
||||
constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmM1 = 2;
|
||||
#endif
|
||||
|
||||
constexpr index_t GemmM = K;
|
||||
constexpr index_t GemmN = N * Ho * Wo;
|
||||
|
||||
constexpr index_t GridSize = math::integer_divide_ceil(GemmM, GemmMPerBlock) *
|
||||
math::integer_divide_ceil(GemmN, GemmNPerBlock);
|
||||
|
||||
printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize);
|
||||
|
||||
using gridwise_conv = GridwiseConvolutionForwardImplicitGemm_v4r4_nhwc_kyxc_nhwk<
|
||||
GridSize,
|
||||
BlockSize,
|
||||
TDevice,
|
||||
TDevice,
|
||||
decltype(in_n_hi_wi_c_desc),
|
||||
decltype(wei_k_y_x_c_desc),
|
||||
decltype(out_n_ho_wo_k_desc),
|
||||
ConvStrides,
|
||||
ConvDilations,
|
||||
InLeftPads,
|
||||
InRightPads,
|
||||
GemmMPerBlock,
|
||||
GemmNPerBlock,
|
||||
GemmKPerBlock,
|
||||
GemmMPerThread,
|
||||
GemmNPerThread,
|
||||
GemmKPerThread,
|
||||
GemmMLevel0Cluster,
|
||||
GemmNLevel0Cluster,
|
||||
GemmMLevel1Cluster,
|
||||
GemmNLevel1Cluster,
|
||||
ThreadGemmDataPerReadM,
|
||||
ThreadGemmDataPerReadN,
|
||||
GemmABlockCopyThreadSliceLengths_GemmK_GemmM,
|
||||
GemmABlockCopyThreadClusterLengths_GemmK_GemmM,
|
||||
GemmABlockCopySrcDataPerRead_GemmK,
|
||||
GemmABlockCopyDstDataPerWrite_GemmM,
|
||||
GemmBBlockCopyThreadSliceLengths_GemmK_GemmN,
|
||||
GemmBBlockCopyThreadClusterLengths_GemmK_GemmN,
|
||||
GemmBBlockCopySrcDataPerRead_GemmK,
|
||||
GemmBBlockCopyDstDataPerWrite_GemmN,
|
||||
GemmCThreadCopyDstDataPerWrite_GemmM1>;
|
||||
|
||||
for(index_t i = 0; i < 5; ++i)
|
||||
{
|
||||
std::cout << "Start running " << nrepeat << " times..." << std::endl;
|
||||
|
||||
KernelTimer timer;
|
||||
timer.Start();
|
||||
|
||||
for(index_t j = 0; j < nrepeat; ++j)
|
||||
{
|
||||
launch_kernel(run_gridwise_operation<gridwise_conv,
|
||||
const TDevice* const __restrict__,
|
||||
const TDevice* const __restrict__,
|
||||
TDevice* const __restrict__>,
|
||||
dim3(GridSize),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
0,
|
||||
static_cast<TDevice*>(in_nhwc_device_buf.GetDeviceBuffer()),
|
||||
static_cast<TDevice*>(wei_kyxc_device_buf.GetDeviceBuffer()),
|
||||
static_cast<TDevice*>(out_nhwk_device_buf.GetDeviceBuffer()));
|
||||
}
|
||||
|
||||
timer.End();
|
||||
|
||||
float ave_time = timer.GetElapsedTime() / nrepeat;
|
||||
|
||||
float perf = (float)calculate_convolution_flops(InDesc{}, WeiDesc{}, OutDesc{}) /
|
||||
(std::size_t(1000) * 1000 * 1000) / ave_time;
|
||||
|
||||
std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" << std::endl;
|
||||
}
|
||||
|
||||
out_nhwk_device_buf.FromDevice(out_nhwk.mData.data());
|
||||
|
||||
auto f_nhwk2nkhw = [&](auto n, auto k, auto ho, auto wo) {
|
||||
out_nkhw(n, k, ho, wo) = out_nhwk(n, ho, wo, k);
|
||||
};
|
||||
|
||||
make_ParallelTensorFunctor(f_nhwk2nkhw, N, K, Ho, Wo)(std::thread::hardware_concurrency());
|
||||
}
|
||||
@@ -257,7 +257,6 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nk
|
||||
GemmKPerBlock,
|
||||
GemmMPerWave,
|
||||
GemmNPerWave,
|
||||
GemmK1,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1,
|
||||
|
||||
@@ -57,7 +57,35 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_xdlops_nhwc_kyxc_nh
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(out_n_ho_wo_k_lengths);
|
||||
|
||||
#if 1
|
||||
// [M, N, K0, K1] = [256, 128, 4, 8]
|
||||
// [M, N, K0, K1] = [256, 128, 4, 4] for fp32
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 256;
|
||||
constexpr index_t GemmNPerBlock = 128;
|
||||
constexpr index_t GemmKPerBlock = 4;
|
||||
|
||||
constexpr index_t GemmMPerWave = 64;
|
||||
constexpr index_t GemmNPerWave = 64;
|
||||
constexpr index_t GemmK1 = 4;
|
||||
|
||||
constexpr index_t MRepeat = 2;
|
||||
constexpr index_t NRepeat = 1;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 4>;
|
||||
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 4;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 4;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 4>;
|
||||
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 4;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 4;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector = 4;
|
||||
#elif 1
|
||||
// [M, N, K0, K1] = [256, 128, 4, 8] for fp16
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 256;
|
||||
|
||||
@@ -56,7 +56,63 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r3_xdlops_nhwc_kyxc_nh
|
||||
const auto out_n_ho_wo_k_desc =
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(out_n_ho_wo_k_lengths);
|
||||
|
||||
#if 0
|
||||
#if 1
|
||||
// [M, N, K0, K1] = [256, 128, 4, 4] for fp32
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 256;
|
||||
constexpr index_t GemmNPerBlock = 128;
|
||||
constexpr index_t GemmKPerBlock = 4;
|
||||
|
||||
constexpr index_t GemmMPerWave = 32;
|
||||
constexpr index_t GemmNPerWave = 32;
|
||||
constexpr index_t GemmK1 = 4;
|
||||
|
||||
constexpr index_t MRepeat = 4;
|
||||
constexpr index_t NRepeat = 2;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 4>;
|
||||
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 4;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 4;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 4>;
|
||||
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 4;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 4;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector = 4;
|
||||
#elif 1
|
||||
// [M, N, K0, K1] = [128, 128, 4, 4] for fp32
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 128;
|
||||
constexpr index_t GemmNPerBlock = 128;
|
||||
constexpr index_t GemmKPerBlock = 4;
|
||||
|
||||
constexpr index_t GemmMPerWave = 32;
|
||||
constexpr index_t GemmNPerWave = 32;
|
||||
constexpr index_t GemmK1 = 4;
|
||||
|
||||
constexpr index_t MRepeat = 2;
|
||||
constexpr index_t NRepeat = 2;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 2, 4>;
|
||||
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 4;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 4;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 4>;
|
||||
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 4;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 4;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector = 4;
|
||||
#elif 0
|
||||
// [M, N, K0, K1] = [256, 256, 4, 8] for fp16
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
@@ -111,34 +167,6 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r3_xdlops_nhwc_kyxc_nh
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 8;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector = 4;
|
||||
#elif 0
|
||||
// [M, N, K0, K1] = [128, 128, 4, 4] for fp32
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 128;
|
||||
constexpr index_t GemmNPerBlock = 128;
|
||||
constexpr index_t GemmKPerBlock = 4;
|
||||
|
||||
constexpr index_t GemmMPerWave = 32;
|
||||
constexpr index_t GemmNPerWave = 32;
|
||||
constexpr index_t GemmK1 = 4;
|
||||
|
||||
constexpr index_t MRepeat = 2;
|
||||
constexpr index_t NRepeat = 2;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 2, 4>;
|
||||
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 4;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 4;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 4>;
|
||||
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 4;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 4;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector = 4;
|
||||
#endif
|
||||
|
||||
|
||||
@@ -56,7 +56,63 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nh
|
||||
const auto out_n_ho_wo_k_desc =
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(out_n_ho_wo_k_lengths);
|
||||
|
||||
#if 0
|
||||
#if 1
|
||||
// [M, N, K0, K1] = [256, 128, 4, 4] for fp32
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 256;
|
||||
constexpr index_t GemmNPerBlock = 128;
|
||||
constexpr index_t GemmKPerBlock = 4;
|
||||
|
||||
constexpr index_t GemmMPerWave = 32;
|
||||
constexpr index_t GemmNPerWave = 32;
|
||||
constexpr index_t GemmK1 = 4;
|
||||
|
||||
constexpr index_t MRepeat = 4;
|
||||
constexpr index_t NRepeat = 2;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 4>;
|
||||
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 4;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 4;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 4>;
|
||||
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 4;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 4;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
|
||||
#elif 1
|
||||
// [M, N, K0, K1] = [128, 128, 4, 4] for fp32
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 128;
|
||||
constexpr index_t GemmNPerBlock = 128;
|
||||
constexpr index_t GemmKPerBlock = 4;
|
||||
|
||||
constexpr index_t GemmMPerWave = 32;
|
||||
constexpr index_t GemmNPerWave = 32;
|
||||
constexpr index_t GemmK1 = 4;
|
||||
|
||||
constexpr index_t MRepeat = 2;
|
||||
constexpr index_t NRepeat = 2;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 2, 4>;
|
||||
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 4;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 4;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 4>;
|
||||
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 4;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 4;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
|
||||
#elif 1
|
||||
// [M, N, K0, K1] = [256, 256, 4, 8] for fp16
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
@@ -139,34 +195,6 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nh
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 8;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
|
||||
#elif 1
|
||||
// [M, N, K0, K1] = [128, 128, 4, 4] for fp32
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 128;
|
||||
constexpr index_t GemmNPerBlock = 128;
|
||||
constexpr index_t GemmKPerBlock = 4;
|
||||
|
||||
constexpr index_t GemmMPerWave = 32;
|
||||
constexpr index_t GemmNPerWave = 32;
|
||||
constexpr index_t GemmK1 = 4;
|
||||
|
||||
constexpr index_t MRepeat = 2;
|
||||
constexpr index_t NRepeat = 2;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 2, 4>;
|
||||
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 4;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 4;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 4>;
|
||||
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 4;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 4;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
|
||||
#elif 1
|
||||
// [M, N, K0, K1] = [128, 128, 4, 8] for fp16
|
||||
|
||||
@@ -215,7 +215,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r5r2_nchw_kcyx_nkhw(
|
||||
Sequence<4, 3, 2, 0, 1>, // BBlockTransferSrcAccessOrder
|
||||
GemmBBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1,
|
||||
GemmBBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1,
|
||||
Sequence<0, 1, 2, 3, 4>, // BBlockTransferSrcVectorTensorContiguousDimOrder
|
||||
Sequence<0, 1, 2, 3, 4>, // BBlockTransferSrcVectorTensorContiguousDimOrder
|
||||
Sequence<3, 4, 5, 0, 1, 2>, // CThreadTransferSrcDstAccessOrder
|
||||
5, // CThreadTransferSrcDstVectorDim
|
||||
GemmCThreadTransferDstScalarPerVector_BN1,
|
||||
|
||||
@@ -1,23 +1,6 @@
|
||||
#pragma once
|
||||
#include "host_tensor.hpp"
|
||||
#include "common_header.hpp"
|
||||
#include "tensor_descriptor.hpp"
|
||||
|
||||
template <typename TensorDesc, std::size_t... Is>
|
||||
auto make_HostTensorDescriptor_impl(TensorDesc, std::integer_sequence<std::size_t, Is...>)
|
||||
{
|
||||
std::initializer_list<std::size_t> lengths = {TensorDesc::GetLengths()[Is]...};
|
||||
std::initializer_list<std::size_t> strides = {TensorDesc::GetStrides()[Is]...};
|
||||
|
||||
return HostTensorDescriptor(lengths, strides);
|
||||
}
|
||||
|
||||
template <typename TensorDesc>
|
||||
auto make_HostTensorDescriptor(TensorDesc)
|
||||
{
|
||||
return make_HostTensorDescriptor_impl(
|
||||
TensorDesc{}, std::make_integer_sequence<std::size_t, TensorDesc::GetNumOfDimension()>{});
|
||||
}
|
||||
|
||||
template <typename TensorDesc>
|
||||
void ostream_tensor_descriptor(TensorDesc, std::ostream& os = std::cout)
|
||||
|
||||
@@ -1,20 +0,0 @@
|
||||
#!/bin/bash
|
||||
|
||||
MY_PROJECT_SOURCE=../../../
|
||||
|
||||
export CUDA_ROOT=/usr/local/cuda
|
||||
export CPATH=$CPATH:$CUDA_ROOT/include
|
||||
export LIBRARY_PATH=$LIBRARY_PATH:$CUDA_ROOT/lib64
|
||||
export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$CUDA_ROOT/lib64
|
||||
|
||||
cmake \
|
||||
-D CMAKE_CXX_COMPILER=clang++ \
|
||||
-D CMAKE_BUILD_TYPE=Release \
|
||||
-D CMAKE_VERBOSE_MAKEFILE:BOOL=ON \
|
||||
-D DEVICE_BACKEND=NVIDIA \
|
||||
-D CMAKE_CUDA_FLAGS="-ccbin clang++ -m64 -Xcompiler -fopenmp -lineinfo --source-in-ptx -keep -Xptxas -v -gencode=arch=compute_61,code=sm_61 -Xptxas -v -maxrregcount=128" \
|
||||
${MY_PROJECT_SOURCE}
|
||||
|
||||
|
||||
#-D CMAKE_CUDA_FLAGS="-ccbin clang++ -m64 -Xcompiler -fopenmp -lineinfo --source-in-ptx -keep -Xptxas -v -gencode=arch=compute_70,code=sm_70" \
|
||||
#-D CMAKE_CUDA_FLAGS="-ccbin clang++ -m64 -Xcompiler -fopenmp -lineinfo --source-in-ptx -keep -Xptxas -v -gencode=arch=compute_70,code=sm_70 -Xptxas -v -maxrregcount=128" \
|
||||
@@ -1,3 +0,0 @@
|
||||
DRIVER=$1
|
||||
ARCH=$2
|
||||
cuobjdump -xelf $ARCH ./driver/$DRIVER && nvdisasm --print-code -g $DRIVER.$ARCH.cubin > $DRIVER.$ARCH.asm
|
||||
Reference in New Issue
Block a user