mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-05 14:11:29 +00:00
Fix bug of layernorm ckProfiler and refine code (#448)
* Fix bug of profiler for layernorm * 1. Rename layernorm into normalization 2. Decouple softmax from normalization * clang-format
This commit is contained in:
@@ -11,33 +11,6 @@
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
struct DeviceNormalization : public BaseOperator
|
||||
{
|
||||
// inLengths: input tensor extent(s) from high to low dimension
|
||||
// inStrides: input tensor stride(s) from high to low dimension
|
||||
// reduceDims: the dimension(s) the normalization operation is applied
|
||||
// alpha: typeless pointer in host memory storing the alpha scaling value of type AccDataType
|
||||
// beta: typeless pointer in host memory storing the beta scaling value of type AccDataType
|
||||
// in_dev: typeless const pointer in device memory storing the input tensor
|
||||
// out_dev: typeless pointer in device memory storing the output tensor
|
||||
virtual std::unique_ptr<BaseArgument> MakeArgumentPointer(const std::vector<index_t> inLengths,
|
||||
const std::vector<index_t> inStrides,
|
||||
const std::vector<int> reduceDims,
|
||||
const void* alpha,
|
||||
const void* beta,
|
||||
const void* in_dev,
|
||||
void* out_dev) = 0;
|
||||
|
||||
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
|
||||
|
||||
virtual index_t GetRank() const = 0;
|
||||
|
||||
virtual index_t GetNumReduceDim() const = 0;
|
||||
};
|
||||
|
||||
using DeviceNormalizationPtr = std::unique_ptr<DeviceNormalization>;
|
||||
|
||||
template <typename XDataType,
|
||||
typename GammaDataType,
|
||||
typename BetaDataType,
|
||||
@@ -46,7 +19,7 @@ template <typename XDataType,
|
||||
typename AccElementwiseOperation,
|
||||
index_t Rank,
|
||||
index_t NumReduceDim>
|
||||
struct DeviceLayernorm : public BaseOperator
|
||||
struct DeviceNormalization : public BaseOperator
|
||||
{
|
||||
virtual std::unique_ptr<BaseArgument>
|
||||
MakeArgumentPointer(const std::vector<index_t> lengths,
|
||||
@@ -73,14 +46,14 @@ template <typename XDataType,
|
||||
typename AccElementwiseOperation,
|
||||
index_t Rank,
|
||||
index_t NumReduceDim>
|
||||
using DeviceLayernormPtr = std::unique_ptr<DeviceLayernorm<XDataType,
|
||||
GammaDataType,
|
||||
BetaDataType,
|
||||
AccDataType,
|
||||
YDataType,
|
||||
AccElementwiseOperation,
|
||||
Rank,
|
||||
NumReduceDim>>;
|
||||
using DeviceNormalizationPtr = std::unique_ptr<DeviceNormalization<XDataType,
|
||||
GammaDataType,
|
||||
BetaDataType,
|
||||
AccDataType,
|
||||
YDataType,
|
||||
AccElementwiseOperation,
|
||||
Rank,
|
||||
NumReduceDim>>;
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
|
||||
@@ -75,14 +75,14 @@ template <typename XDataType,
|
||||
index_t BetaSrcVectorDim,
|
||||
index_t BetaSrcVectorSize,
|
||||
index_t YDstVectorSize>
|
||||
struct DeviceLayernormImpl : public DeviceLayernorm<XDataType,
|
||||
GammaDataType,
|
||||
BetaDataType,
|
||||
AccDataType,
|
||||
YDataType,
|
||||
AccElementwiseOperation,
|
||||
Rank,
|
||||
NumReduceDim>
|
||||
struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
|
||||
GammaDataType,
|
||||
BetaDataType,
|
||||
AccDataType,
|
||||
YDataType,
|
||||
AccElementwiseOperation,
|
||||
Rank,
|
||||
NumReduceDim>
|
||||
{
|
||||
static_assert(
|
||||
((GammaSrcVectorDim == 0 && MThreadSliceSize % GammaSrcVectorSize == 0) ||
|
||||
@@ -452,7 +452,7 @@ struct DeviceLayernormImpl : public DeviceLayernorm<XDataType,
|
||||
auto str = std::stringstream();
|
||||
|
||||
// clang-format off
|
||||
str << "DeviceLayernormImpl<" << BlockSize << ",";
|
||||
str << "DeviceNormalizationImpl<" << BlockSize << ",";
|
||||
str << "M_C" << MThreadClusterSize << "_S" << MThreadSliceSize << ",";
|
||||
str << "K_C" << KThreadClusterSize << "_S" << KThreadSliceSize << ",";
|
||||
str << "XYSrcVectorDim_" << XYSrcVectorDim << ",";
|
||||
Reference in New Issue
Block a user