GEMM/Conv+BiasAdd+ReLU+Add (#55)

* gemm+activation

* move C pointwise operation into threadwise copy

* add pointwise operation to A/B matrix

* update ckProfiler

* adding bias add

* adding bias add

* adding bias add

* added bias add; worked around compiler issues

* clean up

* clean up

* Update README.md

* Update README.md

* Update README.md

* clean up

* add conv_xdl example

* adding conv_xdl_bias_relu_add example

* add conv+bias+relu+add, but has register spill issue

* tweak

* tweak

* refactor

* Update README.md

update readme for example/2_gemm_xdl_bias_relu_add

* clean up

* Update README.md

update readme for example/3_conv_xdl

* Update README.md
This commit is contained in:
Chao Liu
2021-12-02 20:07:37 -06:00
committed by GitHub
parent d7a0a3f94c
commit 41cdd3801a
40 changed files with 4336 additions and 250 deletions

View File

@@ -8,6 +8,9 @@ namespace ck {
namespace tensor_operation {
namespace device {
template <typename InElementwiseOperation,
typename WeiElementwiseOperation,
typename OutElementwiseOperation>
struct DeviceConvFwd : public BaseOperator
{
virtual std::unique_ptr<BaseArgument>
@@ -23,11 +26,17 @@ struct DeviceConvFwd : public BaseOperator
std::vector<ck::index_t> conv_filter_strides,
std::vector<ck::index_t> conv_filter_dilations,
std::vector<ck::index_t> input_left_pads,
std::vector<ck::index_t> input_right_pads) = 0;
std::vector<ck::index_t> input_right_pads,
InElementwiseOperation in_element_op,
WeiElementwiseOperation wei_element_op,
OutElementwiseOperation out_element_op) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
};
template <typename InElementwiseOperation,
typename WeiElementwiseOperation,
typename OutElementwiseOperation>
struct DeviceConvBwd : public BaseOperator
{
virtual std::unique_ptr<BaseArgument>
@@ -43,11 +52,17 @@ struct DeviceConvBwd : public BaseOperator
std::vector<ck::index_t> conv_filter_strides,
std::vector<ck::index_t> conv_filter_dilations,
std::vector<ck::index_t> input_left_pads,
std::vector<ck::index_t> input_right_pads) = 0;
std::vector<ck::index_t> input_right_pads,
InElementwiseOperation in_element_op,
WeiElementwiseOperation wei_element_op,
OutElementwiseOperation out_element_op) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
};
template <typename InElementwiseOperation,
typename WeiElementwiseOperation,
typename OutElementwiseOperation>
struct DeviceConvWrw : public BaseOperator
{
virtual std::unique_ptr<BaseArgument>
@@ -63,14 +78,31 @@ struct DeviceConvWrw : public BaseOperator
std::vector<ck::index_t> conv_filter_strides,
std::vector<ck::index_t> conv_filter_dilations,
std::vector<ck::index_t> input_left_pads,
std::vector<ck::index_t> input_right_pads) = 0;
std::vector<ck::index_t> input_right_pads,
InElementwiseOperation in_element_op,
WeiElementwiseOperation wei_element_op,
OutElementwiseOperation out_element_op) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
};
using DeviceConvFwdPtr = std::unique_ptr<DeviceConvFwd>;
using DeviceConvBwdPtr = std::unique_ptr<DeviceConvBwd>;
using DeviceConvWrwPtr = std::unique_ptr<DeviceConvWrw>;
template <typename InElementwiseOperation,
typename WeiElementwiseOperation,
typename OutElementwiseOperation>
using DeviceConvFwdPtr = std::unique_ptr<
DeviceConvFwd<InElementwiseOperation, WeiElementwiseOperation, OutElementwiseOperation>>;
template <typename InElementwiseOperation,
typename WeiElementwiseOperation,
typename OutElementwiseOperation>
using DeviceConvBwdPtr = std::unique_ptr<
DeviceConvBwd<InElementwiseOperation, WeiElementwiseOperation, OutElementwiseOperation>>;
template <typename InElementwiseOperation,
typename WeiElementwiseOperation,
typename OutElementwiseOperation>
using DeviceConvWrwPtr = std::unique_ptr<
DeviceConvWrw<InElementwiseOperation, WeiElementwiseOperation, OutElementwiseOperation>>;
} // namespace device
} // namespace tensor_operation

View File

@@ -23,6 +23,9 @@ template <ck::index_t NDimSpatial,
typename InLayout,
typename WeiLayout,
typename OutLayout,
typename InElementwiseOperation,
typename WeiElementwiseOperation,
typename OutElementwiseOperation,
ck::index_t BlockSize,
ck::index_t MPerBlock,
ck::index_t NPerBlock,

View File

