mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 02:02:46 +00:00
backward data (#7)
* enabled atomic add in tensor copy
* added gridwise GEMM
* added backward data conv using GEMM + atomic
* added backward data conv using GEMM, no atomic
[ROCm/composable_kernel commit: 8f5f64960e]
This commit is contained in:
@@ -55,9 +55,11 @@ include_directories(BEFORE
|
||||
if(DEVICE_BACKEND STREQUAL "AMD")
|
||||
configure_file("${PROJECT_SOURCE_DIR}/composable_kernel/include/utility/config.amd.hpp.in" "${PROJECT_BINARY_DIR}/composable_kernel/include/utility/config.hpp")
|
||||
configure_file("${PROJECT_SOURCE_DIR}/composable_kernel/include/utility/float_type.amd.hpp.in" "${PROJECT_BINARY_DIR}/composable_kernel/include/utility/float_type.hpp")
|
||||
configure_file("${PROJECT_SOURCE_DIR}/composable_kernel/include/utility/in_memory_operation.amd.hpp.in" "${PROJECT_BINARY_DIR}/composable_kernel/include/utility/in_memory_operation.hpp")
|
||||
elseif(DEVICE_BACKEND STREQUAL "NVIDIA")
|
||||
configure_file("${PROJECT_SOURCE_DIR}/composable_kernel/include/utility/config.nvidia.hpp.in" "${PROJECT_BINARY_DIR}/composable_kernel/include/utility/config.hpp")
|
||||
configure_file("${PROJECT_SOURCE_DIR}/composable_kernel/include/utility/float_type.nvidia.hpp.in" "${PROJECT_BINARY_DIR}/composable_kernel/include/utility/float_type.hpp")
|
||||
configure_file("${PROJECT_SOURCE_DIR}/composable_kernel/include/utility/in_memory_operation.nvidia.hpp.in" "${PROJECT_BINARY_DIR}/composable_kernel/include/utility/in_memory_operation.hpp")
|
||||
endif()
|
||||
|
||||
add_subdirectory(driver)
|
||||
|
||||
10
composable_kernel/include/gridwise_operation_wrapper.hpp
Normal file
10
composable_kernel/include/gridwise_operation_wrapper.hpp
Normal file
@@ -0,0 +1,10 @@
|
||||
#ifndef CK_GRIDWISE_OPERATION_KERNEL_WRAPPER
|
||||
#define CK_GRIDWISE_OPERATION_KERNEL_WRAPPER
|
||||
|
||||
template <typename GridwiseOp, typename... Xs>
|
||||
__global__ void run_gridwise_operation(GridwiseOp, Xs... xs)
|
||||
{
|
||||
GridwiseOp{}.Run(xs...);
|
||||
}
|
||||
|
||||
#endif
|
||||
@@ -0,0 +1,130 @@
|
||||
#ifndef CK_GRIDWISE_COL2IM_EB_NCHW_HPP
|
||||
#define CK_GRIDWISE_COL2IM_EB_NCHW_HPP
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "tensor_descriptor.hpp"
|
||||
#include "tensor_descriptor_helper.hpp"
|
||||
#include "blockwise_generic_tensor_slice_copy.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
// B = merge(N, Ho, Wo)
|
||||
template <index_t GridSize,
|
||||
index_t BlockSize,
|
||||
typename Float,
|
||||
typename ColGlobalDesc,
|
||||
typename ImgGlobalDesc,
|
||||
typename FilterSizes,
|
||||
typename OutputSizes,
|
||||
typename ConvStrides,
|
||||
typename ConvDilations,
|
||||
typename LeftPads,
|
||||
typename RightPads,
|
||||
index_t EPerBlock,
|
||||
index_t BPerBlock,
|
||||
typename BlockCopySubLengths_E_B,
|
||||
typename BlockCopyClusterLengths_E_B,
|
||||
typename BlockCopyThreadClusterArrangeOrder,
|
||||
typename BlockCopySrcAccessOrder,
|
||||
typename BlockCopyDstAccessOrder,
|
||||
index_t BlockCopyDataPerAccess_B>
|
||||
struct GridwiseCol2Im_eb_nchw
|
||||
{
|
||||
__device__ void Run(const Float* const __restrict__ p_col_global,
|
||||
Float* const __restrict__ p_img_global) const
|
||||
{
|
||||
constexpr auto col_e_b_global_desc = ColGlobalDesc{};
|
||||
constexpr auto img_n_c_hi_wi_global_desc = ImgGlobalDesc{};
|
||||
|
||||
constexpr index_t N = img_n_c_hi_wi_global_desc.GetLengths()[0];
|
||||
constexpr index_t C = img_n_c_hi_wi_global_desc.GetLengths()[1];
|
||||
constexpr index_t Hi = img_n_c_hi_wi_global_desc.GetLengths()[2];
|
||||
constexpr index_t Wi = img_n_c_hi_wi_global_desc.GetLengths()[3];
|
||||
|
||||
constexpr index_t Ho = OutputSizes{}[0];
|
||||
constexpr index_t Wo = OutputSizes{}[1];
|
||||
|
||||
constexpr index_t Y = FilterSizes{}[0];
|
||||
constexpr index_t X = FilterSizes{}[1];
|
||||
|
||||
constexpr index_t ConvStrideH = ConvStrides{}[0];
|
||||
constexpr index_t ConvStrideW = ConvStrides{}[1];
|
||||
|
||||
constexpr index_t ConvDilationH = ConvDilations{}[0];
|
||||
constexpr index_t ConvDilationW = ConvDilations{}[1];
|
||||
|
||||
constexpr index_t E = C * Y * X;
|
||||
constexpr index_t B = N * Ho * Wo;
|
||||
|
||||
// sanity-check for vectorized memory load
|
||||
static_assert((Wo == 1 || (ConvStrideW == 1 || BlockCopyDataPerAccess_B == 1)) &&
|
||||
(X == 1 || ConvDilationW % BlockCopyDataPerAccess_B == 0),
|
||||
"wrong! aligment requirement for vectorized global load of input tensor will "
|
||||
"be violated");
|
||||
|
||||
// divide block work by [E, B]
|
||||
static_assert(E % EPerBlock == 0 && B % BPerBlock == 0,
|
||||
"wrong! cannot divide work evenly among block");
|
||||
|
||||
constexpr index_t EBlockWork = E / EPerBlock;
|
||||
constexpr index_t BBlockWork = B / BPerBlock;
|
||||
|
||||
constexpr auto block_work_desc =
|
||||
make_cluster_descriptor(Sequence<EBlockWork, BBlockWork>{});
|
||||
|
||||
const auto block_work_id = block_work_desc.CalculateClusterIndex(get_block_1d_id());
|
||||
|
||||
const index_t e_block_data_on_global = block_work_id[0] * EPerBlock;
|
||||
const index_t b_block_data_on_global = block_work_id[1] * BPerBlock;
|
||||
|
||||
// construct img_eb_global_desc
|
||||
constexpr auto img_n_c_hip_wip_global_desc = transform_tensor_descriptor(
|
||||
img_n_c_hi_wi_global_desc,
|
||||
make_tuple(
|
||||
PassThrough<N>{}, PassThrough<C>{}, Pad<Sequence<Hi, Wi>, LeftPads, RightPads>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}));
|
||||
|
||||
constexpr auto img_n_c_y_ho_x_wo_global_desc = transform_tensor_descriptor(
|
||||
img_n_c_hip_wip_global_desc,
|
||||
make_tuple(PassThrough<N>{},
|
||||
PassThrough<C>{},
|
||||
Embed<Sequence<Y, Ho>, Sequence<ConvDilationH, ConvStrideH, 0>>{},
|
||||
Embed<Sequence<X, Wo>, Sequence<ConvDilationW, ConvStrideW, 0>>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}));
|
||||
|
||||
constexpr auto img_e_b_global_desc = transform_tensor_descriptor(
|
||||
img_n_c_y_ho_x_wo_global_desc,
|
||||
make_tuple(Merge<Sequence<C, Y, X>>{}, Merge<Sequence<N, Ho, Wo>>{}),
|
||||
make_tuple(Sequence<1, 2, 4>{}, Sequence<0, 3, 5>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
// blockwise atomic accumulation
|
||||
auto blockwise_copy = BlockwiseGenericTensorSliceCopy_v4<BlockSize,
|
||||
decltype(col_e_b_global_desc),
|
||||
decltype(img_e_b_global_desc),
|
||||
Sequence<EPerBlock, BPerBlock>,
|
||||
BlockCopySubLengths_E_B,
|
||||
BlockCopyClusterLengths_E_B,
|
||||
BlockCopyThreadClusterArrangeOrder,
|
||||
BlockCopySrcAccessOrder,
|
||||
BlockCopyDstAccessOrder,
|
||||
1,
|
||||
1,
|
||||
BlockCopyDataPerAccess_B,
|
||||
BlockCopyDataPerAccess_B,
|
||||
AddressSpace::vgpr,
|
||||
AddressSpace::vgpr,
|
||||
AddressSpace::global,
|
||||
InMemoryDataOperation::atomic_add>(
|
||||
{e_block_data_on_global, b_block_data_on_global},
|
||||
{e_block_data_on_global, b_block_data_on_global});
|
||||
|
||||
// blockwise copy
|
||||
blockwise_copy.Run(p_col_global, p_img_global);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
@@ -0,0 +1,157 @@
|
||||
#ifndef CK_GRIDWISE_CONVOLUTION_BACKWARD_DATA_IMPLICIT_GEMM_V1R1_NCHW_KCYX_NKHW_HPP
|
||||
#define CK_GRIDWISE_CONVOLUTION_BACKWARD_DATA_IMPLICIT_GEMM_V1R1_NCHW_KCYX_NKHW_HPP
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "tensor_descriptor.hpp"
|
||||
#include "tensor_descriptor_helper.hpp"
|
||||
#include "gridwise_gemm.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <index_t GridSize,
|
||||
index_t BlockSize,
|
||||
typename Float,
|
||||
typename AccFloat,
|
||||
typename InGlobalDesc,
|
||||
typename WeiGlobalDesc,
|
||||
typename OutGlobalDesc,
|
||||
typename ConvStrides,
|
||||
typename ConvDilations,
|
||||
typename LeftPads,
|
||||
typename RightPads,
|
||||
index_t EPerBlock,
|
||||
index_t BPerBlock,
|
||||
index_t KPerBlock,
|
||||
index_t GemmMPerThreadSubC,
|
||||
index_t GemmNPerThreadSubC,
|
||||
index_t GemmMLevel0Cluster,
|
||||
index_t GemmNLevel0Cluster,
|
||||
index_t GemmMLevel1Cluster,
|
||||
index_t GemmNLevel1Cluster,
|
||||
index_t GemmKPerThreadLoop,
|
||||
index_t GemmThreadGemmDataPerReadM,
|
||||
index_t GemmThreadGemmDataPerReadN,
|
||||
typename WeiBlockCopySubLengths_K_E,
|
||||
typename WeiBlockCopyClusterLengths_K_E,
|
||||
index_t WeiBlockCopyDataPerAccess_E,
|
||||
typename OutBlockCopySubLengths_K_B,
|
||||
typename OutBlockCopyClusterLengths_K_B,
|
||||
index_t OutBlockCopyDataPerAccess_B,
|
||||
index_t InThreadCopyDataPerAccess_B>
|
||||
struct GridwiseConvolutionBackwardDataImplicitGemm_v1r1_nchw_kcyx_nkhw
|
||||
{
|
||||
__device__ void Run(Float* __restrict__ p_in_global,
|
||||
const Float* __restrict__ p_wei_global,
|
||||
const Float* __restrict__ p_out_global) const
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
constexpr auto True = integral_constant<bool, true>{};
|
||||
|
||||
constexpr auto in_n_c_hi_wi_global_desc = InGlobalDesc{};
|
||||
constexpr auto wei_k_c_y_x_global_desc = WeiGlobalDesc{};
|
||||
constexpr auto out_n_k_ho_wo_global_desc = OutGlobalDesc{};
|
||||
|
||||
constexpr index_t N = in_n_c_hi_wi_global_desc.GetLengths()[0];
|
||||
constexpr index_t C = in_n_c_hi_wi_global_desc.GetLengths()[1];
|
||||
constexpr index_t Hi = in_n_c_hi_wi_global_desc.GetLengths()[2];
|
||||
constexpr index_t Wi = in_n_c_hi_wi_global_desc.GetLengths()[3];
|
||||
|
||||
constexpr index_t K = out_n_k_ho_wo_global_desc.GetLengths()[1];
|
||||
constexpr index_t Ho = out_n_k_ho_wo_global_desc.GetLengths()[2];
|
||||
constexpr index_t Wo = out_n_k_ho_wo_global_desc.GetLengths()[3];
|
||||
|
||||
constexpr index_t Y = wei_k_c_y_x_global_desc.GetLengths()[2];
|
||||
constexpr index_t X = wei_k_c_y_x_global_desc.GetLengths()[3];
|
||||
|
||||
constexpr index_t ConvStrideH = ConvStrides{}[0];
|
||||
constexpr index_t ConvStrideW = ConvStrides{}[1];
|
||||
|
||||
constexpr index_t ConvDilationH = ConvDilations{}[0];
|
||||
constexpr index_t ConvDilationW = ConvDilations{}[1];
|
||||
|
||||
constexpr index_t E = C * Y * X;
|
||||
constexpr index_t B = N * Ho * Wo;
|
||||
|
||||
// sanity-check for vectorized memory load
|
||||
static_assert((Wo == 1 || (ConvStrideW == 1 || InThreadCopyDataPerAccess_B == 1)) &&
|
||||
(X == 1 || ConvDilationW % InThreadCopyDataPerAccess_B == 0),
|
||||
"wrong! aligment requirement for vectorized global load of input tensor will "
|
||||
"be violated");
|
||||
|
||||
// output tensor
|
||||
constexpr auto out_n_k_howo_global_desc =
|
||||
unfold_tensor_descriptor(out_n_k_ho_wo_global_desc, I2, I3);
|
||||
|
||||
constexpr auto out_k_b_global_desc =
|
||||
transform_tensor_descriptor(out_n_k_howo_global_desc,
|
||||
make_tuple(PassThrough<K>{}, Merge<Sequence<N, Ho * Wo>>{}),
|
||||
make_tuple(Sequence<1>{}, Sequence<0, 2>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
// weight tensor
|
||||
constexpr auto wei_k_e_global_desc =
|
||||
unfold_tensor_descriptor(wei_k_c_y_x_global_desc, I1, I3);
|
||||
|
||||
// input tensor
|
||||
constexpr auto in_n_c_hip_wip_global_desc = transform_tensor_descriptor(
|
||||
in_n_c_hi_wi_global_desc,
|
||||
make_tuple(
|
||||
PassThrough<N>{}, PassThrough<C>{}, Pad<Sequence<Hi, Wi>, LeftPads, RightPads>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}));
|
||||
|
||||
constexpr auto in_n_c_y_ho_x_wo_global_desc = transform_tensor_descriptor(
|
||||
in_n_c_hip_wip_global_desc,
|
||||
make_tuple(PassThrough<N>{},
|
||||
PassThrough<C>{},
|
||||
Embed<Sequence<Y, Ho>, Sequence<ConvDilationH, ConvStrideH, 0>>{},
|
||||
Embed<Sequence<X, Wo>, Sequence<ConvDilationW, ConvStrideW, 0>>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}));
|
||||
|
||||
constexpr auto in_e_b_global_desc = transform_tensor_descriptor(
|
||||
in_n_c_y_ho_x_wo_global_desc,
|
||||
make_tuple(Merge<Sequence<C, Y, X>>{}, Merge<Sequence<N, Ho, Wo>>{}),
|
||||
make_tuple(Sequence<1, 2, 4>{}, Sequence<0, 3, 5>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
// GEMM: atomic add
|
||||
constexpr auto gridwise_gemm =
|
||||
GridwiseGemmTransposedANormalBNormalC_v1r1<GridSize,
|
||||
BlockSize,
|
||||
Float,
|
||||
AccFloat,
|
||||
decltype(wei_k_e_global_desc),
|
||||
decltype(out_k_b_global_desc),
|
||||
decltype(in_e_b_global_desc),
|
||||
InMemoryDataOperation::atomic_add,
|
||||
EPerBlock,
|
||||
BPerBlock,
|
||||
KPerBlock,
|
||||
GemmMPerThreadSubC,
|
||||
GemmNPerThreadSubC,
|
||||
GemmMLevel0Cluster,
|
||||
GemmNLevel0Cluster,
|
||||
GemmMLevel1Cluster,
|
||||
GemmNLevel1Cluster,
|
||||
GemmKPerThreadLoop,
|
||||
GemmThreadGemmDataPerReadM,
|
||||
GemmThreadGemmDataPerReadN,
|
||||
WeiBlockCopySubLengths_K_E,
|
||||
WeiBlockCopyClusterLengths_K_E,
|
||||
WeiBlockCopyDataPerAccess_E,
|
||||
OutBlockCopySubLengths_K_B,
|
||||
OutBlockCopyClusterLengths_K_B,
|
||||
OutBlockCopyDataPerAccess_B,
|
||||
InThreadCopyDataPerAccess_B>{};
|
||||
|
||||
gridwise_gemm.Run(p_wei_global, p_out_global, p_in_global);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
@@ -0,0 +1,438 @@
|
||||
#ifndef CK_GRIDWISE_CONVOLUTION_BACKWARD_DATA_IMPLICIT_GEMM_V1R2_NCHW_KCYX_NKHW_LDS_DOUBLE_BUFFER_HPP
|
||||
#define CK_GRIDWISE_CONVOLUTION_BACKWARD_DATA_IMPLICIT_GEMM_V1R2_NCHW_KCYX_NKHW_LDS_DOUBLE_BUFFER_HPP
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "tensor_descriptor.hpp"
|
||||
#include "tensor_descriptor_helper.hpp"
|
||||
#include "ConstantMatrixDescriptor.hpp"
|
||||
#include "blockwise_generic_tensor_slice_copy.hpp"
|
||||
#include "threadwise_generic_tensor_slice_copy.hpp"
|
||||
#include "blockwise_gemm.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <index_t GridSize,
|
||||
index_t BlockSize,
|
||||
typename Float,
|
||||
typename AccFloat,
|
||||
typename InGlobalDesc,
|
||||
typename WeiGlobalDesc,
|
||||
typename OutGlobalDesc,
|
||||
typename ConvStrides,
|
||||
typename ConvDilations,
|
||||
typename LeftPads,
|
||||
typename RightPads,
|
||||
index_t EPerBlock,
|
||||
index_t BPerBlock,
|
||||
index_t KPerBlock,
|
||||
index_t GemmMPerThreadSubC,
|
||||
index_t GemmNPerThreadSubC,
|
||||
index_t GemmMLevel0Cluster,
|
||||
index_t GemmNLevel0Cluster,
|
||||
index_t GemmMLevel1Cluster,
|
||||
index_t GemmNLevel1Cluster,
|
||||
index_t GemmKPerThreadLoop,
|
||||
index_t GemmDataPerReadA,
|
||||
index_t GemmDataPerReadB,
|
||||
typename OutBlockCopySubLengths_K_B_N0,
|
||||
typename OutBlockCopyClusterLengths_K_B_N0,
|
||||
index_t OutBlockCopySrcDataPerRead_B,
|
||||
index_t OutBlockCopyDstDataPerWrite_N0,
|
||||
typename WeiBlockCopySubLengths_K_E_C0,
|
||||
typename WeiBlockCopyClusterLengths_K_E_C0,
|
||||
index_t WeiBlockCopySrcDataPerRead_E,
|
||||
index_t WeiBlockCopyDstDataPerWrite_C0,
|
||||
index_t InThreadCopyDstDataPerWrite_B>
|
||||
struct GridwiseConvolutionBackwardDataImplicitGemm_v1r2_nchw_kcyx_nkhw_lds_double_buffer
|
||||
{
|
||||
__device__ void Run(Float* const __restrict__ p_in_global,
|
||||
const Float* const __restrict__ p_wei_global,
|
||||
const Float* const __restrict__ p_out_global) const
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
constexpr auto True = integral_constant<bool, true>{};
|
||||
|
||||
constexpr auto in_n_c_hi_wi_global_desc = InGlobalDesc{};
|
||||
constexpr auto wei_k_c_y_x_global_desc = WeiGlobalDesc{};
|
||||
constexpr auto out_n_k_ho_wo_global_desc = OutGlobalDesc{};
|
||||
|
||||
constexpr index_t N = in_n_c_hi_wi_global_desc.GetLengths()[0];
|
||||
constexpr index_t C = in_n_c_hi_wi_global_desc.GetLengths()[1];
|
||||
constexpr index_t Hi = in_n_c_hi_wi_global_desc.GetLengths()[2];
|
||||
constexpr index_t Wi = in_n_c_hi_wi_global_desc.GetLengths()[3];
|
||||
|
||||
constexpr index_t K = out_n_k_ho_wo_global_desc.GetLengths()[1];
|
||||
constexpr index_t Ho = out_n_k_ho_wo_global_desc.GetLengths()[2];
|
||||
constexpr index_t Wo = out_n_k_ho_wo_global_desc.GetLengths()[3];
|
||||
|
||||
constexpr index_t Y = wei_k_c_y_x_global_desc.GetLengths()[2];
|
||||
constexpr index_t X = wei_k_c_y_x_global_desc.GetLengths()[3];
|
||||
|
||||
constexpr index_t ConvStrideH = ConvStrides{}[0];
|
||||
constexpr index_t ConvStrideW = ConvStrides{}[1];
|
||||
|
||||
constexpr index_t ConvDilationH = ConvDilations{}[0];
|
||||
constexpr index_t ConvDilationW = ConvDilations{}[1];
|
||||
|
||||
constexpr index_t C0 = GemmMPerThreadSubC;
|
||||
constexpr index_t N0 = GemmNPerThreadSubC;
|
||||
|
||||
static_assert(C % C0 == 0 && N % N0 == 0, "wrong!");
|
||||
|
||||
constexpr index_t C1 = C / C0;
|
||||
constexpr index_t N1 = N / N0;
|
||||
|
||||
constexpr index_t E = C1 * Y * X;
|
||||
constexpr index_t B = N1 * Ho * Wo;
|
||||
|
||||
// sanity-check for vectorized memory load
|
||||
static_assert((Wo == 1 || (ConvStrideW == 1 || InThreadCopyDstDataPerWrite_B == 1)) &&
|
||||
(X == 1 || ConvDilationW % InThreadCopyDstDataPerWrite_B == 0),
|
||||
"wrong! aligment requirement for vectorized global load of input tensor will "
|
||||
"be violated");
|
||||
|
||||
// divide block work by [K, B]
|
||||
static_assert(E % EPerBlock == 0 && B % BPerBlock == 0 && K % KPerBlock == 0,
|
||||
"wrong! cannot divide work evenly among block");
|
||||
|
||||
constexpr index_t EBlockWork = E / EPerBlock;
|
||||
constexpr index_t BBlockWork = B / BPerBlock;
|
||||
|
||||
constexpr auto block_work_desc =
|
||||
make_cluster_descriptor(Sequence<EBlockWork, BBlockWork>{});
|
||||
|
||||
const auto block_work_id = block_work_desc.CalculateClusterIndex(get_block_1d_id());
|
||||
|
||||
const index_t e_block_data_on_global = block_work_id[0] * EPerBlock;
|
||||
const index_t b_block_data_on_global = block_work_id[1] * BPerBlock;
|
||||
|
||||
// output tensor
|
||||
// global tensor in global memory, src of blockwise copy
|
||||
constexpr auto out_n_k_howo_global_desc =
|
||||
unfold_tensor_descriptor(out_n_k_ho_wo_global_desc, I2, I3);
|
||||
|
||||
constexpr auto out_n0_n1_k_howo_global_desc = transform_tensor_descriptor(
|
||||
out_n_k_howo_global_desc,
|
||||
make_tuple(UnMerge<Sequence<N0, N1>>{}, PassThrough<K>{}, PassThrough<Ho * Wo>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
|
||||
make_tuple(Sequence<0, 1>{}, Sequence<2>{}, Sequence<3>{}));
|
||||
|
||||
constexpr auto out_k_b_n0_global_desc = transform_tensor_descriptor(
|
||||
out_n0_n1_k_howo_global_desc,
|
||||
make_tuple(PassThrough<K>{}, Merge<Sequence<N1, Ho * Wo>>{}, PassThrough<N0>{}),
|
||||
make_tuple(Sequence<2>{}, Sequence<1, 3>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
|
||||
|
||||
// block tensor in LDS memory, dst of blockwise copy
|
||||
// be careful of LDS alignment
|
||||
constexpr auto out_k_b_n0_block_desc = make_native_tensor_descriptor_aligned(
|
||||
Sequence<KPerBlock, BPerBlock, N0>{}, Number<OutBlockCopyDstDataPerWrite_N0>{});
|
||||
|
||||
// output tensor blockwise copy
|
||||
auto blockwise_out_copy =
|
||||
BlockwiseGenericTensorSliceCopy_v4<BlockSize,
|
||||
decltype(out_k_b_n0_global_desc),
|
||||
decltype(out_k_b_n0_block_desc),
|
||||
decltype(out_k_b_n0_block_desc.GetLengths()),
|
||||
OutBlockCopySubLengths_K_B_N0,
|
||||
OutBlockCopyClusterLengths_K_B_N0,
|
||||
Sequence<0, 1, 2>,
|
||||
Sequence<0, 1, 2>,
|
||||
Sequence<0, 1, 2>,
|
||||
1,
|
||||
2,
|
||||
OutBlockCopySrcDataPerRead_B,
|
||||
OutBlockCopyDstDataPerWrite_N0,
|
||||
AddressSpace::global,
|
||||
AddressSpace::vgpr,
|
||||
AddressSpace::lds,
|
||||
InMemoryDataOperation::none>(
|
||||
{0, b_block_data_on_global, 0}, {0, 0, 0});
|
||||
|
||||
// weight tensor
|
||||
// global tensor in global memory, src of blockwise copy
|
||||
constexpr auto wei_k_cyx_global_desc =
|
||||
unfold_tensor_descriptor(wei_k_c_y_x_global_desc, I1, I3);
|
||||
|
||||
constexpr auto wei_k_c0_e_global_desc =
|
||||
transform_tensor_descriptor(wei_k_cyx_global_desc,
|
||||
make_tuple(PassThrough<K>{}, UnMerge<Sequence<C0, E>>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1, 2>{}));
|
||||
|
||||
constexpr auto wei_k_e_c0_global_desc = reorder_tensor_descriptor_given_lower2upper(
|
||||
wei_k_c0_e_global_desc, Sequence<0, 2, 1>{});
|
||||
|
||||
// block tensor in LDS memory, dst of blockwise copy
|
||||
// be careful of LDS alignment
|
||||
constexpr auto wei_k_e_c0_block_desc = make_native_tensor_descriptor_aligned(
|
||||
Sequence<KPerBlock, EPerBlock, C0>{}, Number<WeiBlockCopyDstDataPerWrite_C0>{});
|
||||
|
||||
// weight tensor blockwise copy
|
||||
auto blockwise_wei_copy =
|
||||
BlockwiseGenericTensorSliceCopy_v4<BlockSize,
|
||||
decltype(wei_k_e_c0_global_desc),
|
||||
decltype(wei_k_e_c0_block_desc),
|
||||
decltype(wei_k_e_c0_block_desc.GetLengths()),
|
||||
WeiBlockCopySubLengths_K_E_C0,
|
||||
WeiBlockCopyClusterLengths_K_E_C0,
|
||||
Sequence<0, 1, 2>,
|
||||
Sequence<0, 1, 2>,
|
||||
Sequence<0, 1, 2>,
|
||||
1,
|
||||
2,
|
||||
WeiBlockCopySrcDataPerRead_E,
|
||||
WeiBlockCopyDstDataPerWrite_C0,
|
||||
AddressSpace::global,
|
||||
AddressSpace::vgpr,
|
||||
AddressSpace::lds,
|
||||
InMemoryDataOperation::none>(
|
||||
{0, e_block_data_on_global, 0}, {0, 0, 0});
|
||||
|
||||
// GEMM definition
|
||||
// c_mtx += transpose(a_mtx) * b_mtx
|
||||
// a_mtx[KPerBlock, EPerBlock*C0] is in LDS
|
||||
// b_mtx[KPerBlocl, BPerBlock*N0] is in LDS
|
||||
// c_mtx[EPerBlock*C0, BPerBlock*N0] is distributed among threads, and saved in
|
||||
// register
|
||||
constexpr auto a_k_ec0_block_mtx_desc = make_ConstantMatrixDescriptor(
|
||||
wei_k_e_c0_block_desc.GetLength(I0),
|
||||
wei_k_e_c0_block_desc.GetLength(I1) * wei_k_e_c0_block_desc.GetLength(I2),
|
||||
wei_k_e_c0_block_desc.GetStride(I0));
|
||||
constexpr auto b_k_bn0_block_mtx_desc = make_ConstantMatrixDescriptor(
|
||||
out_k_b_n0_block_desc.GetLength(I0),
|
||||
out_k_b_n0_block_desc.GetLength(I1) * out_k_b_n0_block_desc.GetLength(I2),
|
||||
out_k_b_n0_block_desc.GetStride(I0));
|
||||
|
||||
// sanity check alignment
|
||||
// TODO: this check is ad-hoc, should enforce it by enforcing alignment of
|
||||
// wei_k_e_c0_block_desc and out_k_b_n0_block_desc
|
||||
static_assert(a_k_ec0_block_mtx_desc.RowStride() % GemmDataPerReadB == 0, "wrong!");
|
||||
static_assert(b_k_bn0_block_mtx_desc.RowStride() % GemmDataPerReadA == 0, "wrong!");
|
||||
|
||||
// sanity check
|
||||
static_assert(EPerBlock % (GemmMLevel0Cluster * GemmMLevel1Cluster) == 0 &&
|
||||
BPerBlock % (GemmNLevel0Cluster * GemmNLevel1Cluster) == 0,
|
||||
"wrong!");
|
||||
|
||||
constexpr index_t GemmMRepeat = EPerBlock / (GemmMLevel0Cluster * GemmMLevel1Cluster);
|
||||
constexpr index_t GemmNRepeat = BPerBlock / (GemmNLevel0Cluster * GemmNLevel1Cluster);
|
||||
|
||||
// c_thread_mtx definition: this is a mess
|
||||
// TODO:: more elegent way of defining c_thread_mtx
|
||||
constexpr auto c_e0e1c0_b0b1n0_thread_mtx_desc = make_ConstantMatrixDescriptor_packed(
|
||||
Number<GemmMRepeat * GemmMPerThreadSubC>{}, Number<GemmNRepeat * GemmNPerThreadSubC>{});
|
||||
|
||||
const auto blockwise_gemm = BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2<
|
||||
BlockSize,
|
||||
decltype(a_k_ec0_block_mtx_desc),
|
||||
decltype(b_k_bn0_block_mtx_desc),
|
||||
decltype(c_e0e1c0_b0b1n0_thread_mtx_desc),
|
||||
GemmMPerThreadSubC,
|
||||
GemmNPerThreadSubC,
|
||||
GemmMLevel0Cluster,
|
||||
GemmNLevel0Cluster,
|
||||
GemmMLevel1Cluster,
|
||||
GemmNLevel1Cluster,
|
||||
GemmKPerThreadLoop,
|
||||
GemmDataPerReadA,
|
||||
GemmDataPerReadB>{};
|
||||
|
||||
// LDS allocation for input and weight: be careful of alignment
|
||||
constexpr index_t max_lds_align = math::lcm(WeiBlockCopyDstDataPerWrite_C0,
|
||||
OutBlockCopyDstDataPerWrite_N0,
|
||||
GemmDataPerReadA,
|
||||
GemmDataPerReadB);
|
||||
|
||||
constexpr index_t out_block_space =
|
||||
math::integer_least_multiple(out_k_b_n0_block_desc.GetElementSpace(), max_lds_align);
|
||||
|
||||
constexpr index_t wei_block_space =
|
||||
math::integer_least_multiple(wei_k_e_c0_block_desc.GetElementSpace(), max_lds_align);
|
||||
|
||||
__shared__ Float p_out_block_double[2 * out_block_space];
|
||||
__shared__ Float p_wei_block_double[2 * wei_block_space];
|
||||
|
||||
// register allocation for output
|
||||
AccFloat p_in_thread[c_e0e1c0_b0b1n0_thread_mtx_desc.GetElementSpace()];
|
||||
|
||||
// zero out threadwise output
|
||||
threadwise_matrix_set_zero(c_e0e1c0_b0b1n0_thread_mtx_desc, p_in_thread);
|
||||
|
||||
// LDS double buffer: preload data into LDS
|
||||
{
|
||||
blockwise_out_copy.Run(p_out_global, p_out_block_double);
|
||||
blockwise_wei_copy.Run(p_wei_global, p_wei_block_double);
|
||||
}
|
||||
|
||||
// LDS double buffer: main body
|
||||
for(index_t k_block_data_begin = 0; k_block_data_begin + 2 * KPerBlock < K;
|
||||
k_block_data_begin += 2 * KPerBlock)
|
||||
{
|
||||
#pragma unroll
|
||||
for(index_t iloop = 0; iloop < 2; ++iloop)
|
||||
{
|
||||
const bool even_loop = (iloop % 2 == 0);
|
||||
|
||||
Float* p_out_block_now =
|
||||
even_loop ? p_out_block_double : p_out_block_double + out_block_space;
|
||||
Float* p_wei_block_now =
|
||||
even_loop ? p_wei_block_double : p_wei_block_double + wei_block_space;
|
||||
|
||||
Float* p_out_block_next =
|
||||
even_loop ? p_out_block_double + out_block_space : p_out_block_double;
|
||||
Float* p_wei_block_next =
|
||||
even_loop ? p_wei_block_double + wei_block_space : p_wei_block_double;
|
||||
|
||||
Float p_out_thread_buffer[blockwise_out_copy.GetThreadBufferSize()];
|
||||
Float p_wei_thread_buffer[blockwise_wei_copy.GetThreadBufferSize()];
|
||||
|
||||
blockwise_out_copy.MoveSrcSliceWindow(Sequence<KPerBlock, 0, 0>{}, True);
|
||||
blockwise_wei_copy.MoveSrcSliceWindow(Sequence<KPerBlock, 0, 0>{}, True);
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// LDS doubel buffer: load next data from device mem
|
||||
blockwise_out_copy.RunLoadThreadBuffer(p_out_global, p_out_thread_buffer);
|
||||
blockwise_wei_copy.RunLoadThreadBuffer(p_wei_global, p_wei_thread_buffer);
|
||||
|
||||
// LDS double buffer: GEMM on current data
|
||||
blockwise_gemm.Run(p_wei_block_now, p_out_block_now, p_in_thread);
|
||||
|
||||
// LDS double buffer: store next data to LDS
|
||||
blockwise_out_copy.RunStoreThreadBuffer(p_out_thread_buffer, p_out_block_next);
|
||||
blockwise_wei_copy.RunStoreThreadBuffer(p_wei_thread_buffer, p_wei_block_next);
|
||||
}
|
||||
}
|
||||
|
||||
// LDS double buffer: tail
|
||||
{
|
||||
constexpr bool has_two_iteration_left = (K % (2 * KPerBlock) == 0);
|
||||
|
||||
if(has_two_iteration_left) // if has 2 iteration left
|
||||
{
|
||||
Float p_out_thread_buffer[blockwise_out_copy.GetThreadBufferSize()];
|
||||
Float p_wei_thread_buffer[blockwise_wei_copy.GetThreadBufferSize()];
|
||||
|
||||
blockwise_out_copy.MoveSrcSliceWindow(Sequence<KPerBlock, 0, 0>{}, True);
|
||||
blockwise_wei_copy.MoveSrcSliceWindow(Sequence<KPerBlock, 0, 0>{}, True);
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// LDS double buffer: load last data from device mem
|
||||
blockwise_out_copy.RunLoadThreadBuffer(p_out_global, p_out_thread_buffer);
|
||||
blockwise_wei_copy.RunLoadThreadBuffer(p_wei_global, p_wei_thread_buffer);
|
||||
|
||||
// LDS double buffer: GEMM on 2nd-last data
|
||||
blockwise_gemm.Run(p_wei_block_double, p_out_block_double, p_in_thread);
|
||||
|
||||
// LDS double buffer: store last data to LDS
|
||||
blockwise_out_copy.RunStoreThreadBuffer(p_out_thread_buffer,
|
||||
p_out_block_double + out_block_space);
|
||||
blockwise_wei_copy.RunStoreThreadBuffer(p_wei_thread_buffer,
|
||||
p_wei_block_double + wei_block_space);
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// LDS double buffer: GEMM on last data
|
||||
blockwise_gemm.Run(p_wei_block_double + wei_block_space,
|
||||
p_out_block_double + out_block_space,
|
||||
p_in_thread);
|
||||
}
|
||||
else // if has 1 iteration left
|
||||
{
|
||||
__syncthreads();
|
||||
|
||||
// LDS double buffer: GEMM on last data
|
||||
blockwise_gemm.Run(p_wei_block_double, p_out_block_double, p_in_thread);
|
||||
}
|
||||
}
|
||||
|
||||
// input: register to global memory, atomic add
|
||||
{
|
||||
constexpr index_t E1 = GemmMLevel0Cluster * GemmMLevel1Cluster;
|
||||
constexpr index_t E0 = E / E1;
|
||||
|
||||
constexpr index_t B1 = GemmNLevel0Cluster * GemmNLevel1Cluster;
|
||||
constexpr index_t B0 = B / B1;
|
||||
|
||||
// define input tensor descriptor for threadwise copy
|
||||
// thread input tensor, src of threadwise copy
|
||||
constexpr auto in_e0_e1_c0_b0_b1_n0_thread_desc = make_native_tensor_descriptor_packed(
|
||||
Sequence<GemmMRepeat, 1, GemmMPerThreadSubC, GemmNRepeat, 1, GemmNPerThreadSubC>{});
|
||||
|
||||
// global input tensor, dst of threadwise copy
|
||||
constexpr auto in_n_c_hip_wip_global_desc = transform_tensor_descriptor(
|
||||
in_n_c_hi_wi_global_desc,
|
||||
make_tuple(PassThrough<N>{},
|
||||
PassThrough<C>{},
|
||||
Pad<Sequence<Hi, Wi>, LeftPads, RightPads>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}));
|
||||
|
||||
constexpr auto in_n0_n1_c0_c1_y_ho_x_wo_global_desc = transform_tensor_descriptor(
|
||||
in_n_c_hip_wip_global_desc,
|
||||
make_tuple(UnMerge<Sequence<N0, N1>>{},
|
||||
UnMerge<Sequence<C0, C1>>{},
|
||||
Embed<Sequence<Y, Ho>, Sequence<ConvDilationH, ConvStrideH, 0>>{},
|
||||
Embed<Sequence<X, Wo>, Sequence<ConvDilationW, ConvStrideW, 0>>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}, Sequence<6, 7>{}));
|
||||
|
||||
constexpr auto in_e_c0_b_n0_global_desc = transform_tensor_descriptor(
|
||||
in_n0_n1_c0_c1_y_ho_x_wo_global_desc,
|
||||
make_tuple(Merge<Sequence<C1, Y, X>>{},
|
||||
PassThrough<C0>{},
|
||||
Merge<Sequence<N1, Ho, Wo>>{},
|
||||
PassThrough<N0>{}),
|
||||
make_tuple(Sequence<3, 4, 6>{}, Sequence<2>{}, Sequence<1, 5, 7>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
|
||||
|
||||
constexpr auto in_e0_e1_c0_b0_b1_n0_global_desc = transform_tensor_descriptor(
|
||||
in_e_c0_b_n0_global_desc,
|
||||
make_tuple(UnMerge<Sequence<E0, E1>>{},
|
||||
PassThrough<C0>{},
|
||||
UnMerge<Sequence<B0, B1>>{},
|
||||
PassThrough<N0>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0, 1>{}, Sequence<2>{}, Sequence<3, 4>{}, Sequence<5>{}));
|
||||
|
||||
// calculate origin of thread input tensor on global memory
|
||||
// blockwise GEMM c matrix starting index
|
||||
const auto c_thread_mtx_on_block =
|
||||
blockwise_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id());
|
||||
|
||||
const index_t e_thread_data_on_global =
|
||||
e_block_data_on_global + c_thread_mtx_on_block.row / GemmMPerThreadSubC;
|
||||
|
||||
const index_t b_thread_data_on_global =
|
||||
b_block_data_on_global + c_thread_mtx_on_block.col / GemmNPerThreadSubC;
|
||||
|
||||
ThreadwiseGenericTensorSliceCopy_v4r2<
|
||||
decltype(in_e0_e1_c0_b0_b1_n0_thread_desc),
|
||||
decltype(in_e0_e1_c0_b0_b1_n0_global_desc),
|
||||
decltype(in_e0_e1_c0_b0_b1_n0_thread_desc.GetLengths()),
|
||||
Sequence<0, 1, 2, 3, 4, 5>,
|
||||
4,
|
||||
1,
|
||||
InThreadCopyDstDataPerWrite_B,
|
||||
AddressSpace::vgpr,
|
||||
AddressSpace::global,
|
||||
InMemoryDataOperation::atomic_add>({0, 0, 0, 0, 0, 0},
|
||||
{e_thread_data_on_global / E1,
|
||||
e_thread_data_on_global % E1,
|
||||
0,
|
||||
b_thread_data_on_global / B1,
|
||||
b_thread_data_on_global % B1,
|
||||
0})
|
||||
.Run(p_in_thread, p_in_global);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
@@ -0,0 +1,211 @@
|
||||
#ifndef CK_GRIDWISE_CONVOLUTION_BACKWARD_DATA_IMPLICIT_GEMM_V2R1_NCHW_KCYX_NKHW_HPP
|
||||
#define CK_GRIDWISE_CONVOLUTION_BACKWARD_DATA_IMPLICIT_GEMM_V2R1_NCHW_KCYX_NKHW_HPP
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "tensor_descriptor.hpp"
|
||||
#include "tensor_descriptor_helper.hpp"
|
||||
#include "gridwise_gemm.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
// GemmK = K * Ydot * Xdot;
|
||||
// GemmM = C * Ytilda * Xtilda;
|
||||
// GemmN = N * Htilda * Wtilda;
|
||||
template <index_t GridSize,
|
||||
index_t BlockSize,
|
||||
typename Float,
|
||||
typename AccFloat,
|
||||
typename InGlobalDesc,
|
||||
typename WeiGlobalDesc,
|
||||
typename OutGlobalDesc,
|
||||
typename ConvStrides,
|
||||
typename ConvDilations,
|
||||
typename LeftPads,
|
||||
typename RightPads,
|
||||
index_t GemmMPerBlock,
|
||||
index_t GemmNPerBlock,
|
||||
index_t GemmKPerBlock,
|
||||
index_t GemmMPerThreadSubC,
|
||||
index_t GemmNPerThreadSubC,
|
||||
index_t GemmMLevel0Cluster,
|
||||
index_t GemmNLevel0Cluster,
|
||||
index_t GemmMLevel1Cluster,
|
||||
index_t GemmNLevel1Cluster,
|
||||
index_t GemmKPerThreadLoop,
|
||||
index_t GemmThreadGemmDataPerReadM,
|
||||
index_t GemmThreadGemmDataPerReadN,
|
||||
typename GemmABlockCopySubLengths, // Gemm-K, Gemm-M
|
||||
typename GemmABlockCopyClusterLengths, // Gemm-K, Gemm-M
|
||||
index_t GemmABlockCopyDataPerAccess, // Gemm-M
|
||||
typename GemmBBlockCopySubLengths, // Gemm-K, Gemm-N
|
||||
typename GemmBBlockCopyClusterLengths, // Gemm-K, Gemm-N
|
||||
index_t GemmBBlockCopyDataPerAccess, // Gemm-N
|
||||
index_t GemmCThreadCopyDataPerAccess // Gemm-N
|
||||
>
|
||||
struct GridwiseConvolutionBackwardDataImplicitGemm_v2r1_nchw_kcyx_nkhw
|
||||
{
|
||||
__device__ void Run(Float* __restrict__ p_in_global,
|
||||
const Float* __restrict__ p_wei_global,
|
||||
const Float* __restrict__ p_out_global) const
|
||||
{
|
||||
constexpr auto in_n_c_hi_wi_global_desc = InGlobalDesc{};
|
||||
constexpr auto wei_k_c_y_x_global_desc = WeiGlobalDesc{};
|
||||
constexpr auto out_n_k_ho_wo_global_desc = OutGlobalDesc{};
|
||||
|
||||
constexpr index_t N = in_n_c_hi_wi_global_desc.GetLengths()[0];
|
||||
constexpr index_t C = in_n_c_hi_wi_global_desc.GetLengths()[1];
|
||||
constexpr index_t Hi = in_n_c_hi_wi_global_desc.GetLengths()[2];
|
||||
constexpr index_t Wi = in_n_c_hi_wi_global_desc.GetLengths()[3];
|
||||
|
||||
constexpr index_t K = out_n_k_ho_wo_global_desc.GetLengths()[1];
|
||||
constexpr index_t Ho = out_n_k_ho_wo_global_desc.GetLengths()[2];
|
||||
constexpr index_t Wo = out_n_k_ho_wo_global_desc.GetLengths()[3];
|
||||
|
||||
constexpr index_t Y = wei_k_c_y_x_global_desc.GetLengths()[2];
|
||||
constexpr index_t X = wei_k_c_y_x_global_desc.GetLengths()[3];
|
||||
|
||||
constexpr index_t ConvStrideH = ConvStrides{}[0];
|
||||
constexpr index_t ConvStrideW = ConvStrides{}[1];
|
||||
|
||||
constexpr index_t ConvDilationH = ConvDilations{}[0];
|
||||
constexpr index_t ConvDilationW = ConvDilations{}[1];
|
||||
|
||||
// sanity-check for vectorized memory load
|
||||
static_assert((Wo == 1 || (ConvStrideW == 1 || GemmCThreadCopyDataPerAccess == 1)) &&
|
||||
(X == 1 || ConvDilationW % GemmCThreadCopyDataPerAccess == 0),
|
||||
"wrong! aligment requirement for vectorized global load of input tensor will "
|
||||
"be violated");
|
||||
|
||||
constexpr index_t hcf_stride_dilation_h = math::hcf(ConvStrideH, ConvDilationH);
|
||||
constexpr index_t hcf_stride_dilation_w = math::hcf(ConvStrideW, ConvDilationW);
|
||||
|
||||
constexpr index_t Ytilda = ConvStrideH / hcf_stride_dilation_h;
|
||||
constexpr index_t Xtilda = ConvStrideW / hcf_stride_dilation_w;
|
||||
|
||||
constexpr index_t Ydot = math::integer_divide_ceil(Y, Ytilda);
|
||||
constexpr index_t Xdot = math::integer_divide_ceil(X, Xtilda);
|
||||
|
||||
constexpr index_t right_pad_ho = (ConvDilationH / hcf_stride_dilation_h) * (Y - Ytilda);
|
||||
constexpr index_t right_pad_wo = (ConvDilationW / hcf_stride_dilation_w) * (X - Xtilda);
|
||||
|
||||
constexpr index_t Htilda = Ho + right_pad_ho;
|
||||
constexpr index_t Wtilda = Wo + right_pad_wo;
|
||||
|
||||
// weight tensor
|
||||
constexpr auto wei_k_c_yp_xp_global_desc = transform_tensor_descriptor(
|
||||
wei_k_c_y_x_global_desc,
|
||||
make_tuple(PassThrough<K>{},
|
||||
PassThrough<C>{},
|
||||
Pad<Sequence<Y, X>,
|
||||
Sequence<0, 0>,
|
||||
Sequence<Ydot * Ytilda - Y, Xdot * Xtilda - X>,
|
||||
true>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}));
|
||||
|
||||
constexpr auto wei_k_c_ydot_ytilda_xdot_xtilda_global_desc = transform_tensor_descriptor(
|
||||
wei_k_c_yp_xp_global_desc,
|
||||
make_tuple(PassThrough<K>{},
|
||||
PassThrough<C>{},
|
||||
Embed<Sequence<Ydot, Ytilda>,
|
||||
Sequence<ConvStrideH / hcf_stride_dilation_h, 1, 0>>{},
|
||||
Embed<Sequence<Xdot, Xtilda>,
|
||||
Sequence<ConvStrideW / hcf_stride_dilation_w, 1, 0>>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}));
|
||||
|
||||
constexpr auto wei_gemmk_gemmm_global_desc = transform_tensor_descriptor(
|
||||
wei_k_c_ydot_ytilda_xdot_xtilda_global_desc,
|
||||
make_tuple(Merge<Sequence<K, Ydot, Xdot>>{}, Merge<Sequence<C, Ytilda, Xtilda>>{}),
|
||||
make_tuple(Sequence<0, 2, 4>{}, Sequence<1, 3, 5>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
// output tensor
|
||||
constexpr auto out_n_k_hop_wop_global_desc =
|
||||
transform_tensor_descriptor(out_n_k_ho_wo_global_desc,
|
||||
make_tuple(PassThrough<N>{},
|
||||
PassThrough<K>{},
|
||||
Pad<Sequence<Ho, Wo>,
|
||||
Sequence<0, 0>,
|
||||
Sequence<right_pad_ho, right_pad_wo>,
|
||||
true>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}));
|
||||
|
||||
constexpr auto out_n_k_ydot_htilda_xdot_wtilda_global_desc = transform_tensor_descriptor(
|
||||
out_n_k_hop_wop_global_desc,
|
||||
make_tuple(PassThrough<N>{},
|
||||
PassThrough<K>{},
|
||||
Embed<Sequence<Ydot, Htilda>,
|
||||
Sequence<-ConvDilationH / hcf_stride_dilation_h, 1, 0>>{},
|
||||
Embed<Sequence<Xdot, Wtilda>,
|
||||
Sequence<-ConvDilationW / hcf_stride_dilation_w, 1, 0>>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}));
|
||||
|
||||
constexpr auto out_gemmk_gemmn_global_desc = transform_tensor_descriptor(
|
||||
out_n_k_ydot_htilda_xdot_wtilda_global_desc,
|
||||
make_tuple(Merge<Sequence<K, Ydot, Xdot>>{}, Merge<Sequence<N, Htilda, Wtilda>>{}),
|
||||
make_tuple(Sequence<1, 2, 4>{}, Sequence<0, 3, 5>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
// input tensor
|
||||
constexpr auto in_n_c_hip_wip_global_desc = transform_tensor_descriptor(
|
||||
in_n_c_hi_wi_global_desc,
|
||||
make_tuple(PassThrough<N>{},
|
||||
PassThrough<C>{},
|
||||
Pad<Sequence<Hi, Wi>, LeftPads, RightPads, true>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}));
|
||||
|
||||
constexpr auto in_n_c_ytilda_htilda_xtilda_wtilda_global_desc = transform_tensor_descriptor(
|
||||
in_n_c_hip_wip_global_desc,
|
||||
make_tuple(PassThrough<N>{},
|
||||
PassThrough<C>{},
|
||||
Embed<Sequence<Ytilda, Htilda>, Sequence<ConvDilationH, ConvStrideH, 0>>{},
|
||||
Embed<Sequence<Xtilda, Wtilda>, Sequence<ConvDilationW, ConvStrideW, 0>>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}));
|
||||
|
||||
constexpr auto in_gemmm_gemmn_global_desc = transform_tensor_descriptor(
|
||||
in_n_c_ytilda_htilda_xtilda_wtilda_global_desc,
|
||||
make_tuple(Merge<Sequence<C, Ytilda, Xtilda>>{}, Merge<Sequence<N, Htilda, Wtilda>>{}),
|
||||
make_tuple(Sequence<1, 2, 4>{}, Sequence<0, 3, 5>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
// GEMM
|
||||
constexpr auto gridwise_gemm =
|
||||
GridwiseGemmTransposedANormalBNormalC_v1r1<GridSize,
|
||||
BlockSize,
|
||||
Float,
|
||||
AccFloat,
|
||||
decltype(wei_gemmk_gemmm_global_desc),
|
||||
decltype(out_gemmk_gemmn_global_desc),
|
||||
decltype(in_gemmm_gemmn_global_desc),
|
||||
InMemoryDataOperation::none,
|
||||
GemmMPerBlock,
|
||||
GemmNPerBlock,
|
||||
GemmKPerBlock,
|
||||
GemmMPerThreadSubC,
|
||||
GemmNPerThreadSubC,
|
||||
GemmMLevel0Cluster,
|
||||
GemmNLevel0Cluster,
|
||||
GemmMLevel1Cluster,
|
||||
GemmNLevel1Cluster,
|
||||
GemmKPerThreadLoop,
|
||||
GemmThreadGemmDataPerReadM,
|
||||
GemmThreadGemmDataPerReadN,
|
||||
GemmABlockCopySubLengths,
|
||||
GemmABlockCopyClusterLengths,
|
||||
GemmABlockCopyDataPerAccess,
|
||||
GemmBBlockCopySubLengths,
|
||||
GemmBBlockCopyClusterLengths,
|
||||
GemmBBlockCopyDataPerAccess,
|
||||
GemmCThreadCopyDataPerAccess>{};
|
||||
|
||||
gridwise_gemm.Run(p_wei_global, p_out_global, p_in_global);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
@@ -93,8 +93,9 @@ struct GridwiseConvolutionDirect_v2_nchw_kcyx_nkhw
|
||||
constexpr auto wei_kcyx_thread_block_desc = make_ConstantTensorDescriptor(
|
||||
Sequence<KPerThread, CPerThread, Y, X>{}, wei_kcyx_block_desc.GetStrides());
|
||||
|
||||
constexpr auto out_nkhw_thread_desc = get_convolution_output_default_4d_tensor_descriptor(
|
||||
in_nchw_thread_block_desc, wei_kcyx_thread_block_desc);
|
||||
constexpr auto out_nkhw_thread_desc =
|
||||
get_convolution_output_default_4d_tensor_descriptor_deprecated(
|
||||
in_nchw_thread_block_desc, wei_kcyx_thread_block_desc);
|
||||
|
||||
// register
|
||||
Float p_out_thread[out_nkhw_thread_desc.GetElementSpace()];
|
||||
|
||||
@@ -107,11 +107,6 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer
|
||||
|
||||
constexpr auto True = integral_constant<bool, true>{};
|
||||
|
||||
constexpr auto generic_address_space =
|
||||
integral_constant<AddressSpace, AddressSpace::generic>{};
|
||||
constexpr auto global_address_space =
|
||||
integral_constant<AddressSpace, AddressSpace::global>{};
|
||||
|
||||
static_assert(ConvDirection == ConvolutionDirection::Forward ||
|
||||
ConvDirection == ConvolutionDirection::BackwardWeight,
|
||||
"wrong! this kernel only support convolution forward and backward-weight");
|
||||
@@ -130,17 +125,17 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer
|
||||
constexpr auto wei_k_c_y_x_global_desc = WeiGlobalDesc{};
|
||||
constexpr auto out_n_k_ho_wo_global_desc = OutGlobalDesc{};
|
||||
|
||||
constexpr index_t N = in_n_c_hi_wi_global_desc.GetLength(I0);
|
||||
constexpr index_t C = in_n_c_hi_wi_global_desc.GetLength(I1);
|
||||
constexpr index_t Hi = in_n_c_hi_wi_global_desc.GetLength(I2);
|
||||
constexpr index_t Wi = in_n_c_hi_wi_global_desc.GetLength(I3);
|
||||
constexpr index_t N = in_n_c_hi_wi_global_desc.GetLengths()[0];
|
||||
constexpr index_t C = in_n_c_hi_wi_global_desc.GetLengths()[1];
|
||||
constexpr index_t Hi = in_n_c_hi_wi_global_desc.GetLengths()[2];
|
||||
constexpr index_t Wi = in_n_c_hi_wi_global_desc.GetLengths()[3];
|
||||
|
||||
constexpr index_t K = out_n_k_ho_wo_global_desc.GetLength(I1);
|
||||
constexpr index_t Ho = out_n_k_ho_wo_global_desc.GetLength(I2);
|
||||
constexpr index_t Wo = out_n_k_ho_wo_global_desc.GetLength(I3);
|
||||
constexpr index_t K = out_n_k_ho_wo_global_desc.GetLengths()[1];
|
||||
constexpr index_t Ho = out_n_k_ho_wo_global_desc.GetLengths()[2];
|
||||
constexpr index_t Wo = out_n_k_ho_wo_global_desc.GetLengths()[3];
|
||||
|
||||
constexpr index_t Y = wei_k_c_y_x_global_desc.GetLength(I2);
|
||||
constexpr index_t X = wei_k_c_y_x_global_desc.GetLength(I3);
|
||||
constexpr index_t Y = wei_k_c_y_x_global_desc.GetLengths()[2];
|
||||
constexpr index_t X = wei_k_c_y_x_global_desc.GetLengths()[3];
|
||||
|
||||
constexpr index_t ConvStrideH = ConvStrides{}[0];
|
||||
constexpr index_t ConvStrideW = ConvStrides{}[1];
|
||||
@@ -230,7 +225,11 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer
|
||||
2,
|
||||
3,
|
||||
InBlockCopySrcDataPerRead_B,
|
||||
InBlockCopyDstDataPerWrite_N2>(
|
||||
InBlockCopyDstDataPerWrite_N2,
|
||||
AddressSpace::global,
|
||||
AddressSpace::vgpr,
|
||||
AddressSpace::lds,
|
||||
InMemoryDataOperation::none>(
|
||||
{0, 0, b_block_data_on_global, 0}, {0, 0, 0, 0});
|
||||
|
||||
// weight tensor
|
||||
@@ -266,7 +265,11 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer
|
||||
0,
|
||||
1,
|
||||
WeiBlockCopySrcDataPerRead_E,
|
||||
WeiBlockCopyDstDataPerWrite_K>(
|
||||
WeiBlockCopyDstDataPerWrite_K,
|
||||
AddressSpace::global,
|
||||
AddressSpace::vgpr,
|
||||
AddressSpace::lds,
|
||||
InMemoryDataOperation::none>(
|
||||
{0, k_block_data_on_global}, {0, 0});
|
||||
|
||||
// GEMM definition
|
||||
@@ -334,10 +337,8 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer
|
||||
|
||||
// LDS double buffer: preload data into LDS
|
||||
{
|
||||
blockwise_in_copy.Run(
|
||||
p_in_global, p_in_block_double, global_address_space, generic_address_space);
|
||||
blockwise_wei_copy.Run(
|
||||
p_wei_global, p_wei_block_double, global_address_space, generic_address_space);
|
||||
blockwise_in_copy.Run(p_in_global, p_in_block_double);
|
||||
blockwise_wei_copy.Run(p_wei_global, p_wei_block_double);
|
||||
}
|
||||
|
||||
// LDS double buffer: main body
|
||||
@@ -368,10 +369,8 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer
|
||||
__syncthreads();
|
||||
|
||||
// LDS doubel buffer: load next data from device mem
|
||||
blockwise_in_copy.RunLoadThreadBuffer(
|
||||
p_in_global, p_in_thread_buffer, global_address_space, generic_address_space);
|
||||
blockwise_wei_copy.RunLoadThreadBuffer(
|
||||
p_wei_global, p_wei_thread_buffer, global_address_space, generic_address_space);
|
||||
blockwise_in_copy.RunLoadThreadBuffer(p_in_global, p_in_thread_buffer);
|
||||
blockwise_wei_copy.RunLoadThreadBuffer(p_wei_global, p_wei_thread_buffer);
|
||||
|
||||
// LDS double buffer: GEMM on current data
|
||||
blockwise_gemm.Run(p_wei_block_now, p_in_block_now, p_out_thread);
|
||||
@@ -397,10 +396,8 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer
|
||||
__syncthreads();
|
||||
|
||||
// LDS double buffer: load last data from device mem
|
||||
blockwise_in_copy.RunLoadThreadBuffer(
|
||||
p_in_global, p_in_thread_buffer, global_address_space, generic_address_space);
|
||||
blockwise_wei_copy.RunLoadThreadBuffer(
|
||||
p_wei_global, p_wei_thread_buffer, global_address_space, generic_address_space);
|
||||
blockwise_in_copy.RunLoadThreadBuffer(p_in_global, p_in_thread_buffer);
|
||||
blockwise_wei_copy.RunLoadThreadBuffer(p_wei_global, p_wei_thread_buffer);
|
||||
|
||||
// LDS double buffer: GEMM on 2nd-last data
|
||||
blockwise_gemm.Run(p_wei_block_double, p_in_block_double, p_out_thread);
|
||||
@@ -474,20 +471,23 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer
|
||||
const index_t b_thread_data_on_global =
|
||||
b_block_data_on_global + c_thread_mtx_on_block.col / N2;
|
||||
|
||||
ThreadwiseGenericTensorSliceCopy_v4r2<decltype(out_k0_k1_n1_b_n2_thread_desc),
|
||||
decltype(out_k0_k1_n1_b_n2_global_desc),
|
||||
decltype(
|
||||
out_k0_k1_n1_b_n2_thread_desc.GetLengths()),
|
||||
arithmetic_sequence_gen<0, 5, 1>::type,
|
||||
3,
|
||||
1,
|
||||
1>({0, 0, 0, 0, 0},
|
||||
{k_thread_data_on_global / K1,
|
||||
k_thread_data_on_global % K1,
|
||||
0,
|
||||
b_thread_data_on_global,
|
||||
0})
|
||||
.Run(p_out_thread, p_out_global, generic_address_space, global_address_space);
|
||||
ThreadwiseGenericTensorSliceCopy_v4r2<
|
||||
decltype(out_k0_k1_n1_b_n2_thread_desc),
|
||||
decltype(out_k0_k1_n1_b_n2_global_desc),
|
||||
decltype(out_k0_k1_n1_b_n2_thread_desc.GetLengths()),
|
||||
arithmetic_sequence_gen<0, 5, 1>::type,
|
||||
3,
|
||||
1,
|
||||
1,
|
||||
AddressSpace::vgpr,
|
||||
AddressSpace::global,
|
||||
InMemoryDataOperation::none>({0, 0, 0, 0, 0},
|
||||
{k_thread_data_on_global / K1,
|
||||
k_thread_data_on_global % K1,
|
||||
0,
|
||||
b_thread_data_on_global,
|
||||
0})
|
||||
.Run(p_out_thread, p_out_global);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
@@ -10,7 +10,6 @@
|
||||
#include "blockwise_gemm.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
// B = merge(N, Ho, Wo)
|
||||
template <index_t GridSize,
|
||||
index_t BlockSize,
|
||||
@@ -61,11 +60,6 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_lds_double_buffer
|
||||
|
||||
constexpr auto True = integral_constant<bool, true>{};
|
||||
|
||||
constexpr auto generic_address_space =
|
||||
integral_constant<AddressSpace, AddressSpace::generic>{};
|
||||
constexpr auto global_address_space =
|
||||
integral_constant<AddressSpace, AddressSpace::global>{};
|
||||
|
||||
constexpr auto in_n_c_hi_wi_global_desc =
|
||||
make_native_tensor_descriptor(InGlobalDesc::GetLengths(), InGlobalDesc::GetStrides());
|
||||
constexpr auto wei_k_c_y_x_global_desc =
|
||||
@@ -158,7 +152,11 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_lds_double_buffer
|
||||
1,
|
||||
1,
|
||||
InBlockCopyDataPerAccess_B,
|
||||
InBlockCopyDataPerAccess_B>(
|
||||
InBlockCopyDataPerAccess_B,
|
||||
AddressSpace::global,
|
||||
AddressSpace::vgpr,
|
||||
AddressSpace::lds,
|
||||
InMemoryDataOperation::none>(
|
||||
{0, b_block_data_on_global}, {0, 0});
|
||||
|
||||
// weight tensor
|
||||
@@ -192,7 +190,11 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_lds_double_buffer
|
||||
0,
|
||||
1,
|
||||
WeiBlockCopySrcDataPerRead_E,
|
||||
WeiBlockCopyDstDataPerWrite_K>(
|
||||
WeiBlockCopyDstDataPerWrite_K,
|
||||
AddressSpace::global,
|
||||
AddressSpace::vgpr,
|
||||
AddressSpace::lds,
|
||||
InMemoryDataOperation::none>(
|
||||
{0, k_block_data_on_global}, {0, 0});
|
||||
|
||||
// GEMM definition
|
||||
@@ -202,7 +204,6 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_lds_double_buffer
|
||||
// c_mtx[KPerBlock, BPerBlock] is distributed among threads, and saved in
|
||||
// register
|
||||
constexpr auto a_e_k_block_mtx_desc = make_ConstantMatrixDescriptor(wei_e_k_block_desc);
|
||||
|
||||
constexpr auto b_e_b_block_mtx_desc = make_ConstantMatrixDescriptor(in_e_b_block_desc);
|
||||
|
||||
// sanity check
|
||||
@@ -260,10 +261,8 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_lds_double_buffer
|
||||
|
||||
// LDS double buffer: preload data into LDS
|
||||
{
|
||||
blockwise_in_copy.Run(
|
||||
p_in_global, p_in_block_double, global_address_space, generic_address_space);
|
||||
blockwise_wei_copy.Run(
|
||||
p_wei_global, p_wei_block_double, global_address_space, generic_address_space);
|
||||
blockwise_in_copy.Run(p_in_global, p_in_block_double);
|
||||
blockwise_wei_copy.Run(p_wei_global, p_wei_block_double);
|
||||
}
|
||||
|
||||
// LDS double buffer: main body
|
||||
@@ -294,10 +293,8 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_lds_double_buffer
|
||||
__syncthreads();
|
||||
|
||||
// LDS doubel buffer: load next data from device mem
|
||||
blockwise_in_copy.RunLoadThreadBuffer(
|
||||
p_in_global, p_in_thread_buffer, global_address_space, generic_address_space);
|
||||
blockwise_wei_copy.RunLoadThreadBuffer(
|
||||
p_wei_global, p_wei_thread_buffer, global_address_space, generic_address_space);
|
||||
blockwise_in_copy.RunLoadThreadBuffer(p_in_global, p_in_thread_buffer);
|
||||
blockwise_wei_copy.RunLoadThreadBuffer(p_wei_global, p_wei_thread_buffer);
|
||||
|
||||
// LDS double buffer: GEMM on current data
|
||||
blockwise_gemm.Run(p_wei_block_now, p_in_block_now, p_out_thread);
|
||||
@@ -323,10 +320,8 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_lds_double_buffer
|
||||
__syncthreads();
|
||||
|
||||
// LDS double buffer: load last data from device mem
|
||||
blockwise_in_copy.RunLoadThreadBuffer(
|
||||
p_in_global, p_in_thread_buffer, global_address_space, generic_address_space);
|
||||
blockwise_wei_copy.RunLoadThreadBuffer(
|
||||
p_wei_global, p_wei_thread_buffer, global_address_space, generic_address_space);
|
||||
blockwise_in_copy.RunLoadThreadBuffer(p_in_global, p_in_thread_buffer);
|
||||
blockwise_wei_copy.RunLoadThreadBuffer(p_wei_global, p_wei_thread_buffer);
|
||||
|
||||
// LDS double buffer: GEMM on 2nd-last data
|
||||
blockwise_gemm.Run(p_wei_block_double, p_in_block_double, p_out_thread);
|
||||
@@ -397,17 +392,14 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_lds_double_buffer
|
||||
arithmetic_sequence_gen<0, 4, 1>::type,
|
||||
3,
|
||||
OutThreadCopyDataPerAccess_B,
|
||||
OutThreadCopyDataPerAccess_B>({0, 0, 0, 0},
|
||||
{k_thread_data_on_global / K1,
|
||||
k_thread_data_on_global % K1,
|
||||
b_thread_data_on_global / B1,
|
||||
b_thread_data_on_global % B1})
|
||||
#if 1
|
||||
.Run(p_out_thread, p_out_global, generic_address_space, global_address_space);
|
||||
#else // tweaking
|
||||
.Run_optimized_dst_address_calculation(
|
||||
p_out_thread, p_out_global, generic_address_space, global_address_space);
|
||||
#endif
|
||||
OutThreadCopyDataPerAccess_B,
|
||||
AddressSpace::vgpr,
|
||||
AddressSpace::global>({0, 0, 0, 0},
|
||||
{k_thread_data_on_global / K1,
|
||||
k_thread_data_on_global % K1,
|
||||
b_thread_data_on_global / B1,
|
||||
b_thread_data_on_global % B1})
|
||||
.Run(p_out_thread, p_out_global);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
@@ -60,7 +60,7 @@ __host__ __device__ constexpr auto
|
||||
|
||||
template <typename... Ts>
|
||||
__host__ __device__ constexpr auto
|
||||
make_ConstantMatrixDescriptor(ConstantTensorDescriptor_deprecated<Ts...>)
|
||||
make_ConstantMatrixDescriptor(ConstantTensorDescriptor_deprecated<Ts...>)
|
||||
{
|
||||
using TDesc = ConstantTensorDescriptor_deprecated<Ts...>;
|
||||
static_assert(TDesc::GetNumOfDimension() == 2, "wrong");
|
||||
|
||||
@@ -84,35 +84,14 @@ struct Pad
|
||||
__host__ __device__ constexpr bool
|
||||
IsUpperIndexMappedToValidLowerIndex(const UpperIndex& idx_up) const
|
||||
{
|
||||
#if 0
|
||||
struct lambda_no_pad
|
||||
{
|
||||
__host__ __device__ constexpr bool operator()(index_t x) const { return x == 0; }
|
||||
};
|
||||
bool flag = true;
|
||||
|
||||
if(sequence_all_of(LeftPads{}, lambda_no_pad{}) &&
|
||||
sequence_all_of(RightPads{}, lambda_no_pad{}))
|
||||
{
|
||||
return true;
|
||||
}
|
||||
else
|
||||
#endif
|
||||
{
|
||||
bool flag = true;
|
||||
static_for<0, nDim, 1>{}([&](auto idim) {
|
||||
flag = flag && (idx_up[idim] >= LeftPads::At(idim)) &&
|
||||
(idx_up[idim] < LeftPads::At(idim) + LowerLengths::At(idim));
|
||||
});
|
||||
|
||||
static_for<0, nDim, 1>{}([&](auto idim) {
|
||||
// only check if there is left-padding
|
||||
static_if<(LeftPads::At(idim) != 0)>{}(
|
||||
[&](auto) { flag = flag && idx_up[idim] >= LeftPads::At(idim); });
|
||||
|
||||
// only check if there is right-padding
|
||||
static_if<(RightPads::At(idim) != 0)>{}([&](auto) {
|
||||
flag = flag && (idx_up[idim] < LeftPads::At(idim) + LowerLengths::At(idim));
|
||||
});
|
||||
});
|
||||
|
||||
return flag;
|
||||
}
|
||||
return flag;
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -228,7 +228,7 @@ struct TensorCoordinate
|
||||
private:
|
||||
template <typename... Ts>
|
||||
__host__ __device__ static constexpr auto
|
||||
MakeDummyTensorCoordinate(NativeTensorDescriptor<Ts...>)
|
||||
MakeDummyTensorCoordinate(NativeTensorDescriptor<Ts...>)
|
||||
{
|
||||
return NativeTensorCoordinate<NativeTensorDescriptor<Ts...>>(
|
||||
make_zero_array<index_t, TensorDesc::GetNumOfDimension()>());
|
||||
@@ -236,7 +236,7 @@ struct TensorCoordinate
|
||||
|
||||
template <typename... Ts>
|
||||
__host__ __device__ static constexpr auto
|
||||
MakeDummyTensorCoordinate(TransformedTensorDescriptor<Ts...>)
|
||||
MakeDummyTensorCoordinate(TransformedTensorDescriptor<Ts...>)
|
||||
{
|
||||
return TransformedTensorCoordinate<TransformedTensorDescriptor<Ts...>>(
|
||||
make_zero_array<index_t, TensorDesc::GetNumOfDimension()>());
|
||||
|
||||
@@ -327,14 +327,14 @@ struct TensorCoordinate_deprecated
|
||||
private:
|
||||
template <class... Ts>
|
||||
__host__ __device__ static constexpr auto
|
||||
MakeDummyTensorCoordinate(ConstantTensorDescriptor_deprecated<Ts...>)
|
||||
MakeDummyTensorCoordinate(ConstantTensorDescriptor_deprecated<Ts...>)
|
||||
{
|
||||
return NormalTensorCoordinate_deprecated<ConstantTensorDescriptor_deprecated<Ts...>>();
|
||||
}
|
||||
|
||||
template <class... Ts>
|
||||
__host__ __device__ static constexpr auto
|
||||
MakeDummyTensorCoordinate(ConstantMergedTensorDescriptor_deprecated<Ts...>)
|
||||
MakeDummyTensorCoordinate(ConstantMergedTensorDescriptor_deprecated<Ts...>)
|
||||
{
|
||||
return MergedTensorCoordinate_deprecated<
|
||||
ConstantMergedTensorDescriptor_deprecated<Ts...>>();
|
||||
|
||||
@@ -64,10 +64,10 @@ template <typename LowerTensorDescriptor,
|
||||
index_t... LowerDimensionIds,
|
||||
index_t... UpperDimensionIds>
|
||||
__host__ __device__ constexpr auto
|
||||
reorder_transformed_tensor_descriptor_impl(LowerTensorDescriptor,
|
||||
Sequence<LowerLengths...>,
|
||||
Sequence<LowerDimensionIds...>,
|
||||
Sequence<UpperDimensionIds...>)
|
||||
reorder_transformed_tensor_descriptor_impl(LowerTensorDescriptor,
|
||||
Sequence<LowerLengths...>,
|
||||
Sequence<LowerDimensionIds...>,
|
||||
Sequence<UpperDimensionIds...>)
|
||||
{
|
||||
return TransformedTensorDescriptor<LowerTensorDescriptor,
|
||||
Tuple<PassThrough<LowerLengths>...>,
|
||||
@@ -78,7 +78,7 @@ __host__ __device__ constexpr auto
|
||||
// reorder a NativeTensorDescriptor
|
||||
template <typename... Ts, typename MapLower2Upper>
|
||||
__host__ __device__ constexpr auto
|
||||
reorder_tensor_descriptor_given_lower2upper(NativeTensorDescriptor<Ts...>, MapLower2Upper)
|
||||
reorder_tensor_descriptor_given_lower2upper(NativeTensorDescriptor<Ts...>, MapLower2Upper)
|
||||
{
|
||||
static_assert(is_valid_sequence_map<MapLower2Upper>{},
|
||||
"wrong! MapLower2Upper is not a valid map");
|
||||
@@ -96,7 +96,7 @@ __host__ __device__ constexpr auto
|
||||
// reorder a TransformedTensorDescriptor
|
||||
template <typename... Ts, typename MapLower2Upper>
|
||||
__host__ __device__ constexpr auto
|
||||
reorder_tensor_descriptor_given_lower2upper(TransformedTensorDescriptor<Ts...>, MapLower2Upper)
|
||||
reorder_tensor_descriptor_given_lower2upper(TransformedTensorDescriptor<Ts...>, MapLower2Upper)
|
||||
{
|
||||
static_assert(is_valid_sequence_map<MapLower2Upper>{},
|
||||
"wrong! MapLower2Upper is not a valid map");
|
||||
|
||||
@@ -21,7 +21,11 @@ template <index_t BlockSize,
|
||||
index_t SrcVectorAccessDim,
|
||||
index_t DstVectorAccessDim,
|
||||
index_t SrcDataPerAccess,
|
||||
index_t DstDataPerAccess>
|
||||
index_t DstDataPerAccess,
|
||||
AddressSpace SrcAddressSpace = AddressSpace::generic,
|
||||
AddressSpace ThreadBufferAddressSpace = AddressSpace::generic,
|
||||
AddressSpace DstAddressSpace = AddressSpace::generic,
|
||||
InMemoryDataOperation DstInMemOp = InMemoryDataOperation::none>
|
||||
struct BlockwiseGenericTensorSliceCopy_v4
|
||||
{
|
||||
static constexpr index_t nDim = BlockSrcDesc::GetNumOfDimension();
|
||||
@@ -66,76 +70,21 @@ struct BlockwiseGenericTensorSliceCopy_v4
|
||||
return ThreadBufferDesc::GetElementSpace();
|
||||
}
|
||||
|
||||
template <typename BlockSrcData,
|
||||
typename ThreadBufferData,
|
||||
AddressSpace BlockSrcAddressSpace,
|
||||
AddressSpace ThreadBufferAddressSpace>
|
||||
__device__ void
|
||||
RunLoadThreadBuffer(const BlockSrcData* p_block_src,
|
||||
ThreadBufferData* p_thread_buffer,
|
||||
integral_constant<AddressSpace, BlockSrcAddressSpace>,
|
||||
integral_constant<AddressSpace, ThreadBufferAddressSpace>) const
|
||||
{
|
||||
constexpr auto block_src_address_space =
|
||||
integral_constant<AddressSpace, BlockSrcAddressSpace>{};
|
||||
constexpr auto thread_buffer_address_space =
|
||||
integral_constant<AddressSpace, ThreadBufferAddressSpace>{};
|
||||
|
||||
constexpr bool has_optimized_address_calculation =
|
||||
decltype(mThreadwiseStore)::HasWorkingOptimizedAddressCalculation();
|
||||
|
||||
// TODO: threadwise copy is still being tweaked
|
||||
if(has_optimized_address_calculation)
|
||||
{
|
||||
mThreadwiseLoad.Run_optimized_src_address_calculation(
|
||||
p_block_src, p_thread_buffer, block_src_address_space, thread_buffer_address_space);
|
||||
}
|
||||
else
|
||||
{
|
||||
mThreadwiseLoad.Run(
|
||||
p_block_src, p_thread_buffer, block_src_address_space, thread_buffer_address_space);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename BlockSrcData, typename ThreadBufferData>
|
||||
__device__ void RunLoadThreadBuffer(const BlockSrcData* p_block_src,
|
||||
ThreadBufferData* p_thread_buffer) const
|
||||
{
|
||||
constexpr auto generic_address_space =
|
||||
integral_constant<AddressSpace, AddressSpace::generic>{};
|
||||
|
||||
RunLoadThreadBuffer(
|
||||
p_block_src, p_thread_buffer, generic_address_space, generic_address_space);
|
||||
}
|
||||
|
||||
template <typename ThreadBufferData,
|
||||
typename BlockDstData,
|
||||
AddressSpace ThreadBufferAddressSpace,
|
||||
AddressSpace BlockDstAddressSpace>
|
||||
__device__ void
|
||||
RunStoreThreadBuffer(const ThreadBufferData* p_thread_buffer,
|
||||
BlockDstData* p_block_dst,
|
||||
integral_constant<AddressSpace, ThreadBufferAddressSpace>,
|
||||
integral_constant<AddressSpace, BlockDstAddressSpace>) const
|
||||
{
|
||||
constexpr auto thread_buffer_address_space =
|
||||
integral_constant<AddressSpace, ThreadBufferAddressSpace>{};
|
||||
constexpr auto block_dst_address_space =
|
||||
integral_constant<AddressSpace, BlockDstAddressSpace>{};
|
||||
|
||||
constexpr bool has_optimized_address_calculation =
|
||||
decltype(mThreadwiseStore)::HasWorkingOptimizedAddressCalculation();
|
||||
|
||||
// TODO: threadwise copy is still being tweaked
|
||||
if(has_optimized_address_calculation)
|
||||
{
|
||||
mThreadwiseStore.Run_optimized_dst_address_calculation(
|
||||
p_thread_buffer, p_block_dst, thread_buffer_address_space, block_dst_address_space);
|
||||
mThreadwiseLoad.Run_optimized_src_address_calculation(p_block_src, p_thread_buffer);
|
||||
}
|
||||
else
|
||||
{
|
||||
mThreadwiseStore.Run(
|
||||
p_thread_buffer, p_block_dst, thread_buffer_address_space, block_dst_address_space);
|
||||
mThreadwiseLoad.Run(p_block_src, p_thread_buffer);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -143,43 +92,35 @@ struct BlockwiseGenericTensorSliceCopy_v4
|
||||
__device__ void RunStoreThreadBuffer(const ThreadBufferData* p_thread_buffer,
|
||||
BlockDstData* p_block_dst) const
|
||||
{
|
||||
constexpr auto generic_address_space =
|
||||
integral_constant<AddressSpace, AddressSpace::generic>{};
|
||||
constexpr bool has_optimized_address_calculation =
|
||||
decltype(mThreadwiseStore)::HasWorkingOptimizedAddressCalculation();
|
||||
|
||||
RunStoreThreadBuffer(
|
||||
p_thread_buffer, p_block_dst, generic_address_space, generic_address_space);
|
||||
}
|
||||
|
||||
template <typename BlockSrcData,
|
||||
typename BlockDstData,
|
||||
AddressSpace BlockSrcAddressSpace,
|
||||
AddressSpace BlockDstAddressSpace>
|
||||
__device__ void
|
||||
Run(const BlockSrcData* p_block_src,
|
||||
BlockDstData* p_block_dst,
|
||||
integral_constant<AddressSpace, BlockSrcAddressSpace> block_src_address_space,
|
||||
integral_constant<AddressSpace, BlockDstAddressSpace> block_dst_address_space) const
|
||||
{
|
||||
BlockSrcData p_thread_buffer[GetThreadBufferSize()];
|
||||
|
||||
constexpr auto generic_address_space =
|
||||
integral_constant<AddressSpace, AddressSpace::generic>{};
|
||||
|
||||
RunLoadThreadBuffer(
|
||||
p_block_src, p_thread_buffer, block_src_address_space, generic_address_space);
|
||||
|
||||
// if there is type conversion, it's done during store
|
||||
RunStoreThreadBuffer(
|
||||
p_thread_buffer, p_block_dst, generic_address_space, block_dst_address_space);
|
||||
// TODO: threadwise copy is still being tweaked
|
||||
if(has_optimized_address_calculation)
|
||||
{
|
||||
mThreadwiseStore.Run_optimized_dst_address_calculation(p_thread_buffer, p_block_dst);
|
||||
}
|
||||
else
|
||||
{
|
||||
mThreadwiseStore.Run(p_thread_buffer, p_block_dst);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename BlockSrcData, typename BlockDstData>
|
||||
__device__ void Run(const BlockSrcData* p_block_src, BlockDstData* p_block_dst) const
|
||||
{
|
||||
constexpr auto generic_address_space =
|
||||
integral_constant<AddressSpace, AddressSpace::generic>{};
|
||||
static_assert(ThreadBufferAddressSpace == AddressSpace::vgpr,
|
||||
"wrong! This function use vgpr as its thread "
|
||||
"buffer. However, you have set RunLoadThreadBuffer and RunStoreThreadBuffer "
|
||||
"to use ThreadBufferAddressSpace as their thread buffer, which is not vgpr. "
|
||||
"Behavior may be different");
|
||||
|
||||
Run(p_block_src, p_block_dst, generic_address_space, generic_address_space);
|
||||
BlockSrcData p_thread_buffer[GetThreadBufferSize()];
|
||||
|
||||
RunLoadThreadBuffer(p_block_src, p_thread_buffer);
|
||||
|
||||
// if there is type conversion, it's done during store
|
||||
RunStoreThreadBuffer(p_thread_buffer, p_block_dst);
|
||||
}
|
||||
|
||||
template <typename T, bool PositiveDirection>
|
||||
@@ -207,7 +148,10 @@ struct BlockwiseGenericTensorSliceCopy_v4
|
||||
SrcDimAccessOrder,
|
||||
SrcVectorAccessDim,
|
||||
SrcDataPerAccess,
|
||||
1>;
|
||||
1,
|
||||
SrcAddressSpace,
|
||||
ThreadBufferAddressSpace,
|
||||
InMemoryDataOperation::none>;
|
||||
|
||||
using ThreadwiseStore = ThreadwiseGenericTensorSliceCopy_v4r2<ThreadBufferDesc,
|
||||
BlockDstDesc,
|
||||
@@ -215,7 +159,10 @@ struct BlockwiseGenericTensorSliceCopy_v4
|
||||
DstDimAccessOrder,
|
||||
DstVectorAccessDim,
|
||||
1,
|
||||
DstDataPerAccess>;
|
||||
DstDataPerAccess,
|
||||
ThreadBufferAddressSpace,
|
||||
DstAddressSpace,
|
||||
DstInMemOp>;
|
||||
|
||||
ThreadwiseLoad mThreadwiseLoad;
|
||||
ThreadwiseStore mThreadwiseStore;
|
||||
|
||||
325
composable_kernel/include/tensor_operation/gridwise_gemm.hpp
Normal file
325
composable_kernel/include/tensor_operation/gridwise_gemm.hpp
Normal file
@@ -0,0 +1,325 @@
|
||||
#ifndef CK_GRIDWISE_GEMM_HPP
|
||||
#define CK_GRIDWISE_GEMM_HPP
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "tensor_descriptor.hpp"
|
||||
#include "tensor_descriptor_helper.hpp"
|
||||
#include "ConstantMatrixDescriptor.hpp"
|
||||
#include "blockwise_generic_tensor_slice_copy.hpp"
|
||||
#include "threadwise_generic_tensor_slice_copy.hpp"
|
||||
#include "blockwise_gemm.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <index_t GridSize,
|
||||
index_t BlockSize,
|
||||
typename Float,
|
||||
typename AccFloat,
|
||||
typename AGlobalDesc,
|
||||
typename BGlobalDesc,
|
||||
typename CGlobalDesc,
|
||||
InMemoryDataOperation CGlobalMemoryDataOperation,
|
||||
index_t MPerBlock,
|
||||
index_t NPerBlock,
|
||||
index_t KPerBlock,
|
||||
index_t MPerThreadSubC,
|
||||
index_t NPerThreadSubC,
|
||||
index_t MLevel0Cluster,
|
||||
index_t NLevel0Cluster,
|
||||
index_t MLevel1Cluster,
|
||||
index_t NLevel1Cluster,
|
||||
index_t KPerThreadLoop,
|
||||
index_t ThreadGemmDataPerReadM,
|
||||
index_t ThreadGemmDataPerReadN,
|
||||
typename ABlockCopySubLengths_K_M,
|
||||
typename ABlockCopyClusterLengths_K_M,
|
||||
index_t ABlockCopyDataPerAccess_M,
|
||||
typename BBlockCopySubLengths_K_N,
|
||||
typename BBlockCopyClusterLengths_K_N,
|
||||
index_t BBlockCopyDataPerAccess_N,
|
||||
index_t CThreadCopyDataPerAccess_N>
|
||||
struct GridwiseGemmTransposedANormalBNormalC_v1r1
|
||||
{
|
||||
__device__ void Run(const Float* __restrict__ p_a_global,
|
||||
const Float* __restrict__ p_b_global,
|
||||
Float* __restrict__ p_c_global) const
|
||||
{
|
||||
constexpr auto True = integral_constant<bool, true>{};
|
||||
|
||||
constexpr auto a_k_m_global_desc = AGlobalDesc{};
|
||||
constexpr auto b_k_n_global_desc = BGlobalDesc{};
|
||||
constexpr auto c_m_n_global_desc = CGlobalDesc{};
|
||||
|
||||
constexpr auto K = a_k_m_global_desc.GetLengths()[0];
|
||||
constexpr auto M = a_k_m_global_desc.GetLengths()[1];
|
||||
constexpr auto N = b_k_n_global_desc.GetLengths()[1];
|
||||
|
||||
// lds max alignment
|
||||
constexpr index_t max_lds_align = math::lcm(ABlockCopyDataPerAccess_M,
|
||||
BBlockCopyDataPerAccess_N,
|
||||
ThreadGemmDataPerReadM,
|
||||
ThreadGemmDataPerReadN);
|
||||
|
||||
// divide block work by [M, N]
|
||||
static_assert(M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0,
|
||||
"wrong! cannot divide work evenly among block");
|
||||
|
||||
constexpr index_t MBlockWork = M / MPerBlock;
|
||||
constexpr index_t NBlockWork = N / NPerBlock;
|
||||
|
||||
constexpr auto block_work_desc =
|
||||
make_cluster_descriptor(Sequence<MBlockWork, NBlockWork>{});
|
||||
|
||||
const auto block_work_id = block_work_desc.CalculateClusterIndex(get_block_1d_id());
|
||||
|
||||
const index_t m_block_data_on_global = block_work_id[0] * MPerBlock;
|
||||
const index_t n_block_data_on_global = block_work_id[1] * NPerBlock;
|
||||
|
||||
// A matrix in LDS memory, dst of blockwise copy
|
||||
// be careful of LDS alignment
|
||||
constexpr auto a_k_m_block_desc = make_native_tensor_descriptor_aligned(
|
||||
Sequence<KPerBlock, MPerBlock>{}, Number<max_lds_align>{});
|
||||
|
||||
// A matrix blockwise copy
|
||||
auto a_blockwise_copy =
|
||||
BlockwiseGenericTensorSliceCopy_v4<BlockSize,
|
||||
decltype(a_k_m_global_desc),
|
||||
decltype(a_k_m_block_desc),
|
||||
decltype(a_k_m_block_desc.GetLengths()),
|
||||
ABlockCopySubLengths_K_M,
|
||||
ABlockCopyClusterLengths_K_M,
|
||||
Sequence<0, 1>,
|
||||
Sequence<0, 1>,
|
||||
Sequence<0, 1>,
|
||||
1,
|
||||
1,
|
||||
ABlockCopyDataPerAccess_M,
|
||||
ABlockCopyDataPerAccess_M,
|
||||
AddressSpace::global,
|
||||
AddressSpace::vgpr,
|
||||
AddressSpace::lds,
|
||||
InMemoryDataOperation::none>(
|
||||
{0, m_block_data_on_global}, {0, 0});
|
||||
|
||||
// B matrix in LDS memory, dst of blockwise copy
|
||||
// be careful of LDS alignment
|
||||
constexpr auto b_k_n_block_desc = make_native_tensor_descriptor_aligned(
|
||||
Sequence<KPerBlock, NPerBlock>{}, Number<max_lds_align>{});
|
||||
|
||||
// B matrix blockwise copy
|
||||
auto b_blockwise_copy =
|
||||
BlockwiseGenericTensorSliceCopy_v4<BlockSize,
|
||||
decltype(b_k_n_global_desc),
|
||||
decltype(b_k_n_block_desc),
|
||||
decltype(b_k_n_block_desc.GetLengths()),
|
||||
BBlockCopySubLengths_K_N,
|
||||
BBlockCopyClusterLengths_K_N,
|
||||
Sequence<0, 1>,
|
||||
Sequence<0, 1>,
|
||||
Sequence<0, 1>,
|
||||
1,
|
||||
1,
|
||||
BBlockCopyDataPerAccess_N,
|
||||
BBlockCopyDataPerAccess_N,
|
||||
AddressSpace::global,
|
||||
AddressSpace::vgpr,
|
||||
AddressSpace::lds,
|
||||
InMemoryDataOperation::none>(
|
||||
{0, n_block_data_on_global}, {0, 0});
|
||||
|
||||
// GEMM definition
|
||||
// c_mtx += transpose(a_mtx) * b_mtx
|
||||
// a_mtx[KPerBlock, MPerBlock] is in LDS
|
||||
// b_mtx[KPerBlocl, NPerBlock] is in LDS
|
||||
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
|
||||
// register
|
||||
constexpr auto a_k_m_block_mtx_desc = make_ConstantMatrixDescriptor(a_k_m_block_desc);
|
||||
constexpr auto b_k_n_block_mtx_desc = make_ConstantMatrixDescriptor(b_k_n_block_desc);
|
||||
|
||||
// sanity check
|
||||
static_assert(MPerBlock % (MPerThreadSubC * MLevel0Cluster * MLevel1Cluster) == 0 &&
|
||||
NPerBlock % (NPerThreadSubC * NLevel0Cluster * NLevel1Cluster) == 0,
|
||||
"wrong!");
|
||||
|
||||
constexpr index_t GemmMRepeat =
|
||||
MPerBlock / (MPerThreadSubC * MLevel0Cluster * MLevel1Cluster);
|
||||
|
||||
constexpr index_t GemmNRepeat =
|
||||
NPerBlock / (NPerThreadSubC * NLevel0Cluster * NLevel1Cluster);
|
||||
|
||||
// c_thread_mtx definition: this is a mess
|
||||
// TODO:: more elegent way of defining c_thread_mtx
|
||||
constexpr auto c_m0m1_n0n1_thread_mtx_desc = make_ConstantMatrixDescriptor_packed(
|
||||
Number<GemmMRepeat * MPerThreadSubC>{}, Number<GemmNRepeat * NPerThreadSubC>{});
|
||||
|
||||
const auto blockwise_gemm = BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2<
|
||||
BlockSize,
|
||||
decltype(a_k_m_block_mtx_desc),
|
||||
decltype(b_k_n_block_mtx_desc),
|
||||
decltype(c_m0m1_n0n1_thread_mtx_desc),
|
||||
MPerThreadSubC,
|
||||
NPerThreadSubC,
|
||||
MLevel0Cluster,
|
||||
NLevel0Cluster,
|
||||
MLevel1Cluster,
|
||||
NLevel1Cluster,
|
||||
KPerThreadLoop,
|
||||
ThreadGemmDataPerReadM,
|
||||
ThreadGemmDataPerReadN>{};
|
||||
|
||||
// LDS allocation for A and B: be careful of alignment
|
||||
constexpr index_t a_block_space =
|
||||
math::integer_least_multiple(a_k_m_block_desc.GetElementSpace(), max_lds_align);
|
||||
|
||||
constexpr index_t b_block_space =
|
||||
math::integer_least_multiple(b_k_n_block_desc.GetElementSpace(), max_lds_align);
|
||||
|
||||
__shared__ Float p_a_block_double[2 * a_block_space];
|
||||
__shared__ Float p_b_block_double[2 * b_block_space];
|
||||
|
||||
// register allocation for output
|
||||
AccFloat p_c_thread[c_m0m1_n0n1_thread_mtx_desc.GetElementSpace()];
|
||||
|
||||
// zero out threadwise output
|
||||
threadwise_matrix_set_zero(c_m0m1_n0n1_thread_mtx_desc, p_c_thread);
|
||||
|
||||
// LDS double buffer: preload data into LDS
|
||||
{
|
||||
a_blockwise_copy.Run(p_a_global, p_a_block_double);
|
||||
b_blockwise_copy.Run(p_b_global, p_b_block_double);
|
||||
}
|
||||
|
||||
// LDS double buffer: main body
|
||||
for(index_t k_block_data_begin = 0; k_block_data_begin + 2 * KPerBlock < K;
|
||||
k_block_data_begin += 2 * KPerBlock)
|
||||
{
|
||||
#pragma unroll
|
||||
for(index_t iloop = 0; iloop < 2; ++iloop)
|
||||
{
|
||||
const bool even_loop = (iloop % 2 == 0);
|
||||
|
||||
Float* p_a_block_now =
|
||||
even_loop ? p_a_block_double : p_a_block_double + a_block_space;
|
||||
Float* p_b_block_now =
|
||||
even_loop ? p_b_block_double : p_b_block_double + b_block_space;
|
||||
|
||||
Float* p_a_block_next =
|
||||
even_loop ? p_a_block_double + a_block_space : p_a_block_double;
|
||||
Float* p_b_block_next =
|
||||
even_loop ? p_b_block_double + b_block_space : p_b_block_double;
|
||||
|
||||
Float p_a_thread_buffer[a_blockwise_copy.GetThreadBufferSize()];
|
||||
Float p_b_thread_buffer[b_blockwise_copy.GetThreadBufferSize()];
|
||||
|
||||
a_blockwise_copy.MoveSrcSliceWindow(Sequence<KPerBlock, 0>{}, True);
|
||||
b_blockwise_copy.MoveSrcSliceWindow(Sequence<KPerBlock, 0>{}, True);
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// LDS doubel buffer: load next data from device mem
|
||||
a_blockwise_copy.RunLoadThreadBuffer(p_a_global, p_a_thread_buffer);
|
||||
b_blockwise_copy.RunLoadThreadBuffer(p_b_global, p_b_thread_buffer);
|
||||
|
||||
// LDS double buffer: GEMM on current data
|
||||
blockwise_gemm.Run(p_a_block_now, p_b_block_now, p_c_thread);
|
||||
|
||||
// LDS double buffer: store next data to LDS
|
||||
a_blockwise_copy.RunStoreThreadBuffer(p_a_thread_buffer, p_a_block_next);
|
||||
b_blockwise_copy.RunStoreThreadBuffer(p_b_thread_buffer, p_b_block_next);
|
||||
}
|
||||
}
|
||||
|
||||
// LDS double buffer: tail
|
||||
{
|
||||
constexpr bool has_two_iteration_left = (K % (2 * KPerBlock) == 0);
|
||||
|
||||
if(has_two_iteration_left) // if has 2 iteration left
|
||||
{
|
||||
Float p_a_thread_buffer[a_blockwise_copy.GetThreadBufferSize()];
|
||||
Float p_b_thread_buffer[b_blockwise_copy.GetThreadBufferSize()];
|
||||
|
||||
a_blockwise_copy.MoveSrcSliceWindow(Sequence<KPerBlock, 0>{}, True);
|
||||
b_blockwise_copy.MoveSrcSliceWindow(Sequence<KPerBlock, 0>{}, True);
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// LDS double buffer: load last data from device mem
|
||||
a_blockwise_copy.RunLoadThreadBuffer(p_a_global, p_a_thread_buffer);
|
||||
b_blockwise_copy.RunLoadThreadBuffer(p_b_global, p_b_thread_buffer);
|
||||
|
||||
// LDS double buffer: GEMM on 2nd-last data
|
||||
blockwise_gemm.Run(p_a_block_double, p_b_block_double, p_c_thread);
|
||||
|
||||
// LDS double buffer: store last data to LDS
|
||||
a_blockwise_copy.RunStoreThreadBuffer(p_a_thread_buffer,
|
||||
p_a_block_double + a_block_space);
|
||||
b_blockwise_copy.RunStoreThreadBuffer(p_b_thread_buffer,
|
||||
p_b_block_double + b_block_space);
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// LDS double buffer: GEMM on last data
|
||||
blockwise_gemm.Run(
|
||||
p_a_block_double + a_block_space, p_b_block_double + b_block_space, p_c_thread);
|
||||
}
|
||||
else // if has 1 iteration left
|
||||
{
|
||||
__syncthreads();
|
||||
|
||||
// LDS double buffer: GEMM on last data
|
||||
blockwise_gemm.Run(p_a_block_double, p_b_block_double, p_c_thread);
|
||||
}
|
||||
}
|
||||
|
||||
// input: register to global memory
|
||||
{
|
||||
constexpr index_t M1 = MPerThreadSubC * MLevel0Cluster * MLevel1Cluster;
|
||||
constexpr index_t M0 = M / M1;
|
||||
|
||||
constexpr index_t N1 = NPerThreadSubC * NLevel0Cluster * NLevel1Cluster;
|
||||
constexpr index_t N0 = N / N1;
|
||||
|
||||
// define input tensor descriptor for threadwise copy
|
||||
// thread input tensor, src of threadwise copy
|
||||
constexpr auto c_m0_m1_n0_n1_thread_desc = make_native_tensor_descriptor_packed(
|
||||
Sequence<GemmMRepeat, MPerThreadSubC, GemmNRepeat, NPerThreadSubC>{});
|
||||
|
||||
constexpr auto c_m0_m1_n0_n1_global_desc = transform_tensor_descriptor(
|
||||
c_m_n_global_desc,
|
||||
make_tuple(UnMerge<Sequence<M0, M1>>{}, UnMerge<Sequence<N0, N1>>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{}));
|
||||
|
||||
// calculate origin of thread input tensor on global memory
|
||||
// blockwise GEMM c matrix starting index
|
||||
const auto c_thread_mtx_on_block =
|
||||
blockwise_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id());
|
||||
|
||||
const index_t m_thread_data_on_global =
|
||||
m_block_data_on_global + c_thread_mtx_on_block.row;
|
||||
|
||||
const index_t n_thread_data_on_global =
|
||||
n_block_data_on_global + c_thread_mtx_on_block.col;
|
||||
|
||||
ThreadwiseGenericTensorSliceCopy_v4r2<decltype(c_m0_m1_n0_n1_thread_desc),
|
||||
decltype(c_m0_m1_n0_n1_global_desc),
|
||||
decltype(c_m0_m1_n0_n1_thread_desc.GetLengths()),
|
||||
Sequence<0, 1, 2, 3>,
|
||||
3,
|
||||
CThreadCopyDataPerAccess_N,
|
||||
CThreadCopyDataPerAccess_N,
|
||||
AddressSpace::vgpr,
|
||||
AddressSpace::global,
|
||||
CGlobalMemoryDataOperation>(
|
||||
{0, 0, 0, 0},
|
||||
{m_thread_data_on_global / M1,
|
||||
m_thread_data_on_global % M1,
|
||||
n_thread_data_on_global / N1,
|
||||
n_thread_data_on_global % N1})
|
||||
.Run(p_c_thread, p_c_global);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
@@ -21,7 +21,10 @@ template <typename SrcDesc,
|
||||
typename DimAccessOrder,
|
||||
index_t VectorAccessDim,
|
||||
index_t SrcDataPerAccess,
|
||||
index_t DstDataPerAccess>
|
||||
index_t DstDataPerAccess,
|
||||
AddressSpace SrcAddressSpace = AddressSpace::generic,
|
||||
AddressSpace DstAddressSpace = AddressSpace::generic,
|
||||
InMemoryDataOperation DstInMemOp = InMemoryDataOperation::none>
|
||||
struct ThreadwiseGenericTensorSliceCopy_v4r2
|
||||
{
|
||||
static constexpr index_t nDim = SliceLengths::Size();
|
||||
@@ -66,18 +69,9 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
|
||||
|
||||
// Will do padding check on src data: Read 0 if src data is in padding area.
|
||||
// Will do padding check on dst data: No write if dst data is in paddin area.
|
||||
template <typename SrcData,
|
||||
typename DstData,
|
||||
AddressSpace SrcAddressSpace,
|
||||
AddressSpace DstAddressSpace>
|
||||
__device__ void Run(const SrcData* p_src,
|
||||
DstData* p_dst,
|
||||
integral_constant<AddressSpace, SrcAddressSpace>,
|
||||
integral_constant<AddressSpace, DstAddressSpace>) const
|
||||
template <typename SrcData, typename DstData>
|
||||
__device__ void Run(const SrcData* p_src, DstData* p_dst) const
|
||||
{
|
||||
using src_vector_t = typename vector_type<SrcData, SrcDataPerAccess>::MemoryType;
|
||||
using dst_vector_t = typename vector_type<DstData, DstDataPerAccess>::MemoryType;
|
||||
|
||||
constexpr auto vector_access_dim = Number<VectorAccessDim>{};
|
||||
|
||||
constexpr auto src_data_per_access = Number<SrcDataPerAccess>{};
|
||||
@@ -120,20 +114,12 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
|
||||
// has the same padding situation
|
||||
if(src_coord.IsUpperIndexMappedToValidOffset())
|
||||
{
|
||||
static_if<SrcAddressSpace == AddressSpace::global>{}([&](auto fwd) {
|
||||
#if CK_USE_AMD_BUFFER_ADDRESSING
|
||||
*reinterpret_cast<src_vector_t*>(&p_src_long_vector[buffer_offset]) =
|
||||
amd_intrinsic_buffer_load<SrcData, SrcDataPerAccess>(
|
||||
fwd(p_src), src_coord.GetOffset(), 0);
|
||||
#else
|
||||
*reinterpret_cast<src_vector_t*>(&p_src_long_vector[buffer_offset]) =
|
||||
*reinterpret_cast<const src_vector_t*>(&p_src[src_coord.GetOffset()]);
|
||||
#endif
|
||||
}).Else([&](auto) {
|
||||
// src can be all kinds of memory-space.
|
||||
*reinterpret_cast<src_vector_t*>(&p_src_long_vector[buffer_offset]) =
|
||||
*reinterpret_cast<const src_vector_t*>(&p_src[src_coord.GetOffset()]);
|
||||
});
|
||||
move_data<SrcData,
|
||||
SrcDataPerAccess,
|
||||
SrcAddressSpace,
|
||||
AddressSpace::vgpr,
|
||||
InMemoryDataOperation::none>(
|
||||
p_src, src_coord.GetOffset(), p_src_long_vector, buffer_offset);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -160,36 +146,17 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
|
||||
// has the same padding situation
|
||||
if(dst_coord.IsUpperIndexMappedToValidOffset())
|
||||
{
|
||||
static_if<DstAddressSpace == AddressSpace::global>{}([&](auto fwd) {
|
||||
#if CK_USE_AMD_BUFFER_ADDRESSING
|
||||
amd_intrinsic_buffer_store<DstData, DstDataPerAccess>(
|
||||
*reinterpret_cast<dst_vector_t*>(&p_dst_long_vector[buffer_offset]),
|
||||
fwd(p_dst),
|
||||
dst_coord.GetOffset(),
|
||||
0);
|
||||
#else
|
||||
*reinterpret_cast<dst_vector_t*>(&p_dst[dst_coord.GetOffset()]) =
|
||||
*reinterpret_cast<dst_vector_t*>(&p_dst_long_vector[buffer_offset]);
|
||||
#endif
|
||||
}).Else([&](auto) {
|
||||
// dst can be all kinds of memory-space
|
||||
*reinterpret_cast<dst_vector_t*>(&p_dst[dst_coord.GetOffset()]) =
|
||||
*reinterpret_cast<dst_vector_t*>(&p_dst_long_vector[buffer_offset]);
|
||||
});
|
||||
move_data<DstData,
|
||||
DstDataPerAccess,
|
||||
AddressSpace::vgpr,
|
||||
DstAddressSpace,
|
||||
DstInMemOp>(
|
||||
p_dst_long_vector, buffer_offset, p_dst, dst_coord.GetOffset());
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
template <typename SrcData, typename DstData>
|
||||
__device__ void Run(const SrcData* p_src, DstData* p_dst) const
|
||||
{
|
||||
constexpr auto generic_address_space =
|
||||
integral_constant<AddressSpace, AddressSpace::generic>{};
|
||||
|
||||
Run(p_src, p_dst, generic_address_space, generic_address_space);
|
||||
}
|
||||
|
||||
// Modify Length to 1, if Mask is set to false
|
||||
// Used for isolating linear dimension from non-linear dimensions
|
||||
template <index_t... Lengths, index_t... Mask>
|
||||
@@ -198,26 +165,14 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
|
||||
return Sequence<(Mask ? Lengths : 1)...>{};
|
||||
}
|
||||
|
||||
// p_src must be global-memory, p_dst can be any memory-space.
|
||||
// User should make sure p_src is a block-invariant pointer, because
|
||||
// buffer_load is used for loading from global-memory into register buffer.
|
||||
// Will do padding check on src data: Read 0 if src data is in padding area.
|
||||
// Will do padding check on dst data: No write if dst data is in paddin area.
|
||||
// This version is optimized for address calculation of src tensor
|
||||
// TODO: this function is not compiled to expected ISA
|
||||
template <typename SrcData,
|
||||
typename DstData,
|
||||
AddressSpace SrcAddressSpace,
|
||||
AddressSpace DstAddressSpace>
|
||||
__device__ void
|
||||
Run_optimized_src_address_calculation(const SrcData* p_src,
|
||||
DstData* p_dst,
|
||||
integral_constant<AddressSpace, SrcAddressSpace>,
|
||||
integral_constant<AddressSpace, DstAddressSpace>) const
|
||||
template <typename SrcData, typename DstData>
|
||||
__device__ void Run_optimized_src_address_calculation(const SrcData* p_src,
|
||||
DstData* p_dst) const
|
||||
{
|
||||
using src_vector_t = typename vector_type<SrcData, SrcDataPerAccess>::MemoryType;
|
||||
using dst_vector_t = typename vector_type<DstData, DstDataPerAccess>::MemoryType;
|
||||
|
||||
constexpr auto vector_access_dim = Number<VectorAccessDim>{};
|
||||
|
||||
constexpr auto src_data_per_access = Number<SrcDataPerAccess>{};
|
||||
@@ -308,21 +263,15 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
|
||||
// the src vector has the same padding situation
|
||||
if(src_coord.IsUpperIndexMappedToValidOffset())
|
||||
{
|
||||
static_if<SrcAddressSpace == AddressSpace::global>{}([&](auto) {
|
||||
#if CK_USE_AMD_BUFFER_ADDRESSING
|
||||
*reinterpret_cast<src_vector_t*>(&p_src_long_vector[buffer_offset]) =
|
||||
amd_intrinsic_buffer_load<SrcData, SrcDataPerAccess>(
|
||||
p_src, src_nonlinear_coord.GetOffset(), src_linear_offset);
|
||||
#else
|
||||
*reinterpret_cast<src_vector_t*>(&p_src_long_vector[buffer_offset]) =
|
||||
*reinterpret_cast<const src_vector_t*>(
|
||||
&p_src[src_nonlinear_coord.GetOffset() + src_linear_offset]);
|
||||
#endif
|
||||
}).Else([&](auto) {
|
||||
*reinterpret_cast<src_vector_t*>(&p_src_long_vector[buffer_offset]) =
|
||||
*reinterpret_cast<const src_vector_t*>(
|
||||
&p_src[src_nonlinear_coord.GetOffset() + src_linear_offset]);
|
||||
});
|
||||
move_data<SrcData,
|
||||
SrcDataPerAccess,
|
||||
SrcAddressSpace,
|
||||
AddressSpace::vgpr,
|
||||
InMemoryDataOperation::none>(p_src,
|
||||
src_nonlinear_coord.GetOffset() +
|
||||
src_linear_offset,
|
||||
p_src_long_vector,
|
||||
buffer_offset);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -352,34 +301,26 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
|
||||
// the dst vector has the same padding situation
|
||||
if(dst_coord.IsUpperIndexMappedToValidOffset())
|
||||
{
|
||||
*reinterpret_cast<dst_vector_t*>(&p_dst[dst_coord.GetOffset()]) =
|
||||
*reinterpret_cast<dst_vector_t*>(&p_dst_long_vector[buffer_offset]);
|
||||
move_data<DstData,
|
||||
DstDataPerAccess,
|
||||
AddressSpace::vgpr,
|
||||
DstAddressSpace,
|
||||
DstInMemOp>(
|
||||
p_dst_long_vector, buffer_offset, p_dst, dst_coord.GetOffset());
|
||||
}
|
||||
}
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
// p_src could be any memory space, d_dst must be global memory.
|
||||
// User should make sure p_dst is a block-invariant pointer, because
|
||||
// buffer_load is used for storing data from regsiter buffer into global-memory.
|
||||
// Will do padding check on src data: Read 0 if src data is in padding area.
|
||||
// Will do padding check on dst data: No write if dst data is in paddin area.
|
||||
// This version is optimized for address calculation of dst tensor
|
||||
// TODO: this function is not compiled to expected ISA
|
||||
template <typename SrcData,
|
||||
typename DstData,
|
||||
AddressSpace SrcAddressSpace,
|
||||
AddressSpace DstAddressSpace>
|
||||
__device__ void
|
||||
Run_optimized_dst_address_calculation(const SrcData* p_src,
|
||||
DstData* p_dst,
|
||||
integral_constant<AddressSpace, SrcAddressSpace>,
|
||||
integral_constant<AddressSpace, DstAddressSpace>) const
|
||||
template <typename SrcData, typename DstData>
|
||||
__device__ void Run_optimized_dst_address_calculation(const SrcData* p_src,
|
||||
DstData* p_dst) const
|
||||
{
|
||||
using src_vector_t = typename vector_type<SrcData, SrcDataPerAccess>::MemoryType;
|
||||
using dst_vector_t = typename vector_type<DstData, DstDataPerAccess>::MemoryType;
|
||||
|
||||
constexpr auto vector_access_dim = Number<VectorAccessDim>{};
|
||||
|
||||
constexpr auto src_data_per_access = Number<SrcDataPerAccess>{};
|
||||
@@ -461,8 +402,12 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
|
||||
// the src vector has the same padding situation
|
||||
if(src_coord.IsUpperIndexMappedToValidOffset())
|
||||
{
|
||||
*reinterpret_cast<src_vector_t*>(&p_src_long_vector[buffer_offset]) =
|
||||
*reinterpret_cast<const src_vector_t*>(&p_src[src_coord.GetOffset()]);
|
||||
move_data<SrcData,
|
||||
SrcDataPerAccess,
|
||||
SrcAddressSpace,
|
||||
AddressSpace::vgpr,
|
||||
InMemoryDataOperation::none>(
|
||||
p_src, src_coord.GetOffset(), p_src_long_vector, buffer_offset);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -501,23 +446,14 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
|
||||
// the dst vector has the same padding situation
|
||||
if(dst_coord.IsUpperIndexMappedToValidOffset())
|
||||
{
|
||||
static_if<DstAddressSpace == AddressSpace::global>{}([&](auto) {
|
||||
#if CK_USE_AMD_BUFFER_ADDRESSING
|
||||
amd_intrinsic_buffer_store<DstData, DstDataPerAccess>(
|
||||
*reinterpret_cast<dst_vector_t*>(&p_dst_long_vector[buffer_offset]),
|
||||
p_dst,
|
||||
dst_nonlinear_coord.GetOffset(),
|
||||
dst_linear_offset);
|
||||
#else
|
||||
*reinterpret_cast<dst_vector_t*>(
|
||||
&p_dst[dst_nonlinear_coord.GetOffset() + dst_linear_offset]) =
|
||||
*reinterpret_cast<dst_vector_t*>(&p_dst_long_vector[buffer_offset]);
|
||||
#endif
|
||||
}).Else([&](auto) {
|
||||
*reinterpret_cast<dst_vector_t*>(
|
||||
&p_dst[dst_nonlinear_coord.GetOffset() + dst_linear_offset]) =
|
||||
*reinterpret_cast<dst_vector_t*>(&p_dst_long_vector[buffer_offset]);
|
||||
});
|
||||
move_data<DstData,
|
||||
DstDataPerAccess,
|
||||
AddressSpace::vgpr,
|
||||
DstAddressSpace,
|
||||
DstInMemOp>(p_dst_long_vector,
|
||||
buffer_offset,
|
||||
p_dst,
|
||||
dst_nonlinear_coord.GetOffset() + dst_linear_offset);
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
@@ -54,10 +54,18 @@ __device__ void __llvm_amdgcn_buffer_storex4(float4_t vdata,
|
||||
bool glc,
|
||||
bool slc) __asm("llvm.amdgcn.buffer.store.v4f32");
|
||||
|
||||
// buffer_load requires:
|
||||
// 1) p_src must be in global memory space, d_dst must be vgpr
|
||||
// 2) p_src to be a block-invariant pointer.
|
||||
// It is user's responsibility to make sure that is true.
|
||||
template <typename T, index_t VectorSize>
|
||||
__device__ typename vector_type<T, VectorSize>::MemoryType amd_intrinsic_buffer_load(
|
||||
const T* p_src_block, index_t src_thread_data_offset, index_t src_const_data_offset);
|
||||
|
||||
// buffer_store requires:
|
||||
// 1) p_src must be in vgpr space, d_dst must be global memory
|
||||
// 2) p_dst to be a block-invariant pointer.
|
||||
// It is user's responsibility to make sure that is true.
|
||||
template <typename T, index_t VectorSize>
|
||||
__device__ void
|
||||
amd_intrinsic_buffer_store(const typename vector_type<T, VectorSize>::MemoryType& src,
|
||||
|
||||
@@ -15,6 +15,7 @@
|
||||
#include "functional2.hpp"
|
||||
#include "functional3.hpp"
|
||||
#include "functional4.hpp"
|
||||
#include "in_memory_operation.hpp"
|
||||
|
||||
#if CK_USE_AMD_INLINE_ASM
|
||||
#include "amd_inline_asm.hpp"
|
||||
|
||||
@@ -54,7 +54,15 @@ namespace ck {
|
||||
enum AddressSpace
|
||||
{
|
||||
generic,
|
||||
global
|
||||
global,
|
||||
lds,
|
||||
vgpr
|
||||
};
|
||||
|
||||
enum InMemoryDataOperation
|
||||
{
|
||||
none,
|
||||
atomic_add
|
||||
};
|
||||
|
||||
#if CK_UNSIGNED_INDEX_TYPE
|
||||
|
||||
@@ -33,7 +33,15 @@ namespace ck {
|
||||
enum AddressSpace
|
||||
{
|
||||
generic,
|
||||
global = generic
|
||||
global,
|
||||
lds,
|
||||
vgpr
|
||||
};
|
||||
|
||||
enum InMemoryDataOperation
|
||||
{
|
||||
none,
|
||||
atomic_add
|
||||
};
|
||||
|
||||
#if CK_UNSIGNED_INDEX_TYPE
|
||||
|
||||
@@ -64,9 +64,8 @@ struct static_if<true>
|
||||
}
|
||||
|
||||
template <typename F>
|
||||
__host__ __device__ static constexpr auto Else(F)
|
||||
__host__ __device__ static void Else(F)
|
||||
{
|
||||
return Type{};
|
||||
}
|
||||
};
|
||||
|
||||
@@ -82,14 +81,13 @@ struct static_if<false>
|
||||
}
|
||||
|
||||
template <typename F>
|
||||
__host__ __device__ static constexpr auto Else(F f)
|
||||
__host__ __device__ static void Else(F f)
|
||||
{
|
||||
// This is a trick for compiler:
|
||||
// Pass forwarder to lambda "f" as "auto" argument, and make sure "f" will use it,
|
||||
// this will make "f" a generic lambda, so that "f" won't be compiled until being
|
||||
// instantiated here
|
||||
f(forwarder{});
|
||||
return Type{};
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -0,0 +1,87 @@
|
||||
#ifndef CK_IN_MEMORY_OPERATION_AMD_HPP
|
||||
#define CK_IN_MEMORY_OPERATION_AMD_HPP
|
||||
|
||||
#include "float_type.hpp"
|
||||
#include "amd_buffer_addressing.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <typename T,
|
||||
index_t DataPerAccess,
|
||||
AddressSpace SrcAddressSpace,
|
||||
AddressSpace DstAddressSpace>
|
||||
__device__ void copy_data(const T* p_src, index_t src_offset, T* p_dst, index_t dst_offset)
|
||||
{
|
||||
using vector_t = typename vector_type<T, DataPerAccess>::MemoryType;
|
||||
|
||||
#if CK_USE_AMD_BUFFER_ADDRESSING
|
||||
// TODO: use static_if::ElseIf, instead of nested static_if
|
||||
static_if<SrcAddressSpace == AddressSpace::global && DstAddressSpace == vgpr>{}([&](auto) {
|
||||
// buffer_load requires:
|
||||
// 1) p_src must be in global memory space, d_dst must be vgpr
|
||||
// 2) p_src to be a block-invariant pointer.
|
||||
// It is user's responsibility to make sure that is true.
|
||||
*reinterpret_cast<vector_t*>(&p_dst[dst_offset]) =
|
||||
amd_intrinsic_buffer_load<T, DataPerAccess>(p_src, src_offset, 0);
|
||||
}).Else([&](auto) {
|
||||
static_if<SrcAddressSpace == AddressSpace::vgpr && DstAddressSpace == global>{}([&](auto) {
|
||||
// buffer_store requires:
|
||||
// 1) p_src must be in vgpr space, d_dst must be global memory
|
||||
// 2) p_dst to be a block-invariant pointer.
|
||||
// It is user's responsibility to make sure that is true.
|
||||
amd_intrinsic_buffer_store<T, DataPerAccess>(
|
||||
*reinterpret_cast<const vector_t*>(&p_src[src_offset]), p_dst, dst_offset, 0);
|
||||
}).Else([&](auto) {
|
||||
*reinterpret_cast<vector_t*>(&p_dst[dst_offset]) =
|
||||
*reinterpret_cast<const vector_t*>(&p_src[src_offset]);
|
||||
});
|
||||
});
|
||||
#else
|
||||
*reinterpret_cast<vector_t*>(&p_dst[dst_offset]) =
|
||||
*reinterpret_cast<const vector_t*>(&p_src[src_offset]);
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename T,
|
||||
index_t DataPerAccess,
|
||||
AddressSpace SrcAddressSpace,
|
||||
AddressSpace DstAddressSpace>
|
||||
__device__ void atomic_add_data(const T* p_src, index_t src_offset, T* p_dst, index_t dst_offset)
|
||||
{
|
||||
using vector_t = typename vector_type<T, DataPerAccess>::MemoryType;
|
||||
|
||||
static_if<SrcAddressSpace == AddressSpace::vgpr && DstAddressSpace == AddressSpace::global>{}(
|
||||
[&](auto) {
|
||||
atomicAdd(reinterpret_cast<vector_t*>(&p_dst[dst_offset]),
|
||||
*reinterpret_cast<const vector_t*>(&p_src[src_offset]));
|
||||
})
|
||||
.Else([&](auto fwd) {
|
||||
static_assert(fwd(false), "atomic_add doesn't support this memory space");
|
||||
});
|
||||
}
|
||||
|
||||
template <typename T,
|
||||
index_t DataPerAccess,
|
||||
AddressSpace SrcAddressSpace,
|
||||
AddressSpace DstAddressSpace,
|
||||
InMemoryDataOperation DstInMemOp>
|
||||
__device__ void move_data(const T* p_src, index_t src_offset, T* p_dst, index_t dst_offset)
|
||||
{
|
||||
static_assert(DstInMemOp == InMemoryDataOperation::none ||
|
||||
DstInMemOp == InMemoryDataOperation::atomic_add,
|
||||
"wrong! InMemoryDataOperation not supported!");
|
||||
|
||||
// TODO: use static_if::ElseIf
|
||||
static_if<DstInMemOp == InMemoryDataOperation::none>{}([&](auto) {
|
||||
copy_data<T, DataPerAccess, SrcAddressSpace, DstAddressSpace>(
|
||||
p_src, src_offset, p_dst, dst_offset);
|
||||
});
|
||||
|
||||
static_if<DstInMemOp == InMemoryDataOperation::atomic_add>{}([&](auto) {
|
||||
atomic_add_data<T, DataPerAccess, SrcAddressSpace, DstAddressSpace>(
|
||||
p_src, src_offset, p_dst, dst_offset);
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
@@ -0,0 +1,60 @@
|
||||
#ifndef CK_IN_MEMORY_OPERATION_NVIDIA_HPP
|
||||
#define CK_IN_MEMORY_OPERATION_NVIDIA_HPP
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <typename T,
|
||||
index_t DataPerAccess,
|
||||
AddressSpace SrcAddressSpace,
|
||||
AddressSpace DstAddressSpace>
|
||||
__device__ void copy_data(const T* p_src, index_t src_offset, T* p_dst, index_t dst_offset)
|
||||
{
|
||||
using vector_t = typename vector_type<T, DataPerAccess>::MemoryType;
|
||||
|
||||
*reinterpret_cast<vector_t*>(&p_dst[dst_offset]) =
|
||||
*reinterpret_cast<const vector_t*>(&p_src[src_offset]);
|
||||
}
|
||||
|
||||
template <typename T,
|
||||
index_t DataPerAccess,
|
||||
AddressSpace SrcAddressSpace,
|
||||
AddressSpace DstAddressSpace>
|
||||
__device__ void atomic_add_data(const T* p_src, index_t src_offset, T* p_dst, index_t dst_offset)
|
||||
{
|
||||
using vector_t = typename vector_type<T, DataPerAccess>::MemoryType;
|
||||
|
||||
static_if<SrcAddressSpace == AddressSpace::vgpr && DstAddressSpace == AddressSpace::global>{}(
|
||||
[&](auto) {
|
||||
atomicAdd(reinterpret_cast<vector_t*>(&p_dst[dst_offset]),
|
||||
*reinterpret_cast<const vector_t*>(&p_src[src_offset]));
|
||||
})
|
||||
.Else([&](auto fwd) {
|
||||
static_assert(fwd(false), "atomic_add doesn't support this memory space");
|
||||
});
|
||||
}
|
||||
|
||||
template <typename T,
|
||||
index_t DataPerAccess,
|
||||
AddressSpace SrcAddressSpace,
|
||||
AddressSpace DstAddressSpace,
|
||||
InMemoryDataOperation DstInMemOp>
|
||||
__device__ void move_data(const T* p_src, index_t src_offset, T* p_dst, index_t dst_offset)
|
||||
{
|
||||
static_assert(DstInMemOp == InMemoryDataOperation::none ||
|
||||
DstInMemOp == InMemoryDataOperation::atomic_add,
|
||||
"wrong! InMemoryDataOperation not supported!");
|
||||
|
||||
// TODO: use static_if::ElseIf
|
||||
static_if<DstInMemOp == InMemoryDataOperation::none>{}([&](auto) {
|
||||
copy_data<T, DataPerAccess, SrcAddressSpace, DstAddressSpace>(
|
||||
p_src, src_offset, p_dst, dst_offset);
|
||||
});
|
||||
|
||||
static_if<DstInMemOp == InMemoryDataOperation::atomic_add>{}([&](auto) {
|
||||
atomic_add_data<T, DataPerAccess, SrcAddressSpace, DstAddressSpace>(
|
||||
p_src, src_offset, p_dst, dst_offset);
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
@@ -97,12 +97,57 @@ __host__ __device__ constexpr T min(T x, Ts... xs)
|
||||
return x < y ? x : y;
|
||||
}
|
||||
|
||||
// this is WRONG
|
||||
// TODO: implement least common multiple properly, instead of calling max()
|
||||
template <class T, class... Ts>
|
||||
__host__ __device__ constexpr T lcm(T x, Ts... xs)
|
||||
// highest common factor
|
||||
template <typename T>
|
||||
__host__ __device__ constexpr T hcf(T x, T y)
|
||||
{
|
||||
return max(x, xs...);
|
||||
if(x == 0)
|
||||
{
|
||||
return y;
|
||||
}
|
||||
|
||||
if(y == 0)
|
||||
{
|
||||
return x;
|
||||
}
|
||||
|
||||
if(x == y)
|
||||
{
|
||||
return x;
|
||||
}
|
||||
|
||||
if(x > y)
|
||||
{
|
||||
return hcf(x - y, y);
|
||||
}
|
||||
|
||||
return hcf(x, y - x);
|
||||
}
|
||||
|
||||
template <index_t X, index_t Y>
|
||||
__host__ __device__ constexpr auto hcf(Number<X>, Number<Y>)
|
||||
{
|
||||
constexpr auto result = hcf(X, Y);
|
||||
return Number<result>{};
|
||||
}
|
||||
|
||||
template <typename X, typename... Ys>
|
||||
__host__ __device__ constexpr auto hcf(X x, Ys... ys)
|
||||
{
|
||||
return hcf(x, ys...);
|
||||
}
|
||||
|
||||
// least common multiple
|
||||
template <typename T>
|
||||
__host__ __device__ constexpr T lcm(T x, T y)
|
||||
{
|
||||
return (x * y) / hcf(x, y);
|
||||
}
|
||||
|
||||
template <typename X, typename Y, typename... Zs>
|
||||
__host__ __device__ constexpr auto lcm(X x, Y y, Zs... zs)
|
||||
{
|
||||
return lcm(x, lcm(y, zs...));
|
||||
}
|
||||
|
||||
template <class T>
|
||||
|
||||
@@ -15,10 +15,18 @@ install(TARGETS host LIBRARY DESTINATION lib)
|
||||
|
||||
|
||||
if(DEVICE_BACKEND STREQUAL "AMD")
|
||||
set(DRIVER_SOURCE src/driver.cpp)
|
||||
set(CONV_SOURCE src/conv_driver.cpp)
|
||||
set(COL2IM_SOURCE src/col2im_driver.cpp)
|
||||
set(CONV_BWD_DATA_SOURCE src/conv_bwd_data_driver.cpp)
|
||||
elseif(DEVICE_BACKEND STREQUAL "NVIDIA")
|
||||
set(DRIVER_SOURCE src/driver.cu)
|
||||
set(CONV_SOURCE src/conv_driver.cu)
|
||||
set(COL2IM_SOURCE src/col2im_driver.cu)
|
||||
set(CONV_BWD_DATA_SOURCE src/conv_bwd_data_driver.cu)
|
||||
endif()
|
||||
|
||||
add_executable(driver ${DRIVER_SOURCE})
|
||||
target_link_libraries(driver PRIVATE host)
|
||||
add_executable(conv ${CONV_SOURCE})
|
||||
add_executable(col2im ${COL2IM_SOURCE})
|
||||
add_executable(conv_bwd_data ${CONV_BWD_DATA_SOURCE})
|
||||
target_link_libraries(conv PRIVATE host)
|
||||
target_link_libraries(col2im PRIVATE host)
|
||||
target_link_libraries(conv_bwd_data PRIVATE host)
|
||||
|
||||
@@ -2,39 +2,7 @@
|
||||
#define CONV_COMMON_HPP
|
||||
|
||||
#include "ConstantTensorDescriptor_deprecated.hpp"
|
||||
|
||||
// this is ugly, only for 4d
|
||||
template <class InDesc, class WeiDesc>
|
||||
constexpr auto get_convolution_output_default_4d_tensor_descriptor(InDesc, WeiDesc)
|
||||
{
|
||||
using namespace ck;
|
||||
|
||||
constexpr auto in_desc = InDesc{};
|
||||
constexpr auto wei_desc = WeiDesc{};
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
static_assert(in_desc.GetNumOfDimension() == 4, "input nDim is not 4");
|
||||
static_assert(wei_desc.GetNumOfDimension() == 4, "weight nDim is not 4");
|
||||
static_assert(in_desc.GetLength(I1) == wei_desc.GetLength(I1),
|
||||
"input & weight dimension not consistent");
|
||||
|
||||
constexpr auto N = in_desc.GetLength(I0);
|
||||
constexpr auto HI = in_desc.GetLength(I2);
|
||||
constexpr auto WI = in_desc.GetLength(I3);
|
||||
|
||||
constexpr auto K = wei_desc.GetLength(I0);
|
||||
constexpr auto Y = wei_desc.GetLength(I2);
|
||||
constexpr auto X = wei_desc.GetLength(I3);
|
||||
|
||||
constexpr auto HO = HI + 1 - Y;
|
||||
constexpr auto WO = WI + 1 - X;
|
||||
|
||||
return make_ConstantTensorDescriptor_packed(Sequence<N, K, HO, WO>{});
|
||||
}
|
||||
#include "tensor_descriptor.hpp"
|
||||
|
||||
template <class InDesc,
|
||||
class WeiDesc,
|
||||
@@ -42,7 +10,7 @@ template <class InDesc,
|
||||
class ConvDilations,
|
||||
class LowerPads,
|
||||
class UpperPads>
|
||||
constexpr auto get_convolution_with_padding_output_default_4d_tensor_descriptor(
|
||||
constexpr auto get_convolution_output_default_4d_tensor_descriptor_deprecated(
|
||||
InDesc, WeiDesc, ConvStrides, ConvDilations, LowerPads, UpperPads)
|
||||
{
|
||||
using namespace ck;
|
||||
@@ -83,6 +51,53 @@ constexpr auto get_convolution_with_padding_output_default_4d_tensor_descriptor(
|
||||
return make_ConstantTensorDescriptor_packed(Sequence<N, K, Ho, Wo>{});
|
||||
}
|
||||
|
||||
template <class InDesc,
|
||||
class WeiDesc,
|
||||
class ConvStrides,
|
||||
class ConvDilations,
|
||||
class LowerPads,
|
||||
class UpperPads>
|
||||
constexpr auto get_convolution_output_default_4d_tensor_descriptor(
|
||||
InDesc, WeiDesc, ConvStrides, ConvDilations, LowerPads, UpperPads)
|
||||
{
|
||||
using namespace ck;
|
||||
|
||||
constexpr auto in_desc = InDesc{};
|
||||
constexpr auto wei_desc = WeiDesc{};
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
static_assert(in_desc.GetNumOfDimension() == 4, "input nDim is not 4");
|
||||
static_assert(wei_desc.GetNumOfDimension() == 4, "weight nDim is not 4");
|
||||
static_assert(in_desc.GetLength(I1) == wei_desc.GetLength(I1),
|
||||
"input & weight dimension not consistent");
|
||||
|
||||
constexpr index_t N = in_desc.GetLength(I0);
|
||||
constexpr index_t Hi = in_desc.GetLength(I2);
|
||||
constexpr index_t Wi = in_desc.GetLength(I3);
|
||||
|
||||
constexpr index_t K = wei_desc.GetLength(I0);
|
||||
constexpr index_t Y = wei_desc.GetLength(I2);
|
||||
constexpr index_t X = wei_desc.GetLength(I3);
|
||||
|
||||
constexpr index_t HPadLow = LowerPads{}.Get(I0);
|
||||
constexpr index_t WPadLow = LowerPads{}.Get(I1);
|
||||
|
||||
constexpr index_t HPadUp = UpperPads{}.Get(I0);
|
||||
constexpr index_t WPadUp = UpperPads{}.Get(I1);
|
||||
|
||||
constexpr index_t YEff = (Y - 1) * ConvDilations{}[0] + 1;
|
||||
constexpr index_t XEff = (X - 1) * ConvDilations{}[1] + 1;
|
||||
|
||||
constexpr index_t Ho = (Hi + HPadLow + HPadUp - YEff) / ConvStrides{}[0] + 1;
|
||||
constexpr index_t Wo = (Wi + WPadLow + WPadUp - XEff) / ConvStrides{}[1] + 1;
|
||||
|
||||
return make_native_tensor_descriptor_packed(Sequence<N, K, Ho, Wo>{});
|
||||
}
|
||||
|
||||
template <class InDesc, class WeiDesc, class OutDesc>
|
||||
constexpr std::size_t calculate_convolution_flops(InDesc, WeiDesc, OutDesc)
|
||||
{
|
||||
|
||||
108
driver/include/device_col2im_eb_nchw.hpp
Normal file
108
driver/include/device_col2im_eb_nchw.hpp
Normal file
@@ -0,0 +1,108 @@
|
||||
#pragma once
|
||||
#include <unistd.h>
|
||||
#include "device.hpp"
|
||||
#include "tensor.hpp"
|
||||
#include "gridwise_operation_wrapper.hpp"
|
||||
#include "gridwise_col2im_eb_nchw.hpp"
|
||||
|
||||
template <typename T,
|
||||
typename ColDesc,
|
||||
typename ImgDesc,
|
||||
typename FilterSizes,
|
||||
typename OutputSizes,
|
||||
typename ConvStrides,
|
||||
typename ConvDilations,
|
||||
typename LeftPads,
|
||||
typename RightPads>
|
||||
void device_col2im_eb_nchw(ColDesc,
|
||||
const Tensor<T>& col_eb,
|
||||
ImgDesc,
|
||||
Tensor<T>& img_nchw,
|
||||
FilterSizes,
|
||||
OutputSizes,
|
||||
ConvStrides,
|
||||
ConvDilations,
|
||||
LeftPads,
|
||||
RightPads,
|
||||
std::size_t nrepeat)
|
||||
{
|
||||
using namespace ck;
|
||||
|
||||
constexpr auto col_eb_desc = ColDesc{};
|
||||
constexpr auto img_nchw_desc = ImgDesc{};
|
||||
|
||||
constexpr index_t N = img_nchw_desc.GetLengths()[0];
|
||||
constexpr index_t C = img_nchw_desc.GetLengths()[1];
|
||||
constexpr index_t Hi = img_nchw_desc.GetLengths()[2];
|
||||
constexpr index_t Wi = img_nchw_desc.GetLengths()[3];
|
||||
|
||||
constexpr index_t E = col_eb_desc.GetLengths()[0];
|
||||
constexpr index_t B = col_eb_desc.GetLengths()[1];
|
||||
|
||||
std::size_t data_sz = sizeof(T);
|
||||
DeviceMem col_eb_device_buf(data_sz * col_eb.mDesc.GetElementSpace());
|
||||
DeviceMem img_nchw_device_buf(data_sz * img_nchw.mDesc.GetElementSpace());
|
||||
|
||||
col_eb_device_buf.ToDevice(col_eb.mData.data());
|
||||
img_nchw_device_buf.ToDevice(img_nchw.mData.data());
|
||||
|
||||
#if 1
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t EPerBlock = 128;
|
||||
constexpr index_t BPerBlock = 128;
|
||||
|
||||
using BlockCopySubLengths_E_B = Sequence<8, 8>;
|
||||
using BlockCopyClusterLengths_E_B = Sequence<16, 16>;
|
||||
using BlockCopyThreadClusterArrangeOrder = Sequence<0, 1>; // [E, B]
|
||||
using BlockCopySrcAccessOrder = Sequence<0, 1>; // [E, B]
|
||||
using BlockCopyDstAccessOrder = Sequence<0, 1>; // [E, B]
|
||||
|
||||
constexpr index_t BlockCopyDataPerAccess_B = 1;
|
||||
#endif
|
||||
|
||||
constexpr index_t GridSize =
|
||||
((E + EPerBlock - 1) / EPerBlock) * ((B + BPerBlock - 1) / BPerBlock);
|
||||
|
||||
printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize);
|
||||
|
||||
constexpr auto gridwise_col2im = GridwiseCol2Im_eb_nchw<GridSize,
|
||||
BlockSize,
|
||||
T,
|
||||
ColDesc,
|
||||
ImgDesc,
|
||||
FilterSizes,
|
||||
OutputSizes,
|
||||
ConvStrides,
|
||||
ConvDilations,
|
||||
LeftPads,
|
||||
RightPads,
|
||||
EPerBlock,
|
||||
BPerBlock,
|
||||
BlockCopySubLengths_E_B,
|
||||
BlockCopyClusterLengths_E_B,
|
||||
BlockCopyThreadClusterArrangeOrder,
|
||||
BlockCopySrcAccessOrder,
|
||||
BlockCopyDstAccessOrder,
|
||||
BlockCopyDataPerAccess_B>{};
|
||||
|
||||
for(index_t i = 0; i < nrepeat; ++i)
|
||||
{
|
||||
float time = launch_kernel(run_gridwise_operation<decltype(gridwise_col2im),
|
||||
const T* const __restrict__,
|
||||
T* const __restrict__>,
|
||||
dim3(GridSize),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
gridwise_col2im,
|
||||
const_cast<const T* const __restrict__>(
|
||||
static_cast<T*>(col_eb_device_buf.GetDeviceBuffer())),
|
||||
const_cast<T* const __restrict__>(
|
||||
static_cast<T*>(img_nchw_device_buf.GetDeviceBuffer())));
|
||||
|
||||
printf("Elapsed time : %f ms\n", time);
|
||||
usleep(std::min(time * 1000, float(10000)));
|
||||
}
|
||||
|
||||
img_nchw_device_buf.FromDevice(img_nchw.mData.data());
|
||||
}
|
||||
@@ -0,0 +1,143 @@
|
||||
#pragma once
|
||||
#include <unistd.h>
|
||||
#include "device.hpp"
|
||||
#include "tensor.hpp"
|
||||
#include "gridwise_operation_wrapper.hpp"
|
||||
#include "gridwise_convolution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw.hpp"
|
||||
|
||||
template <typename T,
|
||||
typename InDesc,
|
||||
typename WeiDesc,
|
||||
typename OutDesc,
|
||||
typename ConvStrides,
|
||||
typename ConvDilations,
|
||||
typename LeftPads,
|
||||
typename RightPads>
|
||||
void device_convolution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw(InDesc in_nchw_desc,
|
||||
Tensor<T>& in_nchw,
|
||||
WeiDesc wei_kcyx_desc,
|
||||
const Tensor<T>& wei_kcyx,
|
||||
OutDesc out_nkhw_desc,
|
||||
const Tensor<T>& out_nkhw,
|
||||
ConvStrides,
|
||||
ConvDilations,
|
||||
LeftPads,
|
||||
RightPads,
|
||||
std::size_t nrepeat)
|
||||
{
|
||||
using namespace ck;
|
||||
|
||||
constexpr index_t N = out_nkhw_desc.GetLengths()[0];
|
||||
constexpr index_t K = out_nkhw_desc.GetLengths()[1];
|
||||
constexpr index_t Ho = out_nkhw_desc.GetLengths()[2];
|
||||
constexpr index_t Wo = out_nkhw_desc.GetLengths()[3];
|
||||
|
||||
constexpr index_t C = wei_kcyx_desc.GetLengths()[1];
|
||||
constexpr index_t Y = wei_kcyx_desc.GetLengths()[2];
|
||||
constexpr index_t X = wei_kcyx_desc.GetLengths()[3];
|
||||
|
||||
std::size_t data_sz = sizeof(T);
|
||||
DeviceMem in_nchw_device_buf(data_sz * in_nchw.mDesc.GetElementSpace());
|
||||
DeviceMem wei_kcyx_device_buf(data_sz * wei_kcyx.mDesc.GetElementSpace());
|
||||
DeviceMem out_nkhw_device_buf(data_sz * out_nkhw.mDesc.GetElementSpace());
|
||||
|
||||
in_nchw_device_buf.ToDevice(in_nchw.mData.data());
|
||||
wei_kcyx_device_buf.ToDevice(wei_kcyx.mData.data());
|
||||
out_nkhw_device_buf.ToDevice(out_nkhw.mData.data());
|
||||
|
||||
#if 1
|
||||
// BlockSize = 256, each thread hold 64 data
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 128;
|
||||
constexpr index_t GemmNPerBlock = 128;
|
||||
constexpr index_t GemmKPerBlock = 8;
|
||||
constexpr index_t GemmMPerThreadSubC = 4;
|
||||
constexpr index_t GemmNPerThreadSubC = 4;
|
||||
constexpr index_t GemmMLevel0Cluster = 4;
|
||||
constexpr index_t GemmNLevel0Cluster = 4;
|
||||
constexpr index_t GemmMLevel1Cluster = 4;
|
||||
constexpr index_t GemmNLevel1Cluster = 4;
|
||||
constexpr index_t GemmKPerThreadLoop = 1;
|
||||
constexpr index_t GemmThreadGemmDataPerReadM = 4;
|
||||
constexpr index_t GemmThreadGemmDataPerReadN = 4;
|
||||
|
||||
using GemmABlockCopySubLengths = Sequence<1, 4>; // Gemm-K, Gemm-M
|
||||
using GemmABlockCopyClusterLengths = Sequence<8, 32>; // Gemm-K, Gemm-M
|
||||
|
||||
constexpr index_t GemmABlockCopyDataPerAccess = 4; // Gemm-M
|
||||
|
||||
using GemmBBlockCopySubLengths = Sequence<4, 1>; // Gemm-K, Gemm-N
|
||||
using GemmBBlockCopyClusterLengths = Sequence<2, 128>; // Gemm-K, Gemm-N
|
||||
|
||||
constexpr index_t GemmBBlockCopyDataPerAccess = 1; // Gemm-N
|
||||
|
||||
constexpr index_t GemmCThreadCopyDataPerAccess = 1; // Gemm-N
|
||||
#endif
|
||||
|
||||
constexpr index_t GemmM = C * Y * X;
|
||||
constexpr index_t GemmN = N * Ho * Wo;
|
||||
|
||||
constexpr index_t GridSize = ((GemmM + GemmMPerBlock - 1) / GemmMPerBlock) *
|
||||
((GemmN + GemmNPerBlock - 1) / GemmNPerBlock);
|
||||
|
||||
printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize);
|
||||
|
||||
constexpr auto gridwise_conv = GridwiseConvolutionBackwardDataImplicitGemm_v1r1_nchw_kcyx_nkhw<
|
||||
GridSize,
|
||||
BlockSize,
|
||||
T,
|
||||
T,
|
||||
decltype(in_nchw_desc),
|
||||
decltype(wei_kcyx_desc),
|
||||
decltype(out_nkhw_desc),
|
||||
ConvStrides,
|
||||
ConvDilations,
|
||||
LeftPads,
|
||||
RightPads,
|
||||
GemmMPerBlock,
|
||||
GemmNPerBlock,
|
||||
GemmKPerBlock,
|
||||
GemmMPerThreadSubC,
|
||||
GemmNPerThreadSubC,
|
||||
GemmMLevel0Cluster,
|
||||
GemmNLevel0Cluster,
|
||||
GemmMLevel1Cluster,
|
||||
GemmNLevel1Cluster,
|
||||
GemmKPerThreadLoop,
|
||||
GemmThreadGemmDataPerReadM,
|
||||
GemmThreadGemmDataPerReadN,
|
||||
GemmABlockCopySubLengths,
|
||||
GemmABlockCopyClusterLengths,
|
||||
GemmABlockCopyDataPerAccess,
|
||||
GemmBBlockCopySubLengths,
|
||||
GemmBBlockCopyClusterLengths,
|
||||
GemmBBlockCopyDataPerAccess,
|
||||
GemmCThreadCopyDataPerAccess>{};
|
||||
|
||||
for(index_t i = 0; i < nrepeat; ++i)
|
||||
{
|
||||
float time = launch_kernel(run_gridwise_operation<decltype(gridwise_conv),
|
||||
T* const __restrict__,
|
||||
const T* const __restrict__,
|
||||
const T* const __restrict__>,
|
||||
dim3(GridSize),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
gridwise_conv,
|
||||
const_cast<T* const __restrict__>(
|
||||
static_cast<T*>(in_nchw_device_buf.GetDeviceBuffer())),
|
||||
const_cast<const T* const __restrict__>(
|
||||
static_cast<T*>(wei_kcyx_device_buf.GetDeviceBuffer())),
|
||||
const_cast<const T* const __restrict__>(
|
||||
static_cast<T*>(out_nkhw_device_buf.GetDeviceBuffer())));
|
||||
|
||||
printf("Elapsed time : %f ms, %f TFlop/s\n",
|
||||
time,
|
||||
(float)calculate_convolution_flops(InDesc{}, WeiDesc{}, OutDesc{}) /
|
||||
(std::size_t(1000) * 1000 * 1000) / time);
|
||||
usleep(std::min(time * 1000, float(10000)));
|
||||
}
|
||||
|
||||
in_nchw_device_buf.FromDevice(in_nchw.mData.data());
|
||||
}
|
||||
@@ -0,0 +1,155 @@
|
||||
#pragma once
|
||||
#include <unistd.h>
|
||||
#include "device.hpp"
|
||||
#include "tensor.hpp"
|
||||
#include "gridwise_operation_wrapper.hpp"
|
||||
#include "gridwise_convolution_backward_data_implicit_gemm_v1r2_nchw_kcyx_nkhw_lds_double_buffer.hpp"
|
||||
|
||||
template <typename T,
|
||||
typename InDesc,
|
||||
typename WeiDesc,
|
||||
typename OutDesc,
|
||||
typename ConvStrides,
|
||||
typename ConvDilations,
|
||||
typename LeftPads,
|
||||
typename RightPads>
|
||||
void device_convolution_backward_data_implicit_gemm_v1r2_nchw_kcyx_nkhw(InDesc in_nchw_desc,
|
||||
Tensor<T>& in_nchw,
|
||||
WeiDesc wei_kcyx_desc,
|
||||
const Tensor<T>& wei_kcyx,
|
||||
OutDesc out_nkhw_desc,
|
||||
const Tensor<T>& out_nkhw,
|
||||
ConvStrides,
|
||||
ConvDilations,
|
||||
LeftPads,
|
||||
RightPads,
|
||||
std::size_t nrepeat)
|
||||
{
|
||||
using namespace ck;
|
||||
|
||||
constexpr index_t N = out_nkhw_desc.GetLengths()[0];
|
||||
constexpr index_t K = out_nkhw_desc.GetLengths()[1];
|
||||
constexpr index_t Ho = out_nkhw_desc.GetLengths()[2];
|
||||
constexpr index_t Wo = out_nkhw_desc.GetLengths()[3];
|
||||
|
||||
constexpr index_t C = wei_kcyx_desc.GetLengths()[1];
|
||||
constexpr index_t Y = wei_kcyx_desc.GetLengths()[2];
|
||||
constexpr index_t X = wei_kcyx_desc.GetLengths()[3];
|
||||
|
||||
std::size_t data_sz = sizeof(T);
|
||||
DeviceMem in_nchw_device_buf(data_sz * in_nchw.mDesc.GetElementSpace());
|
||||
DeviceMem wei_kcyx_device_buf(data_sz * wei_kcyx.mDesc.GetElementSpace());
|
||||
DeviceMem out_nkhw_device_buf(data_sz * out_nkhw.mDesc.GetElementSpace());
|
||||
|
||||
in_nchw_device_buf.ToDevice(in_nchw.mData.data());
|
||||
wei_kcyx_device_buf.ToDevice(wei_kcyx.mData.data());
|
||||
out_nkhw_device_buf.ToDevice(out_nkhw.mData.data());
|
||||
|
||||
#if 1
|
||||
// BlockSize = 256, each thread hold 64 data
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t BPerBlock = 32;
|
||||
constexpr index_t EPerBlock = 32;
|
||||
constexpr index_t KPerBlock = 8;
|
||||
|
||||
constexpr index_t GemmMPerThreadSubC = 4;
|
||||
constexpr index_t GemmNPerThreadSubC = 4;
|
||||
constexpr index_t GemmMLevel0Cluster = 4;
|
||||
constexpr index_t GemmNLevel0Cluster = 4;
|
||||
constexpr index_t GemmMLevel1Cluster = 4;
|
||||
constexpr index_t GemmNLevel1Cluster = 4;
|
||||
constexpr index_t GemmKPerThreadLoop = 1;
|
||||
constexpr index_t GemmDataPerReadA = 4;
|
||||
constexpr index_t GemmDataPerReadB = 4;
|
||||
|
||||
using OutBlockCopySubLengths_K_B_N0 = Sequence<1, 1, 4>;
|
||||
using OutBlockCopyClusterLengths_K_B_N0 = Sequence<8, 32, 1>;
|
||||
|
||||
constexpr index_t OutBlockCopySrcDataPerRead_B = 1;
|
||||
constexpr index_t OutBlockCopyDstDataPerWrite_N0 = 4;
|
||||
|
||||
using WeiBlockCopySubLengths_K_E_C0 = Sequence<1, 4, 1>;
|
||||
using WeiBlockCopyClusterLengths_K_E_C0 = Sequence<8, 8, 4>;
|
||||
|
||||
constexpr index_t WeiBlockCopySrcDataPerRead_E = 4;
|
||||
constexpr index_t WeiBlockCopyDstDataPerWrite_C0 = 1;
|
||||
|
||||
constexpr index_t InThreadCopyDstDataPerWrite_B = 1;
|
||||
#endif
|
||||
|
||||
constexpr index_t C0 = GemmMPerThreadSubC;
|
||||
constexpr index_t N0 = GemmNPerThreadSubC;
|
||||
|
||||
constexpr index_t C1 = C / C0;
|
||||
constexpr index_t N1 = N / N0;
|
||||
|
||||
constexpr index_t E = C1 * Y * X;
|
||||
constexpr index_t B = (N1 * Ho * Wo);
|
||||
|
||||
constexpr index_t GridSize =
|
||||
((E + EPerBlock - 1) / EPerBlock) * ((B + BPerBlock - 1) / BPerBlock);
|
||||
|
||||
printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize);
|
||||
|
||||
constexpr auto gridwise_conv =
|
||||
GridwiseConvolutionBackwardDataImplicitGemm_v1r2_nchw_kcyx_nkhw_lds_double_buffer<
|
||||
GridSize,
|
||||
BlockSize,
|
||||
T,
|
||||
T,
|
||||
decltype(in_nchw_desc),
|
||||
decltype(wei_kcyx_desc),
|
||||
decltype(out_nkhw_desc),
|
||||
ConvStrides,
|
||||
ConvDilations,
|
||||
LeftPads,
|
||||
RightPads,
|
||||
EPerBlock,
|
||||
BPerBlock,
|
||||
KPerBlock,
|
||||
GemmMPerThreadSubC,
|
||||
GemmNPerThreadSubC,
|
||||
GemmMLevel0Cluster,
|
||||
GemmNLevel0Cluster,
|
||||
GemmMLevel1Cluster,
|
||||
GemmNLevel1Cluster,
|
||||
GemmKPerThreadLoop,
|
||||
GemmDataPerReadA,
|
||||
GemmDataPerReadB,
|
||||
OutBlockCopySubLengths_K_B_N0,
|
||||
OutBlockCopyClusterLengths_K_B_N0,
|
||||
OutBlockCopySrcDataPerRead_B,
|
||||
OutBlockCopyDstDataPerWrite_N0,
|
||||
WeiBlockCopySubLengths_K_E_C0,
|
||||
WeiBlockCopyClusterLengths_K_E_C0,
|
||||
WeiBlockCopySrcDataPerRead_E,
|
||||
WeiBlockCopyDstDataPerWrite_C0,
|
||||
InThreadCopyDstDataPerWrite_B>{};
|
||||
|
||||
for(index_t i = 0; i < nrepeat; ++i)
|
||||
{
|
||||
float time = launch_kernel(run_gridwise_operation<decltype(gridwise_conv),
|
||||
T* const __restrict__,
|
||||
const T* const __restrict__,
|
||||
const T* const __restrict__>,
|
||||
dim3(GridSize),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
gridwise_conv,
|
||||
const_cast<T* const __restrict__>(
|
||||
static_cast<T*>(in_nchw_device_buf.GetDeviceBuffer())),
|
||||
const_cast<const T* const __restrict__>(
|
||||
static_cast<T*>(wei_kcyx_device_buf.GetDeviceBuffer())),
|
||||
const_cast<const T* const __restrict__>(
|
||||
static_cast<T*>(out_nkhw_device_buf.GetDeviceBuffer())));
|
||||
|
||||
printf("Elapsed time : %f ms, %f TFlop/s\n",
|
||||
time,
|
||||
(float)calculate_convolution_flops(InDesc{}, WeiDesc{}, OutDesc{}) /
|
||||
(std::size_t(1000) * 1000 * 1000) / time);
|
||||
usleep(std::min(time * 1000, float(10000)));
|
||||
}
|
||||
|
||||
in_nchw_device_buf.FromDevice(in_nchw.mData.data());
|
||||
}
|
||||
@@ -0,0 +1,195 @@
|
||||
#pragma once
|
||||
#include <unistd.h>
|
||||
#include "device.hpp"
|
||||
#include "tensor.hpp"
|
||||
#include "gridwise_operation_wrapper.hpp"
|
||||
#include "gridwise_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw.hpp"
|
||||
|
||||
template <typename T,
|
||||
typename InDesc,
|
||||
typename WeiDesc,
|
||||
typename OutDesc,
|
||||
typename ConvStrides,
|
||||
typename ConvDilations,
|
||||
typename LeftPads,
|
||||
typename RightPads>
|
||||
void device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw(InDesc in_nchw_desc,
|
||||
Tensor<T>& in_nchw,
|
||||
WeiDesc wei_kcyx_desc,
|
||||
const Tensor<T>& wei_kcyx,
|
||||
OutDesc out_nkhw_desc,
|
||||
const Tensor<T>& out_nkhw,
|
||||
ConvStrides,
|
||||
ConvDilations,
|
||||
LeftPads,
|
||||
RightPads,
|
||||
std::size_t nrepeat)
|
||||
{
|
||||
using namespace ck;
|
||||
|
||||
constexpr index_t N = out_nkhw_desc.GetLengths()[0];
|
||||
constexpr index_t K = out_nkhw_desc.GetLengths()[1];
|
||||
constexpr index_t Ho = out_nkhw_desc.GetLengths()[2];
|
||||
constexpr index_t Wo = out_nkhw_desc.GetLengths()[3];
|
||||
|
||||
constexpr index_t C = wei_kcyx_desc.GetLengths()[1];
|
||||
constexpr index_t Y = wei_kcyx_desc.GetLengths()[2];
|
||||
constexpr index_t X = wei_kcyx_desc.GetLengths()[3];
|
||||
|
||||
constexpr index_t ConvStrideH = ConvStrides{}[0];
|
||||
constexpr index_t ConvStrideW = ConvStrides{}[1];
|
||||
|
||||
constexpr index_t ConvDilationH = ConvDilations{}[0];
|
||||
constexpr index_t ConvDilationW = ConvDilations{}[1];
|
||||
|
||||
std::size_t data_sz = sizeof(T);
|
||||
DeviceMem in_nchw_device_buf(data_sz * in_nchw.mDesc.GetElementSpace());
|
||||
DeviceMem wei_kcyx_device_buf(data_sz * wei_kcyx.mDesc.GetElementSpace());
|
||||
DeviceMem out_nkhw_device_buf(data_sz * out_nkhw.mDesc.GetElementSpace());
|
||||
|
||||
in_nchw_device_buf.ToDevice(in_nchw.mData.data());
|
||||
wei_kcyx_device_buf.ToDevice(wei_kcyx.mData.data());
|
||||
out_nkhw_device_buf.ToDevice(out_nkhw.mData.data());
|
||||
|
||||
#if 1
|
||||
// BlockSize = 256, each thread hold 64 data
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 128;
|
||||
constexpr index_t GemmNPerBlock = 128;
|
||||
constexpr index_t GemmKPerBlock = 8;
|
||||
constexpr index_t GemmMPerThreadSubC = 4;
|
||||
constexpr index_t GemmNPerThreadSubC = 4;
|
||||
constexpr index_t GemmMLevel0Cluster = 4;
|
||||
constexpr index_t GemmNLevel0Cluster = 4;
|
||||
constexpr index_t GemmMLevel1Cluster = 4;
|
||||
constexpr index_t GemmNLevel1Cluster = 4;
|
||||
constexpr index_t GemmKPerThreadLoop = 1;
|
||||
constexpr index_t GemmThreadGemmDataPerReadM = 4;
|
||||
constexpr index_t GemmThreadGemmDataPerReadN = 4;
|
||||
|
||||
using GemmABlockCopySubLengths = Sequence<4, 1>; // Gemm-K, Gemm-M
|
||||
using GemmABlockCopyClusterLengths = Sequence<2, 128>; // Gemm-K, Gemm-M
|
||||
|
||||
constexpr index_t GemmABlockCopyDataPerAccess = 1; // Gemm-M
|
||||
|
||||
using GemmBBlockCopySubLengths = Sequence<4, 1>; // Gemm-K, Gemm-N
|
||||
using GemmBBlockCopyClusterLengths = Sequence<2, 128>; // Gemm-K, Gemm-N
|
||||
|
||||
constexpr index_t GemmBBlockCopyDataPerAccess = 1; // Gemm-N
|
||||
|
||||
constexpr index_t GemmCThreadCopyDataPerAccess = 1; // Gemm-N
|
||||
#elif 0
|
||||
// BlockSize = 256, each thread hold 64 data
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 128;
|
||||
constexpr index_t GemmNPerBlock = 128;
|
||||
constexpr index_t GemmKPerBlock = 8;
|
||||
constexpr index_t GemmMPerThreadSubC = 4;
|
||||
constexpr index_t GemmNPerThreadSubC = 4;
|
||||
constexpr index_t GemmMLevel0Cluster = 4;
|
||||
constexpr index_t GemmNLevel0Cluster = 4;
|
||||
constexpr index_t GemmMLevel1Cluster = 4;
|
||||
constexpr index_t GemmNLevel1Cluster = 4;
|
||||
constexpr index_t GemmKPerThreadLoop = 1;
|
||||
constexpr index_t GemmThreadGemmDataPerReadM = 4;
|
||||
constexpr index_t GemmThreadGemmDataPerReadN = 4;
|
||||
|
||||
using GemmABlockCopySubLengths = Sequence<1, 4>; // Gemm-K, Gemm-M
|
||||
using GemmABlockCopyClusterLengths = Sequence<8, 32>; // Gemm-K, Gemm-M
|
||||
|
||||
constexpr index_t GemmABlockCopyDataPerAccess = 4; // Gemm-M
|
||||
|
||||
using GemmBBlockCopySubLengths = Sequence<4, 1>; // Gemm-K, Gemm-N
|
||||
using GemmBBlockCopyClusterLengths = Sequence<2, 128>; // Gemm-K, Gemm-N
|
||||
|
||||
constexpr index_t GemmBBlockCopyDataPerAccess = 1; // Gemm-N
|
||||
|
||||
constexpr index_t GemmCThreadCopyDataPerAccess = 1; // Gemm-N
|
||||
#endif
|
||||
|
||||
// TODO: this algo support any stride and dilation. But for now, let's fix them to be 1 for
|
||||
// simplicity
|
||||
constexpr index_t hcf_stride_dilation_h = math::hcf(ConvStrideH, ConvDilationH);
|
||||
constexpr index_t hcf_stride_dilation_w = math::hcf(ConvStrideW, ConvDilationW);
|
||||
|
||||
constexpr index_t Ytilda = ConvStrideH / hcf_stride_dilation_h; // may be wrong
|
||||
constexpr index_t Xtilda = ConvStrideW / hcf_stride_dilation_w; // may be wrong
|
||||
|
||||
constexpr index_t Ydot = math::integer_divide_ceil(Y, Ytilda);
|
||||
constexpr index_t Xdot = math::integer_divide_ceil(X, Xtilda);
|
||||
|
||||
constexpr index_t right_pad_ho = (ConvDilationH / hcf_stride_dilation_h) * (Y - Ytilda);
|
||||
constexpr index_t right_pad_wo = (ConvDilationW / hcf_stride_dilation_w) * (X - Xtilda);
|
||||
|
||||
constexpr index_t Htilda = Ho + right_pad_ho;
|
||||
constexpr index_t Wtilda = Wo + right_pad_wo;
|
||||
|
||||
constexpr index_t GemmK = K * Ydot * Xdot;
|
||||
constexpr index_t GemmM = C * Ytilda * Xtilda;
|
||||
constexpr index_t GemmN = N * Htilda * Wtilda;
|
||||
|
||||
constexpr index_t GridSize = ((GemmM + GemmMPerBlock - 1) / GemmMPerBlock) *
|
||||
((GemmN + GemmNPerBlock - 1) / GemmNPerBlock);
|
||||
|
||||
printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize);
|
||||
|
||||
constexpr auto gridwise_conv = GridwiseConvolutionBackwardDataImplicitGemm_v2r1_nchw_kcyx_nkhw<
|
||||
GridSize,
|
||||
BlockSize,
|
||||
T,
|
||||
T,
|
||||
decltype(in_nchw_desc),
|
||||
decltype(wei_kcyx_desc),
|
||||
decltype(out_nkhw_desc),
|
||||
ConvStrides,
|
||||
ConvDilations,
|
||||
LeftPads,
|
||||
RightPads,
|
||||
GemmMPerBlock,
|
||||
GemmNPerBlock,
|
||||
GemmKPerBlock,
|
||||
GemmMPerThreadSubC,
|
||||
GemmNPerThreadSubC,
|
||||
GemmMLevel0Cluster,
|
||||
GemmNLevel0Cluster,
|
||||
GemmMLevel1Cluster,
|
||||
GemmNLevel1Cluster,
|
||||
GemmKPerThreadLoop,
|
||||
GemmThreadGemmDataPerReadM,
|
||||
GemmThreadGemmDataPerReadN,
|
||||
GemmABlockCopySubLengths,
|
||||
GemmABlockCopyClusterLengths,
|
||||
GemmABlockCopyDataPerAccess,
|
||||
GemmBBlockCopySubLengths,
|
||||
GemmBBlockCopyClusterLengths,
|
||||
GemmBBlockCopyDataPerAccess,
|
||||
GemmCThreadCopyDataPerAccess>{};
|
||||
|
||||
for(index_t i = 0; i < nrepeat; ++i)
|
||||
{
|
||||
float time = launch_kernel(run_gridwise_operation<decltype(gridwise_conv),
|
||||
T* const __restrict__,
|
||||
const T* const __restrict__,
|
||||
const T* const __restrict__>,
|
||||
dim3(GridSize),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
gridwise_conv,
|
||||
const_cast<T* const __restrict__>(
|
||||
static_cast<T*>(in_nchw_device_buf.GetDeviceBuffer())),
|
||||
const_cast<const T* const __restrict__>(
|
||||
static_cast<T*>(wei_kcyx_device_buf.GetDeviceBuffer())),
|
||||
const_cast<const T* const __restrict__>(
|
||||
static_cast<T*>(out_nkhw_device_buf.GetDeviceBuffer())));
|
||||
|
||||
printf("Elapsed time : %f ms, %f TFlop/s\n",
|
||||
time,
|
||||
(float)calculate_convolution_flops(InDesc{}, WeiDesc{}, OutDesc{}) /
|
||||
(std::size_t(1000) * 1000 * 1000) / time);
|
||||
usleep(std::min(time * 1000, float(10000)));
|
||||
}
|
||||
|
||||
in_nchw_device_buf.FromDevice(in_nchw.mData.data());
|
||||
}
|
||||
@@ -2,7 +2,7 @@
|
||||
#include <unistd.h>
|
||||
#include "device.hpp"
|
||||
#include "tensor.hpp"
|
||||
#include "gridwise_convolution_kernel_wrapper.hpp"
|
||||
#include "gridwise_operation_wrapper.hpp"
|
||||
#include "convolution_common.hpp"
|
||||
#include "gridwise_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer.hpp"
|
||||
|
||||
@@ -54,8 +54,8 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc,
|
||||
wei_kcyx_device_buf.ToDevice(wei_kcyx.mData.data());
|
||||
out_nkhw_device_buf.ToDevice(out_nkhw.mData.data());
|
||||
|
||||
#if 1
|
||||
// BlockSize = 256, each thread hold 64 data
|
||||
#if 0
|
||||
// BlockSize = 256, EperBlock = 8, each thread hold 64 data
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t BPerBlock = 16;
|
||||
@@ -89,6 +89,43 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc,
|
||||
using WeiBlockCopySrcAccessOrder = Sequence<1, 0>; // [K, E]
|
||||
using WeiBlockCopyDstAccessOrder = Sequence<0, 1>; // [E, K]
|
||||
|
||||
constexpr index_t WeiBlockCopySrcDataPerRead_E = 4;
|
||||
constexpr index_t WeiBlockCopyDstDataPerWrite_K = 1;
|
||||
#elif 1
|
||||
// BlockSize = 256, EPerBlock = 16, each thread hold 64 data
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t BPerBlock = 16;
|
||||
constexpr index_t KPerBlock = 128;
|
||||
constexpr index_t EPerBlock = 16;
|
||||
|
||||
constexpr index_t GemmNRepeat = 2;
|
||||
|
||||
constexpr index_t GemmMPerThreadSubC = 4;
|
||||
constexpr index_t GemmNPerThreadSubC = 4;
|
||||
constexpr index_t GemmMLevel0Cluster = 4;
|
||||
constexpr index_t GemmNLevel0Cluster = 4;
|
||||
constexpr index_t GemmMLevel1Cluster = 4;
|
||||
constexpr index_t GemmNLevel1Cluster = 4;
|
||||
constexpr index_t GemmKPerThreadLoop = 1;
|
||||
constexpr index_t GemmDataPerReadA = 4;
|
||||
constexpr index_t GemmDataPerReadB = 4;
|
||||
|
||||
using InBlockCopySubLengths_E_N1_B_N2 = Sequence<1, 2, 1, 4>;
|
||||
using InBlockCopyClusterLengths_E_N1_B_N2 = Sequence<16, 1, 16, 1>;
|
||||
using InBlockCopyThreadClusterArrangeOrder = Sequence<0, 1, 3, 2>; // [E, N1, N2, B]
|
||||
using InBlockCopySrcAccessOrder = Sequence<0, 2, 1, 3>; // [E, B, N1, N2]
|
||||
using InBlockCopyDstAccessOrder = Sequence<0, 1, 2, 3>; // [E, N1, B, N2]
|
||||
|
||||
constexpr index_t InBlockCopySrcDataPerRead_B = 1;
|
||||
constexpr index_t InBlockCopyDstDataPerWrite_N2 = 4;
|
||||
|
||||
using WeiBlockCopySubLengths_E_K = Sequence<4, 2>;
|
||||
using WeiBlockCopyClusterLengths_E_K = Sequence<4, 64>;
|
||||
using WeiBlockCopyThreadClusterArrangeOrder = Sequence<1, 0>; // [K, E]
|
||||
using WeiBlockCopySrcAccessOrder = Sequence<1, 0>; // [K, E]
|
||||
using WeiBlockCopyDstAccessOrder = Sequence<0, 1>; // [E, K]
|
||||
|
||||
constexpr index_t WeiBlockCopySrcDataPerRead_E = 4;
|
||||
constexpr index_t WeiBlockCopyDstDataPerWrite_K = 1;
|
||||
#elif 0
|
||||
@@ -221,13 +258,20 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc,
|
||||
|
||||
for(index_t i = 0; i < nrepeat; ++i)
|
||||
{
|
||||
float time = launch_kernel(run_gridwise_convolution_kernel<decltype(gridwise_conv), T>,
|
||||
float time = launch_kernel(run_gridwise_operation<decltype(gridwise_conv),
|
||||
const T* const __restrict__,
|
||||
const T* const __restrict__,
|
||||
T* const __restrict__>,
|
||||
dim3(GridSize),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
static_cast<T*>(in_nchw_device_buf.GetDeviceBuffer()),
|
||||
static_cast<T*>(wei_kcyx_device_buf.GetDeviceBuffer()),
|
||||
static_cast<T*>(out_nkhw_device_buf.GetDeviceBuffer()));
|
||||
gridwise_conv,
|
||||
const_cast<const T* const __restrict__>(
|
||||
static_cast<T*>(in_nchw_device_buf.GetDeviceBuffer())),
|
||||
const_cast<const T* const __restrict__>(
|
||||
static_cast<T*>(wei_kcyx_device_buf.GetDeviceBuffer())),
|
||||
const_cast<T* const __restrict__>(
|
||||
static_cast<T*>(out_nkhw_device_buf.GetDeviceBuffer())));
|
||||
|
||||
printf("Elapsed time : %f ms, %f TFlop/s\n",
|
||||
time,
|
||||
|
||||
@@ -46,7 +46,7 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw_deprecated(InDesc,
|
||||
wei_kcyx_device_buf.ToDevice(wei_kcyx.mData.data());
|
||||
out_nkhw_device_buf.ToDevice(out_nkhw.mData.data());
|
||||
|
||||
#if 1
|
||||
#if 0
|
||||
// BlockSize = 256, blockwise-GEMM 128x128, each thread hold 64 data
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
@@ -120,7 +120,7 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw_deprecated(InDesc,
|
||||
|
||||
constexpr index_t WeiBlockCopySrcDataPerRead_E = 4;
|
||||
constexpr index_t WeiBlockCopyDstDataPerWrite_K = 1;
|
||||
#elif 1
|
||||
#elif 0
|
||||
// BlockSize = 256, blockwise-GEMM 64x128, each thread hold 32 data
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
@@ -157,6 +157,42 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw_deprecated(InDesc,
|
||||
|
||||
constexpr index_t WeiBlockCopySrcDataPerRead_E = 2;
|
||||
constexpr index_t WeiBlockCopyDstDataPerWrite_K = 1;
|
||||
#elif 1
|
||||
constexpr index_t BlockSize = 64;
|
||||
|
||||
constexpr index_t BPerBlock = 16;
|
||||
constexpr index_t KPerBlock = 32;
|
||||
constexpr index_t EPerBlock = 4;
|
||||
|
||||
constexpr index_t GemmNRepeat = 2;
|
||||
|
||||
constexpr index_t GemmMPerThreadSubC = 4;
|
||||
constexpr index_t GemmNPerThreadSubC = 4;
|
||||
constexpr index_t GemmMLevel0Cluster = 1;
|
||||
constexpr index_t GemmNLevel0Cluster = 4;
|
||||
constexpr index_t GemmMLevel1Cluster = 4;
|
||||
constexpr index_t GemmNLevel1Cluster = 4;
|
||||
constexpr index_t GemmKPerThreadLoop = 1;
|
||||
constexpr index_t GemmDataPerReadA = 4;
|
||||
constexpr index_t GemmDataPerReadB = 4;
|
||||
|
||||
using InBlockCopySubLengths_E_N1_B_N2 = Sequence<1, 2, 1, 4>;
|
||||
using InBlockCopyClusterLengths_E_N1_B_N2 = Sequence<4, 1, 16, 1>;
|
||||
using InBlockCopyThreadClusterArrangeOrder = Sequence<0, 1, 3, 2>; // [E, N1, N2, B]
|
||||
using InBlockCopySrcAccessOrder = Sequence<0, 2, 1, 3>; // [E, B, N1, N2]
|
||||
using InBlockCopyDstAccessOrder = Sequence<0, 1, 2, 3>; // [E, N1, B, N2]
|
||||
|
||||
constexpr index_t InBlockCopySrcDataPerRead_B = 1;
|
||||
constexpr index_t InBlockCopyDstDataPerWrite_N2 = 4;
|
||||
|
||||
using WeiBlockCopySubLengths_E_K = Sequence<1, 2>;
|
||||
using WeiBlockCopyClusterLengths_E_K = Sequence<4, 16>;
|
||||
using WeiBlockCopyThreadClusterArrangeOrder = Sequence<1, 0>; // [K, E]
|
||||
using WeiBlockCopySrcAccessOrder = Sequence<1, 0>; // [K, E]
|
||||
using WeiBlockCopyDstAccessOrder = Sequence<0, 1>; // [E, K]
|
||||
|
||||
constexpr index_t WeiBlockCopySrcDataPerRead_E = 1;
|
||||
constexpr index_t WeiBlockCopyDstDataPerWrite_K = 2;
|
||||
#endif
|
||||
|
||||
constexpr index_t N1 = GemmNRepeat;
|
||||
|
||||
@@ -51,6 +51,7 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
|
||||
out_nkhw_device_buf.ToDevice(out_nkhw.mData.data());
|
||||
|
||||
#if 1
|
||||
// BlockSize = 256, EPerBlock = 8
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t BPerBlock = 128;
|
||||
@@ -85,7 +86,8 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
|
||||
constexpr index_t WeiBlockCopyDstDataPerWrite_K = 1;
|
||||
|
||||
constexpr index_t OutThreadCopyDataPerAccess_B = 1;
|
||||
#elif 1
|
||||
#elif 0
|
||||
// BlockSize = 256, EPerBlock = 8
|
||||
// 1x1 filter, 8x8 image
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
@@ -122,6 +124,43 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
|
||||
|
||||
constexpr index_t OutThreadCopyDataPerAccess_B = 4;
|
||||
#elif 0
|
||||
// BlockSize = 256, EPerBlock = 16
|
||||
// 1x1 filter, 8x8 image
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t BPerBlock = 128;
|
||||
constexpr index_t KPerBlock = 128;
|
||||
constexpr index_t EPerBlock = 16;
|
||||
|
||||
constexpr index_t GemmMPerThreadSubC = 4;
|
||||
constexpr index_t GemmNPerThreadSubC = 4;
|
||||
constexpr index_t GemmMLevel0Cluster = 4;
|
||||
constexpr index_t GemmNLevel0Cluster = 4;
|
||||
constexpr index_t GemmMLevel1Cluster = 4;
|
||||
constexpr index_t GemmNLevel1Cluster = 4;
|
||||
constexpr index_t GemmKPerThreadLoop = 1;
|
||||
constexpr index_t GemmDataPerReadA = 4;
|
||||
constexpr index_t GemmDataPerReadB = 4;
|
||||
|
||||
using InBlockCopySubLengths_E_B = Sequence<2, 4>;
|
||||
using InBlockCopyClusterLengths_E_B = Sequence<8, 32>;
|
||||
using InBlockCopyThreadClusterArrangeOrder = Sequence<0, 1>; // [E, B]
|
||||
using InBlockCopySrcAccessOrder = Sequence<0, 1>; // [E, B]
|
||||
using InBlockCopyDstAccessOrder = Sequence<0, 1>; // [E, B]
|
||||
|
||||
constexpr index_t InBlockCopyDataPerAccess_B = 4;
|
||||
|
||||
using WeiBlockCopySubLengths_E_K = Sequence<4, 2>;
|
||||
using WeiBlockCopyClusterLengths_E_K = Sequence<4, 64>;
|
||||
using WeiBlockCopyThreadClusterArrangeOrder = Sequence<1, 0>; // [K, E]
|
||||
using WeiBlockCopySrcAccessOrder = Sequence<1, 0>; // [K, E]
|
||||
using WeiBlockCopyDstAccessOrder = Sequence<0, 1>; // [E, K]
|
||||
|
||||
constexpr index_t WeiBlockCopySrcDataPerRead_E = 4;
|
||||
constexpr index_t WeiBlockCopyDstDataPerWrite_K = 1;
|
||||
|
||||
constexpr index_t OutThreadCopyDataPerAccess_B = 4;
|
||||
#elif 1
|
||||
// 1x1 filter, 14x14 image
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
@@ -167,47 +206,43 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
|
||||
printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize);
|
||||
|
||||
constexpr auto gridwise_conv =
|
||||
#if 0
|
||||
GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_padded
|
||||
#else
|
||||
GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_lds_double_buffer
|
||||
#endif
|
||||
<GridSize,
|
||||
BlockSize,
|
||||
T,
|
||||
decltype(in_nchw_desc),
|
||||
decltype(wei_kcyx_desc),
|
||||
decltype(out_nkhw_desc),
|
||||
ConvStrides,
|
||||
ConvDilations,
|
||||
LeftPads,
|
||||
RightPads,
|
||||
BPerBlock,
|
||||
KPerBlock,
|
||||
EPerBlock,
|
||||
GemmMPerThreadSubC,
|
||||
GemmNPerThreadSubC,
|
||||
GemmMLevel0Cluster,
|
||||
GemmNLevel0Cluster,
|
||||
GemmMLevel1Cluster,
|
||||
GemmNLevel1Cluster,
|
||||
GemmKPerThreadLoop,
|
||||
GemmDataPerReadA,
|
||||
GemmDataPerReadB,
|
||||
InBlockCopySubLengths_E_B,
|
||||
InBlockCopyClusterLengths_E_B,
|
||||
InBlockCopyThreadClusterArrangeOrder,
|
||||
InBlockCopySrcAccessOrder,
|
||||
InBlockCopyDstAccessOrder,
|
||||
InBlockCopyDataPerAccess_B,
|
||||
WeiBlockCopySubLengths_E_K,
|
||||
WeiBlockCopyClusterLengths_E_K,
|
||||
WeiBlockCopyThreadClusterArrangeOrder,
|
||||
WeiBlockCopySrcAccessOrder,
|
||||
WeiBlockCopyDstAccessOrder,
|
||||
WeiBlockCopySrcDataPerRead_E,
|
||||
WeiBlockCopyDstDataPerWrite_K,
|
||||
OutThreadCopyDataPerAccess_B>{};
|
||||
GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_lds_double_buffer<
|
||||
GridSize,
|
||||
BlockSize,
|
||||
T,
|
||||
decltype(in_nchw_desc),
|
||||
decltype(wei_kcyx_desc),
|
||||
decltype(out_nkhw_desc),
|
||||
ConvStrides,
|
||||
ConvDilations,
|
||||
LeftPads,
|
||||
RightPads,
|
||||
BPerBlock,
|
||||
KPerBlock,
|
||||
EPerBlock,
|
||||
GemmMPerThreadSubC,
|
||||
GemmNPerThreadSubC,
|
||||
GemmMLevel0Cluster,
|
||||
GemmNLevel0Cluster,
|
||||
GemmMLevel1Cluster,
|
||||
GemmNLevel1Cluster,
|
||||
GemmKPerThreadLoop,
|
||||
GemmDataPerReadA,
|
||||
GemmDataPerReadB,
|
||||
InBlockCopySubLengths_E_B,
|
||||
InBlockCopyClusterLengths_E_B,
|
||||
InBlockCopyThreadClusterArrangeOrder,
|
||||
InBlockCopySrcAccessOrder,
|
||||
InBlockCopyDstAccessOrder,
|
||||
InBlockCopyDataPerAccess_B,
|
||||
WeiBlockCopySubLengths_E_K,
|
||||
WeiBlockCopyClusterLengths_E_K,
|
||||
WeiBlockCopyThreadClusterArrangeOrder,
|
||||
WeiBlockCopySrcAccessOrder,
|
||||
WeiBlockCopyDstAccessOrder,
|
||||
WeiBlockCopySrcDataPerRead_E,
|
||||
WeiBlockCopyDstDataPerWrite_K,
|
||||
OutThreadCopyDataPerAccess_B>{};
|
||||
|
||||
for(index_t i = 0; i < nrepeat; ++i)
|
||||
{
|
||||
|
||||
28
driver/include/device_tensor.hpp
Normal file
28
driver/include/device_tensor.hpp
Normal file
@@ -0,0 +1,28 @@
|
||||
#pragma once
|
||||
#include "tensor.hpp"
|
||||
#include "common_header.hpp"
|
||||
#include "ConstantTensorDescriptor_deprecated.hpp"
|
||||
#include "tensor_descriptor.hpp"
|
||||
|
||||
template <typename ConstTensorDesc, std::size_t... Is>
|
||||
auto make_TensorDescriptor_impl(ConstTensorDesc, std::integer_sequence<std::size_t, Is...>)
|
||||
{
|
||||
std::initializer_list<std::size_t> lengths = {ConstTensorDesc::GetLengths()[Is]...};
|
||||
std::initializer_list<std::size_t> strides = {ConstTensorDesc::GetStrides()[Is]...};
|
||||
|
||||
return TensorDescriptor(lengths, strides);
|
||||
}
|
||||
|
||||
template <typename ConstTensorDesc>
|
||||
auto make_TensorDescriptor(ConstTensorDesc)
|
||||
{
|
||||
return make_TensorDescriptor_impl(
|
||||
ConstTensorDesc{},
|
||||
std::make_integer_sequence<std::size_t, ConstTensorDesc::GetNumOfDimension()>{});
|
||||
}
|
||||
|
||||
template <typename ConstTensorDesc>
|
||||
void ostream_ConstantTensorDescriptor(ConstTensorDesc, std::ostream& os = std::cout)
|
||||
{
|
||||
ostream_TensorDescriptor(make_TensorDescriptor(ConstTensorDesc{}), os);
|
||||
}
|
||||
71
driver/include/host_col2im.hpp
Normal file
71
driver/include/host_col2im.hpp
Normal file
@@ -0,0 +1,71 @@
|
||||
#pragma once
|
||||
#include "tensor.hpp"
|
||||
|
||||
template <typename T,
|
||||
typename FilterSizes,
|
||||
typename OutputSizes,
|
||||
typename ConvStrides,
|
||||
typename ConvDilations,
|
||||
typename LeftPads,
|
||||
typename RightPads>
|
||||
void host_col2im(const Tensor<T>& in_eb,
|
||||
Tensor<T>& in_nchw,
|
||||
FilterSizes,
|
||||
OutputSizes,
|
||||
ConvStrides,
|
||||
ConvDilations,
|
||||
LeftPads,
|
||||
RightPads)
|
||||
{
|
||||
using namespace ck;
|
||||
|
||||
int N = in_nchw.mDesc.GetLengths()[0];
|
||||
int C = in_nchw.mDesc.GetLengths()[1];
|
||||
int HI = in_nchw.mDesc.GetLengths()[2];
|
||||
int WI = in_nchw.mDesc.GetLengths()[3];
|
||||
|
||||
int Y = FilterSizes{}[0];
|
||||
int X = FilterSizes{}[1];
|
||||
|
||||
int HO = OutputSizes{}[0];
|
||||
int WO = OutputSizes{}[1];
|
||||
|
||||
auto f = [&](auto n, auto c, auto hi, auto wi) {
|
||||
double v = 0;
|
||||
|
||||
for(int y = 0; y < Y; ++y)
|
||||
{
|
||||
int h_tmp = hi + LeftPads{}[0] - y * ConvDilations{}[0];
|
||||
|
||||
if(h_tmp >= 0 && h_tmp < HI && h_tmp % ConvStrides{}[0] == 0)
|
||||
{
|
||||
int ho = h_tmp / ConvStrides{}[0];
|
||||
|
||||
for(int x = 0; x < X; ++x)
|
||||
{
|
||||
int w_tmp = wi + LeftPads{}[1] - x * ConvDilations{}[1];
|
||||
|
||||
if(w_tmp >= 0 && w_tmp < WI && w_tmp % ConvStrides{}[1] == 0)
|
||||
{
|
||||
int wo = w_tmp / ConvStrides{}[1];
|
||||
|
||||
int e = c * (Y * X) + y * X + x;
|
||||
int b = n * (HO * WO) + ho * WO + wo;
|
||||
|
||||
v += in_eb(e, b);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
in_nchw(n, c, hi, wi) = v;
|
||||
};
|
||||
|
||||
auto f_par = make_ParallelTensorFunctor(f,
|
||||
in_nchw.mDesc.GetLengths()[0],
|
||||
in_nchw.mDesc.GetLengths()[1],
|
||||
in_nchw.mDesc.GetLengths()[2],
|
||||
in_nchw.mDesc.GetLengths()[3]);
|
||||
|
||||
f_par(std::thread::hardware_concurrency());
|
||||
}
|
||||
@@ -1,49 +1,5 @@
|
||||
#pragma once
|
||||
#include "tensor.hpp"
|
||||
#include "common_header.hpp"
|
||||
#include "ConstantTensorDescriptor_deprecated.hpp"
|
||||
|
||||
// this is ugly, only for 4d
|
||||
template <class TConstTensorDesc>
|
||||
void ostream_ConstantTensorDescriptor(TConstTensorDesc, std::ostream& os = std::cout)
|
||||
{
|
||||
using namespace ck;
|
||||
|
||||
static_assert(TConstTensorDesc::nDim == 4, "nDim is not 4");
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
constexpr auto desc = TConstTensorDesc{};
|
||||
|
||||
os << "Lengths: {" << desc.GetLength(I0) << ", " << desc.GetLength(I1) << ", "
|
||||
<< desc.GetLength(I2) << ", " << desc.GetLength(I3) << "}, "
|
||||
<< "Strides: {" << desc.GetStride(I0) << ", " << desc.GetStride(I1) << ", "
|
||||
<< desc.GetStride(I2) << ", " << desc.GetStride(I3) << "}" << std::endl;
|
||||
}
|
||||
|
||||
// this is ugly, only for 4d
|
||||
template <class TConstTensorDesc>
|
||||
auto make_TensorDescriptor(TConstTensorDesc)
|
||||
{
|
||||
using namespace ck;
|
||||
|
||||
static_assert(TConstTensorDesc::nDim == 4, "nDim is not 4");
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
constexpr auto desc = TConstTensorDesc{};
|
||||
|
||||
std::initializer_list<index_t> lengths = {
|
||||
desc.GetLength(I0), desc.GetLength(I1), desc.GetLength(I2), desc.GetLength(I3)};
|
||||
std::initializer_list<index_t> strides = {
|
||||
desc.GetStride(I0), desc.GetStride(I1), desc.GetStride(I2), desc.GetStride(I3)};
|
||||
|
||||
return TensorDescriptor(lengths, strides);
|
||||
}
|
||||
|
||||
template <class TIn,
|
||||
class TWei,
|
||||
@@ -331,25 +287,3 @@ void host_winograd_3x3_convolution(const Tensor<TIn>& in_nchw,
|
||||
make_ParallelTensorFunctor(f_out_hold, N, K, HTile, WTile)(num_thread);
|
||||
make_ParallelTensorFunctor(f_out, N, K, HTile, WTile)(num_thread);
|
||||
}
|
||||
|
||||
template <class T>
|
||||
void check_error(const Tensor<T>& ref, const Tensor<T>& result)
|
||||
{
|
||||
float error = 0;
|
||||
float max_diff = -1;
|
||||
float ref_value = 0, result_value = 0;
|
||||
for(int i = 0; i < ref.mData.size(); ++i)
|
||||
{
|
||||
error += std::abs(double(ref.mData[i]) - double(result.mData[i]));
|
||||
float diff = std::abs(double(ref.mData[i]) - double(result.mData[i]));
|
||||
if(max_diff < diff)
|
||||
{
|
||||
max_diff = diff;
|
||||
ref_value = ref.mData[i];
|
||||
result_value = result.mData[i];
|
||||
}
|
||||
}
|
||||
|
||||
std::cout << "error: " << error << std::endl;
|
||||
std::cout << "max_diff: " << max_diff << ", " << ref_value << ", " << result_value << std::endl;
|
||||
}
|
||||
|
||||
77
driver/include/host_conv_bwd_data.hpp
Normal file
77
driver/include/host_conv_bwd_data.hpp
Normal file
@@ -0,0 +1,77 @@
|
||||
#pragma once
|
||||
#include "tensor.hpp"
|
||||
|
||||
template <typename TIn,
|
||||
typename TWei,
|
||||
typename TOut,
|
||||
typename ConvStrides,
|
||||
typename ConvDilations,
|
||||
typename LeftPads,
|
||||
typename RightPads>
|
||||
void host_direct_convolution_backward_data(Tensor<TIn>& in_nchw,
|
||||
const Tensor<TWei>& wei_kcyx,
|
||||
const Tensor<TOut>& out_nkhw,
|
||||
ConvStrides,
|
||||
ConvDilations,
|
||||
LeftPads,
|
||||
RightPads)
|
||||
{
|
||||
using namespace ck;
|
||||
|
||||
int N = in_nchw.mDesc.GetLengths()[0];
|
||||
int C = in_nchw.mDesc.GetLengths()[1];
|
||||
int HI = in_nchw.mDesc.GetLengths()[2];
|
||||
int WI = in_nchw.mDesc.GetLengths()[3];
|
||||
|
||||
std::size_t K = wei_kcyx.mDesc.GetLengths()[0];
|
||||
std::size_t Y = wei_kcyx.mDesc.GetLengths()[2];
|
||||
std::size_t X = wei_kcyx.mDesc.GetLengths()[3];
|
||||
|
||||
std::size_t HO = out_nkhw.mDesc.GetLengths()[2];
|
||||
std::size_t WO = out_nkhw.mDesc.GetLengths()[3];
|
||||
|
||||
auto f = [&](auto n, auto c, auto hi, auto wi) {
|
||||
double v = 0;
|
||||
|
||||
for(int y = 0; y < Y; ++y)
|
||||
{
|
||||
int h_tmp = hi + LeftPads{}[0] - y * ConvDilations{}[0];
|
||||
|
||||
if(h_tmp % ConvStrides{}[0] == 0)
|
||||
{
|
||||
int ho = h_tmp / ConvStrides{}[0];
|
||||
|
||||
if(ho >= 0 && ho < HO)
|
||||
{
|
||||
for(int x = 0; x < X; ++x)
|
||||
{
|
||||
int w_tmp = wi + LeftPads{}[1] - x * ConvDilations{}[1];
|
||||
|
||||
if(w_tmp % ConvStrides{}[1] == 0)
|
||||
{
|
||||
int wo = w_tmp / ConvStrides{}[1];
|
||||
|
||||
if(wo >= 0 && wo < WO)
|
||||
{
|
||||
for(int k = 0; k < K; ++k)
|
||||
{
|
||||
v += out_nkhw(n, k, ho, wo) * wei_kcyx(k, c, y, x);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
in_nchw(n, c, hi, wi) = v;
|
||||
};
|
||||
|
||||
auto f_par = make_ParallelTensorFunctor(f,
|
||||
in_nchw.mDesc.GetLengths()[0],
|
||||
in_nchw.mDesc.GetLengths()[1],
|
||||
in_nchw.mDesc.GetLengths()[2],
|
||||
in_nchw.mDesc.GetLengths()[3]);
|
||||
|
||||
f_par(std::thread::hardware_concurrency());
|
||||
}
|
||||
@@ -68,10 +68,12 @@ auto construct_f_unpack_args(F, T args)
|
||||
struct TensorDescriptor
|
||||
{
|
||||
TensorDescriptor() = delete;
|
||||
TensorDescriptor(std::initializer_list<std::size_t> lens);
|
||||
TensorDescriptor(std::initializer_list<std::size_t> lens,
|
||||
std::initializer_list<std::size_t> strides);
|
||||
TensorDescriptor(std::vector<std::size_t> lens, std::vector<std::size_t> strides);
|
||||
|
||||
template <typename X>
|
||||
TensorDescriptor(std::vector<X> lens);
|
||||
|
||||
template <typename X, typename Y>
|
||||
TensorDescriptor(std::vector<X> lens, std::vector<Y> strides);
|
||||
|
||||
void CalculateStrides();
|
||||
|
||||
@@ -269,4 +271,39 @@ struct Tensor
|
||||
std::vector<T> mData;
|
||||
};
|
||||
|
||||
void ostream_TensorDescriptor(const TensorDescriptor& desc, std::ostream& os = std::cout)
|
||||
{
|
||||
os << "dim " << desc.GetNumOfDimension() << ", ";
|
||||
|
||||
os << "lengths {";
|
||||
LogRange(os, desc.GetLengths(), ", ");
|
||||
os << "}, ";
|
||||
|
||||
os << "strides {";
|
||||
LogRange(os, desc.GetStrides(), ", ");
|
||||
os << "}" << std::endl;
|
||||
}
|
||||
|
||||
template <class T>
|
||||
void check_error(const Tensor<T>& ref, const Tensor<T>& result)
|
||||
{
|
||||
float error = 0;
|
||||
float max_diff = -1;
|
||||
float ref_value = 0, result_value = 0;
|
||||
for(int i = 0; i < ref.mData.size(); ++i)
|
||||
{
|
||||
error += std::abs(double(ref.mData[i]) - double(result.mData[i]));
|
||||
float diff = std::abs(double(ref.mData[i]) - double(result.mData[i]));
|
||||
if(max_diff < diff)
|
||||
{
|
||||
max_diff = diff;
|
||||
ref_value = ref.mData[i];
|
||||
result_value = result.mData[i];
|
||||
}
|
||||
}
|
||||
|
||||
std::cout << "error: " << error << std::endl;
|
||||
std::cout << "max_diff: " << max_diff << ", " << ref_value << ", " << result_value << std::endl;
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
57
driver/include/tensor_generator.hpp
Normal file
57
driver/include/tensor_generator.hpp
Normal file
@@ -0,0 +1,57 @@
|
||||
#ifndef TENSOR_GENERATOR_HPP
|
||||
#define TENSOR_GENERATOR_HPP
|
||||
|
||||
#include "config.hpp"
|
||||
|
||||
struct GeneratorTensor_1
|
||||
{
|
||||
int value = 1;
|
||||
|
||||
template <class... Is>
|
||||
double operator()(Is... is)
|
||||
{
|
||||
return value;
|
||||
}
|
||||
};
|
||||
|
||||
struct GeneratorTensor_2
|
||||
{
|
||||
int min_value = 0;
|
||||
int max_value = 1;
|
||||
|
||||
template <class... Is>
|
||||
double operator()(Is...)
|
||||
{
|
||||
return (std::rand() % (max_value - min_value)) + min_value;
|
||||
}
|
||||
};
|
||||
|
||||
struct GeneratorTensor_3
|
||||
{
|
||||
template <class... Is>
|
||||
double operator()(Is... is)
|
||||
{
|
||||
std::array<ck::index_t, sizeof...(Is)> dims = {{static_cast<ck::index_t>(is)...}};
|
||||
|
||||
auto f_acc = [](auto a, auto b) { return 10 * a + b; };
|
||||
|
||||
return std::accumulate(dims.begin(), dims.end(), ck::index_t(0), f_acc);
|
||||
}
|
||||
};
|
||||
|
||||
struct GeneratorTensor_Checkboard
|
||||
{
|
||||
template <class... Ts>
|
||||
double operator()(Ts... Xs) const
|
||||
{
|
||||
std::array<ck::index_t, sizeof...(Ts)> dims = {{Xs...}};
|
||||
return std::accumulate(dims.begin(),
|
||||
dims.end(),
|
||||
true,
|
||||
[](bool init, ck::index_t x) -> int { return init != (x % 2); })
|
||||
? 1
|
||||
: -1;
|
||||
}
|
||||
};
|
||||
|
||||
#endif
|
||||
385
driver/src/col2im_driver.cpp
Normal file
385
driver/src/col2im_driver.cpp
Normal file
@@ -0,0 +1,385 @@
|
||||
#include <iostream>
|
||||
#include <numeric>
|
||||
#include <initializer_list>
|
||||
#include <cstdlib>
|
||||
#include <stdlib.h>
|
||||
#include "config.hpp"
|
||||
#include "tensor_descriptor.hpp"
|
||||
#include "tensor_descriptor_helper.hpp"
|
||||
#include "print_array.hpp"
|
||||
#include "print_sequence.hpp"
|
||||
#include "device.hpp"
|
||||
#include "tensor_generator.hpp"
|
||||
#include "device_tensor.hpp"
|
||||
#include "conv_common.hpp"
|
||||
#include "host_col2im.hpp"
|
||||
#include "device_col2im_eb_nchw.hpp"
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
using namespace ck;
|
||||
|
||||
#if 1
|
||||
constexpr index_t N = 2;
|
||||
constexpr index_t C = 8;
|
||||
constexpr index_t HI = 8;
|
||||
constexpr index_t WI = 8;
|
||||
constexpr index_t K = 128;
|
||||
constexpr index_t Y = 4;
|
||||
constexpr index_t X = 4;
|
||||
|
||||
using ConvStrides = Sequence<1, 1>;
|
||||
using ConvDilations = Sequence<1, 1>;
|
||||
|
||||
using LeftPads = Sequence<1, 1>;
|
||||
using RightPads = Sequence<2, 2>;
|
||||
#elif 0
|
||||
// 3x3, 34x34
|
||||
constexpr index_t N = 64;
|
||||
constexpr index_t C = 256;
|
||||
constexpr index_t HI = 34;
|
||||
constexpr index_t WI = 34;
|
||||
constexpr index_t K = 128;
|
||||
constexpr index_t Y = 3;
|
||||
constexpr index_t X = 3;
|
||||
|
||||
using ConvStrides = Sequence<1, 1>;
|
||||
using ConvDilations = Sequence<1, 1>;
|
||||
|
||||
using LeftPads = Sequence<0, 0>;
|
||||
using RightPads = Sequence<0, 0>;
|
||||
#elif 0
|
||||
// 1x1 filter, 8x8 image
|
||||
// cudnn@V100 68%, ck@V100 72%, ck@P100 52%, ck@VII 42%
|
||||
constexpr index_t N = 64;
|
||||
constexpr index_t C = 1536;
|
||||
constexpr index_t HI = 8;
|
||||
constexpr index_t WI = 8;
|
||||
constexpr index_t K = 256;
|
||||
constexpr index_t Y = 1;
|
||||
constexpr index_t X = 1;
|
||||
|
||||
using ConvStrides = Sequence<1, 1>;
|
||||
using ConvDilations = Sequence<1, 1>;
|
||||
|
||||
using LeftPads = Sequence<0, 0>;
|
||||
using RightPads = Sequence<0, 0>;
|
||||
#elif 0
|
||||
// 1x1 filter, 8x8 image
|
||||
// cudnn@V100 77%, ck@V100 76%, ck@P100 79%, ck@VII 51%
|
||||
constexpr index_t N = 128;
|
||||
constexpr index_t C = 2048;
|
||||
constexpr index_t HI = 8;
|
||||
constexpr index_t WI = 8;
|
||||
constexpr index_t K = 384;
|
||||
constexpr index_t Y = 1;
|
||||
constexpr index_t X = 1;
|
||||
|
||||
using ConvStrides = Sequence<1, 1>;
|
||||
using ConvDilations = Sequence<1, 1>;
|
||||
|
||||
using LeftPads = Sequence<0, 0>;
|
||||
using RightPads = Sequence<0, 0>;
|
||||
#elif 0
|
||||
// 1x1 filter, 7x7 image
|
||||
// cudnn@V100 82%, ck@V100 76%, ck@P100 67%, ck@VII 64%
|
||||
constexpr index_t N = 128;
|
||||
constexpr index_t C = 832;
|
||||
constexpr index_t HI = 7;
|
||||
constexpr index_t WI = 7;
|
||||
constexpr index_t K = 384;
|
||||
constexpr index_t Y = 1;
|
||||
constexpr index_t X = 1;
|
||||
|
||||
using ConvStrides = Sequence<1, 1>;
|
||||
using ConvDilations = Sequence<1, 1>;
|
||||
|
||||
using LeftPads = Sequence<0, 0>;
|
||||
using RightPads = Sequence<0, 0>;
|
||||
#elif 0
|
||||
// 1x1 filter, 8x8 image
|
||||
// cudnn@V100 83%, ck@V100 75%, ck@P100 78%, ck@VII 65%
|
||||
constexpr index_t N = 128;
|
||||
constexpr index_t C = 1280;
|
||||
constexpr index_t HI = 8;
|
||||
constexpr index_t WI = 8;
|
||||
constexpr index_t K = 384;
|
||||
constexpr index_t Y = 1;
|
||||
constexpr index_t X = 1;
|
||||
|
||||
using ConvStrides = Sequence<1, 1>;
|
||||
using ConvDilations = Sequence<1, 1>;
|
||||
|
||||
using LeftPads = Sequence<0, 0>;
|
||||
using RightPads = Sequence<0, 0>;
|
||||
#elif 0
|
||||
// 1x1 filter, 14x14 image
|
||||
// cudnn@V100 62%, ck@V100 68%, ck@P100 70%, ck@VII 50%
|
||||
constexpr index_t N = 128;
|
||||
constexpr index_t C = 512;
|
||||
constexpr index_t HI = 14;
|
||||
constexpr index_t WI = 14;
|
||||
constexpr index_t K = 128;
|
||||
constexpr index_t Y = 1;
|
||||
constexpr index_t X = 1;
|
||||
|
||||
using ConvStrides = Sequence<1, 1>;
|
||||
using ConvDilations = Sequence<1, 1>;
|
||||
|
||||
using LeftPads = Sequence<0, 0>;
|
||||
using RightPads = Sequence<0, 0>;
|
||||
#elif 0
|
||||
// 1x1 filter, 8x8 image
|
||||
// cudnn@V100 74%, ck@V100 57%, ck@P100 78%, ck@VII 61%
|
||||
constexpr index_t N = 64;
|
||||
constexpr index_t C = 1536;
|
||||
constexpr index_t HI = 8;
|
||||
constexpr index_t WI = 8;
|
||||
constexpr index_t K = 384;
|
||||
constexpr index_t Y = 1;
|
||||
constexpr index_t X = 1;
|
||||
|
||||
using ConvStrides = Sequence<1, 1>;
|
||||
using ConvDilations = Sequence<1, 1>;
|
||||
|
||||
using LeftPads = Sequence<0, 0>;
|
||||
using RightPads = Sequence<0, 0>;
|
||||
#elif 0
|
||||
// 1x1 filter, 28x28 image
|
||||
// cudnn@V100 86%, ck@V100 84%, ck@P100 80%, ck@VII 69%
|
||||
constexpr index_t N = 128;
|
||||
constexpr index_t C = 256;
|
||||
constexpr index_t HI = 28;
|
||||
constexpr index_t WI = 28;
|
||||
constexpr index_t K = 128;
|
||||
constexpr index_t Y = 1;
|
||||
constexpr index_t X = 1;
|
||||
|
||||
using ConvStrides = Sequence<1, 1>;
|
||||
using ConvDilations = Sequence<1, 1>;
|
||||
|
||||
using LeftPads = Sequence<0, 0>;
|
||||
using RightPads = Sequence<0, 0>;
|
||||
#elif 0
|
||||
// 1x1 filter, 7x7 image
|
||||
// cudnn@V100 71%, ck@V100 55%, ck@P100 70%, ck@VII 62%
|
||||
constexpr index_t N = 128;
|
||||
constexpr index_t C = 832;
|
||||
constexpr index_t HI = 7;
|
||||
constexpr index_t WI = 7;
|
||||
constexpr index_t K = 256;
|
||||
constexpr index_t Y = 1;
|
||||
constexpr index_t X = 1;
|
||||
|
||||
using ConvStrides = Sequence<1, 1>;
|
||||
using ConvDilations = Sequence<1, 1>;
|
||||
|
||||
using LeftPads = Sequence<0, 0>;
|
||||
using RightPads = Sequence<0, 0>;
|
||||
#elif 0
|
||||
// 1x1 filter, 17x17 input
|
||||
// cudnn@V100 81%, ck@V100 76%, ck@P100 70%, ck@VII 76%
|
||||
constexpr index_t N = 128;
|
||||
constexpr index_t C = 768;
|
||||
constexpr index_t HI = 17;
|
||||
constexpr index_t WI = 17;
|
||||
constexpr index_t K = 128;
|
||||
constexpr index_t Y = 1;
|
||||
constexpr index_t X = 1;
|
||||
|
||||
using ConvStrides = Sequence<1, 1>;
|
||||
using ConvDilations = Sequence<1, 1>;
|
||||
|
||||
using LeftPads = Sequence<0, 0>;
|
||||
using RightPads = Sequence<0, 0>;
|
||||
#elif 0
|
||||
// 1x1 filter, 14x14 image
|
||||
// cudnn@V100 73%, ck@V100 71%, ck@P100 70%, ck@VII 64%
|
||||
constexpr index_t N = 128;
|
||||
constexpr index_t C = 528;
|
||||
constexpr index_t HI = 14;
|
||||
constexpr index_t WI = 14;
|
||||
constexpr index_t K = 128;
|
||||
constexpr index_t Y = 1;
|
||||
constexpr index_t X = 1;
|
||||
|
||||
using ConvStrides = Sequence<1, 1>;
|
||||
using ConvDilations = Sequence<1, 1>;
|
||||
|
||||
using LeftPads = Sequence<0, 0>;
|
||||
using RightPads = Sequence<0, 0>;
|
||||
#elif 0
|
||||
// 1x1 filter, 14x14 image
|
||||
// cudnn@V100 73%, ck@V100 72%, ck@P100 79%, ck@VII 75%
|
||||
constexpr index_t N = 128;
|
||||
constexpr index_t C = 528;
|
||||
constexpr index_t HI = 14;
|
||||
constexpr index_t WI = 14;
|
||||
constexpr index_t K = 256;
|
||||
constexpr index_t Y = 1;
|
||||
constexpr index_t X = 1;
|
||||
|
||||
using ConvStrides = Sequence<1, 1>;
|
||||
using ConvDilations = Sequence<1, 1>;
|
||||
|
||||
using LeftPads = Sequence<0, 0>;
|
||||
using RightPads = Sequence<0, 0>;
|
||||
#elif 0
|
||||
// 1x1 filter, 7x7 image
|
||||
// cudnn@V100 49%, ck@V100 50%, ck@P100 61%, ck@VII 52%
|
||||
constexpr index_t N = 128;
|
||||
constexpr index_t C = 832;
|
||||
constexpr index_t HI = 7;
|
||||
constexpr index_t WI = 7;
|
||||
constexpr index_t K = 128;
|
||||
constexpr index_t Y = 1;
|
||||
constexpr index_t X = 1;
|
||||
|
||||
using ConvStrides = Sequence<1, 1>;
|
||||
using ConvDilations = Sequence<1, 1>;
|
||||
|
||||
using LeftPads = Sequence<0, 0>;
|
||||
using RightPads = Sequence<0, 0>;
|
||||
#elif 0
|
||||
// 3x3 filter, 2x2 stride, 35x35 input, 17x17 output
|
||||
// cudnn@V100 90%, ck@V100 93%, ck@P100 83%, ck@VII 81%
|
||||
constexpr index_t N = 128;
|
||||
constexpr index_t C = 288;
|
||||
constexpr index_t HI = 35;
|
||||
constexpr index_t WI = 35;
|
||||
constexpr index_t K = 384;
|
||||
constexpr index_t Y = 3;
|
||||
constexpr index_t X = 3;
|
||||
|
||||
using ConvStrides = Sequence<2, 2>;
|
||||
using ConvDilations = Sequence<1, 1>;
|
||||
|
||||
using LeftPads = Sequence<0, 0>;
|
||||
using RightPads = Sequence<0, 0>;
|
||||
#elif 0
|
||||
// 5x5 filter, 2x2 pad, 7x7 input
|
||||
constexpr index_t N = 128;
|
||||
constexpr index_t C = 48;
|
||||
constexpr index_t HI = 7;
|
||||
constexpr index_t WI = 7;
|
||||
constexpr index_t K = 128;
|
||||
constexpr index_t Y = 5;
|
||||
constexpr index_t X = 5;
|
||||
|
||||
using ConvStrides = Sequence<1, 1>;
|
||||
using ConvDilations = Sequence<1, 1>;
|
||||
|
||||
using LeftPads = Sequence<2, 2>;
|
||||
using RightPads = Sequence<2, 2>;
|
||||
#elif 0
|
||||
// 7x1 filter, 3x0 pad, 17x17 input
|
||||
constexpr index_t N = 128;
|
||||
constexpr index_t C = 128;
|
||||
constexpr index_t HI = 17;
|
||||
constexpr index_t WI = 17;
|
||||
constexpr index_t K = 128;
|
||||
constexpr index_t Y = 7;
|
||||
constexpr index_t X = 1;
|
||||
|
||||
using ConvStrides = Sequence<1, 1>;
|
||||
using ConvDilations = Sequence<1, 1>;
|
||||
|
||||
using LeftPads = Sequence<3, 0>;
|
||||
using RightPads = Sequence<3, 0>;
|
||||
#elif 1
|
||||
// 1x7 filter, 0x3 pad, 17x17 input
|
||||
constexpr index_t N = 128;
|
||||
constexpr index_t C = 128;
|
||||
constexpr index_t HI = 17;
|
||||
constexpr index_t WI = 17;
|
||||
constexpr index_t K = 128;
|
||||
constexpr index_t Y = 1;
|
||||
constexpr index_t X = 7;
|
||||
|
||||
using ConvStrides = Sequence<1, 1>;
|
||||
using ConvDilations = Sequence<1, 1>;
|
||||
|
||||
using LeftPads = Sequence<0, 3>;
|
||||
using RightPads = Sequence<0, 3>;
|
||||
#endif
|
||||
|
||||
constexpr auto img_nchw_desc = make_native_tensor_descriptor_packed(Sequence<N, C, HI, WI>{});
|
||||
constexpr auto wei_kcyx_desc = make_native_tensor_descriptor_packed(Sequence<K, C, Y, X>{});
|
||||
constexpr auto out_nkhw_desc = get_convolution_output_default_4d_tensor_descriptor(
|
||||
img_nchw_desc, wei_kcyx_desc, ConvStrides{}, ConvDilations{}, LeftPads{}, RightPads{});
|
||||
|
||||
constexpr index_t HO = out_nkhw_desc.GetLengths()[2];
|
||||
constexpr index_t WO = out_nkhw_desc.GetLengths()[3];
|
||||
|
||||
constexpr auto col_eb_desc =
|
||||
make_native_tensor_descriptor_packed(Sequence<C * Y * X, N * HO * WO>{});
|
||||
|
||||
using FilterSizes = Sequence<Y, X>;
|
||||
using OutputSizes = Sequence<HO, WO>;
|
||||
|
||||
ostream_ConstantTensorDescriptor(col_eb_desc, std::cout << "col_eb_desc: ");
|
||||
ostream_ConstantTensorDescriptor(img_nchw_desc, std::cout << "img_nchw_desc: ");
|
||||
print_sequence("FilterSizes", FilterSizes{});
|
||||
print_sequence("OutputSizes", OutputSizes{});
|
||||
print_sequence("LeftPads", LeftPads{});
|
||||
print_sequence("LeftPads", LeftPads{});
|
||||
print_sequence("RightPads", RightPads{});
|
||||
print_sequence("ConvStrides", ConvStrides{});
|
||||
print_sequence("ConvDilations", ConvDilations{});
|
||||
|
||||
Tensor<float> col_eb(make_TensorDescriptor(col_eb_desc));
|
||||
Tensor<float> img_nchw_host(make_TensorDescriptor(img_nchw_desc));
|
||||
Tensor<float> img_nchw_device(make_TensorDescriptor(img_nchw_desc));
|
||||
|
||||
std::size_t num_thread = std::thread::hardware_concurrency();
|
||||
|
||||
if(argc != 3)
|
||||
{
|
||||
printf("arg1: do_verification, arg2: nrepeat\n");
|
||||
exit(1);
|
||||
}
|
||||
|
||||
bool do_verification = atoi(argv[1]);
|
||||
std::size_t nrepeat = atoi(argv[2]);
|
||||
|
||||
if(do_verification)
|
||||
{
|
||||
#if 0
|
||||
col_eb.GenerateTensorValue(GeneratorTensor_1{}, num_thread);
|
||||
#else
|
||||
col_eb.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread);
|
||||
#endif
|
||||
}
|
||||
|
||||
device_col2im_eb_nchw(col_eb_desc,
|
||||
col_eb,
|
||||
img_nchw_desc,
|
||||
img_nchw_device,
|
||||
FilterSizes{},
|
||||
OutputSizes{},
|
||||
ConvStrides{},
|
||||
ConvDilations{},
|
||||
LeftPads{},
|
||||
RightPads{},
|
||||
nrepeat);
|
||||
|
||||
if(do_verification)
|
||||
{
|
||||
host_col2im(col_eb,
|
||||
img_nchw_host,
|
||||
FilterSizes{},
|
||||
OutputSizes{},
|
||||
ConvStrides{},
|
||||
ConvDilations{},
|
||||
LeftPads{},
|
||||
RightPads{});
|
||||
|
||||
check_error(img_nchw_host, img_nchw_device);
|
||||
|
||||
#if 0
|
||||
LogRange(std::cout << "col_eb : ", col_eb.mData, ",") << std::endl;
|
||||
LogRange(std::cout << "img_nchw_host : ", img_nchw_host.mData, ",") << std::endl;
|
||||
LogRange(std::cout << "img_nchw_device : ", img_nchw_device.mData, ",") << std::endl;
|
||||
#endif
|
||||
}
|
||||
}
|
||||
1
driver/src/col2im_driver.cu
Symbolic link
1
driver/src/col2im_driver.cu
Symbolic link
@@ -0,0 +1 @@
|
||||
col2im_driver.cpp
|
||||
374
driver/src/conv_bwd_data_driver.cpp
Normal file
374
driver/src/conv_bwd_data_driver.cpp
Normal file
@@ -0,0 +1,374 @@
|
||||
#include <iostream>
|
||||
#include <numeric>
|
||||
#include <initializer_list>
|
||||
#include <cstdlib>
|
||||
#include <stdlib.h>
|
||||
#include "config.hpp"
|
||||
#include "tensor_descriptor.hpp"
|
||||
#include "tensor_descriptor_helper.hpp"
|
||||
#include "print_array.hpp"
|
||||
#include "print_sequence.hpp"
|
||||
#include "device.hpp"
|
||||
#include "tensor_generator.hpp"
|
||||
#include "device_tensor.hpp"
|
||||
#include "conv_common.hpp"
|
||||
#include "host_conv_bwd_data.hpp"
|
||||
#include "device_convolution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw.hpp"
|
||||
#include "device_convolution_backward_data_implicit_gemm_v1r2_nchw_kcyx_nkhw.hpp"
|
||||
#include "device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw.hpp"
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
using namespace ck;
|
||||
|
||||
#if 0
|
||||
constexpr index_t N = 8;
|
||||
constexpr index_t C = 128;
|
||||
constexpr index_t HI = 16;
|
||||
constexpr index_t WI = 16;
|
||||
constexpr index_t K = 8;
|
||||
constexpr index_t Y = 2;
|
||||
constexpr index_t X = 2;
|
||||
|
||||
using ConvStrides = Sequence<4, 4>;
|
||||
using ConvDilations = Sequence<2, 2>;
|
||||
|
||||
using LeftPads = Sequence<0, 0>;
|
||||
using RightPads = Sequence<0, 0>;
|
||||
#elif 0
|
||||
// 3x3, 34x34
|
||||
constexpr index_t N = 64;
|
||||
constexpr index_t C = 256;
|
||||
constexpr index_t HI = 34;
|
||||
constexpr index_t WI = 34;
|
||||
constexpr index_t K = 128;
|
||||
constexpr index_t Y = 3;
|
||||
constexpr index_t X = 3;
|
||||
|
||||
using ConvStrides = Sequence<1, 1>;
|
||||
using ConvDilations = Sequence<1, 1>;
|
||||
|
||||
using LeftPads = Sequence<0, 0>;
|
||||
using RightPads = Sequence<0, 0>;
|
||||
#elif 0
|
||||
// 1x1 filter, 8x8 image
|
||||
constexpr index_t N = 64;
|
||||
constexpr index_t C = 1536;
|
||||
constexpr index_t HI = 8;
|
||||
constexpr index_t WI = 8;
|
||||
constexpr index_t K = 256;
|
||||
constexpr index_t Y = 1;
|
||||
constexpr index_t X = 1;
|
||||
|
||||
using ConvStrides = Sequence<1, 1>;
|
||||
using ConvDilations = Sequence<1, 1>;
|
||||
|
||||
using LeftPads = Sequence<0, 0>;
|
||||
using RightPads = Sequence<0, 0>;
|
||||
#elif 0
|
||||
// 1x1 filter, 8x8 image
|
||||
constexpr index_t N = 128;
|
||||
constexpr index_t C = 2048;
|
||||
constexpr index_t HI = 8;
|
||||
constexpr index_t WI = 8;
|
||||
constexpr index_t K = 384;
|
||||
constexpr index_t Y = 1;
|
||||
constexpr index_t X = 1;
|
||||
|
||||
using ConvStrides = Sequence<1, 1>;
|
||||
using ConvDilations = Sequence<1, 1>;
|
||||
|
||||
using LeftPads = Sequence<0, 0>;
|
||||
using RightPads = Sequence<0, 0>;
|
||||
#elif 0
|
||||
// 1x1 filter, 7x7 image
|
||||
constexpr index_t N = 128;
|
||||
constexpr index_t C = 832;
|
||||
constexpr index_t HI = 7;
|
||||
constexpr index_t WI = 7;
|
||||
constexpr index_t K = 384;
|
||||
constexpr index_t Y = 1;
|
||||
constexpr index_t X = 1;
|
||||
|
||||
using ConvStrides = Sequence<1, 1>;
|
||||
using ConvDilations = Sequence<1, 1>;
|
||||
|
||||
using LeftPads = Sequence<0, 0>;
|
||||
using RightPads = Sequence<0, 0>;
|
||||
#elif 0
|
||||
// 1x1 filter, 8x8 image
|
||||
constexpr index_t N = 128;
|
||||
constexpr index_t C = 1280;
|
||||
constexpr index_t HI = 8;
|
||||
constexpr index_t WI = 8;
|
||||
constexpr index_t K = 384;
|
||||
constexpr index_t Y = 1;
|
||||
constexpr index_t X = 1;
|
||||
|
||||
using ConvStrides = Sequence<1, 1>;
|
||||
using ConvDilations = Sequence<1, 1>;
|
||||
|
||||
using LeftPads = Sequence<0, 0>;
|
||||
using RightPads = Sequence<0, 0>;
|
||||
#elif 0
|
||||
// 1x1 filter, 14x14 image
|
||||
constexpr index_t N = 128;
|
||||
constexpr index_t C = 512;
|
||||
constexpr index_t HI = 14;
|
||||
constexpr index_t WI = 14;
|
||||
constexpr index_t K = 128;
|
||||
constexpr index_t Y = 1;
|
||||
constexpr index_t X = 1;
|
||||
|
||||
using ConvStrides = Sequence<1, 1>;
|
||||
using ConvDilations = Sequence<1, 1>;
|
||||
|
||||
using LeftPads = Sequence<0, 0>;
|
||||
using RightPads = Sequence<0, 0>;
|
||||
#elif 0
|
||||
// 1x1 filter, 8x8 image
|
||||
constexpr index_t N = 64;
|
||||
constexpr index_t C = 1536;
|
||||
constexpr index_t HI = 8;
|
||||
constexpr index_t WI = 8;
|
||||
constexpr index_t K = 384;
|
||||
constexpr index_t Y = 1;
|
||||
constexpr index_t X = 1;
|
||||
|
||||
using ConvStrides = Sequence<1, 1>;
|
||||
using ConvDilations = Sequence<1, 1>;
|
||||
|
||||
using LeftPads = Sequence<0, 0>;
|
||||
using RightPads = Sequence<0, 0>;
|
||||
#elif 0
|
||||
// 1x1 filter, 28x28 image
|
||||
constexpr index_t N = 128;
|
||||
constexpr index_t C = 256;
|
||||
constexpr index_t HI = 28;
|
||||
constexpr index_t WI = 28;
|
||||
constexpr index_t K = 128;
|
||||
constexpr index_t Y = 1;
|
||||
constexpr index_t X = 1;
|
||||
|
||||
using ConvStrides = Sequence<1, 1>;
|
||||
using ConvDilations = Sequence<1, 1>;
|
||||
|
||||
using LeftPads = Sequence<0, 0>;
|
||||
using RightPads = Sequence<0, 0>;
|
||||
#elif 0
|
||||
// 1x1 filter, 7x7 image
|
||||
constexpr index_t N = 128;
|
||||
constexpr index_t C = 832;
|
||||
constexpr index_t HI = 7;
|
||||
constexpr index_t WI = 7;
|
||||
constexpr index_t K = 256;
|
||||
constexpr index_t Y = 1;
|
||||
constexpr index_t X = 1;
|
||||
|
||||
using ConvStrides = Sequence<1, 1>;
|
||||
using ConvDilations = Sequence<1, 1>;
|
||||
|
||||
using LeftPads = Sequence<0, 0>;
|
||||
using RightPads = Sequence<0, 0>;
|
||||
#elif 0
|
||||
// 1x1 filter, 17x17 input
|
||||
constexpr index_t N = 128;
|
||||
constexpr index_t C = 768;
|
||||
constexpr index_t HI = 17;
|
||||
constexpr index_t WI = 17;
|
||||
constexpr index_t K = 128;
|
||||
constexpr index_t Y = 1;
|
||||
constexpr index_t X = 1;
|
||||
|
||||
using ConvStrides = Sequence<1, 1>;
|
||||
using ConvDilations = Sequence<1, 1>;
|
||||
|
||||
using LeftPads = Sequence<0, 0>;
|
||||
using RightPads = Sequence<0, 0>;
|
||||
#elif 0
|
||||
// 1x1 filter, 14x14 image
|
||||
constexpr index_t N = 128;
|
||||
constexpr index_t C = 528;
|
||||
constexpr index_t HI = 14;
|
||||
constexpr index_t WI = 14;
|
||||
constexpr index_t K = 128;
|
||||
constexpr index_t Y = 1;
|
||||
constexpr index_t X = 1;
|
||||
|
||||
using ConvStrides = Sequence<1, 1>;
|
||||
using ConvDilations = Sequence<1, 1>;
|
||||
|
||||
using LeftPads = Sequence<0, 0>;
|
||||
using RightPads = Sequence<0, 0>;
|
||||
#elif 0
|
||||
// 1x1 filter, 14x14 image
|
||||
constexpr index_t N = 128;
|
||||
constexpr index_t C = 528;
|
||||
constexpr index_t HI = 14;
|
||||
constexpr index_t WI = 14;
|
||||
constexpr index_t K = 256;
|
||||
constexpr index_t Y = 1;
|
||||
constexpr index_t X = 1;
|
||||
|
||||
using ConvStrides = Sequence<1, 1>;
|
||||
using ConvDilations = Sequence<1, 1>;
|
||||
|
||||
using LeftPads = Sequence<0, 0>;
|
||||
using RightPads = Sequence<0, 0>;
|
||||
#elif 0
|
||||
// 1x1 filter, 7x7 image
|
||||
constexpr index_t N = 128;
|
||||
constexpr index_t C = 832;
|
||||
constexpr index_t HI = 7;
|
||||
constexpr index_t WI = 7;
|
||||
constexpr index_t K = 128;
|
||||
constexpr index_t Y = 1;
|
||||
constexpr index_t X = 1;
|
||||
|
||||
using ConvStrides = Sequence<1, 1>;
|
||||
using ConvDilations = Sequence<1, 1>;
|
||||
|
||||
using LeftPads = Sequence<0, 0>;
|
||||
using RightPads = Sequence<0, 0>;
|
||||
#elif 0
|
||||
// 3x3 filter, 2x2 stride, 35x35 input, 17x17 output
|
||||
constexpr index_t N = 128;
|
||||
constexpr index_t C = 288;
|
||||
constexpr index_t HI = 35;
|
||||
constexpr index_t WI = 35;
|
||||
constexpr index_t K = 384;
|
||||
constexpr index_t Y = 3;
|
||||
constexpr index_t X = 3;
|
||||
|
||||
using ConvStrides = Sequence<2, 2>;
|
||||
using ConvDilations = Sequence<1, 1>;
|
||||
|
||||
using LeftPads = Sequence<0, 0>;
|
||||
using RightPads = Sequence<0, 0>;
|
||||
#elif 0
|
||||
// 5x5 filter, 2x2 pad, 7x7 input
|
||||
constexpr index_t N = 128;
|
||||
constexpr index_t C = 48;
|
||||
constexpr index_t HI = 7;
|
||||
constexpr index_t WI = 7;
|
||||
constexpr index_t K = 128;
|
||||
constexpr index_t Y = 5;
|
||||
constexpr index_t X = 5;
|
||||
|
||||
using ConvStrides = Sequence<1, 1>;
|
||||
using ConvDilations = Sequence<1, 1>;
|
||||
|
||||
using LeftPads = Sequence<2, 2>;
|
||||
using RightPads = Sequence<2, 2>;
|
||||
#elif 0
|
||||
// 7x1 filter, 3x0 pad, 17x17 input
|
||||
constexpr index_t N = 128;
|
||||
constexpr index_t C = 128;
|
||||
constexpr index_t HI = 17;
|
||||
constexpr index_t WI = 17;
|
||||
constexpr index_t K = 128;
|
||||
constexpr index_t Y = 7;
|
||||
constexpr index_t X = 1;
|
||||
|
||||
using ConvStrides = Sequence<1, 1>;
|
||||
using ConvDilations = Sequence<1, 1>;
|
||||
|
||||
using LeftPads = Sequence<3, 0>;
|
||||
using RightPads = Sequence<3, 0>;
|
||||
#elif 1
|
||||
// 1x7 filter, 0x3 pad, 17x17 input
|
||||
constexpr index_t N = 128;
|
||||
constexpr index_t C = 128;
|
||||
constexpr index_t HI = 17;
|
||||
constexpr index_t WI = 17;
|
||||
constexpr index_t K = 128;
|
||||
constexpr index_t Y = 1;
|
||||
constexpr index_t X = 7;
|
||||
|
||||
using ConvStrides = Sequence<1, 1>;
|
||||
using ConvDilations = Sequence<1, 1>;
|
||||
|
||||
using LeftPads = Sequence<0, 3>;
|
||||
using RightPads = Sequence<0, 3>;
|
||||
#endif
|
||||
|
||||
constexpr auto in_nchw_desc = make_native_tensor_descriptor_packed(Sequence<N, C, HI, WI>{});
|
||||
constexpr auto wei_kcyx_desc = make_native_tensor_descriptor_packed(Sequence<K, C, Y, X>{});
|
||||
constexpr auto out_nkhw_desc = get_convolution_output_default_4d_tensor_descriptor(
|
||||
in_nchw_desc, wei_kcyx_desc, ConvStrides{}, ConvDilations{}, LeftPads{}, RightPads{});
|
||||
|
||||
ostream_ConstantTensorDescriptor(in_nchw_desc, std::cout << "in_nchw_desc: ");
|
||||
ostream_ConstantTensorDescriptor(wei_kcyx_desc, std::cout << "wei_kcyx_desc: ");
|
||||
ostream_ConstantTensorDescriptor(out_nkhw_desc, std::cout << "out_nkhw_desc: ");
|
||||
print_sequence("LeftPads", LeftPads{});
|
||||
print_sequence("LeftPads", LeftPads{});
|
||||
print_sequence("RightPads", RightPads{});
|
||||
print_sequence("ConvStrides", ConvStrides{});
|
||||
print_sequence("ConvDilations", ConvDilations{});
|
||||
|
||||
Tensor<float> in_nchw_device(make_TensorDescriptor(in_nchw_desc));
|
||||
Tensor<float> in_nchw_host(make_TensorDescriptor(in_nchw_desc));
|
||||
Tensor<float> wei_kcyx(make_TensorDescriptor(wei_kcyx_desc));
|
||||
Tensor<float> out_nkhw(make_TensorDescriptor(out_nkhw_desc));
|
||||
|
||||
std::size_t num_thread = std::thread::hardware_concurrency();
|
||||
|
||||
if(argc != 3)
|
||||
{
|
||||
printf("arg1: do_verification, arg2: nrepeat\n");
|
||||
exit(1);
|
||||
}
|
||||
|
||||
bool do_verification = atoi(argv[1]);
|
||||
std::size_t nrepeat = atoi(argv[2]);
|
||||
|
||||
if(do_verification)
|
||||
{
|
||||
#if 0
|
||||
wei_kcyx.GenerateTensorValue(GeneratorTensor_1{1}, num_thread);
|
||||
out_nkhw.GenerateTensorValue(GeneratorTensor_1{1}, num_thread);
|
||||
#else
|
||||
wei_kcyx.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread);
|
||||
out_nkhw.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread);
|
||||
#endif
|
||||
}
|
||||
|
||||
#if 0
|
||||
device_convolution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw
|
||||
#elif 0
|
||||
device_convolution_backward_data_implicit_gemm_v1r2_nchw_kcyx_nkhw
|
||||
#else
|
||||
device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw
|
||||
#endif
|
||||
(in_nchw_desc,
|
||||
in_nchw_device,
|
||||
wei_kcyx_desc,
|
||||
wei_kcyx,
|
||||
out_nkhw_desc,
|
||||
out_nkhw,
|
||||
ConvStrides{},
|
||||
ConvDilations{},
|
||||
LeftPads{},
|
||||
RightPads{},
|
||||
nrepeat);
|
||||
|
||||
if(do_verification)
|
||||
{
|
||||
host_direct_convolution_backward_data(in_nchw_host,
|
||||
wei_kcyx,
|
||||
out_nkhw,
|
||||
ConvStrides{},
|
||||
ConvDilations{},
|
||||
LeftPads{},
|
||||
RightPads{});
|
||||
|
||||
check_error(in_nchw_host, in_nchw_device);
|
||||
|
||||
#if 0
|
||||
LogRange(std::cout << "out_nkhw : ", out_nkhw.mData, ",") << std::endl;
|
||||
LogRange(std::cout << "wei_kcyx : ", wei_kcyx.mData, ",") << std::endl;
|
||||
LogRange(std::cout << "in_nchw_host : ", in_nchw_host.mData, ",") << std::endl;
|
||||
LogRange(std::cout << "in_nchw_device : ", in_nchw_device.mData, ",") << std::endl;
|
||||
#endif
|
||||
}
|
||||
}
|
||||
1
driver/src/conv_bwd_data_driver.cu
Symbolic link
1
driver/src/conv_bwd_data_driver.cu
Symbolic link
@@ -0,0 +1 @@
|
||||
conv_bwd_data_driver.cpp
|
||||
@@ -8,8 +8,10 @@
|
||||
#include "print_array.hpp"
|
||||
#include "print_sequence.hpp"
|
||||
#include "device.hpp"
|
||||
#include "tensor_generator.hpp"
|
||||
#include "conv_common.hpp"
|
||||
#include "host_conv.hpp"
|
||||
#include "device_tensor.hpp"
|
||||
//#include "device_convolution_direct_v2_nchw_kcyx_nkhw.hpp"
|
||||
//#include "device_convolution_implicit_gemm_v1_chwn_cyxk_khwn.hpp"
|
||||
//#include "device_convolution_implicit_gemm_v1_chwn_cyxk_khwn_padded.hpp"
|
||||
@@ -23,73 +25,24 @@
|
||||
#include "device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw_deprecated.hpp"
|
||||
#include "device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp"
|
||||
|
||||
struct GeneratorTensor_1
|
||||
{
|
||||
template <class... Is>
|
||||
double operator()(Is... is)
|
||||
{
|
||||
return 1;
|
||||
}
|
||||
};
|
||||
|
||||
struct GeneratorTensor_2
|
||||
{
|
||||
int min_value = 0;
|
||||
int max_value = 1;
|
||||
|
||||
template <class... Is>
|
||||
double operator()(Is...)
|
||||
{
|
||||
return (std::rand() % (max_value - min_value)) + min_value;
|
||||
}
|
||||
};
|
||||
|
||||
struct GeneratorTensor_3
|
||||
{
|
||||
template <class... Is>
|
||||
double operator()(Is... is)
|
||||
{
|
||||
std::array<index_t, sizeof...(Is)> dims = {{static_cast<index_t>(is)...}};
|
||||
|
||||
auto f_acc = [](auto a, auto b) { return 10 * a + b; };
|
||||
|
||||
return std::accumulate(dims.begin(), dims.end(), index_t(0), f_acc);
|
||||
}
|
||||
};
|
||||
|
||||
struct GeneratorTensor_Checkboard
|
||||
{
|
||||
template <class... Ts>
|
||||
double operator()(Ts... Xs) const
|
||||
{
|
||||
std::array<index_t, sizeof...(Ts)> dims = {{Xs...}};
|
||||
return std::accumulate(dims.begin(),
|
||||
dims.end(),
|
||||
true,
|
||||
[](bool init, index_t x) -> int { return init != (x % 2); })
|
||||
? 1
|
||||
: -1;
|
||||
}
|
||||
};
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
using namespace ck;
|
||||
|
||||
#if 0
|
||||
constexpr index_t N = 128;
|
||||
constexpr index_t C = 128;
|
||||
constexpr index_t HI = 17;
|
||||
constexpr index_t WI = 17;
|
||||
constexpr index_t K = 128;
|
||||
constexpr index_t Y = 1;
|
||||
constexpr index_t X = 7;
|
||||
constexpr index_t N = 8;
|
||||
constexpr index_t C = 32;
|
||||
constexpr index_t HI = 28;
|
||||
constexpr index_t WI = 28;
|
||||
constexpr index_t K = 32;
|
||||
constexpr index_t Y = 5;
|
||||
constexpr index_t X = 5;
|
||||
|
||||
using ConvStrides = Sequence<1, 1>;
|
||||
using ConvDilations = Sequence<1, 1>;
|
||||
using ConvDilations = Sequence<2, 2>;
|
||||
|
||||
using LeftPads = Sequence<0, 3>;
|
||||
using RightPads = Sequence<0, 3>;
|
||||
using LeftPads = Sequence<0, 0>;
|
||||
using RightPads = Sequence<0, 0>;
|
||||
#elif 0
|
||||
// 3x3, 34x34
|
||||
constexpr index_t N = 64;
|
||||
@@ -297,7 +250,7 @@ int main(int argc, char* argv[])
|
||||
|
||||
using LeftPads = Sequence<0, 0>;
|
||||
using RightPads = Sequence<0, 0>;
|
||||
#elif 0
|
||||
#elif 1
|
||||
// 3x3 filter, 2x2 stride, 35x35 input, 17x17 output
|
||||
// cudnn@V100 90%, ck@V100 93%, ck@P100 83%, ck@VII 81%
|
||||
constexpr index_t N = 128;
|
||||
@@ -343,7 +296,7 @@ int main(int argc, char* argv[])
|
||||
|
||||
using LeftPads = Sequence<3, 0>;
|
||||
using RightPads = Sequence<3, 0>;
|
||||
#elif 1
|
||||
#elif 0
|
||||
// 1x7 filter, 0x3 pad, 17x17 input
|
||||
constexpr index_t N = 128;
|
||||
constexpr index_t C = 128;
|
||||
@@ -362,7 +315,7 @@ int main(int argc, char* argv[])
|
||||
|
||||
auto in_nchw_desc = make_ConstantTensorDescriptor_packed(Sequence<N, C, HI, WI>{});
|
||||
auto wei_kcyx_desc = make_ConstantTensorDescriptor_packed(Sequence<K, C, Y, X>{});
|
||||
auto out_nkhw_desc = get_convolution_with_padding_output_default_4d_tensor_descriptor(
|
||||
auto out_nkhw_desc = get_convolution_output_default_4d_tensor_descriptor_deprecated(
|
||||
in_nchw_desc, wei_kcyx_desc, ConvStrides{}, ConvDilations{}, LeftPads{}, RightPads{});
|
||||
|
||||
ostream_ConstantTensorDescriptor(in_nchw_desc, std::cout << "in_nchw_desc: ");
|
||||
@@ -492,7 +445,7 @@ int main(int argc, char* argv[])
|
||||
ConvStrides{},
|
||||
ConvDilations{},
|
||||
nrepeat);
|
||||
#elif 0
|
||||
#elif 1
|
||||
device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(in_nchw_desc,
|
||||
in_nchw,
|
||||
wei_kcyx_desc,
|
||||
1
driver/src/conv_driver.cu
Symbolic link
1
driver/src/conv_driver.cu
Symbolic link
@@ -0,0 +1 @@
|
||||
conv_driver.cpp
|
||||
@@ -1 +0,0 @@
|
||||
driver.cpp
|
||||
@@ -3,12 +3,14 @@
|
||||
|
||||
#include "tensor.hpp"
|
||||
|
||||
TensorDescriptor::TensorDescriptor(std::initializer_list<std::size_t> lens) : mLens(lens)
|
||||
template <typename X>
|
||||
TensorDescriptor::TensorDescriptor(std::vector<X> lens) : mLens(lens)
|
||||
{
|
||||
this->CalculateStrides();
|
||||
}
|
||||
|
||||
TensorDescriptor::TensorDescriptor(std::vector<std::size_t> lens, std::vector<std::size_t> strides)
|
||||
template <typename X, typename Y>
|
||||
TensorDescriptor::TensorDescriptor(std::vector<X> lens, std::vector<Y> strides)
|
||||
: mLens(lens), mStrides(strides)
|
||||
{
|
||||
}
|
||||
|
||||
@@ -4,5 +4,5 @@
|
||||
export KMDUMPLLVM=1
|
||||
export KMDUMPDIR=$PWD
|
||||
|
||||
make -j driver
|
||||
make -j $1
|
||||
#/opt/rocm/hcc/bin/llvm-objdump -mcpu=gfx906 -source -line-numbers driver/dump-gfx906.isabin > driver/dump-gfx906.isabin.asm
|
||||
|
||||
3
script/docker-cuda.sh
Executable file
3
script/docker-cuda.sh
Executable file
@@ -0,0 +1,3 @@
|
||||
WORKSPACE=$1
|
||||
echo "workspace: " $WORKSPACE
|
||||
sudo docker run -it -v $WORKSPACE:/root/workspace --group-add sudo --runtime=nvidia asroy/cuda:10.1-cudnn7-devel-ubuntu18.04-latest /bin/bash
|
||||
@@ -1,12 +0,0 @@
|
||||
for((i=0;i<=4096;i=i+64))
|
||||
do
|
||||
OFFSET=$i
|
||||
echo "if(offset == $OFFSET)"
|
||||
echo "{"
|
||||
echo " asm volatile(\"\\n \\"
|
||||
echo " ds_read_b128 %0, %1 offset:$OFFSET\n \\"
|
||||
echo " \""
|
||||
echo " : \"=v\"(r)"
|
||||
echo " : \"v\"(__to_local(lds)));"
|
||||
echo "}"
|
||||
done
|
||||
Reference in New Issue
Block a user