mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-05 14:11:29 +00:00
Backward of gamma and beta for layernorm and groupnorm (#1013)
* Add layernorm backward reference code * Add groupnorm backward reference code * Add example * clang format * Fixc bug of reference layernorm and groupnorm * Fix naming * Refine naming * Add device op for normalization bwd gamma and beta * Refine template parameter * Add bwd gamma & beta of kernel * 1. Add groupnorm example 2. Refine layernorm naming * Narrow down the static check for performance * Refine variable name
This commit is contained in:
@@ -0,0 +1,61 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <iostream>
|
||||
#include <vector>
|
||||
|
||||
#include "ck/tensor_operation/gpu/device/device_base.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
template <typename DYDataType,
|
||||
typename XDataType,
|
||||
typename MeanInvStdDataType,
|
||||
typename DGammaDataType,
|
||||
typename DBetaDataType,
|
||||
index_t Rank,
|
||||
index_t NumReduceDim>
|
||||
struct DeviceNormalizationBwdGammaBeta : public BaseOperator
|
||||
{
|
||||
virtual std::unique_ptr<BaseArgument>
|
||||
MakeArgumentPointer(const std::vector<index_t> inLengths,
|
||||
const std::vector<index_t> dyStrides,
|
||||
const std::vector<index_t> xStrides,
|
||||
const std::vector<index_t> meanStrides,
|
||||
const std::vector<index_t> invStdStrides,
|
||||
const std::vector<index_t> outLengths,
|
||||
const std::vector<index_t> dgammaStrides,
|
||||
const std::vector<index_t> dbetaStrides,
|
||||
const std::vector<index_t> reduceDims,
|
||||
const void* p_dy,
|
||||
const void* p_x,
|
||||
const void* p_mean,
|
||||
const void* p_invStd,
|
||||
void* p_dgamma,
|
||||
void* p_dbeta) = 0;
|
||||
|
||||
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
|
||||
};
|
||||
|
||||
template <typename DYDataType,
|
||||
typename XDataType,
|
||||
typename MeanInvStdDataType,
|
||||
typename DGammaDataType,
|
||||
typename DBetaDataType,
|
||||
index_t Rank,
|
||||
index_t NumReduceDim>
|
||||
using DeviceNormalizationBwdGammaBetaPtr =
|
||||
std::unique_ptr<DeviceNormalizationBwdGammaBeta<DYDataType,
|
||||
XDataType,
|
||||
MeanInvStdDataType,
|
||||
DGammaDataType,
|
||||
DBetaDataType,
|
||||
Rank,
|
||||
NumReduceDim>>;
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,464 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <iostream>
|
||||
#include <vector>
|
||||
|
||||
#include "ck/tensor_operation/gpu/device/device_normalization_bwd_gamma_beta.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_bwd_gamma_beta.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_reduce_common.hpp"
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
#include "ck/host_utility/kernel_launch.hpp"
|
||||
|
||||
// M is invarient dimension, K is reduced dimension
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
template <typename GridwiseReduction,
|
||||
typename DYDataType,
|
||||
typename XDataType,
|
||||
typename MeanInvStdDataType,
|
||||
typename DGammaDataType,
|
||||
typename DBetaDataType,
|
||||
typename GridDesc_M_K,
|
||||
typename GridDesc_M>
|
||||
__global__ void
|
||||
kernel_normalization_bwd_gamma_beta(const GridDesc_M_K dy_grid_desc_m_k,
|
||||
const GridDesc_M_K x_grid_desc_m_k,
|
||||
const GridDesc_M_K mean_grid_desc_m_k,
|
||||
const GridDesc_M_K inv_std_grid_desc_m_k,
|
||||
const GridDesc_M dgamma_grid_desc_m,
|
||||
const GridDesc_M dbeta_grid_desc_m,
|
||||
index_t num_k_block_tile_iteration,
|
||||
const DYDataType* const __restrict__ p_dy_global,
|
||||
const XDataType* const __restrict__ p_x_global,
|
||||
const MeanInvStdDataType* const __restrict__ p_mean_global,
|
||||
const MeanInvStdDataType* const __restrict__ p_inv_std_global,
|
||||
DGammaDataType* const __restrict__ p_dgamma_global,
|
||||
DBetaDataType* const __restrict__ p_dbeta_global)
|
||||
{
|
||||
GridwiseReduction::Run(dy_grid_desc_m_k,
|
||||
x_grid_desc_m_k,
|
||||
mean_grid_desc_m_k,
|
||||
inv_std_grid_desc_m_k,
|
||||
dgamma_grid_desc_m,
|
||||
dbeta_grid_desc_m,
|
||||
num_k_block_tile_iteration,
|
||||
p_dy_global,
|
||||
p_x_global,
|
||||
p_mean_global,
|
||||
p_inv_std_global,
|
||||
p_dgamma_global,
|
||||
p_dbeta_global);
|
||||
};
|
||||
|
||||
template <typename DYDataType,
|
||||
typename XDataType,
|
||||
typename MeanInvStdDataType,
|
||||
typename ComputeDataType,
|
||||
typename DGammaDataType,
|
||||
typename DBetaDataType,
|
||||
index_t Rank,
|
||||
index_t NumReduceDim,
|
||||
index_t BlockSize,
|
||||
index_t MThreadClusterSize,
|
||||
index_t KThreadClusterSize,
|
||||
index_t MThreadSliceSize,
|
||||
index_t KThreadSliceSize,
|
||||
bool IsDYFastestDimReduced,
|
||||
index_t DYSrcVectorSize,
|
||||
bool IsXFastestDimReduced,
|
||||
index_t XSrcVectorSize,
|
||||
bool IsMeanInvStdFastestDimReduced,
|
||||
index_t MeanInvStdSrcVectorSize,
|
||||
index_t DGammaDstVectorSize,
|
||||
index_t DBetaDstVectorSize>
|
||||
struct DeviceNormalizationBwdGammaBetaImpl
|
||||
: public DeviceNormalizationBwdGammaBeta<DYDataType,
|
||||
XDataType,
|
||||
MeanInvStdDataType,
|
||||
DGammaDataType,
|
||||
DBetaDataType,
|
||||
Rank,
|
||||
NumReduceDim>
|
||||
{
|
||||
|
||||
static constexpr index_t DYSrcVectorDim = IsDYFastestDimReduced ? 1 : 0;
|
||||
static constexpr index_t XSrcVectorDim = IsXFastestDimReduced ? 1 : 0;
|
||||
static constexpr index_t MeanInvStdSrcVectorDim = IsMeanInvStdFastestDimReduced ? 1 : 0;
|
||||
|
||||
static_assert(BlockSize == MThreadClusterSize * KThreadClusterSize);
|
||||
|
||||
static_assert(((DYSrcVectorDim == 0 && MThreadSliceSize % DYSrcVectorSize == 0) ||
|
||||
(DYSrcVectorDim == 1 && KThreadSliceSize % DYSrcVectorSize == 0)),
|
||||
"Invalid thread slice sizes and/or dy vector sizes configuration, please check!");
|
||||
|
||||
static_assert(((XSrcVectorDim == 0 && MThreadSliceSize % XSrcVectorSize == 0) ||
|
||||
(XSrcVectorDim == 1 && KThreadSliceSize % XSrcVectorSize == 0)),
|
||||
"Invalid thread slice sizes and/or x vector sizes configuration, please check!");
|
||||
|
||||
static_assert(
|
||||
((MThreadSliceSize % DGammaDstVectorSize == 0) ||
|
||||
(MThreadSliceSize % DBetaDstVectorSize == 0)),
|
||||
"Invalid thread slice sizes and/or Gamma and beta vector sizes configuration, please "
|
||||
"check!");
|
||||
|
||||
static_assert(
|
||||
(MeanInvStdSrcVectorDim == 0 && MThreadSliceSize % MeanInvStdSrcVectorSize == 0) ||
|
||||
(MeanInvStdSrcVectorDim == 1 && KThreadSliceSize % MeanInvStdSrcVectorSize == 0),
|
||||
"Invalid thread slice sizes and/or mean and inverse std vector sizes configuration, please "
|
||||
"check!");
|
||||
|
||||
static constexpr index_t NumInvariantDim = Rank - NumReduceDim;
|
||||
static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize;
|
||||
static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize;
|
||||
|
||||
static constexpr bool reduceAllDim = (NumInvariantDim == 0);
|
||||
static_assert(!reduceAllDim);
|
||||
|
||||
static auto MakeSrc2dDescriptor(const std::vector<index_t>& inLengths,
|
||||
const std::vector<index_t>& inStrides,
|
||||
int numBlockTileIteration)
|
||||
{
|
||||
const auto tupleSrcLengths = make_tuple_from_array(inLengths, Number<Rank>{});
|
||||
const auto tupleSrcStrides = make_tuple_from_array(inStrides, Number<Rank>{});
|
||||
|
||||
const auto inDesc = make_naive_tensor_descriptor(tupleSrcLengths, tupleSrcStrides);
|
||||
|
||||
const auto in_grid_desc_m_k = [&]() {
|
||||
using InvariantDims = typename arithmetic_sequence_gen<0, NumInvariantDim, 1>::type;
|
||||
using ReduceDims = typename arithmetic_sequence_gen<NumInvariantDim, Rank, 1>::type;
|
||||
|
||||
const auto reduceDimLengths =
|
||||
make_tuple_from_array_and_index_seq(inLengths, ReduceDims{});
|
||||
const auto invariantDimLengths =
|
||||
make_tuple_from_array_and_index_seq(inLengths, InvariantDims{});
|
||||
|
||||
return transform_tensor_descriptor(inDesc,
|
||||
make_tuple(make_merge_transform(invariantDimLengths),
|
||||
make_merge_transform(reduceDimLengths)),
|
||||
make_tuple(InvariantDims{}, ReduceDims{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
}();
|
||||
|
||||
const auto invariantLength = in_grid_desc_m_k.GetLength(Number<0>{});
|
||||
const auto reduceLength = in_grid_desc_m_k.GetLength(Number<1>{});
|
||||
|
||||
const auto inPad_M =
|
||||
math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength;
|
||||
const auto inPad_K = K_BlockTileSize * numBlockTileIteration - reduceLength;
|
||||
|
||||
auto in_grid_desc_m_k_padded = transform_tensor_descriptor(
|
||||
in_grid_desc_m_k,
|
||||
make_tuple(make_right_pad_transform(invariantLength, inPad_M),
|
||||
make_right_pad_transform(reduceLength, inPad_K)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
return in_grid_desc_m_k_padded;
|
||||
}
|
||||
|
||||
static auto MakeDst1dDescriptor(const std::vector<index_t>& outLengths,
|
||||
const std::vector<index_t>& outStrides)
|
||||
{
|
||||
const auto tupleDstLengths =
|
||||
generate_tuple([&](auto I) { return outLengths[I]; }, Number<NumInvariantDim>{});
|
||||
const auto tupleDstStrides =
|
||||
generate_tuple([&](auto I) { return outStrides[I]; }, Number<NumInvariantDim>{});
|
||||
|
||||
auto outDesc = make_naive_tensor_descriptor(tupleDstLengths, tupleDstStrides);
|
||||
|
||||
auto out_grid_desc_m = transform_tensor_descriptor(
|
||||
outDesc,
|
||||
make_tuple(make_merge_transform(tupleDstLengths)),
|
||||
make_tuple(typename arithmetic_sequence_gen<0, NumInvariantDim, 1>::type{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
|
||||
const auto invariantLength = out_grid_desc_m.GetLength(Number<0>{});
|
||||
|
||||
const auto outPad =
|
||||
math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength;
|
||||
|
||||
auto out_grid_desc_m_padded = transform_tensor_descriptor(
|
||||
out_grid_desc_m,
|
||||
make_tuple(make_right_pad_transform(invariantLength, outPad)),
|
||||
make_tuple(Sequence<0>{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
return (out_grid_desc_m_padded);
|
||||
};
|
||||
|
||||
using GridDesc_M_K = decltype(MakeSrc2dDescriptor({1}, {1}, 1));
|
||||
using GridDesc_M = decltype(MakeDst1dDescriptor({1}, {1}));
|
||||
|
||||
using GridwiseNormalizationBwdGammaBeta =
|
||||
GridwiseNormalizationBwdGammaBeta_mk_to_k<DYDataType,
|
||||
XDataType,
|
||||
MeanInvStdDataType,
|
||||
ComputeDataType,
|
||||
DGammaDataType,
|
||||
DBetaDataType,
|
||||
GridDesc_M_K,
|
||||
GridDesc_M,
|
||||
BlockSize,
|
||||
MThreadClusterSize,
|
||||
KThreadClusterSize,
|
||||
MThreadSliceSize,
|
||||
KThreadSliceSize,
|
||||
DYSrcVectorDim,
|
||||
DYSrcVectorSize,
|
||||
XSrcVectorDim,
|
||||
XSrcVectorSize,
|
||||
MeanInvStdSrcVectorDim,
|
||||
MeanInvStdSrcVectorSize,
|
||||
DGammaDstVectorSize,
|
||||
DBetaDstVectorSize>;
|
||||
|
||||
struct Argument : public BaseArgument
|
||||
{
|
||||
Argument(const std::vector<index_t> inLengths,
|
||||
const std::vector<index_t> dyStrides,
|
||||
const std::vector<index_t> xStrides,
|
||||
const std::vector<index_t> meanStrides,
|
||||
const std::vector<index_t> invStdStrides,
|
||||
const std::vector<index_t> outLengths,
|
||||
const std::vector<index_t> dgammaStrides,
|
||||
const std::vector<index_t> dbetaStrides,
|
||||
const std::vector<index_t> reduceDims,
|
||||
const DYDataType* p_dy,
|
||||
const XDataType* p_x,
|
||||
const MeanInvStdDataType* p_mean,
|
||||
const MeanInvStdDataType* p_invStd,
|
||||
DGammaDataType* p_dgamma,
|
||||
DBetaDataType* p_dbeta)
|
||||
: p_dy_(p_dy),
|
||||
p_x_(p_x),
|
||||
p_mean_(p_mean),
|
||||
p_invStd_(p_invStd),
|
||||
p_dgamma_(p_dgamma),
|
||||
p_dbeta_(p_dbeta),
|
||||
outLengths_{outLengths},
|
||||
dgammaStrides_{dgammaStrides},
|
||||
dbetaStrides_{dbetaStrides}
|
||||
{
|
||||
inLengths_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(inLengths, reduceDims);
|
||||
dyStrides_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(dyStrides, reduceDims);
|
||||
xStrides_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(xStrides, reduceDims);
|
||||
meanStrides_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(meanStrides, reduceDims);
|
||||
invStdStrides_ =
|
||||
shuffle_tensor_dimensions<Rank, NumReduceDim>(invStdStrides, reduceDims);
|
||||
|
||||
std::tie(MRaw_, KRaw_) = get_2d_lengths<Rank, NumReduceDim>(inLengths_);
|
||||
|
||||
numBlockTileIteration_ = math::integer_divide_ceil(KRaw_, K_BlockTileSize);
|
||||
|
||||
gridSize_ = math::integer_divide_ceil(MRaw_, M_BlockTileSize);
|
||||
|
||||
dy_grid_desc_m_k_ = MakeSrc2dDescriptor(inLengths_, dyStrides_, numBlockTileIteration_);
|
||||
x_grid_desc_m_k_ = MakeSrc2dDescriptor(inLengths_, xStrides_, numBlockTileIteration_);
|
||||
mean_grid_desc_m_k_ =
|
||||
MakeSrc2dDescriptor(inLengths_, meanStrides_, numBlockTileIteration_);
|
||||
inv_std_grid_desc_m_k_ =
|
||||
MakeSrc2dDescriptor(inLengths_, invStdStrides_, numBlockTileIteration_);
|
||||
|
||||
dgamma_grid_desc_m_ = MakeDst1dDescriptor(outLengths_, dgammaStrides_);
|
||||
dbeta_grid_desc_m_ = MakeDst1dDescriptor(outLengths_, dbetaStrides_);
|
||||
}
|
||||
|
||||
const DYDataType* p_dy_;
|
||||
const XDataType* p_x_;
|
||||
const MeanInvStdDataType* p_mean_;
|
||||
const MeanInvStdDataType* p_invStd_;
|
||||
DGammaDataType* p_dgamma_;
|
||||
DBetaDataType* p_dbeta_;
|
||||
|
||||
std::vector<index_t> inLengths_;
|
||||
std::vector<index_t> dyStrides_;
|
||||
std::vector<index_t> xStrides_;
|
||||
std::vector<index_t> meanStrides_;
|
||||
std::vector<index_t> invStdStrides_;
|
||||
std::vector<index_t> outLengths_;
|
||||
std::vector<index_t> dgammaStrides_;
|
||||
std::vector<index_t> dbetaStrides_;
|
||||
|
||||
int numBlockTileIteration_;
|
||||
size_t gridSize_;
|
||||
|
||||
// Source descriptor
|
||||
GridDesc_M_K dy_grid_desc_m_k_;
|
||||
GridDesc_M_K x_grid_desc_m_k_;
|
||||
GridDesc_M_K mean_grid_desc_m_k_;
|
||||
GridDesc_M_K inv_std_grid_desc_m_k_;
|
||||
|
||||
// Destination descriptor
|
||||
GridDesc_M dgamma_grid_desc_m_;
|
||||
GridDesc_M dbeta_grid_desc_m_;
|
||||
|
||||
index_t MRaw_; // invarient length
|
||||
index_t KRaw_; // reduce length
|
||||
};
|
||||
|
||||
struct Invoker : public BaseInvoker
|
||||
{
|
||||
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
|
||||
{
|
||||
const auto kernel_main =
|
||||
kernel_normalization_bwd_gamma_beta<GridwiseNormalizationBwdGammaBeta,
|
||||
DYDataType,
|
||||
XDataType,
|
||||
MeanInvStdDataType,
|
||||
DGammaDataType,
|
||||
DBetaDataType,
|
||||
GridDesc_M_K,
|
||||
GridDesc_M>;
|
||||
|
||||
return launch_and_time_kernel(stream_config,
|
||||
kernel_main,
|
||||
dim3(arg.gridSize_),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
arg.dy_grid_desc_m_k_,
|
||||
arg.x_grid_desc_m_k_,
|
||||
arg.mean_grid_desc_m_k_,
|
||||
arg.inv_std_grid_desc_m_k_,
|
||||
arg.dgamma_grid_desc_m_,
|
||||
arg.dbeta_grid_desc_m_,
|
||||
arg.numBlockTileIteration_,
|
||||
arg.p_dy_,
|
||||
arg.p_x_,
|
||||
arg.p_mean_,
|
||||
arg.p_invStd_,
|
||||
arg.p_dgamma_,
|
||||
arg.p_dbeta_);
|
||||
}
|
||||
|
||||
float Run(const BaseArgument* p_arg,
|
||||
const StreamConfig& stream_config = StreamConfig{}) override
|
||||
{
|
||||
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
|
||||
}
|
||||
};
|
||||
|
||||
template <index_t SrcVectorDim, index_t SrcVectorSize>
|
||||
bool IsSrcVectorDimSizeValid(const std::vector<index_t>& lengths,
|
||||
const std::vector<index_t>& strides)
|
||||
{
|
||||
if constexpr(SrcVectorSize == 1)
|
||||
return true;
|
||||
|
||||
// Fastest dimension is not reduced
|
||||
if constexpr(SrcVectorDim == 0)
|
||||
{
|
||||
if constexpr(NumInvariantDim == 0)
|
||||
return false;
|
||||
|
||||
if(strides[NumInvariantDim - 1] != 1)
|
||||
return false;
|
||||
|
||||
if(lengths[NumInvariantDim - 1] % SrcVectorSize != 0)
|
||||
return false;
|
||||
}
|
||||
else // Fastest dimension is reduced
|
||||
{
|
||||
if(strides[Rank - 1] != 1)
|
||||
return false;
|
||||
|
||||
if(lengths[Rank - 1] % SrcVectorSize != 0)
|
||||
return false;
|
||||
};
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
template <index_t DstVectorSize>
|
||||
bool IsDstVectorSizeValid(const std::vector<index_t>& lengths,
|
||||
const std::vector<index_t>& strides)
|
||||
{
|
||||
if constexpr(DstVectorSize == 1)
|
||||
return true;
|
||||
|
||||
if(strides[NumInvariantDim - 1] != 1)
|
||||
return false;
|
||||
|
||||
if(lengths[NumInvariantDim - 1] % DstVectorSize != 0)
|
||||
return false;
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
bool IsSupportedArgument(const BaseArgument* p_arg) override
|
||||
{
|
||||
const Argument* p_arg_ = dynamic_cast<const Argument*>(p_arg);
|
||||
|
||||
bool pass = true;
|
||||
pass &= IsSrcVectorDimSizeValid<DYSrcVectorDim, DYSrcVectorSize>(p_arg_->inLengths_,
|
||||
p_arg_->dyStrides_);
|
||||
pass &= IsSrcVectorDimSizeValid<XSrcVectorDim, XSrcVectorSize>(p_arg_->inLengths_,
|
||||
p_arg_->xStrides_);
|
||||
pass &= IsSrcVectorDimSizeValid<MeanInvStdSrcVectorDim, MeanInvStdSrcVectorSize>(
|
||||
p_arg_->inLengths_, p_arg_->meanStrides_);
|
||||
pass &= IsSrcVectorDimSizeValid<MeanInvStdSrcVectorDim, MeanInvStdSrcVectorSize>(
|
||||
p_arg_->inLengths_, p_arg_->invStdStrides_);
|
||||
|
||||
pass &=
|
||||
IsDstVectorSizeValid<DGammaDstVectorSize>(p_arg_->outLengths_, p_arg_->dgammaStrides_);
|
||||
pass &=
|
||||
IsDstVectorSizeValid<DBetaDstVectorSize>(p_arg_->outLengths_, p_arg_->dbetaStrides_);
|
||||
|
||||
return pass;
|
||||
}
|
||||
|
||||
std::unique_ptr<BaseArgument> MakeArgumentPointer(const std::vector<index_t> inLengths,
|
||||
const std::vector<index_t> dyStrides,
|
||||
const std::vector<index_t> xStrides,
|
||||
const std::vector<index_t> meanStrides,
|
||||
const std::vector<index_t> invStdStrides,
|
||||
const std::vector<index_t> outLengths,
|
||||
const std::vector<index_t> dgammaStrides,
|
||||
const std::vector<index_t> dbetaStrides,
|
||||
const std::vector<index_t> reduceDims,
|
||||
const void* p_dy,
|
||||
const void* p_x,
|
||||
const void* p_mean,
|
||||
const void* p_invStd,
|
||||
void* p_dgamma,
|
||||
void* p_dbeta) override
|
||||
{
|
||||
if(inLengths.size() != Rank || dyStrides.size() != Rank || xStrides.size() != Rank ||
|
||||
meanStrides.size() != Rank || invStdStrides.size() != Rank)
|
||||
throw std::runtime_error("dimension is incorrect");
|
||||
|
||||
if(outLengths.size() != NumInvariantDim || dgammaStrides.size() != NumInvariantDim ||
|
||||
dbetaStrides.size() != NumInvariantDim)
|
||||
throw std::runtime_error("dimension is incorrect");
|
||||
|
||||
return std::make_unique<Argument>(inLengths,
|
||||
dyStrides,
|
||||
xStrides,
|
||||
meanStrides,
|
||||
invStdStrides,
|
||||
outLengths,
|
||||
dgammaStrides,
|
||||
dbetaStrides,
|
||||
reduceDims,
|
||||
static_cast<const DYDataType*>(p_dy),
|
||||
static_cast<const XDataType*>(p_x),
|
||||
static_cast<const MeanInvStdDataType*>(p_mean),
|
||||
static_cast<const MeanInvStdDataType*>(p_invStd),
|
||||
static_cast<DGammaDataType*>(p_dgamma),
|
||||
static_cast<DBetaDataType*>(p_dbeta));
|
||||
}
|
||||
|
||||
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
|
||||
{
|
||||
return std::make_unique<Invoker>();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,343 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck/utility/data_type.hpp"
|
||||
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
#include "ck/tensor_operation/gpu/block/reduction_functions_blockwise.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
// dgamma = reduce_sum(dy * (x - mean) * inv_std)
|
||||
// dbeta = reduce_sum(dy)
|
||||
template <typename DYDataType,
|
||||
typename XDataType,
|
||||
typename MeanInvStdDataType,
|
||||
typename ComputeDataType,
|
||||
typename DGammaDataType,
|
||||
typename DBetaDataType,
|
||||
typename GridDesc_M_K,
|
||||
typename GridDesc_M,
|
||||
index_t BlockSize,
|
||||
index_t MThreadClusterSize,
|
||||
index_t KThreadClusterSize,
|
||||
index_t MThreadSliceSize,
|
||||
index_t KThreadSliceSize,
|
||||
index_t DYSrcVectorDim,
|
||||
index_t DYSrcVectorSize,
|
||||
index_t XSrcVectorDim,
|
||||
index_t XSrcVectorSize,
|
||||
index_t MeanInvStdSrcVectorDim,
|
||||
index_t MeanInvStdSrcVectorSize,
|
||||
index_t DGammaDstVectorSize,
|
||||
index_t DBetaDstVectorSize>
|
||||
struct GridwiseNormalizationBwdGammaBeta_mk_to_k
|
||||
{
|
||||
// if we just check ThreadSliceSize & VectorSize == 0, the performance may be poor
|
||||
static_assert(((DYSrcVectorDim == 0 && MThreadSliceSize == DYSrcVectorSize) ||
|
||||
(DYSrcVectorDim == 1 && KThreadSliceSize == DYSrcVectorSize)),
|
||||
"Invalid thread slice sizes and/or dy vector sizes configuration, please check!");
|
||||
|
||||
static_assert(((XSrcVectorDim == 0 && MThreadSliceSize == XSrcVectorSize) ||
|
||||
(XSrcVectorDim == 1 && KThreadSliceSize == XSrcVectorSize)),
|
||||
"Invalid thread slice sizes and/or x vector sizes configuration, please check!");
|
||||
|
||||
using ThreadClusterLengths_M_K = Sequence<MThreadClusterSize, KThreadClusterSize>;
|
||||
|
||||
using DYThreadBufferDimAccessOrder =
|
||||
typename conditional<DYSrcVectorDim == 0, Sequence<1, 0>, Sequence<0, 1>>::type;
|
||||
|
||||
using XThreadBufferDimAccessOrder =
|
||||
typename conditional<XSrcVectorDim == 0, Sequence<1, 0>, Sequence<0, 1>>::type;
|
||||
|
||||
using MeanInvStdThreadBufferDimAccessOrder =
|
||||
typename conditional<MeanInvStdSrcVectorDim == 0, Sequence<1, 0>, Sequence<0, 1>>::type;
|
||||
|
||||
using ThreadClusterArrangeOrder = DYThreadBufferDimAccessOrder;
|
||||
|
||||
static constexpr auto thread_cluster_desc =
|
||||
make_cluster_descriptor(ThreadClusterLengths_M_K{}, ThreadClusterArrangeOrder{});
|
||||
|
||||
using ThreadBufferLengths_M_K = Sequence<MThreadSliceSize, KThreadSliceSize>;
|
||||
using ThreadBufferLengths_M = Sequence<MThreadSliceSize>;
|
||||
|
||||
static constexpr auto thread_buffer_desc_m_k = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(Number<MThreadSliceSize>{}, Number<KThreadSliceSize>{}));
|
||||
|
||||
static constexpr auto thread_buffer_desc_m =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{}));
|
||||
|
||||
using PassThroughOp = tensor_operation::element_wise::PassThrough;
|
||||
|
||||
using BlockwiseSumReduce = PartitionedBlockwiseReduction<ComputeDataType,
|
||||
BlockSize,
|
||||
ThreadClusterLengths_M_K,
|
||||
ThreadClusterArrangeOrder,
|
||||
reduce::Add,
|
||||
true>;
|
||||
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
|
||||
static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize;
|
||||
static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize;
|
||||
|
||||
__device__ static void Run(const GridDesc_M_K& dy_grid_desc_m_k,
|
||||
const GridDesc_M_K& x_grid_desc_m_k,
|
||||
const GridDesc_M_K& mean_grid_desc_m_k,
|
||||
const GridDesc_M_K& inv_std_grid_desc_m_k,
|
||||
const GridDesc_M& dgamma_grid_desc_m,
|
||||
const GridDesc_M& dbeta_grid_desc_m,
|
||||
index_t num_k_block_tile_iteration,
|
||||
const DYDataType* const __restrict__ p_dy_global,
|
||||
const XDataType* const __restrict__ p_x_global,
|
||||
const MeanInvStdDataType* const __restrict__ p_mean_global,
|
||||
const MeanInvStdDataType* const __restrict__ p_inv_std_global,
|
||||
DGammaDataType* const __restrict__ p_dgamma_global,
|
||||
DBetaDataType* const __restrict__ p_dbeta_global)
|
||||
{
|
||||
// LDS
|
||||
__shared__ ComputeDataType p_reduce_work_buffer[BlockSize];
|
||||
|
||||
auto reduce_work_buf =
|
||||
make_dynamic_buffer<AddressSpaceEnum::Lds>(p_reduce_work_buffer, BlockSize);
|
||||
|
||||
// Global
|
||||
const auto dy_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_dy_global, dy_grid_desc_m_k.GetElementSpaceSize());
|
||||
|
||||
const auto x_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_x_global, x_grid_desc_m_k.GetElementSpaceSize());
|
||||
|
||||
const auto mean_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_mean_global, mean_grid_desc_m_k.GetElementSpaceSize());
|
||||
|
||||
const auto inv_std_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_inv_std_global, inv_std_grid_desc_m_k.GetElementSpaceSize());
|
||||
|
||||
auto dgamma_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_dgamma_global, dgamma_grid_desc_m.GetElementSpaceSize());
|
||||
|
||||
auto dbeta_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_dbeta_global, dbeta_grid_desc_m.GetElementSpaceSize());
|
||||
|
||||
// VGPR
|
||||
auto dy_thread_buf = StaticBuffer<AddressSpaceEnum::Vgpr,
|
||||
ComputeDataType,
|
||||
MThreadSliceSize * KThreadSliceSize,
|
||||
true>{};
|
||||
|
||||
auto x_thread_buf = StaticBuffer<AddressSpaceEnum::Vgpr,
|
||||
ComputeDataType,
|
||||
MThreadSliceSize * KThreadSliceSize,
|
||||
true>{};
|
||||
|
||||
auto mean_thread_buf = StaticBuffer<AddressSpaceEnum::Vgpr,
|
||||
ComputeDataType,
|
||||
MThreadSliceSize * KThreadSliceSize,
|
||||
true>{};
|
||||
|
||||
auto inv_std_thread_buf = StaticBuffer<AddressSpaceEnum::Vgpr,
|
||||
ComputeDataType,
|
||||
MThreadSliceSize * KThreadSliceSize,
|
||||
true>{};
|
||||
|
||||
auto dgamma_thread_buf =
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr, ComputeDataType, MThreadSliceSize, true>{};
|
||||
|
||||
auto dbeta_thread_buf =
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr, ComputeDataType, MThreadSliceSize, true>{};
|
||||
|
||||
const index_t thread_local_id = get_thread_local_1d_id();
|
||||
const index_t block_global_id = get_block_1d_id();
|
||||
|
||||
const auto thread_cluster_idx =
|
||||
thread_cluster_desc.CalculateBottomIndex(make_multi_index(thread_local_id));
|
||||
|
||||
const auto thread_m_cluster_id = thread_cluster_idx[I0];
|
||||
const auto thread_k_cluster_id = thread_cluster_idx[I1];
|
||||
|
||||
// IO
|
||||
auto threadwise_dy_load = ThreadwiseTensorSliceTransfer_v2<DYDataType,
|
||||
ComputeDataType,
|
||||
GridDesc_M_K,
|
||||
decltype(thread_buffer_desc_m_k),
|
||||
ThreadBufferLengths_M_K,
|
||||
DYThreadBufferDimAccessOrder,
|
||||
DYSrcVectorDim,
|
||||
DYSrcVectorSize,
|
||||
1,
|
||||
true>(
|
||||
dy_grid_desc_m_k,
|
||||
make_multi_index(block_global_id * M_BlockTileSize +
|
||||
thread_m_cluster_id * MThreadSliceSize,
|
||||
thread_k_cluster_id * KThreadSliceSize));
|
||||
|
||||
auto threadwise_x_load = ThreadwiseTensorSliceTransfer_v2<XDataType,
|
||||
ComputeDataType,
|
||||
GridDesc_M_K,
|
||||
decltype(thread_buffer_desc_m_k),
|
||||
ThreadBufferLengths_M_K,
|
||||
XThreadBufferDimAccessOrder,
|
||||
XSrcVectorDim,
|
||||
XSrcVectorSize,
|
||||
1,
|
||||
true>(
|
||||
x_grid_desc_m_k,
|
||||
make_multi_index(block_global_id * M_BlockTileSize +
|
||||
thread_m_cluster_id * MThreadSliceSize,
|
||||
thread_k_cluster_id * KThreadSliceSize));
|
||||
|
||||
auto threadwise_mean_load =
|
||||
ThreadwiseTensorSliceTransfer_v2<MeanInvStdDataType,
|
||||
ComputeDataType,
|
||||
GridDesc_M_K,
|
||||
decltype(thread_buffer_desc_m_k),
|
||||
ThreadBufferLengths_M_K,
|
||||
MeanInvStdThreadBufferDimAccessOrder,
|
||||
MeanInvStdSrcVectorDim,
|
||||
MeanInvStdSrcVectorSize,
|
||||
1,
|
||||
true>(
|
||||
mean_grid_desc_m_k,
|
||||
make_multi_index(block_global_id * M_BlockTileSize +
|
||||
thread_m_cluster_id * MThreadSliceSize,
|
||||
thread_k_cluster_id * KThreadSliceSize));
|
||||
|
||||
auto threadwise_inv_std_load =
|
||||
ThreadwiseTensorSliceTransfer_v2<MeanInvStdDataType,
|
||||
ComputeDataType,
|
||||
GridDesc_M_K,
|
||||
decltype(thread_buffer_desc_m_k),
|
||||
ThreadBufferLengths_M_K,
|
||||
MeanInvStdThreadBufferDimAccessOrder,
|
||||
MeanInvStdSrcVectorDim,
|
||||
MeanInvStdSrcVectorSize,
|
||||
1,
|
||||
true>(
|
||||
inv_std_grid_desc_m_k,
|
||||
make_multi_index(block_global_id * M_BlockTileSize +
|
||||
thread_m_cluster_id * MThreadSliceSize,
|
||||
thread_k_cluster_id * KThreadSliceSize));
|
||||
|
||||
auto threadwise_dgamma_store =
|
||||
ThreadwiseTensorSliceTransfer_v1r3<ComputeDataType,
|
||||
DGammaDataType,
|
||||
decltype(thread_buffer_desc_m),
|
||||
GridDesc_M,
|
||||
PassThroughOp,
|
||||
ThreadBufferLengths_M,
|
||||
Sequence<0>,
|
||||
0,
|
||||
DGammaDstVectorSize,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
1,
|
||||
true>(
|
||||
dgamma_grid_desc_m,
|
||||
make_multi_index(block_global_id * M_BlockTileSize +
|
||||
thread_m_cluster_id * MThreadSliceSize),
|
||||
PassThroughOp{});
|
||||
|
||||
auto threadwise_dbeta_store =
|
||||
ThreadwiseTensorSliceTransfer_v1r3<ComputeDataType,
|
||||
DBetaDataType,
|
||||
decltype(thread_buffer_desc_m),
|
||||
GridDesc_M,
|
||||
PassThroughOp,
|
||||
ThreadBufferLengths_M,
|
||||
Sequence<0>,
|
||||
0,
|
||||
DBetaDstVectorSize,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
1,
|
||||
true>(
|
||||
dbeta_grid_desc_m,
|
||||
make_multi_index(block_global_id * M_BlockTileSize +
|
||||
thread_m_cluster_id * MThreadSliceSize),
|
||||
PassThroughOp{});
|
||||
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
|
||||
dgamma_thread_buf(I) = type_convert<ComputeDataType>(0.0f);
|
||||
dbeta_thread_buf(I) = type_convert<ComputeDataType>(0.0f);
|
||||
});
|
||||
|
||||
constexpr auto thread_copy_fwd_step_m_k = make_multi_index(0, K_BlockTileSize);
|
||||
|
||||
for(index_t reducedTiles = 0; reducedTiles < num_k_block_tile_iteration; ++reducedTiles)
|
||||
{
|
||||
threadwise_dy_load.Run(dy_grid_desc_m_k,
|
||||
dy_global_val_buf,
|
||||
thread_buffer_desc_m_k,
|
||||
make_tuple(I0, I0),
|
||||
dy_thread_buf);
|
||||
|
||||
threadwise_x_load.Run(x_grid_desc_m_k,
|
||||
x_global_val_buf,
|
||||
thread_buffer_desc_m_k,
|
||||
make_tuple(I0, I0),
|
||||
x_thread_buf);
|
||||
|
||||
threadwise_mean_load.Run(mean_grid_desc_m_k,
|
||||
mean_global_val_buf,
|
||||
thread_buffer_desc_m_k,
|
||||
make_tuple(I0, I0),
|
||||
mean_thread_buf);
|
||||
|
||||
threadwise_inv_std_load.Run(inv_std_grid_desc_m_k,
|
||||
inv_std_global_val_buf,
|
||||
thread_buffer_desc_m_k,
|
||||
make_tuple(I0, I0),
|
||||
inv_std_thread_buf);
|
||||
|
||||
threadwise_dy_load.MoveSrcSliceWindow(dy_grid_desc_m_k, thread_copy_fwd_step_m_k);
|
||||
threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, thread_copy_fwd_step_m_k);
|
||||
threadwise_mean_load.MoveSrcSliceWindow(mean_grid_desc_m_k, thread_copy_fwd_step_m_k);
|
||||
threadwise_inv_std_load.MoveSrcSliceWindow(inv_std_grid_desc_m_k,
|
||||
thread_copy_fwd_step_m_k);
|
||||
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
|
||||
constexpr auto offset_m =
|
||||
Number<thread_buffer_desc_m.CalculateOffset(make_tuple(iM))>{};
|
||||
|
||||
static_for<0, KThreadSliceSize, 1>{}([&](auto iK) {
|
||||
constexpr auto offset_m_k =
|
||||
Number<thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK))>{};
|
||||
|
||||
dgamma_thread_buf(offset_m) +=
|
||||
dy_thread_buf[offset_m_k] * inv_std_thread_buf[offset_m_k] *
|
||||
(x_thread_buf[offset_m_k] - mean_thread_buf[offset_m_k]);
|
||||
|
||||
dbeta_thread_buf(offset_m) += dy_thread_buf[offset_m_k];
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
|
||||
if constexpr(I > 0)
|
||||
block_sync_lds();
|
||||
|
||||
BlockwiseSumReduce::Reduce(reduce_work_buf, dbeta_thread_buf(I));
|
||||
block_sync_lds();
|
||||
BlockwiseSumReduce::Reduce(reduce_work_buf, dgamma_thread_buf(I));
|
||||
});
|
||||
|
||||
if(thread_k_cluster_id == 0)
|
||||
{
|
||||
threadwise_dgamma_store.Run(thread_buffer_desc_m,
|
||||
make_tuple(I0),
|
||||
dgamma_thread_buf,
|
||||
dgamma_grid_desc_m,
|
||||
dgamma_global_val_buf);
|
||||
|
||||
threadwise_dbeta_store.Run(thread_buffer_desc_m,
|
||||
make_tuple(I0),
|
||||
dbeta_thread_buf,
|
||||
dbeta_grid_desc_m,
|
||||
dbeta_global_val_buf);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
Reference in New Issue
Block a user