gemm + layernorm (#261)

* Implement reduction meand and reduction square mean

* Refine file name

* Add reduce mean and square mean

* Fix parameter name

* Add normalize device op (not implement invoker::run())

* Remove epislon

* Refine deviceop

* Add 5ary elementwise for normalization

* Add layernorm example

* layerNorm verication

* Fix compiler error due to merge from develop

* Fix typo

* Fix compile error

* Refine naming

* [What] Suport non pointer for invoker and argument
[Why] Snyc coding style with gemm

* Refine folder name

* Refine class name

* Evaluate perf of the kernel

* Fix compile error

* [What] Refine perf evaluation in example of gemm + reduction
[Why] evaluation of gemm + reduction may cause verification fail. Because evaluation will not initial global memory

* clang-format
This commit is contained in:
rocking5566
2022-05-31 05:36:55 +08:00
committed by GitHub
parent 91d8b7d67a
commit d32a67a9b6
23 changed files with 1130 additions and 99 deletions

View File

@@ -0,0 +1,333 @@
#pragma once
#include <iostream>
#include <sstream>
#include "device.hpp"
#include "device_base.hpp"
#include "common_header.hpp"
#include "gridwise_5ary_Elementwise_1d.hpp"
#include "tensor_layout.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
template <typename ADataType,
typename BDataType,
typename CDataType,
typename DDataType,
typename EDataType,
typename FDataType,
typename ComputeDataType,
typename ElementwiseFunctor,
index_t NDim,
index_t MPerThread,
index_t AScalarPerVector,
index_t BScalarPerVector,
index_t CScalarPerVector,
index_t DScalarPerVector,
index_t EScalarPerVector,
index_t FScalarPerVector>
struct Device5AryElementwise : public BaseOperator
{
static constexpr auto I0 = Number<0>{};
template <typename Desc_M>
static auto PadDescriptor_M_1d(Desc_M desc_m, index_t gridSize, index_t blockSize)
{
const auto m = desc_m.GetLength(I0);
const index_t loop_step = gridSize * blockSize * MPerThread;
const auto pad = math::integer_least_multiple(m, loop_step) - m;
const auto desc_m_pad =
transform_tensor_descriptor(desc_m,
make_tuple(make_right_pad_transform(m, pad)),
make_tuple(Sequence<0>{}),
make_tuple(Sequence<0>{}));
return desc_m_pad;
}
static auto MakeDescriptor_M(const std::vector<index_t>& lengths,
const std::vector<index_t>& stride,
index_t gridSize,
index_t blockSize)
{
auto tupleOfShape = generate_tuple([&](auto I) { return lengths[I]; }, Number<NDim>{});
auto tupleOfStride = generate_tuple([&](auto I) { return stride[I]; }, Number<NDim>{});
// nd desc - [s0, s1, s2, ...]
const auto desc = make_naive_tensor_descriptor(tupleOfShape, tupleOfStride);
// merge nd to 1d desc - [s0 * s1 * ...]
if constexpr(NDim > 1)
{
const auto desc_m = transform_tensor_descriptor(
desc,
make_tuple(make_merge_transform(tupleOfShape)),
make_tuple(generate_sequence_v2([&](auto I) { return I; }, Number<NDim>{})),
make_tuple(Sequence<0>{}));
return PadDescriptor_M_1d(desc_m, gridSize, blockSize);
}
else
return PadDescriptor_M_1d(desc, gridSize, blockSize);
}
using AGridDesc_M = decltype(MakeDescriptor_M({1, 1}, {1, 1}, 1, 1));
using BGridDesc_M = decltype(MakeDescriptor_M({1, 1}, {1, 1}, 1, 1));
using CGridDesc_M = decltype(MakeDescriptor_M({1, 1}, {1, 1}, 1, 1));
using DGridDesc_M = decltype(MakeDescriptor_M({1, 1}, {1, 1}, 1, 1));
using EGridDesc_M = decltype(MakeDescriptor_M({1, 1}, {1, 1}, 1, 1));
using FGridDesc_M = decltype(MakeDescriptor_M({1, 1}, {1, 1}, 1, 1));
using Gridwise5AryEltwise = Gridwise5AryElementwise_1D<ADataType,
BDataType,
CDataType,
DDataType,
EDataType,
FDataType,
ComputeDataType,
AGridDesc_M,
BGridDesc_M,
CGridDesc_M,
DGridDesc_M,
EGridDesc_M,
FGridDesc_M,
ElementwiseFunctor,
MPerThread,
AScalarPerVector,
BScalarPerVector,
CScalarPerVector,
DScalarPerVector,
EScalarPerVector,
FScalarPerVector>;
struct Argument : public BaseArgument
{
Argument(const ADataType* p_a,
const BDataType* p_b,
const CDataType* p_c,
const DDataType* p_d,
const EDataType* p_e,
FDataType* p_f,
const std::vector<index_t>& lengths,
const std::vector<index_t>& a_strides,
const std::vector<index_t>& b_strides,
const std::vector<index_t>& c_strides,
const std::vector<index_t>& d_strides,
const std::vector<index_t>& e_strides,
const std::vector<index_t>& f_strides,
ElementwiseFunctor functor)
: p_a_(p_a),
p_b_(p_b),
p_c_(p_c),
p_d_(p_d),
p_e_(p_e),
p_f_(p_f),
lengths_(lengths),
a_strides_(a_strides),
b_strides_(b_strides),
c_strides_(c_strides),
d_strides_(d_strides),
e_strides_(e_strides),
f_strides_(f_strides),
functor_(functor),
blockSize_(256),
gridSize_(120) // FIXME - Calculate the grid size by number of CU in the future
{
a_grid_desc_m_ = MakeDescriptor_M(lengths, a_strides, gridSize_, blockSize_);
b_grid_desc_m_ = MakeDescriptor_M(lengths, b_strides, gridSize_, blockSize_);
c_grid_desc_m_ = MakeDescriptor_M(lengths, c_strides, gridSize_, blockSize_);
d_grid_desc_m_ = MakeDescriptor_M(lengths, d_strides, gridSize_, blockSize_);
e_grid_desc_m_ = MakeDescriptor_M(lengths, e_strides, gridSize_, blockSize_);
f_grid_desc_m_ = MakeDescriptor_M(lengths, f_strides, gridSize_, blockSize_);
}
const ADataType* p_a_;
const BDataType* p_b_;
const CDataType* p_c_;
const DDataType* p_d_;
const EDataType* p_e_;
FDataType* p_f_;
std::vector<index_t> lengths_;
AGridDesc_M a_grid_desc_m_;
BGridDesc_M b_grid_desc_m_;
CGridDesc_M c_grid_desc_m_;
DGridDesc_M d_grid_desc_m_;
EGridDesc_M e_grid_desc_m_;
FGridDesc_M f_grid_desc_m_;
std::vector<index_t> a_strides_;
std::vector<index_t> b_strides_;
std::vector<index_t> c_strides_;
std::vector<index_t> d_strides_;
std::vector<index_t> e_strides_;
std::vector<index_t> f_strides_;
ElementwiseFunctor functor_;
index_t blockSize_;
index_t gridSize_;
};
struct Invoker : public BaseInvoker
{
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
const auto kernel = kernel_5ary_elementwise_1d<Gridwise5AryEltwise,
ADataType,
BDataType,
CDataType,
DDataType,
EDataType,
FDataType,
AGridDesc_M,
BGridDesc_M,
CGridDesc_M,
DGridDesc_M,
EGridDesc_M,
FGridDesc_M,
ElementwiseFunctor>;
float elapsed_time = launch_and_time_kernel(stream_config,
kernel,
dim3(arg.gridSize_),
dim3(arg.blockSize_),
0,
arg.p_a_,
arg.p_b_,
arg.p_c_,
arg.p_d_,
arg.p_e_,
arg.p_f_,
arg.a_grid_desc_m_,
arg.b_grid_desc_m_,
arg.c_grid_desc_m_,
arg.d_grid_desc_m_,
arg.e_grid_desc_m_,
arg.f_grid_desc_m_,
arg.functor_);
return elapsed_time;
}
// polymorphic
float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
}
};
bool IsSupportedArgument(const BaseArgument& p_arg) { return IsSupportedArgument(&p_arg); }
bool IsSupportedArgument(const BaseArgument* p_arg) override
{
const Argument* pArg = dynamic_cast<const Argument*>(p_arg);
if(pArg == nullptr)
return false;
if(pArg->lengths_.size() != NDim)
return false;
if(pArg->lengths_.back() % MPerThread != 0)
return false;
auto IsScalarPerVectorValid = [](bool isLastDimensionCoalesced, int scalarPerVector) {
bool ret = true;
if(!isLastDimensionCoalesced)
ret = scalarPerVector == 1;
else
ret = MPerThread % scalarPerVector == 0;
return ret;
};
if(!IsScalarPerVectorValid(pArg->a_strides_.back() == 1, AScalarPerVector))
return false;
if(!IsScalarPerVectorValid(pArg->b_strides_.back() == 1, BScalarPerVector))
return false;
if(!IsScalarPerVectorValid(pArg->c_strides_.back() == 1, CScalarPerVector))
return false;
if(!IsScalarPerVectorValid(pArg->d_strides_.back() == 1, DScalarPerVector))
return false;
if(!IsScalarPerVectorValid(pArg->e_strides_.back() == 1, EScalarPerVector))
return false;
if(!IsScalarPerVectorValid(pArg->f_strides_.back() == 1, FScalarPerVector))
return false;
return true;
};
static auto MakeArgument(const ADataType* p_a,
const BDataType* p_b,
const CDataType* p_c,
const DDataType* p_d,
const EDataType* p_e,
FDataType* p_f,
std::vector<index_t> lengths,
std::vector<index_t> a_strides,
std::vector<index_t> b_strides,
std::vector<index_t> c_strides,
std::vector<index_t> d_strides,
std::vector<index_t> e_strides,
std::vector<index_t> f_strides,
ElementwiseFunctor functor)
{
return Argument{p_a,
p_b,
p_c,
p_d,
p_e,
p_f,
lengths,
a_strides,
b_strides,
c_strides,
d_strides,
e_strides,
f_strides,
functor};
}
std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
const void* p_b,
const void* p_c,
const void* p_d,
const void* p_e,
void* p_f,
std::vector<index_t> lengths,
std::vector<index_t> a_strides,
std::vector<index_t> b_strides,
std::vector<index_t> c_strides,
std::vector<index_t> d_strides,
std::vector<index_t> e_strides,
std::vector<index_t> f_strides,
ElementwiseFunctor functor)
{
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
static_cast<const BDataType*>(p_b),
static_cast<const CDataType*>(p_c),
static_cast<const DDataType*>(p_d),
static_cast<const EDataType*>(p_e),
static_cast<FDataType*>(p_f),
lengths,
a_strides,
b_strides,
c_strides,
d_strides,
e_strides,
f_strides,
functor);
}
static auto MakeInvoker() { return Invoker{}; }
std::unique_ptr<BaseInvoker> MakeInvokerPointer() { return std::make_unique<Invoker>(); }
}; // namespace device
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -22,7 +22,7 @@ template <typename GridwiseGemm,
typename BElementwiseOperation,
typename CElementwiseOperation,
typename DxsInElementwiseOperation,
typename DxsOutElementwiseOperation,
typename DxsAccElementwiseOperation,
typename AGridDesc_AK0_M_AK1,
typename BGridDesc_BK0_N_BK1,
typename CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
@@ -44,7 +44,7 @@ __global__ void
const BElementwiseOperation b_element_op,
const CElementwiseOperation c_element_op,
const DxsInElementwiseOperation dxs_in_element_op,
const DxsOutElementwiseOperation dxs_out_element_op,
const DxsAccElementwiseOperation dxs_out_element_op,
const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1,
const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
@@ -126,7 +126,7 @@ template <typename ALayout,
typename CElementwiseOperation,
typename DxsReduceOperation,
typename DxsInElementwiseOperation,
typename DxsOutElementwiseOperation,
typename DxsAccElementwiseOperation,
typename DGlobalMemoryDataOperation,
GemmSpecialization GemmSpec,
index_t NumGemmKPrefetchStage,
@@ -167,7 +167,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGloba
BElementwiseOperation,
CElementwiseOperation,
DxsInElementwiseOperation,
DxsOutElementwiseOperation>
DxsAccElementwiseOperation>
{
using DeviceOp = DeviceBatchedGemmReduce_Xdl_CShuffle;
@@ -527,7 +527,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGloba
CElementwiseOperation,
DxsReduceOperation,
DxsInElementwiseOperation,
DxsOutElementwiseOperation,
DxsAccElementwiseOperation,
InMemoryDataOperationEnum::Set,
DGlobalMemoryDataOperation,
AGridDesc_AK0_M_AK1,
@@ -587,7 +587,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGloba
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op,
DxsInElementwiseOperation dxs_in_element_op,
DxsOutElementwiseOperation dxs_out_element_op,
DxsAccElementwiseOperation dxs_out_element_op,
index_t BatchCount)
: p_a_grid_{p_a_grid},
p_b_grid_{p_b_grid},
@@ -645,7 +645,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGloba
BElementwiseOperation b_element_op_;
CElementwiseOperation c_element_op_;
DxsInElementwiseOperation dxs_in_element_op_;
DxsOutElementwiseOperation dxs_out_element_op_;
DxsAccElementwiseOperation dxs_out_element_op_;
};
// Invoker
@@ -703,7 +703,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGloba
BElementwiseOperation,
CElementwiseOperation,
DxsInElementwiseOperation,
DxsOutElementwiseOperation,
DxsAccElementwiseOperation,
DeviceOp::AGridDesc_AK0_M_AK1,
DeviceOp::BGridDesc_BK0_N_BK1,
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
@@ -746,7 +746,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGloba
BElementwiseOperation,
CElementwiseOperation,
DxsInElementwiseOperation,
DxsOutElementwiseOperation,
DxsAccElementwiseOperation,
DeviceOp::AGridDesc_AK0_M_AK1,
DeviceOp::BGridDesc_BK0_N_BK1,
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
@@ -832,7 +832,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGloba
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op,
DxsInElementwiseOperation dxs_in_element_op,
DxsOutElementwiseOperation dxs_out_element_op,
DxsAccElementwiseOperation dxs_out_element_op,
index_t BatchCount)
{
return Argument{p_a,
@@ -870,7 +870,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGloba
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op,
DxsInElementwiseOperation dxs_in_element_op,
DxsOutElementwiseOperation dxs_out_element_op,
DxsAccElementwiseOperation dxs_out_element_op,
index_t BatchCount) override
{
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),

View File

@@ -11,7 +11,7 @@ template <typename DPtrsGlobal,
typename BElementwiseOperation,
typename CElementwiseOperation,
typename DxsInElementwiseOperation,
typename DxsOutElementwiseOperation>
typename DxsAccElementwiseOperation>
struct DeviceGemmReduce : public BaseOperator
{
virtual std::unique_ptr<BaseArgument>
@@ -29,7 +29,7 @@ struct DeviceGemmReduce : public BaseOperator
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op,
DxsInElementwiseOperation dxs_in_element_op,
DxsOutElementwiseOperation dxs_out_element_op,
DxsAccElementwiseOperation dxs_out_element_op,
ck::index_t BatchCount = 1) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
@@ -40,13 +40,13 @@ template <typename DPtrsGlobal,
typename BElementwiseOperation,
typename CElementwiseOperation,
typename DxsInElementwiseOperation,
typename DxsOutElementwiseOperation>
typename DxsAccElementwiseOperation>
using DeviceGemmReducePtr = std::unique_ptr<DeviceGemmReduce<DPtrsGlobal,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
DxsInElementwiseOperation,
DxsOutElementwiseOperation>>;
DxsAccElementwiseOperation>>;
} // namespace device
} // namespace tensor_operation