@@ -22,6 +22,9 @@ template <typename InDataType,
typename WeiDataType,
typename OutDataType,
typename AccDataType,
typename InElementwiseOperation,
typename WeiElementwiseOperation,
typename OutElementwiseOperation,
ck::index_t BlockSize,
ck::index_t MPerBlock,
ck::index_t NPerBlock,
@@ -58,6 +61,9 @@ struct DeviceConvFwdXdl<
ck::tensor_layout::convolution::NHWC, // typename InLayout,
ck::tensor_layout::convolution::KYXC, // typename WeiLayout,
ck::tensor_layout::convolution::NHWK, // typename OutLayout,
InElementwiseOperation, // typename InElementwiseOperation,
WeiElementwiseOperation, // typename WeiElementwiseOperation,
OutElementwiseOperation, // typename OutElementwiseOperation,
BlockSize, // ck::index_t BlockSize,
MPerBlock, // ck::index_t MPerBlock,
NPerBlock, // ck::index_t NPerBlock,
@@ -87,7 +93,8 @@ struct DeviceConvFwdXdl<
CThreadTransferDstScalarPerVector, // ck::index_t CThreadTransferDstScalarPerVector,
ABlockLdsAddExtraM, // bool ABlockLdsAddExtraM,
BBlockLdsAddExtraN // bool BBlockLdsAddExtraN>
> : public DeviceConvFwd
>
: public DeviceConvFwd<InElementwiseOperation, WeiElementwiseOperation, OutElementwiseOperation>
{
using ADataType = InDataType;
using BDataType = WeiDataType;
@@ -293,6 +300,9 @@ struct DeviceConvFwdXdl<
AGridDesc_K0_M_K1,
BGridDesc_K0_N_K1,
CGridDesc_M_N,
InElementwiseOperation,
WeiElementwiseOperation,
OutElementwiseOperation,
MPerBlock,
NPerBlock,
K0PerBlock,
@@ -351,7 +361,10 @@ struct DeviceConvFwdXdl<
std::vector<ck::index_t> input_left_pads,
std::vector<ck::index_t> input_right_pads,
ck::index_t M01,
ck::index_t N01)
ck::index_t N01,
InElementwiseOperation in_element_op,
WeiElementwiseOperation wei_element_op,
OutElementwiseOperation out_element_op)
: p_a_grid_{p_in_grid},
p_b_grid_{p_wei_grid},
p_c_grid_{p_out_grid},
@@ -361,7 +374,10 @@ struct DeviceConvFwdXdl<
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_{},
block_2_ctile_map_{},
M01_{M01},
N01_{N01}
N01_{N01},
in_element_op_{in_element_op},
wei_element_op_{wei_element_op},
out_element_op_{out_element_op}
{
const auto descs = DeviceConvFwdXdl::MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(
N,
@@ -400,6 +416,9 @@ struct DeviceConvFwdXdl<
Block2CTileMap block_2_ctile_map_;
index_t M01_;
index_t N01_;
InElementwiseOperation in_element_op_;
WeiElementwiseOperation wei_element_op_;
OutElementwiseOperation out_element_op_;
};
// Invoker
@@ -449,6 +468,9 @@ struct DeviceConvFwdXdl<
remove_reference_t<DeviceConvFwdXdl::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceConvFwdXdl::BGridDesc_K0_N_K1>,
remove_reference_t<DeviceConvFwdXdl::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>,
InElementwiseOperation,
WeiElementwiseOperation,
OutElementwiseOperation,
remove_reference_t<DeviceConvFwdXdl::Block2CTileMap>,
true>;
@@ -463,6 +485,9 @@ struct DeviceConvFwdXdl<
arg.a_grid_desc_k0_m_k1_,
arg.b_grid_desc_k0_n_k1_,
arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_,
arg.in_element_op_,
arg.wei_element_op_,
arg.out_element_op_,
arg.block_2_ctile_map_);
}
else
@@ -474,6 +499,9 @@ struct DeviceConvFwdXdl<
remove_reference_t<DeviceConvFwdXdl::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceConvFwdXdl::BGridDesc_K0_N_K1>,
remove_reference_t<DeviceConvFwdXdl::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>,
InElementwiseOperation,
WeiElementwiseOperation,
OutElementwiseOperation,
remove_reference_t<DeviceConvFwdXdl::Block2CTileMap>,
false>;
@@ -488,6 +516,9 @@ struct DeviceConvFwdXdl<
arg.a_grid_desc_k0_m_k1_,
arg.b_grid_desc_k0_n_k1_,
arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_,
arg.in_element_op_,
arg.wei_element_op_,
arg.out_element_op_,
arg.block_2_ctile_map_);
}
@@ -534,7 +565,10 @@ struct DeviceConvFwdXdl<
std::vector<ck::index_t> conv_filter_strides,
std::vector<ck::index_t> conv_filter_dilations,
std::vector<ck::index_t> input_left_pads,
std::vector<ck::index_t> input_right_pads)
std::vector<ck::index_t> input_right_pads,
InElementwiseOperation in_element_op,
WeiElementwiseOperation wei_element_op,
OutElementwiseOperation out_element_op)
{
return Argument{p_in_grid,
p_wei_grid,
@@ -550,7 +584,10 @@ struct DeviceConvFwdXdl<
input_left_pads,
input_right_pads,
1,
1};
1,
in_element_op,
wei_element_op,
out_element_op};
}
static auto MakeInvoker() { return Invoker{}; }
@@ -569,7 +606,10 @@ struct DeviceConvFwdXdl<
std::vector<ck::index_t> conv_filter_strides,
std::vector<ck::index_t> conv_filter_dilations,
std::vector<ck::index_t> input_left_pads,
std::vector<ck::index_t> input_right_pads) override
std::vector<ck::index_t> input_right_pads,
InElementwiseOperation in_element_op,
WeiElementwiseOperation wei_element_op,
OutElementwiseOperation out_element_op) override
{
return std::make_unique<Argument>(static_cast<const InDataType*>(p_in_grid),
static_cast<const WeiDataType*>(p_wei_grid),
@@ -585,7 +625,10 @@ struct DeviceConvFwdXdl<
input_left_pads,
input_right_pads,
1,
1);
1,
in_element_op,
wei_element_op,
out_element_op);
}
// polymorphic
@@ -593,7 +636,7 @@ struct DeviceConvFwdXdl<
{
return std::make_unique<Invoker>(Invoker{});
}
};
}; // namespace device
} // namespace device
} // namespace tensor_operation

