mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-03 21:21:22 +00:00
Improve normalization (#580)
* Sync the order of type string with template parameter * Add more instances * Check the vector size and remove redundant var * Extract var to static, prepare to separate sweep once kernel * Separate sweeponce flow and optimize the flow * 1. Rename AccDatatype in normalization to computeData 2. Rename AccElementwiseOperation to YElementwiseOperation in normalization * Remove useless code * Update naive variance kernel * Refine string * Fix typo * Support naive variance for device_normalization * Check the blocksize * Share the VGPR of x and y * Share the VGPR of gamma and beta * Add more instances * Support fp16 sqrt for experiment * Add CHANGELOG * Fix typo * clang-format
This commit is contained in:
@@ -14,9 +14,9 @@ namespace device {
|
||||
template <typename XDataType,
|
||||
typename GammaDataType,
|
||||
typename BetaDataType,
|
||||
typename AccDataType,
|
||||
typename ComputeDataType,
|
||||
typename YDataType,
|
||||
typename AccElementwiseOperation,
|
||||
typename YElementwiseOperation,
|
||||
index_t Rank,
|
||||
index_t NumReduceDim>
|
||||
struct DeviceNormalization : public BaseOperator
|
||||
@@ -35,7 +35,7 @@ struct DeviceNormalization : public BaseOperator
|
||||
void* p_y,
|
||||
void* p_savedMean,
|
||||
void* p_savedInvVar,
|
||||
AccElementwiseOperation acc_elementwise_op) = 0;
|
||||
YElementwiseOperation y_elementwise_op) = 0;
|
||||
|
||||
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
|
||||
};
|
||||
@@ -43,17 +43,17 @@ struct DeviceNormalization : public BaseOperator
|
||||
template <typename XDataType,
|
||||
typename GammaDataType,
|
||||
typename BetaDataType,
|
||||
typename AccDataType,
|
||||
typename ComputeDataType,
|
||||
typename YDataType,
|
||||
typename AccElementwiseOperation,
|
||||
typename YElementwiseOperation,
|
||||
index_t Rank,
|
||||
index_t NumReduceDim>
|
||||
using DeviceNormalizationPtr = std::unique_ptr<DeviceNormalization<XDataType,
|
||||
GammaDataType,
|
||||
BetaDataType,
|
||||
AccDataType,
|
||||
ComputeDataType,
|
||||
YDataType,
|
||||
AccElementwiseOperation,
|
||||
YElementwiseOperation,
|
||||
Rank,
|
||||
NumReduceDim>>;
|
||||
|
||||
|
||||
@@ -10,46 +10,11 @@
|
||||
#include "ck/tensor_operation/gpu/device/device_normalization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_reduce.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_reduce_common.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_normalization_welford_variance.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_normalization_selector.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_set_buffer_value.hpp"
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
#include "ck/host_utility/kernel_launch.hpp"
|
||||
|
||||
namespace ck {
|
||||
template <typename GridwiseReduction,
|
||||
typename XDataType,
|
||||
typename GammaDataType,
|
||||
typename BetaDataType,
|
||||
typename YDataType,
|
||||
typename AccDataType,
|
||||
typename AccElementwiseOperation,
|
||||
typename GridDesc_M_K>
|
||||
__global__ void kernel_normalization(const GridDesc_M_K x_grid_desc_m_k,
|
||||
const GridDesc_M_K gamma_grid_desc_m_k,
|
||||
const GridDesc_M_K beta_grid_desc_m_k,
|
||||
const GridDesc_M_K y_grid_desc_m_k,
|
||||
index_t num_k_block_tile_iteration,
|
||||
AccDataType epsilon,
|
||||
const XDataType* const __restrict__ p_x_global,
|
||||
const GammaDataType* const __restrict__ p_gamma_global,
|
||||
const BetaDataType* const __restrict__ p_beta_global,
|
||||
YDataType* const __restrict__ p_y_global,
|
||||
const AccElementwiseOperation acc_elementwise_op)
|
||||
{
|
||||
GridwiseReduction::Run(x_grid_desc_m_k,
|
||||
gamma_grid_desc_m_k,
|
||||
beta_grid_desc_m_k,
|
||||
y_grid_desc_m_k,
|
||||
num_k_block_tile_iteration,
|
||||
epsilon,
|
||||
p_x_global,
|
||||
p_gamma_global,
|
||||
p_beta_global,
|
||||
p_y_global,
|
||||
acc_elementwise_op);
|
||||
};
|
||||
} // namespace ck
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
@@ -58,9 +23,9 @@ namespace device {
|
||||
template <typename XDataType,
|
||||
typename GammaDataType,
|
||||
typename BetaDataType,
|
||||
typename AccDataType,
|
||||
typename ComputeDataType,
|
||||
typename YDataType,
|
||||
typename AccElementwiseOperation,
|
||||
typename YElementwiseOperation,
|
||||
index_t Rank,
|
||||
index_t NumReduceDim,
|
||||
index_t BlockSize,
|
||||
@@ -74,16 +39,18 @@ template <typename XDataType,
|
||||
index_t GammaSrcVectorSize,
|
||||
index_t BetaSrcVectorDim,
|
||||
index_t BetaSrcVectorSize,
|
||||
index_t YDstVectorSize>
|
||||
index_t YDstVectorSize,
|
||||
bool UseWelford = true>
|
||||
struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
|
||||
GammaDataType,
|
||||
BetaDataType,
|
||||
AccDataType,
|
||||
ComputeDataType,
|
||||
YDataType,
|
||||
AccElementwiseOperation,
|
||||
YElementwiseOperation,
|
||||
Rank,
|
||||
NumReduceDim>
|
||||
{
|
||||
static_assert(BlockSize == MThreadClusterSize * KThreadClusterSize);
|
||||
static_assert(
|
||||
((GammaSrcVectorDim == 0 && MThreadSliceSize % GammaSrcVectorSize == 0) ||
|
||||
(GammaSrcVectorDim == 1 && KThreadSliceSize % GammaSrcVectorSize == 0)),
|
||||
@@ -167,51 +134,6 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
|
||||
|
||||
using GridDesc_M_K = decltype(MakeSrc2dDescriptor({1}, {1}, 1, 1));
|
||||
|
||||
using GridwiseReduceLayernormGeneric =
|
||||
GridwiseNormalizationWelfordVariance_mk_to_mk<XDataType,
|
||||
GammaDataType,
|
||||
BetaDataType,
|
||||
YDataType,
|
||||
AccDataType,
|
||||
AccElementwiseOperation,
|
||||
GridDesc_M_K,
|
||||
BlockSize,
|
||||
MThreadClusterSize,
|
||||
KThreadClusterSize,
|
||||
MThreadSliceSize,
|
||||
KThreadSliceSize,
|
||||
XYSrcVectorDim,
|
||||
XSrcVectorSize,
|
||||
GammaSrcVectorDim,
|
||||
GammaSrcVectorSize,
|
||||
BetaSrcVectorDim,
|
||||
BetaSrcVectorSize,
|
||||
XYSrcVectorDim,
|
||||
YDstVectorSize,
|
||||
false>;
|
||||
using GridwiseNormalizationSweepOnce =
|
||||
GridwiseNormalizationWelfordVariance_mk_to_mk<XDataType,
|
||||
GammaDataType,
|
||||
BetaDataType,
|
||||
YDataType,
|
||||
AccDataType,
|
||||
AccElementwiseOperation,
|
||||
GridDesc_M_K,
|
||||
BlockSize,
|
||||
MThreadClusterSize,
|
||||
KThreadClusterSize,
|
||||
MThreadSliceSize,
|
||||
KThreadSliceSize,
|
||||
XYSrcVectorDim,
|
||||
XSrcVectorSize,
|
||||
GammaSrcVectorDim,
|
||||
GammaSrcVectorSize,
|
||||
BetaSrcVectorDim,
|
||||
BetaSrcVectorSize,
|
||||
XYSrcVectorDim,
|
||||
YDstVectorSize,
|
||||
true>;
|
||||
|
||||
struct Argument : public BaseArgument
|
||||
{
|
||||
Argument(const std::vector<index_t> lengths,
|
||||
@@ -220,7 +142,7 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
|
||||
const std::vector<index_t> betaStrides,
|
||||
const std::vector<index_t> yStrides,
|
||||
const std::vector<index_t> reduceDims,
|
||||
AccElementwiseOperation acc_elementwise_op,
|
||||
YElementwiseOperation y_elementwise_op,
|
||||
double epsilon,
|
||||
const XDataType* p_x,
|
||||
const GammaDataType* p_gamma,
|
||||
@@ -230,9 +152,9 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
|
||||
p_gamma_(p_gamma),
|
||||
p_beta_(p_beta),
|
||||
p_y_(p_y),
|
||||
acc_elementwise_op_(acc_elementwise_op)
|
||||
y_elementwise_op_(y_elementwise_op)
|
||||
{
|
||||
epsilon_ = static_cast<AccDataType>(epsilon);
|
||||
epsilon_ = static_cast<ComputeDataType>(epsilon);
|
||||
|
||||
Lengths_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(lengths, reduceDims);
|
||||
xStrides_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(xStrides, reduceDims);
|
||||
@@ -265,7 +187,7 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
|
||||
x_grid_desc_m_k_.GetLength(Number<1>{}) <= KThreadClusterSize * KThreadSliceSize;
|
||||
}
|
||||
|
||||
AccDataType epsilon_;
|
||||
ComputeDataType epsilon_;
|
||||
|
||||
const XDataType* p_x_;
|
||||
const GammaDataType* p_gamma_;
|
||||
@@ -278,7 +200,7 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
|
||||
std::vector<index_t> betaStrides_;
|
||||
std::vector<index_t> yStrides_;
|
||||
|
||||
AccElementwiseOperation acc_elementwise_op_;
|
||||
YElementwiseOperation y_elementwise_op_;
|
||||
|
||||
int blkGroupSize_;
|
||||
int numBlockTileIteration_;
|
||||
@@ -295,23 +217,27 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
|
||||
{
|
||||
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
|
||||
{
|
||||
const auto kernel_main = arg.isSweeponce_
|
||||
? kernel_normalization<GridwiseNormalizationSweepOnce,
|
||||
XDataType,
|
||||
GammaDataType,
|
||||
BetaDataType,
|
||||
YDataType,
|
||||
AccDataType,
|
||||
AccElementwiseOperation,
|
||||
GridDesc_M_K>
|
||||
: kernel_normalization<GridwiseReduceLayernormGeneric,
|
||||
XDataType,
|
||||
GammaDataType,
|
||||
BetaDataType,
|
||||
YDataType,
|
||||
AccDataType,
|
||||
AccElementwiseOperation,
|
||||
GridDesc_M_K>;
|
||||
auto kernel_main = NormalizationKernelSelector<XDataType,
|
||||
GammaDataType,
|
||||
BetaDataType,
|
||||
YDataType,
|
||||
ComputeDataType,
|
||||
YElementwiseOperation,
|
||||
GridDesc_M_K,
|
||||
BlockSize,
|
||||
MThreadClusterSize,
|
||||
KThreadClusterSize,
|
||||
MThreadSliceSize,
|
||||
KThreadSliceSize,
|
||||
XYSrcVectorDim,
|
||||
XSrcVectorSize,
|
||||
GammaSrcVectorDim,
|
||||
GammaSrcVectorSize,
|
||||
BetaSrcVectorDim,
|
||||
BetaSrcVectorSize,
|
||||
XYSrcVectorDim,
|
||||
YDstVectorSize,
|
||||
UseWelford>(arg.isSweeponce_);
|
||||
|
||||
float avg_time = 0;
|
||||
avg_time += launch_and_time_kernel(stream_config,
|
||||
@@ -329,7 +255,7 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
|
||||
arg.p_gamma_,
|
||||
arg.p_beta_,
|
||||
arg.p_y_,
|
||||
arg.acc_elementwise_op_);
|
||||
arg.y_elementwise_op_);
|
||||
|
||||
return (avg_time);
|
||||
};
|
||||
@@ -429,7 +355,7 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
|
||||
void* p_y,
|
||||
void* p_saveMean,
|
||||
void* p_saveInvVar,
|
||||
AccElementwiseOperation acc_elementwise_op) override
|
||||
YElementwiseOperation y_elementwise_op) override
|
||||
{
|
||||
// TODO
|
||||
// Optional cache of the intermediate results (mean and InvVariance) during the
|
||||
@@ -443,7 +369,7 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
|
||||
betaStrides,
|
||||
yStrides,
|
||||
reduceDims,
|
||||
acc_elementwise_op,
|
||||
y_elementwise_op,
|
||||
epsilon,
|
||||
static_cast<const XDataType*>(p_x),
|
||||
static_cast<const GammaDataType*>(p_gamma),
|
||||
@@ -462,8 +388,8 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
|
||||
|
||||
// clang-format off
|
||||
str << "DeviceNormalizationImpl<" << BlockSize << ",";
|
||||
str << "M_C" << MThreadClusterSize << "_S" << MThreadSliceSize << ",";
|
||||
str << "K_C" << KThreadClusterSize << "_S" << KThreadSliceSize << ",";
|
||||
str << "Cluster_MK_" << MThreadClusterSize << "_" << KThreadClusterSize << ",";
|
||||
str << "Slice_MK_" << MThreadSliceSize << "_" << KThreadSliceSize << ",";
|
||||
str << "XYSrcVectorDim_" << XYSrcVectorDim << ",";
|
||||
str << "VectorSize_X" << XSrcVectorSize << "_Gamma" << GammaSrcVectorSize << "_Beta" << BetaSrcVectorSize << "_Y" << YDstVectorSize << ">";
|
||||
// clang-format on
|
||||
|
||||
Reference in New Issue
Block a user