mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-03 21:21:22 +00:00
ckProfiler for layernorm (#330)
* Refine parameter * Add base class for layernorm * Add layernorm instance * Add layernorm to ckProfiler * Remove redundant * Add verification * Fix compile error due to merge
This commit is contained in:
@@ -7,7 +7,7 @@
|
||||
#include <sstream>
|
||||
|
||||
#include "ck/utility/reduction_operator.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_base.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_normalization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_reduce.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_reduce_multiblock.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_reduce_common.hpp"
|
||||
@@ -39,7 +39,14 @@ template <typename XDataType,
|
||||
index_t GammaSrcVectorSize,
|
||||
index_t BetaSrcVectorSize,
|
||||
index_t YDstVectorSize>
|
||||
struct DeviceLayernorm : public BaseOperator
|
||||
struct DeviceLayernorm : public DeviceNormalization2<XDataType,
|
||||
GammaDataType,
|
||||
BetaDataType,
|
||||
AccDataType,
|
||||
YDataType,
|
||||
AccElementwiseOperation,
|
||||
Rank,
|
||||
NumReduceDim>
|
||||
{
|
||||
static_assert(
|
||||
(KThreadSliceSize % GammaSrcVectorSize == 0),
|
||||
@@ -297,17 +304,18 @@ struct DeviceLayernorm : public BaseOperator
|
||||
return true;
|
||||
};
|
||||
|
||||
std::unique_ptr<BaseArgument> MakeArgumentPointer(const std::vector<index_t> lengths,
|
||||
const std::vector<index_t> xStrides,
|
||||
const std::vector<index_t> gammaStrides,
|
||||
const std::vector<index_t> betaStrides,
|
||||
const std::vector<index_t> reduceDims,
|
||||
AccDataType epsilon,
|
||||
const void* p_x,
|
||||
const void* p_gamma,
|
||||
const void* p_beta,
|
||||
void* p_y,
|
||||
AccElementwiseOperation acc_elementwise_op)
|
||||
std::unique_ptr<BaseArgument>
|
||||
MakeArgumentPointer(const std::vector<index_t> lengths,
|
||||
const std::vector<index_t> xStrides,
|
||||
const std::vector<index_t> gammaStrides,
|
||||
const std::vector<index_t> betaStrides,
|
||||
const std::vector<index_t> reduceDims,
|
||||
AccDataType epsilon,
|
||||
const void* p_x,
|
||||
const void* p_gamma,
|
||||
const void* p_beta,
|
||||
void* p_y,
|
||||
AccElementwiseOperation acc_elementwise_op) override
|
||||
{
|
||||
return std::make_unique<Argument>(lengths,
|
||||
xStrides,
|
||||
@@ -322,7 +330,10 @@ struct DeviceLayernorm : public BaseOperator
|
||||
static_cast<YDataType*>(p_y));
|
||||
};
|
||||
|
||||
std::unique_ptr<BaseInvoker> MakeInvokerPointer() { return std::make_unique<Invoker>(); };
|
||||
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
|
||||
{
|
||||
return std::make_unique<Invoker>();
|
||||
};
|
||||
|
||||
std::string GetTypeString() const override
|
||||
{
|
||||
@@ -332,7 +343,6 @@ struct DeviceLayernorm : public BaseOperator
|
||||
str << "DeviceLayernorm<" << BlockSize << ",";
|
||||
str << "M_C" << MThreadClusterSize << "_S" << MThreadSliceSize << ",";
|
||||
str << "K_C" << KThreadClusterSize << "_S" << KThreadSliceSize << ",";
|
||||
str << "K_C" << KThreadClusterSize << "_S" << KThreadSliceSize << ",";
|
||||
str << "XYSrcVectorDim_" << XYSrcVectorDim << ",";
|
||||
str << "VectorSize_X" << XSrcVectorSize << "_Gamma" << GammaSrcVectorSize << "_Beta" << BetaSrcVectorSize << "_Y" << YDstVectorSize << ">";
|
||||
// clang-format on
|
||||
|
||||
@@ -38,6 +38,49 @@ struct DeviceNormalization : public BaseOperator
|
||||
|
||||
using DeviceNormalizationPtr = std::unique_ptr<DeviceNormalization>;
|
||||
|
||||
template <typename XDataType,
|
||||
typename GammaDataType,
|
||||
typename BetaDataType,
|
||||
typename AccDataType,
|
||||
typename YDataType,
|
||||
typename AccElementwiseOperation,
|
||||
index_t Rank,
|
||||
index_t NumReduceDim>
|
||||
struct DeviceNormalization2 : public BaseOperator
|
||||
{
|
||||
virtual std::unique_ptr<BaseArgument>
|
||||
MakeArgumentPointer(const std::vector<index_t> lengths,
|
||||
const std::vector<index_t> xStrides,
|
||||
const std::vector<index_t> gammaStrides,
|
||||
const std::vector<index_t> betaStrides,
|
||||
const std::vector<index_t> reduceDims,
|
||||
AccDataType epsilon,
|
||||
const void* p_x,
|
||||
const void* p_gamma,
|
||||
const void* p_beta,
|
||||
void* p_y,
|
||||
AccElementwiseOperation acc_elementwise_op) = 0;
|
||||
|
||||
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
|
||||
};
|
||||
|
||||
template <typename XDataType,
|
||||
typename GammaDataType,
|
||||
typename BetaDataType,
|
||||
typename AccDataType,
|
||||
typename YDataType,
|
||||
typename AccElementwiseOperation,
|
||||
index_t Rank,
|
||||
index_t NumReduceDim>
|
||||
using DeviceNormalization2Ptr = std::unique_ptr<DeviceNormalization2<XDataType,
|
||||
GammaDataType,
|
||||
BetaDataType,
|
||||
AccDataType,
|
||||
YDataType,
|
||||
AccElementwiseOperation,
|
||||
Rank,
|
||||
NumReduceDim>>;
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
|
||||
Reference in New Issue
Block a user