View File

@@ -2,6 +2,7 @@
#define DEVICE_CONV_INSTANTCE_HPP
#include "device_conv.hpp"
#include "element_wise_operation.hpp"
namespace ck {
namespace tensor_operation {
@@ -15,7 +16,10 @@ template <ck::index_t NDimSpatial,
typename InLayout,
typename WeiLayout,
typename OutLayout>
void add_device_conv_fwd_instance(std::vector<DeviceConvFwdPtr>&);
void add_device_conv_fwd_instance(
std::vector<DeviceConvFwdPtr<ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough>>&);
template <ck::index_t NDimSpatial,
typename InDataType,
@@ -24,7 +28,10 @@ template <ck::index_t NDimSpatial,
typename InLayout,
typename WeiLayout,
typename OutLayout>
void add_device_conv_bwd_instance(std::vector<DeviceConvBwdPtr>&);
void add_device_conv_bwd_instance(
std::vector<DeviceConvBwdPtr<ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough>>&);
template <ck::index_t NDimSpatial,
typename InDataType,
@@ -33,7 +40,10 @@ template <ck::index_t NDimSpatial,
typename InLayout,
typename WeiLayout,
typename OutLayout>
void add_device_conv_wrw_instance(std::vector<DeviceConvWrwPtr>&);
void add_device_conv_wrw_instance(
std::vector<DeviceConvWrwPtr<ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough>>&);
} // namespace device_conv_instance
} // namespace device

View File

@@ -8,22 +8,33 @@ namespace ck {
namespace tensor_operation {
namespace device {
template <typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation>
struct DeviceGemm : public BaseOperator
{
virtual std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
const void* p_b,
void* p_c,
ck::index_t M,
ck::index_t N,
ck::index_t K,
ck::index_t StrideA,
ck::index_t StrideB,
ck::index_t StrideC) = 0;
virtual std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_a,
const void* p_b,
void* p_c,
ck::index_t M,
ck::index_t N,
ck::index_t K,
ck::index_t StrideA,
ck::index_t StrideB,
ck::index_t StrideC,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
};
using DeviceGemmPtr = std::unique_ptr<DeviceGemm>;
template <typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation>
using DeviceGemmPtr = std::unique_ptr<
DeviceGemm<AElementwiseOperation, BElementwiseOperation, CElementwiseOperation>>;
} // namespace device
} // namespace tensor_operation

View File

@@ -2,6 +2,7 @@
#define DEVICE_GEMM_INSTANTCE_HPP
#include "device_gemm.hpp"
#include "element_wise_operation.hpp"
namespace ck {
namespace tensor_operation {
@@ -14,7 +15,10 @@ template <typename ADataType,
typename ALayout,
typename BLayout,
typename CLayout>
void add_device_gemm_instance(std::vector<DeviceGemmPtr>&);
void add_device_gemm_instance(
std::vector<DeviceGemmPtr<ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough>>&);
} // namespace device_gemm_instance
} // namespace device

View File

