Fix bug of layernorm ckProfiler and refine code (#448)

* Fix bug of profiler for layernorm

* 1. Rename layernorm into normalization
2. Decouple softmax from normalization

* clang-format
This commit is contained in:
rocking5566
2022-10-13 10:06:39 +08:00
committed by GitHub
parent a8236c1912
commit 1b62bfaa2a
29 changed files with 423 additions and 461 deletions

View File

@@ -11,33 +11,6 @@
namespace ck {
namespace tensor_operation {
namespace device {
struct DeviceNormalization : public BaseOperator
{
// inLengths: input tensor extent(s) from high to low dimension
// inStrides: input tensor stride(s) from high to low dimension
// reduceDims: the dimension(s) the normalization operation is applied
// alpha: typeless pointer in host memory storing the alpha scaling value of type AccDataType
// beta: typeless pointer in host memory storing the beta scaling value of type AccDataType
// in_dev: typeless const pointer in device memory storing the input tensor
// out_dev: typeless pointer in device memory storing the output tensor
virtual std::unique_ptr<BaseArgument> MakeArgumentPointer(const std::vector<index_t> inLengths,
const std::vector<index_t> inStrides,
const std::vector<int> reduceDims,
const void* alpha,
const void* beta,
const void* in_dev,
void* out_dev) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
virtual index_t GetRank() const = 0;
virtual index_t GetNumReduceDim() const = 0;
};
using DeviceNormalizationPtr = std::unique_ptr<DeviceNormalization>;
template <typename XDataType,
typename GammaDataType,
typename BetaDataType,
@@ -46,7 +19,7 @@ template <typename XDataType,
typename AccElementwiseOperation,
index_t Rank,
index_t NumReduceDim>
struct DeviceLayernorm : public BaseOperator
struct DeviceNormalization : public BaseOperator
{
virtual std::unique_ptr<BaseArgument>
MakeArgumentPointer(const std::vector<index_t> lengths,
@@ -73,14 +46,14 @@ template <typename XDataType,
typename AccElementwiseOperation,
index_t Rank,
index_t NumReduceDim>
using DeviceLayernormPtr = std::unique_ptr<DeviceLayernorm<XDataType,
GammaDataType,
BetaDataType,
AccDataType,
YDataType,
AccElementwiseOperation,
Rank,
NumReduceDim>>;
using DeviceNormalizationPtr = std::unique_ptr<DeviceNormalization<XDataType,
GammaDataType,
BetaDataType,
AccDataType,
YDataType,
AccElementwiseOperation,
Rank,
NumReduceDim>>;
} // namespace device
} // namespace tensor_operation

View File

@@ -75,14 +75,14 @@ template <typename XDataType,
index_t BetaSrcVectorDim,
index_t BetaSrcVectorSize,
index_t YDstVectorSize>
struct DeviceLayernormImpl : public DeviceLayernorm<XDataType,
GammaDataType,
BetaDataType,
AccDataType,
YDataType,
AccElementwiseOperation,
Rank,
NumReduceDim>
struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
GammaDataType,
BetaDataType,
AccDataType,
YDataType,
AccElementwiseOperation,
Rank,
NumReduceDim>
{
static_assert(
((GammaSrcVectorDim == 0 && MThreadSliceSize % GammaSrcVectorSize == 0) ||
@@ -452,7 +452,7 @@ struct DeviceLayernormImpl : public DeviceLayernorm<XDataType,
auto str = std::stringstream();
// clang-format off
str << "DeviceLayernormImpl<" << BlockSize << ",";
str << "DeviceNormalizationImpl<" << BlockSize << ",";
str << "M_C" << MThreadClusterSize << "_S" << MThreadSliceSize << ",";
str << "K_C" << KThreadClusterSize << "_S" << KThreadSliceSize << ",";
str << "XYSrcVectorDim_" << XYSrcVectorDim << ",";