mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-11 08:50:17 +00:00
GEMM/Conv+BiasAdd+ReLU+Add (#55)
* gemm+activation * move C pointwise operation into threadwise copy * add pointwise operation to A/B matrix * update ckProfiler * adding bias add * adding bias add * adding bias add * added bias add; worked around compiler issues * clean up * clean up * Update README.md * Update README.md * Update README.md * clean up * add conv_xdl example * adding conv_xdl_bias_relu_add example * add conv+bias+relu+add, but has register spill issue * tweak * tweak * refactor * Update README.md update readme for example/2_gemm_xdl_bias_relu_add * clean up * Update README.md update readme for example/3_conv_xdl * Update README.md
This commit is contained in:
@@ -8,6 +8,9 @@ namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
template <typename InElementwiseOperation,
|
||||
typename WeiElementwiseOperation,
|
||||
typename OutElementwiseOperation>
|
||||
struct DeviceConvFwd : public BaseOperator
|
||||
{
|
||||
virtual std::unique_ptr<BaseArgument>
|
||||
@@ -23,11 +26,17 @@ struct DeviceConvFwd : public BaseOperator
|
||||
std::vector<ck::index_t> conv_filter_strides,
|
||||
std::vector<ck::index_t> conv_filter_dilations,
|
||||
std::vector<ck::index_t> input_left_pads,
|
||||
std::vector<ck::index_t> input_right_pads) = 0;
|
||||
std::vector<ck::index_t> input_right_pads,
|
||||
InElementwiseOperation in_element_op,
|
||||
WeiElementwiseOperation wei_element_op,
|
||||
OutElementwiseOperation out_element_op) = 0;
|
||||
|
||||
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
|
||||
};
|
||||
|
||||
template <typename InElementwiseOperation,
|
||||
typename WeiElementwiseOperation,
|
||||
typename OutElementwiseOperation>
|
||||
struct DeviceConvBwd : public BaseOperator
|
||||
{
|
||||
virtual std::unique_ptr<BaseArgument>
|
||||
@@ -43,11 +52,17 @@ struct DeviceConvBwd : public BaseOperator
|
||||
std::vector<ck::index_t> conv_filter_strides,
|
||||
std::vector<ck::index_t> conv_filter_dilations,
|
||||
std::vector<ck::index_t> input_left_pads,
|
||||
std::vector<ck::index_t> input_right_pads) = 0;
|
||||
std::vector<ck::index_t> input_right_pads,
|
||||
InElementwiseOperation in_element_op,
|
||||
WeiElementwiseOperation wei_element_op,
|
||||
OutElementwiseOperation out_element_op) = 0;
|
||||
|
||||
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
|
||||
};
|
||||
|
||||
template <typename InElementwiseOperation,
|
||||
typename WeiElementwiseOperation,
|
||||
typename OutElementwiseOperation>
|
||||
struct DeviceConvWrw : public BaseOperator
|
||||
{
|
||||
virtual std::unique_ptr<BaseArgument>
|
||||
@@ -63,14 +78,31 @@ struct DeviceConvWrw : public BaseOperator
|
||||
std::vector<ck::index_t> conv_filter_strides,
|
||||
std::vector<ck::index_t> conv_filter_dilations,
|
||||
std::vector<ck::index_t> input_left_pads,
|
||||
std::vector<ck::index_t> input_right_pads) = 0;
|
||||
std::vector<ck::index_t> input_right_pads,
|
||||
InElementwiseOperation in_element_op,
|
||||
WeiElementwiseOperation wei_element_op,
|
||||
OutElementwiseOperation out_element_op) = 0;
|
||||
|
||||
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
|
||||
};
|
||||
|
||||
using DeviceConvFwdPtr = std::unique_ptr<DeviceConvFwd>;
|
||||
using DeviceConvBwdPtr = std::unique_ptr<DeviceConvBwd>;
|
||||
using DeviceConvWrwPtr = std::unique_ptr<DeviceConvWrw>;
|
||||
template <typename InElementwiseOperation,
|
||||
typename WeiElementwiseOperation,
|
||||
typename OutElementwiseOperation>
|
||||
using DeviceConvFwdPtr = std::unique_ptr<
|
||||
DeviceConvFwd<InElementwiseOperation, WeiElementwiseOperation, OutElementwiseOperation>>;
|
||||
|
||||
template <typename InElementwiseOperation,
|
||||
typename WeiElementwiseOperation,
|
||||
typename OutElementwiseOperation>
|
||||
using DeviceConvBwdPtr = std::unique_ptr<
|
||||
DeviceConvBwd<InElementwiseOperation, WeiElementwiseOperation, OutElementwiseOperation>>;
|
||||
|
||||
template <typename InElementwiseOperation,
|
||||
typename WeiElementwiseOperation,
|
||||
typename OutElementwiseOperation>
|
||||
using DeviceConvWrwPtr = std::unique_ptr<
|
||||
DeviceConvWrw<InElementwiseOperation, WeiElementwiseOperation, OutElementwiseOperation>>;
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
|
||||
@@ -23,6 +23,9 @@ template <ck::index_t NDimSpatial,
|
||||
typename InLayout,
|
||||
typename WeiLayout,
|
||||
typename OutLayout,
|
||||
typename InElementwiseOperation,
|
||||
typename WeiElementwiseOperation,
|
||||
typename OutElementwiseOperation,
|
||||
ck::index_t BlockSize,
|
||||
ck::index_t MPerBlock,
|
||||
ck::index_t NPerBlock,
|
||||
|
||||
@@ -22,6 +22,9 @@ template <typename InDataType,
|
||||
typename WeiDataType,
|
||||
typename OutDataType,
|
||||
typename AccDataType,
|
||||
typename InElementwiseOperation,
|
||||
typename WeiElementwiseOperation,
|
||||
typename OutElementwiseOperation,
|
||||
ck::index_t BlockSize,
|
||||
ck::index_t MPerBlock,
|
||||
ck::index_t NPerBlock,
|
||||
@@ -58,6 +61,9 @@ struct DeviceConvFwdXdl<
|
||||
ck::tensor_layout::convolution::NHWC, // typename InLayout,
|
||||
ck::tensor_layout::convolution::KYXC, // typename WeiLayout,
|
||||
ck::tensor_layout::convolution::NHWK, // typename OutLayout,
|
||||
InElementwiseOperation, // typename InElementwiseOperation,
|
||||
WeiElementwiseOperation, // typename WeiElementwiseOperation,
|
||||
OutElementwiseOperation, // typename OutElementwiseOperation,
|
||||
BlockSize, // ck::index_t BlockSize,
|
||||
MPerBlock, // ck::index_t MPerBlock,
|
||||
NPerBlock, // ck::index_t NPerBlock,
|
||||
@@ -87,7 +93,8 @@ struct DeviceConvFwdXdl<
|
||||
CThreadTransferDstScalarPerVector, // ck::index_t CThreadTransferDstScalarPerVector,
|
||||
ABlockLdsAddExtraM, // bool ABlockLdsAddExtraM,
|
||||
BBlockLdsAddExtraN // bool BBlockLdsAddExtraN>
|
||||
> : public DeviceConvFwd
|
||||
>
|
||||
: public DeviceConvFwd<InElementwiseOperation, WeiElementwiseOperation, OutElementwiseOperation>
|
||||
{
|
||||
using ADataType = InDataType;
|
||||
using BDataType = WeiDataType;
|
||||
@@ -293,6 +300,9 @@ struct DeviceConvFwdXdl<
|
||||
AGridDesc_K0_M_K1,
|
||||
BGridDesc_K0_N_K1,
|
||||
CGridDesc_M_N,
|
||||
InElementwiseOperation,
|
||||
WeiElementwiseOperation,
|
||||
OutElementwiseOperation,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
K0PerBlock,
|
||||
@@ -351,7 +361,10 @@ struct DeviceConvFwdXdl<
|
||||
std::vector<ck::index_t> input_left_pads,
|
||||
std::vector<ck::index_t> input_right_pads,
|
||||
ck::index_t M01,
|
||||
ck::index_t N01)
|
||||
ck::index_t N01,
|
||||
InElementwiseOperation in_element_op,
|
||||
WeiElementwiseOperation wei_element_op,
|
||||
OutElementwiseOperation out_element_op)
|
||||
: p_a_grid_{p_in_grid},
|
||||
p_b_grid_{p_wei_grid},
|
||||
p_c_grid_{p_out_grid},
|
||||
@@ -361,7 +374,10 @@ struct DeviceConvFwdXdl<
|
||||
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_{},
|
||||
block_2_ctile_map_{},
|
||||
M01_{M01},
|
||||
N01_{N01}
|
||||
N01_{N01},
|
||||
in_element_op_{in_element_op},
|
||||
wei_element_op_{wei_element_op},
|
||||
out_element_op_{out_element_op}
|
||||
{
|
||||
const auto descs = DeviceConvFwdXdl::MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(
|
||||
N,
|
||||
@@ -400,6 +416,9 @@ struct DeviceConvFwdXdl<
|
||||
Block2CTileMap block_2_ctile_map_;
|
||||
index_t M01_;
|
||||
index_t N01_;
|
||||
InElementwiseOperation in_element_op_;
|
||||
WeiElementwiseOperation wei_element_op_;
|
||||
OutElementwiseOperation out_element_op_;
|
||||
};
|
||||
|
||||
// Invoker
|
||||
@@ -449,6 +468,9 @@ struct DeviceConvFwdXdl<
|
||||
remove_reference_t<DeviceConvFwdXdl::AGridDesc_K0_M_K1>,
|
||||
remove_reference_t<DeviceConvFwdXdl::BGridDesc_K0_N_K1>,
|
||||
remove_reference_t<DeviceConvFwdXdl::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>,
|
||||
InElementwiseOperation,
|
||||
WeiElementwiseOperation,
|
||||
OutElementwiseOperation,
|
||||
remove_reference_t<DeviceConvFwdXdl::Block2CTileMap>,
|
||||
true>;
|
||||
|
||||
@@ -463,6 +485,9 @@ struct DeviceConvFwdXdl<
|
||||
arg.a_grid_desc_k0_m_k1_,
|
||||
arg.b_grid_desc_k0_n_k1_,
|
||||
arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_,
|
||||
arg.in_element_op_,
|
||||
arg.wei_element_op_,
|
||||
arg.out_element_op_,
|
||||
arg.block_2_ctile_map_);
|
||||
}
|
||||
else
|
||||
@@ -474,6 +499,9 @@ struct DeviceConvFwdXdl<
|
||||
remove_reference_t<DeviceConvFwdXdl::AGridDesc_K0_M_K1>,
|
||||
remove_reference_t<DeviceConvFwdXdl::BGridDesc_K0_N_K1>,
|
||||
remove_reference_t<DeviceConvFwdXdl::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>,
|
||||
InElementwiseOperation,
|
||||
WeiElementwiseOperation,
|
||||
OutElementwiseOperation,
|
||||
remove_reference_t<DeviceConvFwdXdl::Block2CTileMap>,
|
||||
false>;
|
||||
|
||||
@@ -488,6 +516,9 @@ struct DeviceConvFwdXdl<
|
||||
arg.a_grid_desc_k0_m_k1_,
|
||||
arg.b_grid_desc_k0_n_k1_,
|
||||
arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_,
|
||||
arg.in_element_op_,
|
||||
arg.wei_element_op_,
|
||||
arg.out_element_op_,
|
||||
arg.block_2_ctile_map_);
|
||||
}
|
||||
|
||||
@@ -534,7 +565,10 @@ struct DeviceConvFwdXdl<
|
||||
std::vector<ck::index_t> conv_filter_strides,
|
||||
std::vector<ck::index_t> conv_filter_dilations,
|
||||
std::vector<ck::index_t> input_left_pads,
|
||||
std::vector<ck::index_t> input_right_pads)
|
||||
std::vector<ck::index_t> input_right_pads,
|
||||
InElementwiseOperation in_element_op,
|
||||
WeiElementwiseOperation wei_element_op,
|
||||
OutElementwiseOperation out_element_op)
|
||||
{
|
||||
return Argument{p_in_grid,
|
||||
p_wei_grid,
|
||||
@@ -550,7 +584,10 @@ struct DeviceConvFwdXdl<
|
||||
input_left_pads,
|
||||
input_right_pads,
|
||||
1,
|
||||
1};
|
||||
1,
|
||||
in_element_op,
|
||||
wei_element_op,
|
||||
out_element_op};
|
||||
}
|
||||
|
||||
static auto MakeInvoker() { return Invoker{}; }
|
||||
@@ -569,7 +606,10 @@ struct DeviceConvFwdXdl<
|
||||
std::vector<ck::index_t> conv_filter_strides,
|
||||
std::vector<ck::index_t> conv_filter_dilations,
|
||||
std::vector<ck::index_t> input_left_pads,
|
||||
std::vector<ck::index_t> input_right_pads) override
|
||||
std::vector<ck::index_t> input_right_pads,
|
||||
InElementwiseOperation in_element_op,
|
||||
WeiElementwiseOperation wei_element_op,
|
||||
OutElementwiseOperation out_element_op) override
|
||||
{
|
||||
return std::make_unique<Argument>(static_cast<const InDataType*>(p_in_grid),
|
||||
static_cast<const WeiDataType*>(p_wei_grid),
|
||||
@@ -585,7 +625,10 @@ struct DeviceConvFwdXdl<
|
||||
input_left_pads,
|
||||
input_right_pads,
|
||||
1,
|
||||
1);
|
||||
1,
|
||||
in_element_op,
|
||||
wei_element_op,
|
||||
out_element_op);
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
@@ -593,7 +636,7 @@ struct DeviceConvFwdXdl<
|
||||
{
|
||||
return std::make_unique<Invoker>(Invoker{});
|
||||
}
|
||||
};
|
||||
}; // namespace device
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
#define DEVICE_CONV_INSTANTCE_HPP
|
||||
|
||||
#include "device_conv.hpp"
|
||||
#include "element_wise_operation.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
@@ -15,7 +16,10 @@ template <ck::index_t NDimSpatial,
|
||||
typename InLayout,
|
||||
typename WeiLayout,
|
||||
typename OutLayout>
|
||||
void add_device_conv_fwd_instance(std::vector<DeviceConvFwdPtr>&);
|
||||
void add_device_conv_fwd_instance(
|
||||
std::vector<DeviceConvFwdPtr<ck::tensor_operation::element_wise::PassThrough,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
ck::tensor_operation::element_wise::PassThrough>>&);
|
||||
|
||||
template <ck::index_t NDimSpatial,
|
||||
typename InDataType,
|
||||
@@ -24,7 +28,10 @@ template <ck::index_t NDimSpatial,
|
||||
typename InLayout,
|
||||
typename WeiLayout,
|
||||
typename OutLayout>
|
||||
void add_device_conv_bwd_instance(std::vector<DeviceConvBwdPtr>&);
|
||||
void add_device_conv_bwd_instance(
|
||||
std::vector<DeviceConvBwdPtr<ck::tensor_operation::element_wise::PassThrough,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
ck::tensor_operation::element_wise::PassThrough>>&);
|
||||
|
||||
template <ck::index_t NDimSpatial,
|
||||
typename InDataType,
|
||||
@@ -33,7 +40,10 @@ template <ck::index_t NDimSpatial,
|
||||
typename InLayout,
|
||||
typename WeiLayout,
|
||||
typename OutLayout>
|
||||
void add_device_conv_wrw_instance(std::vector<DeviceConvWrwPtr>&);
|
||||
void add_device_conv_wrw_instance(
|
||||
std::vector<DeviceConvWrwPtr<ck::tensor_operation::element_wise::PassThrough,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
ck::tensor_operation::element_wise::PassThrough>>&);
|
||||
|
||||
} // namespace device_conv_instance
|
||||
} // namespace device
|
||||
|
||||
@@ -8,22 +8,33 @@ namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
template <typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CElementwiseOperation>
|
||||
struct DeviceGemm : public BaseOperator
|
||||
{
|
||||
virtual std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
|
||||
const void* p_b,
|
||||
void* p_c,
|
||||
ck::index_t M,
|
||||
ck::index_t N,
|
||||
ck::index_t K,
|
||||
ck::index_t StrideA,
|
||||
ck::index_t StrideB,
|
||||
ck::index_t StrideC) = 0;
|
||||
virtual std::unique_ptr<BaseArgument>
|
||||
MakeArgumentPointer(const void* p_a,
|
||||
const void* p_b,
|
||||
void* p_c,
|
||||
ck::index_t M,
|
||||
ck::index_t N,
|
||||
ck::index_t K,
|
||||
ck::index_t StrideA,
|
||||
ck::index_t StrideB,
|
||||
ck::index_t StrideC,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CElementwiseOperation c_element_op) = 0;
|
||||
|
||||
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
|
||||
};
|
||||
|
||||
using DeviceGemmPtr = std::unique_ptr<DeviceGemm>;
|
||||
template <typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CElementwiseOperation>
|
||||
using DeviceGemmPtr = std::unique_ptr<
|
||||
DeviceGemm<AElementwiseOperation, BElementwiseOperation, CElementwiseOperation>>;
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
#define DEVICE_GEMM_INSTANTCE_HPP
|
||||
|
||||
#include "device_gemm.hpp"
|
||||
#include "element_wise_operation.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
@@ -14,7 +15,10 @@ template <typename ADataType,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename CLayout>
|
||||
void add_device_gemm_instance(std::vector<DeviceGemmPtr>&);
|
||||
void add_device_gemm_instance(
|
||||
std::vector<DeviceGemmPtr<ck::tensor_operation::element_wise::PassThrough,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
ck::tensor_operation::element_wise::PassThrough>>&);
|
||||
|
||||
} // namespace device_gemm_instance
|
||||
} // namespace device
|
||||
|
||||
@@ -22,6 +22,9 @@ template <typename ADataType,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename CLayout,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CElementwiseOperation,
|
||||
ck::index_t BlockSize,
|
||||
ck::index_t MPerBlock,
|
||||
ck::index_t NPerBlock,
|
||||
@@ -49,7 +52,8 @@ template <typename ADataType,
|
||||
ck::index_t CThreadTransferDstScalarPerVector,
|
||||
bool ABlockLdsAddExtraM,
|
||||
bool BBlockLdsAddExtraN>
|
||||
struct DeviceGemmXdl : public DeviceGemm
|
||||
struct DeviceGemmXdl
|
||||
: public DeviceGemm<AElementwiseOperation, BElementwiseOperation, CElementwiseOperation>
|
||||
{
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
@@ -176,6 +180,9 @@ struct DeviceGemmXdl : public DeviceGemm
|
||||
AGridDesc_K0_M_K1,
|
||||
BGridDesc_K0_N_K1,
|
||||
CGridDesc_M_N,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CElementwiseOperation,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
K0PerBlock,
|
||||
@@ -230,7 +237,10 @@ struct DeviceGemmXdl : public DeviceGemm
|
||||
index_t StrideB,
|
||||
index_t StrideC,
|
||||
index_t M01,
|
||||
index_t N01)
|
||||
index_t N01,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CElementwiseOperation c_element_op)
|
||||
: p_a_grid_{p_a_grid},
|
||||
p_b_grid_{p_b_grid},
|
||||
p_c_grid_{p_c_grid},
|
||||
@@ -240,7 +250,10 @@ struct DeviceGemmXdl : public DeviceGemm
|
||||
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_{},
|
||||
block_2_ctile_map_{},
|
||||
M01_{M01},
|
||||
N01_{N01}
|
||||
N01_{N01},
|
||||
a_element_op_{a_element_op},
|
||||
b_element_op_{b_element_op},
|
||||
c_element_op_{c_element_op}
|
||||
{
|
||||
a_grid_desc_k0_m_k1_ = DeviceGemmXdl::MakeAGridDescriptor_K0_M_K1(M, K, StrideA);
|
||||
b_grid_desc_k0_n_k1_ = DeviceGemmXdl::MakeBGridDescriptor_K0_N_K1(K, N, StrideB);
|
||||
@@ -267,6 +280,9 @@ struct DeviceGemmXdl : public DeviceGemm
|
||||
Block2CTileMap block_2_ctile_map_;
|
||||
index_t M01_;
|
||||
index_t N01_;
|
||||
AElementwiseOperation a_element_op_;
|
||||
BElementwiseOperation b_element_op_;
|
||||
CElementwiseOperation c_element_op_;
|
||||
};
|
||||
|
||||
// Invoker
|
||||
@@ -316,6 +332,9 @@ struct DeviceGemmXdl : public DeviceGemm
|
||||
remove_reference_t<DeviceGemmXdl::AGridDesc_K0_M_K1>,
|
||||
remove_reference_t<DeviceGemmXdl::BGridDesc_K0_N_K1>,
|
||||
remove_reference_t<DeviceGemmXdl::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CElementwiseOperation,
|
||||
remove_reference_t<DeviceGemmXdl::Block2CTileMap>,
|
||||
true>;
|
||||
|
||||
@@ -330,6 +349,9 @@ struct DeviceGemmXdl : public DeviceGemm
|
||||
arg.a_grid_desc_k0_m_k1_,
|
||||
arg.b_grid_desc_k0_n_k1_,
|
||||
arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_,
|
||||
arg.a_element_op_,
|
||||
arg.b_element_op_,
|
||||
arg.c_element_op_,
|
||||
arg.block_2_ctile_map_);
|
||||
}
|
||||
else
|
||||
@@ -341,6 +363,9 @@ struct DeviceGemmXdl : public DeviceGemm
|
||||
remove_reference_t<DeviceGemmXdl::AGridDesc_K0_M_K1>,
|
||||
remove_reference_t<DeviceGemmXdl::BGridDesc_K0_N_K1>,
|
||||
remove_reference_t<DeviceGemmXdl::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CElementwiseOperation,
|
||||
remove_reference_t<DeviceGemmXdl::Block2CTileMap>,
|
||||
false>;
|
||||
|
||||
@@ -355,6 +380,9 @@ struct DeviceGemmXdl : public DeviceGemm
|
||||
arg.a_grid_desc_k0_m_k1_,
|
||||
arg.b_grid_desc_k0_n_k1_,
|
||||
arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_,
|
||||
arg.a_element_op_,
|
||||
arg.b_element_op_,
|
||||
arg.c_element_op_,
|
||||
arg.block_2_ctile_map_);
|
||||
}
|
||||
|
||||
@@ -397,9 +425,25 @@ struct DeviceGemmXdl : public DeviceGemm
|
||||
index_t K,
|
||||
index_t StrideA,
|
||||
index_t StrideB,
|
||||
index_t StrideC)
|
||||
index_t StrideC,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CElementwiseOperation c_element_op)
|
||||
{
|
||||
return Argument{p_a, p_b, p_c, M, N, K, StrideA, StrideB, StrideC, 1, 1};
|
||||
return Argument{p_a,
|
||||
p_b,
|
||||
p_c,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
StrideA,
|
||||
StrideB,
|
||||
StrideC,
|
||||
1,
|
||||
1,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
c_element_op};
|
||||
}
|
||||
|
||||
static auto MakeInvoker() { return Invoker{}; }
|
||||
@@ -413,7 +457,10 @@ struct DeviceGemmXdl : public DeviceGemm
|
||||
index_t K,
|
||||
index_t StrideA,
|
||||
index_t StrideB,
|
||||
index_t StrideC) override
|
||||
index_t StrideC,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CElementwiseOperation c_element_op) override
|
||||
{
|
||||
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
|
||||
static_cast<const BDataType*>(p_b),
|
||||
@@ -425,7 +472,10 @@ struct DeviceGemmXdl : public DeviceGemm
|
||||
StrideB,
|
||||
StrideC,
|
||||
1,
|
||||
1);
|
||||
1,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
c_element_op);
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
|
||||
20
device_operation/include/element_wise_operation.hpp
Normal file
20
device_operation/include/element_wise_operation.hpp
Normal file
@@ -0,0 +1,20 @@
|
||||
#ifndef ELEMENT_WISE_OPERATION_HPP
|
||||
#define ELEMENT_WISE_OPERATION_HPP
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace element_wise {
|
||||
|
||||
struct PassThrough
|
||||
{
|
||||
template <typename T>
|
||||
__host__ __device__ constexpr T operator()(T v) const
|
||||
{
|
||||
return v;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace element_wise
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
#endif
|
||||
Reference in New Issue
Block a user