View File

@@ -32,7 +32,7 @@ template <typename ALayout,
typename CElementwiseOperation,
typename DxsReduceOperation,
typename DxsInElementwiseOperation,
typename DxsOutElementwiseOperation,
typename DxsAccElementwiseOperation,
typename DGlobalMemoryDataOperation,
GemmSpecialization GemmSpec,
index_t NumGemmKPrefetchStage,
@@ -73,7 +73,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGlobal,
BElementwiseOperation,
CElementwiseOperation,
DxsInElementwiseOperation,
DxsOutElementwiseOperation>
DxsAccElementwiseOperation>
{
using DeviceOp = DeviceGemmReduce_Xdl_CShuffle;
@@ -389,7 +389,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGlobal,
CElementwiseOperation,
DxsReduceOperation,
DxsInElementwiseOperation,
DxsOutElementwiseOperation,
DxsAccElementwiseOperation,
InMemoryDataOperationEnum::Set,
DGlobalMemoryDataOperation,
AGridDesc_AK0_M_AK1,
@@ -449,7 +449,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGlobal,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op,
DxsInElementwiseOperation dxs_in_element_op,
DxsOutElementwiseOperation dxs_out_element_op)
DxsAccElementwiseOperation dxs_out_element_op)
: p_a_grid_{p_a_grid},
p_b_grid_{p_b_grid},
p_c_grid_{p_c_grid},
@@ -498,7 +498,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGlobal,
BElementwiseOperation b_element_op_;
CElementwiseOperation c_element_op_;
DxsInElementwiseOperation dxs_in_element_op_;
DxsOutElementwiseOperation dxs_out_element_op_;
DxsAccElementwiseOperation dxs_out_element_op_;
};
// Invoker
@@ -554,7 +554,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGlobal,
BElementwiseOperation,
CElementwiseOperation,
DxsInElementwiseOperation,
DxsOutElementwiseOperation,
DxsAccElementwiseOperation,
DeviceOp::AGridDesc_AK0_M_AK1,
DeviceOp::BGridDesc_BK0_N_BK1,
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
@@ -594,7 +594,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGlobal,
BElementwiseOperation,
CElementwiseOperation,
DxsInElementwiseOperation,
DxsOutElementwiseOperation,
DxsAccElementwiseOperation,
DeviceOp::AGridDesc_AK0_M_AK1,
DeviceOp::BGridDesc_BK0_N_BK1,
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
@@ -669,7 +669,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGlobal,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op,
DxsInElementwiseOperation dxs_in_element_op,
DxsOutElementwiseOperation dxs_out_element_op)
DxsAccElementwiseOperation dxs_out_element_op)
{
return Argument{p_a,
p_b,
@@ -705,7 +705,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGlobal,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op,
DxsInElementwiseOperation dxs_in_element_op,
DxsOutElementwiseOperation dxs_out_element_op,
DxsAccElementwiseOperation dxs_out_element_op,
index_t /* KBatch */ = 1) override
{
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),