diff --git a/example/53_layernorm_bwd/CMakeLists.txt b/example/53_layernorm_bwd/CMakeLists.txt new file mode 100644 index 0000000000..24db221523 --- /dev/null +++ b/example/53_layernorm_bwd/CMakeLists.txt @@ -0,0 +1 @@ +add_example_executable(example_layernorm2d_bwd_fp16 layernorm2d_bwd_fp16.cpp) diff --git a/example/53_layernorm_bwd/layernorm2d_bwd_fp16.cpp b/example/53_layernorm_bwd/layernorm2d_bwd_fp16.cpp new file mode 100644 index 0000000000..f2e6bfb44d --- /dev/null +++ b/example/53_layernorm_bwd/layernorm2d_bwd_fp16.cpp @@ -0,0 +1,165 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include +#include + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_common_util.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" + +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_normalization_bwd_gamma_beta_impl.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_layernorm_bwd.hpp" + +using DYDataType = ck::half_t; +using XDataType = ck::half_t; +using GammaDataType = ck::half_t; +using MeanInvStdDataType = float; +using DGammaDataType = ck::half_t; +using DBetaDataType = ck::half_t; +using DXDataType = ck::half_t; +using ComputeDataType = float; + +constexpr int Rank = 2; +constexpr int NumReduceDim = 1; + +// Layernorm: + +// Input shape +// dy: [M, N] +// x: [M, N] +// mean: [M, 1] +// inv_std: [M, 1] + +// Output shape +// dgamma: [1, N] +// dbeta: [1, N] + +// dgamma = reduce_sum(dy * (x - mean) * inv_std, axis=0) +// dbeta = reduce_sum(dy, axis=0) + +// [CAUSION] +// In DeviceNormalizationBwdGammaBetaImpl, M is invarient dimension, K is reduced dimension +// Hence, M in this example and DeviceNormalizationBwdGammaBetaImpl is different +using GammaBetaDeviceInstance = ck::tensor_operation::device::DeviceNormalizationBwdGammaBetaImpl< + DYDataType, + XDataType, + MeanInvStdDataType, + ComputeDataType, + DGammaDataType, + DBetaDataType, + Rank, + NumReduceDim, + 256, // BlockSize + 8, // ClusterInvarient + 32, // ClusterReduce + 8, // SliceInvarient + 1, // SliceReduce + false, // IsDYFastestDimReduced + 8, // DYSrcVectorSize + false, // IsXFastestDimReduced + 8, // XSrcVectorSize + true, // IsMeanInvStdFastestDimReduced + 1, // MeanInvStdSrcVectorSize + 1, // DGammaDstVectorSize + 1>; // DBetaDstVectorSize + +int main() +{ + bool time_kernel = false; + + ck::index_t M = 1024; + ck::index_t N = 512; + + Tensor dy({M, N}); + Tensor x({M, N}); + Tensor gamma({N}); + Tensor mean({M}); + Tensor inv_std({M}); + + Tensor dgamma({N}); + Tensor dbeta({N}); + Tensor dx({M, N}); + + dy.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + x.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + gamma.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + mean.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + inv_std.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + + DeviceMem dy_dev(sizeof(DYDataType) * dy.mDesc.GetElementSpaceSize()); + DeviceMem x_dev(sizeof(XDataType) * x.mDesc.GetElementSpaceSize()); + DeviceMem mean_dev(sizeof(MeanInvStdDataType) * mean.mDesc.GetElementSpaceSize()); + DeviceMem inv_std_dev(sizeof(MeanInvStdDataType) * inv_std.mDesc.GetElementSpaceSize()); + DeviceMem dgamma_dev(sizeof(DGammaDataType) * dgamma.mDesc.GetElementSpaceSize()); + DeviceMem dbeta_dev(sizeof(DBetaDataType) * dbeta.mDesc.GetElementSpaceSize()); + + dy_dev.ToDevice(dy.mData.data()); + x_dev.ToDevice(x.mData.data()); + mean_dev.ToDevice(mean.mData.data()); + inv_std_dev.ToDevice(inv_std.mData.data()); + + auto gamma_beta_device_instance = GammaBetaDeviceInstance{}; + auto gamma_beta_argument_ptr = + gamma_beta_device_instance.MakeArgumentPointer({M, N}, // inLengths + {N, 1}, // dyStrides + {N, 1}, // xStrides + {1, 0}, // meanStrides + {1, 0}, // invStdStrides + {N}, // outLengths + {1}, // dgammaStrides + {1}, // dbetaStrides + {0}, // reduceDims + dy_dev.GetDeviceBuffer(), + x_dev.GetDeviceBuffer(), + mean_dev.GetDeviceBuffer(), + inv_std_dev.GetDeviceBuffer(), + dgamma_dev.GetDeviceBuffer(), + dbeta_dev.GetDeviceBuffer()); + + if(!gamma_beta_device_instance.IsSupportedArgument(gamma_beta_argument_ptr.get())) + { + std::cout << "The runtime parameters are not supported" << std::endl; + return 1; + }; + + auto gamma_beta_invoker_ptr = gamma_beta_device_instance.MakeInvokerPointer(); + gamma_beta_invoker_ptr->Run(gamma_beta_argument_ptr.get(), StreamConfig{nullptr, time_kernel}); + + bool pass = true; + { + Tensor host_dgamma({N}); + Tensor host_dbeta({N}); + Tensor host_dx({M, N}); + using ReferenceInstance = + ck::tensor_operation::host::ReferenceLayernormBwd; + + ReferenceInstance ref; + auto ref_argument = + ref.MakeArgument(dy, x, gamma, mean, inv_std, host_dgamma, host_dbeta, host_dx, {M, N}); + auto ref_invoker = ref.MakeInvoker(); + ref_invoker.Run(ref_argument); + + dgamma_dev.FromDevice(dgamma.mData.data()); + dbeta_dev.FromDevice(dbeta.mData.data()); + + pass &= ck::utils::check_err(dgamma, host_dgamma, "Error: Incorrect dgamma", 1e-3, 1e-3); + pass &= ck::utils::check_err(dbeta, host_dbeta, "Error: Incorrect dbeta", 1e-3, 1e-3); + } + + return (pass ? 0 : 1); +} diff --git a/example/54_groupnorm_bwd/CMakeLists.txt b/example/54_groupnorm_bwd/CMakeLists.txt new file mode 100644 index 0000000000..ac548cbc79 --- /dev/null +++ b/example/54_groupnorm_bwd/CMakeLists.txt @@ -0,0 +1 @@ +add_example_executable(example_groupnorm_bwd_fp16 groupnorm_bwd_fp16.cpp) diff --git a/example/54_groupnorm_bwd/groupnorm_bwd_fp16.cpp b/example/54_groupnorm_bwd/groupnorm_bwd_fp16.cpp new file mode 100644 index 0000000000..1537a014d4 --- /dev/null +++ b/example/54_groupnorm_bwd/groupnorm_bwd_fp16.cpp @@ -0,0 +1,167 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include +#include + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_common_util.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" + +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_normalization_bwd_gamma_beta_impl.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_groupnorm_bwd.hpp" + +using DYDataType = ck::half_t; +using XDataType = ck::half_t; +using GammaDataType = ck::half_t; +using MeanInvStdDataType = float; +using DGammaDataType = ck::half_t; +using DBetaDataType = ck::half_t; +using DXDataType = ck::half_t; +using ComputeDataType = float; + +constexpr int Rank = 5; +constexpr int NumReduceDim = 3; + +// Grouprnorm +// kernel: M , K +// dy: N, H, W, G, C -> G * C, N * H * W +// x: N, H, W, G, C -> G * C, N * H * W +// mean: N, 1, 1, G, 1 -> G * 1, N * 1 * 1 +// rstd: N, 1, 1, G, 1 -> G * 1, N * 1 * 1 + +// dgamma: 1, 1, 1, G, C -> G * C +// dbeta: 1, 1, 1, G, C -> G * C + +// reduced axis: 0, 1, 2 + +using GammaBetaDeviceInstance = ck::tensor_operation::device::DeviceNormalizationBwdGammaBetaImpl< + DYDataType, + XDataType, + MeanInvStdDataType, + ComputeDataType, + DGammaDataType, + DBetaDataType, + Rank, + NumReduceDim, + 256, // BlockSize + 8, // ClusterInvarient + 32, // ClusterReduce + 8, // SliceInvarient + 1, // SliceReduce + false, // IsDYFastestDimReduced + 8, // DYSrcVectorSize + false, // IsXFastestDimReduced + 8, // XSrcVectorSize + false, // IsMeanInvStdFastestDimReduced + 1, // MeanInvStdSrcVectorSize + 1, // DGammaDstVectorSize + 1>; // DBetaDstVectorSize + +int main() +{ + bool time_kernel = false; + + ck::index_t N = 16; + ck::index_t H = 16; + ck::index_t W = 16; + ck::index_t G = 32; + ck::index_t C = 64; + + Tensor dy({N, H, W, G, C}); + Tensor x({N, H, W, G, C}); + Tensor gamma({G, C}); + Tensor mean({N, G}); + Tensor inv_std({N, G}); + + Tensor dgamma({G, C}); + Tensor dbeta({G, C}); + Tensor dx({N, H, W, G, C}); + + dy.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + x.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + gamma.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + mean.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + inv_std.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + + DeviceMem dy_dev(sizeof(DYDataType) * dy.mDesc.GetElementSpaceSize()); + DeviceMem x_dev(sizeof(XDataType) * x.mDesc.GetElementSpaceSize()); + DeviceMem mean_dev(sizeof(MeanInvStdDataType) * mean.mDesc.GetElementSpaceSize()); + DeviceMem inv_std_dev(sizeof(MeanInvStdDataType) * inv_std.mDesc.GetElementSpaceSize()); + DeviceMem dgamma_dev(sizeof(DGammaDataType) * dgamma.mDesc.GetElementSpaceSize()); + DeviceMem dbeta_dev(sizeof(DBetaDataType) * dbeta.mDesc.GetElementSpaceSize()); + + dy_dev.ToDevice(dy.mData.data()); + x_dev.ToDevice(x.mData.data()); + mean_dev.ToDevice(mean.mData.data()); + inv_std_dev.ToDevice(inv_std.mData.data()); + + std::vector dyStrides{dy.mDesc.GetStrides().begin(), dy.mDesc.GetStrides().end()}; + std::vector xStrides{x.mDesc.GetStrides().begin(), x.mDesc.GetStrides().end()}; + std::vector meanStrides = {G, 0, 0, 1, 0}; + std::vector invStdStrides = {G, 0, 0, 1, 0}; + + auto gamma_beta_device_instance = GammaBetaDeviceInstance{}; + auto gamma_beta_argument_ptr = + gamma_beta_device_instance.MakeArgumentPointer({N, H, W, G, C}, // inLengths + dyStrides, // dyStrides + xStrides, // xStrides + meanStrides, // meanStrides + invStdStrides, // invStdStrides + {G, C}, // outLengths + {C, 1}, // dgammaStrides + {C, 1}, // dbetaStrides + {0, 1, 2}, // reduceDims + dy_dev.GetDeviceBuffer(), + x_dev.GetDeviceBuffer(), + mean_dev.GetDeviceBuffer(), + inv_std_dev.GetDeviceBuffer(), + dgamma_dev.GetDeviceBuffer(), + dbeta_dev.GetDeviceBuffer()); + + if(!gamma_beta_device_instance.IsSupportedArgument(gamma_beta_argument_ptr.get())) + { + std::cout << "The runtime parameters are not supported" << std::endl; + return 1; + }; + + auto gamma_beta_invoker_ptr = gamma_beta_device_instance.MakeInvokerPointer(); + gamma_beta_invoker_ptr->Run(gamma_beta_argument_ptr.get(), StreamConfig{nullptr, time_kernel}); + + bool pass = true; + { + Tensor host_dgamma({G, C}); + Tensor host_dbeta({G, C}); + Tensor host_dx({N, H, W, G, C}); + using ReferenceInstance = + ck::tensor_operation::host::ReferenceGroupnormBwd; + + ReferenceInstance ref; + auto ref_argument = ref.MakeArgument( + dy, x, gamma, mean, inv_std, host_dgamma, host_dbeta, host_dx, {N, H, W, G, C}); + auto ref_invoker = ref.MakeInvoker(); + ref_invoker.Run(ref_argument); + + dgamma_dev.FromDevice(dgamma.mData.data()); + dbeta_dev.FromDevice(dbeta.mData.data()); + + pass &= ck::utils::check_err(dgamma, host_dgamma, "Error: Incorrect dgamma", 1e-3, 1e-3); + pass &= ck::utils::check_err(dbeta, host_dbeta, "Error: Incorrect dbeta", 1e-3, 1e-3); + } + + return (pass ? 0 : 1); +} diff --git a/include/ck/tensor_operation/gpu/device/device_normalization_bwd_gamma_beta.hpp b/include/ck/tensor_operation/gpu/device/device_normalization_bwd_gamma_beta.hpp new file mode 100644 index 0000000000..5a8804cbf0 --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/device_normalization_bwd_gamma_beta.hpp @@ -0,0 +1,61 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/tensor_operation/gpu/device/device_base.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +template +struct DeviceNormalizationBwdGammaBeta : public BaseOperator +{ + virtual std::unique_ptr + MakeArgumentPointer(const std::vector inLengths, + const std::vector dyStrides, + const std::vector xStrides, + const std::vector meanStrides, + const std::vector invStdStrides, + const std::vector outLengths, + const std::vector dgammaStrides, + const std::vector dbetaStrides, + const std::vector reduceDims, + const void* p_dy, + const void* p_x, + const void* p_mean, + const void* p_invStd, + void* p_dgamma, + void* p_dbeta) = 0; + + virtual std::unique_ptr MakeInvokerPointer() = 0; +}; + +template +using DeviceNormalizationBwdGammaBetaPtr = + std::unique_ptr>; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/impl/device_normalization_bwd_gamma_beta_impl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_normalization_bwd_gamma_beta_impl.hpp new file mode 100644 index 0000000000..43d2db4624 --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/impl/device_normalization_bwd_gamma_beta_impl.hpp @@ -0,0 +1,464 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/tensor_operation/gpu/device/device_normalization_bwd_gamma_beta.hpp" +#include "ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_bwd_gamma_beta.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_reduce_common.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" + +// M is invarient dimension, K is reduced dimension +namespace ck { +namespace tensor_operation { +namespace device { + +template +__global__ void +kernel_normalization_bwd_gamma_beta(const GridDesc_M_K dy_grid_desc_m_k, + const GridDesc_M_K x_grid_desc_m_k, + const GridDesc_M_K mean_grid_desc_m_k, + const GridDesc_M_K inv_std_grid_desc_m_k, + const GridDesc_M dgamma_grid_desc_m, + const GridDesc_M dbeta_grid_desc_m, + index_t num_k_block_tile_iteration, + const DYDataType* const __restrict__ p_dy_global, + const XDataType* const __restrict__ p_x_global, + const MeanInvStdDataType* const __restrict__ p_mean_global, + const MeanInvStdDataType* const __restrict__ p_inv_std_global, + DGammaDataType* const __restrict__ p_dgamma_global, + DBetaDataType* const __restrict__ p_dbeta_global) +{ + GridwiseReduction::Run(dy_grid_desc_m_k, + x_grid_desc_m_k, + mean_grid_desc_m_k, + inv_std_grid_desc_m_k, + dgamma_grid_desc_m, + dbeta_grid_desc_m, + num_k_block_tile_iteration, + p_dy_global, + p_x_global, + p_mean_global, + p_inv_std_global, + p_dgamma_global, + p_dbeta_global); +}; + +template +struct DeviceNormalizationBwdGammaBetaImpl + : public DeviceNormalizationBwdGammaBeta +{ + + static constexpr index_t DYSrcVectorDim = IsDYFastestDimReduced ? 1 : 0; + static constexpr index_t XSrcVectorDim = IsXFastestDimReduced ? 1 : 0; + static constexpr index_t MeanInvStdSrcVectorDim = IsMeanInvStdFastestDimReduced ? 1 : 0; + + static_assert(BlockSize == MThreadClusterSize * KThreadClusterSize); + + static_assert(((DYSrcVectorDim == 0 && MThreadSliceSize % DYSrcVectorSize == 0) || + (DYSrcVectorDim == 1 && KThreadSliceSize % DYSrcVectorSize == 0)), + "Invalid thread slice sizes and/or dy vector sizes configuration, please check!"); + + static_assert(((XSrcVectorDim == 0 && MThreadSliceSize % XSrcVectorSize == 0) || + (XSrcVectorDim == 1 && KThreadSliceSize % XSrcVectorSize == 0)), + "Invalid thread slice sizes and/or x vector sizes configuration, please check!"); + + static_assert( + ((MThreadSliceSize % DGammaDstVectorSize == 0) || + (MThreadSliceSize % DBetaDstVectorSize == 0)), + "Invalid thread slice sizes and/or Gamma and beta vector sizes configuration, please " + "check!"); + + static_assert( + (MeanInvStdSrcVectorDim == 0 && MThreadSliceSize % MeanInvStdSrcVectorSize == 0) || + (MeanInvStdSrcVectorDim == 1 && KThreadSliceSize % MeanInvStdSrcVectorSize == 0), + "Invalid thread slice sizes and/or mean and inverse std vector sizes configuration, please " + "check!"); + + static constexpr index_t NumInvariantDim = Rank - NumReduceDim; + static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize; + static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize; + + static constexpr bool reduceAllDim = (NumInvariantDim == 0); + static_assert(!reduceAllDim); + + static auto MakeSrc2dDescriptor(const std::vector& inLengths, + const std::vector& inStrides, + int numBlockTileIteration) + { + const auto tupleSrcLengths = make_tuple_from_array(inLengths, Number{}); + const auto tupleSrcStrides = make_tuple_from_array(inStrides, Number{}); + + const auto inDesc = make_naive_tensor_descriptor(tupleSrcLengths, tupleSrcStrides); + + const auto in_grid_desc_m_k = [&]() { + using InvariantDims = typename arithmetic_sequence_gen<0, NumInvariantDim, 1>::type; + using ReduceDims = typename arithmetic_sequence_gen::type; + + const auto reduceDimLengths = + make_tuple_from_array_and_index_seq(inLengths, ReduceDims{}); + const auto invariantDimLengths = + make_tuple_from_array_and_index_seq(inLengths, InvariantDims{}); + + return transform_tensor_descriptor(inDesc, + make_tuple(make_merge_transform(invariantDimLengths), + make_merge_transform(reduceDimLengths)), + make_tuple(InvariantDims{}, ReduceDims{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + }(); + + const auto invariantLength = in_grid_desc_m_k.GetLength(Number<0>{}); + const auto reduceLength = in_grid_desc_m_k.GetLength(Number<1>{}); + + const auto inPad_M = + math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength; + const auto inPad_K = K_BlockTileSize * numBlockTileIteration - reduceLength; + + auto in_grid_desc_m_k_padded = transform_tensor_descriptor( + in_grid_desc_m_k, + make_tuple(make_right_pad_transform(invariantLength, inPad_M), + make_right_pad_transform(reduceLength, inPad_K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return in_grid_desc_m_k_padded; + } + + static auto MakeDst1dDescriptor(const std::vector& outLengths, + const std::vector& outStrides) + { + const auto tupleDstLengths = + generate_tuple([&](auto I) { return outLengths[I]; }, Number{}); + const auto tupleDstStrides = + generate_tuple([&](auto I) { return outStrides[I]; }, Number{}); + + auto outDesc = make_naive_tensor_descriptor(tupleDstLengths, tupleDstStrides); + + auto out_grid_desc_m = transform_tensor_descriptor( + outDesc, + make_tuple(make_merge_transform(tupleDstLengths)), + make_tuple(typename arithmetic_sequence_gen<0, NumInvariantDim, 1>::type{}), + make_tuple(Sequence<0>{})); + + const auto invariantLength = out_grid_desc_m.GetLength(Number<0>{}); + + const auto outPad = + math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength; + + auto out_grid_desc_m_padded = transform_tensor_descriptor( + out_grid_desc_m, + make_tuple(make_right_pad_transform(invariantLength, outPad)), + make_tuple(Sequence<0>{}), + make_tuple(Sequence<0>{})); + return (out_grid_desc_m_padded); + }; + + using GridDesc_M_K = decltype(MakeSrc2dDescriptor({1}, {1}, 1)); + using GridDesc_M = decltype(MakeDst1dDescriptor({1}, {1})); + + using GridwiseNormalizationBwdGammaBeta = + GridwiseNormalizationBwdGammaBeta_mk_to_k; + + struct Argument : public BaseArgument + { + Argument(const std::vector inLengths, + const std::vector dyStrides, + const std::vector xStrides, + const std::vector meanStrides, + const std::vector invStdStrides, + const std::vector outLengths, + const std::vector dgammaStrides, + const std::vector dbetaStrides, + const std::vector reduceDims, + const DYDataType* p_dy, + const XDataType* p_x, + const MeanInvStdDataType* p_mean, + const MeanInvStdDataType* p_invStd, + DGammaDataType* p_dgamma, + DBetaDataType* p_dbeta) + : p_dy_(p_dy), + p_x_(p_x), + p_mean_(p_mean), + p_invStd_(p_invStd), + p_dgamma_(p_dgamma), + p_dbeta_(p_dbeta), + outLengths_{outLengths}, + dgammaStrides_{dgammaStrides}, + dbetaStrides_{dbetaStrides} + { + inLengths_ = shuffle_tensor_dimensions(inLengths, reduceDims); + dyStrides_ = shuffle_tensor_dimensions(dyStrides, reduceDims); + xStrides_ = shuffle_tensor_dimensions(xStrides, reduceDims); + meanStrides_ = shuffle_tensor_dimensions(meanStrides, reduceDims); + invStdStrides_ = + shuffle_tensor_dimensions(invStdStrides, reduceDims); + + std::tie(MRaw_, KRaw_) = get_2d_lengths(inLengths_); + + numBlockTileIteration_ = math::integer_divide_ceil(KRaw_, K_BlockTileSize); + + gridSize_ = math::integer_divide_ceil(MRaw_, M_BlockTileSize); + + dy_grid_desc_m_k_ = MakeSrc2dDescriptor(inLengths_, dyStrides_, numBlockTileIteration_); + x_grid_desc_m_k_ = MakeSrc2dDescriptor(inLengths_, xStrides_, numBlockTileIteration_); + mean_grid_desc_m_k_ = + MakeSrc2dDescriptor(inLengths_, meanStrides_, numBlockTileIteration_); + inv_std_grid_desc_m_k_ = + MakeSrc2dDescriptor(inLengths_, invStdStrides_, numBlockTileIteration_); + + dgamma_grid_desc_m_ = MakeDst1dDescriptor(outLengths_, dgammaStrides_); + dbeta_grid_desc_m_ = MakeDst1dDescriptor(outLengths_, dbetaStrides_); + } + + const DYDataType* p_dy_; + const XDataType* p_x_; + const MeanInvStdDataType* p_mean_; + const MeanInvStdDataType* p_invStd_; + DGammaDataType* p_dgamma_; + DBetaDataType* p_dbeta_; + + std::vector inLengths_; + std::vector dyStrides_; + std::vector xStrides_; + std::vector meanStrides_; + std::vector invStdStrides_; + std::vector outLengths_; + std::vector dgammaStrides_; + std::vector dbetaStrides_; + + int numBlockTileIteration_; + size_t gridSize_; + + // Source descriptor + GridDesc_M_K dy_grid_desc_m_k_; + GridDesc_M_K x_grid_desc_m_k_; + GridDesc_M_K mean_grid_desc_m_k_; + GridDesc_M_K inv_std_grid_desc_m_k_; + + // Destination descriptor + GridDesc_M dgamma_grid_desc_m_; + GridDesc_M dbeta_grid_desc_m_; + + index_t MRaw_; // invarient length + index_t KRaw_; // reduce length + }; + + struct Invoker : public BaseInvoker + { + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + const auto kernel_main = + kernel_normalization_bwd_gamma_beta; + + return launch_and_time_kernel(stream_config, + kernel_main, + dim3(arg.gridSize_), + dim3(BlockSize), + 0, + arg.dy_grid_desc_m_k_, + arg.x_grid_desc_m_k_, + arg.mean_grid_desc_m_k_, + arg.inv_std_grid_desc_m_k_, + arg.dgamma_grid_desc_m_, + arg.dbeta_grid_desc_m_, + arg.numBlockTileIteration_, + arg.p_dy_, + arg.p_x_, + arg.p_mean_, + arg.p_invStd_, + arg.p_dgamma_, + arg.p_dbeta_); + } + + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + template + bool IsSrcVectorDimSizeValid(const std::vector& lengths, + const std::vector& strides) + { + if constexpr(SrcVectorSize == 1) + return true; + + // Fastest dimension is not reduced + if constexpr(SrcVectorDim == 0) + { + if constexpr(NumInvariantDim == 0) + return false; + + if(strides[NumInvariantDim - 1] != 1) + return false; + + if(lengths[NumInvariantDim - 1] % SrcVectorSize != 0) + return false; + } + else // Fastest dimension is reduced + { + if(strides[Rank - 1] != 1) + return false; + + if(lengths[Rank - 1] % SrcVectorSize != 0) + return false; + }; + + return true; + } + + template + bool IsDstVectorSizeValid(const std::vector& lengths, + const std::vector& strides) + { + if constexpr(DstVectorSize == 1) + return true; + + if(strides[NumInvariantDim - 1] != 1) + return false; + + if(lengths[NumInvariantDim - 1] % DstVectorSize != 0) + return false; + + return true; + } + + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + const Argument* p_arg_ = dynamic_cast(p_arg); + + bool pass = true; + pass &= IsSrcVectorDimSizeValid(p_arg_->inLengths_, + p_arg_->dyStrides_); + pass &= IsSrcVectorDimSizeValid(p_arg_->inLengths_, + p_arg_->xStrides_); + pass &= IsSrcVectorDimSizeValid( + p_arg_->inLengths_, p_arg_->meanStrides_); + pass &= IsSrcVectorDimSizeValid( + p_arg_->inLengths_, p_arg_->invStdStrides_); + + pass &= + IsDstVectorSizeValid(p_arg_->outLengths_, p_arg_->dgammaStrides_); + pass &= + IsDstVectorSizeValid(p_arg_->outLengths_, p_arg_->dbetaStrides_); + + return pass; + } + + std::unique_ptr MakeArgumentPointer(const std::vector inLengths, + const std::vector dyStrides, + const std::vector xStrides, + const std::vector meanStrides, + const std::vector invStdStrides, + const std::vector outLengths, + const std::vector dgammaStrides, + const std::vector dbetaStrides, + const std::vector reduceDims, + const void* p_dy, + const void* p_x, + const void* p_mean, + const void* p_invStd, + void* p_dgamma, + void* p_dbeta) override + { + if(inLengths.size() != Rank || dyStrides.size() != Rank || xStrides.size() != Rank || + meanStrides.size() != Rank || invStdStrides.size() != Rank) + throw std::runtime_error("dimension is incorrect"); + + if(outLengths.size() != NumInvariantDim || dgammaStrides.size() != NumInvariantDim || + dbetaStrides.size() != NumInvariantDim) + throw std::runtime_error("dimension is incorrect"); + + return std::make_unique(inLengths, + dyStrides, + xStrides, + meanStrides, + invStdStrides, + outLengths, + dgammaStrides, + dbetaStrides, + reduceDims, + static_cast(p_dy), + static_cast(p_x), + static_cast(p_mean), + static_cast(p_invStd), + static_cast(p_dgamma), + static_cast(p_dbeta)); + } + + virtual std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_bwd_gamma_beta.hpp b/include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_bwd_gamma_beta.hpp new file mode 100644 index 0000000000..e80c360bb2 --- /dev/null +++ b/include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_bwd_gamma_beta.hpp @@ -0,0 +1,343 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/data_type.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/block/reduction_functions_blockwise.hpp" + +namespace ck { + +// dgamma = reduce_sum(dy * (x - mean) * inv_std) +// dbeta = reduce_sum(dy) +template +struct GridwiseNormalizationBwdGammaBeta_mk_to_k +{ + // if we just check ThreadSliceSize & VectorSize == 0, the performance may be poor + static_assert(((DYSrcVectorDim == 0 && MThreadSliceSize == DYSrcVectorSize) || + (DYSrcVectorDim == 1 && KThreadSliceSize == DYSrcVectorSize)), + "Invalid thread slice sizes and/or dy vector sizes configuration, please check!"); + + static_assert(((XSrcVectorDim == 0 && MThreadSliceSize == XSrcVectorSize) || + (XSrcVectorDim == 1 && KThreadSliceSize == XSrcVectorSize)), + "Invalid thread slice sizes and/or x vector sizes configuration, please check!"); + + using ThreadClusterLengths_M_K = Sequence; + + using DYThreadBufferDimAccessOrder = + typename conditional, Sequence<0, 1>>::type; + + using XThreadBufferDimAccessOrder = + typename conditional, Sequence<0, 1>>::type; + + using MeanInvStdThreadBufferDimAccessOrder = + typename conditional, Sequence<0, 1>>::type; + + using ThreadClusterArrangeOrder = DYThreadBufferDimAccessOrder; + + static constexpr auto thread_cluster_desc = + make_cluster_descriptor(ThreadClusterLengths_M_K{}, ThreadClusterArrangeOrder{}); + + using ThreadBufferLengths_M_K = Sequence; + using ThreadBufferLengths_M = Sequence; + + static constexpr auto thread_buffer_desc_m_k = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{})); + + static constexpr auto thread_buffer_desc_m = + make_naive_tensor_descriptor_packed(make_tuple(Number{})); + + using PassThroughOp = tensor_operation::element_wise::PassThrough; + + using BlockwiseSumReduce = PartitionedBlockwiseReduction; + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + + static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize; + static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize; + + __device__ static void Run(const GridDesc_M_K& dy_grid_desc_m_k, + const GridDesc_M_K& x_grid_desc_m_k, + const GridDesc_M_K& mean_grid_desc_m_k, + const GridDesc_M_K& inv_std_grid_desc_m_k, + const GridDesc_M& dgamma_grid_desc_m, + const GridDesc_M& dbeta_grid_desc_m, + index_t num_k_block_tile_iteration, + const DYDataType* const __restrict__ p_dy_global, + const XDataType* const __restrict__ p_x_global, + const MeanInvStdDataType* const __restrict__ p_mean_global, + const MeanInvStdDataType* const __restrict__ p_inv_std_global, + DGammaDataType* const __restrict__ p_dgamma_global, + DBetaDataType* const __restrict__ p_dbeta_global) + { + // LDS + __shared__ ComputeDataType p_reduce_work_buffer[BlockSize]; + + auto reduce_work_buf = + make_dynamic_buffer(p_reduce_work_buffer, BlockSize); + + // Global + const auto dy_global_val_buf = make_dynamic_buffer( + p_dy_global, dy_grid_desc_m_k.GetElementSpaceSize()); + + const auto x_global_val_buf = make_dynamic_buffer( + p_x_global, x_grid_desc_m_k.GetElementSpaceSize()); + + const auto mean_global_val_buf = make_dynamic_buffer( + p_mean_global, mean_grid_desc_m_k.GetElementSpaceSize()); + + const auto inv_std_global_val_buf = make_dynamic_buffer( + p_inv_std_global, inv_std_grid_desc_m_k.GetElementSpaceSize()); + + auto dgamma_global_val_buf = make_dynamic_buffer( + p_dgamma_global, dgamma_grid_desc_m.GetElementSpaceSize()); + + auto dbeta_global_val_buf = make_dynamic_buffer( + p_dbeta_global, dbeta_grid_desc_m.GetElementSpaceSize()); + + // VGPR + auto dy_thread_buf = StaticBuffer{}; + + auto x_thread_buf = StaticBuffer{}; + + auto mean_thread_buf = StaticBuffer{}; + + auto inv_std_thread_buf = StaticBuffer{}; + + auto dgamma_thread_buf = + StaticBuffer{}; + + auto dbeta_thread_buf = + StaticBuffer{}; + + const index_t thread_local_id = get_thread_local_1d_id(); + const index_t block_global_id = get_block_1d_id(); + + const auto thread_cluster_idx = + thread_cluster_desc.CalculateBottomIndex(make_multi_index(thread_local_id)); + + const auto thread_m_cluster_id = thread_cluster_idx[I0]; + const auto thread_k_cluster_id = thread_cluster_idx[I1]; + + // IO + auto threadwise_dy_load = ThreadwiseTensorSliceTransfer_v2( + dy_grid_desc_m_k, + make_multi_index(block_global_id * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize, + thread_k_cluster_id * KThreadSliceSize)); + + auto threadwise_x_load = ThreadwiseTensorSliceTransfer_v2( + x_grid_desc_m_k, + make_multi_index(block_global_id * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize, + thread_k_cluster_id * KThreadSliceSize)); + + auto threadwise_mean_load = + ThreadwiseTensorSliceTransfer_v2( + mean_grid_desc_m_k, + make_multi_index(block_global_id * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize, + thread_k_cluster_id * KThreadSliceSize)); + + auto threadwise_inv_std_load = + ThreadwiseTensorSliceTransfer_v2( + inv_std_grid_desc_m_k, + make_multi_index(block_global_id * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize, + thread_k_cluster_id * KThreadSliceSize)); + + auto threadwise_dgamma_store = + ThreadwiseTensorSliceTransfer_v1r3, + 0, + DGammaDstVectorSize, + InMemoryDataOperationEnum::Set, + 1, + true>( + dgamma_grid_desc_m, + make_multi_index(block_global_id * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize), + PassThroughOp{}); + + auto threadwise_dbeta_store = + ThreadwiseTensorSliceTransfer_v1r3, + 0, + DBetaDstVectorSize, + InMemoryDataOperationEnum::Set, + 1, + true>( + dbeta_grid_desc_m, + make_multi_index(block_global_id * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize), + PassThroughOp{}); + + static_for<0, MThreadSliceSize, 1>{}([&](auto I) { + dgamma_thread_buf(I) = type_convert(0.0f); + dbeta_thread_buf(I) = type_convert(0.0f); + }); + + constexpr auto thread_copy_fwd_step_m_k = make_multi_index(0, K_BlockTileSize); + + for(index_t reducedTiles = 0; reducedTiles < num_k_block_tile_iteration; ++reducedTiles) + { + threadwise_dy_load.Run(dy_grid_desc_m_k, + dy_global_val_buf, + thread_buffer_desc_m_k, + make_tuple(I0, I0), + dy_thread_buf); + + threadwise_x_load.Run(x_grid_desc_m_k, + x_global_val_buf, + thread_buffer_desc_m_k, + make_tuple(I0, I0), + x_thread_buf); + + threadwise_mean_load.Run(mean_grid_desc_m_k, + mean_global_val_buf, + thread_buffer_desc_m_k, + make_tuple(I0, I0), + mean_thread_buf); + + threadwise_inv_std_load.Run(inv_std_grid_desc_m_k, + inv_std_global_val_buf, + thread_buffer_desc_m_k, + make_tuple(I0, I0), + inv_std_thread_buf); + + threadwise_dy_load.MoveSrcSliceWindow(dy_grid_desc_m_k, thread_copy_fwd_step_m_k); + threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, thread_copy_fwd_step_m_k); + threadwise_mean_load.MoveSrcSliceWindow(mean_grid_desc_m_k, thread_copy_fwd_step_m_k); + threadwise_inv_std_load.MoveSrcSliceWindow(inv_std_grid_desc_m_k, + thread_copy_fwd_step_m_k); + + static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { + constexpr auto offset_m = + Number{}; + + static_for<0, KThreadSliceSize, 1>{}([&](auto iK) { + constexpr auto offset_m_k = + Number{}; + + dgamma_thread_buf(offset_m) += + dy_thread_buf[offset_m_k] * inv_std_thread_buf[offset_m_k] * + (x_thread_buf[offset_m_k] - mean_thread_buf[offset_m_k]); + + dbeta_thread_buf(offset_m) += dy_thread_buf[offset_m_k]; + }); + }); + } + + static_for<0, MThreadSliceSize, 1>{}([&](auto I) { + if constexpr(I > 0) + block_sync_lds(); + + BlockwiseSumReduce::Reduce(reduce_work_buf, dbeta_thread_buf(I)); + block_sync_lds(); + BlockwiseSumReduce::Reduce(reduce_work_buf, dgamma_thread_buf(I)); + }); + + if(thread_k_cluster_id == 0) + { + threadwise_dgamma_store.Run(thread_buffer_desc_m, + make_tuple(I0), + dgamma_thread_buf, + dgamma_grid_desc_m, + dgamma_global_val_buf); + + threadwise_dbeta_store.Run(thread_buffer_desc_m, + make_tuple(I0), + dbeta_thread_buf, + dbeta_grid_desc_m, + dbeta_global_val_buf); + } + } +}; + +} // namespace ck diff --git a/library/include/ck/library/reference_tensor_operation/cpu/reference_groupnorm_bwd.hpp b/library/include/ck/library/reference_tensor_operation/cpu/reference_groupnorm_bwd.hpp new file mode 100644 index 0000000000..37f875d07a --- /dev/null +++ b/library/include/ck/library/reference_tensor_operation/cpu/reference_groupnorm_bwd.hpp @@ -0,0 +1,207 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include +#include + +#include "ck/tensor_operation/gpu/device/device_base.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" + +namespace ck { +namespace tensor_operation { +namespace host { + +template +struct ReferenceGroupnormBwd : public device::BaseOperator +{ + // Argument + struct Argument : public device::BaseArgument + { + Argument(const Tensor& dy_nhwgc, + const Tensor& x_nhwgc, + const Tensor& gamma_gc, + const Tensor& mean_ng, + const Tensor& inv_std_ng, + Tensor& dgamma_gc, + Tensor& dbeta_gc, + Tensor& dx_nhwgc, + const std::vector lengths) + : dy_nhwgc_(dy_nhwgc), + x_nhwgc_(x_nhwgc), + gamma_gc_(gamma_gc), + mean_ng_(mean_ng), + inv_std_ng_(inv_std_ng), + dgamma_gc_(dgamma_gc), + dbeta_gc_(dbeta_gc), + dx_nhwgc_(dx_nhwgc), + lengths_(lengths) + { + } + + const Tensor& dy_nhwgc_; + const Tensor& x_nhwgc_; + const Tensor& gamma_gc_; + const Tensor& mean_ng_; + const Tensor& inv_std_ng_; + Tensor& dgamma_gc_; + Tensor& dbeta_gc_; + Tensor& dx_nhwgc_; + std::vector lengths_; + }; + + // Invoker + struct Invoker : public device::BaseInvoker + { + float Run(const Argument& arg) + { + int N = arg.lengths_[0]; + int H = arg.lengths_[1]; + int W = arg.lengths_[2]; + int G = arg.lengths_[3]; + int C = arg.lengths_[4]; + + // Calculate dgamma and dbeta + for(int g = 0; g < G; ++g) + for(int c = 0; c < C; ++c) + { + ComputeDataType dgamma = 0; + ComputeDataType dbeta = 0; + + for(int n = 0; n < N; ++n) + for(int h = 0; h < H; ++h) + for(int w = 0; w < W; ++w) + { + ComputeDataType dy = + ck::type_convert(arg.dy_nhwgc_(n, h, w, g, c)); + ComputeDataType x = + ck::type_convert(arg.x_nhwgc_(n, h, w, g, c)); + ComputeDataType mean = + ck::type_convert(arg.mean_ng_(n, g)); + ComputeDataType rstd = + ck::type_convert(arg.inv_std_ng_(n, g)); + dgamma += dy * rstd * (x - mean); + dbeta += dy; + } + arg.dgamma_gc_(g, c) = ck::type_convert(dgamma); + arg.dbeta_gc_(g, c) = ck::type_convert(dbeta); + } + + // Calculate dx + int reduce_size = H * W * C; + for(int n = 0; n < N; ++n) + for(int g = 0; g < G; ++g) + { + ComputeDataType ds = 0; + ComputeDataType db = 0; + + ComputeDataType mean = ck::type_convert(arg.mean_ng_(n, g)); + ComputeDataType rstd = ck::type_convert(arg.inv_std_ng_(n, g)); + + for(int h = 0; h < H; ++h) + for(int w = 0; w < W; ++w) + for(int c = 0; c < C; ++c) + { + ComputeDataType dy = + ck::type_convert(arg.dy_nhwgc_(n, h, w, g, c)); + ComputeDataType x = + ck::type_convert(arg.x_nhwgc_(n, h, w, g, c)); + ComputeDataType gamma = + ck::type_convert(arg.gamma_gc_(g, c)); + + ds += dy * gamma * x; + db += dy * gamma; + } + + for(int h = 0; h < H; ++h) + for(int w = 0; w < W; ++w) + for(int c = 0; c < C; ++c) + { + ComputeDataType dy = + ck::type_convert(arg.dy_nhwgc_(n, h, w, g, c)); + ComputeDataType x = + ck::type_convert(arg.x_nhwgc_(n, h, w, g, c)); + ComputeDataType gamma = + ck::type_convert(arg.gamma_gc_(g, c)); + + ComputeDataType b = + (db * mean - ds) * rstd * rstd * rstd / reduce_size; + ComputeDataType c1 = -b * mean - db * rstd / reduce_size; + arg.dx_nhwgc_(n, h, w, g, c) = + ck::type_convert(dy * gamma * rstd + b * x + c1); + } + } + + return 0; + } + + float Run(const device::BaseArgument* p_arg, + const StreamConfig& /* stream_config */ = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg)); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + bool IsSupportedArgument(const device::BaseArgument*) override { return true; } + + static auto MakeArgument(const Tensor& dy_nhwgc, + const Tensor& x_nhwgc, + const Tensor& gamma_gc, + const Tensor& mean_ng, + const Tensor& inv_std_ng, + Tensor& dgamma_gc, + Tensor& dbeta_gc, + Tensor& dx_nhwgc, + const std::vector lengths) + { + return Argument{dy_nhwgc, + x_nhwgc, + gamma_gc, + mean_ng, + inv_std_ng, + dgamma_gc, + dbeta_gc, + dx_nhwgc, + lengths}; + } + + static auto MakeInvoker() { return Invoker{}; } + + virtual std::unique_ptr MakeInvokerPointer() + { + return std::make_unique(Invoker{}); + } + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "ReferenceGroupnormBwd" + << std::endl; + // clang-format on + + return str.str(); + } +}; + +} // namespace host +} // namespace tensor_operation +} // namespace ck diff --git a/library/include/ck/library/reference_tensor_operation/cpu/reference_layernorm_bwd.hpp b/library/include/ck/library/reference_tensor_operation/cpu/reference_layernorm_bwd.hpp new file mode 100644 index 0000000000..102002d1a2 --- /dev/null +++ b/library/include/ck/library/reference_tensor_operation/cpu/reference_layernorm_bwd.hpp @@ -0,0 +1,177 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include +#include + +#include "ck/tensor_operation/gpu/device/device_base.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" + +namespace ck { +namespace tensor_operation { +namespace host { + +template +struct ReferenceLayernormBwd : public device::BaseOperator +{ + // Argument + struct Argument : public device::BaseArgument + { + Argument(const Tensor& dy_m_n, + const Tensor& x_m_n, + const Tensor& gamma_n, + const Tensor& mean_m, + const Tensor& inv_std_m, + Tensor& dgamma_n, + Tensor& dbeta_n, + Tensor& dx_m_n, + const std::vector lengths) + : dy_m_n_(dy_m_n), + x_m_n_(x_m_n), + gamma_n_(gamma_n), + mean_m_(mean_m), + inv_std_m_(inv_std_m), + dgamma_n_(dgamma_n), + dbeta_n_(dbeta_n), + dx_m_n_(dx_m_n), + lengths_(lengths) + { + } + + const Tensor& dy_m_n_; + const Tensor& x_m_n_; + const Tensor& gamma_n_; + const Tensor& mean_m_; + const Tensor& inv_std_m_; + Tensor& dgamma_n_; + Tensor& dbeta_n_; + Tensor& dx_m_n_; + std::vector lengths_; + }; + + // Invoker + struct Invoker : public device::BaseInvoker + { + float Run(const Argument& arg) + { + int M = arg.lengths_[0]; + int N = arg.lengths_[1]; + + // Calculate dgamma and dbeta + for(int n = 0; n < N; ++n) + { + ComputeDataType dgamma = 0; + ComputeDataType dbeta = 0; + + for(int m = 0; m < M; ++m) + { + ComputeDataType dy = ck::type_convert(arg.dy_m_n_(m, n)); + ComputeDataType x = ck::type_convert(arg.x_m_n_(m, n)); + ComputeDataType mean = ck::type_convert(arg.mean_m_(m)); + ComputeDataType rstd = ck::type_convert(arg.inv_std_m_(m)); + dgamma += dy * rstd * (x - mean); + dbeta += dy; + } + arg.dgamma_n_(n) = ck::type_convert(dgamma); + arg.dbeta_n_(n) = ck::type_convert(dbeta); + } + + // Calculate dx + for(int m = 0; m < M; ++m) + { + ComputeDataType ds = 0; + ComputeDataType db = 0; + + ComputeDataType mean = ck::type_convert(arg.mean_m_(m)); + ComputeDataType rstd = ck::type_convert(arg.inv_std_m_(m)); + + for(int n = 0; n < N; ++n) + { + ComputeDataType dy = ck::type_convert(arg.dy_m_n_(m, n)); + ComputeDataType x = ck::type_convert(arg.x_m_n_(m, n)); + ComputeDataType gamma = ck::type_convert(arg.gamma_n_(n)); + + ds += dy * gamma * x; + db += dy * gamma; + } + + for(int n = 0; n < N; ++n) + { + ComputeDataType dy = ck::type_convert(arg.dy_m_n_(m, n)); + ComputeDataType x = ck::type_convert(arg.x_m_n_(m, n)); + ComputeDataType gamma = ck::type_convert(arg.gamma_n_(n)); + + ComputeDataType b = (db * mean - ds) * rstd * rstd * rstd / N; + ComputeDataType c = -b * mean - db * rstd / N; + + arg.dx_m_n_(m, n) = ck::type_convert(dy * gamma * rstd + b * x + c); + } + } + + return 0; + } + + float Run(const device::BaseArgument* p_arg, + const StreamConfig& /* stream_config */ = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg)); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + bool IsSupportedArgument(const device::BaseArgument*) override { return true; } + + static auto MakeArgument(const Tensor& dy_m_n, + const Tensor& x_m_n, + const Tensor& gamma_n, + const Tensor& mean_m, + const Tensor& inv_std_m, + Tensor& dgamma_n, + Tensor& dbeta_n, + Tensor& dx_m_n, + const std::vector lengths) + { + return Argument{ + dy_m_n, x_m_n, gamma_n, mean_m, inv_std_m, dgamma_n, dbeta_n, dx_m_n, lengths}; + } + + static auto MakeInvoker() { return Invoker{}; } + + virtual std::unique_ptr MakeInvokerPointer() + { + return std::make_unique(Invoker{}); + } + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "ReferenceLayernormBwd" + << std::endl; + // clang-format on + + return str.str(); + } +}; + +} // namespace host +} // namespace tensor_operation +} // namespace ck