From c3bb3db252f2eb6061094296f936982b5be139b3 Mon Sep 17 00:00:00 2001 From: Qianfeng Date: Thu, 1 Dec 2022 03:32:20 +0800 Subject: [PATCH] BatchNorm backward instance/external API/profiler/tests (#519) * Refine the device batchnorm-backward base API templates and data type assignments * Remove duplicated kernel file * Add batchnorm backward instances and external API * Add batchnorm-backward profiler and tests * Add client example which uses batchnorm backward external API * Merge test/batchnorm_fwd and test/batchnorm_bwd into one directory * Loose the threshold for batchnorm-backward check_err() [ROCm/composable_kernel commit: 63af525c06363f398b851967da2740a2ace382b5] --- client_example/13_batchnorm/CMakeLists.txt | 2 + .../13_batchnorm/batchnorm_bwd_nhwc.cpp | 201 +++++++++ .../34_batchnorm/batchnorm_backward_nhwc.cpp | 84 ++-- .../gpu/device/device_batchnorm_backward.hpp | 36 +- .../impl/device_batchnorm_backward_impl.hpp | 96 ++-- ...e_second_half_batchnorm_backward_final.hpp | 124 ++---- ...cond_half_multiblock_reduce_first_half.hpp | 55 +-- ...e_batchnorm_backward_blockwise_welford.hpp | 78 ++-- ...gridwise_multiblock_welford_first_half.hpp | 258 ----------- .../cpu/reference_batchnorm_backward.hpp | 412 ++++++++++++++++++ .../reference_batchnorm_backward_nhwc_c.hpp | 319 -------------- .../gpu/batchnorm_backward.hpp | 124 ++++++ .../gpu/batchnorm/CMakeLists.txt | 4 + ...evice_batchnorm_backward_bf16_instance.cpp | 146 +++++++ ...device_batchnorm_backward_f16_instance.cpp | 147 +++++++ ...device_batchnorm_backward_f32_instance.cpp | 145 ++++++ ...device_batchnorm_backward_f64_instance.cpp | 145 ++++++ profiler/CMakeLists.txt | 1 + .../profile_batchnorm_backward_impl.hpp | 390 +++++++++++++++++ profiler/src/profile_batchnorm_bwd.cpp | 204 +++++++++ profiler/src/profiler.cpp | 5 + test/CMakeLists.txt | 2 +- .../CMakeLists.txt | 2 + test/batchnorm/batchnorm_bwd_rank_4.cpp | 92 ++++ .../batchnorm_fwd_rank_4.cpp | 0 25 files changed, 2240 insertions(+), 832 deletions(-) create mode 100644 client_example/13_batchnorm/batchnorm_bwd_nhwc.cpp delete mode 100644 include/ck/tensor_operation/gpu/grid/gridwise_multiblock_welford_first_half.hpp create mode 100644 library/include/ck/library/reference_tensor_operation/cpu/reference_batchnorm_backward.hpp delete mode 100644 library/include/ck/library/reference_tensor_operation/cpu/reference_batchnorm_backward_nhwc_c.hpp create mode 100644 library/include/ck/library/tensor_operation_instance/gpu/batchnorm_backward.hpp create mode 100644 library/src/tensor_operation_instance/gpu/batchnorm/device_batchnorm_backward_bf16_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/batchnorm/device_batchnorm_backward_f16_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/batchnorm/device_batchnorm_backward_f32_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/batchnorm/device_batchnorm_backward_f64_instance.cpp create mode 100644 profiler/include/profile_batchnorm_backward_impl.hpp create mode 100644 profiler/src/profile_batchnorm_bwd.cpp rename test/{batchnorm_fwd => batchnorm}/CMakeLists.txt (50%) create mode 100644 test/batchnorm/batchnorm_bwd_rank_4.cpp rename test/{batchnorm_fwd => batchnorm}/batchnorm_fwd_rank_4.cpp (100%) diff --git a/client_example/13_batchnorm/CMakeLists.txt b/client_example/13_batchnorm/CMakeLists.txt index 0ddea1a8f1..54669678ae 100644 --- a/client_example/13_batchnorm/CMakeLists.txt +++ b/client_example/13_batchnorm/CMakeLists.txt @@ -1,2 +1,4 @@ add_executable(client_batchnorm_fwd_nhwc batchnorm_fwd_nhwc.cpp) +add_executable(client_batchnorm_bwd_nhwc batchnorm_bwd_nhwc.cpp) target_link_libraries(client_batchnorm_fwd_nhwc PRIVATE composable_kernel::device_operations) +target_link_libraries(client_batchnorm_bwd_nhwc PRIVATE composable_kernel::device_operations) diff --git a/client_example/13_batchnorm/batchnorm_bwd_nhwc.cpp b/client_example/13_batchnorm/batchnorm_bwd_nhwc.cpp new file mode 100644 index 0000000000..8ef21986a4 --- /dev/null +++ b/client_example/13_batchnorm/batchnorm_bwd_nhwc.cpp @@ -0,0 +1,201 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/device_reduce.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/gpu/batchnorm_backward.hpp" + +using XDataType = ck::half_t; +using DxDataType = float; +using DyDataType = float; +using AccDataType = float; +using ScaleDataType = ck::half_t; +using DscaleDbiasDataType = float; +using MeanVarDataType = float; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +constexpr int Rank = 4; +constexpr int NumBatchNormReduceDim = 3; + +const double epsilon = std::numeric_limits::epsilon(); + +struct SimpleDeviceMem +{ + SimpleDeviceMem() = delete; + + SimpleDeviceMem(std::size_t mem_size) : p_mem_{} + { + (void)hipMalloc(static_cast(&p_mem_), mem_size); + } + + void* GetDeviceBuffer() { return p_mem_; } + + ~SimpleDeviceMem() { (void)hipFree(p_mem_); } + + void* p_mem_; +}; + +int main(int argc, char* argv[]) +{ + std::array xyLengths{16, 8, 128, 256}; + std::array xyStrides{8 * 128 * 256, 128 * 256, 256, 1}; + std::array scaleBiasMeanVarLengths{256}; + std::array scaleBiasMeanVarStrides{1}; + std::array reduceDims{0, 1, 2}; + + ck::index_t numXYElement = + std::accumulate(xyLengths.begin(), xyLengths.end(), 1, std::multiplies()); + + ck::index_t numScaleBiasMeanVarElement = std::accumulate(scaleBiasMeanVarLengths.begin(), + scaleBiasMeanVarLengths.end(), + 1, + std::multiplies()); + + SimpleDeviceMem x(sizeof(XDataType) * numXYElement); + SimpleDeviceMem dy(sizeof(DyDataType) * numXYElement); + SimpleDeviceMem scale(sizeof(ScaleDataType) * numScaleBiasMeanVarElement); + SimpleDeviceMem mean(sizeof(MeanVarDataType) * numScaleBiasMeanVarElement); + SimpleDeviceMem invVariance(sizeof(MeanVarDataType) * numScaleBiasMeanVarElement); + SimpleDeviceMem dx(sizeof(DxDataType) * numXYElement); + SimpleDeviceMem dscale(sizeof(DscaleDbiasDataType) * numScaleBiasMeanVarElement); + SimpleDeviceMem dbias(sizeof(DscaleDbiasDataType) * numScaleBiasMeanVarElement); + + using DeviceOp = ck::tensor_operation::device::DeviceBatchNormBwd; + + const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + std::cout << "found " << op_ptrs.size() << " instances" << std::endl; + + std::string best_op_name; + bool found = false; + int best_op_id = -1; + float best_ave_time = std::numeric_limits::max(); + float best_gb_per_sec = 0; + + // profile device operation instances + std::cout << "Run all instances and do timing" << std::endl; + + for(int i = 0; i < op_ptrs.size(); ++i) + { + auto& op_ptr = op_ptrs[i]; + + auto argument_ptr = op_ptr->MakeArgumentPointer(xyLengths, + xyStrides, + xyStrides, + xyStrides, + reduceDims, + scaleBiasMeanVarLengths, + scaleBiasMeanVarStrides, + scaleBiasMeanVarStrides, + scaleBiasMeanVarStrides, + x.GetDeviceBuffer(), + dy.GetDeviceBuffer(), + scale.GetDeviceBuffer(), + mean.GetDeviceBuffer(), + invVariance.GetDeviceBuffer(), + epsilon, + PassThrough{}, + dx.GetDeviceBuffer(), + dscale.GetDeviceBuffer(), + dbias.GetDeviceBuffer()); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + std::string op_name = op_ptr->GetTypeString(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + size_t workspace_sz = op_ptr->GetWorkSpaceSize(argument_ptr.get()); + + SimpleDeviceMem workspace(workspace_sz); + + op_ptr->SetWorkSpacePointer(argument_ptr.get(), workspace.GetDeviceBuffer()); + + float ave_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true}); + + std::size_t num_bytes = + numXYElement * (sizeof(XDataType) + sizeof(DyDataType) + sizeof(DxDataType)) + + numScaleBiasMeanVarElement * + (sizeof(ScaleDataType) + sizeof(DscaleDbiasDataType) * 2 + + sizeof(MeanVarDataType) * 2); + + float gb_per_sec = num_bytes / 1.E6 / ave_time; + + std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << gb_per_sec << " GB/s, " + << op_name << std::endl; + + if(ave_time < best_ave_time) + { + found = true; + best_op_id = i; + best_op_name = op_name; + best_ave_time = ave_time; + best_gb_per_sec = gb_per_sec; + } + } + else + { + std::cout << op_name << " does not support this problem" << std::endl; + } + } + + if(found) + { + std::cout << "Best Perf: " << best_ave_time << " ms, " << best_gb_per_sec << " GB/s, " + << best_op_name << std::endl; + + // run the best intance + auto& op_ptr = op_ptrs[best_op_id]; + std::cout << "Run the best instance without timing: " << op_ptr->GetTypeString() + << std::endl; + + auto argument_ptr = op_ptr->MakeArgumentPointer(xyLengths, + xyStrides, + xyStrides, + xyStrides, + reduceDims, + scaleBiasMeanVarLengths, + scaleBiasMeanVarStrides, + scaleBiasMeanVarStrides, + scaleBiasMeanVarStrides, + x.GetDeviceBuffer(), + dy.GetDeviceBuffer(), + scale.GetDeviceBuffer(), + mean.GetDeviceBuffer(), + invVariance.GetDeviceBuffer(), + epsilon, + PassThrough{}, + dx.GetDeviceBuffer(), + dscale.GetDeviceBuffer(), + dbias.GetDeviceBuffer()); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false}); + } + + std::cout << "Done" << std::endl; + } + + return 0; +} diff --git a/example/34_batchnorm/batchnorm_backward_nhwc.cpp b/example/34_batchnorm/batchnorm_backward_nhwc.cpp index 90e3718441..a6ca9d150b 100644 --- a/example/34_batchnorm/batchnorm_backward_nhwc.cpp +++ b/example/34_batchnorm/batchnorm_backward_nhwc.cpp @@ -11,7 +11,7 @@ #include "ck/library/utility/host_tensor.hpp" #include "ck/library/utility/host_tensor_generator.hpp" #include "ck/library/utility/host_common_util.hpp" -#include "ck/library/reference_tensor_operation/cpu/reference_batchnorm_backward_nhwc_c.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_batchnorm_backward.hpp" #include "ck/tensor_operation/gpu/device/impl/device_batchnorm_backward_impl.hpp" static struct option long_options[] = {{"inOutLengths", required_argument, nullptr, 'D'}, @@ -106,7 +106,7 @@ class BatchNormBwdArg using namespace ck; -template +template bool bnorm_bwd_nhwc_test(bool do_verification, int init_method, bool time_kernel, @@ -118,13 +118,15 @@ bool bnorm_bwd_nhwc_test(bool do_verification, constexpr index_t Rank = 4; constexpr index_t NumReduceDim = 3; + using ScaleDataType = XDataType; + const std::vector scaleBiasMeanVarLengths = {inOutLengths[3]}; // input data of the batchnorm backward algorithm - Tensor x(inOutLengths); - Tensor dy(inOutLengths); + Tensor x(inOutLengths); + Tensor dy(inOutLengths); - Tensor bnScale(scaleBiasMeanVarLengths); + Tensor bnScale(scaleBiasMeanVarLengths); Tensor savedMean(scaleBiasMeanVarLengths); Tensor savedInvVar(scaleBiasMeanVarLengths); @@ -132,8 +134,8 @@ bool bnorm_bwd_nhwc_test(bool do_verification, Tensor savedVariance(scaleBiasMeanVarLengths); // output data of the batchnorm backward algorithm - Tensor dx_ref(inOutLengths); - Tensor dx(inOutLengths); + Tensor dx_ref(inOutLengths); + Tensor dx(inOutLengths); Tensor dscale(scaleBiasMeanVarLengths); Tensor dbias(scaleBiasMeanVarLengths); @@ -153,7 +155,7 @@ bool bnorm_bwd_nhwc_test(bool do_verification, const float noise_stddev = 0.0001f; // input data in normal distribution - x.GenerateTensorValue(GeneratorTensor_4{x_mean, x_stddev}, num_thread); + x.GenerateTensorValue(GeneratorTensor_4{x_mean, x_stddev}, num_thread); // initialize the savedMean to be values with tiny variation to the mean of the x values savedMean.GenerateTensorValue(GeneratorTensor_4{x_mean, noise_stddev}, @@ -182,7 +184,7 @@ bool bnorm_bwd_nhwc_test(bool do_verification, const float x_stddev = 1.0f; // input data in normal distribution - x.GenerateTensorValue(GeneratorTensor_4{x_mean, x_stddev}, num_thread); + x.GenerateTensorValue(GeneratorTensor_4{x_mean, x_stddev}, num_thread); }; if(do_verification) @@ -190,34 +192,34 @@ bool bnorm_bwd_nhwc_test(bool do_verification, switch(init_method) { case 0: - dy.GenerateTensorValue(GeneratorTensor_0{}, num_thread); - bnScale.GenerateTensorValue(GeneratorTensor_0{}, num_thread); + dy.GenerateTensorValue(GeneratorTensor_0{}, num_thread); + bnScale.GenerateTensorValue(GeneratorTensor_0{}, num_thread); break; case 1: - dy.GenerateTensorValue(GeneratorTensor_1{1}, num_thread); - bnScale.GenerateTensorValue(GeneratorTensor_1{1}, num_thread); + dy.GenerateTensorValue(GeneratorTensor_1{1}, num_thread); + bnScale.GenerateTensorValue(GeneratorTensor_1{1}, num_thread); break; case 2: - dy.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); - bnScale.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); + dy.GenerateTensorValue(GeneratorTensor_2{-2, 2}, num_thread); + bnScale.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); break; default: - dy.GenerateTensorValue(GeneratorTensor_3{-0.2f, 0.2f}, num_thread); - bnScale.GenerateTensorValue(GeneratorTensor_3{-0.5f, 0.5f}, num_thread); + dy.GenerateTensorValue(GeneratorTensor_3{-0.2f, 0.2f}, num_thread); + bnScale.GenerateTensorValue(GeneratorTensor_3{-0.5f, 0.5f}, num_thread); } }; // input data of the batchnorm backward algorithm - DeviceMem x_dev(sizeof(InOutDataType) * x.mDesc.GetElementSpaceSize()); - DeviceMem dy_dev(sizeof(InOutDataType) * dy.mDesc.GetElementSpaceSize()); + DeviceMem x_dev(sizeof(XDataType) * x.mDesc.GetElementSpaceSize()); + DeviceMem dy_dev(sizeof(AccDataType) * dy.mDesc.GetElementSpaceSize()); - DeviceMem bnScale_dev(sizeof(AccDataType) * bnScale.mDesc.GetElementSpaceSize()); + DeviceMem bnScale_dev(sizeof(ScaleDataType) * bnScale.mDesc.GetElementSpaceSize()); DeviceMem savedMean_dev(sizeof(AccDataType) * savedMean.mDesc.GetElementSpaceSize()); DeviceMem savedInvVar_dev(sizeof(AccDataType) * savedInvVar.mDesc.GetElementSpaceSize()); // output data of the batchnorm backward algorithm - DeviceMem dx_dev(sizeof(InOutDataType) * dx.mDesc.GetElementSpaceSize()); + DeviceMem dx_dev(sizeof(AccDataType) * dx.mDesc.GetElementSpaceSize()); DeviceMem dscale_dev(sizeof(AccDataType) * dscale.mDesc.GetElementSpaceSize()); DeviceMem dbias_dev(sizeof(AccDataType) * dbias.mDesc.GetElementSpaceSize()); @@ -249,13 +251,13 @@ bool bnorm_bwd_nhwc_test(bool do_verification, using PassThroughOp = ck::tensor_operation::element_wise::PassThrough; using DeviceBatchNormBwdInstance = - ck::tensor_operation::device::DeviceBatchNormBwdImpl; // MeanVarSrcVectorSize auto batchnorm_bwd = DeviceBatchNormBwdInstance{}; @@ -324,7 +326,7 @@ bool bnorm_bwd_nhwc_test(bool do_verification, // inputing of x, dy, scale, outputing of dx, dscale, dbias num_bytes += - total_length * sizeof(InOutDataType) * 3 + invariant_length * sizeof(AccDataType) * 3; + total_length * sizeof(XDataType) * 3 + invariant_length * sizeof(AccDataType) * 3; // outputing of mean, inv-variance num_bytes += haveSavedMeanInvVar ? invariant_length * sizeof(AccDataType) * 2 : 0; @@ -341,14 +343,16 @@ bool bnorm_bwd_nhwc_test(bool do_verification, if(do_verification) { using ReferenceBatchNormBwdInstance = - ck::tensor_operation::host::ReferenceBatchNormBwd_Input_N_H_W_C_Output_C; + ck::tensor_operation::host::ReferenceBatchNormBwd; auto batchNormBwd_ref = ReferenceBatchNormBwdInstance{}; @@ -390,8 +394,8 @@ bool bnorm_bwd_nhwc_test(bool do_verification, dbias_dev.FromDevice(dbias.data()); // clang-format off - pass = pass && ck::utils::check_err(dbias.mData, dbias_ref.mData, "dBias result:", 1e-5, 1e-5); - pass = pass && ck::utils::check_err(dscale.mData, dscale_ref.mData, "dScale result:", 1e-5, 2e-4); + pass = pass && ck::utils::check_err(dbias.mData, dbias_ref.mData, "dBias result:", 2e-4, 2e-4); + pass = pass && ck::utils::check_err(dscale.mData, dscale_ref.mData, "dScale result:", 2e-4, 2e-4); pass = pass && ck::utils::check_err(dx.mData, dx_ref.mData, "dx result:"); // clang-format on }; diff --git a/include/ck/tensor_operation/gpu/device/device_batchnorm_backward.hpp b/include/ck/tensor_operation/gpu/device/device_batchnorm_backward.hpp index e969fd0be7..d39f3b7cbc 100644 --- a/include/ck/tensor_operation/gpu/device/device_batchnorm_backward.hpp +++ b/include/ck/tensor_operation/gpu/device/device_batchnorm_backward.hpp @@ -13,7 +13,16 @@ namespace ck { namespace tensor_operation { namespace device { -template +template struct DeviceBatchNormBwd : public BaseOperator { static constexpr index_t NumInvariantDim = Rank - NumBatchNormReduceDim; @@ -26,7 +35,7 @@ struct DeviceBatchNormBwd : public BaseOperator const std::array reduceDims, const std::array bnScaleBiasMeanVarLengths, const std::array bnScaleStrides, - const std::array bnBiasStrides, + const std::array bnDscaleDbiasStrides, const std::array bnMeanVarStrides, const void* p_x, const void* p_dy, @@ -42,9 +51,26 @@ struct DeviceBatchNormBwd : public BaseOperator virtual std::unique_ptr MakeInvokerPointer() = 0; }; -template -using DeviceBatchNormBwdPtr = - std::unique_ptr>; +template +using DeviceBatchNormBwdPtr = std::unique_ptr>; } // namespace device } // namespace tensor_operation diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batchnorm_backward_impl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batchnorm_backward_impl.hpp index d61dbd0010..ab16a757f6 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batchnorm_backward_impl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batchnorm_backward_impl.hpp @@ -27,7 +27,7 @@ template -struct DeviceBatchNormBwdImpl - : public DeviceBatchNormBwd +struct DeviceBatchNormBwdImpl : public DeviceBatchNormBwd { static_assert(Rank <= 6, "Bigger Rank size is not supported!"); static_assert(BlockSize == MThreadClusterSize * KThreadClusterSize, @@ -194,7 +202,7 @@ struct DeviceBatchNormBwdImpl const std::array reduceDims, const std::array bnScaleBiasMeanVarLengths, const std::array bnScaleStrides, - const std::array bnBiasStrides, + const std::array bnDscaleDbiasStrides, const std::array bnMeanVarStrides, const XDataType* p_x, const DyDataType* p_dy, @@ -204,11 +212,11 @@ struct DeviceBatchNormBwdImpl const DyElementwiseOp dy_elementwise_op, double epsilon, DxDataType* p_dx, - ScaleDataType* p_dscale, - BiasDataType* p_dbias) + DscaleDbiasDataType* p_dscale, + DscaleDbiasDataType* p_dbias) : bnScaleBiasMeanVarLengths_(bnScaleBiasMeanVarLengths), bnScaleStrides_(bnScaleStrides), - bnBiasStrides_(bnBiasStrides), + bnDscaleDbiasStrides_(bnDscaleDbiasStrides), bnMeanVarStrides_(bnMeanVarStrides), p_x_(p_x), p_dy_(p_dy), @@ -272,8 +280,8 @@ struct DeviceBatchNormBwdImpl MakeXY2dDescriptor(xyLengths_, dxStrides_, blkGroupSize, numBlockTileIteration); scale_grid_desc_m = MakeScaleBiasMeanVar1dDescriptor(bnScaleBiasMeanVarLengths, bnScaleStrides); - bias_grid_desc_m = - MakeScaleBiasMeanVar1dDescriptor(bnScaleBiasMeanVarLengths, bnBiasStrides); + dscale_dbias_grid_desc_m = + MakeScaleBiasMeanVar1dDescriptor(bnScaleBiasMeanVarLengths, bnDscaleDbiasStrides); mean_var_grid_desc_m = MakeScaleBiasMeanVar1dDescriptor(bnScaleBiasMeanVarLengths, bnMeanVarStrides); } @@ -289,7 +297,7 @@ struct DeviceBatchNormBwdImpl std::array bnScaleBiasMeanVarLengths_; std::array bnScaleStrides_; - std::array bnBiasStrides_; + std::array bnDscaleDbiasStrides_; std::array bnMeanVarStrides_; const XDataType* p_x_; @@ -299,8 +307,8 @@ struct DeviceBatchNormBwdImpl const MeanVarDataType* p_savedInvVar_; const DyElementwiseOp dy_elementwise_op_; DxDataType* p_dx_; - ScaleDataType* p_dscale_; - BiasDataType* p_dbias_; + DscaleDbiasDataType* p_dscale_; + DscaleDbiasDataType* p_dbias_; long_index_t invariant_length; long_index_t reduce_length; @@ -313,7 +321,7 @@ struct DeviceBatchNormBwdImpl XYGridDesc_M_K dy_grid_desc_m_k; XYGridDesc_M_K dx_grid_desc_m_k; ScaleBiasGridDesc_M scale_grid_desc_m; - ScaleBiasGridDesc_M bias_grid_desc_m; + ScaleBiasGridDesc_M dscale_dbias_grid_desc_m; MeanVarGridDesc_M mean_var_grid_desc_m; void* workspace_mean; @@ -337,11 +345,11 @@ struct DeviceBatchNormBwdImpl { // workspace for the partial reduced result for dscale workspace_size += - pArg_->invariant_length * pArg_->blkGroupSize * sizeof(ScaleDataType) + 64; + pArg_->invariant_length * pArg_->blkGroupSize * sizeof(DscaleDbiasDataType) + 64; // workspace for the partial reduced result for dbias workspace_size += - pArg_->invariant_length * pArg_->blkGroupSize * sizeof(BiasDataType) + 64; + pArg_->invariant_length * pArg_->blkGroupSize * sizeof(DscaleDbiasDataType) + 64; if(!pArg_->haveSavedMeanInvVar_) { @@ -379,7 +387,7 @@ struct DeviceBatchNormBwdImpl // setup buffer for the partial reduced result for dscale pArg_->workspace_reduce_dscale = pArg_->p_workspace_; - space_sz = pArg_->invariant_length * pArg_->blkGroupSize * sizeof(ScaleDataType); + space_sz = pArg_->invariant_length * pArg_->blkGroupSize * sizeof(DscaleDbiasDataType); space_sz = math::integer_least_multiple(space_sz, 64); // setup buffer for the partial reduced result for dbias @@ -388,7 +396,7 @@ struct DeviceBatchNormBwdImpl if(UseMultiblockInK && pArg_->blkGroupSize > 1) { - space_sz = pArg_->invariant_length * pArg_->blkGroupSize * sizeof(BiasDataType); + space_sz = pArg_->invariant_length * pArg_->blkGroupSize * sizeof(DscaleDbiasDataType); space_sz = math::integer_least_multiple(space_sz, 64); // setup buffer for welford intermediate mean @@ -454,7 +462,7 @@ struct DeviceBatchNormBwdImpl DyDataType, AccDataType, ScaleDataType, - BiasDataType, + DscaleDbiasDataType, MeanVarDataType, DyElementwiseOp, XYGridDesc_M_K, @@ -477,7 +485,7 @@ struct DeviceBatchNormBwdImpl DxDataType, AccDataType, ScaleDataType, - BiasDataType, + DscaleDbiasDataType, MeanVarDataType, DyElementwiseOp, XYGridDesc_M_K, @@ -493,8 +501,8 @@ struct DeviceBatchNormBwdImpl XSrcVectorSize, DySrcVectorSize, DxDstVectorSize, - ScaleSrcDstVectorSize, - BiasDstVectorSize, + ScaleSrcVectorSize, + DscaleDbiasDstVectorSize, MeanVarSrcVectorSize>; if(UseMultiblockInK && arg.blkGroupSize > 1) @@ -553,7 +561,7 @@ struct DeviceBatchNormBwdImpl DyDataType, AccDataType, ScaleDataType, - BiasDataType, + DscaleDbiasDataType, MeanVarDataType, DyElementwiseOp, XYGridDesc_M_K, @@ -568,7 +576,7 @@ struct DeviceBatchNormBwdImpl DyDataType, DxDataType, ScaleDataType, - BiasDataType, + DscaleDbiasDataType, MeanVarDataType, DyElementwiseOp, XYGridDesc_M_K, @@ -614,8 +622,8 @@ struct DeviceBatchNormBwdImpl : static_cast(arg.workspace_savedInvVar), arg.p_x_, arg.p_dy_, - static_cast(arg.workspace_reduce_dscale), - static_cast(arg.workspace_reduce_dbias)); + static_cast(arg.workspace_reduce_dscale), + static_cast(arg.workspace_reduce_dbias)); avg_time += launch_and_time_kernel( stream_config, @@ -629,13 +637,13 @@ struct DeviceBatchNormBwdImpl dscale_dbias_grid_desc_m_k, arg.mean_var_grid_desc_m, arg.scale_grid_desc_m, - arg.bias_grid_desc_m, + arg.dscale_dbias_grid_desc_m, arg.blkGroupSize, arg.reduce_length, arg.numBlockTileIteration, numDscaleDbiasBlockTileIteration, - static_cast(arg.workspace_reduce_dscale), - static_cast(arg.workspace_reduce_dbias), + static_cast(arg.workspace_reduce_dscale), + static_cast(arg.workspace_reduce_dbias), arg.haveSavedMeanInvVar_ ? arg.p_savedMean_ : static_cast(arg.workspace_savedMean), @@ -664,7 +672,7 @@ struct DeviceBatchNormBwdImpl DxDataType, AccDataType, ScaleDataType, - BiasDataType, + DscaleDbiasDataType, MeanVarDataType, DyElementwiseOp, XYGridDesc_M_K, @@ -680,8 +688,8 @@ struct DeviceBatchNormBwdImpl XSrcVectorSize, DySrcVectorSize, DxDstVectorSize, - ScaleSrcDstVectorSize, - BiasDstVectorSize, + ScaleSrcVectorSize, + DscaleDbiasDstVectorSize, MeanVarSrcVectorSize>; const auto kern_batchnorm_bwd = kernel_batchnorm_backward_with_blockwise_welford< @@ -691,7 +699,7 @@ struct DeviceBatchNormBwdImpl DxDataType, AccDataType, ScaleDataType, - BiasDataType, + DscaleDbiasDataType, MeanVarDataType, DyElementwiseOp, XYGridDesc_M_K, @@ -708,7 +716,7 @@ struct DeviceBatchNormBwdImpl arg.dy_grid_desc_m_k, arg.dx_grid_desc_m_k, arg.scale_grid_desc_m, - arg.bias_grid_desc_m, + arg.dscale_dbias_grid_desc_m, arg.mean_var_grid_desc_m, get_reduce_count_per_thread, arg.reduce_length, @@ -764,16 +772,16 @@ struct DeviceBatchNormBwdImpl return false; }; - if(pArg_->bnScaleStrides_[NumInvariantDim - 1] != 1 && ScaleSrcDstVectorSize != 1) + if(pArg_->bnScaleStrides_[NumInvariantDim - 1] != 1 && ScaleSrcVectorSize != 1) return false; - if(pArg_->bnBiasStrides_[NumInvariantDim - 1] != 1 && BiasDstVectorSize != 1) + if(pArg_->bnDscaleDbiasStrides_[NumInvariantDim - 1] != 1 && DscaleDbiasDstVectorSize != 1) return false; - if(pArg_->bnScaleBiasMeanVarLengths_[NumInvariantDim - 1] % ScaleSrcDstVectorSize != 0) + if(pArg_->bnScaleBiasMeanVarLengths_[NumInvariantDim - 1] % ScaleSrcVectorSize != 0) return false; - if(pArg_->bnScaleBiasMeanVarLengths_[NumInvariantDim - 1] % BiasDstVectorSize != 0) + if(pArg_->bnScaleBiasMeanVarLengths_[NumInvariantDim - 1] % DscaleDbiasDstVectorSize != 0) return false; if(pArg_->haveSavedMeanInvVar_) @@ -806,7 +814,7 @@ struct DeviceBatchNormBwdImpl const std::array reduceDims, const std::array bnScaleBiasMeanVarLengths, const std::array bnScaleStrides, - const std::array bnBiasStrides, + const std::array bnDscaleDbiasStrides, const std::array bnMeanVarStrides, const void* p_x, const void* p_dy, @@ -826,7 +834,7 @@ struct DeviceBatchNormBwdImpl reduceDims, bnScaleBiasMeanVarLengths, bnScaleStrides, - bnBiasStrides, + bnDscaleDbiasStrides, bnMeanVarStrides, static_cast(p_x), static_cast(p_dy), @@ -836,8 +844,8 @@ struct DeviceBatchNormBwdImpl dy_elementwise_op, epsilon, static_cast(p_dx), - static_cast(p_dscale), - static_cast(p_dbias)); + static_cast(p_dscale), + static_cast(p_dbias)); }; std::unique_ptr MakeInvokerPointer() override @@ -854,7 +862,7 @@ struct DeviceBatchNormBwdImpl str << "M_C" << MThreadClusterSize << "_S" << MThreadSliceSize << ","; str << "K_C" << KThreadClusterSize << "_S" << KThreadSliceSize << ","; str << "XDyDxVectorDim_" << XDyDxVectorDim << ","; - str << "VectorSize_X" << XSrcVectorSize << "_scale_" << ScaleSrcDstVectorSize << "_bias_" << BiasDstVectorSize << "_mean_var_" << MeanVarSrcVectorSize << "_Dx_" << DxDstVectorSize << ">"; + str << "VectorSize_X" << XSrcVectorSize << "_scale_" << ScaleSrcVectorSize << "_bias_" << DscaleDbiasDstVectorSize << "_mean_var_" << MeanVarSrcVectorSize << "_Dx_" << DxDstVectorSize << ">"; // clang-format on return str.str(); diff --git a/include/ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_reduce_second_half_batchnorm_backward_final.hpp b/include/ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_reduce_second_half_batchnorm_backward_final.hpp index 9638755565..a72a4ee068 100644 --- a/include/ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_reduce_second_half_batchnorm_backward_final.hpp +++ b/include/ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_reduce_second_half_batchnorm_backward_final.hpp @@ -16,7 +16,7 @@ template struct GridwiseReduceSecondHalfBatchNormBackwardFinal { @@ -155,13 +155,13 @@ struct GridwiseReduceSecondHalfBatchNormBackwardFinal const DscaleDbiasGridDesc_M_K& dscale_dbias_grid_desc_m_k, const MeanVarGridDesc_M& mean_var_grid_desc_m, const ScaleBiasGridDesc_M& scale_grid_desc_m, - const ScaleBiasGridDesc_M& bias_grid_desc_m, + const ScaleBiasGridDesc_M& dscale_dbias_grid_desc_m, index_t blkgroup_size, long_index_t reduce_size, index_t num_xy_k_block_tile_iteration, index_t num_dscale_dbias_k_block_tile_iteration, - const ScaleDataType* const __restrict__ p_reduce_dscale, - const BiasDataType* const __restrict__ p_reduce_dbias, + const DscaleDbiasDataType* const __restrict__ p_reduce_dscale, + const DscaleDbiasDataType* const __restrict__ p_reduce_dbias, const MeanVarDataType* const __restrict__ p_mean, const MeanVarDataType* const __restrict__ p_inv_var, const XDataType* const __restrict__ p_x, @@ -169,8 +169,8 @@ struct GridwiseReduceSecondHalfBatchNormBackwardFinal const ScaleDataType* const __restrict__ p_scale, const DyElementwiseOp dy_elementwise_op, DxDataType* const __restrict__ p_dx, - ScaleDataType* const __restrict__ p_dscale, - BiasDataType* const __restrict__ p_dbias) + DscaleDbiasDataType* const __restrict__ p_dscale, + DscaleDbiasDataType* const __restrict__ p_dbias) { __shared__ AccDataType p_reduce_work_buffer[BlockSize]; @@ -222,8 +222,8 @@ struct GridwiseReduceSecondHalfBatchNormBackwardFinal // Step 1: do final reduction of dbias = sum(dy), dscale = sum(dy * (x-mean) * inv-variance) // clang-format on - auto threadwise_dscale_load_m_k = - ThreadwiseTensorSliceTransfer_v2, - 1, - 1, - 1, - true>( - dscale_dbias_grid_desc_m_k, - make_multi_index(blkgroup_id * M_BlockTileSize + - thread_m_cluster_id * MThreadSliceSize, - thread_k_cluster_id * 1)); - - auto threadwise_dscale_store_m = + auto threadwise_dscale_dbias_store_m = ThreadwiseTensorSliceTransfer_v1r3, 0, - ScaleSrcDstVectorSize, + DscaleDbiasDstVectorSize, InMemoryDataOperationEnum::Set, 1, true>( - scale_grid_desc_m, - make_multi_index(blkgroup_id * M_BlockTileSize + - thread_m_cluster_id * MThreadSliceSize), - PassThroughOp{}); - - auto threadwise_dbias_store_m = - ThreadwiseTensorSliceTransfer_v1r3, - 0, - BiasDstVectorSize, - InMemoryDataOperationEnum::Set, - 1, - true>( - bias_grid_desc_m, + dscale_dbias_grid_desc_m, make_multi_index(blkgroup_id * M_BlockTileSize + thread_m_cluster_id * MThreadSliceSize), PassThroughOp{}); @@ -297,10 +263,10 @@ struct GridwiseReduceSecondHalfBatchNormBackwardFinal p_reduce_dbias, dscale_dbias_grid_desc_m_k.GetElementSpaceSize()); auto dscale_global_buf = make_dynamic_buffer( - p_dscale, scale_grid_desc_m.GetElementSpaceSize()); + p_dscale, dscale_dbias_grid_desc_m.GetElementSpaceSize()); auto dbias_global_buf = make_dynamic_buffer( - p_dbias, bias_grid_desc_m.GetElementSpaceSize()); + p_dbias, dscale_dbias_grid_desc_m.GetElementSpaceSize()); constexpr auto dscale_dbias_thread_copy_step_m_k = make_multi_index(0, KThreadClusterSize * 1); @@ -313,25 +279,23 @@ struct GridwiseReduceSecondHalfBatchNormBackwardFinal for(index_t reducedTiles = 0; reducedTiles < num_dscale_dbias_k_block_tile_iteration; ++reducedTiles) { - threadwise_dscale_load_m_k.Run(dscale_dbias_grid_desc_m_k, - reduce_dscale_global_buf, - thread_buffer_desc_m_1, - make_tuple(I0, I0), - reduce_dscale_thread_buf); + threadwise_dscale_dbias_load_m_k.Run(dscale_dbias_grid_desc_m_k, + reduce_dscale_global_buf, + thread_buffer_desc_m_1, + make_tuple(I0, I0), + reduce_dscale_thread_buf); - threadwise_dbias_load_m_k.Run(dscale_dbias_grid_desc_m_k, - reduce_dbias_global_buf, - thread_buffer_desc_m_1, - make_tuple(I0, I0), - reduce_dbias_thread_buf); + threadwise_dscale_dbias_load_m_k.Run(dscale_dbias_grid_desc_m_k, + reduce_dbias_global_buf, + thread_buffer_desc_m_1, + make_tuple(I0, I0), + reduce_dbias_thread_buf); ThreadwiseReduce::Reduce(reduce_dscale_thread_buf, dscale_thread_buf); ThreadwiseReduce::Reduce(reduce_dbias_thread_buf, dbias_thread_buf); - threadwise_dscale_load_m_k.MoveSrcSliceWindow(dscale_dbias_grid_desc_m_k, - dscale_dbias_thread_copy_step_m_k); - threadwise_dbias_load_m_k.MoveSrcSliceWindow(dscale_dbias_grid_desc_m_k, - dscale_dbias_thread_copy_step_m_k); + threadwise_dscale_dbias_load_m_k.MoveSrcSliceWindow(dscale_dbias_grid_desc_m_k, + dscale_dbias_thread_copy_step_m_k); } static_for<0, MThreadSliceSize, 1>{}([&](auto I) { @@ -343,17 +307,17 @@ struct GridwiseReduceSecondHalfBatchNormBackwardFinal BlockwiseReduce::Reduce(reduce_work_buf, dbias_thread_buf(I)); }); - threadwise_dscale_store_m.Run(thread_buffer_desc_m, - make_tuple(I0), - dscale_thread_buf, - scale_grid_desc_m, - dscale_global_buf); + threadwise_dscale_dbias_store_m.Run(thread_buffer_desc_m, + make_tuple(I0), + dscale_thread_buf, + dscale_dbias_grid_desc_m, + dscale_global_buf); - threadwise_dbias_store_m.Run(thread_buffer_desc_m, - make_tuple(I0), - dbias_thread_buf, - bias_grid_desc_m, - dbias_global_buf); + threadwise_dscale_dbias_store_m.Run(thread_buffer_desc_m, + make_tuple(I0), + dbias_thread_buf, + dscale_dbias_grid_desc_m, + dbias_global_buf); // clang-format off // Step 2: calculate dx = 1/N * inv-variance * scale * (N * dy - dbias - dscale * (x - mean) * inv-variance) @@ -418,7 +382,7 @@ struct GridwiseReduceSecondHalfBatchNormBackwardFinal ThreadBufferLengths_M, Sequence<0>, 0, - ScaleSrcDstVectorSize, + ScaleSrcVectorSize, 1, true>( scale_grid_desc_m, diff --git a/include/ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_welford_second_half_multiblock_reduce_first_half.hpp b/include/ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_welford_second_half_multiblock_reduce_first_half.hpp index a4de9b7e6c..42b7e172b2 100644 --- a/include/ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_welford_second_half_multiblock_reduce_first_half.hpp +++ b/include/ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_welford_second_half_multiblock_reduce_first_half.hpp @@ -17,7 +17,7 @@ template , - 1, - 1, - InMemoryDataOperationEnum::Set, - 1, - true>( - dscale_dbias_grid_desc_m_g, - make_multi_index(blkgroup_id * M_BlockTileSize + - thread_m_cluster_id * MThreadSliceSize, - block_local_id), - PassThroughOp{}); - - auto threadwise_dbias_store = - ThreadwiseTensorSliceTransfer_v1r3 struct GridwiseBatchNormBackwardWithBlockwiseWelford { @@ -165,7 +165,7 @@ struct GridwiseBatchNormBackwardWithBlockwiseWelford const XYGridDesc_M_K dy_grid_desc_m_k, const XYGridDesc_M_K dx_grid_desc_m_k, const ScaleBiasGridDesc_M scale_grid_desc_m, - const ScaleBiasGridDesc_M bias_grid_desc_m, + const ScaleBiasGridDesc_M dscale_dbias_grid_desc_m, const MeanVarGridDesc_M mean_var_grid_desc_m, const GetReduceCountPerThreadFunctor get_reduce_count_per_thread, long_index_t reduce_size, @@ -179,8 +179,8 @@ struct GridwiseBatchNormBackwardWithBlockwiseWelford const MeanVarDataType* const __restrict__ p_savedInvVar, const DyElementwiseOp dy_elementwise_op, DxDataType* const __restrict__ p_dx, - ScaleDataType* const __restrict__ p_dscale, - BiasDataType* const __restrict__ p_dbias) + DscaleDbiasDataType* const __restrict__ p_dscale, + DscaleDbiasDataType* const __restrict__ p_dbias) { using ck::math::sqrt; @@ -253,7 +253,7 @@ struct GridwiseBatchNormBackwardWithBlockwiseWelford XSrcVectorSize, 1, true>( - x_grid_desc_m_k, + dy_grid_desc_m_k, make_multi_index(block_global_id * M_BlockTileSize + thread_m_cluster_id * MThreadSliceSize, thread_k_cluster_id * KThreadSliceSize)); @@ -271,7 +271,7 @@ struct GridwiseBatchNormBackwardWithBlockwiseWelford InMemoryDataOperationEnum::Set, 1, true>( - dy_grid_desc_m_k, + dx_grid_desc_m_k, make_multi_index(block_global_id * M_BlockTileSize + thread_m_cluster_id * MThreadSliceSize, thread_k_cluster_id * KThreadSliceSize), @@ -285,45 +285,27 @@ struct GridwiseBatchNormBackwardWithBlockwiseWelford ThreadBufferLengths_M, Sequence<0>, 0, - ScaleSrcDstVectorSize, + ScaleSrcVectorSize, 1, true>( scale_grid_desc_m, make_multi_index(block_global_id * M_BlockTileSize + thread_m_cluster_id * MThreadSliceSize)); - auto threadwise_dscale_store = + auto threadwise_dscale_dbias_store = ThreadwiseTensorSliceTransfer_v1r3, 0, - ScaleSrcDstVectorSize, + DscaleDbiasDstVectorSize, InMemoryDataOperationEnum::Set, 1, true>( - scale_grid_desc_m, - make_multi_index(block_global_id * M_BlockTileSize + - thread_m_cluster_id * MThreadSliceSize), - PassThroughOp{}); - - auto threadwise_dbias_store = - ThreadwiseTensorSliceTransfer_v1r3, - 0, - BiasDstVectorSize, - InMemoryDataOperationEnum::Set, - 1, - true>( - bias_grid_desc_m, + dscale_dbias_grid_desc_m, make_multi_index(block_global_id * M_BlockTileSize + thread_m_cluster_id * MThreadSliceSize), PassThroughOp{}); @@ -344,10 +326,10 @@ struct GridwiseBatchNormBackwardWithBlockwiseWelford p_scale, scale_grid_desc_m.GetElementSpaceSize()); auto dscale_global_buf = make_dynamic_buffer( - p_dscale, scale_grid_desc_m.GetElementSpaceSize()); + p_dscale, dscale_dbias_grid_desc_m.GetElementSpaceSize()); auto dbias_global_buf = make_dynamic_buffer( - p_dbias, bias_grid_desc_m.GetElementSpaceSize()); + p_dbias, dscale_dbias_grid_desc_m.GetElementSpaceSize()); // clang-format off // Step 1: calculating mean and inv-variance using welford method (if savedMean/savedInvVar not available), where inv-variance = 1/sqrt(epsilon+variance) @@ -487,17 +469,17 @@ struct GridwiseBatchNormBackwardWithBlockwiseWelford if(thread_k_cluster_id == 0) { - threadwise_dscale_store.Run(thread_buffer_desc_m, - make_tuple(I0), - dscale_thread_buf, - scale_grid_desc_m, - dscale_global_buf); + threadwise_dscale_dbias_store.Run(thread_buffer_desc_m, + make_tuple(I0), + dscale_thread_buf, + dscale_dbias_grid_desc_m, + dscale_global_buf); - threadwise_dbias_store.Run(thread_buffer_desc_m, - make_tuple(I0), - dbias_thread_buf, - bias_grid_desc_m, - dbias_global_buf); + threadwise_dscale_dbias_store.Run(thread_buffer_desc_m, + make_tuple(I0), + dbias_thread_buf, + dscale_dbias_grid_desc_m, + dbias_global_buf); }; // clang-format off diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_multiblock_welford_first_half.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_multiblock_welford_first_half.hpp deleted file mode 100644 index 1afe9f9752..0000000000 --- a/include/ck/tensor_operation/gpu/grid/gridwise_multiblock_welford_first_half.hpp +++ /dev/null @@ -1,258 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. - -#pragma once - -#include "ck/utility/data_type.hpp" -#include "ck/utility/math.hpp" -#include "ck/tensor_operation/gpu/block/blockwise_welford.hpp" -#include "ck/tensor_operation/gpu/thread/threadwise_welford.hpp" -#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" -#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" - -namespace ck { - -template -__global__ void kernel_multiblock_welford_first_half( - const XGridDesc_M_K x_grid_desc_m_k, - const MeanVarCountGridDesc_M_G mean_var_count_grid_desc_m_g, - const GetReduceCountPerThreadFunctor get_reduce_count_per_thread, - index_t num_k_block_tile_iteration, - const XDataType* const __restrict__ p_x, - MeanVarDataType* const p_welford_mean, - MeanVarDataType* const p_welford_variance, - int32_t* const p_welford_count) -{ - GridwiseMultiblockWelfordFirstHalf_::Run(x_grid_desc_m_k, - mean_var_count_grid_desc_m_g, - get_reduce_count_per_thread, - num_k_block_tile_iteration, - p_x, - p_welford_mean, - p_welford_variance, - p_welford_count); -}; - -template -struct GridwiseMultiblockWelfordFirstHalf -{ - static_assert((XSrcCountSrcVectorDim == 0 && MThreadSliceSize % XSrcCountSrcVectorSize == 0) || - (XSrcCountSrcVectorDim == 1 && - KThreadSliceSize % XSrcCountSrcVectorSize == 0), - "Invalid thread slice sizes and/or vector sizes configuration, please check!"); - - static constexpr bool reorder_thread_cluster = (XSrcCountSrcVectorDim == 0); - - using ThreadClusterLengths_M_K = Sequence; - - using ThreadBufferDimAccessOrder = - typename conditional, Sequence<0, 1>>::type; - - using ThreadClusterArrangeOrder = - typename conditional, Sequence<0, 1>>::type; - - static constexpr auto thread_cluster_desc = - make_cluster_descriptor(ThreadClusterLengths_M_K{}, ThreadClusterArrangeOrder{}); - - using ThreadReduceSrcDesc_M_K = decltype(make_naive_tensor_descriptor_packed( - make_tuple(Number{}, Number{}))); - using ThreadReduceDstDesc_M = - decltype(make_naive_tensor_descriptor_packed(make_tuple(Number{}))); - - using ThreadwiseWelford = - ThreadwiseWelford; - - using BlockwiseWelford = BlockwiseWelford; - - using PassThroughOp = tensor_operation::element_wise::PassThrough; - - static constexpr auto I0 = Number<0>{}; - static constexpr auto I1 = Number<1>{}; - - static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize; - static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize; - - __device__ static void Run(const XGridDesc_M_K& x_grid_desc_m_k, - const MeanVarCountGridDesc_M_G& mean_var_count_grid_desc_m_g, - const GetReduceCountPerThreadFunctor& get_reduce_count_per_thread, - index_t num_k_block_tile_iteration, - const XDataType* const __restrict__ p_x, - MeanVarDataType* const p_welford_mean, - MeanVarDataType* const p_welford_variance, - int32_t* const p_welford_count) - { - StaticBuffer - x_thread_buf; - - StaticBuffer - welford_mean_thread_buf; - StaticBuffer - welford_var_thread_buf; - StaticBuffer - welford_count_thread_buf; - - const index_t blkgroup_size = mean_var_count_grid_desc_m_g.GetLength(I1); - - const index_t thread_local_id = get_thread_local_1d_id(); - const index_t block_global_id = get_block_1d_id(); - const index_t blkgroup_id = block_global_id / blkgroup_size; - const index_t block_local_id = block_global_id % blkgroup_size; - - const auto thread_cluster_idx = - thread_cluster_desc.CalculateBottomIndex(make_multi_index(thread_local_id)); - - const auto thread_m_cluster_id = thread_cluster_idx[I0]; - const auto thread_k_cluster_id = thread_cluster_idx[I1]; - - using ThreadBufferLengths_M_K = Sequence; - using ThreadBufferLengths_M_1 = Sequence; - - constexpr auto thread_buffer_desc_m_k = make_naive_tensor_descriptor_packed( - make_tuple(Number{}, Number{})); - constexpr auto thread_buffer_desc_m_1 = make_naive_tensor_descriptor_packed( - make_tuple(Number{}, Number<1>{})); - - const index_t reduceSizePerBlock = K_BlockTileSize * num_k_block_tile_iteration; - - auto threadwise_x_load = ThreadwiseTensorSliceTransfer_v2( - x_grid_desc_m_k, - make_multi_index(blkgroup_id * M_BlockTileSize + thread_m_cluster_id * MThreadSliceSize, - block_local_id * reduceSizePerBlock + - thread_k_cluster_id * KThreadSliceSize)); - - auto threadwise_welford_mean_var_store = - ThreadwiseTensorSliceTransfer_v1r3, - 1, - 1, - InMemoryDataOperationEnum::Set, - 1, - true>( - mean_var_count_grid_desc_m_g, - make_multi_index(blkgroup_id * M_BlockTileSize + - thread_m_cluster_id * MThreadSliceSize, - block_local_id), - PassThroughOp{}); - - auto threadwise_welford_count_store = - ThreadwiseTensorSliceTransfer_v1r3, - 1, - 1, - InMemoryDataOperationEnum::Set, - 1, - true>( - mean_var_count_grid_desc_m_g, - make_multi_index(blkgroup_id * M_BlockTileSize + - thread_m_cluster_id * MThreadSliceSize, - block_local_id), - PassThroughOp{}); - - constexpr auto thread_copy_fwd_step_m_k = make_multi_index(0, K_BlockTileSize); - - const auto x_global_val_buf = make_dynamic_buffer( - p_x, x_grid_desc_m_k.GetElementSpaceSize()); - - auto welford_mean_global_val_buf = make_dynamic_buffer( - p_welford_mean, mean_var_count_grid_desc_m_g.GetElementSpaceSize()); - - auto welford_var_global_val_buf = make_dynamic_buffer( - p_welford_variance, mean_var_count_grid_desc_m_g.GetElementSpaceSize()); - - auto welford_count_global_val_buf = make_dynamic_buffer( - p_welford_count, mean_var_count_grid_desc_m_g.GetElementSpaceSize()); - - auto threadwise_welford = ThreadwiseWelford(); - threadwise_welford.max_count_ = - get_reduce_count_per_thread(block_local_id, thread_k_cluster_id); - - static_for<0, MThreadSliceSize, 1>{}([&](auto I) { - welford_mean_thread_buf(I) = type_convert(0.0f); - welford_var_thread_buf(I) = type_convert(0.0f); - }); - - for(index_t reducedTiles = 0; reducedTiles < num_k_block_tile_iteration; ++reducedTiles) - { - threadwise_x_load.Run(x_grid_desc_m_k, - x_global_val_buf, - thread_buffer_desc_m_k, - make_tuple(I0, I0), - x_thread_buf); - - threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, thread_copy_fwd_step_m_k); - threadwise_welford.Run(x_thread_buf, welford_mean_thread_buf, welford_var_thread_buf); - } - - static_for<0, MThreadSliceSize, 1>{}([&](auto I) { - if constexpr(I > 0) - block_sync_lds(); - - welford_count_thread_buf(I) = threadwise_welford.cur_count_; - BlockwiseWelford::Run( - welford_mean_thread_buf(I), welford_var_thread_buf(I), welford_count_thread_buf(I)); - }); - - if(thread_k_cluster_id == 0) - { - threadwise_welford_mean_var_store.Run(thread_buffer_desc_m_1, - make_tuple(I0, I0), - welford_mean_thread_buf, - mean_var_count_grid_desc_m_g, - welford_mean_global_val_buf); - - threadwise_welford_mean_var_store.Run(thread_buffer_desc_m_1, - make_tuple(I0, I0), - welford_var_thread_buf, - mean_var_count_grid_desc_m_g, - welford_var_global_val_buf); - - threadwise_welford_count_store.Run(thread_buffer_desc_m_1, - make_tuple(I0, I0), - welford_count_thread_buf, - mean_var_count_grid_desc_m_g, - welford_count_global_val_buf); - }; - } -}; - -} // namespace ck diff --git a/library/include/ck/library/reference_tensor_operation/cpu/reference_batchnorm_backward.hpp b/library/include/ck/library/reference_tensor_operation/cpu/reference_batchnorm_backward.hpp new file mode 100644 index 0000000000..0b621e88a0 --- /dev/null +++ b/library/include/ck/library/reference_tensor_operation/cpu/reference_batchnorm_backward.hpp @@ -0,0 +1,412 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include +#include + +#include "ck/utility/math_v2.hpp" +#include "ck/utility/ignore.hpp" +#include "ck/library/utility/host_common_util.hpp" +#include "ck/tensor_operation/gpu/device/device_batchnorm_backward.hpp" + +namespace ck { +namespace tensor_operation { +namespace host { + +template +struct ReferenceBatchNormBwd : public device::DeviceBatchNormBwd +{ + static_assert(Rank <= 6, "Bigger Rank size is not supported!"); + + static constexpr index_t NumInvariantDim = Rank - NumBatchNormReduceDim; + + struct Argument : public device::BaseArgument + { + Argument(const std::array xyLengths, + const std::array xStrides, + const std::array dxStrides, + const std::array dyStrides, + const std::array reduceDims, + const std::array bnScaleBiasMeanVarLengths, + const std::array bnScaleStrides, + const std::array bnDscaleDbiasStrides, + const std::array bnMeanVarStrides, + const XDataType* p_x, + const DyDataType* p_dy, + const ScaleDataType* p_scale, + const MeanVarDataType* p_savedMean, + const MeanVarDataType* p_savedInvVar, + double epsilon, + const DyElementwiseOp dy_elementwise_op, + DxDataType* p_dx, + DscaleDbiasDataType* p_dscale, + DscaleDbiasDataType* p_dbias) + : reduceDims_(reduceDims), + bnScaleBiasMeanVarLengths_(bnScaleBiasMeanVarLengths), + bnScaleStrides_(bnScaleStrides), + bnDscaleDbiasStrides_(bnDscaleDbiasStrides), + bnMeanVarStrides_(bnMeanVarStrides), + p_x_(p_x), + p_dy_(p_dy), + p_scale_(p_scale), + p_savedMean_(p_savedMean), + p_savedInvVar_(p_savedInvVar), + dy_elementwise_op_(dy_elementwise_op), + p_dx_(p_dx), + p_dscale_(p_dscale), + p_dbias_(p_dbias) + { + using ck::host_common::get_index_set; + + if(std::any_of( + reduceDims.begin(), reduceDims.end(), [](int d) { return d < 0 || d >= Rank; })) + throw std::runtime_error("Invalid reduce dimensions!"); + + // get invariant_dims[] and invariant_lengths[] + for(int dim = 0, i = 0; dim < Rank; dim++) + if(std::none_of( + reduceDims.begin(), reduceDims.end(), [&](int d) { return d == dim; })) + { + invariantDims_[i] = dim; + invariant_lengths_[i] = xyLengths[dim]; + i++; + }; + + // get reduce_lengths_[] + for(int j = 0, i = 0; j < NumBatchNormReduceDim; j++) + { + int dim = reduceDims[j]; + reduce_lengths_[i++] = xyLengths[dim]; + }; + + for(int i = 0; i < NumInvariantDim; i++) + if(invariant_lengths_[i] != bnScaleBiasMeanVarLengths_[i]) + throw std::runtime_error("Invalid lengths parameters!"); + + for(int j = 0, i = 0; j < NumInvariantDim; j++) + { + int dim = invariantDims_[j]; + x_invariant_strides_[i] = xStrides[dim]; + dy_invariant_strides_[i] = dyStrides[dim]; + dx_invariant_strides_[i] = dxStrides[dim]; + i++; + }; + + for(int j = 0, i = 0; j < NumBatchNormReduceDim; j++) + { + int dim = reduceDims_[j]; + x_reduce_strides_[i] = xStrides[dim]; + dy_reduce_strides_[i] = dyStrides[dim]; + dx_reduce_strides_[i] = dxStrides[dim]; + i++; + }; + + reduceSize_ = std::accumulate( + reduce_lengths_.begin(), reduce_lengths_.end(), 1, std::multiplies{}); + + invariant_index_set_ = get_index_set(invariant_lengths_); + reduce_index_set_ = get_index_set(reduce_lengths_); + + epsilon_ = type_convert(epsilon); + + haveSavedMeanInvVar_ = (p_savedMean != nullptr && p_savedInvVar != nullptr); + } + + std::array reduceDims_; + std::array invariantDims_; + std::array invariant_lengths_; + std::array reduce_lengths_; + + const std::array bnScaleBiasMeanVarLengths_; + const std::array bnScaleStrides_; + const std::array bnDscaleDbiasStrides_; + const std::array bnMeanVarStrides_; + + std::array x_invariant_strides_; + std::array dy_invariant_strides_; + std::array dx_invariant_strides_; + std::array x_reduce_strides_; + std::array dy_reduce_strides_; + std::array dx_reduce_strides_; + + const XDataType* p_x_; + const DyDataType* p_dy_; + const ScaleDataType* p_scale_; + const MeanVarDataType* p_savedMean_; + const MeanVarDataType* p_savedInvVar_; + const DyElementwiseOp dy_elementwise_op_; + + DxDataType* p_dx_; + DscaleDbiasDataType* p_dscale_; + DscaleDbiasDataType* p_dbias_; + + bool haveSavedMeanInvVar_; + + std::vector> invariant_index_set_; + std::vector> reduce_index_set_; + + AccDataType epsilon_; + size_t reduceSize_; + }; + + struct Invoker : public device::BaseInvoker + { + float Run(const Argument& arg) + { + using ck::host_common::get_offset_from_index; + + auto thread_reduce_func = [&](auto invariant_index) { + size_t x_invariant_offset = get_offset_from_index( + arg.x_invariant_strides_, invariant_index); + size_t dy_invariant_offset = get_offset_from_index( + arg.dy_invariant_strides_, invariant_index); + size_t dx_invariant_offset = get_offset_from_index( + arg.dx_invariant_strides_, invariant_index); + + AccDataType mean = type_convert(0.0f); + AccDataType variance = type_convert(0.0f); + AccDataType invVar; + int32_t curr_count = 0; + + if(arg.haveSavedMeanInvVar_) + { + size_t mean_invVar_invariant_offset = get_offset_from_index( + arg.bnMeanVarStrides_, invariant_index); + + mean = + type_convert(arg.p_savedMean_[mean_invVar_invariant_offset]); + invVar = + type_convert(arg.p_savedInvVar_[mean_invVar_invariant_offset]); + } + else + { + // compute mean, variance using welford method + for(const auto& reduce_index : arg.reduce_index_set_) + { + size_t x_reduce_offset = get_offset_from_index( + arg.x_reduce_strides_, reduce_index); + + auto x_offset = x_invariant_offset + x_reduce_offset; + + curr_count++; + + AccDataType x = type_convert(arg.p_x_[x_offset]); + + AccDataType delta = x - mean; + + mean += delta / curr_count; + + AccDataType delta2 = x - mean; + + variance += delta * delta2; + }; + + // actual variance + variance = variance / curr_count; + + // inv-variance defined as 1/sqrt(epsilon+variance) + invVar = + type_convert(1.0f) / ck::math::sqrt(arg.epsilon_ + variance); + }; + + AccDataType dbias = + type_convert(0.0f); // Sum on reduced dimensions of dy + AccDataType dscale = + type_convert(0.0f); // Sum on reduced dimensions of dy * norm_x + + // 1) calculate dy * (x - mean) * inv-variance + // 2) calculate sum(dy) on reduced dimensions + // 3) calculate sum(dy * norm_x) on reduced dimensions + for(const auto& reduce_index : arg.reduce_index_set_) + { + size_t x_reduce_offset = get_offset_from_index( + arg.x_reduce_strides_, reduce_index); + size_t dy_reduce_offset = get_offset_from_index( + arg.dy_reduce_strides_, reduce_index); + + auto x_offset = x_invariant_offset + x_reduce_offset; + auto dy_offset = dy_invariant_offset + dy_reduce_offset; + + AccDataType x = type_convert(arg.p_x_[x_offset]); + + AccDataType norm_x = (x - mean) * invVar; + AccDataType dy = type_convert(arg.p_dy_[dy_offset]); + + arg.dy_elementwise_op_(dy, dy); + + dbias += dy; + dscale += norm_x * dy; + }; + + size_t dscale_offset = get_offset_from_index( + arg.bnDscaleDbiasStrides_, invariant_index); + size_t dbias_offset = get_offset_from_index( + arg.bnDscaleDbiasStrides_, invariant_index); + + arg.p_dscale_[dscale_offset] = type_convert(dscale); + arg.p_dbias_[dbias_offset] = type_convert(dbias); + + size_t scale_offset = + get_offset_from_index(arg.bnScaleStrides_, invariant_index); + + AccDataType scale = type_convert(arg.p_scale_[scale_offset]); + + AccDataType multiplier = type_convert(1.0f) / + type_convert(arg.reduceSize_) * invVar * + scale; + + // 1) calculate tmp = dscale * (x - mean) * inv-variance + // 2) calculate dx = 1/reduceSize * inv-variance * scale * (reduceSize * dy - dbias + // - tmp) + for(const auto& reduce_index : arg.reduce_index_set_) + { + size_t x_reduce_offset = get_offset_from_index( + arg.x_reduce_strides_, reduce_index); + size_t dy_reduce_offset = get_offset_from_index( + arg.dy_reduce_strides_, reduce_index); + size_t dx_reduce_offset = get_offset_from_index( + arg.dx_reduce_strides_, reduce_index); + + auto x_offset = x_invariant_offset + x_reduce_offset; + auto dy_offset = dy_invariant_offset + dy_reduce_offset; + auto dx_offset = dx_invariant_offset + dx_reduce_offset; + + AccDataType x = type_convert(arg.p_x_[x_offset]); + + AccDataType norm_x = (x - mean) * invVar; + AccDataType dy = type_convert(arg.p_dy_[dy_offset]); + + arg.dy_elementwise_op_(dy, dy); + + AccDataType tmpVal = norm_x * dscale; + + AccDataType dx = multiplier * (type_convert(arg.reduceSize_) * dy - + dbias - tmpVal); + + arg.p_dx_[dx_offset] = type_convert(dx); + }; + }; + + std::size_t num_thread = std::thread::hardware_concurrency(); + std::size_t work_per_thread = + (arg.invariant_index_set_.size() + num_thread - 1) / num_thread; + + std::vector threads(num_thread); + + for(std::size_t it = 0; it < num_thread; ++it) + { + std::size_t i_begin = it * work_per_thread; + std::size_t i_end = std::min(static_cast((it + 1) * work_per_thread), + arg.invariant_index_set_.size()); + + auto f = [=] { + for(std::size_t i = i_begin; i < i_end; ++i) + { + thread_reduce_func(arg.invariant_index_set_[i]); + } + }; + + threads[it] = joinable_thread(f); + } + + return (0.0f); + }; + + float Run(const device::BaseArgument* p_arg, + const StreamConfig& /*stream_config*/ = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg)); + }; + }; + + bool IsSupportedArgument(const device::BaseArgument* p_arg) override + { + (void)p_arg; + + return (true); + }; + + std::unique_ptr + MakeArgumentPointer(const std::array xyLengths, + const std::array xStrides, + const std::array dxStrides, + const std::array dyStrides, + const std::array reduceDims, + const std::array bnScaleBiasMeanVarLengths, + const std::array bnScaleStrides, + const std::array bnDscaleDbiasStrides, + const std::array bnMeanVarStrides, + const void* p_x, + const void* p_dy, + const void* p_scale, + const void* p_savedMean, + const void* p_savedInvVar, + double epsilon, + const DyElementwiseOp dy_elementwise_op, + void* p_dx, + void* p_dscale, + void* p_dbias) override + { + return std::make_unique(xyLengths, + xStrides, + dxStrides, + dyStrides, + reduceDims, + bnScaleBiasMeanVarLengths, + bnScaleStrides, + bnDscaleDbiasStrides, + bnMeanVarStrides, + static_cast(p_x), + static_cast(p_dy), + static_cast(p_scale), + static_cast(p_savedMean), + static_cast(p_savedInvVar), + epsilon, + dy_elementwise_op, + static_cast(p_dx), + static_cast(p_dscale), + static_cast(p_dbias)); + }; + + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(); + }; + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "Reference_BatchNorm_Backward" << std::endl; + // clang-format on + + return str.str(); + } +}; + +} // namespace host +} // namespace tensor_operation +} // namespace ck diff --git a/library/include/ck/library/reference_tensor_operation/cpu/reference_batchnorm_backward_nhwc_c.hpp b/library/include/ck/library/reference_tensor_operation/cpu/reference_batchnorm_backward_nhwc_c.hpp deleted file mode 100644 index 64eb06a441..0000000000 --- a/library/include/ck/library/reference_tensor_operation/cpu/reference_batchnorm_backward_nhwc_c.hpp +++ /dev/null @@ -1,319 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. - -#pragma once - -#include -#include -#include - -#include "ck/tensor_operation/gpu/device/device_batchnorm_backward.hpp" - -namespace ck { -namespace tensor_operation { -namespace host { - -template -struct ReferenceBatchNormBwd_Input_N_H_W_C_Output_C - : public device::DeviceBatchNormBwd<4, 3, DyElementwiseOp> -{ - struct Argument : public device::BaseArgument - { - Argument(const std::array xyLengths, - const std::array xStrides, - const std::array dyStrides, - const std::array dxStrides, - const std::array reduceDims, - const std::array bnScaleBiasMeanVarLengths, - const std::array bnScaleStrides, - const std::array bnBiasStrides, - const std::array bnMeanVarStrides, - const XDataType* p_x, - const DyDataType* p_dy, - const ScaleDataType* p_scale, - const MeanVarDataType* p_savedMean, - const MeanVarDataType* p_savedInvVar, - double epsilon, - const DyElementwiseOp dy_elementwise_op, - DxDataType* p_dx, - ScaleDataType* p_dscale, - BiasDataType* p_dbias) - : p_x_(p_x), - p_dy_(p_dy), - p_scale_(p_scale), - p_savedMean_(p_savedMean), - p_savedInvVar_(p_savedInvVar), - epsilon_(epsilon), - dy_elementwise_op_(dy_elementwise_op), - p_dx_(p_dx), - p_dscale_(p_dscale), - p_dbias_(p_dbias) - { - ignore = xStrides; - ignore = dyStrides; - ignore = dxStrides; - ignore = bnScaleStrides; - ignore = bnBiasStrides; - ignore = bnMeanVarStrides; - - if(xyLengths.size() != 4 || bnScaleBiasMeanVarLengths.size() != 1 || - bnScaleBiasMeanVarLengths[0] != xyLengths[3]) - throw std::runtime_error("Invalid tensor dimensions!"); - - if(reduceDims[0] != 0 || reduceDims[1] != 1 || reduceDims[2] != 2) - throw std::runtime_error("Invalid reduce dimensions!"); - - n_ = xyLengths[0]; - h_ = xyLengths[1]; - w_ = xyLengths[2]; - c_ = xyLengths[3]; - - haveSavedMeanInvVar_ = (p_savedMean != nullptr && p_savedInvVar != nullptr); - } - - const XDataType* p_x_; - const DyDataType* p_dy_; - const ScaleDataType* p_scale_; - const MeanVarDataType* p_savedMean_; - const MeanVarDataType* p_savedInvVar_; - - double epsilon_; - const DyElementwiseOp dy_elementwise_op_; - - DxDataType* p_dx_; - ScaleDataType* p_dscale_; - BiasDataType* p_dbias_; - - bool haveSavedMeanInvVar_; - - index_t n_, h_, w_, c_; - }; - - struct Invoker : public device::BaseInvoker - { - float Run(const Argument& arg) - { - auto thread_reduce_func = [&](auto iC) { - AccDataType reduceSize = type_convert(arg.n_) * - type_convert(arg.h_) * - type_convert(arg.w_); - index_t offset_C = iC; - AccDataType mean; - AccDataType invVar; - - if(arg.haveSavedMeanInvVar_) - { - mean = arg.p_savedMean_[offset_C]; - invVar = arg.p_savedInvVar_[offset_C]; - } - else - { - AccDataType meansquare; - - meansquare = type_convert(0.0f); - mean = type_convert(0.0f); - - // compute mean, meanquare, variance, inv-variance - for(index_t iN = 0; iN < arg.n_; iN++) - { - index_t offset_N = iN * arg.h_ * arg.w_ * arg.c_; - for(index_t iH = 0; iH < arg.h_; iH++) - { - index_t offset_H = iH * arg.w_ * arg.c_; - for(index_t iW = 0; iW < arg.w_; iW++) - { - index_t offset_W = iW * arg.c_; - - auto offset = offset_N + offset_H + offset_W + offset_C; - - AccDataType x = type_convert(arg.p_x_[offset]); - - mean += x; - meansquare += x * x; - }; - } - }; - - mean = mean / reduceSize; - meansquare = meansquare / reduceSize; - - AccDataType variance = meansquare - mean * mean; - invVar = type_convert(1.0f) / - std::sqrt(type_convert(arg.epsilon_) + variance); - }; - - AccDataType dbias = type_convert(0.0f); // Sum on NHW of dy - AccDataType dscale = type_convert(0.0f); // Sum on NHW of dy * norm_x - - // 1) calculate dy * (x - mean) * inv-variance - // 2) calculate sum(dy) on NHW dimensions - // 3) calculate sum(dy * norm_x) on NHW dimensions - for(index_t iN = 0; iN < arg.n_; iN++) - { - index_t offset_N = iN * arg.h_ * arg.w_ * arg.c_; - for(index_t iH = 0; iH < arg.h_; iH++) - { - index_t offset_H = iH * arg.w_ * arg.c_; - for(index_t iW = 0; iW < arg.w_; iW++) - { - index_t offset_W = iW * arg.c_; - - auto offset = offset_N + offset_H + offset_W + offset_C; - - AccDataType x = type_convert(arg.p_x_[offset]); - - AccDataType norm_x = (x - mean) * invVar; - AccDataType dy = type_convert(arg.p_dy_[offset]); - - arg.dy_elementwise_op_(dy, dy); - - dbias += dy; - dscale += norm_x * dy; - }; - } - }; - - arg.p_dscale_[offset_C] = type_convert(dscale); - arg.p_dbias_[offset_C] = type_convert(dbias); - - AccDataType scale = type_convert(arg.p_scale_[offset_C]); - AccDataType multiplier = - type_convert(1.0f) / reduceSize * invVar * scale; - - // 1) calculate tmp = dscale * (x - mean) * inv-variance - // 2) calculate dx = 1/nhw * inv-variance * scale * (nhw * dy - dbias - tmp) - for(index_t iN = 0; iN < arg.n_; iN++) - { - index_t offset_N = iN * arg.h_ * arg.w_ * arg.c_; - for(index_t iH = 0; iH < arg.h_; iH++) - { - index_t offset_H = iH * arg.w_ * arg.c_; - for(index_t iW = 0; iW < arg.w_; iW++) - { - index_t offset_W = iW * arg.c_; - - auto offset = offset_N + offset_H + offset_W + offset_C; - - AccDataType x = type_convert(arg.p_x_[offset]); - - AccDataType norm_x = (x - mean) * invVar; - AccDataType dy = type_convert(arg.p_dy_[offset]); - - arg.dy_elementwise_op_(dy, dy); - - AccDataType tmpVal = norm_x * dscale; - - AccDataType dx = multiplier * (reduceSize * dy - dbias - tmpVal); - - arg.p_dx_[offset] = type_convert(dx); - }; - } - }; - }; - - std::size_t num_thread = std::thread::hardware_concurrency(); - std::size_t work_per_thread = (arg.c_ + num_thread - 1) / num_thread; - - std::vector threads(num_thread); - - for(std::size_t it = 0; it < num_thread; ++it) - { - std::size_t ic_begin = it * work_per_thread; - std::size_t ic_end = std::min(static_cast((it + 1) * work_per_thread), arg.c_); - - auto f = [=] { - for(std::size_t ic = ic_begin; ic < ic_end; ++ic) - { - thread_reduce_func(ic); - } - }; - - threads[it] = joinable_thread(f); - } - - return (0.0f); - }; - - float Run(const device::BaseArgument* p_arg, - const StreamConfig& /*stream_config*/ = StreamConfig{}) override - { - return Run(*dynamic_cast(p_arg)); - }; - }; - - bool IsSupportedArgument(const device::BaseArgument* p_arg) override - { - (void)p_arg; - - return (true); - }; - - std::unique_ptr - MakeArgumentPointer(const std::array xyLengths, - const std::array xStrides, - const std::array dyStrides, - const std::array dxStrides, - const std::array reduceDims, - const std::array bnScaleBiasMeanVarLengths, - const std::array bnScaleStrides, - const std::array bnBiasStrides, - const std::array bnMeanVarStrides, - const void* p_x, - const void* p_dy, - const void* p_scale, - const void* p_savedMean, - const void* p_savedInvVar, - double epsilon, - const DyElementwiseOp dy_elementwise_op, - void* p_dx, - void* p_dscale, - void* p_dbias) override - { - return std::make_unique(xyLengths, - xStrides, - dyStrides, - dxStrides, - reduceDims, - bnScaleBiasMeanVarLengths, - bnScaleStrides, - bnBiasStrides, - bnMeanVarStrides, - static_cast(p_x), - static_cast(p_dy), - static_cast(p_scale), - static_cast(p_savedMean), - static_cast(p_savedInvVar), - epsilon, - dy_elementwise_op, - static_cast(p_dx), - static_cast(p_dscale), - static_cast(p_dbias)); - }; - - std::unique_ptr MakeInvokerPointer() override - { - return std::make_unique(); - }; - - std::string GetTypeString() const override - { - auto str = std::stringstream(); - - // clang-format off - str << "Reference_BatchNorm_Backward_NHWC_C<" << std::endl; - // clang-format on - - return str.str(); - } -}; - -} // namespace host -} // namespace tensor_operation -} // namespace ck diff --git a/library/include/ck/library/tensor_operation_instance/gpu/batchnorm_backward.hpp b/library/include/ck/library/tensor_operation_instance/gpu/batchnorm_backward.hpp new file mode 100644 index 0000000000..c84ffcff8c --- /dev/null +++ b/library/include/ck/library/tensor_operation_instance/gpu/batchnorm_backward.hpp @@ -0,0 +1,124 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/device_batchnorm_backward.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// FP16 +void add_device_batchnorm_backward_rank_4_3_f16_instances( + std::vector>>&); + +// FP32 +void add_device_batchnorm_backward_rank_4_3_f32_instances( + std::vector>>&); + +// BF16 +void add_device_batchnorm_backward_rank_4_3_bf16_instances( + std::vector>>&); + +// FP64 +void add_device_batchnorm_backward_rank_4_3_f64_instances( + std::vector>>&); + +template +struct DeviceOperationInstanceFactory< + ck::tensor_operation::device::DeviceBatchNormBwd> +{ + using DeviceOp = DeviceBatchNormBwd; + + static auto GetInstances() + { + std::vector> op_ptrs; + + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v && + is_same_v && is_same_v && + is_same_v) + { + if constexpr(Rank == 4 && NumReduceDim == 3 && is_same_v) + { + add_device_batchnorm_backward_rank_4_3_f16_instances(op_ptrs); + } + } + else if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v && + is_same_v && is_same_v && + is_same_v) + { + if constexpr(Rank == 4 && NumReduceDim == 3 && is_same_v) + { + add_device_batchnorm_backward_rank_4_3_f32_instances(op_ptrs); + } + } + else if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v && + is_same_v && is_same_v && + is_same_v) + { + if constexpr(Rank == 4 && NumReduceDim == 3 && is_same_v) + { + add_device_batchnorm_backward_rank_4_3_bf16_instances(op_ptrs); + } + } + else if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v && + is_same_v && is_same_v && + is_same_v) + { + if constexpr(Rank == 4 && NumReduceDim == 3 && is_same_v) + { + add_device_batchnorm_backward_rank_4_3_f64_instances(op_ptrs); + } + } + + return op_ptrs; + } +}; + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/batchnorm/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/batchnorm/CMakeLists.txt index 8947ecb9c1..d12a2f244f 100644 --- a/library/src/tensor_operation_instance/gpu/batchnorm/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/batchnorm/CMakeLists.txt @@ -3,4 +3,8 @@ add_instance_library(device_batchnorm_instance device_batchnorm_forward_f32_instance.cpp device_batchnorm_forward_bf16_instance.cpp device_batchnorm_forward_f64_instance.cpp + device_batchnorm_backward_f16_instance.cpp + device_batchnorm_backward_f32_instance.cpp + device_batchnorm_backward_bf16_instance.cpp + device_batchnorm_backward_f64_instance.cpp ) diff --git a/library/src/tensor_operation_instance/gpu/batchnorm/device_batchnorm_backward_bf16_instance.cpp b/library/src/tensor_operation_instance/gpu/batchnorm/device_batchnorm_backward_bf16_instance.cpp new file mode 100644 index 0000000000..b62c8b99cb --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/batchnorm/device_batchnorm_backward_bf16_instance.cpp @@ -0,0 +1,146 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_batchnorm_backward_impl.hpp" +#include "ck/utility/data_type.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using BF16 = ck::bhalf_t; +using F32 = float; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +// clang-format off +template +using device_batchnorm_backward_bf16_blockwise_instances = + std::tuple < + // XDataType, DxDataType, DyDataType, AccDataType, ScaleDataType, DscaleDbiasDataType, MeanVarDataType, DyElementwiseOp, Rank, NumReduceDim, UseMultiBlockInK, BLockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XDyDxVectorDim, XSrcVectorSize, DySrcVectorSize, DxDstVectorSize, ScaleSrcVectorSize, DscaleDbiasDstVectorSize, MeanVarSrcVectorSize + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl + >; +// clang-format on + +// clang-format off +template +using device_batchnorm_backward_bf16_multiblock_instances = + std::tuple < + // XDataType, DxDataType, DyDataType, AccDataType, ScaleDataType, BiasDataType, MeanVarDataType, DyElementwiseOp, Rank, NumReduceDim, UseMultiBlockInK, BLockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XDyDxVectorDim, XSrcVectorSize, DySrcVectorSize, DxDstVectorSize, ScaleSrcDstVectorSize, BiasDstVectorSize, MeanVarSrcVectorSize + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl + >; +// clang-format on + +void add_device_batchnorm_backward_rank_4_3_bf16_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, device_batchnorm_backward_bf16_blockwise_instances<4, 3, PassThrough>{}); + add_device_operation_instances( + instances, device_batchnorm_backward_bf16_multiblock_instances<4, 3, PassThrough>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/batchnorm/device_batchnorm_backward_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/batchnorm/device_batchnorm_backward_f16_instance.cpp new file mode 100644 index 0000000000..d05b8b592c --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/batchnorm/device_batchnorm_backward_f16_instance.cpp @@ -0,0 +1,147 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_batchnorm_backward_impl.hpp" +#include "ck/utility/data_type.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +// clang-format off +template +using device_batchnorm_backward_f16_blockwise_instances = + std::tuple < + // XDataType, DxDataType, DyDataType, AccDataType, ScaleDataType, DscaleDbiasDataType, MeanVarDataType, DyElementwiseOp, Rank, NumReduceDim, UseMultiBlockInK, BLockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XDyDxVectorDim, XSrcVectorSize, DySrcVectorSize, DxDstVectorSize, ScaleSrcVectorSize, DscaleDbiasDstVectorSize, MeanVarSrcVectorSize + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl + >; +// clang-format on + +// clang-format off +template +using device_batchnorm_backward_f16_multiblock_instances = + std::tuple < + // XDataType, DxDataType, DyDataType, AccDataType, ScaleDataType, BiasDataType, MeanVarDataType, DyElementwiseOp, Rank, NumReduceDim, UseMultiBlockInK, BLockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XDyDxVectorDim, XSrcVectorSize, DySrcVectorSize, DxDstVectorSize, ScaleSrcDstVectorSize, BiasDstVectorSize, MeanVarSrcVectorSize + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl + >; +// clang-format on + +void add_device_batchnorm_backward_rank_4_3_f16_instances( + std::vector< + std::unique_ptr>>& + instances) +{ + add_device_operation_instances( + instances, device_batchnorm_backward_f16_blockwise_instances<4, 3, PassThrough>{}); + add_device_operation_instances( + instances, device_batchnorm_backward_f16_multiblock_instances<4, 3, PassThrough>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/batchnorm/device_batchnorm_backward_f32_instance.cpp b/library/src/tensor_operation_instance/gpu/batchnorm/device_batchnorm_backward_f32_instance.cpp new file mode 100644 index 0000000000..e3ef95d12e --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/batchnorm/device_batchnorm_backward_f32_instance.cpp @@ -0,0 +1,145 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_batchnorm_backward_impl.hpp" +#include "ck/utility/data_type.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F32 = float; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +// clang-format off +template +using device_batchnorm_backward_f32_blockwise_instances = std::tuple< + // XDataType, DxDataType, DyDataType, AccDataType, ScaleDataType, DscaleDbiasDataType, MeanVarDataType, DyElementwiseOp, Rank, NumReduceDim, UseMultiBlockInK, BLockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XDyDxVectorDim, XSrcVectorSize, DySrcVectorSize, DxDstVectorSize, ScaleSrcVectorSize, DscaleDbiasDstVectorSize, MeanVarSrcVectorSize + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl + >; +// clang-format on + +// clang-format off +template +using device_batchnorm_backward_f32_multiblock_instances = + std::tuple < + // XDataType, DxDataType, DyDataType, AccDataType, ScaleDataType, BiasDataType, MeanVarDataType, DyElementwiseOp, Rank, NumReduceDim, UseMultiBlockInK, BLockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XDyDxVectorDim, XSrcVectorSize, DySrcVectorSize, DxDstVectorSize, ScaleSrcDstVectorSize, BiasDstVectorSize, MeanVarSrcVectorSize + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl + >; +// clang-format on + +void add_device_batchnorm_backward_rank_4_3_f32_instances( + std::vector< + std::unique_ptr>>& + instances) +{ + add_device_operation_instances( + instances, device_batchnorm_backward_f32_blockwise_instances<4, 3, PassThrough>{}); + add_device_operation_instances( + instances, device_batchnorm_backward_f32_multiblock_instances<4, 3, PassThrough>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/batchnorm/device_batchnorm_backward_f64_instance.cpp b/library/src/tensor_operation_instance/gpu/batchnorm/device_batchnorm_backward_f64_instance.cpp new file mode 100644 index 0000000000..41be396c24 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/batchnorm/device_batchnorm_backward_f64_instance.cpp @@ -0,0 +1,145 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_batchnorm_backward_impl.hpp" +#include "ck/utility/data_type.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F64 = double; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +// clang-format off +template +using device_batchnorm_backward_f64_blockwise_instances = std::tuple< + // XDataType, DxDataType, DyDataType, AccDataType, ScaleDataType, DscaleDbiasDataType, MeanVarDataType, DyElementwiseOp, Rank, NumReduceDim, UseMultiBlockInK, BLockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XDyDxVectorDim, XSrcVectorSize, DySrcVectorSize, DxDstVectorSize, ScaleSrcVectorSize, DscaleDbiasDstVectorSize, MeanVarSrcVectorSize + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl + >; +// clang-format on + +// clang-format off +template +using device_batchnorm_backward_f64_multiblock_instances = + std::tuple < + // XDataType, DxDataType, DyDataType, AccDataType, ScaleDataType, BiasDataType, MeanVarDataType, DyElementwiseOp, Rank, NumReduceDim, UseMultiBlockInK, BLockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XDyDxVectorDim, XSrcVectorSize, DySrcVectorSize, DxDstVectorSize, ScaleSrcDstVectorSize, BiasDstVectorSize, MeanVarSrcVectorSize + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl, + DeviceBatchNormBwdImpl + >; +// clang-format on + +void add_device_batchnorm_backward_rank_4_3_f64_instances( + std::vector< + std::unique_ptr>>& + instances) +{ + add_device_operation_instances( + instances, device_batchnorm_backward_f64_blockwise_instances<4, 3, PassThrough>{}); + add_device_operation_instances( + instances, device_batchnorm_backward_f64_multiblock_instances<4, 3, PassThrough>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/profiler/CMakeLists.txt b/profiler/CMakeLists.txt index aad40cc79f..0dccfff476 100644 --- a/profiler/CMakeLists.txt +++ b/profiler/CMakeLists.txt @@ -27,6 +27,7 @@ set(PROFILER_SOURCE src/profile_layernorm.cpp src/profile_softmax.cpp src/profile_batchnorm_fwd.cpp + src/profile_batchnorm_bwd.cpp ) add_executable(ckProfiler ${PROFILER_SOURCE}) diff --git a/profiler/include/profile_batchnorm_backward_impl.hpp b/profiler/include/profile_batchnorm_backward_impl.hpp new file mode 100644 index 0000000000..79d8862081 --- /dev/null +++ b/profiler/include/profile_batchnorm_backward_impl.hpp @@ -0,0 +1,390 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/tensor_operation_instance/gpu/batchnorm_backward.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_batchnorm_backward.hpp" + +namespace ck { +namespace profiler { + +template +bool profile_batchnorm_backward_impl(bool do_verification, + int init_method, + bool do_dumpout, + bool time_kernel, + const std::vector inOutLengths, + const std::vector reduceDims, + bool haveSavedMeanInvVar, + double epsilon) +{ + if(inOutLengths.size() != Rank || reduceDims.size() != NumBatchNormReduceDim) + { + throw std::runtime_error("Invalid tensor lengths or number of reduce dimensions!"); + }; + + std::vector scaleBiasMeanVarLengths; + + // used for calculating the effective transferred bytes by each operation + size_t total_length; + size_t invariant_length = 1; + + total_length = + std::accumulate(inOutLengths.begin(), inOutLengths.end(), 1, std::multiplies{}); + + if(std::any_of(reduceDims.begin(), reduceDims.end(), [](int d) { return d < 0 || d >= Rank; })) + throw std::runtime_error("Invalid reduce dimensions!"); + + for(int dim = 0; dim < Rank; dim++) + { + if(std::none_of(reduceDims.begin(), reduceDims.end(), [&](int d) { return dim == d; })) + { + scaleBiasMeanVarLengths.push_back(inOutLengths[dim]); + invariant_length *= inOutLengths[dim]; + }; + } + + // input data of the batchnorm backward algorithm + Tensor x(inOutLengths); + Tensor dy(inOutLengths); + Tensor bnScale(scaleBiasMeanVarLengths); + + Tensor savedMean(scaleBiasMeanVarLengths); + Tensor savedInvVar(scaleBiasMeanVarLengths); + // savedVariance is only used for initializing savedInvVar + Tensor savedVariance(scaleBiasMeanVarLengths); + + // output data of the batchnorm backward algorithm + Tensor dx_ref(inOutLengths); + Tensor dx(inOutLengths); + + Tensor dscale(scaleBiasMeanVarLengths); + Tensor dbias(scaleBiasMeanVarLengths); + + Tensor dscale_ref(scaleBiasMeanVarLengths); + Tensor dbias_ref(scaleBiasMeanVarLengths); + + auto inOutStrides = x.mDesc.GetStrides(); + auto scaleBiasMeanVarStrides = bnScale.mDesc.GetStrides(); + + std::size_t num_thread = std::thread::hardware_concurrency(); + + if(haveSavedMeanInvVar) + { + const float x_mean = 0.0f; + const float x_stddev = 1.0f; + const float noise_stddev = 0.0001f; + + // input data in normal distribution + x.GenerateTensorValue(GeneratorTensor_4{x_mean, x_stddev}, num_thread); + + // initialize the savedMean to be values with tiny variation to the mean of the x values + savedMean.GenerateTensorValue(GeneratorTensor_4{x_mean, noise_stddev}, + num_thread); + + // initialize the variance to be values with tiny variation to the variance of the x values + savedVariance.GenerateTensorValue( + GeneratorTensor_4{x_stddev * x_stddev, noise_stddev}, num_thread); + + auto it_src = savedVariance.mData.begin(); + auto it_dst = savedInvVar.mData.begin(); + float tmp_epsilon = std::numeric_limits::epsilon(); + + while(it_src != savedVariance.mData.end()) + { + *it_dst = type_convert( + 1.0f / std::sqrtf(type_convert(*it_src) + tmp_epsilon)); + + it_src++; + it_dst++; + }; + } + else + { + const float x_mean = 0.0f; + const float x_stddev = 1.0f; + + // input data in normal distribution + x.GenerateTensorValue(GeneratorTensor_4{x_mean, x_stddev}, num_thread); + }; + + if(do_verification) + { + switch(init_method) + { + case 0: + dy.GenerateTensorValue(GeneratorTensor_0{}, num_thread); + bnScale.GenerateTensorValue(GeneratorTensor_0{}, num_thread); + break; + case 1: + dy.GenerateTensorValue(GeneratorTensor_1{1}, num_thread); + bnScale.GenerateTensorValue(GeneratorTensor_1{1}, num_thread); + break; + case 2: + dy.GenerateTensorValue(GeneratorTensor_2{-2, 2}, num_thread); + bnScale.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); + break; + default: + dy.GenerateTensorValue(GeneratorTensor_3{-0.2f, 0.2f}, num_thread); + bnScale.GenerateTensorValue(GeneratorTensor_3{-0.5f, 0.5f}, num_thread); + } + }; + + // input data of the batchnorm backward algorithm + DeviceMem x_dev(sizeof(XDataType) * x.mDesc.GetElementSpaceSize()); + DeviceMem dy_dev(sizeof(DyDataType) * dy.mDesc.GetElementSpaceSize()); + + DeviceMem bnScale_dev(sizeof(ScaleDataType) * bnScale.mDesc.GetElementSpaceSize()); + + DeviceMem savedMean_dev(sizeof(MeanVarDataType) * savedMean.mDesc.GetElementSpaceSize()); + DeviceMem savedInvVar_dev(sizeof(MeanVarDataType) * savedInvVar.mDesc.GetElementSpaceSize()); + + // output data of the batchnorm backward algorithm + DeviceMem dx_dev(sizeof(DxDataType) * dx.mDesc.GetElementSpaceSize()); + + DeviceMem dscale_dev(sizeof(DscaleDbiasDataType) * dscale.mDesc.GetElementSpaceSize()); + DeviceMem dbias_dev(sizeof(DscaleDbiasDataType) * dbias.mDesc.GetElementSpaceSize()); + + x_dev.ToDevice(x.mData.data()); + dy_dev.ToDevice(dy.mData.data()); + bnScale_dev.ToDevice(bnScale.mData.data()); + + if(haveSavedMeanInvVar) + { + savedMean_dev.ToDevice(savedMean.mData.data()); + savedInvVar_dev.ToDevice(savedInvVar.mData.data()); + }; + + std::array arrInOutLengths; + std::array arrInOutStrides; + std::array arrScaleBiasMeanVarLengths; + std::array arrScaleBiasMeanVarStrides; + std::array arrReduceDims; + + std::copy(inOutLengths.begin(), inOutLengths.end(), arrInOutLengths.begin()); + std::copy(inOutStrides.begin(), inOutStrides.end(), arrInOutStrides.begin()); + std::copy(scaleBiasMeanVarLengths.begin(), + scaleBiasMeanVarLengths.end(), + arrScaleBiasMeanVarLengths.begin()); + std::copy(scaleBiasMeanVarStrides.begin(), + scaleBiasMeanVarStrides.end(), + arrScaleBiasMeanVarStrides.begin()); + + std::copy(reduceDims.begin(), reduceDims.end(), arrReduceDims.begin()); + + using PassThroughOp = ck::tensor_operation::element_wise::PassThrough; + + // add device batchnorm-backward instances + using DeviceOp = ck::tensor_operation::device::DeviceBatchNormBwd; + + // get device op instances + const auto instance_ptrs = + ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + std::cout << "found " << instance_ptrs.size() << " instances" << std::endl; + + std::string best_instance_name; + float best_avg_time = std::numeric_limits::max(); + float best_gb_per_sec = 0; + + if(do_verification) + { + using ReferenceBatchNormBwdInstance = + ck::tensor_operation::host::ReferenceBatchNormBwd; + + auto batchNormBwd_ref = ReferenceBatchNormBwdInstance{}; + + auto argument_ptr_ref = batchNormBwd_ref.MakeArgumentPointer( + arrInOutLengths, + arrInOutStrides, + arrInOutStrides, + arrInOutStrides, + arrReduceDims, + arrScaleBiasMeanVarLengths, + arrScaleBiasMeanVarStrides, + arrScaleBiasMeanVarStrides, + arrScaleBiasMeanVarStrides, + x.mData.data(), + dy.mData.data(), + bnScale.mData.data(), + haveSavedMeanInvVar ? savedMean.mData.data() : nullptr, + haveSavedMeanInvVar ? savedInvVar.mData.data() : nullptr, + epsilon, + PassThroughOp{}, + dx_ref.mData.data(), + dscale_ref.mData.data(), + dbias_ref.mData.data()); + + if(!batchNormBwd_ref.IsSupportedArgument(argument_ptr_ref.get())) + { + std::cout << "The runtime parameters not supported by the reference instance, exiting!" + << std::endl; + return (false); + }; + + auto invoker_ptr_ref = batchNormBwd_ref.MakeInvokerPointer(); + + (void)invoker_ptr_ref->Run(argument_ptr_ref.get()); + } + + int num_kernel = 0; + bool pass = true; + + for(auto& inst_ptr : instance_ptrs) + { + auto argument_ptr = inst_ptr->MakeArgumentPointer( + arrInOutLengths, + arrInOutStrides, + arrInOutStrides, + arrInOutStrides, + arrReduceDims, + arrScaleBiasMeanVarLengths, + arrScaleBiasMeanVarStrides, + arrScaleBiasMeanVarStrides, + arrScaleBiasMeanVarStrides, + x_dev.GetDeviceBuffer(), + dy_dev.GetDeviceBuffer(), + bnScale_dev.GetDeviceBuffer(), + haveSavedMeanInvVar ? savedMean_dev.GetDeviceBuffer() : nullptr, + haveSavedMeanInvVar ? savedInvVar_dev.GetDeviceBuffer() : nullptr, + epsilon, + PassThroughOp{}, + dx_dev.GetDeviceBuffer(), + dscale_dev.GetDeviceBuffer(), + dbias_dev.GetDeviceBuffer()); + + if(inst_ptr->IsSupportedArgument(argument_ptr.get())) + { + num_kernel++; + } + else + { + if(time_kernel) + { + std::cout << inst_ptr->GetTypeString() + << " skipped due to unsupported argument: " << std::endl; + } + + continue; + }; + + size_t workspace_sz = inst_ptr->GetWorkSpaceSize(argument_ptr.get()); + + DeviceMem workspace_dev(workspace_sz); + + inst_ptr->SetWorkSpacePointer(argument_ptr.get(), workspace_dev.GetDeviceBuffer()); + + auto invoker_ptr = inst_ptr->MakeInvokerPointer(); + + float avg_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel}); + + size_t num_bytes = 0; + + // inputing of x, dy, scale, outputing of dx, dscale, dbias + num_bytes += total_length * (sizeof(XDataType) + sizeof(DyDataType) + sizeof(DxDataType)) + + invariant_length * sizeof(DscaleDbiasDataType) * 2; + + // inputting of savedMean, savedInvVariance + if(haveSavedMeanInvVar) + num_bytes += invariant_length * sizeof(MeanVarDataType) * 2; + + float gb_per_sec = num_bytes / 1.E6 / avg_time; + + if(time_kernel) + std::cout << "Perf: " << avg_time << " ms, " << gb_per_sec << " GB/s, " + << inst_ptr->GetTypeString() << std::endl; + + if(avg_time < best_avg_time) + { + best_instance_name = inst_ptr->GetTypeString(); + best_avg_time = avg_time; + best_gb_per_sec = gb_per_sec; + } + + if(do_verification) + { + using ck::utils::check_err; + bool single_pass = true; + + dx_dev.FromDevice(dx.mData.data()); + dscale_dev.FromDevice(dscale.data()); + dbias_dev.FromDevice(dbias.data()); + + // clang-format off + single_pass = single_pass && ck::utils::check_err(dx.mData, dx_ref.mData, "dx result:", 5e-4, 5e-4); + single_pass = single_pass && ck::utils::check_err(dscale.mData, dscale_ref.mData, "dScale result:", 3e-3, 3e-3); + single_pass = single_pass && ck::utils::check_err(dbias.mData, dbias_ref.mData, "dBias result:", 3e-3, 3e-3); + // clang-format on + + pass = pass && single_pass; + }; + + if(do_dumpout) + { + using ck::host_common::dumpBufferToFile; + + // clang-format off + dumpBufferToFile("dump_x.bin", x.mData.data(), x.mDesc.GetElementSize()); + dumpBufferToFile("dump_dy.bin", dy.mData.data(), dy.mDesc.GetElementSize()); + dumpBufferToFile("dump_dx.bin", dx.mData.data(), dx.mDesc.GetElementSize()); + dumpBufferToFile("dump_dx_ref.bin", dx_ref.mData.data(), dx_ref.mDesc.GetElementSize()); + dumpBufferToFile("dump_dscale.bin", dscale.mData.data(), dscale.mDesc.GetElementSize()); + dumpBufferToFile("dump_dscale_ref.bin", dscale_ref.mData.data(), dscale_ref.mDesc.GetElementSize()); + // clang-format off + }; + } + + if(time_kernel) + { + std::cout << "best perf = " << best_avg_time << " ms, " << best_gb_per_sec << " GB/s, " + << best_instance_name << std::endl; + } + + if(num_kernel == 0) + { + std::cout << "Error: No kernel is applicable" << std::endl; + return false; + } + + return pass; +} + +} // namespace profiler +} // namespace ck diff --git a/profiler/src/profile_batchnorm_bwd.cpp b/profiler/src/profile_batchnorm_bwd.cpp new file mode 100644 index 0000000000..d5938a1e6b --- /dev/null +++ b/profiler/src/profile_batchnorm_bwd.cpp @@ -0,0 +1,204 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include + +#include "ck/library/utility/host_common_util.hpp" +#include "profiler/include/profile_batchnorm_backward_impl.hpp" + +using ck::index_t; + +using namespace std; + +static const struct option long_options[] = {{"inOutLengths", required_argument, nullptr, 'D'}, + {"reduceDims", required_argument, nullptr, 'R'}, + {"dumpout", required_argument, nullptr, 'o'}, + {"verify", required_argument, nullptr, 'v'}, + {"help", no_argument, nullptr, '?'}, + {nullptr, 0, nullptr, 0}}; + +class BatchnormBwdArgParser +{ + private: + int option_index = 0; + + public: + std::vector inLengths; + std::vector reduceDims; + + bool do_verification = false; + bool do_dumpout = false; + + bool haveSavedMeanInvVar; + + int data_type = 0; + int init_method = 2; + bool time_kernel = false; + + BatchnormBwdArgParser() = default; + ~BatchnormBwdArgParser() = default; + + void show_usage(const char* cmd) + { + // clang-format off + std::cout << "Usage of " << cmd << std::endl; + std::cout << "--inOutLengths or -D, comma separated list of input tensor dimension lengths, must have 4 integers for nhwc" << std::endl; + std::cout << "--reduceDims or -R, comma separated list of dimensions to reduce on" << std::endl; + std::cout << "--verify or -v, 1/0 to indicate whether to verify the result by comparing with the host-based batch-normalization" << std::endl; + std::cout << "Arg1: data type (0: fp16, 1: fp32, 5: bp16, 6: fp64)" << std::endl; + std::cout << "Arg2 -- 1/0 to indicate whether to use saved mean and invVariance" << std::endl; + std::cout << "Arg3 -- init method used for dy and bnScale (0=no init, 1=single integer value, 2=scope integer value, 3=decimal value)" << std::endl; + std::cout << "Arg4 -- time kernel (0=no, 1=yes)" << std::endl; + // clang-format on + }; + + int operator()(int argc, char* argv[]) + { + using ck::host_common::getTypeValuesFromString; + + int ch; + + optind++; // to skip the module name + + while(1) + { + ch = getopt_long(argc, argv, "D:R:v:o:", long_options, &option_index); + if(ch == -1) + break; + switch(ch) + { + case 'D': + if(!optarg) + throw std::runtime_error("Invalid option format!"); + + inLengths = getTypeValuesFromString(optarg); + break; + case 'R': + if(!optarg) + throw std::runtime_error("Invalid option format!"); + + reduceDims = getTypeValuesFromString(optarg); + break; + case 'v': + if(!optarg) + throw std::runtime_error("Invalid option format!"); + + do_verification = static_cast(std::atoi(optarg)); + break; + case 'o': + if(!optarg) + throw std::runtime_error("Invalid option format!"); + + do_dumpout = static_cast(std::atoi(optarg)); + break; + case '?': + if(std::string(long_options[option_index].name) == "help") + { + show_usage(argv[0]); + return -1; + }; + break; + + default: + show_usage(argv[0]); + std::cerr << "Invalid cmd-line options!" << std::endl; + return -1; + }; + }; + + if(optind + 4 > argc) + throw std::runtime_error("Invalid cmd-line arguments, more argumetns are needed!"); + + data_type = std::atoi(argv[optind++]); + haveSavedMeanInvVar = std::atoi(argv[optind++]); + init_method = std::atoi(argv[optind++]); + time_kernel = static_cast(std::atoi(argv[optind++])); + + if(data_type != 0 && data_type != 1 && data_type != 3 && data_type != 5 && data_type != 6) + return -1; + + return 0; + }; +}; // end of class AppArgs + +static const double epsilon = std::numeric_limits::epsilon(); + +int profile_batchnorm_backward(int argc, char* argv[]) +{ + using ck::profiler::profile_batchnorm_backward_impl; + + BatchnormBwdArgParser arg_parser; + + if(arg_parser(argc, argv) != 0) + return -1; + + using F16 = ck::half_t; + using F32 = float; + using BF16 = ck::bhalf_t; + using F64 = double; + + if(arg_parser.data_type == 0) + { + if(arg_parser.inLengths.size() == 4 && arg_parser.reduceDims.size() == 3) + { + profile_batchnorm_backward_impl( + arg_parser.do_verification, + arg_parser.init_method, + arg_parser.do_dumpout, + arg_parser.time_kernel, + arg_parser.inLengths, + arg_parser.reduceDims, + arg_parser.haveSavedMeanInvVar, + epsilon); + }; + } + else if(arg_parser.data_type == 1) + { + if(arg_parser.inLengths.size() == 4 && arg_parser.reduceDims.size() == 3) + { + profile_batchnorm_backward_impl( + arg_parser.do_verification, + arg_parser.init_method, + arg_parser.do_dumpout, + arg_parser.time_kernel, + arg_parser.inLengths, + arg_parser.reduceDims, + arg_parser.haveSavedMeanInvVar, + epsilon); + }; + } + else if(arg_parser.data_type == 5) + { + if(arg_parser.inLengths.size() == 4 && arg_parser.reduceDims.size() == 3) + { + profile_batchnorm_backward_impl( + arg_parser.do_verification, + arg_parser.init_method, + arg_parser.do_dumpout, + arg_parser.time_kernel, + arg_parser.inLengths, + arg_parser.reduceDims, + arg_parser.haveSavedMeanInvVar, + epsilon); + }; + } + else if(arg_parser.data_type == 6) + { + if(arg_parser.inLengths.size() == 4 && arg_parser.reduceDims.size() == 3) + { + profile_batchnorm_backward_impl( + arg_parser.do_verification, + arg_parser.init_method, + arg_parser.do_dumpout, + arg_parser.time_kernel, + arg_parser.inLengths, + arg_parser.reduceDims, + arg_parser.haveSavedMeanInvVar, + epsilon); + }; + } + + return 0; +} diff --git a/profiler/src/profiler.cpp b/profiler/src/profiler.cpp index 4942d3c558..34d0f5409f 100644 --- a/profiler/src/profiler.cpp +++ b/profiler/src/profiler.cpp @@ -25,6 +25,7 @@ int profile_layernorm(int, char*[]); int profile_groupnorm(int, char*[]); int profile_reduce(int, char*[]); int profile_batchnorm_forward(int, char*[]); +int profile_batchnorm_backward(int, char*[]); static void print_helper_message() { @@ -148,6 +149,10 @@ int main(int argc, char* argv[]) { return profile_batchnorm_forward(argc, argv); } + else if(strcmp(argv[1], "bnorm_bwd") == 0) + { + return profile_batchnorm_backward(argc, argv); + } else { print_helper_message(); diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 57c11b55aa..a3d2bcdc82 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -53,4 +53,4 @@ add_subdirectory(softmax) add_subdirectory(normalization) add_subdirectory(data_type) add_subdirectory(elementwise_normalization) -add_subdirectory(batchnorm_fwd) +add_subdirectory(batchnorm) diff --git a/test/batchnorm_fwd/CMakeLists.txt b/test/batchnorm/CMakeLists.txt similarity index 50% rename from test/batchnorm_fwd/CMakeLists.txt rename to test/batchnorm/CMakeLists.txt index 87361f9d0a..52f1508682 100644 --- a/test/batchnorm_fwd/CMakeLists.txt +++ b/test/batchnorm/CMakeLists.txt @@ -1,2 +1,4 @@ add_gtest_executable(test_batchnorm_fwd_rank_4 batchnorm_fwd_rank_4.cpp) +add_gtest_executable(test_batchnorm_bwd_rank_4 batchnorm_bwd_rank_4.cpp) target_link_libraries(test_batchnorm_fwd_rank_4 PRIVATE utility device_batchnorm_instance) +target_link_libraries(test_batchnorm_bwd_rank_4 PRIVATE utility device_batchnorm_instance) diff --git a/test/batchnorm/batchnorm_bwd_rank_4.cpp b/test/batchnorm/batchnorm_bwd_rank_4.cpp new file mode 100644 index 0000000000..77590626dc --- /dev/null +++ b/test/batchnorm/batchnorm_bwd_rank_4.cpp @@ -0,0 +1,92 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include +#include +#include + +#include "profiler/include/profile_batchnorm_backward_impl.hpp" + +using F16 = ck::half_t; +using F32 = float; +using BF16 = ck::bhalf_t; +using F64 = double; + +template +class TestBatchNormBwdRank4 : public ::testing::Test +{ + private: + const double epsilon = std::numeric_limits::epsilon(); + + protected: + using XDataType = std::tuple_element_t<0, Tuple>; + using DxDataType = std::tuple_element_t<1, Tuple>; + using DyDataType = std::tuple_element_t<2, Tuple>; + using AccDataType = std::tuple_element_t<3, Tuple>; + using ScaleDataType = std::tuple_element_t<4, Tuple>; + using BiasDataType = std::tuple_element_t<5, Tuple>; + using MeanVarDataType = std::tuple_element_t<6, Tuple>; + + std::vector> list_of_lengths = { + {128, 16, 3, 1024}, {128, 16, 6, 512}, {1, 1, 1, 1}, {4, 4, 4, 4}, {32, 32, 32, 32}}; + std::vector reduceDims; + + template + void Run() + { + for(auto& inOutLengths : list_of_lengths) + { + bool pass = true; + + EXPECT_FALSE(reduceDims.size() != NumReduceDim); + + pass = pass && ck::profiler::profile_batchnorm_backward_impl( + true, 3, false, false, inOutLengths, reduceDims, true, epsilon); + + pass = pass && ck::profiler::profile_batchnorm_backward_impl( + true, 3, false, false, inOutLengths, reduceDims, false, epsilon); + + EXPECT_TRUE(pass); + } + } +}; + +using KernelTypes = ::testing::Types, + std::tuple, + std::tuple, + std::tuple>; + +TYPED_TEST_SUITE(TestBatchNormBwdRank4, KernelTypes); + +// nhwc +TYPED_TEST(TestBatchNormBwdRank4, nhwc) +{ + this->reduceDims = {0, 1, 2}; + this->template Run<3>(); +} + +// nchw +TYPED_TEST(TestBatchNormBwdRank4, nchw) +{ + this->reduceDims = {0, 2, 3}; + this->template Run<3>(); +} diff --git a/test/batchnorm_fwd/batchnorm_fwd_rank_4.cpp b/test/batchnorm/batchnorm_fwd_rank_4.cpp similarity index 100% rename from test/batchnorm_fwd/batchnorm_fwd_rank_4.cpp rename to test/batchnorm/batchnorm_fwd_rank_4.cpp