@@ -22,6 +22,9 @@ template <typename ADataType,
typename ALayout,
typename BLayout,
typename CLayout,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation,
ck::index_t BlockSize,
ck::index_t MPerBlock,
ck::index_t NPerBlock,
@@ -49,7 +52,8 @@ template <typename ADataType,
ck::index_t CThreadTransferDstScalarPerVector,
bool ABlockLdsAddExtraM,
bool BBlockLdsAddExtraN>
struct DeviceGemmXdl : public DeviceGemm
struct DeviceGemmXdl
: public DeviceGemm<AElementwiseOperation, BElementwiseOperation, CElementwiseOperation>
{
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
@@ -176,6 +180,9 @@ struct DeviceGemmXdl : public DeviceGemm
AGridDesc_K0_M_K1,
BGridDesc_K0_N_K1,
CGridDesc_M_N,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
MPerBlock,
NPerBlock,
K0PerBlock,
@@ -230,7 +237,10 @@ struct DeviceGemmXdl : public DeviceGemm
index_t StrideB,
index_t StrideC,
index_t M01,
index_t N01)
index_t N01,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op)
: p_a_grid_{p_a_grid},
p_b_grid_{p_b_grid},
p_c_grid_{p_c_grid},
@@ -240,7 +250,10 @@ struct DeviceGemmXdl : public DeviceGemm
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_{},
block_2_ctile_map_{},
M01_{M01},
N01_{N01}
N01_{N01},
a_element_op_{a_element_op},
b_element_op_{b_element_op},
c_element_op_{c_element_op}
{
a_grid_desc_k0_m_k1_ = DeviceGemmXdl::MakeAGridDescriptor_K0_M_K1(M, K, StrideA);
b_grid_desc_k0_n_k1_ = DeviceGemmXdl::MakeBGridDescriptor_K0_N_K1(K, N, StrideB);
@@ -267,6 +280,9 @@ struct DeviceGemmXdl : public DeviceGemm
Block2CTileMap block_2_ctile_map_;
index_t M01_;
index_t N01_;
AElementwiseOperation a_element_op_;
BElementwiseOperation b_element_op_;
CElementwiseOperation c_element_op_;
};
// Invoker
@@ -316,6 +332,9 @@ struct DeviceGemmXdl : public DeviceGemm
remove_reference_t<DeviceGemmXdl::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceGemmXdl::BGridDesc_K0_N_K1>,
remove_reference_t<DeviceGemmXdl::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
remove_reference_t<DeviceGemmXdl::Block2CTileMap>,
true>;
@@ -330,6 +349,9 @@ struct DeviceGemmXdl : public DeviceGemm
arg.a_grid_desc_k0_m_k1_,
arg.b_grid_desc_k0_n_k1_,
arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_,
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_,
arg.block_2_ctile_map_);
}
else
@@ -341,6 +363,9 @@ struct DeviceGemmXdl : public DeviceGemm
remove_reference_t<DeviceGemmXdl::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceGemmXdl::BGridDesc_K0_N_K1>,
remove_reference_t<DeviceGemmXdl::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
remove_reference_t<DeviceGemmXdl::Block2CTileMap>,
false>;
@@ -355,6 +380,9 @@ struct DeviceGemmXdl : public DeviceGemm
arg.a_grid_desc_k0_m_k1_,
arg.b_grid_desc_k0_n_k1_,
arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_,
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_,
arg.block_2_ctile_map_);
}
@@ -397,9 +425,25 @@ struct DeviceGemmXdl : public DeviceGemm
index_t K,
index_t StrideA,
index_t StrideB,
index_t StrideC)
index_t StrideC,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op)
{
return Argument{p_a, p_b, p_c, M, N, K, StrideA, StrideB, StrideC, 1, 1};
return Argument{p_a,
p_b,
p_c,
M,
N,
K,
StrideA,
StrideB,
StrideC,
1,
1,
a_element_op,
b_element_op,
c_element_op};
}
static auto MakeInvoker() { return Invoker{}; }
@@ -413,7 +457,10 @@ struct DeviceGemmXdl : public DeviceGemm
index_t K,
index_t StrideA,
index_t StrideB,
index_t StrideC) override
index_t StrideC,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op) override
{
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
static_cast<const BDataType*>(p_b),
@@ -425,7 +472,10 @@ struct DeviceGemmXdl : public DeviceGemm
StrideB,
StrideC,
1,
1);
1,
a_element_op,
b_element_op,
c_element_op);
}
// polymorphic

View File

@@ -0,0 +1,20 @@
#ifndef ELEMENT_WISE_OPERATION_HPP
#define ELEMENT_WISE_OPERATION_HPP
namespace ck {
namespace tensor_operation {
namespace element_wise {
struct PassThrough
{
template <typename T>
__host__ __device__ constexpr T operator()(T v) const
{
return v;
}
};
} // namespace element_wise
} // namespace tensor_operation
} // namespace ck
#endif