From 9abc62c4826da16885e558310dcdddd053c2e0c3 Mon Sep 17 00:00:00 2001 From: Qianfeng Date: Fri, 28 Oct 2022 08:52:54 +0800 Subject: [PATCH] Batchnorm-forward implemented using welford method to calculate variance (#403) * Update to the batchnorm-forward API and base class * Fix leeked header including in gridwise_set_buffer_value.hpp * Add kernels and device file for batchnorm-forward welford supporting both blockwise and multi-block reduction * Update to the batchnorm-forward example to use the new batchnorm-forward device interface * Change the batchnorm-forward reference to use sequential welford method * Change to assign the workspace into four buffers in the host layer * Use GetReduceCountPerThread functor to replace the initial count for Blockwise and Multiblock welford * Tiny correction and remove un-used file under example/34_batchnorm * Renaming in the kernel arguments * Explicitly use ck::math::sqrt in batchnorm-forward kernels * Add some comments to some kernels * Tiny fix * Generalize the data types in reference_batchnorm_forward_nhwc_c * Use ck::ignore to mark un-used parameters * Move GetReduceCountPerThread functor codes from kernel to device * Remove some un-used codes in device_batchnorm_forward_impl.hpp * Tiny fix in batchnorm_forward example * Move GetReduceCountPerThread() to welford_helper.hpp * Use seperate data type for Scale and Bias * Renaming in device Op * Tiny fix in forward example * Updata to batchnorm-infer (type spliting, renaming) * Add time and bandwidth measurement to the batchnorm-forward example * Add support of elementwise operation for batchnorm forward output * Reduce object copying by passing object as reference type * Tiny change for performance * Updates for performance again * Some Renamings * Add GetActualVariance template parameter for ThreadwiseWelfordMerge * Tiny update in reference batchnorm forward nhwc/c * Move batchnorm multiblock kernel files to grid/batchnorm_multiblock sub-directory * Fuse mean and bias in the normalization calculation Co-authored-by: root Co-authored-by: rocking5566 [ROCm/composable_kernel commit: 7fa892e63e63c541756d299ab00ea7e6a6d51c39] --- example/34_batchnorm/batchnorm_common.hpp | 127 +--- .../34_batchnorm/batchnorm_forward_impl.hpp | 295 -------- .../34_batchnorm/batchnorm_forward_nhwc.cpp | 311 +++++--- example/34_batchnorm/batchnorm_infer_impl.hpp | 42 +- example/34_batchnorm/batchnorm_infer_nhwc.cpp | 56 +- .../gpu/device/device_batchnorm_forward.hpp | 21 +- .../gpu/device/device_batchnorm_infer.hpp | 4 +- .../impl/device_batchnorm_forward_impl.hpp | 711 ++++++++++++++++++ .../gpu/device/welford_helper.hpp | 89 +++ ...gridwise_multiblock_welford_first_half.hpp | 258 +++++++ ...rd_second_half_batchnorm_forward_final.hpp | 570 ++++++++++++++ ...se_batchnorm_forward_blockwise_welford.hpp | 482 ++++++++++++ .../gpu/grid/gridwise_set_buffer_value.hpp | 1 + .../gpu/thread/threadwise_welford.hpp | 59 ++ .../reference_batchnorm_forward_nhwc_c.hpp | 171 +++-- .../cpu/reference_batchnorm_infer_nhwc_c.hpp | 91 ++- 16 files changed, 2627 insertions(+), 661 deletions(-) delete mode 100644 example/34_batchnorm/batchnorm_forward_impl.hpp create mode 100644 include/ck/tensor_operation/gpu/device/impl/device_batchnorm_forward_impl.hpp create mode 100644 include/ck/tensor_operation/gpu/device/welford_helper.hpp create mode 100644 include/ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_welford_first_half.hpp create mode 100644 include/ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_welford_second_half_batchnorm_forward_final.hpp create mode 100644 include/ck/tensor_operation/gpu/grid/gridwise_batchnorm_forward_blockwise_welford.hpp diff --git a/example/34_batchnorm/batchnorm_common.hpp b/example/34_batchnorm/batchnorm_common.hpp index 6eac5dd838..bdbc8ea8b8 100644 --- a/example/34_batchnorm/batchnorm_common.hpp +++ b/example/34_batchnorm/batchnorm_common.hpp @@ -10,102 +10,17 @@ #include "ck/utility/data_type.hpp" -// binary operation used to calculate invVariance from mean and meansquare -struct InvVariance -{ - InvVariance(double epsilon) : epsilon_(epsilon){}; - - template - __host__ __device__ constexpr void operator()(T& y, const T& mean, const T& meansquare) const - { - static_assert(std::is_same::value || std::is_same::value, - "Data type is not supported by this operation!"); - - using ck::type_convert; - using ck::math::sqrt; - - T tmp_epsilon = type_convert(epsilon_); - - y = meansquare - mean * mean; - y = 1.0f / sqrt(tmp_epsilon + y); - }; - - double epsilon_; -}; - -// (4-in, 2-out) element-wise operation used to update the moving average of mean and variance -struct MovingAverage -{ - MovingAverage(double factor) : factor_(factor){}; - - template - __host__ __device__ constexpr void operator()(T& y0, - T& y1, - const T& mean, - const T& runningMean, - const T& meansquare, - const T& runningVariance) const - { - static_assert(std::is_same::value || std::is_same::value, - "Data type is not supported by this operation!"); - - using ck::type_convert; - - T tmp_factor = type_convert(factor_); - T variance = meansquare - mean * mean; - - y0 = runningMean * (type_convert(1.0f) - tmp_factor) + mean * tmp_factor; - y1 = runningVariance * (type_convert(1.0f) - tmp_factor) + variance * tmp_factor; - }; - - double factor_; -}; - -struct MovingAverageAndInvVariance -{ - MovingAverageAndInvVariance(double epsilon, double factor) - : epsilon_(epsilon), factor_(factor){}; - - template - __host__ __device__ constexpr void operator()(T& y0, // resultRunningMean - T& y1, // resultRunningVariance - T& y2, // saveInvVariance - const T& mean, - const T& runningMean, - const T& meansquare, - const T& runningVariance) const - { - static_assert(std::is_same::value || std::is_same::value, - "Data type is not supported by this operation!"); - - using ck::type_convert; - using ck::math::sqrt; - - T tmp_epsilon = type_convert(epsilon_); - T tmp_factor = type_convert(factor_); - T variance = meansquare - mean * mean; - - y0 = runningMean * (type_convert(1.0f) - tmp_factor) + mean * tmp_factor; - y1 = runningVariance * (type_convert(1.0f) - tmp_factor) + variance * tmp_factor; - - y2 = 1.0f / sqrt(tmp_epsilon + variance); - }; - - double epsilon_; - double factor_; -}; - struct NormalizeInInfer { NormalizeInInfer(double epsilon = 1e-4) : epsilon_(epsilon) {} - template + template __host__ __device__ constexpr void operator()(T1& y, const T1& x, const T2& mean, const T2& variance, - const T2& gamma, - const T2& beta) const + const T3& gamma, + const T4& beta) const { static_assert(std::is_same::value || std::is_same::value, "Data type is not supported by this operation!"); @@ -117,38 +32,10 @@ struct NormalizeInInfer tmp_x = type_convert(x); - tmp_y = ((tmp_x - mean) / sqrt(variance + type_convert(epsilon_))) * gamma + beta; - y = type_convert(tmp_y); - }; - - double epsilon_; -}; - -struct NormalizeInForward -{ - NormalizeInForward(double epsilon = 1e-4) : epsilon_(epsilon) {} - - template - __host__ __device__ constexpr void operator()(T1& y, - const T1& x, - const T2& mean, - const T2& meansquare, - const T2& gamma, - const T2& beta) const - { - static_assert(std::is_same::value || std::is_same::value, - "Data type is not supported by this operation!"); - - using ck::type_convert; - using ck::math::sqrt; - - T2 tmp_x, tmp_y; - T2 variance = meansquare - mean * mean; - - tmp_x = type_convert(x); - - tmp_y = ((tmp_x - mean) / sqrt(variance + type_convert(epsilon_))) * gamma + beta; - y = type_convert(tmp_y); + tmp_y = ((tmp_x - mean) / sqrt(variance + type_convert(epsilon_))) * + type_convert(gamma) + + type_convert(beta); + y = type_convert(tmp_y); }; double epsilon_; diff --git a/example/34_batchnorm/batchnorm_forward_impl.hpp b/example/34_batchnorm/batchnorm_forward_impl.hpp deleted file mode 100644 index 6fb7987e97..0000000000 --- a/example/34_batchnorm/batchnorm_forward_impl.hpp +++ /dev/null @@ -1,295 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. - -#pragma once - -#include -#include - -#include "ck/ck.hpp" -#include "ck/utility/reduction_operator.hpp" -#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_multiple_reduce_multiblock.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_elementwise.hpp" - -#include "batchnorm_common.hpp" - -template -int bnorm_fwd(bool time_kernel, - bool updateMovingAverage, - bool saveMeanAndInvVariance, - const std::array reduceDims, - const std::array xyLengths, - const std::array xStrides, - const std::array yStrides, - const std::array bnScaleBiasMeanVarLengths, - const std::array bnScaleBiasMeanVarStrides, - const void* p_x, - const void* p_scale, - const void* p_bias, - void* p_y, - double exponentialAverageFactor, - void* p_runningMean, - void* p_runningVariance, - double epsilon, - void* p_saveMean, - void* p_saveInvVariance, - void* p_tmp_mean, - void* p_tmp_meansquare) -{ - static_assert(NumBatchNormReduceDim < Rank, - "Invalid number of reduced dimensions for batchnorm!"); - - constexpr ck::index_t NumScaleBiasMeanVarDim = Rank - NumBatchNormReduceDim; - - using InElementwiseOperation_Mean = ck::tensor_operation::element_wise::PassThrough; - using AccElementwiseOperation_Mean = ck::tensor_operation::element_wise::UnaryDivide; - - using InElementwiseOperation_Meansquare = ck::tensor_operation::element_wise::UnarySquare; - using AccElementwiseOperation_Meansquare = ck::tensor_operation::element_wise::UnaryDivide; - - using DeviceMeanAndMeansquareInstance = - ck::tensor_operation::device::DeviceMultipleReduceMultiBlock< - 2, - InOutDataType, - AccDataType, - ck::Tuple, - Rank, - NumBatchNormReduceDim, - ck::reduce::Add, - ck::Tuple, - ck::Tuple, - ck::InMemoryDataOperationEnum::Set, - false, // PropagateNan - 256, - 16, - 16, - 1, - 1, - fastest_dim_is_reduced ? 1 : 0, - 1, - ck::Sequence<1, 1>>; - - using DeviceNormalizeInstance = ck::tensor_operation::device::DeviceElementwise< - ck::Tuple, // x, mean, - // meansquare, - // scale, bias - ck::Tuple, // y - NormalizeInForward, - Rank, - 2, // MPerthread - ck::Sequence<1, 1, 1, 1, 1>, // scalarPerVector: x, mean, meansquare, scale, bias - ck::Sequence<1>>; // scalarPerVector: y - - using DeviceInvVarianceInstance = ck::tensor_operation::device::DeviceElementwise< - ck::Tuple, // mean, meansquare - ck::Tuple, // invVariance - InvVariance, - NumScaleBiasMeanVarDim, - 2, // MPerthread - ck::Sequence<1, 1>, // scalarPerVector: mean, meansquare - ck::Sequence<1>>; // scalarPerVector: invVariance - - using DeviceMovingAverageInstance = ck::tensor_operation::device::DeviceElementwise< - ck::Tuple, // old moving mean, new mean, - // old moving variance, new - // meansquare - ck::Tuple, // updated moving mean, updated moving variance - MovingAverage, - NumScaleBiasMeanVarDim, - 4, // MPerthread - ck::Sequence<1, 1, 1, 1>, // scalarPerVector: old moving mean, new mean, old moving - // variance, new meansquare - ck::Sequence<1, 1>>; // scalarPerVector: updated moving mean, updated moving variance - - using DeviceMovingAverageAndInvVarianceInstance = - ck::tensor_operation::device::DeviceElementwise< - ck::Tuple, // old moving mean, new - // mean, old moving - // variance, new - // meansquare - ck::Tuple, // updated moving mean, updated moving - // variancem, invVariance - MovingAverageAndInvVariance, - NumScaleBiasMeanVarDim, - 4, // MPerthread - ck::Sequence<1, 1, 1, 1>, // scalarPerVector: old moving mean, new mean, old moving - // variance, new meansquare - ck::Sequence<1, 1, 1>>; // scalarPerVector: updated moving mean, updated moving variance - - auto invariantDims = get_invariant_dims(reduceDims); - std::array aligned_scaleBiasMeanVarStrides{0}; - - int i = 0; - for(auto dim : invariantDims) - { - assert(xyLengths[dim] == bnScaleBiasMeanVarLengths[i]); - - aligned_scaleBiasMeanVarStrides[dim] = bnScaleBiasMeanVarStrides[i]; - i++; - }; - - int32_t reduceLength = 1; - - for(auto dim : reduceDims) - reduceLength *= xyLengths[dim]; - - int32_t invariantLength = 1; - - for(auto dim : invariantDims) - invariantLength *= xyLengths[dim]; - - size_t total_length = static_cast(invariantLength) * reduceLength; - - float avg_time = 0.0f; - std::size_t num_bytes = 0; - - auto dev_mean_and_meansquare = DeviceMeanAndMeansquareInstance{}; - - void* p_mean = saveMeanAndInvVariance ? p_saveMean : p_tmp_mean; - - const AccDataType alpha = ck::type_convert(1.0f); - const AccDataType beta = ck::type_convert(0.0f); - - auto argument_ptr1 = dev_mean_and_meansquare.MakeArgumentPointer( - xyLengths, - xStrides, - bnScaleBiasMeanVarLengths, - {bnScaleBiasMeanVarStrides, bnScaleBiasMeanVarStrides}, - reduceDims, - {&alpha, &alpha}, - {&beta, &beta}, - p_x, - {p_mean, p_tmp_meansquare}, - ck::make_tuple(InElementwiseOperation_Mean{}, InElementwiseOperation_Meansquare{}), - ck::make_tuple(AccElementwiseOperation_Mean{reduceLength}, - AccElementwiseOperation_Meansquare{reduceLength})); - - auto dev_normalize = DeviceNormalizeInstance{}; - - auto argument_ptr2 = - dev_normalize.MakeArgumentPointer(xyLengths, - {xStrides, - aligned_scaleBiasMeanVarStrides, - aligned_scaleBiasMeanVarStrides, - aligned_scaleBiasMeanVarStrides, - aligned_scaleBiasMeanVarStrides}, - {yStrides}, - {p_x, p_mean, p_tmp_meansquare, p_scale, p_bias}, - {p_y}, - NormalizeInForward{epsilon}); - - if(!dev_mean_and_meansquare.IsSupportedArgument(argument_ptr1.get()) || - !dev_normalize.IsSupportedArgument(argument_ptr2.get())) - { - std::cout << "The runtime parameters seems not supported by the Devic, exiting!" - << std::endl; - - return (-1); - }; - - auto invoker_ptr1 = dev_mean_and_meansquare.MakeInvokerPointer(); - auto invoker_ptr2 = dev_normalize.MakeInvokerPointer(); - - avg_time += invoker_ptr1->Run(argument_ptr1.get(), StreamConfig{nullptr, time_kernel}); - avg_time += invoker_ptr2->Run(argument_ptr2.get(), StreamConfig{nullptr, time_kernel}); - - num_bytes += - (total_length * sizeof(InOutDataType) + invariantLength * 2 * sizeof(AccDataType)) + // No.1 - (total_length * (1 * sizeof(InOutDataType) + 4 * sizeof(AccDataType)) + - total_length * sizeof(InOutDataType)); // No.2 - - if(saveMeanAndInvVariance && updateMovingAverage) - { - auto dev_moving_average_inv_variance = DeviceMovingAverageAndInvVarianceInstance{}; - - auto argument_ptr3 = dev_moving_average_inv_variance.MakeArgumentPointer( - bnScaleBiasMeanVarLengths, - {bnScaleBiasMeanVarStrides, - bnScaleBiasMeanVarStrides, - bnScaleBiasMeanVarStrides, - bnScaleBiasMeanVarStrides}, - {bnScaleBiasMeanVarStrides, bnScaleBiasMeanVarStrides, bnScaleBiasMeanVarStrides}, - {p_mean, p_runningMean, p_tmp_meansquare, p_runningVariance}, - {p_runningMean, p_runningVariance, p_saveInvVariance}, - MovingAverageAndInvVariance{epsilon, exponentialAverageFactor}); - - if(!dev_moving_average_inv_variance.IsSupportedArgument(argument_ptr3.get())) - { - std::cout << "Runtime parameters not supported by the Device, exiting!" << std::endl; - - return (-1); - }; - - auto invoker_ptr3 = dev_moving_average_inv_variance.MakeInvokerPointer(); - - avg_time += invoker_ptr3->Run(argument_ptr3.get(), StreamConfig{nullptr, time_kernel}); - - num_bytes += invariantLength * (4 + 3) * sizeof(AccDataType) * 2; // No.5 - } - else if(saveMeanAndInvVariance) - { - auto dev_inv_variance = DeviceInvVarianceInstance{}; - auto argument_ptr3 = dev_inv_variance.MakeArgumentPointer( - bnScaleBiasMeanVarLengths, - {bnScaleBiasMeanVarStrides, bnScaleBiasMeanVarStrides}, - {bnScaleBiasMeanVarStrides}, - {p_mean, p_tmp_meansquare}, - {p_saveInvVariance}, - InvVariance{epsilon}); - - if(!dev_inv_variance.IsSupportedArgument(argument_ptr3.get())) - { - std::cout << "Runtime parameters not supported by the Device, exiting!" << std::endl; - - return (-1); - }; - - auto invoker_ptr3 = dev_inv_variance.MakeInvokerPointer(); - - avg_time += invoker_ptr3->Run(argument_ptr3.get(), StreamConfig{nullptr, time_kernel}); - - num_bytes += invariantLength * (2 + 1) * sizeof(AccDataType); - } - else if(updateMovingAverage) - { - auto dev_moving_average = DeviceMovingAverageInstance{}; - - auto argument_ptr3 = dev_moving_average.MakeArgumentPointer( - bnScaleBiasMeanVarLengths, - {bnScaleBiasMeanVarStrides, - bnScaleBiasMeanVarStrides, - bnScaleBiasMeanVarStrides, - bnScaleBiasMeanVarStrides}, - {bnScaleBiasMeanVarStrides, bnScaleBiasMeanVarStrides}, - {p_mean, p_runningMean, p_tmp_meansquare, p_runningVariance}, - {p_runningMean, p_runningVariance}, - MovingAverage{exponentialAverageFactor}); - - if(!dev_moving_average.IsSupportedArgument(argument_ptr3.get())) - { - std::cout << "Runtime parameters not supported by the Device, exiting!" << std::endl; - - return (-1); - }; - - auto invoker_ptr3 = dev_moving_average.MakeInvokerPointer(); - - avg_time += invoker_ptr3->Run(argument_ptr3.get(), StreamConfig{nullptr, time_kernel}); - - num_bytes += invariantLength * (4 + 2) * sizeof(AccDataType) * 2; // No.5 - }; - - if(time_kernel) - { - float gb_per_sec = num_bytes / 1.E6 / avg_time; - - std::cout << "Perf: " << avg_time << " ms, " << gb_per_sec << " GB/s" << std::endl; - }; - - return (0); -}; diff --git a/example/34_batchnorm/batchnorm_forward_nhwc.cpp b/example/34_batchnorm/batchnorm_forward_nhwc.cpp index 0b916c838a..13e408cab8 100644 --- a/example/34_batchnorm/batchnorm_forward_nhwc.cpp +++ b/example/34_batchnorm/batchnorm_forward_nhwc.cpp @@ -15,13 +15,9 @@ #include "ck/library/utility/host_tensor_generator.hpp" #include "ck/library/utility/host_common_util.hpp" #include "ck/library/reference_tensor_operation/cpu/reference_batchnorm_forward_nhwc_c.hpp" - -#include "batchnorm_forward_impl.hpp" - -template -using ReferenceBatchNormFwdInstance = - ck::tensor_operation::host::ReferenceBatchNormFwd_Input_N_H_W_C_Output_C; +#include "ck/tensor_operation/gpu/device/impl/device_batchnorm_forward_impl.hpp" +#include "ck/library/utility/host_common_util.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" static struct option long_options[] = {{"inOutLengths", required_argument, nullptr, 'D'}, {"verify", required_argument, nullptr, 'v'}, @@ -41,9 +37,10 @@ class BatchNormFwdArg bool updateMovingAverage; bool saveMeanAndInvVariance; - int data_type = 0; - int init_method = 2; - bool time_kernel = false; + int data_type = 0; + int init_method = 2; + bool time_kernel = false; + bool use_multiblock_welford = false; public: void show_usage(const char* cmd) @@ -68,6 +65,7 @@ class BatchNormFwdArg "value, 3=decimal value)" << std::endl; std::cout << "Arg5: time kernel (0=no, 1=yes)" << std::endl; + std::cout << "Arg6: use multi-block welford (0=n0, 1=yes)" << std::endl; }; int processArgs(int argc, char* argv[]) @@ -110,14 +108,15 @@ class BatchNormFwdArg }; }; - if(optind + 5 > argc) + if(optind + 6 > argc) throw std::runtime_error("Invalid cmd-line arguments, more argumetns are needed!"); data_type = std::atoi(argv[optind++]); updateMovingAverage = std::atoi(argv[optind++]); saveMeanAndInvVariance = std::atoi(argv[optind++]); init_method = std::atoi(argv[optind++]); - time_kernel = static_cast(std::atoi(argv[optind])); + time_kernel = static_cast(std::atoi(argv[optind++])); + use_multiblock_welford = static_cast(std::atoi(argv[optind])); if(data_type != 0 && data_type != 1 && data_type != 3 && data_type != 5 && data_type != 6) return (-1); @@ -128,7 +127,7 @@ class BatchNormFwdArg using namespace ck; -template +template bool bnorm_fwd_nhwc_test(bool do_verification, int init_method, bool time_kernel, @@ -273,73 +272,140 @@ bool bnorm_fwd_nhwc_test(bool do_verification, scaleBiasMeanVarStrides.end(), i_scaleBiasMeanVarStrides.begin()); - int result = 0; + using PassThroughOp = ck::tensor_operation::element_wise::PassThrough; - // used for saving meansquare - DeviceMem workspace(sizeof(AccDataType) * 2 * resultSaveMean_ref.mDesc.GetElementSpaceSize() + - 128); + using DeviceBatchNormFwdInstance = + ck::tensor_operation::device::DeviceBatchNormFwdImpl; - void* p_tmp_mean = workspace.GetDeviceBuffer(); - void* p_tmp_meansquare = - static_cast(p_tmp_mean) + - (sizeof(AccDataType) * resultSaveMean_ref.mDesc.GetElementSpaceSize() + 63) / 64 * 64; + auto batchnorm_fwd = DeviceBatchNormFwdInstance{}; - result = bnorm_fwd( - time_kernel, - updateMovingAverage, - saveMeanAndInvVariance, - {0, 1, 2}, + auto argument_ptr = batchnorm_fwd.MakeArgumentPointer( i_inOutLengths, i_inOutStrides, i_inOutStrides, + {0, 1, 2}, i_scaleBiasMeanVarLengths, i_scaleBiasMeanVarStrides, + i_scaleBiasMeanVarStrides, + i_scaleBiasMeanVarStrides, x_dev.GetDeviceBuffer(), bnScale_dev.GetDeviceBuffer(), bnBias_dev.GetDeviceBuffer(), - y_dev.GetDeviceBuffer(), - averageFactor, - updateMovingAverage ? resultRunningMean_dev.GetDeviceBuffer() : nullptr, - updateMovingAverage ? resultRunningVariance_dev.GetDeviceBuffer() : nullptr, epsilon, + PassThroughOp{}, + y_dev.GetDeviceBuffer(), saveMeanAndInvVariance ? resultSaveMean_dev.GetDeviceBuffer() : nullptr, saveMeanAndInvVariance ? resultSaveInvVariance_dev.GetDeviceBuffer() : nullptr, - p_tmp_mean, - p_tmp_meansquare); + averageFactor, + updateMovingAverage ? resultRunningMean_dev.GetDeviceBuffer() : nullptr, + updateMovingAverage ? resultRunningVariance_dev.GetDeviceBuffer() : nullptr); - if(result < 0) + if(!batchnorm_fwd.IsSupportedArgument(argument_ptr.get())) + { + std::cout << "The runtime parameters seems not supported by the BatchNorm device instance, " + "exiting!" + << std::endl; return (false); + }; + + size_t workspace_sz = batchnorm_fwd.GetWorkSpaceSize(argument_ptr.get()); + + DeviceMem workspace_dev(workspace_sz); + + batchnorm_fwd.SetWorkSpacePointer(argument_ptr.get(), workspace_dev.GetDeviceBuffer()); + + auto invoker_ptr = batchnorm_fwd.MakeInvokerPointer(); + + if(time_kernel) + { + float avg_time = 0.0f; + size_t num_bytes = 0; + + size_t total_length = inOutLengths[0] * inOutLengths[1] * inOutLengths[2] * inOutLengths[3]; + size_t invariant_length = inOutLengths[3]; + + avg_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel}); + + // inputing of x, scale, bias, outputing of y + num_bytes += + total_length * sizeof(InOutDataType) * 2 + invariant_length * sizeof(AccDataType) * 2; + + // outputing of mean, inv-variance + num_bytes += saveMeanAndInvVariance ? invariant_length * sizeof(AccDataType) * 2 : 0; + + // updating of moving mean, variance + num_bytes += updateMovingAverage ? invariant_length * sizeof(AccDataType) * 4 : 0; + + float gb_per_sec = num_bytes / 1.E6 / avg_time; + + std::cout << "Perf: " << avg_time << " ms, " << gb_per_sec << " GB/s" << std::endl; + } + else + (void)invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel}); bool pass = true; if(do_verification) { - auto batchNormFwd_ref = ReferenceBatchNormFwdInstance{}; + + using ReferenceBatchNormFwdInstance = + ck::tensor_operation::host::ReferenceBatchNormFwd_Input_N_H_W_C_Output_C; + + auto batchNormFwd_ref = ReferenceBatchNormFwdInstance{}; auto argument_ptr_ref = batchNormFwd_ref.MakeArgumentPointer( i_inOutLengths, i_inOutStrides, i_inOutStrides, + {0, 1, 2}, i_scaleBiasMeanVarLengths, i_scaleBiasMeanVarStrides, + i_scaleBiasMeanVarStrides, + i_scaleBiasMeanVarStrides, x.mData.data(), bnScale.mData.data(), bnBias.mData.data(), - y_ref.mData.data(), - 0.1, // exponentialAverageFactor - updateMovingAverage ? resultRunningMean_ref.mData.data() : nullptr, // resultRunningMean - updateMovingAverage ? resultRunningVariance_ref.mData.data() - : nullptr, // resultRunningVariance epsilon, + PassThroughOp{}, + y_ref.mData.data(), saveMeanAndInvVariance ? resultSaveMean_ref.mData.data() : nullptr, - saveMeanAndInvVariance ? resultSaveInvVariance_ref.mData.data() : nullptr); + saveMeanAndInvVariance ? resultSaveInvVariance_ref.mData.data() : nullptr, + averageFactor, + updateMovingAverage ? resultRunningMean_ref.mData.data() : nullptr, + updateMovingAverage ? resultRunningVariance_ref.mData.data() : nullptr); if(!batchNormFwd_ref.IsSupportedArgument(argument_ptr_ref.get())) { - std::cout - << "The runtime parameters seems not supported by the BatchNorm instance, exiting!" - << std::endl; - return (-2); + std::cout << "The runtime parameters seems not supported by the BatchNorm reference " + "instance, exiting!" + << std::endl; + return (false); }; auto invoker_ptr_ref = batchNormFwd_ref.MakeInvokerPointer(); @@ -365,6 +431,8 @@ bool bnorm_fwd_nhwc_test(bool do_verification, if(saveMeanAndInvVariance) { + using ck::host_common::dumpBufferToFile; + Tensor resultSaveMean(scaleBiasMeanVarLengths); Tensor resultSaveInvVariance(scaleBiasMeanVarLengths); @@ -396,70 +464,129 @@ int main(int argc, char* argv[]) if(arg.data_type == 0) { - pass = bnorm_fwd_nhwc_test(arg.do_verification, - arg.init_method, - arg.time_kernel, - arg.inOutLengths, - arg.updateMovingAverage, - arg.saveMeanAndInvVariance, - averageFactor, - epsilon); + if(arg.use_multiblock_welford) + pass = bnorm_fwd_nhwc_test(arg.do_verification, + arg.init_method, + arg.time_kernel, + arg.inOutLengths, + arg.updateMovingAverage, + arg.saveMeanAndInvVariance, + averageFactor, + epsilon); + else + pass = bnorm_fwd_nhwc_test(arg.do_verification, + arg.init_method, + arg.time_kernel, + arg.inOutLengths, + arg.updateMovingAverage, + arg.saveMeanAndInvVariance, + averageFactor, + epsilon); } else if(arg.data_type == 1) { - pass = bnorm_fwd_nhwc_test(arg.do_verification, - arg.init_method, - arg.time_kernel, - arg.inOutLengths, - arg.updateMovingAverage, - arg.saveMeanAndInvVariance, - averageFactor, - epsilon); + if(arg.use_multiblock_welford) + pass = bnorm_fwd_nhwc_test(arg.do_verification, + arg.init_method, + arg.time_kernel, + arg.inOutLengths, + arg.updateMovingAverage, + arg.saveMeanAndInvVariance, + averageFactor, + epsilon); + else + pass = bnorm_fwd_nhwc_test(arg.do_verification, + arg.init_method, + arg.time_kernel, + arg.inOutLengths, + arg.updateMovingAverage, + arg.saveMeanAndInvVariance, + averageFactor, + epsilon); } else if(arg.data_type == 3) { - pass = bnorm_fwd_nhwc_test(arg.do_verification, - arg.init_method, - arg.time_kernel, - arg.inOutLengths, - arg.updateMovingAverage, - arg.saveMeanAndInvVariance, - averageFactor, - epsilon); + if(arg.use_multiblock_welford) + pass = bnorm_fwd_nhwc_test(arg.do_verification, + arg.init_method, + arg.time_kernel, + arg.inOutLengths, + arg.updateMovingAverage, + arg.saveMeanAndInvVariance, + averageFactor, + epsilon); + else + pass = bnorm_fwd_nhwc_test(arg.do_verification, + arg.init_method, + arg.time_kernel, + arg.inOutLengths, + arg.updateMovingAverage, + arg.saveMeanAndInvVariance, + averageFactor, + epsilon); } else if(arg.data_type == 5) { - pass = bnorm_fwd_nhwc_test(arg.do_verification, - arg.init_method, - arg.time_kernel, - arg.inOutLengths, - arg.updateMovingAverage, - arg.saveMeanAndInvVariance, - averageFactor, - epsilon); + if(arg.use_multiblock_welford) + pass = bnorm_fwd_nhwc_test(arg.do_verification, + arg.init_method, + arg.time_kernel, + arg.inOutLengths, + arg.updateMovingAverage, + arg.saveMeanAndInvVariance, + averageFactor, + epsilon); + else + pass = bnorm_fwd_nhwc_test(arg.do_verification, + arg.init_method, + arg.time_kernel, + arg.inOutLengths, + arg.updateMovingAverage, + arg.saveMeanAndInvVariance, + averageFactor, + epsilon); } else if(arg.data_type == 6) { - pass = bnorm_fwd_nhwc_test(arg.do_verification, - arg.init_method, - arg.time_kernel, - arg.inOutLengths, - arg.updateMovingAverage, - arg.saveMeanAndInvVariance, - averageFactor, - epsilon); + if(arg.use_multiblock_welford) + pass = bnorm_fwd_nhwc_test(arg.do_verification, + arg.init_method, + arg.time_kernel, + arg.inOutLengths, + arg.updateMovingAverage, + arg.saveMeanAndInvVariance, + averageFactor, + epsilon); + else + pass = bnorm_fwd_nhwc_test(arg.do_verification, + arg.init_method, + arg.time_kernel, + arg.inOutLengths, + arg.updateMovingAverage, + arg.saveMeanAndInvVariance, + averageFactor, + epsilon); } } else { - pass = bnorm_fwd_nhwc_test(true, - 2, - false, // don't time kernel - {128, 16, 16, 1024}, - true, - false, - averageFactor, - epsilon); + pass = bnorm_fwd_nhwc_test(true, + 2, + false, // don't time kernel + {128, 16, 6, 512}, + true, + true, + averageFactor, + epsilon); + + pass = pass && bnorm_fwd_nhwc_test(true, + 2, + false, // don't time kernel + {128, 16, 3, 1024}, + true, + true, + averageFactor, + epsilon); }; return (pass ? 0 : 1); diff --git a/example/34_batchnorm/batchnorm_infer_impl.hpp b/example/34_batchnorm/batchnorm_infer_impl.hpp index 23c4978d7f..e457df81da 100644 --- a/example/34_batchnorm/batchnorm_infer_impl.hpp +++ b/example/34_batchnorm/batchnorm_infer_impl.hpp @@ -14,8 +14,12 @@ #include "batchnorm_common.hpp" -template @@ -26,7 +30,9 @@ int bnorm_infer( const std::array xStrides, const std::array yStrides, const std::array bnScaleBiasMeanVarLengths, - const std::array bnScaleBiasMeanVarStrides, + const std::array bnScaleStrides, + const std::array bnBiasStrides, + const std::array bnMeanVarStrides, const void* p_x, const void* p_scale, const void* p_bias, @@ -41,11 +47,11 @@ int bnorm_infer( "Invalid number of reduced dimensions for batchnorm!"); using DeviceNormalizeInstance = ck::tensor_operation::device::DeviceElementwise< - ck::Tuple, // x, mean, - // variance, - // scale, - // bias, - ck::Tuple, // y + ck::Tuple, // x, mean, + // variance, + // scale, + // bias, + ck::Tuple, // y NormalizeInInfer, Rank, 2, // MPerthread @@ -53,14 +59,18 @@ int bnorm_infer( ck::Sequence<1>>; // scalarPerVector: y auto invariantDims = get_invariant_dims(reduceDims); - std::array aligned_scaleBiasMeanVarStrides{0}; + std::array aligned_bnScaleStrides{0}; + std::array aligned_bnBiasStrides{0}; + std::array aligned_bnMeanVarStrides{0}; int i = 0; for(auto dim : invariantDims) { assert(xyLengths[dim] == bnScaleBiasMeanVarLengths[i]); - aligned_scaleBiasMeanVarStrides[dim] = bnScaleBiasMeanVarStrides[i]; + aligned_bnScaleStrides[dim] = bnScaleStrides[i]; + aligned_bnBiasStrides[dim] = bnBiasStrides[i]; + aligned_bnMeanVarStrides[dim] = bnMeanVarStrides[i]; i++; }; @@ -84,10 +94,10 @@ int bnorm_infer( auto argument_ptr1 = dev_normalize.MakeArgumentPointer( xyLengths, {xStrides, - aligned_scaleBiasMeanVarStrides, - aligned_scaleBiasMeanVarStrides, - aligned_scaleBiasMeanVarStrides, - aligned_scaleBiasMeanVarStrides}, + aligned_bnMeanVarStrides, + aligned_bnMeanVarStrides, + aligned_bnScaleStrides, + aligned_bnBiasStrides}, {yStrides}, {p_x, p_estimatedMean, p_estimatedVariance, p_scale, p_bias}, {p_y}, @@ -105,8 +115,10 @@ int bnorm_infer( avg_time += invoker_ptr1->Run(argument_ptr1.get(), StreamConfig{nullptr, time_kernel}); - num_bytes += (total_length * (1 * sizeof(InOutDataType) + 4 * sizeof(AccDataType)) + - total_length * sizeof(InOutDataType)); + num_bytes += total_length * sizeof(XDataType) + + invariantLength * + (sizeof(ScaleDataType) + sizeof(BiasDataType) + 2 * sizeof(MeanVarDataType)) + + total_length * sizeof(YDataType); if(time_kernel) { diff --git a/example/34_batchnorm/batchnorm_infer_nhwc.cpp b/example/34_batchnorm/batchnorm_infer_nhwc.cpp index 247fae6d30..d6c5dc1001 100644 --- a/example/34_batchnorm/batchnorm_infer_nhwc.cpp +++ b/example/34_batchnorm/batchnorm_infer_nhwc.cpp @@ -18,11 +18,6 @@ #include "batchnorm_infer_impl.hpp" -template -using ReferenceBatchNormInferInstance = - ck::tensor_operation::host::ReferenceBatchNormInfer_Input_N_H_W_C_Output_C; - static struct option long_options[] = {{"inOutLengths", required_argument, nullptr, 'D'}, {"verify", required_argument, nullptr, 'v'}, {"help", no_argument, nullptr, '?'}, @@ -236,21 +231,30 @@ bool bnorm_infer_nhwc_test(bool do_verification, int result = 0; - result = bnorm_infer( - time_kernel, - {0, 1, 2}, - i_inOutLengths, - i_inOutStrides, - i_inOutStrides, - i_scaleBiasMeanVarLengths, - i_scaleBiasMeanVarStrides, - x_dev.GetDeviceBuffer(), - bnScale_dev.GetDeviceBuffer(), - bnBias_dev.GetDeviceBuffer(), - epsilon, - estimatedMean_dev.GetDeviceBuffer(), - estimatedVariance_dev.GetDeviceBuffer(), - y_dev.GetDeviceBuffer()); + result = bnorm_infer(time_kernel, + {0, 1, 2}, + i_inOutLengths, + i_inOutStrides, + i_inOutStrides, + i_scaleBiasMeanVarLengths, + i_scaleBiasMeanVarStrides, + i_scaleBiasMeanVarStrides, + i_scaleBiasMeanVarStrides, + x_dev.GetDeviceBuffer(), + bnScale_dev.GetDeviceBuffer(), + bnBias_dev.GetDeviceBuffer(), + epsilon, + estimatedMean_dev.GetDeviceBuffer(), + estimatedVariance_dev.GetDeviceBuffer(), + y_dev.GetDeviceBuffer()); if(result < 0) return (false); @@ -259,7 +263,15 @@ bool bnorm_infer_nhwc_test(bool do_verification, if(do_verification) { - auto batchNormInfer_ref = ReferenceBatchNormInferInstance{}; + using ReferenceBatchNormInferInstance = + ck::tensor_operation::host::ReferenceBatchNormInfer_Input_N_H_W_C_Output_C< + InOutDataType, + InOutDataType, + AccDataType, + AccDataType, + AccDataType, + AccDataType>; + auto batchNormInfer_ref = ReferenceBatchNormInferInstance{}; auto argument_ptr_ref = batchNormInfer_ref.MakeArgumentPointer(i_inOutLengths, @@ -267,6 +279,8 @@ bool bnorm_infer_nhwc_test(bool do_verification, i_inOutStrides, i_scaleBiasMeanVarLengths, i_scaleBiasMeanVarStrides, + i_scaleBiasMeanVarStrides, + i_scaleBiasMeanVarStrides, x.mData.data(), bnScale.mData.data(), bnBias.mData.data(), diff --git a/include/ck/tensor_operation/gpu/device/device_batchnorm_forward.hpp b/include/ck/tensor_operation/gpu/device/device_batchnorm_forward.hpp index 842ad5d459..019f377a5c 100644 --- a/include/ck/tensor_operation/gpu/device/device_batchnorm_forward.hpp +++ b/include/ck/tensor_operation/gpu/device/device_batchnorm_forward.hpp @@ -13,31 +13,36 @@ namespace ck { namespace tensor_operation { namespace device { -template +template struct DeviceBatchNormFwd : public BaseOperator { virtual std::unique_ptr MakeArgumentPointer( const std::array xyLengths, const std::array xStrides, const std::array yStrides, + const std::array reduceDims, const std::array bnScaleBiasMeanVarLengths, - const std::array bnScaleBiasMeanVarStrides, + const std::array bnScaleStrides, + const std::array bnBiasStrides, + const std::array bnMeanVarStrides, const void* p_x, const void* bnScale, const void* bnBias, + double epsilon, + const YElementwiseOp y_elementwise_op, void* p_y, + void* resultSaveMean, + void* resultSaveInvVariance, double exponentialAverageFactor, void* resultRunningMean, - void* resultRunningVariance, - double epsilon, - void* resultSaveMean, - void* resultSaveInvVariance) = 0; + void* resultRunningVariance) = 0; virtual std::unique_ptr MakeInvokerPointer() = 0; }; -template -using DeviceBatchNormFwdPtr = std::unique_ptr>; +template +using DeviceBatchNormFwdPtr = + std::unique_ptr>; } // namespace device } // namespace tensor_operation diff --git a/include/ck/tensor_operation/gpu/device/device_batchnorm_infer.hpp b/include/ck/tensor_operation/gpu/device/device_batchnorm_infer.hpp index 785d64bf14..fabb2394c5 100644 --- a/include/ck/tensor_operation/gpu/device/device_batchnorm_infer.hpp +++ b/include/ck/tensor_operation/gpu/device/device_batchnorm_infer.hpp @@ -21,7 +21,9 @@ struct DeviceBatchNormInfer : public BaseOperator const std::array xStrides, const std::array yStrides, const std::array bnScaleBiasMeanVarLengths, - const std::array bnScaleBiasMeanVarStrides, + const std::array bnScaleStrides, + const std::array bnBiasStrides, + const std::array bnMeanVarStrides, const void* p_x, const void* bnScale, const void* bnBias, diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batchnorm_forward_impl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batchnorm_forward_impl.hpp new file mode 100644 index 0000000000..220456955d --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/impl/device_batchnorm_forward_impl.hpp @@ -0,0 +1,711 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/utility/reduction_operator.hpp" +#include "ck/tensor_operation/gpu/device/device_batchnorm_forward.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_reduce_common.hpp" +#include "ck/tensor_operation/gpu/device/welford_helper.hpp" +#include "ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_welford_first_half.hpp" +#include "ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_welford_second_half_batchnorm_forward_final.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_batchnorm_forward_blockwise_welford.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct DeviceBatchNormFwdImpl + : public DeviceBatchNormFwd +{ + static_assert(Rank <= 6, "Bigger Rank size is not supported!"); + static_assert(BlockSize == MThreadClusterSize * KThreadClusterSize, + "Invalid thread cluster size assignments!"); + + static_assert((XSrcYDstVectorDim == 0 && MThreadSliceSize % XSrcVectorSize == 0) || + (XSrcYDstVectorDim == 1 && KThreadSliceSize % XSrcVectorSize == 0), + "Invalid thread slice sizes and/or vector sizes configuration, please check!"); + + static constexpr index_t NumInvariantDim = Rank - NumBatchNormReduceDim; + + static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize; + static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize; + + static auto MakeXY2dDescriptor(const std::array& xyLengths, + const std::array& xyStrides, + int blkGroupSize, + int numBlockTileIteration) + { + const auto tupleXYLengths = + generate_tuple([&](auto I) { return xyLengths[I]; }, Number{}); + const auto tupleXYStrides = + generate_tuple([&](auto I) { return xyStrides[I]; }, Number{}); + + const auto raw_grid_desc = make_naive_tensor_descriptor(tupleXYLengths, tupleXYStrides); + + const auto grid_desc_m_k = [&]() { + using InvariantDims = typename arithmetic_sequence_gen<0, NumInvariantDim, 1>::type; + using ReduceDims = typename arithmetic_sequence_gen::type; + + const auto reduceDimLengths = + generate_tuple([&](auto I) { return xyLengths[NumInvariantDim + I]; }, + Number{}); + const auto invariantDimLengths = + generate_tuple([&](auto I) { return xyLengths[I]; }, Number{}); + + return transform_tensor_descriptor(raw_grid_desc, + make_tuple(make_merge_transform(invariantDimLengths), + make_merge_transform(reduceDimLengths)), + make_tuple(InvariantDims{}, ReduceDims{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + }(); + + const auto invariantLength = grid_desc_m_k.GetLength(Number<0>{}); + const auto reduceLength = grid_desc_m_k.GetLength(Number<1>{}); + + const int workSizePerBlock = K_BlockTileSize * numBlockTileIteration; + const auto mPad = + math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength; + const auto kPad = workSizePerBlock * blkGroupSize - reduceLength; + + auto grid_desc_m_k_padded = + transform_tensor_descriptor(grid_desc_m_k, + make_tuple(make_right_pad_transform(invariantLength, mPad), + make_right_pad_transform(reduceLength, kPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return (grid_desc_m_k_padded); + }; + + static auto MakeMeanVarCountOutputMG2dDescriptor(int invariantLength, int blkGroupSize) + { + const auto grid_desc_m_g = + make_naive_tensor_descriptor_packed(make_tuple(invariantLength, blkGroupSize)); + + const auto mPad = + math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength; + + auto grid_desc_m_g_padded = + transform_tensor_descriptor(grid_desc_m_g, + make_tuple(make_right_pad_transform(invariantLength, mPad), + make_pass_through_transform(blkGroupSize)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return (grid_desc_m_g_padded); + }; + + static auto MakeMeanVarCountInputMK2dDescriptor(int invariantLength, int blkGroupSize) + { + const auto reduceLength = blkGroupSize; + const auto grid_desc_m_k = + make_naive_tensor_descriptor_packed(make_tuple(invariantLength, reduceLength)); + + const auto mPad = + math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength; + const auto kPad = + math::integer_least_multiple(reduceLength, KThreadClusterSize) - reduceLength; + + auto grid_desc_m_k_padded = + transform_tensor_descriptor(grid_desc_m_k, + make_tuple(make_right_pad_transform(invariantLength, mPad), + make_right_pad_transform(reduceLength, kPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return (grid_desc_m_k_padded); + }; + + static auto + MakeScaleBiasMeanVar1dDescriptor(const std::array& lengths, + const std::array& strides) + { + const auto tupleLengths = + generate_tuple([&](auto I) { return lengths[I]; }, Number{}); + const auto tupleStrides = + generate_tuple([&](auto I) { return strides[I]; }, Number{}); + + auto raw_grid_desc = make_naive_tensor_descriptor(tupleLengths, tupleStrides); + + auto grid_desc_m = transform_tensor_descriptor( + raw_grid_desc, + make_tuple(make_merge_transform(tupleLengths)), + make_tuple(typename arithmetic_sequence_gen<0, NumInvariantDim, 1>::type{}), + make_tuple(Sequence<0>{})); + + const auto invariantLength = grid_desc_m.GetLength(Number<0>{}); + + const auto mPad = + math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength; + + auto grid_desc_m_padded = + transform_tensor_descriptor(grid_desc_m, + make_tuple(make_right_pad_transform(invariantLength, mPad)), + make_tuple(Sequence<0>{}), + make_tuple(Sequence<0>{})); + return (grid_desc_m_padded); + }; + + using XYGridDesc_M_K = decltype(MakeXY2dDescriptor({1}, {1}, 1, 1)); + using ScaleBiasMeanVarGridDesc_M = decltype(MakeScaleBiasMeanVar1dDescriptor({1}, {1})); + + struct Argument : public BaseArgument + { + Argument(const std::array xyLengths, + const std::array xStrides, + const std::array yStrides, + const std::array reduceDims, + const std::array bnScaleBiasMeanVarLengths, + const std::array bnScaleStrides, + const std::array bnBiasStrides, + const std::array bnMeanVarStrides, + const XDataType* p_x, + const ScaleDataType* p_scale, + const BiasDataType* p_bias, + const YElementwiseOp y_elementwise_op, + double epsilon, + YDataType* p_y, + MeanVarDataType* resultSaveMean, + MeanVarDataType* resultSaveInvVariance, + double averageFactor, + MeanVarDataType* resultRunningMean, + MeanVarDataType* resultRunningVariance) + : bnScaleBiasMeanVarLengths_(bnScaleBiasMeanVarLengths), + bnScaleStrides_(bnScaleStrides), + bnBiasStrides_(bnBiasStrides), + bnMeanVarStrides_(bnMeanVarStrides), + p_x_(p_x), + p_scale_(p_scale), + p_bias_(p_bias), + y_elementwise_op_(y_elementwise_op), + p_y_(p_y), + resultSaveMean_(resultSaveMean), + resultSaveInvVariance_(resultSaveInvVariance), + resultRunningMean_(resultRunningMean), + resultRunningVariance_(resultRunningVariance) + { + xyLengths_ = + shuffle_tensor_dimensions(xyLengths, reduceDims); + xStrides_ = + shuffle_tensor_dimensions(xStrides, reduceDims); + yStrides_ = + shuffle_tensor_dimensions(yStrides, reduceDims); + + std::tie(invariant_length_, reduce_length_) = + get_2d_lengths(xyLengths_); + + epsilon_ = type_convert(epsilon); + averageFactor_ = type_convert(averageFactor); + + updateMovingAverage_ = + (resultRunningMean != nullptr && resultRunningVariance != nullptr); + saveMeanInvVariance_ = (resultSaveMean != nullptr && resultSaveInvVariance_ != nullptr); + + if(UseMultiblockInK) + { + int iterations = 1; + while(true) + { + int testBlkGroupSize = (reduce_length_ + (K_BlockTileSize * iterations) - 1) / + (K_BlockTileSize * iterations); + + // we want the blkGroupSize be not more than 128 + if(testBlkGroupSize <= 128) + break; + + iterations++; + }; + + blkGroupSize_ = (reduce_length_ + (K_BlockTileSize * iterations) - 1) / + (K_BlockTileSize * iterations); + + numBlockTileIteration_ = iterations; + } + else + { + blkGroupSize_ = 1; + numBlockTileIteration_ = (reduce_length_ + K_BlockTileSize - 1) / K_BlockTileSize; + }; + + gridSize_ = (invariant_length_ + M_BlockTileSize - 1) / M_BlockTileSize * blkGroupSize_; + + x_grid_desc_m_k_ = + MakeXY2dDescriptor(xyLengths_, xStrides_, blkGroupSize_, numBlockTileIteration_); + y_grid_desc_m_k_ = + MakeXY2dDescriptor(xyLengths_, yStrides_, blkGroupSize_, numBlockTileIteration_); + scale_grid_desc_m_ = + MakeScaleBiasMeanVar1dDescriptor(bnScaleBiasMeanVarLengths, bnScaleStrides_); + bias_grid_desc_m_ = + MakeScaleBiasMeanVar1dDescriptor(bnScaleBiasMeanVarLengths, bnBiasStrides_); + mean_var_grid_desc_m_ = + MakeScaleBiasMeanVar1dDescriptor(bnScaleBiasMeanVarLengths, bnMeanVarStrides_); + } + + AccDataType epsilon_; + AccDataType averageFactor_; + + bool updateMovingAverage_; + bool saveMeanInvVariance_; + + std::array xyLengths_; + std::array xStrides_; + std::array yStrides_; + + std::array bnScaleBiasMeanVarLengths_; + std::array bnScaleStrides_; + std::array bnBiasStrides_; + std::array bnMeanVarStrides_; + + const XDataType* p_x_; + const ScaleDataType* p_scale_; + const BiasDataType* p_bias_; + const YElementwiseOp y_elementwise_op_; + YDataType* p_y_; + + MeanVarDataType* resultSaveMean_; + MeanVarDataType* resultSaveInvVariance_; + + MeanVarDataType* resultRunningMean_; + MeanVarDataType* resultRunningVariance_; + + long_index_t invariant_length_; + long_index_t reduce_length_; + + int blkGroupSize_; + int numBlockTileIteration_; + size_t gridSize_; + + XYGridDesc_M_K x_grid_desc_m_k_; + XYGridDesc_M_K y_grid_desc_m_k_; + ScaleBiasMeanVarGridDesc_M scale_grid_desc_m_; + ScaleBiasMeanVarGridDesc_M bias_grid_desc_m_; + ScaleBiasMeanVarGridDesc_M mean_var_grid_desc_m_; + + void* workspace_mean_; + void* workspace_variance_; + void* workspace_count_; + }; + + size_t GetWorkSpaceSize(const BaseArgument* pArg) const override + { + const Argument* pArg_ = dynamic_cast(pArg); + + size_t workspace_size = 0; + + if(UseMultiblockInK && pArg_->blkGroupSize_ > 1) + { + // workspace for welford intermediate mean + workspace_size += + pArg_->invariant_length_ * pArg_->blkGroupSize_ * sizeof(MeanVarDataType) + 64; + + // workspace for welford intermediate variance + workspace_size += + pArg_->invariant_length_ * pArg_->blkGroupSize_ * sizeof(MeanVarDataType) + 64; + + // workspace for welford intermediate count + workspace_size += + pArg_->invariant_length_ * pArg_->blkGroupSize_ * sizeof(int32_t) + 64; + } + + return (workspace_size); + }; + + void SetWorkSpacePointer(BaseArgument* pArg, void* p_workspace) const override + { + Argument* pArg_ = dynamic_cast(pArg); + + pArg_->p_workspace_ = p_workspace; + + if(UseMultiblockInK && pArg_->blkGroupSize_ > 1) + { + + // setup buffer used for intermediate welford mean + pArg_->workspace_mean_ = static_cast(pArg_->p_workspace_); + + index_t mean_space_sz = + pArg_->invariant_length_ * pArg_->blkGroupSize_ * sizeof(MeanVarDataType); + + mean_space_sz = math::integer_least_multiple(mean_space_sz, 64); + + // setup buffer used for intermediate welford varirance + pArg_->workspace_variance_ = + reinterpret_cast(pArg_->workspace_mean_) + mean_space_sz; + + index_t variance_space_sz = + pArg_->invariant_length_ * pArg_->blkGroupSize_ * sizeof(MeanVarDataType); + + variance_space_sz = math::integer_least_multiple(variance_space_sz, 64); + + // setup buffer used for intermediate welfor count + pArg_->workspace_count_ = + reinterpret_cast(pArg_->workspace_variance_) + variance_space_sz; + }; + }; + + struct Invoker : public BaseInvoker + { + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + float avg_time = 0; + + if(UseMultiblockInK && arg.blkGroupSize_ > 1) + { + using GetReduceCountPerThreadFunctor = + GetReduceCountPerThreadForMultiblockWelford; + + GetReduceCountPerThreadFunctor get_reduce_count_per_thread( + arg.blkGroupSize_, arg.numBlockTileIteration_, arg.reduce_length_); + + const auto mean_var_count_grid_desc_m_g = + DeviceBatchNormFwdImpl::MakeMeanVarCountOutputMG2dDescriptor( + arg.invariant_length_, arg.blkGroupSize_); + + const auto mean_var_count_grid_desc_m_k = + DeviceBatchNormFwdImpl::MakeMeanVarCountInputMK2dDescriptor( + arg.invariant_length_, arg.blkGroupSize_); + + using MeanVarCountGridDesc_M_G = decltype(mean_var_count_grid_desc_m_g); + using MeanVarCountGridDesc_M_K = decltype(mean_var_count_grid_desc_m_k); + + using GridwiseMultiblockWelfordFirstHalf_ = + GridwiseMultiblockWelfordFirstHalf; + + using GridwiseWelfordSecondHalfBatchNormForwardFinal_ = + GridwiseWelfordSecondHalfBatchNormForwardFinal; + + index_t numMeanVarCountBlockTileIteration = + (arg.blkGroupSize_ + KThreadClusterSize - 1) / KThreadClusterSize; + + const auto kern_multiblock_welford_first_half = + kernel_multiblock_welford_first_half; + + const auto kern_welford_second_half_batchnorm_forward_final = + kernel_welford_second_half_batchnorm_forward_final< + GridwiseWelfordSecondHalfBatchNormForwardFinal_, + XDataType, + YDataType, + AccDataType, + ScaleDataType, + BiasDataType, + MeanVarDataType, + YElementwiseOp, + XYGridDesc_M_K, + MeanVarCountGridDesc_M_K, + ScaleBiasMeanVarGridDesc_M, + ScaleBiasMeanVarGridDesc_M>; + + avg_time += + launch_and_time_kernel(stream_config, + kern_multiblock_welford_first_half, + dim3(arg.gridSize_), + dim3(BlockSize), + 0, + arg.x_grid_desc_m_k_, + mean_var_count_grid_desc_m_g, + get_reduce_count_per_thread, + arg.numBlockTileIteration_, + arg.p_x_, + static_cast(arg.workspace_mean_), + static_cast(arg.workspace_variance_), + static_cast(arg.workspace_count_)); + + avg_time += + launch_and_time_kernel(stream_config, + kern_welford_second_half_batchnorm_forward_final, + dim3(arg.gridSize_), + dim3(BlockSize), + 0, + arg.x_grid_desc_m_k_, + arg.y_grid_desc_m_k_, + mean_var_count_grid_desc_m_k, + arg.scale_grid_desc_m_, + arg.bias_grid_desc_m_, + arg.mean_var_grid_desc_m_, + arg.blkGroupSize_, + arg.numBlockTileIteration_, + numMeanVarCountBlockTileIteration, + arg.epsilon_, + static_cast(arg.workspace_mean_), + static_cast(arg.workspace_variance_), + static_cast(arg.workspace_count_), + arg.p_x_, + arg.p_scale_, + arg.p_bias_, + arg.y_elementwise_op_, + arg.p_y_, + arg.updateMovingAverage_, + arg.averageFactor_, + arg.resultRunningMean_, + arg.resultRunningVariance_, + arg.saveMeanInvVariance_, + arg.resultSaveMean_, + arg.resultSaveInvVariance_); + } + else + { + using GetReduceCountPerThreadFunctor = + GetReduceCountPerThreadForBlockwiseWelford; + + GetReduceCountPerThreadFunctor get_reduce_count_per_thread( + arg.numBlockTileIteration_, arg.reduce_length_); + + using GridwiseBatchNormForwardWithBlockwiseWelford_ = + GridwiseBatchNormForwardWithBlockwiseWelford; + + const auto kern_batchnorm_fwd = kernel_batchnorm_forward_with_blockwise_welford< + GridwiseBatchNormForwardWithBlockwiseWelford_, + XDataType, + YDataType, + AccDataType, + ScaleDataType, + BiasDataType, + MeanVarDataType, + YElementwiseOp, + XYGridDesc_M_K, + ScaleBiasMeanVarGridDesc_M, + ScaleBiasMeanVarGridDesc_M, + GetReduceCountPerThreadFunctor>; + + avg_time += launch_and_time_kernel(stream_config, + kern_batchnorm_fwd, + dim3(arg.gridSize_), + dim3(BlockSize), + 0, + arg.x_grid_desc_m_k_, + arg.y_grid_desc_m_k_, + arg.scale_grid_desc_m_, + arg.bias_grid_desc_m_, + arg.mean_var_grid_desc_m_, + get_reduce_count_per_thread, + arg.numBlockTileIteration_, + arg.epsilon_, + arg.p_x_, + arg.p_scale_, + arg.p_bias_, + arg.y_elementwise_op_, + arg.p_y_, + arg.updateMovingAverage_, // true or false + arg.averageFactor_, + arg.resultRunningMean_, + arg.resultRunningVariance_, + arg.saveMeanInvVariance_, // true or false + arg.resultSaveMean_, + arg.resultSaveInvVariance_); + }; + + return (avg_time); + }; + + float Run(const BaseArgument* pArg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(pArg), stream_config); + }; + }; + + bool IsSupportedArgument(const BaseArgument* pArg) override + { + const Argument* pArg_ = dynamic_cast(pArg); + + if constexpr(XSrcYDstVectorDim == 0) + { + if(pArg_->xStrides_[NumInvariantDim - 1] != 1 || + pArg_->yStrides_[NumInvariantDim - 1] != 1) + return false; + + if(pArg_->xyLengths_[NumInvariantDim - 1] % XSrcVectorSize != 0 || + pArg_->xyLengths_[NumInvariantDim - 1] % YDstVectorSize != 0) + return false; + } + else + { + if(pArg_->xStrides_[Rank - 1] != 1 || pArg_->yStrides_[Rank - 1] != 1) + return false; + + if(pArg_->xyLengths_[Rank - 1] % XSrcVectorSize != 0 || + pArg_->xyLengths_[Rank - 1] % YDstVectorSize != 0) + return false; + }; + + if(pArg_->bnScaleStrides_[NumInvariantDim - 1] != 1 && ScaleSrcVectorSize != 1) + return false; + if(pArg_->bnBiasStrides_[NumInvariantDim - 1] != 1 && BiasSrcVectorSize != 1) + return false; + + if(pArg_->bnScaleBiasMeanVarLengths_[NumInvariantDim - 1] % ScaleSrcVectorSize != 0) + return false; + if(pArg_->bnScaleBiasMeanVarLengths_[NumInvariantDim - 1] % BiasSrcVectorSize != 0) + return false; + + if(pArg_->bnMeanVarStrides_[NumInvariantDim - 1] != 1 && MeanVarSrcDstVectorSize != 1) + return false; + + if(pArg_->bnScaleBiasMeanVarLengths_[NumInvariantDim - 1] % MeanVarSrcDstVectorSize != 0) + return false; + + bool is_valid = true; + + static_for<0, NumInvariantDim, 1>{}([&](auto I) { + if(pArg_->xyLengths_[I] != pArg_->bnScaleBiasMeanVarLengths_[I]) + is_valid = false; + }); + + if(!is_valid) + return false; + + return true; + }; + + std::unique_ptr MakeArgumentPointer( + const std::array xyLengths, + const std::array xStrides, + const std::array yStrides, + const std::array reduceDims, + const std::array bnScaleBiasMeanVarLengths, + const std::array bnScaleStrides, + const std::array bnBiasStrides, + const std::array bnMeanVarStrides, + const void* p_x, + const void* p_scale, + const void* p_bias, + double epsilon, + const YElementwiseOp y_elementwise_op, + void* p_y, + void* resultSaveMean, + void* resultSaveInvVariance, + double averageFactor, + void* resultRunningMean, + void* resultRunningVariance) override + { + return std::make_unique(xyLengths, + xStrides, + yStrides, + reduceDims, + bnScaleBiasMeanVarLengths, + bnScaleStrides, + bnBiasStrides, + bnMeanVarStrides, + static_cast(p_x), + static_cast(p_scale), + static_cast(p_bias), + y_elementwise_op, + epsilon, + static_cast(p_y), + static_cast(resultSaveMean), + static_cast(resultSaveInvVariance), + averageFactor, + static_cast(resultRunningMean), + static_cast(resultRunningVariance)); + }; + + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(); + }; + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceBatchNormFwdImpl<" << BlockSize << ","; + str << "M_C" << MThreadClusterSize << "_S" << MThreadSliceSize << ","; + str << "K_C" << KThreadClusterSize << "_S" << KThreadSliceSize << ","; + str << "XSrcYDstVectorDim_" << XSrcYDstVectorDim << ","; + str << "VectorSize_X" << XSrcVectorSize << "_scale_" << ScaleSrcVectorSize << "_bias_" << BiasSrcVectorSize << "_mean_var_" << MeanVarSrcDstVectorSize << "_Y" << YDstVectorSize << ">"; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/welford_helper.hpp b/include/ck/tensor_operation/gpu/device/welford_helper.hpp new file mode 100644 index 0000000000..6c909b767d --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/welford_helper.hpp @@ -0,0 +1,89 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct GetReduceCountPerThreadForBlockwiseWelford +{ + GetReduceCountPerThreadForBlockwiseWelford(index_t numBlockTileIteration, + long_index_t reduce_length) + : numBlockTileIteration_{numBlockTileIteration} + { + count_in_last_tile_ = reduce_length % K_BlockTileSize; + }; + + __device__ index_t operator()(index_t thread_k_cluster_id) const + { + if(count_in_last_tile_ == 0) + return (KThreadSliceSize * numBlockTileIteration_); + else + { + index_t num_complete_slice = count_in_last_tile_ / KThreadSliceSize; + index_t count_in_last_slice = count_in_last_tile_ % KThreadSliceSize; + + if(thread_k_cluster_id < num_complete_slice) + return (KThreadSliceSize * numBlockTileIteration_); + else if(thread_k_cluster_id == num_complete_slice) + return (KThreadSliceSize * (numBlockTileIteration_ - 1) + count_in_last_slice); + else + return (KThreadSliceSize * (numBlockTileIteration_ - 1)); + }; + }; + + index_t numBlockTileIteration_; + index_t count_in_last_tile_; +}; + +template +struct GetReduceCountPerThreadForMultiblockWelford +{ + GetReduceCountPerThreadForMultiblockWelford(index_t blkGroupSize, + index_t numBlockTileIteration, + long_index_t reduce_length) + : blkGroupSize_(blkGroupSize), numBlockTileIteration_{numBlockTileIteration} + { + last_block_reduce_length_ = + reduce_length - K_BlockTileSize * numBlockTileIteration_ * (blkGroupSize_ - 1); + numBlockTileIterationByLastBlock_ = + (last_block_reduce_length_ + K_BlockTileSize - 1) / K_BlockTileSize; + }; + + __device__ index_t operator()(index_t block_local_id, index_t thread_k_cluster_id) const + { + if(last_block_reduce_length_ == K_BlockTileSize * numBlockTileIteration_ || + block_local_id < blkGroupSize_ - 1) + return (KThreadSliceSize * numBlockTileIteration_); + + index_t count_in_last_tile = last_block_reduce_length_ % K_BlockTileSize; + + if(count_in_last_tile == 0) + return (KThreadSliceSize * numBlockTileIterationByLastBlock_); + else + { + index_t num_complete_slice = count_in_last_tile / KThreadSliceSize; + + if(thread_k_cluster_id < num_complete_slice) + return (KThreadSliceSize * numBlockTileIterationByLastBlock_); + else if(thread_k_cluster_id == num_complete_slice) + return (KThreadSliceSize * (numBlockTileIterationByLastBlock_ - 1) + + count_in_last_tile); + else + return (KThreadSliceSize * (numBlockTileIterationByLastBlock_ - 1)); + }; + }; + + index_t blkGroupSize_; + index_t numBlockTileIteration_; + + index_t last_block_reduce_length_; + index_t numBlockTileIterationByLastBlock_; +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_welford_first_half.hpp b/include/ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_welford_first_half.hpp new file mode 100644 index 0000000000..1afe9f9752 --- /dev/null +++ b/include/ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_welford_first_half.hpp @@ -0,0 +1,258 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/data_type.hpp" +#include "ck/utility/math.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_welford.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_welford.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +namespace ck { + +template +__global__ void kernel_multiblock_welford_first_half( + const XGridDesc_M_K x_grid_desc_m_k, + const MeanVarCountGridDesc_M_G mean_var_count_grid_desc_m_g, + const GetReduceCountPerThreadFunctor get_reduce_count_per_thread, + index_t num_k_block_tile_iteration, + const XDataType* const __restrict__ p_x, + MeanVarDataType* const p_welford_mean, + MeanVarDataType* const p_welford_variance, + int32_t* const p_welford_count) +{ + GridwiseMultiblockWelfordFirstHalf_::Run(x_grid_desc_m_k, + mean_var_count_grid_desc_m_g, + get_reduce_count_per_thread, + num_k_block_tile_iteration, + p_x, + p_welford_mean, + p_welford_variance, + p_welford_count); +}; + +template +struct GridwiseMultiblockWelfordFirstHalf +{ + static_assert((XSrcCountSrcVectorDim == 0 && MThreadSliceSize % XSrcCountSrcVectorSize == 0) || + (XSrcCountSrcVectorDim == 1 && + KThreadSliceSize % XSrcCountSrcVectorSize == 0), + "Invalid thread slice sizes and/or vector sizes configuration, please check!"); + + static constexpr bool reorder_thread_cluster = (XSrcCountSrcVectorDim == 0); + + using ThreadClusterLengths_M_K = Sequence; + + using ThreadBufferDimAccessOrder = + typename conditional, Sequence<0, 1>>::type; + + using ThreadClusterArrangeOrder = + typename conditional, Sequence<0, 1>>::type; + + static constexpr auto thread_cluster_desc = + make_cluster_descriptor(ThreadClusterLengths_M_K{}, ThreadClusterArrangeOrder{}); + + using ThreadReduceSrcDesc_M_K = decltype(make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{}))); + using ThreadReduceDstDesc_M = + decltype(make_naive_tensor_descriptor_packed(make_tuple(Number{}))); + + using ThreadwiseWelford = + ThreadwiseWelford; + + using BlockwiseWelford = BlockwiseWelford; + + using PassThroughOp = tensor_operation::element_wise::PassThrough; + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + + static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize; + static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize; + + __device__ static void Run(const XGridDesc_M_K& x_grid_desc_m_k, + const MeanVarCountGridDesc_M_G& mean_var_count_grid_desc_m_g, + const GetReduceCountPerThreadFunctor& get_reduce_count_per_thread, + index_t num_k_block_tile_iteration, + const XDataType* const __restrict__ p_x, + MeanVarDataType* const p_welford_mean, + MeanVarDataType* const p_welford_variance, + int32_t* const p_welford_count) + { + StaticBuffer + x_thread_buf; + + StaticBuffer + welford_mean_thread_buf; + StaticBuffer + welford_var_thread_buf; + StaticBuffer + welford_count_thread_buf; + + const index_t blkgroup_size = mean_var_count_grid_desc_m_g.GetLength(I1); + + const index_t thread_local_id = get_thread_local_1d_id(); + const index_t block_global_id = get_block_1d_id(); + const index_t blkgroup_id = block_global_id / blkgroup_size; + const index_t block_local_id = block_global_id % blkgroup_size; + + const auto thread_cluster_idx = + thread_cluster_desc.CalculateBottomIndex(make_multi_index(thread_local_id)); + + const auto thread_m_cluster_id = thread_cluster_idx[I0]; + const auto thread_k_cluster_id = thread_cluster_idx[I1]; + + using ThreadBufferLengths_M_K = Sequence; + using ThreadBufferLengths_M_1 = Sequence; + + constexpr auto thread_buffer_desc_m_k = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{})); + constexpr auto thread_buffer_desc_m_1 = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number<1>{})); + + const index_t reduceSizePerBlock = K_BlockTileSize * num_k_block_tile_iteration; + + auto threadwise_x_load = ThreadwiseTensorSliceTransfer_v2( + x_grid_desc_m_k, + make_multi_index(blkgroup_id * M_BlockTileSize + thread_m_cluster_id * MThreadSliceSize, + block_local_id * reduceSizePerBlock + + thread_k_cluster_id * KThreadSliceSize)); + + auto threadwise_welford_mean_var_store = + ThreadwiseTensorSliceTransfer_v1r3, + 1, + 1, + InMemoryDataOperationEnum::Set, + 1, + true>( + mean_var_count_grid_desc_m_g, + make_multi_index(blkgroup_id * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize, + block_local_id), + PassThroughOp{}); + + auto threadwise_welford_count_store = + ThreadwiseTensorSliceTransfer_v1r3, + 1, + 1, + InMemoryDataOperationEnum::Set, + 1, + true>( + mean_var_count_grid_desc_m_g, + make_multi_index(blkgroup_id * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize, + block_local_id), + PassThroughOp{}); + + constexpr auto thread_copy_fwd_step_m_k = make_multi_index(0, K_BlockTileSize); + + const auto x_global_val_buf = make_dynamic_buffer( + p_x, x_grid_desc_m_k.GetElementSpaceSize()); + + auto welford_mean_global_val_buf = make_dynamic_buffer( + p_welford_mean, mean_var_count_grid_desc_m_g.GetElementSpaceSize()); + + auto welford_var_global_val_buf = make_dynamic_buffer( + p_welford_variance, mean_var_count_grid_desc_m_g.GetElementSpaceSize()); + + auto welford_count_global_val_buf = make_dynamic_buffer( + p_welford_count, mean_var_count_grid_desc_m_g.GetElementSpaceSize()); + + auto threadwise_welford = ThreadwiseWelford(); + threadwise_welford.max_count_ = + get_reduce_count_per_thread(block_local_id, thread_k_cluster_id); + + static_for<0, MThreadSliceSize, 1>{}([&](auto I) { + welford_mean_thread_buf(I) = type_convert(0.0f); + welford_var_thread_buf(I) = type_convert(0.0f); + }); + + for(index_t reducedTiles = 0; reducedTiles < num_k_block_tile_iteration; ++reducedTiles) + { + threadwise_x_load.Run(x_grid_desc_m_k, + x_global_val_buf, + thread_buffer_desc_m_k, + make_tuple(I0, I0), + x_thread_buf); + + threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, thread_copy_fwd_step_m_k); + threadwise_welford.Run(x_thread_buf, welford_mean_thread_buf, welford_var_thread_buf); + } + + static_for<0, MThreadSliceSize, 1>{}([&](auto I) { + if constexpr(I > 0) + block_sync_lds(); + + welford_count_thread_buf(I) = threadwise_welford.cur_count_; + BlockwiseWelford::Run( + welford_mean_thread_buf(I), welford_var_thread_buf(I), welford_count_thread_buf(I)); + }); + + if(thread_k_cluster_id == 0) + { + threadwise_welford_mean_var_store.Run(thread_buffer_desc_m_1, + make_tuple(I0, I0), + welford_mean_thread_buf, + mean_var_count_grid_desc_m_g, + welford_mean_global_val_buf); + + threadwise_welford_mean_var_store.Run(thread_buffer_desc_m_1, + make_tuple(I0, I0), + welford_var_thread_buf, + mean_var_count_grid_desc_m_g, + welford_var_global_val_buf); + + threadwise_welford_count_store.Run(thread_buffer_desc_m_1, + make_tuple(I0, I0), + welford_count_thread_buf, + mean_var_count_grid_desc_m_g, + welford_count_global_val_buf); + }; + } +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_welford_second_half_batchnorm_forward_final.hpp b/include/ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_welford_second_half_batchnorm_forward_final.hpp new file mode 100644 index 0000000000..53d3e8aee7 --- /dev/null +++ b/include/ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_welford_second_half_batchnorm_forward_final.hpp @@ -0,0 +1,570 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/data_type.hpp" +#include "ck/utility/math_v2.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_welford.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_welford.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +namespace ck { + +template +__global__ void kernel_welford_second_half_batchnorm_forward_final( + const XYGridDesc_M_K x_grid_desc_m_k, + const XYGridDesc_M_K y_grid_desc_m_k, + const MeanVarCountGridDesc_M_K mean_var_count_grid_desc_m_k, + const ScaleBiasGridDesc_M scale_grid_desc_m, + const ScaleBiasGridDesc_M bias_grid_desc_m, + const MeanVarGridDesc_M mean_var_grid_desc_m, + index_t blkgroup_size, + index_t num_xy_k_block_tile_iteration, + index_t num_mean_var_count_k_block_tile_iteration, + AccDataType epsilon, + const MeanVarDataType* const __restrict__ p_in_welford_mean, + const MeanVarDataType* const __restrict__ p_in_welford_variance, + const int32_t* const __restrict__ p_in_welford_count, + const XDataType* const __restrict__ p_x, + const ScaleDataType* const __restrict__ p_scale, + const BiasDataType* const __restrict__ p_bias, + const YElementwiseOp y_elementwise_op, + YDataType* const __restrict__ p_y, + bool updateMovingAverage, + AccDataType averageFactor, + MeanVarDataType* const __restrict__ resultRunningMean, + MeanVarDataType* const __restrict__ resultRunningVariance, + bool saveMeanInvVariance, + MeanVarDataType* const __restrict__ resultSaveMean, + MeanVarDataType* const __restrict__ resultSaveInvVariance) +{ + GridwiseWelfordSecondHalfBatchNormForwardFinal_::Run(x_grid_desc_m_k, + y_grid_desc_m_k, + mean_var_count_grid_desc_m_k, + scale_grid_desc_m, + bias_grid_desc_m, + mean_var_grid_desc_m, + blkgroup_size, + num_xy_k_block_tile_iteration, + num_mean_var_count_k_block_tile_iteration, + epsilon, + p_in_welford_mean, + p_in_welford_variance, + p_in_welford_count, + p_x, + p_scale, + p_bias, + y_elementwise_op, + p_y, + updateMovingAverage, + averageFactor, + resultRunningMean, + resultRunningVariance, + saveMeanInvVariance, + resultSaveMean, + resultSaveInvVariance); +}; + +template +struct GridwiseWelfordSecondHalfBatchNormForwardFinal +{ + static_assert((XSrcYDstVectorDim == 0 && MThreadSliceSize % XSrcVectorSize == 0) || + (XSrcYDstVectorDim == 1 && KThreadSliceSize % XSrcVectorSize == 0), + "Invalid thread slice sizes and/or vector sizes configuration, please check!"); + + static_assert((XSrcYDstVectorDim == 0 && MThreadSliceSize % YDstVectorSize == 0) || + (XSrcYDstVectorDim == 1 && KThreadSliceSize % YDstVectorSize == 0), + "Invalid thread slice sizes and/or vector sizes configuration, please check!"); + + static constexpr bool reorder_thread_cluster = (XSrcYDstVectorDim == 0); + + using ThreadClusterLengths_M_K = Sequence; + + using ThreadBufferDimAccessOrder = + typename conditional, Sequence<0, 1>>::type; + + using ThreadClusterArrangeOrder = + typename conditional, Sequence<0, 1>>::type; + + static constexpr auto thread_cluster_desc = + make_cluster_descriptor(ThreadClusterLengths_M_K{}, ThreadClusterArrangeOrder{}); + + using ThreadReduceSrcDesc_M_1 = decltype( + make_naive_tensor_descriptor_packed(make_tuple(Number{}, Number<1>{}))); + using ThreadReduceDstDesc_M = + decltype(make_naive_tensor_descriptor_packed(make_tuple(Number{}))); + + using ThreadwiseWelford = + ThreadwiseWelfordMerge; + + using BlockwiseWelford = BlockwiseWelford; + + using PassThroughOp = tensor_operation::element_wise::PassThrough; + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + + static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize; + static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize; + + __device__ static void Run(const XYGridDesc_M_K& x_grid_desc_m_k, + const XYGridDesc_M_K& y_grid_desc_m_k, + const MeanVarCountGridDesc_M_K& mean_var_count_grid_desc_m_k, + const ScaleBiasGridDesc_M& scale_grid_desc_m, + const ScaleBiasGridDesc_M& bias_grid_desc_m, + const MeanVarGridDesc_M& mean_var_grid_desc_m, + index_t blkgroup_size, + index_t num_xy_k_block_tile_iteration, + index_t num_mean_var_count_k_block_tile_iteration, + AccDataType epsilon, + const MeanVarDataType* const __restrict__ p_in_welford_mean, + const MeanVarDataType* const __restrict__ p_in_welford_variance, + const int32_t* const __restrict__ p_in_welford_count, + const XDataType* const __restrict__ p_x, + const ScaleDataType* const __restrict__ p_scale, + const BiasDataType* const __restrict__ p_bias, + const YElementwiseOp y_elementwise_op, + YDataType* const __restrict__ p_y, + bool updateMovingAverage, + AccDataType averageFactor, + MeanVarDataType* const __restrict__ resultRunningMean, + MeanVarDataType* const __restrict__ resultRunningVariance, + bool saveMeanInvVariance, + MeanVarDataType* const __restrict__ resultSaveMean, + MeanVarDataType* const __restrict__ resultSaveInvVariance) + + { + using ck::math::sqrt; + + StaticBuffer + in_welford_mean_thread_buf; + StaticBuffer + in_welford_var_thread_buf; + StaticBuffer + in_welford_count_thread_buf; + + StaticBuffer + welford_mean_thread_buf; + StaticBuffer + welford_var_thread_buf; + StaticBuffer + welford_count_thread_buf; + + StaticBuffer + x_thread_buf; + StaticBuffer + y_thread_buf; + + StaticBuffer scale_thread_buf; + StaticBuffer bias_thread_buf; + + const index_t thread_local_id = get_thread_local_1d_id(); + const index_t block_global_id = get_block_1d_id(); + const index_t blkgroup_id = block_global_id / blkgroup_size; + const index_t block_local_id = block_global_id % blkgroup_size; + + const auto thread_cluster_idx = + thread_cluster_desc.CalculateBottomIndex(make_multi_index(thread_local_id)); + + const auto thread_m_cluster_id = thread_cluster_idx[I0]; + const auto thread_k_cluster_id = thread_cluster_idx[I1]; + + using ThreadBufferLengths_M_K = Sequence; + using ThreadBufferLengths_M = Sequence; + using ThreadBufferLengths_M_1 = Sequence; + constexpr auto thread_buffer_desc_m_k = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{})); + constexpr auto thread_buffer_desc_m = + make_naive_tensor_descriptor_packed(make_tuple(Number{})); + constexpr auto thread_buffer_desc_m_1 = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number<1>{})); + + auto threadwise_mean_var_load_m_k = + ThreadwiseTensorSliceTransfer_v2, + 1, + 1, + 1, + true>( + mean_var_count_grid_desc_m_k, + make_multi_index(blkgroup_id * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize, + thread_k_cluster_id * 1)); + + auto threadwise_count_load_m_k = + ThreadwiseTensorSliceTransfer_v2, + 1, + 1, + 1, + true>( + mean_var_count_grid_desc_m_k, + make_multi_index(blkgroup_id * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize, + thread_k_cluster_id * 1)); + + const auto welford_mean_global_val_buf = make_dynamic_buffer( + p_in_welford_mean, mean_var_count_grid_desc_m_k.GetElementSpaceSize()); + + const auto welford_var_global_val_buf = make_dynamic_buffer( + p_in_welford_variance, mean_var_count_grid_desc_m_k.GetElementSpaceSize()); + + const auto welford_count_global_val_buf = make_dynamic_buffer( + p_in_welford_count, mean_var_count_grid_desc_m_k.GetElementSpaceSize()); + + constexpr auto mean_var_count_thread_copy_step_m_k = + make_multi_index(0, KThreadClusterSize * 1); + + // Step 1: do final welford reduction to get mean and variance + + static_for<0, MThreadSliceSize, 1>{}([&](auto I) { + welford_mean_thread_buf(I) = type_convert(0.0f); + welford_var_thread_buf(I) = type_convert(0.0f); + welford_count_thread_buf(I) = 0; + }); + + for(index_t reducedTiles = 0; reducedTiles < num_mean_var_count_k_block_tile_iteration; + ++reducedTiles) + { + threadwise_mean_var_load_m_k.Run(mean_var_count_grid_desc_m_k, + welford_mean_global_val_buf, + thread_buffer_desc_m_1, + make_tuple(I0, I0), + in_welford_mean_thread_buf); + + threadwise_mean_var_load_m_k.Run(mean_var_count_grid_desc_m_k, + welford_var_global_val_buf, + thread_buffer_desc_m_1, + make_tuple(I0, I0), + in_welford_var_thread_buf); + + threadwise_count_load_m_k.Run(mean_var_count_grid_desc_m_k, + welford_count_global_val_buf, + thread_buffer_desc_m_1, + make_tuple(I0, I0), + in_welford_count_thread_buf); + + ThreadwiseWelford::Run(in_welford_mean_thread_buf, + in_welford_var_thread_buf, + in_welford_count_thread_buf, + welford_mean_thread_buf, + welford_var_thread_buf, + welford_count_thread_buf); + + threadwise_mean_var_load_m_k.MoveSrcSliceWindow(mean_var_count_grid_desc_m_k, + mean_var_count_thread_copy_step_m_k); + threadwise_count_load_m_k.MoveSrcSliceWindow(mean_var_count_grid_desc_m_k, + mean_var_count_thread_copy_step_m_k); + } + + static_for<0, MThreadSliceSize, 1>{}([&](auto I) { + if constexpr(I > 0) + block_sync_lds(); + + BlockwiseWelford::Run( + welford_mean_thread_buf(I), welford_var_thread_buf(I), welford_count_thread_buf(I)); + }); + + // Step 2: do normalization and output y + + const index_t workSizePerBlock = K_BlockTileSize * num_xy_k_block_tile_iteration; + + auto threadwise_x_load = ThreadwiseTensorSliceTransfer_v2( + x_grid_desc_m_k, + make_multi_index(blkgroup_id * M_BlockTileSize + thread_m_cluster_id * MThreadSliceSize, + workSizePerBlock * block_local_id + + thread_k_cluster_id * KThreadSliceSize)); + + auto threadwise_y_store = + ThreadwiseTensorSliceTransfer_v1r3( + y_grid_desc_m_k, + make_multi_index( + blkgroup_id * M_BlockTileSize + thread_m_cluster_id * MThreadSliceSize, + workSizePerBlock * block_local_id + thread_k_cluster_id * KThreadSliceSize), + y_elementwise_op); + + auto threadwise_scale_load = + ThreadwiseTensorSliceTransfer_v2, + 0, + ScaleSrcVectorSize, + 1, + true>( + scale_grid_desc_m, + make_multi_index(blkgroup_id * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize)); + + auto threadwise_bias_load = ThreadwiseTensorSliceTransfer_v2, + 0, + BiasSrcVectorSize, + 1, + true>( + bias_grid_desc_m, + make_multi_index(blkgroup_id * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize)); + + const auto x_global_val_buf = make_dynamic_buffer( + p_x, x_grid_desc_m_k.GetElementSpaceSize()); + + const auto scale_global_val_buf = make_dynamic_buffer( + p_scale, scale_grid_desc_m.GetElementSpaceSize()); + + const auto bias_global_val_buf = make_dynamic_buffer( + p_bias, bias_grid_desc_m.GetElementSpaceSize()); + + auto y_global_val_buf = make_dynamic_buffer( + p_y, y_grid_desc_m_k.GetElementSpaceSize()); + + threadwise_scale_load.Run(scale_grid_desc_m, + scale_global_val_buf, + thread_buffer_desc_m, + make_tuple(I0), + scale_thread_buf); + + threadwise_bias_load.Run(bias_grid_desc_m, + bias_global_val_buf, + thread_buffer_desc_m, + make_tuple(I0), + bias_thread_buf); + + constexpr auto xy_thread_copy_step_m_k = make_multi_index(0, K_BlockTileSize); + + for(index_t workTiles = 0; workTiles < num_xy_k_block_tile_iteration; ++workTiles) + { + threadwise_x_load.Run(x_grid_desc_m_k, + x_global_val_buf, + thread_buffer_desc_m_k, + make_tuple(I0, I0), + x_thread_buf); + + static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { + AccDataType multiplier = + scale_thread_buf[iM] / sqrt(welford_var_thread_buf[iM] + epsilon); + + AccDataType fused_mean_bias = + bias_thread_buf[iM] - welford_mean_thread_buf[iM] * multiplier; + + static_for<0, KThreadSliceSize, 1>{}([&](auto iK) { + constexpr auto offset = + thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK)); + + y_thread_buf(Number{}) = + x_thread_buf[Number{}] * multiplier + fused_mean_bias; + }); + }); + + threadwise_y_store.Run(thread_buffer_desc_m_k, + make_tuple(I0, I0), + y_thread_buf, + y_grid_desc_m_k, + y_global_val_buf); + + threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, xy_thread_copy_step_m_k); + threadwise_y_store.MoveDstSliceWindow(y_grid_desc_m_k, xy_thread_copy_step_m_k); + } + + // Step 3: update the moving average of mean and variance (optional) + + if(updateMovingAverage && block_local_id == 0 && thread_k_cluster_id == 0) + { + StaticBuffer + running_mean_thread_buf; + StaticBuffer + running_var_thread_buf; + + auto running_mean_global_buf = make_dynamic_buffer( + resultRunningMean, mean_var_grid_desc_m.GetElementSpaceSize()); + + auto running_var_global_buf = make_dynamic_buffer( + resultRunningVariance, mean_var_grid_desc_m.GetElementSpaceSize()); + + auto threadwise_mean_var_load_m = + ThreadwiseTensorSliceTransfer_v2, + 0, + MeanVarSrcDstVectorSize, + 1, + true>( + mean_var_grid_desc_m, + make_multi_index(blkgroup_id * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize)); + + threadwise_mean_var_load_m.Run(mean_var_grid_desc_m, + running_mean_global_buf, + thread_buffer_desc_m, + make_tuple(I0), + running_mean_thread_buf); + + threadwise_mean_var_load_m.Run(mean_var_grid_desc_m, + running_var_global_buf, + thread_buffer_desc_m, + make_tuple(I0), + running_var_thread_buf); + + AccDataType oneMinusAverageFactor = type_convert(1.0) - averageFactor; + + static_for<0, MThreadSliceSize, 1>{}([&](auto I) { + running_mean_thread_buf(I) = running_mean_thread_buf[I] * oneMinusAverageFactor + + welford_mean_thread_buf[I] * averageFactor; + running_var_thread_buf(I) = running_var_thread_buf[I] * oneMinusAverageFactor + + welford_var_thread_buf[I] * averageFactor; + }); + + auto threadwise_mean_var_store = + ThreadwiseTensorSliceTransfer_v1r3, + 0, + MeanVarSrcDstVectorSize, + InMemoryDataOperationEnum::Set, + 1, + true>( + mean_var_grid_desc_m, + make_multi_index(blkgroup_id * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize), + PassThroughOp{}); + + threadwise_mean_var_store.Run(thread_buffer_desc_m, + make_tuple(I0), + running_mean_thread_buf, + mean_var_grid_desc_m, + running_mean_global_buf); + + threadwise_mean_var_store.Run(thread_buffer_desc_m, + make_tuple(I0), + running_var_thread_buf, + mean_var_grid_desc_m, + running_var_global_buf); + }; + + // Step 4: save mean and inv-variance (optional) + + if(saveMeanInvVariance && block_local_id == 0 && thread_k_cluster_id == 0) + { + auto result_mean_global_buf = make_dynamic_buffer( + resultSaveMean, mean_var_grid_desc_m.GetElementSpaceSize()); + + auto result_inv_var_global_buf = make_dynamic_buffer( + resultSaveInvVariance, mean_var_grid_desc_m.GetElementSpaceSize()); + + static_for<0, MThreadSliceSize, 1>{}([&](auto I) { + welford_var_thread_buf(I) = + type_convert(1.0f) / sqrt(epsilon + welford_var_thread_buf[I]); + }); + + auto threadwise_mean_inv_var_store = + ThreadwiseTensorSliceTransfer_v1r3, + 0, + MeanVarSrcDstVectorSize, + InMemoryDataOperationEnum::Set, + 1, + true>( + mean_var_grid_desc_m, + make_multi_index(blkgroup_id * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize), + PassThroughOp{}); + + threadwise_mean_inv_var_store.Run(thread_buffer_desc_m, + make_tuple(I0), + welford_mean_thread_buf, + mean_var_grid_desc_m, + result_mean_global_buf); + + threadwise_mean_inv_var_store.Run(thread_buffer_desc_m, + make_tuple(I0), + welford_var_thread_buf, + mean_var_grid_desc_m, + result_inv_var_global_buf); + }; + } +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_batchnorm_forward_blockwise_welford.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_batchnorm_forward_blockwise_welford.hpp new file mode 100644 index 0000000000..b0c9ceb3da --- /dev/null +++ b/include/ck/tensor_operation/gpu/grid/gridwise_batchnorm_forward_blockwise_welford.hpp @@ -0,0 +1,482 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/data_type.hpp" +#include "ck/utility/math_v2.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_welford.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_welford.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +namespace ck { + +template +__global__ void kernel_batchnorm_forward_with_blockwise_welford( + const XYGridDesc_M_K x_grid_desc_m_k, + const XYGridDesc_M_K y_grid_desc_m_k, + const ScaleBiasGridDesc_M scale_grid_desc_m, + const ScaleBiasGridDesc_M bias_grid_desc_m, + const MeanVarGridDesc_M mean_var_grid_desc_m, + const GetReduceCountPerThreadFunctor get_reduce_count_per_thread, + index_t num_k_block_tile_iteration, + AccDataType epsilon, + const XDataType* const __restrict__ p_x, + const ScaleDataType* const __restrict__ p_scale, + const BiasDataType* const __restrict__ p_bias, + const YElementwiseOp y_elementwise_op, + YDataType* const __restrict__ p_y, + bool updateMovingAverage, + AccDataType averageFactor, + MeanVarDataType* const __restrict__ resultRunningMean, + MeanVarDataType* const __restrict__ resultRunningVariance, + bool saveMeanInvVariance, + MeanVarDataType* const __restrict__ resultSaveMean, + MeanVarDataType* const __restrict__ resultSaveInvVariance) +{ + GridwiseBatchrNormForwardWithBlockwiseWelford_::Run(x_grid_desc_m_k, + y_grid_desc_m_k, + scale_grid_desc_m, + bias_grid_desc_m, + mean_var_grid_desc_m, + get_reduce_count_per_thread, + num_k_block_tile_iteration, + epsilon, + p_x, + p_scale, + p_bias, + y_elementwise_op, + p_y, + updateMovingAverage, + averageFactor, + resultRunningMean, + resultRunningVariance, + saveMeanInvVariance, + resultSaveMean, + resultSaveInvVariance); +}; + +template +struct GridwiseBatchNormForwardWithBlockwiseWelford +{ + static_assert((XSrcYDstVectorDim == 0 && MThreadSliceSize % XSrcVectorSize == 0) || + (XSrcYDstVectorDim == 1 && KThreadSliceSize % XSrcVectorSize == 0), + "Invalid thread slice sizes and/or vector sizes configuration, please check!"); + + static_assert((XSrcYDstVectorDim == 0 && MThreadSliceSize % YDstVectorSize == 0) || + (XSrcYDstVectorDim == 1 && KThreadSliceSize % YDstVectorSize == 0), + "Invalid thread slice sizes and/or vector sizes configuration, please check!"); + + static constexpr bool reorder_thread_cluster = (XSrcYDstVectorDim == 0); + + using ThreadClusterLengths_M_K = Sequence; + + using ThreadBufferDimAccessOrder = + typename conditional, Sequence<0, 1>>::type; + + using ThreadClusterArrangeOrder = + typename conditional, Sequence<0, 1>>::type; + + static constexpr auto thread_cluster_desc = + make_cluster_descriptor(ThreadClusterLengths_M_K{}, ThreadClusterArrangeOrder{}); + + using ThreadReduceSrcDesc_M_K = decltype(make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{}))); + using ThreadReduceDstDesc_M = + decltype(make_naive_tensor_descriptor_packed(make_tuple(Number{}))); + + using ThreadwiseWelford = + ThreadwiseWelford; + + using BlockwiseWelford = BlockwiseWelford; + + using PassThroughOp = tensor_operation::element_wise::PassThrough; + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + + static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize; + static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize; + + __device__ static void Run(const XYGridDesc_M_K& x_grid_desc_m_k, + const XYGridDesc_M_K& y_grid_desc_m_k, + const ScaleBiasGridDesc_M& scale_grid_desc_m, + const ScaleBiasGridDesc_M& bias_grid_desc_m, + const MeanVarGridDesc_M& mean_var_grid_desc_m, + const GetReduceCountPerThreadFunctor& get_reduce_count_per_thread, + index_t num_k_block_tile_iteration, + AccDataType epsilon, + const XDataType* const __restrict__ p_x, + const ScaleDataType* const __restrict__ p_scale, + const BiasDataType* const __restrict__ p_bias, + const YElementwiseOp y_elementwise_op, + YDataType* const __restrict__ p_y, + bool updateMovingAverage, + AccDataType averageFactor, + MeanVarDataType* const __restrict__ resultRunningMean, + MeanVarDataType* const __restrict__ resultRunningVariance, + bool saveMeanInvVariance, + MeanVarDataType* const __restrict__ resultSaveMean, + MeanVarDataType* const __restrict__ resultSaveInvVariance) + { + using ck::math::sqrt; + + StaticBuffer + x_thread_buf; + + StaticBuffer scale_thread_buf; + + StaticBuffer bias_thread_buf; + + StaticBuffer + y_thread_buf; + + StaticBuffer mean_thread_buf; + StaticBuffer var_thread_buf; + + const index_t thread_local_id = get_thread_local_1d_id(); + const index_t block_global_id = get_block_1d_id(); + + const auto thread_cluster_idx = + thread_cluster_desc.CalculateBottomIndex(make_multi_index(thread_local_id)); + + const auto thread_m_cluster_id = thread_cluster_idx[I0]; + const auto thread_k_cluster_id = thread_cluster_idx[I1]; + + using ThreadBufferLengths_M_K = Sequence; + using ThreadBufferLengths_M = Sequence; + constexpr auto thread_buffer_desc_m_k = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{})); + constexpr auto thread_buffer_desc_m = + make_naive_tensor_descriptor_packed(make_tuple(Number{})); + + auto threadwise_x_load = ThreadwiseTensorSliceTransfer_v2( + x_grid_desc_m_k, + make_multi_index(block_global_id * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize, + thread_k_cluster_id * KThreadSliceSize)); + + auto threadwise_y_store = + ThreadwiseTensorSliceTransfer_v1r3( + y_grid_desc_m_k, + make_multi_index(block_global_id * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize, + thread_k_cluster_id * KThreadSliceSize), + y_elementwise_op); + + auto threadwise_scale_load = + ThreadwiseTensorSliceTransfer_v2, + 0, + ScaleSrcVectorSize, + 1, + true>( + scale_grid_desc_m, + make_multi_index(block_global_id * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize)); + + auto threadwise_bias_load = ThreadwiseTensorSliceTransfer_v2, + 0, + BiasSrcVectorSize, + 1, + true>( + bias_grid_desc_m, + make_multi_index(block_global_id * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize)); + + constexpr auto thread_copy_fwd_step_m_k = make_multi_index(0, K_BlockTileSize); + constexpr auto thread_copy_bwd_step_m_k = make_multi_index(0, -K_BlockTileSize); + + const auto x_global_val_buf = make_dynamic_buffer( + p_x, x_grid_desc_m_k.GetElementSpaceSize()); + + const auto scale_global_val_buf = make_dynamic_buffer( + p_scale, scale_grid_desc_m.GetElementSpaceSize()); + + const auto bias_global_val_buf = make_dynamic_buffer( + p_bias, bias_grid_desc_m.GetElementSpaceSize()); + + auto y_global_val_buf = make_dynamic_buffer( + p_y, y_grid_desc_m_k.GetElementSpaceSize()); + + // Step 1: do welford reduction to get mean and variance + + auto threadwise_welford = ThreadwiseWelford(); + threadwise_welford.max_count_ = get_reduce_count_per_thread(thread_k_cluster_id); + + static_for<0, MThreadSliceSize, 1>{}([&](auto I) { + mean_thread_buf(I) = type_convert(0.0f); + var_thread_buf(I) = type_convert(0.0f); + }); + + for(index_t reducedTiles = 0; reducedTiles < num_k_block_tile_iteration; ++reducedTiles) + { + + threadwise_x_load.Run(x_grid_desc_m_k, + x_global_val_buf, + thread_buffer_desc_m_k, + make_tuple(I0, I0), + x_thread_buf); + + threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, thread_copy_fwd_step_m_k); + threadwise_welford.Run(x_thread_buf, mean_thread_buf, var_thread_buf); + } + + static_for<0, MThreadSliceSize, 1>{}([&](auto I) { + if constexpr(I > 0) + block_sync_lds(); + + int count = threadwise_welford.cur_count_; + BlockwiseWelford::Run(mean_thread_buf(I), var_thread_buf(I), count); + }); + + // Step 2: do normalization and output y + + threadwise_scale_load.Run(scale_grid_desc_m, + scale_global_val_buf, + thread_buffer_desc_m, + make_tuple(I0), + scale_thread_buf); + + threadwise_bias_load.Run(bias_grid_desc_m, + bias_global_val_buf, + thread_buffer_desc_m, + make_tuple(I0), + bias_thread_buf); + + auto thread_copy_tail_m_k = (num_k_block_tile_iteration - 1) * thread_copy_fwd_step_m_k; + + threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, thread_copy_bwd_step_m_k); + threadwise_y_store.MoveDstSliceWindow(y_grid_desc_m_k, thread_copy_tail_m_k); + + for(index_t reducedTiles = 0; reducedTiles < num_k_block_tile_iteration; ++reducedTiles) + { + threadwise_x_load.Run(x_grid_desc_m_k, + x_global_val_buf, + thread_buffer_desc_m_k, + make_tuple(I0, I0), + x_thread_buf); + + static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { + AccDataType multiplier = + scale_thread_buf[Number{}] / sqrt(var_thread_buf[iM] + epsilon); + + AccDataType fused_mean_bias = + bias_thread_buf[Number{}] - mean_thread_buf[iM] * multiplier; + + static_for<0, KThreadSliceSize, 1>{}([&](auto iK) { + constexpr auto offset = + thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK)); + + // normalize + y_thread_buf(Number{}) = + x_thread_buf[Number{}] * multiplier + fused_mean_bias; + }); + }); + + threadwise_y_store.Run(thread_buffer_desc_m_k, + make_tuple(I0, I0), + y_thread_buf, + y_grid_desc_m_k, + y_global_val_buf); + + threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, thread_copy_bwd_step_m_k); + threadwise_y_store.MoveDstSliceWindow(y_grid_desc_m_k, thread_copy_bwd_step_m_k); + } + + // Step 3: update the moving average of mean and variance (optional) + + if(updateMovingAverage && thread_k_cluster_id == 0) + { + StaticBuffer + running_mean_thread_buf; + StaticBuffer + running_var_thread_buf; + + auto running_mean_global_buf = make_dynamic_buffer( + resultRunningMean, mean_var_grid_desc_m.GetElementSpaceSize()); + + auto running_var_global_buf = make_dynamic_buffer( + resultRunningVariance, mean_var_grid_desc_m.GetElementSpaceSize()); + + auto threadwise_mean_var_load = + ThreadwiseTensorSliceTransfer_v2, + 0, + MeanVarSrcDstVectorSize, + 1, + true>( + mean_var_grid_desc_m, + make_multi_index(block_global_id * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize)); + + threadwise_mean_var_load.Run(mean_var_grid_desc_m, + running_mean_global_buf, + thread_buffer_desc_m, + make_tuple(I0), + running_mean_thread_buf); + + threadwise_mean_var_load.Run(mean_var_grid_desc_m, + running_var_global_buf, + thread_buffer_desc_m, + make_tuple(I0), + running_var_thread_buf); + + AccDataType oneMinusAverageFactor = type_convert(1.0) - averageFactor; + + static_for<0, MThreadSliceSize, 1>{}([&](auto I) { + running_mean_thread_buf(I) = running_mean_thread_buf[I] * oneMinusAverageFactor + + mean_thread_buf[I] * averageFactor; + running_var_thread_buf(I) = running_var_thread_buf[I] * oneMinusAverageFactor + + var_thread_buf[I] * averageFactor; + }); + + auto threadwise_mean_var_store = + ThreadwiseTensorSliceTransfer_v1r3, + 0, + MeanVarSrcDstVectorSize, + InMemoryDataOperationEnum::Set, + 1, + true>( + mean_var_grid_desc_m, + make_multi_index(block_global_id * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize), + PassThroughOp{}); + + threadwise_mean_var_store.Run(thread_buffer_desc_m, + make_tuple(I0), + running_mean_thread_buf, + mean_var_grid_desc_m, + running_mean_global_buf); + + threadwise_mean_var_store.Run(thread_buffer_desc_m, + make_tuple(I0), + running_var_thread_buf, + mean_var_grid_desc_m, + running_var_global_buf); + }; + + // Step 4: save mean and inv-variance (optional) + + if(saveMeanInvVariance && thread_k_cluster_id == 0) + { + auto result_mean_global_buf = make_dynamic_buffer( + resultSaveMean, mean_var_grid_desc_m.GetElementSpaceSize()); + + auto result_inv_var_global_buf = make_dynamic_buffer( + resultSaveInvVariance, mean_var_grid_desc_m.GetElementSpaceSize()); + + static_for<0, MThreadSliceSize, 1>{}([&](auto I) { + var_thread_buf(I) = + type_convert(1.0f) / sqrt(epsilon + var_thread_buf[I]); + }); + + auto threadwise_mean_inv_var_store = + ThreadwiseTensorSliceTransfer_v1r3, + 0, + MeanVarSrcDstVectorSize, + InMemoryDataOperationEnum::Set, + 1, + true>( + mean_var_grid_desc_m, + make_multi_index(block_global_id * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize), + PassThroughOp{}); + + threadwise_mean_inv_var_store.Run(thread_buffer_desc_m, + make_tuple(I0), + mean_thread_buf, + mean_var_grid_desc_m, + result_mean_global_buf); + + threadwise_mean_inv_var_store.Run(thread_buffer_desc_m, + make_tuple(I0), + var_thread_buf, + mean_var_grid_desc_m, + result_inv_var_global_buf); + }; + } +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_set_buffer_value.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_set_buffer_value.hpp index 1e52b4057c..901e7aee98 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_set_buffer_value.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_set_buffer_value.hpp @@ -3,6 +3,7 @@ #pragma once +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" namespace ck { diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_welford.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_welford.hpp index 3e224ae664..12ba2c5381 100644 --- a/include/ck/tensor_operation/gpu/thread/threadwise_welford.hpp +++ b/include/ck/tensor_operation/gpu/thread/threadwise_welford.hpp @@ -75,4 +75,63 @@ struct ThreadwiseWelford int max_count_; }; +template +struct ThreadwiseWelfordMerge +{ + static constexpr auto src_thread_desc_m_k = SrcMeanVarCountThreadDesc_M_K{}; + static constexpr auto dst_thread_desc_m = DstMeanVarThreadDesc_M{}; + + static constexpr auto src_length_m = src_thread_desc_m_k.GetLength(Number<0>{}); + static constexpr auto src_length_k = src_thread_desc_m_k.GetLength(Number<1>{}); + static constexpr auto dst_length_m = dst_thread_desc_m.GetLength(Number<0>{}); + + static_assert(src_length_m == dst_length_m, "lengths of source and dst buffer must match!"); + + __device__ static void + Merge(T& mean_a, T& var_a, int32_t& count_a, T mean_b, T var_b, int32_t count_b) + { + int count = count_a + count_b; + T count_b_over_count = count == 0 ? type_convert(0) : type_convert(count_b) / count; + T delta = mean_b - mean_a; + mean_a += delta * count_b_over_count; + var_a += var_b + delta * delta * count_a * count_b_over_count; + count_a = count; + } + + template + __device__ static void Run(const SrcMeanBufferType& src_mean_buf, + const SrcVarBufferType& src_var_buf, + const SrcCountBufferType& src_count_buf, + DstMeanBufferType& dst_mean_buf, + DstVarBufferType& dst_var_buf, + DstCountBufferType& dst_count_buf) + { + static_for<0, src_length_m, 1>{}([&](auto iM) { + static_for<0, src_length_k, 1>{}([&](auto iK) { + constexpr auto src_offset = src_thread_desc_m_k.CalculateOffset(make_tuple(iM, iK)); + + Merge(dst_mean_buf(iM), + dst_var_buf(iM), + dst_count_buf(iM), + src_mean_buf[Number{}], + src_var_buf[Number{}], + src_count_buf[Number{}]); + }); + + if constexpr(GetActualVariance) + { + dst_var_buf(iM) = dst_var_buf[iM] / dst_count_buf[iM]; + }; + }); + }; +}; + } // namespace ck diff --git a/library/include/ck/library/reference_tensor_operation/cpu/reference_batchnorm_forward_nhwc_c.hpp b/library/include/ck/library/reference_tensor_operation/cpu/reference_batchnorm_forward_nhwc_c.hpp index fa45af4997..c54766b6a0 100644 --- a/library/include/ck/library/reference_tensor_operation/cpu/reference_batchnorm_forward_nhwc_c.hpp +++ b/library/include/ck/library/reference_tensor_operation/cpu/reference_batchnorm_forward_nhwc_c.hpp @@ -9,46 +9,61 @@ #include #include +#include "ck/utility/math_v2.hpp" +#include "ck/utility/ignore.hpp" #include "ck/tensor_operation/gpu/device/device_batchnorm_forward.hpp" namespace ck { namespace tensor_operation { namespace host { -template -struct ReferenceBatchNormFwd_Input_N_H_W_C_Output_C : public device::DeviceBatchNormFwd<4, 3> +template +struct ReferenceBatchNormFwd_Input_N_H_W_C_Output_C + : public device::DeviceBatchNormFwd<4, 3, YElementwiseOp> { struct Argument : public device::BaseArgument { Argument(const std::array xyLengths, const std::array xStrides, const std::array yStrides, + const std::array reduceDims, const std::array bnScaleBiasMeanVarLengths, - const std::array bnScaleBiasMeanVarStrides, - const InOutDataType* p_x, - const AccDataType* bnScale, - const AccDataType* bnBias, - InOutDataType* p_y, - double exponentialAverageFactor, - AccDataType* resultRunningMean, - AccDataType* resultRunningVariance, + const std::array bnScaleStrides, + const std::array bnBiasStrides, + const std::array bnMeanVarStrides, + const XDataType* p_x, + const ScaleDataType* bnScale, + const BiasDataType* bnBias, double epsilon, - AccDataType* resultSaveMean, - AccDataType* resultSaveInvVariance) + const YElementwiseOp y_elementwise_op, + YDataType* p_y, + MeanVarDataType* resultSaveMean, + MeanVarDataType* resultSaveInvVariance, + double averageFactor, + MeanVarDataType* resultRunningMean, + MeanVarDataType* resultRunningVariance) : p_x_(p_x), bnScale_(bnScale), bnBias_(bnBias), + y_elementwise_op_(y_elementwise_op), p_y_(p_y), - resultRunningMean_(resultRunningMean), - resultRunningVariance_(resultRunningVariance), resultSaveMean_(resultSaveMean), resultSaveInvVariance_(resultSaveInvVariance), - exponentialAverageFactor_(exponentialAverageFactor), - epsilon_(epsilon) + resultRunningMean_(resultRunningMean), + resultRunningVariance_(resultRunningVariance) { - (void)xStrides; - (void)yStrides; - (void)bnScaleBiasMeanVarStrides; + ignore = xStrides; + ignore = yStrides; + ignore = bnScaleStrides; + ignore = bnBiasStrides; + ignore = bnMeanVarStrides; + ignore = reduceDims; if(xyLengths.size() != 4 || bnScaleBiasMeanVarLengths.size() != 1 || bnScaleBiasMeanVarLengths[0] != xyLengths[3]) @@ -59,26 +74,30 @@ struct ReferenceBatchNormFwd_Input_N_H_W_C_Output_C : public device::DeviceBatch w = xyLengths[2]; c = xyLengths[3]; + epsilon_ = type_convert(epsilon); + averageFactor_ = type_convert(averageFactor); + resultSave = (resultSaveMean != nullptr && resultSaveInvVariance != nullptr); resultRunning = (resultRunningMean != nullptr && resultRunningVariance != nullptr); } - const InOutDataType* p_x_; - const AccDataType* bnScale_; - const AccDataType* bnBias_; - InOutDataType* p_y_; + const XDataType* p_x_; + const ScaleDataType* bnScale_; + const BiasDataType* bnBias_; + const YElementwiseOp y_elementwise_op_; + YDataType* p_y_; - AccDataType* resultRunningMean_; - AccDataType* resultRunningVariance_; - AccDataType* resultSaveMean_; - AccDataType* resultSaveInvVariance_; + MeanVarDataType* resultSaveMean_; + MeanVarDataType* resultSaveInvVariance_; + MeanVarDataType* resultRunningMean_; + MeanVarDataType* resultRunningVariance_; bool resultSave, resultRunning; index_t n, h, w, c; - double exponentialAverageFactor_; - double epsilon_; + AccDataType averageFactor_; + AccDataType epsilon_; }; struct Invoker : public device::BaseInvoker @@ -86,14 +105,12 @@ struct ReferenceBatchNormFwd_Input_N_H_W_C_Output_C : public device::DeviceBatch float Run(const Argument& arg) { auto thread_reduce_func = [&](auto iC) { - AccDataType reduceSize = type_convert(arg.n) * - type_convert(arg.h) * - type_convert(arg.w); - index_t offset_C = iC; - AccDataType mean = type_convert(0.0f); - AccDataType meansquare = type_convert(0.0f); + index_t offset_C = iC; + AccDataType mean = type_convert(0.0f); + AccDataType variance = type_convert(0.0f); + int32_t curr_count = 0; - // compute mean, meanquare, variance, invVariance + // compute mean, variance using welford method for(index_t iN = 0; iN < arg.n; iN++) { index_t offset_N = iN * arg.h * arg.w * arg.c; @@ -106,40 +123,46 @@ struct ReferenceBatchNormFwd_Input_N_H_W_C_Output_C : public device::DeviceBatch auto offset = offset_N + offset_H + offset_W + offset_C; + curr_count++; + AccDataType x = type_convert(arg.p_x_[offset]); - mean += x; - meansquare += x * x; + AccDataType delta = x - mean; + + mean += delta / curr_count; + + AccDataType delta2 = x - mean; + + variance += delta * delta2; }; } }; - mean = mean / reduceSize; - meansquare = meansquare / reduceSize; + // actual variance + variance = variance / curr_count; - AccDataType variance = meansquare - mean * mean; AccDataType invVariance = - type_convert(1.0f) / - std::sqrt(type_convert(arg.epsilon_) + variance); + type_convert(1.0f) / ck::math::sqrt(arg.epsilon_ + variance); // save the mean/invVariance if required if(arg.resultSave) { - arg.resultSaveMean_[iC] = mean; - arg.resultSaveInvVariance_[iC] = invVariance; + arg.resultSaveMean_[iC] = type_convert(mean); + arg.resultSaveInvVariance_[iC] = type_convert(invVariance); }; // update the moving average if required if(arg.resultRunning) { - arg.resultRunningMean_[iC] = - arg.resultRunningMean_[iC] * - type_convert(1.0 - arg.exponentialAverageFactor_) + - mean * arg.exponentialAverageFactor_; - arg.resultRunningVariance_[iC] = - arg.resultRunningVariance_[iC] * - type_convert(1.0 - arg.exponentialAverageFactor_) + - variance * arg.exponentialAverageFactor_; + AccDataType oneMinusAverageFactor = + type_convert(1.0) - arg.averageFactor_; + arg.resultRunningMean_[iC] = type_convert( + type_convert(arg.resultRunningMean_[iC]) * + oneMinusAverageFactor + + mean * arg.averageFactor_); + arg.resultRunningVariance_[iC] = type_convert( + arg.resultRunningVariance_[iC] * oneMinusAverageFactor + + variance * arg.averageFactor_); }; // Normalization @@ -160,7 +183,7 @@ struct ReferenceBatchNormFwd_Input_N_H_W_C_Output_C : public device::DeviceBatch AccDataType norm_x = arg.bnScale_[iC] * (x - mean) * invVariance + arg.bnBias_[iC]; - arg.p_y_[offset] = type_convert(norm_x); + arg.p_y_[offset] = type_convert(norm_x); }; } }; @@ -207,34 +230,42 @@ struct ReferenceBatchNormFwd_Input_N_H_W_C_Output_C : public device::DeviceBatch MakeArgumentPointer(const std::array xyLengths, const std::array xStrides, const std::array yStrides, + const std::array reduceDims, const std::array bnScaleBiasMeanVarLengths, - const std::array bnScaleBiasMeanVarStrides, + const std::array bnScaleStrides, + const std::array bnBiasStrides, + const std::array bnMeanVarStrides, const void* p_x, const void* bnScale, const void* bnBias, - void* p_y, - double exponentialAverageFactor, - void* resultRunningMean, - void* resultRunningVariance, double epsilon, + const YElementwiseOp y_elementwise_op, + void* p_y, void* resultSaveMean, - void* resultSaveInvVariance) override + void* resultSaveInvVariance, + double averageFactor, + void* resultRunningMean, + void* resultRunningVariance) override { return std::make_unique(xyLengths, xStrides, yStrides, + reduceDims, bnScaleBiasMeanVarLengths, - bnScaleBiasMeanVarStrides, - static_cast(p_x), - static_cast(bnScale), - static_cast(bnBias), - static_cast(p_y), - exponentialAverageFactor, - static_cast(resultRunningMean), - static_cast(resultRunningVariance), + bnScaleStrides, + bnBiasStrides, + bnMeanVarStrides, + static_cast(p_x), + static_cast(bnScale), + static_cast(bnBias), epsilon, - static_cast(resultSaveMean), - static_cast(resultSaveInvVariance)); + y_elementwise_op, + static_cast(p_y), + static_cast(resultSaveMean), + static_cast(resultSaveInvVariance), + averageFactor, + static_cast(resultRunningMean), + static_cast(resultRunningVariance)); }; std::unique_ptr MakeInvokerPointer() override diff --git a/library/include/ck/library/reference_tensor_operation/cpu/reference_batchnorm_infer_nhwc_c.hpp b/library/include/ck/library/reference_tensor_operation/cpu/reference_batchnorm_infer_nhwc_c.hpp index 45092861f2..01e9572740 100644 --- a/library/include/ck/library/reference_tensor_operation/cpu/reference_batchnorm_infer_nhwc_c.hpp +++ b/library/include/ck/library/reference_tensor_operation/cpu/reference_batchnorm_infer_nhwc_c.hpp @@ -14,7 +14,12 @@ namespace ck { namespace tensor_operation { namespace host { -template +template struct ReferenceBatchNormInfer_Input_N_H_W_C_Output_C : public device::DeviceBatchNormInfer<4, 3> { struct Argument : public device::BaseArgument @@ -23,14 +28,16 @@ struct ReferenceBatchNormInfer_Input_N_H_W_C_Output_C : public device::DeviceBat const std::array xStrides, const std::array yStrides, const std::array bnScaleBiasMeanVarLengths, - const std::array bnScaleBiasMeanVarStrides, - const InOutDataType* p_x, - const AccDataType* bnScale, - const AccDataType* bnBias, + const std::array bnScaleStrides, + const std::array bnBiasStrides, + const std::array bnMeanVarStrides, + const XDataType* p_x, + const ScaleDataType* bnScale, + const BiasDataType* bnBias, double epsilon, - const AccDataType* estimatedMean, - const AccDataType* estimatedVariance, - InOutDataType* p_y) + const MeanVarDataType* estimatedMean, + const MeanVarDataType* estimatedVariance, + YDataType* p_y) : p_x_(p_x), bnScale_(bnScale), bnBias_(bnBias), @@ -39,32 +46,34 @@ struct ReferenceBatchNormInfer_Input_N_H_W_C_Output_C : public device::DeviceBat estimatedVariance_(estimatedVariance), p_y_(p_y) { - (void)xStrides; - (void)yStrides; - (void)bnScaleBiasMeanVarStrides; + ignore = xStrides; + ignore = yStrides; + ignore = bnScaleStrides; + ignore = bnBiasStrides; + ignore = bnMeanVarStrides; if(xyLengths.size() != 4 || bnScaleBiasMeanVarLengths.size() != 1 || bnScaleBiasMeanVarLengths[0] != xyLengths[3]) throw std::runtime_error("Invalid tensor dimensions!"); - n = xyLengths[0]; - h = xyLengths[1]; - w = xyLengths[2]; - c = xyLengths[3]; + n_ = xyLengths[0]; + h_ = xyLengths[1]; + w_ = xyLengths[2]; + c_ = xyLengths[3]; } - const InOutDataType* p_x_; - const AccDataType* bnScale_; - const AccDataType* bnBias_; + const XDataType* p_x_; + const ScaleDataType* bnScale_; + const BiasDataType* bnBias_; double epsilon_; - const AccDataType* estimatedMean_; - const AccDataType* estimatedVariance_; + const MeanVarDataType* estimatedMean_; + const MeanVarDataType* estimatedVariance_; - InOutDataType* p_y_; + YDataType* p_y_; - index_t n, h, w, c; + index_t n_, h_, w_, c_; }; struct Invoker : public device::BaseInvoker @@ -81,15 +90,15 @@ struct ReferenceBatchNormInfer_Input_N_H_W_C_Output_C : public device::DeviceBat std::sqrt(type_convert(arg.epsilon_) + variance); // Normalization - for(index_t iN = 0; iN < arg.n; iN++) + for(index_t iN = 0; iN < arg.n_; iN++) { - index_t offset_N = iN * arg.h * arg.w * arg.c; - for(index_t iH = 0; iH < arg.h; iH++) + index_t offset_N = iN * arg.h_ * arg.w_ * arg.c_; + for(index_t iH = 0; iH < arg.h_; iH++) { - index_t offset_H = iH * arg.w * arg.c; - for(index_t iW = 0; iW < arg.w; iW++) + index_t offset_H = iH * arg.w_ * arg.c_; + for(index_t iW = 0; iW < arg.w_; iW++) { - index_t offset_W = iW * arg.c; + index_t offset_W = iW * arg.c_; auto offset = offset_N + offset_H + offset_W + offset_C; @@ -98,21 +107,21 @@ struct ReferenceBatchNormInfer_Input_N_H_W_C_Output_C : public device::DeviceBat AccDataType norm_x = arg.bnScale_[iC] * (x - mean) * invVariance + arg.bnBias_[iC]; - arg.p_y_[offset] = type_convert(norm_x); + arg.p_y_[offset] = type_convert(norm_x); }; } }; }; std::size_t num_thread = std::thread::hardware_concurrency(); - std::size_t work_per_thread = (arg.c + num_thread - 1) / num_thread; + std::size_t work_per_thread = (arg.c_ + num_thread - 1) / num_thread; std::vector threads(num_thread); for(std::size_t it = 0; it < num_thread; ++it) { std::size_t ic_begin = it * work_per_thread; - std::size_t ic_end = std::min(static_cast((it + 1) * work_per_thread), arg.c); + std::size_t ic_end = std::min(static_cast((it + 1) * work_per_thread), arg.c_); auto f = [=] { for(std::size_t ic = ic_begin; ic < ic_end; ++ic) @@ -146,7 +155,9 @@ struct ReferenceBatchNormInfer_Input_N_H_W_C_Output_C : public device::DeviceBat const std::array xStrides, const std::array yStrides, const std::array bnScaleBiasMeanVarLengths, - const std::array bnScaleBiasMeanVarStrides, + const std::array bnScaleStrides, + const std::array bnBiasStrides, + const std::array bnMeanVarStrides, const void* p_x, const void* bnScale, const void* bnBias, @@ -159,14 +170,16 @@ struct ReferenceBatchNormInfer_Input_N_H_W_C_Output_C : public device::DeviceBat xStrides, yStrides, bnScaleBiasMeanVarLengths, - bnScaleBiasMeanVarStrides, - static_cast(p_x), - static_cast(bnScale), - static_cast(bnBias), + bnScaleStrides, + bnBiasStrides, + bnMeanVarStrides, + static_cast(p_x), + static_cast(bnScale), + static_cast(bnBias), epsilon, - static_cast(estimatedMean), - static_cast(estimatedVariance), - static_cast(p_y)); + static_cast(estimatedMean), + static_cast(estimatedVariance), + static_cast(p_y)); }; std::unique_ptr MakeInvokerPointer() override