diff --git a/CHANGELOG.md b/CHANGELOG.md index 2c3215ae44..23e7fb6274 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,7 +9,7 @@ Full documentation for Composable Kernel is not yet available. - Fixed grouped ConvBwdWeight test case failure (#524). ### Optimizations -- Optimized ... +- Improve proformance of normalization kernel ### Added - Added user tutorial (#563). diff --git a/client_example/05_layernorm/layernorm2d.cpp b/client_example/05_layernorm/layernorm2d.cpp index adb41171e1..856a4cc219 100644 --- a/client_example/05_layernorm/layernorm2d.cpp +++ b/client_example/05_layernorm/layernorm2d.cpp @@ -12,12 +12,12 @@ #include "ck/library/tensor_operation_instance/gpu/normalization.hpp" -using XDataType = ck::half_t; -using GammaDataType = ck::half_t; -using BetaDataType = ck::half_t; -using YDataType = ck::half_t; -using AccDataType = float; -using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using XDataType = ck::half_t; +using GammaDataType = ck::half_t; +using BetaDataType = ck::half_t; +using YDataType = ck::half_t; +using ComputeDataType = float; +using PassThrough = ck::tensor_operation::element_wise::PassThrough; constexpr int Rank = 2; constexpr int NumReduceDim = 1; @@ -54,7 +54,7 @@ int main(int argc, char* argv[]) using DeviceOp = ck::tensor_operation::device::DeviceNormalization; diff --git a/example/42_groupnorm/groupnorm_sigmoid_fp16.cpp b/example/42_groupnorm/groupnorm_sigmoid_fp16.cpp index e62001d669..35c7c054e0 100644 --- a/example/42_groupnorm/groupnorm_sigmoid_fp16.cpp +++ b/example/42_groupnorm/groupnorm_sigmoid_fp16.cpp @@ -23,11 +23,11 @@ constexpr int Rank = 5; constexpr int NumReduceDim = 3; -using XDataType = ck::half_t; -using GammaDataType = ck::half_t; -using BetaDataType = ck::half_t; -using YDataType = ck::half_t; -using AccDataType = float; +using XDataType = ck::half_t; +using GammaDataType = ck::half_t; +using BetaDataType = ck::half_t; +using YDataType = ck::half_t; +using ComputeDataType = float; struct YElementOp { @@ -50,7 +50,7 @@ using DeviceInstance = ck::tensor_operation::device::DeviceNormalizationImpl; ReferenceInstance ref; diff --git a/include/ck/tensor_operation/gpu/device/device_normalization.hpp b/include/ck/tensor_operation/gpu/device/device_normalization.hpp index ec17ec3d18..03601ce831 100644 --- a/include/ck/tensor_operation/gpu/device/device_normalization.hpp +++ b/include/ck/tensor_operation/gpu/device/device_normalization.hpp @@ -14,9 +14,9 @@ namespace device { template struct DeviceNormalization : public BaseOperator @@ -35,7 +35,7 @@ struct DeviceNormalization : public BaseOperator void* p_y, void* p_savedMean, void* p_savedInvVar, - AccElementwiseOperation acc_elementwise_op) = 0; + YElementwiseOperation y_elementwise_op) = 0; virtual std::unique_ptr MakeInvokerPointer() = 0; }; @@ -43,17 +43,17 @@ struct DeviceNormalization : public BaseOperator template using DeviceNormalizationPtr = std::unique_ptr>; diff --git a/include/ck/tensor_operation/gpu/device/impl/device_normalization_impl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_normalization_impl.hpp index 8cc223a886..bb62332d1a 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_normalization_impl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_normalization_impl.hpp @@ -10,46 +10,11 @@ #include "ck/tensor_operation/gpu/device/device_normalization.hpp" #include "ck/tensor_operation/gpu/device/device_reduce.hpp" #include "ck/tensor_operation/gpu/device/impl/device_reduce_common.hpp" -#include "ck/tensor_operation/gpu/grid/gridwise_normalization_welford_variance.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_normalization_selector.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_set_buffer_value.hpp" #include "ck/host_utility/device_prop.hpp" #include "ck/host_utility/kernel_launch.hpp" -namespace ck { -template -__global__ void kernel_normalization(const GridDesc_M_K x_grid_desc_m_k, - const GridDesc_M_K gamma_grid_desc_m_k, - const GridDesc_M_K beta_grid_desc_m_k, - const GridDesc_M_K y_grid_desc_m_k, - index_t num_k_block_tile_iteration, - AccDataType epsilon, - const XDataType* const __restrict__ p_x_global, - const GammaDataType* const __restrict__ p_gamma_global, - const BetaDataType* const __restrict__ p_beta_global, - YDataType* const __restrict__ p_y_global, - const AccElementwiseOperation acc_elementwise_op) -{ - GridwiseReduction::Run(x_grid_desc_m_k, - gamma_grid_desc_m_k, - beta_grid_desc_m_k, - y_grid_desc_m_k, - num_k_block_tile_iteration, - epsilon, - p_x_global, - p_gamma_global, - p_beta_global, - p_y_global, - acc_elementwise_op); -}; -} // namespace ck - namespace ck { namespace tensor_operation { namespace device { @@ -58,9 +23,9 @@ namespace device { template + index_t YDstVectorSize, + bool UseWelford = true> struct DeviceNormalizationImpl : public DeviceNormalization { + static_assert(BlockSize == MThreadClusterSize * KThreadClusterSize); static_assert( ((GammaSrcVectorDim == 0 && MThreadSliceSize % GammaSrcVectorSize == 0) || (GammaSrcVectorDim == 1 && KThreadSliceSize % GammaSrcVectorSize == 0)), @@ -167,51 +134,6 @@ struct DeviceNormalizationImpl : public DeviceNormalization; - using GridwiseNormalizationSweepOnce = - GridwiseNormalizationWelfordVariance_mk_to_mk; - struct Argument : public BaseArgument { Argument(const std::vector lengths, @@ -220,7 +142,7 @@ struct DeviceNormalizationImpl : public DeviceNormalization betaStrides, const std::vector yStrides, const std::vector reduceDims, - AccElementwiseOperation acc_elementwise_op, + YElementwiseOperation y_elementwise_op, double epsilon, const XDataType* p_x, const GammaDataType* p_gamma, @@ -230,9 +152,9 @@ struct DeviceNormalizationImpl : public DeviceNormalization(epsilon); + epsilon_ = static_cast(epsilon); Lengths_ = shuffle_tensor_dimensions(lengths, reduceDims); xStrides_ = shuffle_tensor_dimensions(xStrides, reduceDims); @@ -265,7 +187,7 @@ struct DeviceNormalizationImpl : public DeviceNormalization{}) <= KThreadClusterSize * KThreadSliceSize; } - AccDataType epsilon_; + ComputeDataType epsilon_; const XDataType* p_x_; const GammaDataType* p_gamma_; @@ -278,7 +200,7 @@ struct DeviceNormalizationImpl : public DeviceNormalization betaStrides_; std::vector yStrides_; - AccElementwiseOperation acc_elementwise_op_; + YElementwiseOperation y_elementwise_op_; int blkGroupSize_; int numBlockTileIteration_; @@ -295,23 +217,27 @@ struct DeviceNormalizationImpl : public DeviceNormalization - : kernel_normalization; + auto kernel_main = NormalizationKernelSelector(arg.isSweeponce_); float avg_time = 0; avg_time += launch_and_time_kernel(stream_config, @@ -329,7 +255,7 @@ struct DeviceNormalizationImpl : public DeviceNormalization(p_x), static_cast(p_gamma), @@ -462,8 +388,8 @@ struct DeviceNormalizationImpl : public DeviceNormalization"; // clang-format on diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_normalization_naive_variance.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_normalization_naive_variance.hpp index 89efea4d6c..792ffabcb9 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_normalization_naive_variance.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_normalization_naive_variance.hpp @@ -4,9 +4,8 @@ #pragma once #include "ck/utility/data_type.hpp" -#include "ck/utility/reduction_common.hpp" + #include "ck/utility/reduction_operator.hpp" -#include "ck/utility/reduction_functions_accumulate.hpp" #include "ck/tensor_operation/gpu/block/reduction_functions_blockwise.hpp" #include "ck/tensor_operation/gpu/thread/reduction_functions_threadwise.hpp" #include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" @@ -19,8 +18,8 @@ template ; @@ -59,19 +62,23 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk static constexpr auto thread_cluster_desc = make_cluster_descriptor(ThreadClusterLengths_M_K{}, ThreadClusterArrangeOrder{}); + using ThreadBufferLengths_M_K = Sequence; + static constexpr auto thread_buffer_desc_m_k = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{})); + using ThreadReduceSrcDesc_M_K = decltype(make_naive_tensor_descriptor_packed( - make_tuple(Number{}, Number{}))); + make_tuple(Number{}, Number{}))); using ThreadReduceDstDesc_M = decltype(make_naive_tensor_descriptor_packed(make_tuple(Number{}))); - using BlockwiseSumReduce = PartitionedBlockwiseReduction; - using ThreadwiseSumReduce = ThreadwiseReduction{}; static constexpr auto I2 = Number<2>{}; - static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize; - static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize; + static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize; + static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize; + static constexpr index_t K_BlockTileStepSize = KThreadClusterSize * XSrcVectorSize; + + static constexpr auto ThreadBufferNumber = Number{}; __device__ static void Run(const GridDesc_M_K& x_grid_desc_m_k, const GridDesc_M_K& gamma_grid_desc_m_k, const GridDesc_M_K& beta_grid_desc_m_k, const GridDesc_M_K& y_grid_desc_m_k, index_t num_k_block_tile_iteration, - AccDataType epsilon, + ComputeDataType epsilon, const XDataType* const __restrict__ p_x_global, const GammaDataType* const __restrict__ p_gamma_global, const BetaDataType* const __restrict__ p_beta_global, YDataType* const __restrict__ p_y_global, - const AccElementwiseOperation acc_elementwise_op) + const YElementwiseOperation y_elementwise_op) { - if constexpr(SweepOnce) - { - num_k_block_tile_iteration = 1; - } - // LDS - __shared__ AccDataType p_reduce_work_buffer[BlockSize]; - - auto y_global_val_buf = make_dynamic_buffer( - p_y_global, y_grid_desc_m_k.GetElementSpaceSize()); + __shared__ ComputeDataType p_reduce_work_buffer[BlockSize]; auto reduce_work_buf = make_dynamic_buffer(p_reduce_work_buffer, BlockSize); - StaticBuffer - x_thread_buf; + auto y_global_val_buf = make_dynamic_buffer( + p_y_global, y_grid_desc_m_k.GetElementSpaceSize()); - StaticBuffer - gamma_thread_buf; + auto x_thread_buf = generate_tuple( + [&](auto) { + return StaticBuffer{}; + }, + Number{}); - StaticBuffer& beta_thread_buf = gamma_thread_buf; + auto gamma_thread_buf = generate_tuple( + [&](auto) { + return StaticBuffer{}; + }, + Number{}); - StaticBuffer - y_thread_buf; + auto& beta_thread_buf = gamma_thread_buf; - StaticBuffer& x_square_thread_buf = y_thread_buf; + auto y_thread_buf = generate_tuple( + [&](auto) { + return StaticBuffer{}; + }, + Number{}); - StaticBuffer mean_thread_buf; - StaticBuffer + auto& x_square_thread_buf = y_thread_buf; + + StaticBuffer + mean_thread_buf; + StaticBuffer mean_square_thread_buf; - StaticBuffer& var_thread_buf = - mean_square_thread_buf; - - static_for<0, MThreadSliceSize, 1>{}([&](auto I) { - mean_thread_buf(I) = reduce::Add::template GetIdentityValue(); - mean_square_thread_buf(I) = reduce::Add::template GetIdentityValue(); - }); + StaticBuffer& + var_thread_buf = mean_square_thread_buf; const index_t thread_local_id = get_thread_local_1d_id(); const index_t block_global_id = get_block_1d_id(); @@ -149,12 +162,8 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk const auto thread_m_cluster_id = thread_cluster_idx[I0]; const auto thread_k_cluster_id = thread_cluster_idx[I1]; - using ThreadBufferLengths_M_K = Sequence; - constexpr auto thread_buffer_desc_m_k = make_naive_tensor_descriptor_packed( - make_tuple(Number{}, Number{})); - auto threadwise_x_load = ThreadwiseTensorSliceTransfer_v2{}([&](auto iM) { - static_for<0, KThreadSliceSize, 1>{}([&](auto iK) { - constexpr auto offset_m_k = - thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK)); - x_square_thread_buf(Number{}) = - x_thread_buf(Number{}) * x_thread_buf(Number{}); - }); - }); - - ThreadwiseSumReduce::Reduce(x_thread_buf, mean_thread_buf); - ThreadwiseSumReduce::Reduce(x_square_thread_buf, mean_square_thread_buf); - - threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, thread_copy_fwd_step_m_k); - - ++reducedTiles; - } while(reducedTiles < num_k_block_tile_iteration); - static_for<0, MThreadSliceSize, 1>{}([&](auto I) { - if constexpr(I > 0) - block_sync_lds(); - - BlockwiseSumReduce::Reduce(reduce_work_buf, mean_thread_buf(I)); - mean_thread_buf(I) = mean_thread_buf(I) / reduce_length; - - block_sync_lds(); - - BlockwiseSumReduce::Reduce(reduce_work_buf, mean_square_thread_buf(I)); - mean_square_thread_buf(I) = mean_square_thread_buf(I) / reduce_length; - - // var(x) = E[x^2] - E[x]^2 - var_thread_buf(I) = - mean_square_thread_buf(I) - (mean_thread_buf(I) * mean_thread_buf(I)); + mean_thread_buf(I) = reduce::Add::template GetIdentityValue(); + mean_square_thread_buf(I) = reduce::Add::template GetIdentityValue(); }); - // y = (x - E[x]) / sqrt(var[x] + epsilon) - auto thread_copy_tail_m_k = (num_k_block_tile_iteration - 1) * thread_copy_fwd_step_m_k; - - threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, thread_copy_bwd_step_m_k); - threadwise_gamma_load.MoveSrcSliceWindow(gamma_grid_desc_m_k, thread_copy_tail_m_k); - threadwise_beta_load.MoveSrcSliceWindow(beta_grid_desc_m_k, thread_copy_tail_m_k); - threadwise_y_store.MoveDstSliceWindow(y_grid_desc_m_k, thread_copy_tail_m_k); - - reducedTiles = 0; - do + // Separate sweep once and sweep twice pipeline + if constexpr(SweepOnce) { - if constexpr(!SweepOnce) - { + static_for<0, ThreadBufferNumber, 1>{}([&](auto i) { threadwise_x_load.Run(x_grid_desc_m_k, x_global_val_buf, thread_buffer_desc_m_k, make_tuple(I0, I0), - x_thread_buf); + x_thread_buf(i)); + + threadwise_gamma_load.Run(gamma_grid_desc_m_k, + gamma_global_val_buf, + thread_buffer_desc_m_k, + make_tuple(I0, I0), + gamma_thread_buf(i)); + + static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { + static_for<0, XSrcVectorSize, 1>{}([&](auto iK) { + constexpr auto offset_m_k = + thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK)); + x_square_thread_buf(i)(Number{}) = + x_thread_buf(i)(Number{}) * + x_thread_buf(i)(Number{}); + }); + }); + + ThreadwiseSumReduce::Reduce(x_thread_buf[i], mean_thread_buf); + ThreadwiseSumReduce::Reduce(x_square_thread_buf[i], mean_square_thread_buf); + + if constexpr(i != ThreadBufferNumber - 1) + { + threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, thread_copy_fwd_step_m_k); + threadwise_gamma_load.MoveSrcSliceWindow(gamma_grid_desc_m_k, + thread_copy_fwd_step_m_k); + } + }); + + static_for<0, MThreadSliceSize, 1>{}([&](auto I) { + if constexpr(I > 0) + block_sync_lds(); + + BlockwiseSumReduce::Reduce(reduce_work_buf, mean_thread_buf(I)); + mean_thread_buf(I) = mean_thread_buf(I) / reduce_length; + + block_sync_lds(); + + BlockwiseSumReduce::Reduce(reduce_work_buf, mean_square_thread_buf(I)); + mean_square_thread_buf(I) = mean_square_thread_buf(I) / reduce_length; + + // var(x) = E[x^2] - E[x]^2 + var_thread_buf(I) = + mean_square_thread_buf(I) - (mean_thread_buf(I) * mean_thread_buf(I)); + }); + + static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { + auto divisor = 1 / ck::math::sqrt(var_thread_buf(iM) + epsilon); + static_for<0, ThreadBufferNumber, 1>{}([&](auto iK0) { + static_for<0, XSrcVectorSize, 1>{}([&](auto iK1) { + constexpr auto offset_m_k = + thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK1)); + + // normalize + y_thread_buf(iK0)(Number{}) = + (x_thread_buf(iK0)(Number{}) - mean_thread_buf(iM)) * + divisor; + + // gamma & beta + y_thread_buf(iK0)(Number{}) = + y_thread_buf(iK0)(Number{}) * + gamma_thread_buf(iK0)(Number{}); + }); + }); + }); + + static_for<0, ThreadBufferNumber, 1>{}([&](auto i) { + threadwise_beta_load.Run(beta_grid_desc_m_k, + beta_global_val_buf, + thread_buffer_desc_m_k, + make_tuple(I0, I0), + beta_thread_buf(i)); + + if constexpr(i != ThreadBufferNumber - 1) + threadwise_beta_load.MoveSrcSliceWindow(beta_grid_desc_m_k, + thread_copy_fwd_step_m_k); + }); + + static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { + static_for<0, ThreadBufferNumber, 1>{}([&](auto iK0) { + static_for<0, XSrcVectorSize, 1>{}([&](auto iK1) { + constexpr auto offset_m_k = + thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK1)); + + // beta + y_thread_buf(iK0)(Number{}) = + y_thread_buf(iK0)(Number{}) + + beta_thread_buf(iK0)(Number{}); + }); + }); + }); + + static_for<0, ThreadBufferNumber, 1>{}([&](auto i) { + threadwise_y_store.Run(thread_buffer_desc_m_k, + make_tuple(I0, I0), + y_thread_buf(i), + y_grid_desc_m_k, + y_global_val_buf); + + if constexpr(i != ThreadBufferNumber - 1) + threadwise_y_store.MoveDstSliceWindow(y_grid_desc_m_k, + thread_copy_fwd_step_m_k); + }); + } // end of sweep once + else + { + for(index_t reducedTiles = 0; reducedTiles < num_k_block_tile_iteration; ++reducedTiles) + { + static_for<0, ThreadBufferNumber, 1>{}([&](auto i) { + threadwise_x_load.Run(x_grid_desc_m_k, + x_global_val_buf, + thread_buffer_desc_m_k, + make_tuple(I0, I0), + x_thread_buf(i)); + threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, thread_copy_fwd_step_m_k); + + static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { + static_for<0, XSrcVectorSize, 1>{}([&](auto iK) { + constexpr auto offset_m_k = + thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK)); + x_square_thread_buf(i)(Number{}) = + x_thread_buf(i)(Number{}) * + x_thread_buf(i)(Number{}); + }); + }); + + ThreadwiseSumReduce::Reduce(x_thread_buf[i], mean_thread_buf); + ThreadwiseSumReduce::Reduce(x_square_thread_buf[i], mean_square_thread_buf); + }); } - threadwise_gamma_load.Run(gamma_grid_desc_m_k, - gamma_global_val_buf, - thread_buffer_desc_m_k, - make_tuple(I0, I0), - gamma_thread_buf); + static_for<0, MThreadSliceSize, 1>{}([&](auto I) { + if constexpr(I > 0) + block_sync_lds(); - static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { - static_for<0, KThreadSliceSize, 1>{}([&](auto iK) { - constexpr auto offset_m_k = - thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK)); + BlockwiseSumReduce::Reduce(reduce_work_buf, mean_thread_buf(I)); + mean_thread_buf(I) = mean_thread_buf(I) / reduce_length; - // normalize - y_thread_buf(Number{}) = - (x_thread_buf(Number{}) - mean_thread_buf(iM)) / - sqrt(var_thread_buf(iM) + epsilon); + block_sync_lds(); - // gamma - y_thread_buf(Number{}) = - y_thread_buf(Number{}) * gamma_thread_buf(Number{}); - }); + BlockwiseSumReduce::Reduce(reduce_work_buf, mean_square_thread_buf(I)); + mean_square_thread_buf(I) = mean_square_thread_buf(I) / reduce_length; + + // var(x) = E[x^2] - E[x]^2 + var_thread_buf(I) = + mean_square_thread_buf(I) - (mean_thread_buf(I) * mean_thread_buf(I)); }); - threadwise_beta_load.Run(beta_grid_desc_m_k, - beta_global_val_buf, - thread_buffer_desc_m_k, - make_tuple(I0, I0), - beta_thread_buf); - - static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { - static_for<0, KThreadSliceSize, 1>{}([&](auto iK) { - constexpr auto offset_m_k = - thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK)); - - // beta - y_thread_buf(Number{}) = - y_thread_buf(Number{}) + beta_thread_buf(Number{}); - }); - }); - - threadwise_y_store.Run(thread_buffer_desc_m_k, - make_tuple(I0, I0), - y_thread_buf, - y_grid_desc_m_k, - y_global_val_buf); + auto thread_copy_tail_m_k = + (num_k_block_tile_iteration - 1) * ThreadBufferNumber * thread_copy_fwd_step_m_k; threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, thread_copy_bwd_step_m_k); - threadwise_gamma_load.MoveSrcSliceWindow(gamma_grid_desc_m_k, thread_copy_bwd_step_m_k); - threadwise_beta_load.MoveSrcSliceWindow(beta_grid_desc_m_k, thread_copy_bwd_step_m_k); - threadwise_y_store.MoveDstSliceWindow(y_grid_desc_m_k, thread_copy_bwd_step_m_k); + threadwise_gamma_load.MoveSrcSliceWindow(gamma_grid_desc_m_k, thread_copy_tail_m_k); + threadwise_beta_load.MoveSrcSliceWindow(beta_grid_desc_m_k, thread_copy_tail_m_k); + threadwise_y_store.MoveDstSliceWindow(y_grid_desc_m_k, thread_copy_tail_m_k); - ++reducedTiles; - } while(reducedTiles < num_k_block_tile_iteration); + for(index_t reducedTiles = 0; reducedTiles < num_k_block_tile_iteration; ++reducedTiles) + { + static_for<0, ThreadBufferNumber, 1>{}([&](auto i) { + threadwise_x_load.Run(x_grid_desc_m_k, + x_global_val_buf, + thread_buffer_desc_m_k, + make_tuple(I0, I0), + x_thread_buf(i)); + threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, thread_copy_fwd_step_m_k); + }); + + static_for<0, ThreadBufferNumber, 1>{}([&](auto i) { + threadwise_gamma_load.Run(gamma_grid_desc_m_k, + gamma_global_val_buf, + thread_buffer_desc_m_k, + make_tuple(I0, I0), + gamma_thread_buf(i)); + + threadwise_gamma_load.MoveSrcSliceWindow(gamma_grid_desc_m_k, + thread_copy_fwd_step_m_k); + }); + + static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { + auto divisor = 1 / ck::math::sqrt(var_thread_buf(iM) + epsilon); + static_for<0, ThreadBufferNumber, 1>{}([&](auto iK0) { + static_for<0, XSrcVectorSize, 1>{}([&](auto iK1) { + constexpr auto offset_m_k = + thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK1)); + + // normalize + y_thread_buf(iK0)(Number{}) = + (x_thread_buf(iK0)(Number{}) - mean_thread_buf(iM)) * + divisor; + + // gamma + y_thread_buf(iK0)(Number{}) = + y_thread_buf(iK0)(Number{}) * + gamma_thread_buf(iK0)(Number{}); + }); + }); + }); + + static_for<0, ThreadBufferNumber, 1>{}([&](auto i) { + threadwise_beta_load.Run(beta_grid_desc_m_k, + beta_global_val_buf, + thread_buffer_desc_m_k, + make_tuple(I0, I0), + beta_thread_buf(i)); + threadwise_beta_load.MoveSrcSliceWindow(beta_grid_desc_m_k, + thread_copy_fwd_step_m_k); + }); + + static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { + static_for<0, ThreadBufferNumber, 1>{}([&](auto iK0) { + static_for<0, XSrcVectorSize, 1>{}([&](auto iK1) { + constexpr auto offset_m_k = + thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK1)); + + // beta + y_thread_buf(iK0)(Number{}) = + y_thread_buf(iK0)(Number{}) + + beta_thread_buf(iK0)(Number{}); + }); + }); + }); + + static_for<0, ThreadBufferNumber, 1>{}([&](auto i) { + threadwise_y_store.Run(thread_buffer_desc_m_k, + make_tuple(I0, I0), + y_thread_buf(i), + y_grid_desc_m_k, + y_global_val_buf); + threadwise_y_store.MoveDstSliceWindow(y_grid_desc_m_k, + thread_copy_fwd_step_m_k); + }); + + threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, 2 * thread_copy_bwd_step_m_k); + threadwise_gamma_load.MoveSrcSliceWindow(gamma_grid_desc_m_k, + 2 * thread_copy_bwd_step_m_k); + threadwise_beta_load.MoveSrcSliceWindow(beta_grid_desc_m_k, + 2 * thread_copy_bwd_step_m_k); + threadwise_y_store.MoveDstSliceWindow(y_grid_desc_m_k, + 2 * thread_copy_bwd_step_m_k); + } + } // end of sweep twice } }; diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_normalization_selector.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_normalization_selector.hpp new file mode 100644 index 0000000000..37795fa569 --- /dev/null +++ b/include/ck/tensor_operation/gpu/grid/gridwise_normalization_selector.hpp @@ -0,0 +1,195 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/tensor_operation/gpu/grid/gridwise_normalization_naive_variance.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_normalization_welford_variance.hpp" + +namespace ck { +template +__global__ void kernel_normalization(const GridDesc_M_K x_grid_desc_m_k, + const GridDesc_M_K gamma_grid_desc_m_k, + const GridDesc_M_K beta_grid_desc_m_k, + const GridDesc_M_K y_grid_desc_m_k, + index_t num_k_block_tile_iteration, + ComputeDataType epsilon, + const XDataType* const __restrict__ p_x_global, + const GammaDataType* const __restrict__ p_gamma_global, + const BetaDataType* const __restrict__ p_beta_global, + YDataType* const __restrict__ p_y_global, + const YElementwiseOperation y_elementwise_op) +{ + GridwiseReduction::Run(x_grid_desc_m_k, + gamma_grid_desc_m_k, + beta_grid_desc_m_k, + y_grid_desc_m_k, + num_k_block_tile_iteration, + epsilon, + p_x_global, + p_gamma_global, + p_beta_global, + p_y_global, + y_elementwise_op); +}; + +template +auto NormalizationKernelSelector(bool isSweepOnce) +{ + using GridwiseNormalizationGenericNaive = + GridwiseNormalizationNaiveVariance_mk_to_mk; + using GridwiseNormalizationSweepOnceNaive = + GridwiseNormalizationNaiveVariance_mk_to_mk; + using GridwiseNormalizationGenericWelford = + GridwiseNormalizationWelfordVariance_mk_to_mk; + using GridwiseNormalizationSweepOnceWelford = + GridwiseNormalizationWelfordVariance_mk_to_mk; + + if constexpr(UseWelford) + { + return isSweepOnce ? kernel_normalization + : kernel_normalization; + } + else + { + return isSweepOnce ? kernel_normalization + : kernel_normalization; + } +} + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_normalization_welford_variance.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_normalization_welford_variance.hpp index 70a8c020dd..3a7ae459e5 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_normalization_welford_variance.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_normalization_welford_variance.hpp @@ -16,8 +16,8 @@ template ; @@ -56,15 +60,19 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk static constexpr auto thread_cluster_desc = make_cluster_descriptor(ThreadClusterLengths_M_K{}, ThreadClusterArrangeOrder{}); + using ThreadBufferLengths_M_K = Sequence; + static constexpr auto thread_buffer_desc_m_k = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{})); + using ThreadReduceSrcDesc_M_K = decltype(make_naive_tensor_descriptor_packed( make_tuple(Number{}, Number{}))); using ThreadReduceDstDesc_M = decltype(make_naive_tensor_descriptor_packed(make_tuple(Number{}))); using ThreadwiseWelford = - ThreadwiseWelford; + ThreadwiseWelford; - using BlockwiseWelford = BlockwiseWelford; @@ -77,10 +85,7 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize; static constexpr index_t K_BlockTileStepSize = KThreadClusterSize * XSrcVectorSize; - static constexpr auto XThreadBufferNumber = Number{}; - static constexpr auto GammaThreadBufferNumber = Number{}; - static constexpr auto BetaThreadBufferNumber = Number{}; - static constexpr auto YThreadBufferNumber = Number{}; + static constexpr auto ThreadBufferNumber = Number{}; __device__ static int GetKPerThread(const GridDesc_M_K& x_grid_desc_m_k, int thread_k_cluster_id) @@ -93,7 +98,7 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk if(kPerBlockTail > 0) { - static_for<0, XThreadBufferNumber, 1>{}([&](auto i) { + static_for<0, ThreadBufferNumber, 1>{}([&](auto i) { int thread_max_len = (thread_k_cluster_id + 1) * XSrcVectorSize + K_BlockTileStepSize * i; int delta = thread_max_len - kPerBlockTail; @@ -110,59 +115,41 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk const GridDesc_M_K& beta_grid_desc_m_k, const GridDesc_M_K& y_grid_desc_m_k, index_t num_k_block_tile_iteration, - AccDataType epsilon, + ComputeDataType epsilon, const XDataType* const __restrict__ p_x_global, const GammaDataType* const __restrict__ p_gamma_global, const BetaDataType* const __restrict__ p_beta_global, YDataType* const __restrict__ p_y_global, - const AccElementwiseOperation acc_elementwise_op) + const YElementwiseOperation y_elementwise_op) { - if constexpr(SweepOnce) - { - num_k_block_tile_iteration = 1; - } - auto y_global_val_buf = make_dynamic_buffer( p_y_global, y_grid_desc_m_k.GetElementSpaceSize()); auto x_thread_buf = generate_tuple( [&](auto) { return StaticBuffer{}; }, - Number{}); + Number{}); auto gamma_thread_buf = generate_tuple( [&](auto) { return StaticBuffer{}; }, - Number{}); + Number{}); - auto beta_thread_buf = generate_tuple( - [&](auto) { - return StaticBuffer{}; - }, - Number{}); + auto& beta_thread_buf = gamma_thread_buf; + auto& y_thread_buf = x_thread_buf; - auto y_thread_buf = generate_tuple( - [&](auto) { - return StaticBuffer{}; - }, - Number{}); - - StaticBuffer mean_thread_buf; - StaticBuffer var_thread_buf; + StaticBuffer + mean_thread_buf; + StaticBuffer + var_thread_buf; const index_t thread_local_id = get_thread_local_1d_id(); const index_t block_global_id = get_block_1d_id(); @@ -173,12 +160,8 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk const auto thread_m_cluster_id = thread_cluster_idx[I0]; const auto thread_k_cluster_id = thread_cluster_idx[I1]; - using ThreadBufferLengths_M_K = Sequence; - constexpr auto thread_buffer_desc_m_k = make_naive_tensor_descriptor_packed( - make_tuple(Number{}, Number{})); - auto threadwise_x_load = ThreadwiseTensorSliceTransfer_v2{}([&](auto I) { - mean_thread_buf(I) = type_convert(0.0f); - var_thread_buf(I) = type_convert(0.0f); + mean_thread_buf(I) = type_convert(0.0f); + var_thread_buf(I) = type_convert(0.0f); }); - for(index_t reducedTiles = 0; reducedTiles < num_k_block_tile_iteration; ++reducedTiles) + // Separate sweep once and sweep twice pipeline + if constexpr(SweepOnce) { - static_for<0, XThreadBufferNumber, 1>{}([&](auto i) { + static_for<0, ThreadBufferNumber, 1>{}([&](auto i) { threadwise_x_load.Run(x_grid_desc_m_k, x_global_val_buf, thread_buffer_desc_m_k, make_tuple(I0, I0), x_thread_buf(i)); - threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, thread_copy_fwd_step_m_k); - threadwise_welford.Run(x_thread_buf[i], mean_thread_buf, var_thread_buf); - }); - } - static_for<0, MThreadSliceSize, 1>{}([&](auto I) { - if constexpr(I > 0) - block_sync_lds(); - - int count = threadwise_welford.cur_count_; - BlockwiseWelford::Run(mean_thread_buf(I), var_thread_buf(I), count); - }); - - auto thread_copy_tail_m_k = - (num_k_block_tile_iteration - 1) * XThreadBufferNumber * thread_copy_fwd_step_m_k; - - threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, thread_copy_bwd_step_m_k); - threadwise_gamma_load.MoveSrcSliceWindow(gamma_grid_desc_m_k, thread_copy_tail_m_k); - threadwise_beta_load.MoveSrcSliceWindow(beta_grid_desc_m_k, thread_copy_tail_m_k); - threadwise_y_store.MoveDstSliceWindow(y_grid_desc_m_k, thread_copy_tail_m_k); - - for(index_t reducedTiles = 0; reducedTiles < num_k_block_tile_iteration; ++reducedTiles) - { - if constexpr(!SweepOnce) - { - static_for<0, XThreadBufferNumber, 1>{}([&](auto i) { - threadwise_x_load.Run(x_grid_desc_m_k, - x_global_val_buf, - thread_buffer_desc_m_k, - make_tuple(I0, I0), - x_thread_buf(i)); - threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, thread_copy_fwd_step_m_k); - }); - } - - static_for<0, GammaThreadBufferNumber, 1>{}([&](auto i) { threadwise_gamma_load.Run(gamma_grid_desc_m_k, gamma_global_val_buf, thread_buffer_desc_m_k, make_tuple(I0, I0), gamma_thread_buf(i)); - threadwise_gamma_load.MoveSrcSliceWindow(gamma_grid_desc_m_k, - thread_copy_fwd_step_m_k); + threadwise_welford.Run(x_thread_buf[i], mean_thread_buf, var_thread_buf); + + if constexpr(i != ThreadBufferNumber - 1) + { + threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, thread_copy_fwd_step_m_k); + threadwise_gamma_load.MoveSrcSliceWindow(gamma_grid_desc_m_k, + thread_copy_fwd_step_m_k); + } + }); + + static_for<0, MThreadSliceSize, 1>{}([&](auto I) { + if constexpr(I > 0) + block_sync_lds(); + + int count = threadwise_welford.cur_count_; + BlockwiseWelford::Run(mean_thread_buf(I), var_thread_buf(I), count); }); static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { auto divisor = 1 / ck::math::sqrt(var_thread_buf(iM) + epsilon); - static_for<0, XThreadBufferNumber, 1>{}([&](auto iK0) { + static_for<0, ThreadBufferNumber, 1>{}([&](auto iK0) { static_for<0, XSrcVectorSize, 1>{}([&](auto iK1) { constexpr auto offset_m_k = thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK1)); @@ -330,7 +293,7 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk (x_thread_buf(iK0)(Number{}) - mean_thread_buf(iM)) * divisor; - // gamma + // gamma & beta y_thread_buf(iK0)(Number{}) = y_thread_buf(iK0)(Number{}) * gamma_thread_buf(iK0)(Number{}); @@ -338,18 +301,20 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk }); }); - static_for<0, BetaThreadBufferNumber, 1>{}([&](auto i) { + static_for<0, ThreadBufferNumber, 1>{}([&](auto i) { threadwise_beta_load.Run(beta_grid_desc_m_k, beta_global_val_buf, thread_buffer_desc_m_k, make_tuple(I0, I0), beta_thread_buf(i)); - threadwise_beta_load.MoveSrcSliceWindow(beta_grid_desc_m_k, - thread_copy_fwd_step_m_k); + + if constexpr(i != ThreadBufferNumber - 1) + threadwise_beta_load.MoveSrcSliceWindow(beta_grid_desc_m_k, + thread_copy_fwd_step_m_k); }); static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { - static_for<0, XThreadBufferNumber, 1>{}([&](auto iK0) { + static_for<0, ThreadBufferNumber, 1>{}([&](auto iK0) { static_for<0, XSrcVectorSize, 1>{}([&](auto iK1) { constexpr auto offset_m_k = thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK1)); @@ -362,22 +327,134 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk }); }); - static_for<0, YThreadBufferNumber, 1>{}([&](auto i) { + static_for<0, ThreadBufferNumber, 1>{}([&](auto i) { threadwise_y_store.Run(thread_buffer_desc_m_k, make_tuple(I0, I0), y_thread_buf(i), y_grid_desc_m_k, y_global_val_buf); - threadwise_y_store.MoveDstSliceWindow(y_grid_desc_m_k, thread_copy_fwd_step_m_k); + + if constexpr(i != ThreadBufferNumber - 1) + threadwise_y_store.MoveDstSliceWindow(y_grid_desc_m_k, + thread_copy_fwd_step_m_k); + }); + } // end of sweep once + else + { + for(index_t reducedTiles = 0; reducedTiles < num_k_block_tile_iteration; ++reducedTiles) + { + static_for<0, ThreadBufferNumber, 1>{}([&](auto i) { + threadwise_x_load.Run(x_grid_desc_m_k, + x_global_val_buf, + thread_buffer_desc_m_k, + make_tuple(I0, I0), + x_thread_buf(i)); + threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, thread_copy_fwd_step_m_k); + threadwise_welford.Run(x_thread_buf[i], mean_thread_buf, var_thread_buf); + }); + } + + static_for<0, MThreadSliceSize, 1>{}([&](auto I) { + if constexpr(I > 0) + block_sync_lds(); + + int count = threadwise_welford.cur_count_; + BlockwiseWelford::Run(mean_thread_buf(I), var_thread_buf(I), count); }); - threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, 2 * thread_copy_bwd_step_m_k); - threadwise_gamma_load.MoveSrcSliceWindow(gamma_grid_desc_m_k, - 2 * thread_copy_bwd_step_m_k); - threadwise_beta_load.MoveSrcSliceWindow(beta_grid_desc_m_k, - 2 * thread_copy_bwd_step_m_k); - threadwise_y_store.MoveDstSliceWindow(y_grid_desc_m_k, 2 * thread_copy_bwd_step_m_k); - } + auto thread_copy_tail_m_k = + (num_k_block_tile_iteration - 1) * ThreadBufferNumber * thread_copy_fwd_step_m_k; + + threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, thread_copy_bwd_step_m_k); + threadwise_gamma_load.MoveSrcSliceWindow(gamma_grid_desc_m_k, thread_copy_tail_m_k); + threadwise_beta_load.MoveSrcSliceWindow(beta_grid_desc_m_k, thread_copy_tail_m_k); + threadwise_y_store.MoveDstSliceWindow(y_grid_desc_m_k, thread_copy_tail_m_k); + + for(index_t reducedTiles = 0; reducedTiles < num_k_block_tile_iteration; ++reducedTiles) + { + static_for<0, ThreadBufferNumber, 1>{}([&](auto i) { + threadwise_x_load.Run(x_grid_desc_m_k, + x_global_val_buf, + thread_buffer_desc_m_k, + make_tuple(I0, I0), + x_thread_buf(i)); + threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, thread_copy_fwd_step_m_k); + }); + + static_for<0, ThreadBufferNumber, 1>{}([&](auto i) { + threadwise_gamma_load.Run(gamma_grid_desc_m_k, + gamma_global_val_buf, + thread_buffer_desc_m_k, + make_tuple(I0, I0), + gamma_thread_buf(i)); + + threadwise_gamma_load.MoveSrcSliceWindow(gamma_grid_desc_m_k, + thread_copy_fwd_step_m_k); + }); + + static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { + auto divisor = 1 / ck::math::sqrt(var_thread_buf(iM) + epsilon); + static_for<0, ThreadBufferNumber, 1>{}([&](auto iK0) { + static_for<0, XSrcVectorSize, 1>{}([&](auto iK1) { + constexpr auto offset_m_k = + thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK1)); + + // normalize + y_thread_buf(iK0)(Number{}) = + (x_thread_buf(iK0)(Number{}) - mean_thread_buf(iM)) * + divisor; + + // gamma + y_thread_buf(iK0)(Number{}) = + y_thread_buf(iK0)(Number{}) * + gamma_thread_buf(iK0)(Number{}); + }); + }); + }); + + static_for<0, ThreadBufferNumber, 1>{}([&](auto i) { + threadwise_beta_load.Run(beta_grid_desc_m_k, + beta_global_val_buf, + thread_buffer_desc_m_k, + make_tuple(I0, I0), + beta_thread_buf(i)); + threadwise_beta_load.MoveSrcSliceWindow(beta_grid_desc_m_k, + thread_copy_fwd_step_m_k); + }); + + static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { + static_for<0, ThreadBufferNumber, 1>{}([&](auto iK0) { + static_for<0, XSrcVectorSize, 1>{}([&](auto iK1) { + constexpr auto offset_m_k = + thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK1)); + + // beta + y_thread_buf(iK0)(Number{}) = + y_thread_buf(iK0)(Number{}) + + beta_thread_buf(iK0)(Number{}); + }); + }); + }); + + static_for<0, ThreadBufferNumber, 1>{}([&](auto i) { + threadwise_y_store.Run(thread_buffer_desc_m_k, + make_tuple(I0, I0), + y_thread_buf(i), + y_grid_desc_m_k, + y_global_val_buf); + threadwise_y_store.MoveDstSliceWindow(y_grid_desc_m_k, + thread_copy_fwd_step_m_k); + }); + + threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, 2 * thread_copy_bwd_step_m_k); + threadwise_gamma_load.MoveSrcSliceWindow(gamma_grid_desc_m_k, + 2 * thread_copy_bwd_step_m_k); + threadwise_beta_load.MoveSrcSliceWindow(beta_grid_desc_m_k, + 2 * thread_copy_bwd_step_m_k); + threadwise_y_store.MoveDstSliceWindow(y_grid_desc_m_k, + 2 * thread_copy_bwd_step_m_k); + } + } // end of sweep twice } }; diff --git a/include/ck/utility/math_v2.hpp b/include/ck/utility/math_v2.hpp index 4aba0b1192..4febace0b8 100644 --- a/include/ck/utility/math_v2.hpp +++ b/include/ck/utility/math_v2.hpp @@ -83,6 +83,11 @@ static inline __host__ bool isnan(int4_t x) }; #endif +static inline __host__ half_t sqrt(half_t x) +{ + return static_cast(std::sqrt(static_cast(x))); +}; + static inline __host__ float sqrt(float x) { return std::sqrt(x); }; static inline __host__ double sqrt(double x) { return std::sqrt(x); }; @@ -158,6 +163,11 @@ static inline __device__ bool isnan(half_t x) return (xx & 0x7FFF) > 0x7C00; }; +static inline __device__ half_t sqrt(half_t x) +{ + return static_cast(__builtin_amdgcn_sqrtf(static_cast(x))); +}; + static inline __device__ float sqrt(float x) { return __builtin_amdgcn_sqrtf(x); }; static inline __device__ double sqrt(double x) { return __builtin_amdgcn_sqrt(x); }; diff --git a/library/src/tensor_operation_instance/gpu/normalization/device_normalization_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/normalization/device_normalization_f16_instance.cpp index 8994d9dcb6..beeaa3aa22 100644 --- a/library/src/tensor_operation_instance/gpu/normalization/device_normalization_f16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/normalization/device_normalization_f16_instance.cpp @@ -21,20 +21,25 @@ template // clang-format off using device_normalization_f16_instances = std::tuple < - // XDataType, GammaDataType, BetaDataType, AccDataType, YDataType, Rank, NumReduceDim, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XYSrcVectorDim, XSrcVectorSize, GammaSrcVectorDim, GammaSrcVectorSize, BetaSrcVectorDim, BetaSrcVectorSize, YDstVectorSize> - DeviceNormalizationImpl, // fallback kernel - DeviceNormalizationImpl, // fallback kernel - DeviceNormalizationImpl, // fallback kernel - DeviceNormalizationImpl, - DeviceNormalizationImpl, - DeviceNormalizationImpl, - DeviceNormalizationImpl, - DeviceNormalizationImpl, + // XDataType, GammaDataType, BetaDataType, ComputeDataType, YDataType, Rank, NumReduceDim, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XYSrcVectorDim, XSrcVectorSize, GammaSrcVectorDim, GammaSrcVectorSize, BetaSrcVectorDim, BetaSrcVectorSize, YDstVectorSize> + DeviceNormalizationImpl, // irregular size + DeviceNormalizationImpl, // irregular size + DeviceNormalizationImpl, // irregular size + DeviceNormalizationImpl, // irregular size + DeviceNormalizationImpl, // irregular size + DeviceNormalizationImpl, // irregular size + DeviceNormalizationImpl, + DeviceNormalizationImpl, + DeviceNormalizationImpl, + DeviceNormalizationImpl, DeviceNormalizationImpl, DeviceNormalizationImpl, + DeviceNormalizationImpl, DeviceNormalizationImpl, - DeviceNormalizationImpl, - DeviceNormalizationImpl + DeviceNormalizationImpl, + DeviceNormalizationImpl, + DeviceNormalizationImpl, + DeviceNormalizationImpl >; // clang-format on diff --git a/library/src/tensor_operation_instance/gpu/normalization/device_normalization_f32_instance.cpp b/library/src/tensor_operation_instance/gpu/normalization/device_normalization_f32_instance.cpp index 4a7e1fd0b9..4d236fb633 100644 --- a/library/src/tensor_operation_instance/gpu/normalization/device_normalization_f32_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/normalization/device_normalization_f32_instance.cpp @@ -19,17 +19,26 @@ using Pass = ck::tensor_operation::element_wise::PassThrough; template using device_layernorm_f32_instances = std::tuple< // clang-format off - // XDataType, GammaDataType, BetaDataType, AccDataType, YDataType, Rank, NumReduceDim, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XYSrcVectorDim, XSrcVectorSize, GammaSrcVectorSize, BetaSrcVectorSize, YDstVectorSize> - DeviceNormalizationImpl, // fallback kernel - DeviceNormalizationImpl, // fallback kernel - DeviceNormalizationImpl, - DeviceNormalizationImpl, - DeviceNormalizationImpl, - DeviceNormalizationImpl, - DeviceNormalizationImpl, + // XDataType, GammaDataType, BetaDataType, ComputeDataType, YDataType, Rank, NumReduceDim, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XYSrcVectorDim, XSrcVectorSize, GammaSrcVectorSize, BetaSrcVectorSize, YDstVectorSize> + DeviceNormalizationImpl, // irregular size + DeviceNormalizationImpl, // irregular size + DeviceNormalizationImpl, // irregular size + DeviceNormalizationImpl, // irregular size + DeviceNormalizationImpl, // irregular size + DeviceNormalizationImpl, + DeviceNormalizationImpl, + DeviceNormalizationImpl, + DeviceNormalizationImpl, + DeviceNormalizationImpl, DeviceNormalizationImpl, DeviceNormalizationImpl, - DeviceNormalizationImpl + DeviceNormalizationImpl, + DeviceNormalizationImpl, + DeviceNormalizationImpl, + DeviceNormalizationImpl, + DeviceNormalizationImpl, + DeviceNormalizationImpl, + DeviceNormalizationImpl // clang-format on >; diff --git a/profiler/include/profiler/profile_layernorm_impl.hpp b/profiler/include/profiler/profile_layernorm_impl.hpp index eb21d4a586..7dd90d0797 100644 --- a/profiler/include/profiler/profile_layernorm_impl.hpp +++ b/profiler/include/profiler/profile_layernorm_impl.hpp @@ -19,7 +19,7 @@ namespace profiler { template bool profile_layernorm_impl(int do_verification, @@ -86,7 +86,7 @@ bool profile_layernorm_impl(int do_verification, using DeviceOp = ck::tensor_operation::device::DeviceNormalization; @@ -181,8 +181,8 @@ bool profile_layernorm_impl(int do_verification, { y_dev.FromDevice(y.mData.data()); - bool pass = ck::utils::check_err( - y.mData, host_y.mData, "Error: Incorrect results d1", 1e-3, 1e-3); + bool pass = + ck::utils::check_err(y.mData, host_y.mData, "Error: Incorrect results", 1e-3, 1e-3); if(do_log) { diff --git a/test/normalization/test_groupnorm_fp16.cpp b/test/normalization/test_groupnorm_fp16.cpp index 636e522dce..60d3b13959 100644 --- a/test/normalization/test_groupnorm_fp16.cpp +++ b/test/normalization/test_groupnorm_fp16.cpp @@ -12,11 +12,11 @@ template class TestGroupnorm : public ::testing::Test { protected: - using XDataType = std::tuple_element_t<0, Tuple>; - using GammaDataType = std::tuple_element_t<1, Tuple>; - using BetaDataType = std::tuple_element_t<2, Tuple>; - using AccDataType = std::tuple_element_t<3, Tuple>; - using YDataType = std::tuple_element_t<4, Tuple>; + using XDataType = std::tuple_element_t<0, Tuple>; + using GammaDataType = std::tuple_element_t<1, Tuple>; + using BetaDataType = std::tuple_element_t<2, Tuple>; + using ComputeDataType = std::tuple_element_t<3, Tuple>; + using YDataType = std::tuple_element_t<4, Tuple>; void Run() { @@ -36,7 +36,7 @@ class TestGroupnorm : public ::testing::Test ck::profiler::profile_groupnorm_impl(true, 2, false, false, length); EXPECT_TRUE(success); } @@ -44,7 +44,7 @@ class TestGroupnorm : public ::testing::Test }; using KernelTypes = ::testing::Types< - // XDataType, GammaDataType, BetaDataType, AccDataType, YDataType> + // XDataType, GammaDataType, BetaDataType, ComputeDataType, YDataType> std::tuple>; TYPED_TEST_SUITE(TestGroupnorm, KernelTypes); diff --git a/test/normalization/test_groupnorm_fp32.cpp b/test/normalization/test_groupnorm_fp32.cpp index ef492664bf..3542f73a62 100644 --- a/test/normalization/test_groupnorm_fp32.cpp +++ b/test/normalization/test_groupnorm_fp32.cpp @@ -12,11 +12,11 @@ template class TestGroupnorm : public ::testing::Test { protected: - using XDataType = std::tuple_element_t<0, Tuple>; - using GammaDataType = std::tuple_element_t<1, Tuple>; - using BetaDataType = std::tuple_element_t<2, Tuple>; - using AccDataType = std::tuple_element_t<3, Tuple>; - using YDataType = std::tuple_element_t<4, Tuple>; + using XDataType = std::tuple_element_t<0, Tuple>; + using GammaDataType = std::tuple_element_t<1, Tuple>; + using BetaDataType = std::tuple_element_t<2, Tuple>; + using ComputeDataType = std::tuple_element_t<3, Tuple>; + using YDataType = std::tuple_element_t<4, Tuple>; void Run() { @@ -34,7 +34,7 @@ class TestGroupnorm : public ::testing::Test ck::profiler::profile_groupnorm_impl(true, 2, false, false, length); EXPECT_TRUE(success); } @@ -42,7 +42,7 @@ class TestGroupnorm : public ::testing::Test }; using KernelTypes = ::testing::Types< - // XDataType, GammaDataType, BetaDataType, AccDataType, YDataType> + // XDataType, GammaDataType, BetaDataType, ComputeDataType, YDataType> std::tuple>; TYPED_TEST_SUITE(TestGroupnorm, KernelTypes); diff --git a/test/normalization/test_layernorm2d_fp16.cpp b/test/normalization/test_layernorm2d_fp16.cpp index eeb8ec150a..d627cbe7f1 100644 --- a/test/normalization/test_layernorm2d_fp16.cpp +++ b/test/normalization/test_layernorm2d_fp16.cpp @@ -12,11 +12,11 @@ template class TestLayernorm2d : public ::testing::Test { protected: - using XDataType = std::tuple_element_t<0, Tuple>; - using GammaDataType = std::tuple_element_t<1, Tuple>; - using BetaDataType = std::tuple_element_t<2, Tuple>; - using AccDataType = std::tuple_element_t<3, Tuple>; - using YDataType = std::tuple_element_t<4, Tuple>; + using XDataType = std::tuple_element_t<0, Tuple>; + using GammaDataType = std::tuple_element_t<1, Tuple>; + using BetaDataType = std::tuple_element_t<2, Tuple>; + using ComputeDataType = std::tuple_element_t<3, Tuple>; + using YDataType = std::tuple_element_t<4, Tuple>; void Run() { @@ -29,7 +29,7 @@ class TestLayernorm2d : public ::testing::Test bool success = ck::profiler::profile_layernorm_impl(true, 2, false, false, length); EXPECT_TRUE(success); @@ -38,7 +38,7 @@ class TestLayernorm2d : public ::testing::Test }; using KernelTypes = ::testing::Types< - // XDataType, GammaDataType, BetaDataType, AccDataType, YDataType> + // XDataType, GammaDataType, BetaDataType, ComputeDataType, YDataType> std::tuple>; TYPED_TEST_SUITE(TestLayernorm2d, KernelTypes); diff --git a/test/normalization/test_layernorm2d_fp32.cpp b/test/normalization/test_layernorm2d_fp32.cpp index f555b42592..de4133aa83 100644 --- a/test/normalization/test_layernorm2d_fp32.cpp +++ b/test/normalization/test_layernorm2d_fp32.cpp @@ -12,11 +12,11 @@ template class TestLayernorm2d : public ::testing::Test { protected: - using XDataType = std::tuple_element_t<0, Tuple>; - using GammaDataType = std::tuple_element_t<1, Tuple>; - using BetaDataType = std::tuple_element_t<2, Tuple>; - using AccDataType = std::tuple_element_t<3, Tuple>; - using YDataType = std::tuple_element_t<4, Tuple>; + using XDataType = std::tuple_element_t<0, Tuple>; + using GammaDataType = std::tuple_element_t<1, Tuple>; + using BetaDataType = std::tuple_element_t<2, Tuple>; + using ComputeDataType = std::tuple_element_t<3, Tuple>; + using YDataType = std::tuple_element_t<4, Tuple>; void Run() { @@ -29,7 +29,7 @@ class TestLayernorm2d : public ::testing::Test bool success = ck::profiler::profile_layernorm_impl(true, 2, false, false, length); EXPECT_TRUE(success); @@ -38,7 +38,7 @@ class TestLayernorm2d : public ::testing::Test }; using KernelTypes = ::testing::Types< - // XDataType, GammaDataType, BetaDataType, AccDataType, YDataType> + // XDataType, GammaDataType, BetaDataType, ComputeDataType, YDataType> std::tuple>; TYPED_TEST_SUITE(TestLayernorm2d, KernelTypes);