diff --git a/example/27_layernorm/CMakeLists.txt b/example/27_layernorm/CMakeLists.txt index d96deae45e..94c23ce774 100644 --- a/example/27_layernorm/CMakeLists.txt +++ b/example/27_layernorm/CMakeLists.txt @@ -1 +1,2 @@ -add_example_executable(example_layernorm_blockwise layernorm_blockwise.cpp) +add_example_executable(example_layernorm_fp16 layernorm_fp16.cpp) +add_example_executable(example_layernorm_splitk_fp16 layernorm_splitk_fp16.cpp) diff --git a/example/27_layernorm/common.hpp b/example/27_layernorm/common.hpp new file mode 100644 index 0000000000..8d833a3ae9 --- /dev/null +++ b/example/27_layernorm/common.hpp @@ -0,0 +1,22 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_normalization_impl.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_normalization_splitk_impl.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_common_util.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_layernorm.hpp" diff --git a/example/27_layernorm/layernorm_fp16.cpp b/example/27_layernorm/layernorm_fp16.cpp new file mode 100644 index 0000000000..c15ffabf50 --- /dev/null +++ b/example/27_layernorm/layernorm_fp16.cpp @@ -0,0 +1,39 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" + +using XDataType = ck::half_t; +using GammaDataType = ck::half_t; +using BetaDataType = ck::half_t; +using YDataType = ck::half_t; +using ComputeDataType = float; +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +constexpr int Rank = 2; +constexpr int NumReduceDim = 1; + +using DeviceInstance = + ck::tensor_operation::device::DeviceNormalizationImpl; // OutScalarPerVector +#include "run_layernorm_example.inc" + +int main() { return run_groupnorm_example(); } diff --git a/example/27_layernorm/layernorm_splitk_fp16.cpp b/example/27_layernorm/layernorm_splitk_fp16.cpp new file mode 100644 index 0000000000..01ee7161eb --- /dev/null +++ b/example/27_layernorm/layernorm_splitk_fp16.cpp @@ -0,0 +1,40 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" + +using XDataType = ck::half_t; +using GammaDataType = ck::half_t; +using BetaDataType = ck::half_t; +using YDataType = ck::half_t; +using ComputeDataType = float; +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +constexpr int Rank = 2; +constexpr int NumReduceDim = 1; + +using DeviceInstance = + ck::tensor_operation::device::DeviceNormalizationSplitKImpl; // YScalarPerVector + +#include "run_layernorm_example.inc" + +int main() { return run_groupnorm_example(); } diff --git a/example/27_layernorm/layernorm_blockwise.cpp b/example/27_layernorm/run_layernorm_example.inc similarity index 58% rename from example/27_layernorm/layernorm_blockwise.cpp rename to example/27_layernorm/run_layernorm_example.inc index 7d91b69d04..678d8df281 100644 --- a/example/27_layernorm/layernorm_blockwise.cpp +++ b/example/27_layernorm/run_layernorm_example.inc @@ -1,58 +1,10 @@ // SPDX-License-Identifier: MIT // Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. -#include -#include -#include -#include -#include +#pragma once -#include "ck/ck.hpp" -#include "ck/utility/reduction_enums.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_normalization_impl.hpp" -#include "ck/tensor_operation/gpu/device/reduction_operator_mapping.hpp" - -#include "ck/library/utility/check_err.hpp" -#include "ck/library/utility/device_memory.hpp" -#include "ck/library/utility/host_common_util.hpp" -#include "ck/library/utility/host_tensor.hpp" -#include "ck/library/utility/host_tensor_generator.hpp" -#include "ck/library/utility/literals.hpp" -#include "ck/library/reference_tensor_operation/cpu/reference_layernorm.hpp" - -using XDataType = ck::half_t; -using GammaDataType = ck::half_t; -using BetaDataType = ck::half_t; -using YDataType = ck::half_t; -using ComputeDataType = float; -using PassThrough = ck::tensor_operation::element_wise::PassThrough; - -constexpr int Rank = 2; -constexpr int NumReduceDim = 1; - -using DeviceInstance = - ck::tensor_operation::device::DeviceNormalizationImpl; // OutScalarPerVector - -int main() +template +int run_groupnorm_example() { bool time_kernel = false; @@ -111,6 +63,10 @@ int main() return 1; }; + size_t workspace_sz = device_instance.GetWorkSpaceSize(argument_ptr.get()); + DeviceMem workspace_dev(workspace_sz); + device_instance.SetWorkSpacePointer(argument_ptr.get(), workspace_dev.GetDeviceBuffer()); + auto invoker_ptr = device_instance.MakeInvokerPointer(); invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel}); @@ -133,7 +89,8 @@ int main() ref_invoker.Run(ref_argument); y_dev.FromDevice(y.mData.data()); - pass &= ck::utils::check_err(y, host_y, "Error: Incorrect results d1", 1e-3, 1e-3); + pass &= ck::utils::check_err(y, host_y, "Error: Incorrect results", 1e-3, 1e-3); } + return (pass ? 0 : 1); } diff --git a/example/42_groupnorm/CMakeLists.txt b/example/42_groupnorm/CMakeLists.txt index a9990c5d89..e8c306ac58 100644 --- a/example/42_groupnorm/CMakeLists.txt +++ b/example/42_groupnorm/CMakeLists.txt @@ -1,2 +1,3 @@ add_example_executable(example_groupnorm_sigmoid_mul_fp16 groupnorm_sigmoid_mul_fp16.cpp) +add_example_executable(example_groupnorm_splitk_fp16 groupnorm_splitk_fp16.cpp) add_example_executable(example_groupnorm_swish_fp16 groupnorm_swish_fp16.cpp) diff --git a/example/42_groupnorm/common.hpp b/example/42_groupnorm/common.hpp index e159abf3e9..780154b26c 100644 --- a/example/42_groupnorm/common.hpp +++ b/example/42_groupnorm/common.hpp @@ -12,6 +12,7 @@ #include "ck/ck.hpp" #include "ck/utility/reduction_enums.hpp" #include "ck/tensor_operation/gpu/device/impl/device_normalization_impl.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_normalization_splitk_impl.hpp" #include "ck/tensor_operation/gpu/device/reduction_operator_mapping.hpp" #include "ck/library/utility/fill.hpp" diff --git a/example/42_groupnorm/groupnorm_splitk_fp16.cpp b/example/42_groupnorm/groupnorm_splitk_fp16.cpp new file mode 100644 index 0000000000..fd4bfe3807 --- /dev/null +++ b/example/42_groupnorm/groupnorm_splitk_fp16.cpp @@ -0,0 +1,40 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" + +constexpr int Rank = 5; +constexpr int NumReduceDim = 3; + +using XDataType = ck::half_t; +using GammaDataType = ck::half_t; +using BetaDataType = ck::half_t; +using YDataType = ck::half_t; +using ComputeDataType = float; +using YElementOp = ck::tensor_operation::element_wise::Swish; + +using DeviceInstance = + ck::tensor_operation::device::DeviceNormalizationSplitKImpl; // OutScalarPerVector + +#include "run_groupnorm_example.inc" + +int main(int argc, char* argv[]) { run_groupnorm_example(argc, argv); } diff --git a/example/42_groupnorm/run_groupnorm_example.inc b/example/42_groupnorm/run_groupnorm_example.inc index bd7eb98ca0..d1016a3b12 100644 --- a/example/42_groupnorm/run_groupnorm_example.inc +++ b/example/42_groupnorm/run_groupnorm_example.inc @@ -73,6 +73,10 @@ int run_groupnorm_example(int argc, char* argv[]) return 1; }; + size_t workspace_sz = device_instance.GetWorkSpaceSize(argument_ptr.get()); + DeviceMem workspace_dev(workspace_sz); + device_instance.SetWorkSpacePointer(argument_ptr.get(), workspace_dev.GetDeviceBuffer()); + auto invoker_ptr = device_instance.MakeInvokerPointer(); float ave_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true, true}); diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp index a383b0bb7d..580087e00c 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp @@ -807,7 +807,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle // workspace for welford intermediate mean workspace_size += gemm_welford_size * sizeof(EMeanVarDataType) + 64; - // workspace for welford intermediate mean + // workspace for welford intermediate variance workspace_size += gemm_welford_size * sizeof(EMeanVarDataType) + 64; // workspace for welford intermediate count diff --git a/include/ck/tensor_operation/gpu/device/impl/device_normalization_impl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_normalization_impl.hpp index bb62332d1a..6a8037a324 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_normalization_impl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_normalization_impl.hpp @@ -10,8 +10,7 @@ #include "ck/tensor_operation/gpu/device/device_normalization.hpp" #include "ck/tensor_operation/gpu/device/device_reduce.hpp" #include "ck/tensor_operation/gpu/device/impl/device_reduce_common.hpp" -#include "ck/tensor_operation/gpu/grid/gridwise_normalization_selector.hpp" -#include "ck/tensor_operation/gpu/grid/gridwise_set_buffer_value.hpp" +#include "ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_selector.hpp" #include "ck/host_utility/device_prop.hpp" #include "ck/host_utility/kernel_launch.hpp" @@ -20,6 +19,10 @@ namespace tensor_operation { namespace device { // Y = Normalization(X, Beta, Gamma) +// M: Invarient length +// K: Reduce length (Calculate mean and variance along K dimension) +// eg. Length = [N, C, H, W], reduce dim = [C, H, W] +// Then, M = N, K = C * H * W template & inLengths, const std::vector& inStrides, - int blkGroupSize, int numBlockTileIteration) { constexpr index_t NumInvariantDim = Rank - NumReduceDim; @@ -117,10 +119,9 @@ struct DeviceNormalizationImpl : public DeviceNormalization{}); const auto reduceLength = in_grid_desc_m_k.GetLength(Number<1>{}); - const int reduceSizePerBlock = K_BlockTileSize * numBlockTileIteration; const auto inPad_M = math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength; - const auto inPad_K = reduceSizePerBlock * blkGroupSize - reduceLength; + const auto inPad_K = K_BlockTileSize * numBlockTileIteration - reduceLength; auto in_grid_desc_m_k_padded = transform_tensor_descriptor( in_grid_desc_m_k, @@ -132,7 +133,7 @@ struct DeviceNormalizationImpl : public DeviceNormalization(gammaStrides, reduceDims); betaStrides_ = shuffle_tensor_dimensions(betaStrides, reduceDims); - long_index_t invariant_total_length; - long_index_t reduce_total_length; + long_index_t invariant_length; + long_index_t reduce_length; - std::tie(invariant_total_length, reduce_total_length) = + std::tie(invariant_length, reduce_length) = get_2d_lengths(Lengths_); - blkGroupSize_ = 1; - numBlockTileIteration_ = (reduce_total_length + K_BlockTileSize - 1) / K_BlockTileSize; + numBlockTileIteration_ = math::integer_divide_ceil(reduce_length, K_BlockTileSize); - gridSize_ = math::integer_least_multiple(invariant_total_length, M_BlockTileSize) / - M_BlockTileSize * blkGroupSize_; + gridSize_ = math::integer_divide_ceil(invariant_length, M_BlockTileSize); - x_grid_desc_m_k_ = - MakeSrc2dDescriptor(Lengths_, xStrides_, blkGroupSize_, numBlockTileIteration_); + x_grid_desc_m_k_ = MakeSrc2dDescriptor(Lengths_, xStrides_, numBlockTileIteration_); gamma_grid_desc_m_k_ = - MakeSrc2dDescriptor(Lengths_, gammaStrides_, blkGroupSize_, numBlockTileIteration_); + MakeSrc2dDescriptor(Lengths_, gammaStrides_, numBlockTileIteration_); beta_grid_desc_m_k_ = - MakeSrc2dDescriptor(Lengths_, betaStrides_, blkGroupSize_, numBlockTileIteration_); - y_grid_desc_m_k_ = - MakeSrc2dDescriptor(Lengths_, yStrides_, blkGroupSize_, numBlockTileIteration_); + MakeSrc2dDescriptor(Lengths_, betaStrides_, numBlockTileIteration_); + y_grid_desc_m_k_ = MakeSrc2dDescriptor(Lengths_, yStrides_, numBlockTileIteration_); isSweeponce_ = x_grid_desc_m_k_.GetLength(Number<1>{}) <= KThreadClusterSize * KThreadSliceSize; @@ -202,7 +199,6 @@ struct DeviceNormalizationImpl : public DeviceNormalizationinvariant_lowest_length % XSrcVectorSize != 0) return false; + + if(p_arg_->invariant_lowest_length % YDstVectorSize != 0) + return false; }; } else @@ -295,12 +294,12 @@ struct DeviceNormalizationImpl : public DeviceNormalizationLengths_[Rank - 1] % XSrcVectorSize != 0) return false; - }; - if(p_arg_->Lengths_[Rank - 1] % YDstVectorSize != 0) - { - return false; - } + if(p_arg_->Lengths_[Rank - 1] % YDstVectorSize != 0) + { + return false; + } + }; // if fastest dim is not reduced if constexpr(GammaSrcVectorDim == 0) diff --git a/include/ck/tensor_operation/gpu/device/impl/device_normalization_splitk_impl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_normalization_splitk_impl.hpp new file mode 100644 index 0000000000..0026a87593 --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/impl/device_normalization_splitk_impl.hpp @@ -0,0 +1,658 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/utility/reduction_operator.hpp" +#include "ck/tensor_operation/gpu/device/matrix_padder.hpp" +#include "ck/tensor_operation/gpu/device/device_normalization.hpp" +#include "ck/tensor_operation/gpu/device/device_reduce.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_reduce_common.hpp" +#include "ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_splitk_1st.hpp" +#include "ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_splitk_2nd.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" + +namespace ck { +template +__global__ void +kernel_normalizationSplitK1st(const XGridDesc_M_K x_grid_desc_m_k, + const MeanVarGridDesc_M_KBlock mean_var_grid_desc_m_kblock, + index_t num_k_block_tile_iteration, + const XDataType* const __restrict__ p_x_global, + MeanVarDataType* const __restrict__ p_welford_mean, + MeanVarDataType* const __restrict__ p_welford_variance, + int32_t* const __restrict__ p_welford_count) +{ + GridwiseWelford::Run(x_grid_desc_m_k, + mean_var_grid_desc_m_kblock, + num_k_block_tile_iteration, + p_x_global, + p_welford_mean, + p_welford_variance, + p_welford_count); +}; + +template +__global__ void +kernel_normalizationSplitK2nd(const MeanVarGridDesc_M_KBlock mean_var_grid_desc_m_kblock, + const CountGridDesc_M_KBlock count_grid_desc_m_kblock, + const XYGammaBetaGridDesc_M_K x_grid_desc_m_k, + const XYGammaBetaGridDesc_M_K gamma_grid_desc_m_k, + const XYGammaBetaGridDesc_M_K beta_grid_desc_m_k, + const XYGammaBetaGridDesc_M_K y_grid_desc_m_k, + index_t num_k_mean_var_count_iteration, + index_t num_k_block_tile_iteration, + index_t k_grid_size, + ComputeDataType epsilon, + const MeanVarDataType* const p_mean_global, + const MeanVarDataType* const p_variance_global, + const int32_t* const p_welford_count_global, + const XDataType* const __restrict__ p_x_global, + const GammaDataType* const __restrict__ p_gamma_global, + const BetaDataType* const __restrict__ p_beta_global, + YDataType* const __restrict__ p_y_global, + const YElementwiseOperation y_elementwise_op) +{ + GridwiseWelfordNormalization::Run(mean_var_grid_desc_m_kblock, + count_grid_desc_m_kblock, + x_grid_desc_m_k, + gamma_grid_desc_m_k, + beta_grid_desc_m_k, + y_grid_desc_m_k, + num_k_mean_var_count_iteration, + num_k_block_tile_iteration, + k_grid_size, + epsilon, + p_mean_global, + p_variance_global, + p_welford_count_global, + p_x_global, + p_gamma_global, + p_beta_global, + p_y_global, + y_elementwise_op); +}; +} // namespace ck + +namespace ck { +namespace tensor_operation { +namespace device { + +// Y = Normalization(X, Beta, Gamma) +// M: Invarient length +// K: Reduce length (Calculate mean and variance along K dimension) +// eg. Length = [N, C, H, W], reduce dim = [C, H, W] +// Then, M = N, K = C * H * W +template +struct DeviceNormalizationSplitKImpl : public DeviceNormalization +{ + using MeanVarDataType = ComputeDataType; + + static_assert(BlockSize == MThreadClusterSize * KThreadClusterSize); + static_assert( + ((GammaSrcVectorDim == 0 && MThreadSliceSize % GammaSrcVectorSize == 0) || + (GammaSrcVectorDim == 1 && KThreadSliceSize % GammaSrcVectorSize == 0)), + "Invalid thread slice sizes and/or gamma vector sizes configuration, please check!"); + + static_assert( + ((BetaSrcVectorDim == 0 && MThreadSliceSize % BetaSrcVectorSize == 0) || + (BetaSrcVectorDim == 1 && KThreadSliceSize % BetaSrcVectorSize == 0)), + "Invalid thread slice sizes and/or beta vector sizes configuration, please check!"); + + using PassThrough = tensor_operation::element_wise::PassThrough; + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + + static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize; + static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize; + + static auto MakeSrc2dDescriptor(const std::vector& inLengths, + const std::vector& inStrides, + int kBlockSize, + int numBlockTileIteration) + { + constexpr index_t NumInvariantDim = Rank - NumReduceDim; + static constexpr index_t numSrcDim = Rank; + static constexpr bool reduceAllDim = (NumInvariantDim == 0); + + const auto tupleSrcLengths = make_tuple_from_array(inLengths, Number{}); + const auto tupleSrcStrides = make_tuple_from_array(inStrides, Number{}); + + const auto inDesc = make_naive_tensor_descriptor(tupleSrcLengths, tupleSrcStrides); + + const auto in_grid_desc_m_k = [&]() { + if constexpr(reduceAllDim) + { + const auto one_dim_inDesc = transform_tensor_descriptor( + inDesc, + make_tuple(make_merge_transform(tupleSrcLengths)), + make_tuple(typename arithmetic_sequence_gen<0, numSrcDim, 1>::type{}), + make_tuple(Sequence<0>{})); + + return transform_tensor_descriptor(one_dim_inDesc, + make_tuple(make_unmerge_transform(make_tuple( + 1, one_dim_inDesc.GetLength(Number<0>{})))), + make_tuple(Sequence<0>{}), + make_tuple(Sequence<0, 1>{})); + } + else + { + using InvariantDims = typename arithmetic_sequence_gen<0, NumInvariantDim, 1>::type; + using ReduceDims = typename arithmetic_sequence_gen::type; + + const auto reduceDimLengths = + make_tuple_from_array_and_index_seq(inLengths, ReduceDims{}); + const auto invariantDimLengths = + make_tuple_from_array_and_index_seq(inLengths, InvariantDims{}); + + return transform_tensor_descriptor( + inDesc, + make_tuple(make_merge_transform(invariantDimLengths), + make_merge_transform(reduceDimLengths)), + make_tuple(InvariantDims{}, ReduceDims{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + }(); + + const auto invariantLength = in_grid_desc_m_k.GetLength(Number<0>{}); + const auto reduceLength = in_grid_desc_m_k.GetLength(Number<1>{}); + + const int reduceSizePerBlock = K_BlockTileSize * numBlockTileIteration; + const auto inPad_M = + math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength; + const auto inPad_K = reduceSizePerBlock * kBlockSize - reduceLength; + + auto in_grid_desc_m_k_padded = transform_tensor_descriptor( + in_grid_desc_m_k, + make_tuple(make_right_pad_transform(invariantLength, inPad_M), + make_right_pad_transform(reduceLength, inPad_K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return (in_grid_desc_m_k_padded); + }; + + template + static auto MakeMeanVarDescriptor_M_K(index_t M, index_t K) + { + const auto grid_desc_m_k = + make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(K, I1)); + return PadTensorDescriptor(grid_desc_m_k, make_tuple(MPerTile, KPerTile), DoPads{}); + } + + template + static auto MakeCountDescriptor_M_K(index_t M, index_t K) + { + const auto grid_desc_m_k = + make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I0, I1)); + return PadTensorDescriptor(grid_desc_m_k, make_tuple(MPerTile, KPerTile), DoPads{}); + } + + using SrcGridDesc_M_K = decltype(MakeSrc2dDescriptor({1}, {1}, 1, 1)); + using Kernel1MeanVarGridDesc_M_KBlock = + decltype(MakeMeanVarDescriptor_M_K, 1, 1>(1, 1)); + + using Kernel2MeanVarGridDesc_M_KBlock = + decltype(MakeMeanVarDescriptor_M_K, 1, 1>(1, 1)); + + using Kernel2CountGridDesc_M_KBlock = + decltype(MakeCountDescriptor_M_K, 1, 1>(1, 1)); + + using GridwiseWelford = GridwiseNormalizationSplitK1st; + + using GridwiseWelfordNormalization = + GridwiseNormalizationSplitK2nd; + + struct Argument : public BaseArgument + { + Argument(const std::vector lengths, + const std::vector xStrides, + const std::vector gammaStrides, + const std::vector betaStrides, + const std::vector yStrides, + const std::vector reduceDims, + YElementwiseOperation y_elementwise_op, + double epsilon, + const XDataType* p_x, + const GammaDataType* p_gamma, + const BetaDataType* p_beta, + YDataType* p_y) + : p_x_(p_x), + p_gamma_(p_gamma), + p_beta_(p_beta), + p_y_(p_y), + p_workspace_mean_{nullptr}, + p_workspace_var_{nullptr}, + p_workspace_count_{nullptr}, + y_elementwise_op_(y_elementwise_op) + { + epsilon_ = static_cast(epsilon); + + Lengths_ = shuffle_tensor_dimensions(lengths, reduceDims); + xStrides_ = shuffle_tensor_dimensions(xStrides, reduceDims); + yStrides_ = shuffle_tensor_dimensions(yStrides, reduceDims); + gammaStrides_ = shuffle_tensor_dimensions(gammaStrides, reduceDims); + betaStrides_ = shuffle_tensor_dimensions(betaStrides, reduceDims); + + std::tie(MRaw_, KRaw_) = get_2d_lengths(Lengths_); + + numBlockTileIteration_ = 1; + while(true) + { + int testKGridSize = + math::integer_divide_ceil(KRaw_, K_BlockTileSize * numBlockTileIteration_); + + // we want the kGridSize_ be not more than 128 + if(testKGridSize <= 128) + break; + + ++numBlockTileIteration_; + }; + + kGridSize_ = math::integer_divide_ceil(KRaw_, K_BlockTileSize * numBlockTileIteration_); + gridSize_ = math::integer_divide_ceil(MRaw_, M_BlockTileSize) * kGridSize_; + + // We do not use vector load for mean, var and count + static constexpr index_t K_MeanVarCountBlockTileSize = KThreadClusterSize; + + numMeanVarCountIteration_ = + math::integer_divide_ceil(kGridSize_, K_MeanVarCountBlockTileSize); + + x_grid_desc_m_k_ = + MakeSrc2dDescriptor(Lengths_, xStrides_, kGridSize_, numBlockTileIteration_); + gamma_grid_desc_m_k_ = + MakeSrc2dDescriptor(Lengths_, gammaStrides_, kGridSize_, numBlockTileIteration_); + beta_grid_desc_m_k_ = + MakeSrc2dDescriptor(Lengths_, betaStrides_, kGridSize_, numBlockTileIteration_); + y_grid_desc_m_k_ = + MakeSrc2dDescriptor(Lengths_, yStrides_, kGridSize_, numBlockTileIteration_); + + // We don't need to pad in K dimension for Welford1. Set KPerTile 1. + kernel1_mean_var_grid_desc_m_kblock_ = + MakeMeanVarDescriptor_M_K, M_BlockTileSize, 1>(MRaw_, + kGridSize_); + + kernel2_mean_var_grid_desc_m_kblock_ = + MakeMeanVarDescriptor_M_K, + M_BlockTileSize, + K_MeanVarCountBlockTileSize>(MRaw_, kGridSize_); + + kernel2_count_grid_desc_m_kblock_ = + MakeCountDescriptor_M_K, + M_BlockTileSize, + K_MeanVarCountBlockTileSize>(MRaw_, kGridSize_); + } + + ComputeDataType epsilon_; + + const XDataType* p_x_; + const GammaDataType* p_gamma_; + const BetaDataType* p_beta_; + YDataType* p_y_; + void* p_workspace_mean_; + void* p_workspace_var_; + void* p_workspace_count_; + + std::vector Lengths_; + std::vector xStrides_; + std::vector gammaStrides_; + std::vector betaStrides_; + std::vector yStrides_; + + YElementwiseOperation y_elementwise_op_; + + int kGridSize_; + int numMeanVarCountIteration_; + int numBlockTileIteration_; + size_t gridSize_; + + SrcGridDesc_M_K x_grid_desc_m_k_; + SrcGridDesc_M_K gamma_grid_desc_m_k_; + SrcGridDesc_M_K beta_grid_desc_m_k_; + SrcGridDesc_M_K y_grid_desc_m_k_; + + Kernel1MeanVarGridDesc_M_KBlock kernel1_mean_var_grid_desc_m_kblock_; + Kernel2MeanVarGridDesc_M_KBlock kernel2_mean_var_grid_desc_m_kblock_; + Kernel2CountGridDesc_M_KBlock kernel2_count_grid_desc_m_kblock_; + + index_t MRaw_; // invarient length + index_t KRaw_; // reduce length + }; + + struct Invoker : public BaseInvoker + { + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + if(arg.p_workspace_mean_ == nullptr || arg.p_workspace_var_ == nullptr || + arg.p_workspace_count_ == nullptr) + throw std::runtime_error("wrong! WorkSpace pointer has not been set"); + + auto kernel1 = kernel_normalizationSplitK1st; + + auto kernel2 = kernel_normalizationSplitK2nd; + + float avg_time = 0; + avg_time += launch_and_time_kernel(stream_config, + kernel1, + dim3(arg.gridSize_), + dim3(BlockSize), + 0, + arg.x_grid_desc_m_k_, + arg.kernel1_mean_var_grid_desc_m_kblock_, + arg.numBlockTileIteration_, + arg.p_x_, + static_cast(arg.p_workspace_mean_), + static_cast(arg.p_workspace_var_), + static_cast(arg.p_workspace_count_)); + + avg_time += launch_and_time_kernel(stream_config, + kernel2, + dim3(arg.gridSize_), + dim3(BlockSize), + 0, + arg.kernel2_mean_var_grid_desc_m_kblock_, + arg.kernel2_count_grid_desc_m_kblock_, + arg.x_grid_desc_m_k_, + arg.gamma_grid_desc_m_k_, + arg.beta_grid_desc_m_k_, + arg.y_grid_desc_m_k_, + arg.numMeanVarCountIteration_, + arg.numBlockTileIteration_, + arg.kGridSize_, + arg.epsilon_, + static_cast(arg.p_workspace_mean_), + static_cast(arg.p_workspace_var_), + static_cast(arg.p_workspace_count_), + arg.p_x_, + arg.p_gamma_, + arg.p_beta_, + arg.p_y_, + arg.y_elementwise_op_); + + return avg_time; + }; + + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + }; + }; + + size_t GetWorkSpaceSize(const BaseArgument* pArg) const override + { + const Argument* pArg_ = dynamic_cast(pArg); + + size_t workspace_size = 0; + + int welford_size = pArg_->MRaw_ * pArg_->kGridSize_; + + // workspace for welford intermediate mean + workspace_size += welford_size * sizeof(MeanVarDataType) + 64; + + // workspace for welford intermediate variance + workspace_size += welford_size * sizeof(MeanVarDataType) + 64; + + // workspace for welford intermediate count + workspace_size += pArg_->kGridSize_ * sizeof(int32_t) + 64; + + return (workspace_size); + }; + + void SetWorkSpacePointer(BaseArgument* pArg, void* p_workspace) const override + { + Argument* pArg_ = dynamic_cast(pArg); + + pArg_->p_workspace_ = p_workspace; + + int welford_size = pArg_->MRaw_ * pArg_->kGridSize_; + + // setup buffer used for intermediate welford mean + pArg_->p_workspace_mean_ = static_cast(pArg_->p_workspace_); + + index_t mean_space_sz = welford_size * sizeof(MeanVarDataType); + mean_space_sz = math::integer_least_multiple(mean_space_sz, 64); + + // setup buffer used for intermediate welford varirance + pArg_->p_workspace_var_ = reinterpret_cast(pArg_->p_workspace_mean_) + mean_space_sz; + + index_t variance_space_sz = welford_size * sizeof(MeanVarDataType); + variance_space_sz = math::integer_least_multiple(variance_space_sz, 64); + + // setup buffer used for intermediate welford count + pArg_->p_workspace_count_ = + reinterpret_cast(pArg_->p_workspace_var_) + variance_space_sz; + }; + + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + const Argument* p_arg_ = dynamic_cast(p_arg); + + constexpr index_t NumInvariantDim = Rank - NumReduceDim; + + if constexpr(XYVectorDim == 0) + { + if constexpr(NumInvariantDim == 0) + { + return false; + } + else + { + if(p_arg_->xStrides_[NumInvariantDim - 1] != 1) + return false; + + if(p_arg_->invariant_lowest_length % XSrcVectorSize != 0) + return false; + + if(p_arg_->invariant_lowest_length % YDstVectorSize != 0) + return false; + }; + } + else + { + if(p_arg_->xStrides_[Rank - 1] != 1) + return false; + + if(p_arg_->Lengths_[Rank - 1] % XSrcVectorSize != 0) + return false; + + if(p_arg_->Lengths_[Rank - 1] % YDstVectorSize != 0) + return false; + }; + + // if fastest dim is not reduced + if constexpr(GammaSrcVectorDim == 0) + { + if(p_arg_->gammaStrides_[NumInvariantDim - 1] != 1) + return false; + + if(p_arg_->Lengths_[Rank - 1] % GammaSrcVectorSize != 0) + return false; + } + else // if fastest dim is reduced + { + if(p_arg_->gammaStrides_[Rank - 1] != 1) + return false; + + if(p_arg_->Lengths_[Rank - 1] % GammaSrcVectorSize != 0) + return false; + } + + // if fastest dim is not reduced + if constexpr(BetaSrcVectorDim == 0) + { + if(p_arg_->betaStrides_[NumInvariantDim - 1] != 1) + return false; + + if(p_arg_->invariant_lowest_length % BetaSrcVectorSize != 0) + return false; + } + else // if fastest dim is reduced + { + if(p_arg_->betaStrides_[Rank - 1] != 1) + return false; + + if(p_arg_->Lengths_[Rank - 1] % BetaSrcVectorSize != 0) + return false; + } + + if(p_arg_->kGridSize_ <= 1) + return false; + + return true; + }; + + std::unique_ptr + MakeArgumentPointer(const std::vector lengths, + const std::vector xStrides, + const std::vector gammaStrides, + const std::vector betaStrides, + const std::vector yStrides, + const std::vector reduceDims, + double epsilon, + const void* p_x, + const void* p_gamma, + const void* p_beta, + void* p_y, + void* p_saveMean, + void* p_saveInvVar, + YElementwiseOperation y_elementwise_op) override + { + // TODO + // Optional cache of the intermediate results (mean and InvVariance) during the + // forward pass could speedup in the backward + ignore = p_saveMean; + ignore = p_saveInvVar; + + return std::make_unique(lengths, + xStrides, + gammaStrides, + betaStrides, + yStrides, + reduceDims, + y_elementwise_op, + epsilon, + static_cast(p_x), + static_cast(p_gamma), + static_cast(p_beta), + static_cast(p_y)); + }; + + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(); + }; + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceNormalizationSplitKImpl<" << BlockSize << ","; + str << "Cluster_MK_" << MThreadClusterSize << "_" << KThreadClusterSize << ","; + str << "Slice_MK_" << MThreadSliceSize << "_" << KThreadSliceSize << ","; + str << "XYSrcVectorDim_" << XYVectorDim << ","; + str << "VectorSize_X" << XSrcVectorSize << "_Gamma" << GammaSrcVectorSize << "_Beta" << BetaSrcVectorSize << "_Y" << YDstVectorSize << ">"; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_normalization_naive_variance.hpp b/include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_naive_variance.hpp similarity index 100% rename from include/ck/tensor_operation/gpu/grid/gridwise_normalization_naive_variance.hpp rename to include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_naive_variance.hpp diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_normalization_selector.hpp b/include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_selector.hpp similarity index 98% rename from include/ck/tensor_operation/gpu/grid/gridwise_normalization_selector.hpp rename to include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_selector.hpp index 37795fa569..632690e1ef 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_normalization_selector.hpp +++ b/include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_selector.hpp @@ -3,8 +3,8 @@ #pragma once -#include "ck/tensor_operation/gpu/grid/gridwise_normalization_naive_variance.hpp" -#include "ck/tensor_operation/gpu/grid/gridwise_normalization_welford_variance.hpp" +#include "ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_naive_variance.hpp" +#include "ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_welford_variance.hpp" namespace ck { template +struct GridwiseNormalizationSplitK1st +{ + static_assert((XSrcVectorDim == 0 && MThreadSliceSize % XSrcVectorSize == 0) || + (XSrcVectorDim == 1 && KThreadSliceSize % XSrcVectorSize == 0), + "Invalid thread slice sizes and/or vector sizes configuration, please check!"); + + static constexpr bool reorder_thread_cluster = (XSrcVectorDim == 0); + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + + using ThreadClusterLengths_M_K = Sequence; + + using ThreadBufferDimAccessOrder = + typename conditional, Sequence<0, 1>>::type; + + using ThreadClusterArrangeOrder = + typename conditional, Sequence<0, 1>>::type; + + static constexpr auto thread_cluster_desc = + make_cluster_descriptor(ThreadClusterLengths_M_K{}, ThreadClusterArrangeOrder{}); + + using ThreadBufferLengths_M_K = Sequence; + static constexpr auto thread_buffer_desc_m_k = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{})); + + using ThreadBufferLengths_M_1 = Sequence; + static constexpr auto thread_buffer_desc_m_1 = + make_naive_tensor_descriptor_packed(make_tuple(Number{}, I1)); + + using ThreadReduceSrcDesc_M_K = decltype(make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{}))); + using ThreadReduceDstDesc_M = + decltype(make_naive_tensor_descriptor_packed(make_tuple(Number{}))); + + using ThreadwiseWelford = + ThreadwiseWelford; + + using BlockwiseWelford = BlockwiseWelford; + + using PassThroughOp = tensor_operation::element_wise::PassThrough; + + static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize; + static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize; + static constexpr index_t K_BlockTileStepSize = KThreadClusterSize * XSrcVectorSize; + + static constexpr auto ThreadBufferNumber = Number{}; + + __device__ static int + GetKPerThread(int kRaw, int kGridSize, int block_k_cluster_id, int thread_k_cluster_id) + { + bool is_rightmost_block = block_k_cluster_id == kGridSize - 1; + + if(is_rightmost_block) + { + int left_kPerBlock = math::integer_divide_ceil(kRaw, kGridSize); + int kPerBlock = kRaw % kGridSize == 0 ? left_kPerBlock : kRaw % left_kPerBlock; + int kPerThread = + kPerBlock < K_BlockTileSize ? 0 : KThreadSliceSize * (kPerBlock / K_BlockTileSize); + int kPerBlockTail = kPerBlock - kPerThread * KThreadClusterSize; + + if(kPerBlockTail > 0) + { + static_for<0, ThreadBufferNumber, 1>{}([&](auto i) { + int thread_max_len = + (thread_k_cluster_id + 1) * XSrcVectorSize + K_BlockTileStepSize * i; + int delta = thread_max_len - kPerBlockTail; + delta = math::clamp(thread_max_len - kPerBlockTail, 0, XSrcVectorSize); + kPerThread += XSrcVectorSize - delta; + }); + } + + return kPerThread; + } + else + { + int kPerBlock = math::integer_divide_ceil(kRaw, kGridSize); + return KThreadSliceSize * (kPerBlock / K_BlockTileSize); + } + } + + // Calculate mean and variance by welford along k dimension + __device__ static void Run(const XGridDesc_M_K& x_grid_desc_m_k, + const MeanVarGridDesc_M_KBlock& mean_var_grid_desc_m_kblock, + index_t num_k_block_tile_iteration, + const XDataType* const __restrict__ p_x_global, + MeanVarDataType* const p_mean_global, + MeanVarDataType* const p_variance_global, + int32_t* const p_welford_count_global) + { + auto x_thread_buf = generate_tuple( + [&](auto) { + return StaticBuffer{}; + }, + Number{}); + + StaticBuffer + mean_thread_buf; + StaticBuffer + var_thread_buf; + + const index_t thread_local_id = get_thread_local_1d_id(); + const index_t block_global_id = get_block_1d_id(); + + const index_t k_grid_size = mean_var_grid_desc_m_kblock.GetLength(I1); + const index_t block_m_cluster_id = block_global_id / k_grid_size; + const index_t block_k_cluster_id = block_global_id % k_grid_size; + + const auto thread_cluster_idx = + thread_cluster_desc.CalculateBottomIndex(make_multi_index(thread_local_id)); + + const auto thread_m_cluster_id = thread_cluster_idx[I0]; + const auto thread_k_cluster_id = thread_cluster_idx[I1]; + + const index_t reduceSizePerBlock = K_BlockTileSize * num_k_block_tile_iteration; + + auto threadwise_x_load = ThreadwiseTensorSliceTransfer_v2( + x_grid_desc_m_k, + make_multi_index( + block_m_cluster_id * M_BlockTileSize + thread_m_cluster_id * MThreadSliceSize, + block_k_cluster_id * reduceSizePerBlock + thread_k_cluster_id * XSrcVectorSize)); + + auto mean_var_count_store_index = make_multi_index( + block_m_cluster_id * M_BlockTileSize + thread_m_cluster_id * MThreadSliceSize, + block_k_cluster_id); + + auto threadwise_welford_mean_var_store = + ThreadwiseTensorSliceTransfer_v1r3, + 1, + 1, + InMemoryDataOperationEnum::Set, + 1, + true>( + mean_var_grid_desc_m_kblock, mean_var_count_store_index, PassThroughOp{}); + + constexpr auto thread_copy_fwd_step_m_k = make_multi_index(0, K_BlockTileStepSize); + + const auto x_global_val_buf = make_dynamic_buffer( + p_x_global, x_grid_desc_m_k.GetElementSpaceSize()); + + auto mean_global_val_buf = make_dynamic_buffer( + p_mean_global, mean_var_grid_desc_m_kblock.GetElementSpaceSize()); + + auto var_global_val_buf = make_dynamic_buffer( + p_variance_global, mean_var_grid_desc_m_kblock.GetElementSpaceSize()); + + auto threadwise_welford = ThreadwiseWelford(); + int kRaw = x_grid_desc_m_k.GetTransforms()[I2].GetUpperLengths()[I0]; + threadwise_welford.max_count_ = + GetKPerThread(kRaw, k_grid_size, block_k_cluster_id, thread_k_cluster_id); + + static_for<0, MThreadSliceSize, 1>{}([&](auto I) { + mean_thread_buf(I) = type_convert(0.0f); + var_thread_buf(I) = type_convert(0.0f); + }); + + for(index_t k = 0; k < num_k_block_tile_iteration; ++k) + { + static_for<0, ThreadBufferNumber, 1>{}([&](auto i) { + threadwise_x_load.Run(x_grid_desc_m_k, + x_global_val_buf, + thread_buffer_desc_m_k, + make_tuple(I0, I0), + x_thread_buf(i)); + threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, thread_copy_fwd_step_m_k); + threadwise_welford.Run(x_thread_buf[i], mean_thread_buf, var_thread_buf); + }); + } + + int welford_count = 0; + static_for<0, MThreadSliceSize, 1>{}([&](auto I) { + if constexpr(I > 0) + block_sync_lds(); + + int count = threadwise_welford.cur_count_; + BlockwiseWelford::Run(mean_thread_buf(I), var_thread_buf(I), count); + + // The value of count is same for all I + if constexpr(I == MThreadSliceSize - 1) + welford_count = count; + }); + + if(thread_k_cluster_id == 0) + { + threadwise_welford_mean_var_store.Run(thread_buffer_desc_m_1, + make_tuple(I0, I0), + mean_thread_buf, + mean_var_grid_desc_m_kblock, + mean_global_val_buf); + + threadwise_welford_mean_var_store.Run(thread_buffer_desc_m_1, + make_tuple(I0, I0), + var_thread_buf, + mean_var_grid_desc_m_kblock, + var_global_val_buf); + + if(block_m_cluster_id == 0 && thread_m_cluster_id == 0) + p_welford_count_global[block_k_cluster_id] = welford_count; + } + } +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_splitk_2nd.hpp b/include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_splitk_2nd.hpp new file mode 100644 index 0000000000..d796d1afc0 --- /dev/null +++ b/include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_splitk_2nd.hpp @@ -0,0 +1,418 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/data_type.hpp" +#include "ck/utility/math.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_welford.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_welford.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +namespace ck { + +template +struct GridwiseNormalizationSplitK2nd +{ + static_assert((XSrcVectorDim == 0 && MThreadSliceSize % XSrcVectorSize == 0) || + (XSrcVectorDim == 1 && KThreadSliceSize % XSrcVectorSize == 0), + "Invalid thread slice sizes and/or vector sizes configuration, please check!"); + + static_assert((YDstVectorDim == 0 && MThreadSliceSize % YDstVectorSize == 0) || + (YDstVectorDim == 1 && KThreadSliceSize % YDstVectorSize == 0), + "Invalid thread slice sizes and/or vector sizes configuration, please check!"); + + static_assert(XSrcVectorSize == YDstVectorSize); + static_assert(XSrcVectorSize == GammaSrcVectorSize); + static_assert(XSrcVectorSize == BetaSrcVectorSize); + + static constexpr bool reorder_thread_cluster = (XSrcVectorDim == 0); + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + + using ThreadClusterLengths_M_K = Sequence; + + using ThreadBufferDimAccessOrder = + typename conditional, Sequence<0, 1>>::type; + + using ThreadClusterArrangeOrder = + typename conditional, Sequence<0, 1>>::type; + + static constexpr auto thread_cluster_desc = + make_cluster_descriptor(ThreadClusterLengths_M_K{}, ThreadClusterArrangeOrder{}); + + using ThreadBufferLengths_M_K = Sequence; + static constexpr auto thread_buffer_desc_m_k = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{})); + + using ThreadBufferLengths_M_1 = Sequence; + static constexpr auto thread_buffer_desc_m_1 = + make_naive_tensor_descriptor_packed(make_tuple(Number{}, I1)); + + using ThreadWelfordSrcDesc_M_1 = decltype(thread_buffer_desc_m_1); + using ThreadWelfordDstDesc_M = + decltype(make_naive_tensor_descriptor_packed(make_tuple(Number{}))); + + using ThreadwiseWelford = + ThreadwiseWelfordMerge; + + using BlockwiseWelford = BlockwiseWelford; + + using PassThroughOp = tensor_operation::element_wise::PassThrough; + + static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize; + static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize; + static constexpr index_t K_BlockTileStepSize = KThreadClusterSize * XSrcVectorSize; + + static constexpr auto ThreadBufferNumber = Number{}; + + __device__ static void Run(const MeanVarGridDesc_M_KBlock& mean_var_grid_desc_m_kblock, + const CountGridDesc_M_KBlock& count_grid_desc_m_kblock, + const XYGammaBetaGridDesc_M_K& x_grid_desc_m_k, + const XYGammaBetaGridDesc_M_K& gamma_grid_desc_m_k, + const XYGammaBetaGridDesc_M_K& beta_grid_desc_m_k, + const XYGammaBetaGridDesc_M_K& y_grid_desc_m_k, + index_t num_k_mean_var_count_iteration, + index_t num_k_block_tile_iteration, + index_t k_grid_size, + ComputeDataType epsilon, + const MeanVarDataType* const p_mean_global, + const MeanVarDataType* const p_variance_global, + const int32_t* const p_welford_count_global, + const XDataType* const __restrict__ p_x_global, + const GammaDataType* const __restrict__ p_gamma_global, + const BetaDataType* const __restrict__ p_beta_global, + YDataType* const __restrict__ p_y_global, + const YElementwiseOperation y_elementwise_op) + { + // Thread/Block id + const index_t thread_local_id = get_thread_local_1d_id(); + const index_t block_global_id = get_block_1d_id(); + const index_t block_m_cluster_id = block_global_id / k_grid_size; + const index_t block_k_cluster_id = block_global_id % k_grid_size; + const auto thread_cluster_idx = + thread_cluster_desc.CalculateBottomIndex(make_multi_index(thread_local_id)); + + const auto thread_m_cluster_id = thread_cluster_idx[I0]; + const auto thread_k_cluster_id = thread_cluster_idx[I1]; + + // Global Memory + const auto mean_global_val_buf = make_dynamic_buffer( + p_mean_global, mean_var_grid_desc_m_kblock.GetElementSpaceSize()); + + const auto var_global_val_buf = make_dynamic_buffer( + p_variance_global, mean_var_grid_desc_m_kblock.GetElementSpaceSize()); + + const auto welford_count_global_val_buf = make_dynamic_buffer( + p_welford_count_global, count_grid_desc_m_kblock.GetElementSpaceSize()); + + const auto x_global_val_buf = make_dynamic_buffer( + p_x_global, x_grid_desc_m_k.GetElementSpaceSize()); + + const auto gamma_global_val_buf = make_dynamic_buffer( + p_gamma_global, gamma_grid_desc_m_k.GetElementSpaceSize()); + + const auto beta_global_val_buf = make_dynamic_buffer( + p_beta_global, beta_grid_desc_m_k.GetElementSpaceSize()); + + auto y_global_val_buf = make_dynamic_buffer( + p_y_global, y_grid_desc_m_k.GetElementSpaceSize()); + + // VGPR + StaticBuffer + in_mean_thread_buf; + StaticBuffer + in_var_thread_buf; + StaticBuffer + in_welford_count_thread_buf; + StaticBuffer + mean_thread_buf; + StaticBuffer + var_thread_buf; + StaticBuffer + welford_count_thread_buf; + + auto x_thread_buf = generate_tuple( + [&](auto) { + return StaticBuffer{}; + }, + Number{}); + + auto gamma_thread_buf = generate_tuple( + [&](auto) { + return StaticBuffer{}; + }, + Number{}); + + auto& beta_thread_buf = gamma_thread_buf; + auto& y_thread_buf = x_thread_buf; + + // IO + auto threadwise_mean_var_load_m_kblock = + ThreadwiseTensorSliceTransfer_v2, + 1, + 1, + 1, + true>( + mean_var_grid_desc_m_kblock, + make_multi_index(block_m_cluster_id * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize, + thread_k_cluster_id)); + + auto threadwise_count_load_m_kblock = + ThreadwiseTensorSliceTransfer_v2, + 1, + 1, + 1, + true>( + count_grid_desc_m_kblock, + make_multi_index(block_m_cluster_id * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize, + thread_k_cluster_id)); + + auto threadwise_x_load = ThreadwiseTensorSliceTransfer_v2( + x_grid_desc_m_k, + make_multi_index(block_m_cluster_id * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize, + block_k_cluster_id * K_BlockTileSize * num_k_block_tile_iteration + + thread_k_cluster_id * XSrcVectorSize)); + + auto threadwise_gamma_load = + ThreadwiseTensorSliceTransfer_v2( + gamma_grid_desc_m_k, + make_multi_index(block_m_cluster_id * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize, + block_k_cluster_id * K_BlockTileSize * num_k_block_tile_iteration + + thread_k_cluster_id * GammaSrcVectorSize)); + + auto threadwise_beta_load = + ThreadwiseTensorSliceTransfer_v2( + beta_grid_desc_m_k, + make_multi_index(block_m_cluster_id * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize, + block_k_cluster_id * K_BlockTileSize * num_k_block_tile_iteration + + thread_k_cluster_id * BetaSrcVectorSize)); + + auto threadwise_y_store = + ThreadwiseTensorSliceTransfer_v1r3( + y_grid_desc_m_k, + make_multi_index(block_m_cluster_id * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize, + block_k_cluster_id * K_BlockTileSize * num_k_block_tile_iteration + + thread_k_cluster_id * YDstVectorSize), + y_elementwise_op); + + // step1: Merge mean and variance + constexpr auto mean_var_count_thread_copy_step_I0_k = + make_multi_index(I0, KThreadClusterSize); + + static_for<0, MThreadSliceSize, 1>{}([&](auto I) { + mean_thread_buf(I) = type_convert(0.0f); + var_thread_buf(I) = type_convert(0.0f); + welford_count_thread_buf(I) = 0; + }); + + for(index_t k = 0; k < num_k_mean_var_count_iteration; ++k) + { + threadwise_mean_var_load_m_kblock.Run(mean_var_grid_desc_m_kblock, + mean_global_val_buf, + thread_buffer_desc_m_1, + make_tuple(I0, I0), + in_mean_thread_buf); + + threadwise_mean_var_load_m_kblock.Run(mean_var_grid_desc_m_kblock, + var_global_val_buf, + thread_buffer_desc_m_1, + make_tuple(I0, I0), + in_var_thread_buf); + + threadwise_count_load_m_kblock.Run(count_grid_desc_m_kblock, + welford_count_global_val_buf, + thread_buffer_desc_m_1, + make_tuple(I0, I0), + in_welford_count_thread_buf); + + ThreadwiseWelford::Run(in_mean_thread_buf, + in_var_thread_buf, + in_welford_count_thread_buf, + mean_thread_buf, + var_thread_buf, + welford_count_thread_buf); + + threadwise_mean_var_load_m_kblock.MoveSrcSliceWindow( + mean_var_grid_desc_m_kblock, mean_var_count_thread_copy_step_I0_k); + threadwise_count_load_m_kblock.MoveSrcSliceWindow(count_grid_desc_m_kblock, + mean_var_count_thread_copy_step_I0_k); + } + + static_for<0, MThreadSliceSize, 1>{}([&](auto I) { + if constexpr(I > 0) + block_sync_lds(); + + BlockwiseWelford::Run( + mean_thread_buf(I), var_thread_buf(I), welford_count_thread_buf(I)); + }); + + // step2: normalization + constexpr auto thread_copy_fwd_step_m_k = make_multi_index(0, K_BlockTileStepSize); + + for(index_t k = 0; k < num_k_block_tile_iteration; ++k) + { + static_for<0, ThreadBufferNumber, 1>{}([&](auto i) { + threadwise_x_load.Run(x_grid_desc_m_k, + x_global_val_buf, + thread_buffer_desc_m_k, + make_tuple(I0, I0), + x_thread_buf(i)); + threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, thread_copy_fwd_step_m_k); + }); + + static_for<0, ThreadBufferNumber, 1>{}([&](auto i) { + threadwise_gamma_load.Run(gamma_grid_desc_m_k, + gamma_global_val_buf, + thread_buffer_desc_m_k, + make_tuple(I0, I0), + gamma_thread_buf(i)); + + threadwise_gamma_load.MoveSrcSliceWindow(gamma_grid_desc_m_k, + thread_copy_fwd_step_m_k); + }); + + static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { + auto divisor = 1 / ck::math::sqrt(var_thread_buf(iM) + epsilon); + static_for<0, ThreadBufferNumber, 1>{}([&](auto iK0) { + static_for<0, XSrcVectorSize, 1>{}([&](auto iK1) { + constexpr auto offset_m_k = + thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK1)); + + // normalize + y_thread_buf(iK0)(Number{}) = + (x_thread_buf(iK0)(Number{}) - mean_thread_buf(iM)) * + divisor; + + // gamma + y_thread_buf(iK0)(Number{}) = + y_thread_buf(iK0)(Number{}) * + gamma_thread_buf(iK0)(Number{}); + }); + }); + }); + + static_for<0, ThreadBufferNumber, 1>{}([&](auto i) { + threadwise_beta_load.Run(beta_grid_desc_m_k, + beta_global_val_buf, + thread_buffer_desc_m_k, + make_tuple(I0, I0), + beta_thread_buf(i)); + threadwise_beta_load.MoveSrcSliceWindow(beta_grid_desc_m_k, + thread_copy_fwd_step_m_k); + }); + + static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { + static_for<0, ThreadBufferNumber, 1>{}([&](auto iK0) { + static_for<0, XSrcVectorSize, 1>{}([&](auto iK1) { + constexpr auto offset_m_k = + thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK1)); + + // beta + y_thread_buf(iK0)(Number{}) = + y_thread_buf(iK0)(Number{}) + + beta_thread_buf(iK0)(Number{}); + }); + }); + }); + + static_for<0, ThreadBufferNumber, 1>{}([&](auto i) { + threadwise_y_store.Run(thread_buffer_desc_m_k, + make_tuple(I0, I0), + y_thread_buf(i), + y_grid_desc_m_k, + y_global_val_buf); + threadwise_y_store.MoveDstSliceWindow(y_grid_desc_m_k, thread_copy_fwd_step_m_k); + }); + } // end for (normalization) + } +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_normalization_welford_variance.hpp b/include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_welford_variance.hpp similarity index 100% rename from include/ck/tensor_operation/gpu/grid/gridwise_normalization_welford_variance.hpp rename to include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_welford_variance.hpp