diff --git a/example/27_layernorm/layernorm_blockwise.cpp b/example/27_layernorm/layernorm_blockwise.cpp index 38a2a63663..7166cae5d3 100644 --- a/example/27_layernorm/layernorm_blockwise.cpp +++ b/example/27_layernorm/layernorm_blockwise.cpp @@ -9,7 +9,7 @@ #include "ck/ck.hpp" #include "ck/utility/reduction_enums.hpp" -#include "ck/tensor_operation/gpu/device/device_layernorm.hpp" +#include "ck/tensor_operation/gpu/device/device_layernorm_impl.hpp" #include "ck/tensor_operation/gpu/device/reduction_operator_mapping.hpp" #include "ck/library/utility/check_err.hpp" @@ -29,24 +29,24 @@ using PassThrough = ck::tensor_operation::element_wise::PassThrough; constexpr int Rank = 2; constexpr int NumReduceDim = 1; -using DeviceInstance = ck::tensor_operation::device::DeviceLayernorm; // OutScalarPerVector +using DeviceInstance = ck::tensor_operation::device::DeviceLayernormImpl; // OutScalarPerVector int main() { @@ -90,6 +90,7 @@ int main() std::vector{x.mDesc.GetStrides().begin(), x.mDesc.GetStrides().end()}, std::vector{gamma.mDesc.GetStrides().begin(), gamma.mDesc.GetStrides().end()}, std::vector{beta.mDesc.GetStrides().begin(), beta.mDesc.GetStrides().end()}, + std::vector{y.mDesc.GetStrides().begin(), y.mDesc.GetStrides().end()}, {1}, 1e-4, x_dev.GetDeviceBuffer(), diff --git a/include/ck/tensor_operation/gpu/block/blockwise_welford.hpp b/include/ck/tensor_operation/gpu/block/blockwise_welford.hpp new file mode 100644 index 0000000000..316508651e --- /dev/null +++ b/include/ck/tensor_operation/gpu/block/blockwise_welford.hpp @@ -0,0 +1,108 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/tensor_description/cluster_descriptor.hpp" +#include "ck/utility/reduction_common.hpp" + +namespace ck { + +// clang-format off +// Assume: +// 1) work_buffer is buffer (typically LDS) allocated outside as workspace +// 2) work_buffer has T elements, and space size is no less than 3*BlockSize +// 3) mean_value, var_value and count is the input data in vgpr from each thread +// 4) mean_value, var_value and count is the over-written reduced output in vgpr for each thread +// 5) Merge mean and M from ThreadwiseWelford +// clang-format on +template +struct BlockwiseWelford +{ + static_assert(BlockSize == ThreadClusterLengths_M_K::At(0) * ThreadClusterLengths_M_K::At(1), + "The product of cluster lengths should be same as BlockSize!"); + + static constexpr auto BufferLength_M = ThreadClusterLengths_M_K::At(0); + static constexpr auto BufferLength_K = ThreadClusterLengths_M_K::At(1); + + static constexpr auto block_buf_desc_m_k = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{})); + + static constexpr auto thread_cluster_desc = + make_cluster_descriptor(ThreadClusterLengths_M_K{}, ThreadClusterArrangeOrder{}); + + __device__ static inline void + Merge(T& mean_a, T& var_a, int& count_a, T mean_b, T var_b, int count_b) + { + int count = count_a + count_b; + T count_b_over_count = count == 0 ? type_convert(0) : type_convert(count_b) / count; + T delta = mean_b - mean_a; + mean_a += delta * count_b_over_count; + var_a += var_b + delta * delta * count_a * count_b_over_count; + count_a = count; + } + + __device__ static void Run(T& mean_value, T& var_value, int& count) + { + __shared__ T mean_block_buf[BlockSize]; + __shared__ T var_block_buf[BlockSize]; + __shared__ int count_block_buf[BlockSize]; + + constexpr auto cluster_len_shift = get_shift(); + + const auto thread_cluster_idx = + thread_cluster_desc.CalculateBottomIndex(make_multi_index(get_thread_local_1d_id())); + + const auto thread_m_cluster_id = thread_cluster_idx[Number<0>{}]; + const auto thread_k_cluster_id = thread_cluster_idx[Number<1>{}]; + + index_t offset1 = block_buf_desc_m_k.CalculateOffset(thread_cluster_idx); + + mean_block_buf[offset1] = mean_value; + var_block_buf[offset1] = var_value; + count_block_buf[offset1] = count; + + block_sync_lds(); + + static_for<0, cluster_len_shift, 1>{}([&](auto I) { + constexpr index_t indOffset = 1 << (cluster_len_shift - 1 - I()); + + if(thread_k_cluster_id < indOffset) + { + index_t offset2 = block_buf_desc_m_k.CalculateOffset(thread_cluster_idx + + make_tuple(0, indOffset)); + + T mean1 = mean_block_buf[offset1]; + T var1 = var_block_buf[offset1]; + int count1 = count_block_buf[offset1]; + + T mean2 = mean_block_buf[offset2]; + T var2 = var_block_buf[offset2]; + int count2 = count_block_buf[offset2]; + + Merge(mean1, var1, count1, mean2, var2, count2); + + mean_block_buf[offset1] = mean1; + var_block_buf[offset1] = var1; + count_block_buf[offset1] = count1; + } + + block_sync_lds(); + }); + + index_t offset = block_buf_desc_m_k.CalculateOffset(make_tuple(thread_m_cluster_id, 0)); + + count = count_block_buf[offset]; + mean_value = mean_block_buf[offset]; + + if constexpr(GetActualVariance) + var_value = var_block_buf[offset] / count; + else + var_value = var_block_buf[offset]; + }; +}; +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/device_layernorm.hpp b/include/ck/tensor_operation/gpu/device/device_layernorm.hpp deleted file mode 100644 index 464ac8c549..0000000000 --- a/include/ck/tensor_operation/gpu/device/device_layernorm.hpp +++ /dev/null @@ -1,356 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. - -#pragma once - -#include -#include - -#include "ck/utility/reduction_operator.hpp" -#include "ck/tensor_operation/gpu/device/device_normalization.hpp" -#include "ck/tensor_operation/gpu/device/device_reduce.hpp" -#include "ck/tensor_operation/gpu/device/device_reduce_multiblock.hpp" -#include "ck/tensor_operation/gpu/device/device_reduce_common.hpp" -#include "ck/tensor_operation/gpu/grid/gridwise_layernorm.hpp" -#include "ck/tensor_operation/gpu/grid/gridwise_set_buffer_value.hpp" -#include "ck/host_utility/device_prop.hpp" -#include "ck/host_utility/kernel_launch.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { - -// Y = LayerNorm(X, Beta, Gamma) -template -struct DeviceLayernorm : public DeviceNormalization2 -{ - static_assert( - (KThreadSliceSize % GammaSrcVectorSize == 0), - "Invalid thread slice sizes and/or gamma vector sizes configuration, please check!"); - - static_assert( - (KThreadSliceSize % BetaSrcVectorSize == 0), - "Invalid thread slice sizes and/or beta vector sizes configuration, please check!"); - - using PassThrough = tensor_operation::element_wise::PassThrough; - - // Used for freeloading of some handy functions from DeviceReduceMultiBlock - using Reduction = DeviceReduceMultiBlock; // YDstVectorSize - - static auto MakeAffine1dDescriptor(const std::vector& Lengths, - const std::vector& Strides, - int blkGroupSize, - int numBlockTileIteration) - { - const auto tupleLengths = make_tuple_from_array(Lengths, Number{}); - const auto tupleStrides = make_tuple_from_array(Strides, Number{}); - - auto desc = make_naive_tensor_descriptor(tupleLengths, tupleStrides); - - auto grid_desc_k = transform_tensor_descriptor( - desc, - make_tuple(make_merge_transform(tupleLengths)), - make_tuple(typename arithmetic_sequence_gen<0, NumReduceDim, 1>::type{}), - make_tuple(Sequence<0>{})); - - const auto reduceTotalLength = grid_desc_k.GetLength(Number<0>{}); - const int reduceSizePerBlock = Reduction::K_BlockTileSize * numBlockTileIteration; - - const auto Pad_K = reduceSizePerBlock * blkGroupSize - reduceTotalLength; - - auto grid_desc_k_padded = transform_tensor_descriptor( - grid_desc_k, - make_tuple(make_right_pad_transform(reduceTotalLength, Pad_K)), - make_tuple(Sequence<0>{}), - make_tuple(Sequence<0>{})); - - return (grid_desc_k_padded); - }; - - using GridDesc_M_K = decltype(Reduction::MakeSrc2dDescriptor({1}, {1}, 1, 1)); - using GridDesc_K = decltype(MakeAffine1dDescriptor({1}, {1}, 1, 1)); - - using GridwiseReduceLayernormGeneric = GridwiseLayernorm_mk_to_mk; - - using GridwiseReduceLayernormSweepOnce = GridwiseLayernorm_mk_to_mk; - - struct Argument : public Reduction::Argument - { - Argument(const std::vector lengths, - const std::vector xStrides, - const std::vector gammaStrides, - const std::vector betaStrides, - const std::vector reduceDims, - AccElementwiseOperation acc_elementwise_op, - AccDataType epsilon, - const XDataType* p_x, - const GammaDataType* p_gamma, - const BetaDataType* p_beta, - YDataType* p_y) - : Reduction::Argument(lengths, - xStrides, - {}, - {}, - reduceDims, - 0.0f, // alpha - 0.0f, // beta - p_x, - nullptr, - p_y, - nullptr, - acc_elementwise_op, - PassThrough{}), - epsilon_(epsilon), - p_gamma_(p_gamma), - p_beta_(p_beta), - gammaStrides_(gammaStrides), - betaStrides_(betaStrides) - { - reduceLength_.resize(NumReduceDim); - - for(int i = 0; i < NumReduceDim; ++i) - { - reduceLength_[i] = lengths[reduceDims[i]]; - } - } - - AccDataType epsilon_; - const GammaDataType* p_gamma_; - const BetaDataType* p_beta_; - std::vector reduceLength_; - std::vector gammaStrides_; - std::vector betaStrides_; - }; - - struct Invoker : public BaseInvoker - { - float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) - { - const auto x_grid_desc_m_k = Reduction::MakeSrc2dDescriptor( - arg.inLengths_, arg.inStrides_, arg.blkGroupSize, arg.numBlockTileIteration); - const auto gamma_grid_desc_k = MakeAffine1dDescriptor( - arg.reduceLength_, arg.gammaStrides_, arg.blkGroupSize, arg.numBlockTileIteration); - const auto beta_grid_desc_k = MakeAffine1dDescriptor( - arg.reduceLength_, arg.betaStrides_, arg.blkGroupSize, arg.numBlockTileIteration); - const auto y_grid_desc_m_k = Reduction::MakeSrc2dDescriptor( - arg.inLengths_, arg.inStrides_, arg.blkGroupSize, arg.numBlockTileIteration); - - bool sweep_once = - x_grid_desc_m_k.GetLength(Number<1>{}) <= KThreadClusterSize * KThreadSliceSize; - - const auto kernel_main = sweep_once ? kernel_layernorm - : kernel_layernorm; - - float avg_time = 0; - avg_time += launch_and_time_kernel(stream_config, - kernel_main, - dim3(arg.gridSize), - dim3(BlockSize), - 0, - x_grid_desc_m_k, - gamma_grid_desc_k, - beta_grid_desc_k, - y_grid_desc_m_k, - arg.numBlockTileIteration, - arg.epsilon_, - arg.in_dev_, - arg.p_gamma_, - arg.p_beta_, - arg.out_dev_, - arg.acc_elementwise_op_); - - return (avg_time); - }; - - float Run(const BaseArgument* p_arg, - const StreamConfig& stream_config = StreamConfig{}) override - { - return Run(*dynamic_cast(p_arg), stream_config); - }; - }; - - bool IsSupportedArgument(const BaseArgument* p_arg) override - { - const Argument* p_arg_ = dynamic_cast(p_arg); - - if(!Reduction::IsSupportedArgument(p_arg_)) - { - return false; - } - - if(p_arg_->inLengths_[Rank - 1] % YDstVectorSize != 0) - { - return false; - } - - if(p_arg_->gammaStrides_.size() != NumReduceDim || - p_arg_->betaStrides_.size() != NumReduceDim) - return false; - - auto IsScalarPerVectorValid = [](bool isLastDimensionCoalesced, int scalarPerVector) { - bool ret = true; - - if(!isLastDimensionCoalesced) - ret = scalarPerVector == 1; - else - ret = KThreadSliceSize % scalarPerVector == 0; - - return ret; - }; - - if(!IsScalarPerVectorValid(p_arg_->gammaStrides_.back() == 1, GammaSrcVectorSize)) - return false; - - if(!IsScalarPerVectorValid(p_arg_->betaStrides_.back() == 1, BetaSrcVectorSize)) - return false; - - return true; - }; - - std::unique_ptr - MakeArgumentPointer(const std::vector lengths, - const std::vector xStrides, - const std::vector gammaStrides, - const std::vector betaStrides, - const std::vector reduceDims, - AccDataType epsilon, - const void* p_x, - const void* p_gamma, - const void* p_beta, - void* p_y, - AccElementwiseOperation acc_elementwise_op) override - { - return std::make_unique(lengths, - xStrides, - gammaStrides, - betaStrides, - reduceDims, - acc_elementwise_op, - epsilon, - static_cast(p_x), - static_cast(p_gamma), - static_cast(p_beta), - static_cast(p_y)); - }; - - std::unique_ptr MakeInvokerPointer() override - { - return std::make_unique(); - }; - - std::string GetTypeString() const override - { - auto str = std::stringstream(); - - // clang-format off - str << "DeviceLayernorm<" << BlockSize << ","; - str << "M_C" << MThreadClusterSize << "_S" << MThreadSliceSize << ","; - str << "K_C" << KThreadClusterSize << "_S" << KThreadSliceSize << ","; - str << "XYSrcVectorDim_" << XYSrcVectorDim << ","; - str << "VectorSize_X" << XSrcVectorSize << "_Gamma" << GammaSrcVectorSize << "_Beta" << BetaSrcVectorSize << "_Y" << YDstVectorSize << ">"; - // clang-format on - - return str.str(); - } -}; - -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/device_layernorm_impl.hpp b/include/ck/tensor_operation/gpu/device/device_layernorm_impl.hpp new file mode 100644 index 0000000000..7852209c3a --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/device_layernorm_impl.hpp @@ -0,0 +1,487 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/utility/reduction_operator.hpp" +#include "ck/tensor_operation/gpu/device/device_normalization.hpp" +#include "ck/tensor_operation/gpu/device/device_reduce.hpp" +#include "ck/tensor_operation/gpu/device/device_reduce_common.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_layernorm_welford_variance.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_set_buffer_value.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" + +namespace ck { +template +__global__ void kernel_layernorm(const GridDesc_M_K x_grid_desc_m_k, + const GridDesc_K gamma_grid_desc_k, + const GridDesc_K beta_grid_desc_k, + const GridDesc_M_K y_grid_desc_m_k, + index_t num_k_block_tile_iteration, + AccDataType epsilon, + const XDataType* const __restrict__ p_x_global, + const GammaDataType* const __restrict__ p_gamma_global, + const BetaDataType* const __restrict__ p_beta_global, + YDataType* const __restrict__ p_y_global, + const AccElementwiseOperation acc_elementwise_op) +{ + GridwiseReduction::Run(x_grid_desc_m_k, + gamma_grid_desc_k, + beta_grid_desc_k, + y_grid_desc_m_k, + num_k_block_tile_iteration, + epsilon, + p_x_global, + p_gamma_global, + p_beta_global, + p_y_global, + acc_elementwise_op); +}; +} // namespace ck + +namespace ck { +namespace tensor_operation { +namespace device { + +// Y = LayerNorm(X, Beta, Gamma) +template +struct DeviceLayernormImpl : public DeviceLayernorm +{ + static_assert( + (KThreadSliceSize % GammaSrcVectorSize == 0), + "Invalid thread slice sizes and/or gamma vector sizes configuration, please check!"); + + static_assert( + (KThreadSliceSize % BetaSrcVectorSize == 0), + "Invalid thread slice sizes and/or beta vector sizes configuration, please check!"); + + using PassThrough = tensor_operation::element_wise::PassThrough; + + static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize; + static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize; + + static auto MakeSrc2dDescriptor(const std::vector& inLengths, + const std::vector& inStrides, + int blkGroupSize, + int numBlockTileIteration) + { + constexpr index_t NumInvariantDim = Rank - NumReduceDim; + static constexpr index_t numSrcDim = Rank; + static constexpr bool reduceAllDim = (NumInvariantDim == 0); + + const auto tupleSrcLengths = make_tuple_from_array(inLengths, Number{}); + const auto tupleSrcStrides = make_tuple_from_array(inStrides, Number{}); + + const auto inDesc = make_naive_tensor_descriptor(tupleSrcLengths, tupleSrcStrides); + + const auto in_grid_desc_m_k = [&]() { + if constexpr(reduceAllDim) + { + const auto one_dim_inDesc = transform_tensor_descriptor( + inDesc, + make_tuple(make_merge_transform(tupleSrcLengths)), + make_tuple(typename arithmetic_sequence_gen<0, numSrcDim, 1>::type{}), + make_tuple(Sequence<0>{})); + + return transform_tensor_descriptor(one_dim_inDesc, + make_tuple(make_unmerge_transform(make_tuple( + 1, one_dim_inDesc.GetLength(Number<0>{})))), + make_tuple(Sequence<0>{}), + make_tuple(Sequence<0, 1>{})); + } + else + { + using InvariantDims = typename arithmetic_sequence_gen<0, NumInvariantDim, 1>::type; + using ReduceDims = typename arithmetic_sequence_gen::type; + + const auto reduceDimLengths = + make_tuple_from_array_and_index_seq(inLengths, ReduceDims{}); + const auto invariantDimLengths = + make_tuple_from_array_and_index_seq(inLengths, InvariantDims{}); + + return transform_tensor_descriptor( + inDesc, + make_tuple(make_merge_transform(invariantDimLengths), + make_merge_transform(reduceDimLengths)), + make_tuple(InvariantDims{}, ReduceDims{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + }(); + + const auto invariantLength = in_grid_desc_m_k.GetLength(Number<0>{}); + const auto reduceLength = in_grid_desc_m_k.GetLength(Number<1>{}); + + const int reduceSizePerBlock = K_BlockTileSize * numBlockTileIteration; + const auto inPad_M = + math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength; + const auto inPad_K = reduceSizePerBlock * blkGroupSize - reduceLength; + + auto in_grid_desc_m_k_padded = transform_tensor_descriptor( + in_grid_desc_m_k, + make_tuple(make_right_pad_transform(invariantLength, inPad_M), + make_right_pad_transform(reduceLength, inPad_K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return (in_grid_desc_m_k_padded); + }; + + static auto MakeAffine1dDescriptor(const std::vector& Lengths, + const std::vector& Strides, + int blkGroupSize, + int numBlockTileIteration) + { + const auto tupleLengths = make_tuple_from_array(Lengths, Number{}); + const auto tupleStrides = make_tuple_from_array(Strides, Number{}); + + auto desc = make_naive_tensor_descriptor(tupleLengths, tupleStrides); + + auto grid_desc_k = transform_tensor_descriptor( + desc, + make_tuple(make_merge_transform(tupleLengths)), + make_tuple(typename arithmetic_sequence_gen<0, NumReduceDim, 1>::type{}), + make_tuple(Sequence<0>{})); + + const auto reduceTotalLength = grid_desc_k.GetLength(Number<0>{}); + const int reduceSizePerBlock = K_BlockTileSize * numBlockTileIteration; + + const auto Pad_K = reduceSizePerBlock * blkGroupSize - reduceTotalLength; + + auto grid_desc_k_padded = transform_tensor_descriptor( + grid_desc_k, + make_tuple(make_right_pad_transform(reduceTotalLength, Pad_K)), + make_tuple(Sequence<0>{}), + make_tuple(Sequence<0>{})); + + return (grid_desc_k_padded); + }; + + using GridDesc_M_K = decltype(MakeSrc2dDescriptor({1}, {1}, 1, 1)); + using GridDesc_K = decltype(MakeAffine1dDescriptor({1}, {1}, 1, 1)); + + using GridwiseReduceLayernormGeneric = + GridwiseLayernormWelfordVariance_mk_to_mk; + + using GridwiseReduceLayernormSweepOnce = + GridwiseLayernormWelfordVariance_mk_to_mk; + + struct Argument : public BaseArgument + { + Argument(const std::vector lengths, + const std::vector xStrides, + const std::vector gammaStrides, + const std::vector betaStrides, + const std::vector yStrides, + const std::vector reduceDims, + AccElementwiseOperation acc_elementwise_op, + AccDataType epsilon, + const XDataType* p_x, + const GammaDataType* p_gamma, + const BetaDataType* p_beta, + YDataType* p_y) + : epsilon_(epsilon), + p_x_(p_x), + p_gamma_(p_gamma), + p_beta_(p_beta), + p_y_(p_y), + gammaStrides_(gammaStrides), + betaStrides_(betaStrides), + acc_elementwise_op_(acc_elementwise_op) + { + Lengths_ = shuffle_tensor_dimensions(lengths, reduceDims); + xStrides_ = shuffle_tensor_dimensions(xStrides, reduceDims); + yStrides_ = shuffle_tensor_dimensions(yStrides, reduceDims); + + long_index_t invariant_total_length; + long_index_t reduce_total_length; + + std::tie(invariant_total_length, reduce_total_length) = + get_2d_lengths(Lengths_); + + blkGroupSize_ = 1; + numBlockTileIteration_ = (reduce_total_length + K_BlockTileSize - 1) / K_BlockTileSize; + + gridSize_ = math::integer_least_multiple(invariant_total_length, M_BlockTileSize) / + M_BlockTileSize * blkGroupSize_; + + reduceLengths_.resize(NumReduceDim); + + for(int i = 0; i < NumReduceDim; ++i) + { + reduceLengths_[i] = lengths[reduceDims[i]]; + } + } + + AccDataType epsilon_; + + const XDataType* p_x_; + const GammaDataType* p_gamma_; + const BetaDataType* p_beta_; + YDataType* p_y_; + + std::vector Lengths_; + std::vector xStrides_; + std::vector reduceLengths_; + std::vector gammaStrides_; + std::vector betaStrides_; + std::vector yStrides_; + + AccElementwiseOperation acc_elementwise_op_; + + int blkGroupSize_; + int numBlockTileIteration_; + size_t gridSize_; + }; + + struct Invoker : public BaseInvoker + { + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + const auto x_grid_desc_m_k = MakeSrc2dDescriptor( + arg.Lengths_, arg.xStrides_, arg.blkGroupSize_, arg.numBlockTileIteration_); + const auto gamma_grid_desc_k = MakeAffine1dDescriptor(arg.reduceLengths_, + arg.gammaStrides_, + arg.blkGroupSize_, + arg.numBlockTileIteration_); + const auto beta_grid_desc_k = MakeAffine1dDescriptor(arg.reduceLengths_, + arg.betaStrides_, + arg.blkGroupSize_, + arg.numBlockTileIteration_); + const auto y_grid_desc_m_k = MakeSrc2dDescriptor( + arg.Lengths_, arg.yStrides_, arg.blkGroupSize_, arg.numBlockTileIteration_); + + bool sweep_once = + x_grid_desc_m_k.GetLength(Number<1>{}) <= KThreadClusterSize * KThreadSliceSize; + + const auto kernel_main = sweep_once ? kernel_layernorm + : kernel_layernorm; + + float avg_time = 0; + avg_time += launch_and_time_kernel(stream_config, + kernel_main, + dim3(arg.gridSize_), + dim3(BlockSize), + 0, + x_grid_desc_m_k, + gamma_grid_desc_k, + beta_grid_desc_k, + y_grid_desc_m_k, + arg.numBlockTileIteration_, + arg.epsilon_, + arg.p_x_, + arg.p_gamma_, + arg.p_beta_, + arg.p_y_, + arg.acc_elementwise_op_); + + return (avg_time); + }; + + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + }; + }; + + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + const Argument* p_arg_ = dynamic_cast(p_arg); + + constexpr index_t NumInvariantDim = Rank - NumReduceDim; + + if constexpr(XYSrcVectorDim == 0) + { + if constexpr(NumInvariantDim == 0) + { + return false; + } + else + { + if(p_arg_->xStrides_[NumInvariantDim - 1] != 1) + return false; + + if(p_arg_->invariant_lowest_length % XSrcVectorSize != 0) + return false; + }; + } + else + { + if(p_arg_->xStrides_[Rank - 1] != 1) + return false; + + if(p_arg_->Lengths_[Rank - 1] % XSrcVectorSize != 0) + return false; + }; + + if(p_arg_->Lengths_[Rank - 1] % YDstVectorSize != 0) + { + return false; + } + + if(p_arg_->gammaStrides_.size() != NumReduceDim || + p_arg_->betaStrides_.size() != NumReduceDim) + return false; + + auto IsScalarPerVectorValid = [](bool isLastDimensionCoalesced, int scalarPerVector) { + bool ret = true; + + if(!isLastDimensionCoalesced) + ret = scalarPerVector == 1; + else + ret = KThreadSliceSize % scalarPerVector == 0; + + return ret; + }; + + if(!IsScalarPerVectorValid(p_arg_->gammaStrides_.back() == 1, GammaSrcVectorSize)) + return false; + + if(!IsScalarPerVectorValid(p_arg_->betaStrides_.back() == 1, BetaSrcVectorSize)) + return false; + + return true; + }; + + std::unique_ptr + MakeArgumentPointer(const std::vector lengths, + const std::vector xStrides, + const std::vector gammaStrides, + const std::vector betaStrides, + const std::vector yStrides, + const std::vector reduceDims, + AccDataType epsilon, + const void* p_x, + const void* p_gamma, + const void* p_beta, + void* p_y, + AccElementwiseOperation acc_elementwise_op) override + { + return std::make_unique(lengths, + xStrides, + gammaStrides, + betaStrides, + yStrides, + reduceDims, + acc_elementwise_op, + epsilon, + static_cast(p_x), + static_cast(p_gamma), + static_cast(p_beta), + static_cast(p_y)); + }; + + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(); + }; + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceLayernormImpl<" << BlockSize << ","; + str << "M_C" << MThreadClusterSize << "_S" << MThreadSliceSize << ","; + str << "K_C" << KThreadClusterSize << "_S" << KThreadSliceSize << ","; + str << "XYSrcVectorDim_" << XYSrcVectorDim << ","; + str << "VectorSize_X" << XSrcVectorSize << "_Gamma" << GammaSrcVectorSize << "_Beta" << BetaSrcVectorSize << "_Y" << YDstVectorSize << ">"; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/device_normalization.hpp b/include/ck/tensor_operation/gpu/device/device_normalization.hpp index 2ca66c5d82..7032b2858b 100644 --- a/include/ck/tensor_operation/gpu/device/device_normalization.hpp +++ b/include/ck/tensor_operation/gpu/device/device_normalization.hpp @@ -46,13 +46,14 @@ template -struct DeviceNormalization2 : public BaseOperator +struct DeviceLayernorm : public BaseOperator { virtual std::unique_ptr MakeArgumentPointer(const std::vector lengths, const std::vector xStrides, const std::vector gammaStrides, const std::vector betaStrides, + const std::vector yStrides, const std::vector reduceDims, AccDataType epsilon, const void* p_x, @@ -72,14 +73,14 @@ template -using DeviceNormalization2Ptr = std::unique_ptr>; +using DeviceLayernormPtr = std::unique_ptr>; } // namespace device } // namespace tensor_operation diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_layernorm.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_layernorm_naive_variance.hpp similarity index 91% rename from include/ck/tensor_operation/gpu/grid/gridwise_layernorm.hpp rename to include/ck/tensor_operation/gpu/grid/gridwise_layernorm_naive_variance.hpp index 597b164788..99061328b6 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_layernorm.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_layernorm_naive_variance.hpp @@ -14,40 +14,6 @@ namespace ck { -template -__global__ void kernel_layernorm(const GridDesc_M_K x_grid_desc_m_k, - const GridDesc_K gamma_grid_desc_k, - const GridDesc_K beta_grid_desc_k, - const GridDesc_M_K y_grid_desc_m_k, - index_t num_k_block_tile_iteration, - AccDataType epsilon, - const XDataType* const __restrict__ p_x_global, - const GammaDataType* const __restrict__ p_gamma_global, - const BetaDataType* const __restrict__ p_beta_global, - YDataType* const __restrict__ p_y_global, - const AccElementwiseOperation acc_elementwise_op) -{ - GridwiseReduction::Run(x_grid_desc_m_k, - gamma_grid_desc_k, - beta_grid_desc_k, - y_grid_desc_m_k, - num_k_block_tile_iteration, - epsilon, - p_x_global, - p_gamma_global, - p_beta_global, - p_y_global, - acc_elementwise_op); -}; - // Y = LayerNorm(X, Beta, Gamma) template -struct GridwiseLayernorm_mk_to_mk +struct GridwiseLayernormNaiveVariance_mk_to_mk { static_assert((XSrcVectorDim == 0 && MThreadSliceSize % XSrcVectorSize == 0) || (XSrcVectorDim == 1 && KThreadSliceSize % XSrcVectorSize == 0), diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_layernorm_welford_variance.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_layernorm_welford_variance.hpp new file mode 100644 index 0000000000..a81c501e61 --- /dev/null +++ b/include/ck/tensor_operation/gpu/grid/gridwise_layernorm_welford_variance.hpp @@ -0,0 +1,328 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/data_type.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_welford.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_welford.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +namespace ck { + +// Y = LayerNorm(X, Beta, Gamma) +template +struct GridwiseLayernormWelfordVariance_mk_to_mk +{ + static_assert((XSrcVectorDim == 0 && MThreadSliceSize % XSrcVectorSize == 0) || + (XSrcVectorDim == 1 && KThreadSliceSize % XSrcVectorSize == 0), + "Invalid thread slice sizes and/or vector sizes configuration, please check!"); + + static_assert((YDstVectorDim == 0 && MThreadSliceSize % YDstVectorSize == 0) || + (YDstVectorDim == 1 && KThreadSliceSize % YDstVectorSize == 0), + "Invalid thread slice sizes and/or vector sizes configuration, please check!"); + + static constexpr bool reorder_thread_cluster = (XSrcVectorDim == 0); + + using ThreadClusterLengths_M_K = Sequence; + + using ThreadBufferDimAccessOrder = + typename conditional, Sequence<0, 1>>::type; + + using ThreadClusterArrangeOrder = + typename conditional, Sequence<0, 1>>::type; + + static constexpr auto thread_cluster_desc = + make_cluster_descriptor(ThreadClusterLengths_M_K{}, ThreadClusterArrangeOrder{}); + + using ThreadReduceSrcDesc_M_K = decltype(make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{}))); + using ThreadReduceDstDesc_M = + decltype(make_naive_tensor_descriptor_packed(make_tuple(Number{}))); + + using ThreadwiseWelford = + ThreadwiseWelford; + + using BlockwiseWelford = BlockwiseWelford; + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + + static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize; + static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize; + + __device__ static int GetKPerThread(const GridDesc_M_K& x_grid_desc_m_k, + int thread_k_cluster_id) + { + int kPerBlock = x_grid_desc_m_k.GetTransforms()[I0].GetUpperLengths()[I1]; + int kPerThread = + kPerBlock < K_BlockTileSize ? 0 : KThreadSliceSize * (kPerBlock / K_BlockTileSize); + int kPerBlockTail = kPerBlock - kPerThread * KThreadClusterSize; + + if(kPerBlockTail > 0) + { + int thread_max_len = (thread_k_cluster_id + 1) * KThreadSliceSize; + int delta = thread_max_len - kPerBlockTail; + delta = math::clamp(thread_max_len - kPerBlockTail, 0, KThreadSliceSize); + kPerThread += KThreadSliceSize - delta; + } + + return kPerThread; + } + + __device__ static void Run(const GridDesc_M_K& x_grid_desc_m_k, + const GridDesc_K& gamma_grid_desc_k, + const GridDesc_K& beta_grid_desc_k, + const GridDesc_M_K& y_grid_desc_m_k, + index_t num_k_block_tile_iteration, + AccDataType epsilon, + const XDataType* const __restrict__ p_x_global, + const GammaDataType* const __restrict__ p_gamma_global, + const BetaDataType* const __restrict__ p_beta_global, + YDataType* const __restrict__ p_y_global, + const AccElementwiseOperation acc_elementwise_op) + { + if constexpr(SweepOnce) + { + num_k_block_tile_iteration = 1; + } + + auto y_global_val_buf = make_dynamic_buffer( + p_y_global, y_grid_desc_m_k.GetElementSpaceSize()); + + StaticBuffer + x_thread_buf; + + StaticBuffer gamma_thread_buf; + + StaticBuffer& beta_thread_buf = + gamma_thread_buf; + + StaticBuffer + y_thread_buf; + + StaticBuffer mean_thread_buf; + StaticBuffer var_thread_buf; + + const index_t thread_local_id = get_thread_local_1d_id(); + const index_t block_global_id = get_block_1d_id(); + + const auto thread_cluster_idx = + thread_cluster_desc.CalculateBottomIndex(make_multi_index(thread_local_id)); + + const auto thread_m_cluster_id = thread_cluster_idx[I0]; + const auto thread_k_cluster_id = thread_cluster_idx[I1]; + + using ThreadBufferLengths_M_K = Sequence; + using ThreadBufferLengths_K = Sequence; + constexpr auto thread_buffer_desc_m_k = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{})); + constexpr auto thread_buffer_desc_k = + make_naive_tensor_descriptor_packed(make_tuple(Number{})); + + auto threadwise_x_load = ThreadwiseTensorSliceTransfer_v2( + x_grid_desc_m_k, + make_multi_index(block_global_id * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize, + thread_k_cluster_id * KThreadSliceSize)); + + auto threadwise_gamma_load = + ThreadwiseTensorSliceTransfer_v2, + 0, + GammaSrcVectorSize, + 1, + true>( + gamma_grid_desc_k, make_multi_index(thread_k_cluster_id * KThreadSliceSize)); + + auto threadwise_beta_load = ThreadwiseTensorSliceTransfer_v2, + 0, + BetaSrcVectorSize, + 1, + true>( + beta_grid_desc_k, make_multi_index(thread_k_cluster_id * KThreadSliceSize)); + + auto threadwise_y_store = + ThreadwiseTensorSliceTransfer_v1r3( + y_grid_desc_m_k, + make_multi_index(block_global_id * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize, + thread_k_cluster_id * KThreadSliceSize), + acc_elementwise_op); + + // Copy x from Cache + // one pass: fwd, second pass: bwd + constexpr auto thread_copy_fwd_step_k = make_multi_index(SweepOnce ? 0 : K_BlockTileSize); + constexpr auto thread_copy_bwd_step_k = make_multi_index(SweepOnce ? 0 : -K_BlockTileSize); + + constexpr auto thread_copy_fwd_step_m_k = + make_multi_index(0, SweepOnce ? 0 : K_BlockTileSize); + constexpr auto thread_copy_bwd_step_m_k = + make_multi_index(0, SweepOnce ? 0 : -K_BlockTileSize); + + const auto x_global_val_buf = make_dynamic_buffer( + p_x_global, x_grid_desc_m_k.GetElementSpaceSize()); + + const auto gamma_global_val_buf = make_dynamic_buffer( + p_gamma_global, gamma_grid_desc_k.GetElementSpaceSize()); + + const auto beta_global_val_buf = make_dynamic_buffer( + p_beta_global, beta_grid_desc_k.GetElementSpaceSize()); + + auto threadwise_welford = ThreadwiseWelford(); + threadwise_welford.max_count_ = GetKPerThread(x_grid_desc_m_k, thread_k_cluster_id); + + static_for<0, MThreadSliceSize, 1>{}([&](auto I) { + mean_thread_buf(I) = type_convert(0.0f); + var_thread_buf(I) = type_convert(0.0f); + }); + + for(index_t reducedTiles = 0; reducedTiles < num_k_block_tile_iteration; ++reducedTiles) + { + + threadwise_x_load.Run(x_grid_desc_m_k, + x_global_val_buf, + thread_buffer_desc_m_k, + make_tuple(I0, I0), + x_thread_buf); + threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, thread_copy_fwd_step_m_k); + threadwise_welford.Run(x_thread_buf, mean_thread_buf, var_thread_buf); + } + + static_for<0, MThreadSliceSize, 1>{}([&](auto I) { + if constexpr(I > 0) + block_sync_lds(); + + int count = threadwise_welford.cur_count_; + BlockwiseWelford::Run(mean_thread_buf(I), var_thread_buf(I), count); + }); + + auto thread_copy_tail_m_k = (num_k_block_tile_iteration - 1) * thread_copy_fwd_step_m_k; + auto thread_copy_tail_k = (num_k_block_tile_iteration - 1) * thread_copy_fwd_step_k; + + threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, thread_copy_bwd_step_m_k); + threadwise_gamma_load.MoveSrcSliceWindow(gamma_grid_desc_k, thread_copy_tail_k); + threadwise_beta_load.MoveSrcSliceWindow(beta_grid_desc_k, thread_copy_tail_k); + threadwise_y_store.MoveDstSliceWindow(y_grid_desc_m_k, thread_copy_tail_m_k); + + for(index_t reducedTiles = 0; reducedTiles < num_k_block_tile_iteration; ++reducedTiles) + { + if constexpr(!SweepOnce) + { + threadwise_x_load.Run(x_grid_desc_m_k, + x_global_val_buf, + thread_buffer_desc_m_k, + make_tuple(I0, I0), + x_thread_buf); + } + + threadwise_gamma_load.Run(gamma_grid_desc_k, + gamma_global_val_buf, + thread_buffer_desc_k, + make_tuple(I0), + gamma_thread_buf); + + static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { + static_for<0, KThreadSliceSize, 1>{}([&](auto iK) { + constexpr auto offset_m_k = + thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK)); + + constexpr auto offset_k = thread_buffer_desc_k.CalculateOffset(make_tuple(iK)); + + // normalize + y_thread_buf(Number{}) = + (x_thread_buf(Number{}) - mean_thread_buf(iM)) / + sqrt(var_thread_buf(iM) + epsilon); + + // gamma + y_thread_buf(Number{}) = + y_thread_buf(Number{}) * gamma_thread_buf(Number{}); + }); + }); + + threadwise_beta_load.Run(beta_grid_desc_k, + beta_global_val_buf, + thread_buffer_desc_k, + make_tuple(I0), + beta_thread_buf); + + static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { + static_for<0, KThreadSliceSize, 1>{}([&](auto iK) { + constexpr auto offset_m_k = + thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK)); + + constexpr auto offset_k = thread_buffer_desc_k.CalculateOffset(make_tuple(iK)); + + // beta + y_thread_buf(Number{}) = + y_thread_buf(Number{}) + beta_thread_buf(Number{}); + }); + }); + + threadwise_y_store.Run(thread_buffer_desc_m_k, + make_tuple(I0, I0), + y_thread_buf, + y_grid_desc_m_k, + y_global_val_buf); + + threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, thread_copy_bwd_step_m_k); + threadwise_gamma_load.MoveSrcSliceWindow(gamma_grid_desc_k, thread_copy_bwd_step_k); + threadwise_beta_load.MoveSrcSliceWindow(beta_grid_desc_k, thread_copy_bwd_step_k); + threadwise_y_store.MoveDstSliceWindow(y_grid_desc_m_k, thread_copy_bwd_step_m_k); + } + } +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_welford.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_welford.hpp new file mode 100644 index 0000000000..3e224ae664 --- /dev/null +++ b/include/ck/tensor_operation/gpu/thread/threadwise_welford.hpp @@ -0,0 +1,78 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/math_v2.hpp" + +namespace ck { + +// Assume +// 1) XDesc is known at compile-time +// 2) MeanVarDesc is known at compile-time +// 3) XBuffer is static buffer +// 4) MeanBuffer is static buffer +// 5) VarBuffer is static buffer +template +struct ThreadwiseWelford +{ + static constexpr auto x_thread_desc_m_k = XThreadDesc_M_K{}; + static constexpr auto mean_var_thread_desc_m = MeanVarThreadDesc_M{}; + + static constexpr auto thread_x_length_m = x_thread_desc_m_k.GetLength(Number<0>{}); + static constexpr auto thread_x_length_k = x_thread_desc_m_k.GetLength(Number<1>{}); + static constexpr auto thread_mean_var_length_m = mean_var_thread_desc_m.GetLength(Number<0>{}); + + static_assert(thread_x_length_m == thread_mean_var_length_m, + "lengths of source and mean/var buffer must match!"); + + __device__ constexpr ThreadwiseWelford() : cur_count_(0), max_count_(0) {} + + __device__ inline void Update(T& mean, T& var, T x) + { + using ck::math::isnan; + + if(isnan(x)) + { + mean = x; + var = x; + } + else + { + T delta = x - mean; + mean += delta / cur_count_; + T delta2 = x - mean; + var += delta * delta2; + } + } + + template + __device__ void + Run(const XBufferType& x_buf_m_k, MeanBufferType& mean_buf_m, VarBufferType& var_buf_m) + { + // FIXME - Better naming for var_buf_m + + static_for<0, thread_x_length_k, 1>{}([&](auto iK) { + if(cur_count_ < max_count_) + { + ++cur_count_; + + static_for<0, thread_x_length_m, 1>{}([&](auto iM) { + constexpr index_t out_offset = + mean_var_thread_desc_m.CalculateOffset(make_tuple(iM)); + + constexpr auto in_offset = + x_thread_desc_m_k.CalculateOffset(make_tuple(iM, iK)); + Update(mean_buf_m(Number{}), + var_buf_m(Number{}), + x_buf_m_k[Number{}]); + }); + } + }); + }; + + int cur_count_; + int max_count_; +}; + +} // namespace ck diff --git a/include/ck/utility/math.hpp b/include/ck/utility/math.hpp index 0cfc2f7da4..12203bd7f3 100644 --- a/include/ck/utility/math.hpp +++ b/include/ck/utility/math.hpp @@ -144,6 +144,12 @@ __host__ __device__ constexpr auto min(X x, Ys... ys) return min(x, min(ys...)); } +template +__host__ __device__ constexpr T clamp(const T& x, const T& lowerbound, const T& upperbound) +{ + return min(max(x, lowerbound), upperbound); +} + // disallow implicit type casting template __device__ T exp(T x); diff --git a/library/src/tensor_operation_instance/gpu/normalization/device_layernorm_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/normalization/device_layernorm_f16_instance.cpp index b880d648dd..ddcde996f7 100644 --- a/library/src/tensor_operation_instance/gpu/normalization/device_layernorm_f16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/normalization/device_layernorm_f16_instance.cpp @@ -2,7 +2,7 @@ // Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. #include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/device_layernorm.hpp" +#include "ck/tensor_operation/gpu/device/device_layernorm_impl.hpp" #include "ck/utility/data_type.hpp" #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" @@ -21,28 +21,28 @@ template using device_layernorm_f16_instances = std::tuple< // clang-format off // XDataType, GammaDataType, BetaDataType, AccDataType, YDataType, Rank, NumReduceDim, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XYSrcVectorDim, XSrcVectorSize, GammaSrcVectorSize, BetaSrcVectorSize, YDstVectorSize> - DeviceLayernorm, // fallback kernel - DeviceLayernorm, // fallback kernel - DeviceLayernorm, // fallback kernel - DeviceLayernorm, - DeviceLayernorm, - DeviceLayernorm, - DeviceLayernorm, - DeviceLayernorm, - DeviceLayernorm, - DeviceLayernorm, - DeviceLayernorm + DeviceLayernormImpl, // fallback kernel + DeviceLayernormImpl, // fallback kernel + DeviceLayernormImpl, // fallback kernel + DeviceLayernormImpl, + DeviceLayernormImpl, + DeviceLayernormImpl, + DeviceLayernormImpl, + DeviceLayernormImpl, + DeviceLayernormImpl, + DeviceLayernormImpl, + DeviceLayernormImpl // clang-format on >; void add_device_layernorm_f16_rank2_instances( - std::vector>& instances) + std::vector>& instances) { add_device_operation_instances(instances, device_layernorm_f16_instances<2, 1>{}); } void add_device_layernorm_f16_rank4_instances( - std::vector>& instances) + std::vector>& instances) { add_device_operation_instances(instances, device_layernorm_f16_instances<4, 3>{}); } diff --git a/library/src/tensor_operation_instance/gpu/normalization/device_layernorm_f32_instance.cpp b/library/src/tensor_operation_instance/gpu/normalization/device_layernorm_f32_instance.cpp index e30f76b514..313d876807 100644 --- a/library/src/tensor_operation_instance/gpu/normalization/device_layernorm_f32_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/normalization/device_layernorm_f32_instance.cpp @@ -2,7 +2,7 @@ // Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. #include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/device_layernorm.hpp" +#include "ck/tensor_operation/gpu/device/device_layernorm_impl.hpp" #include "ck/utility/data_type.hpp" #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" @@ -20,27 +20,27 @@ template using device_layernorm_f32_instances = std::tuple< // clang-format off // XDataType, GammaDataType, BetaDataType, AccDataType, YDataType, Rank, NumReduceDim, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XYSrcVectorDim, XSrcVectorSize, GammaSrcVectorSize, BetaSrcVectorSize, YDstVectorSize> - DeviceLayernorm, // fallback kernel - DeviceLayernorm, // fallback kernel - DeviceLayernorm, - DeviceLayernorm, - DeviceLayernorm, - DeviceLayernorm, - DeviceLayernorm, - DeviceLayernorm, - DeviceLayernorm, - DeviceLayernorm + DeviceLayernormImpl, // fallback kernel + DeviceLayernormImpl, // fallback kernel + DeviceLayernormImpl, + DeviceLayernormImpl, + DeviceLayernormImpl, + DeviceLayernormImpl, + DeviceLayernormImpl, + DeviceLayernormImpl, + DeviceLayernormImpl, + DeviceLayernormImpl // clang-format on >; void add_device_layernorm_f32_rank2_instances( - std::vector>& instances) + std::vector>& instances) { add_device_operation_instances(instances, device_layernorm_f32_instances<2, 1>{}); } void add_device_layernorm_f32_rank4_instances( - std::vector>& instances) + std::vector>& instances) { add_device_operation_instances(instances, device_layernorm_f32_instances<4, 3>{}); } diff --git a/profiler/include/profile_layernorm_impl.hpp b/profiler/include/profile_layernorm_impl.hpp index 0f26050b95..b5d994c129 100644 --- a/profiler/include/profile_layernorm_impl.hpp +++ b/profiler/include/profile_layernorm_impl.hpp @@ -7,7 +7,7 @@ #include "ck/ck.hpp" #include "profiler/include/data_type_enum.hpp" -#include "ck/tensor_operation/gpu/device/device_layernorm.hpp" +#include "ck/tensor_operation/gpu/device/device_layernorm_impl.hpp" #include "ck/library/utility/check_err.hpp" #include "ck/library/utility/device_memory.hpp" @@ -25,10 +25,10 @@ using F32 = float; using PassThrough = ck::tensor_operation::element_wise::PassThrough; void add_device_layernorm_f16_rank2_instances( - std::vector>&); + std::vector>&); void add_device_layernorm_f32_rank2_instances( - std::vector>&); + std::vector>&); } // namespace instance } // namespace device @@ -105,14 +105,14 @@ void profile_layernorm_impl(int do_verification, // add device normalization instances constexpr int NumReduceDim = Rank - 1; - std::vector> + std::vector> instances; if constexpr(is_same::value && is_same::value && @@ -163,6 +163,7 @@ void profile_layernorm_impl(int do_verification, strideXY, strideGamma, strideBeta, + strideXY, reduce_dim, 1e-4, x_dev.GetDeviceBuffer(), diff --git a/test/layernorm/test_layernorm_util.hpp b/test/layernorm/test_layernorm_util.hpp index 37374839c5..707fe36f86 100644 --- a/test/layernorm/test_layernorm_util.hpp +++ b/test/layernorm/test_layernorm_util.hpp @@ -9,7 +9,7 @@ #include "ck/ck.hpp" #include "ck/utility/number.hpp" -#include "ck/tensor_operation/gpu/device/device_layernorm.hpp" +#include "ck/tensor_operation/gpu/device/device_layernorm_impl.hpp" #include "ck/library/utility/check_err.hpp" #include "ck/library/utility/host_tensor.hpp" @@ -63,24 +63,24 @@ class TestLayernorm : public ::testing::Test Rank, NumReduceDim>; - using DeviceInstance = tensor_operation::device::DeviceLayernorm; + using DeviceInstance = tensor_operation::device::DeviceLayernormImpl; TestLayernorm() : ref_instance_invoker_(ReferenceInstance{}.MakeInvoker()) {} @@ -119,6 +119,7 @@ class TestLayernorm : public ::testing::Test gamma.mDesc.GetStrides().end()}, std::vector{beta.mDesc.GetStrides().begin(), beta.mDesc.GetStrides().end()}, + std::vector{y.mDesc.GetStrides().begin(), y.mDesc.GetStrides().end()}, reduceDims, 1e-4, x_dev.GetDeviceBuffer(),