mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 02:02:46 +00:00
BatchNorm backward instance/external API/profiler/tests (#519)
* Refine the device batchnorm-backward base API templates and data type assignments
* Remove duplicated kernel file
* Add batchnorm backward instances and external API
* Add batchnorm-backward profiler and tests
* Add client example which uses batchnorm backward external API
* Merge test/batchnorm_fwd and test/batchnorm_bwd into one directory
* Loose the threshold for batchnorm-backward check_err()
[ROCm/composable_kernel commit: 63af525c06]
This commit is contained in:
@@ -1,2 +1,4 @@
|
||||
add_executable(client_batchnorm_fwd_nhwc batchnorm_fwd_nhwc.cpp)
|
||||
add_executable(client_batchnorm_bwd_nhwc batchnorm_bwd_nhwc.cpp)
|
||||
target_link_libraries(client_batchnorm_fwd_nhwc PRIVATE composable_kernel::device_operations)
|
||||
target_link_libraries(client_batchnorm_bwd_nhwc PRIVATE composable_kernel::device_operations)
|
||||
|
||||
201
client_example/13_batchnorm/batchnorm_bwd_nhwc.cpp
Normal file
201
client_example/13_batchnorm/batchnorm_bwd_nhwc.cpp
Normal file
@@ -0,0 +1,201 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <functional>
|
||||
#include <numeric>
|
||||
#include <iomanip>
|
||||
#include <iostream>
|
||||
#include <vector>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_reduce.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
#include "ck/library/tensor_operation_instance/gpu/batchnorm_backward.hpp"
|
||||
|
||||
using XDataType = ck::half_t;
|
||||
using DxDataType = float;
|
||||
using DyDataType = float;
|
||||
using AccDataType = float;
|
||||
using ScaleDataType = ck::half_t;
|
||||
using DscaleDbiasDataType = float;
|
||||
using MeanVarDataType = float;
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
constexpr int Rank = 4;
|
||||
constexpr int NumBatchNormReduceDim = 3;
|
||||
|
||||
const double epsilon = std::numeric_limits<float>::epsilon();
|
||||
|
||||
struct SimpleDeviceMem
|
||||
{
|
||||
SimpleDeviceMem() = delete;
|
||||
|
||||
SimpleDeviceMem(std::size_t mem_size) : p_mem_{}
|
||||
{
|
||||
(void)hipMalloc(static_cast<void**>(&p_mem_), mem_size);
|
||||
}
|
||||
|
||||
void* GetDeviceBuffer() { return p_mem_; }
|
||||
|
||||
~SimpleDeviceMem() { (void)hipFree(p_mem_); }
|
||||
|
||||
void* p_mem_;
|
||||
};
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
std::array<ck::index_t, Rank> xyLengths{16, 8, 128, 256};
|
||||
std::array<ck::index_t, Rank> xyStrides{8 * 128 * 256, 128 * 256, 256, 1};
|
||||
std::array<ck::index_t, Rank - NumBatchNormReduceDim> scaleBiasMeanVarLengths{256};
|
||||
std::array<ck::index_t, Rank - NumBatchNormReduceDim> scaleBiasMeanVarStrides{1};
|
||||
std::array<int, NumBatchNormReduceDim> reduceDims{0, 1, 2};
|
||||
|
||||
ck::index_t numXYElement =
|
||||
std::accumulate(xyLengths.begin(), xyLengths.end(), 1, std::multiplies<ck::index_t>());
|
||||
|
||||
ck::index_t numScaleBiasMeanVarElement = std::accumulate(scaleBiasMeanVarLengths.begin(),
|
||||
scaleBiasMeanVarLengths.end(),
|
||||
1,
|
||||
std::multiplies<ck::index_t>());
|
||||
|
||||
SimpleDeviceMem x(sizeof(XDataType) * numXYElement);
|
||||
SimpleDeviceMem dy(sizeof(DyDataType) * numXYElement);
|
||||
SimpleDeviceMem scale(sizeof(ScaleDataType) * numScaleBiasMeanVarElement);
|
||||
SimpleDeviceMem mean(sizeof(MeanVarDataType) * numScaleBiasMeanVarElement);
|
||||
SimpleDeviceMem invVariance(sizeof(MeanVarDataType) * numScaleBiasMeanVarElement);
|
||||
SimpleDeviceMem dx(sizeof(DxDataType) * numXYElement);
|
||||
SimpleDeviceMem dscale(sizeof(DscaleDbiasDataType) * numScaleBiasMeanVarElement);
|
||||
SimpleDeviceMem dbias(sizeof(DscaleDbiasDataType) * numScaleBiasMeanVarElement);
|
||||
|
||||
using DeviceOp = ck::tensor_operation::device::DeviceBatchNormBwd<XDataType,
|
||||
DxDataType,
|
||||
DyDataType,
|
||||
AccDataType,
|
||||
ScaleDataType,
|
||||
DscaleDbiasDataType,
|
||||
MeanVarDataType,
|
||||
PassThrough,
|
||||
Rank,
|
||||
NumBatchNormReduceDim>;
|
||||
|
||||
const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory<
|
||||
DeviceOp>::GetInstances();
|
||||
|
||||
std::cout << "found " << op_ptrs.size() << " instances" << std::endl;
|
||||
|
||||
std::string best_op_name;
|
||||
bool found = false;
|
||||
int best_op_id = -1;
|
||||
float best_ave_time = std::numeric_limits<float>::max();
|
||||
float best_gb_per_sec = 0;
|
||||
|
||||
// profile device operation instances
|
||||
std::cout << "Run all instances and do timing" << std::endl;
|
||||
|
||||
for(int i = 0; i < op_ptrs.size(); ++i)
|
||||
{
|
||||
auto& op_ptr = op_ptrs[i];
|
||||
|
||||
auto argument_ptr = op_ptr->MakeArgumentPointer(xyLengths,
|
||||
xyStrides,
|
||||
xyStrides,
|
||||
xyStrides,
|
||||
reduceDims,
|
||||
scaleBiasMeanVarLengths,
|
||||
scaleBiasMeanVarStrides,
|
||||
scaleBiasMeanVarStrides,
|
||||
scaleBiasMeanVarStrides,
|
||||
x.GetDeviceBuffer(),
|
||||
dy.GetDeviceBuffer(),
|
||||
scale.GetDeviceBuffer(),
|
||||
mean.GetDeviceBuffer(),
|
||||
invVariance.GetDeviceBuffer(),
|
||||
epsilon,
|
||||
PassThrough{},
|
||||
dx.GetDeviceBuffer(),
|
||||
dscale.GetDeviceBuffer(),
|
||||
dbias.GetDeviceBuffer());
|
||||
|
||||
auto invoker_ptr = op_ptr->MakeInvokerPointer();
|
||||
std::string op_name = op_ptr->GetTypeString();
|
||||
|
||||
if(op_ptr->IsSupportedArgument(argument_ptr.get()))
|
||||
{
|
||||
size_t workspace_sz = op_ptr->GetWorkSpaceSize(argument_ptr.get());
|
||||
|
||||
SimpleDeviceMem workspace(workspace_sz);
|
||||
|
||||
op_ptr->SetWorkSpacePointer(argument_ptr.get(), workspace.GetDeviceBuffer());
|
||||
|
||||
float ave_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true});
|
||||
|
||||
std::size_t num_bytes =
|
||||
numXYElement * (sizeof(XDataType) + sizeof(DyDataType) + sizeof(DxDataType)) +
|
||||
numScaleBiasMeanVarElement *
|
||||
(sizeof(ScaleDataType) + sizeof(DscaleDbiasDataType) * 2 +
|
||||
sizeof(MeanVarDataType) * 2);
|
||||
|
||||
float gb_per_sec = num_bytes / 1.E6 / ave_time;
|
||||
|
||||
std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << gb_per_sec << " GB/s, "
|
||||
<< op_name << std::endl;
|
||||
|
||||
if(ave_time < best_ave_time)
|
||||
{
|
||||
found = true;
|
||||
best_op_id = i;
|
||||
best_op_name = op_name;
|
||||
best_ave_time = ave_time;
|
||||
best_gb_per_sec = gb_per_sec;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cout << op_name << " does not support this problem" << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
if(found)
|
||||
{
|
||||
std::cout << "Best Perf: " << best_ave_time << " ms, " << best_gb_per_sec << " GB/s, "
|
||||
<< best_op_name << std::endl;
|
||||
|
||||
// run the best intance
|
||||
auto& op_ptr = op_ptrs[best_op_id];
|
||||
std::cout << "Run the best instance without timing: " << op_ptr->GetTypeString()
|
||||
<< std::endl;
|
||||
|
||||
auto argument_ptr = op_ptr->MakeArgumentPointer(xyLengths,
|
||||
xyStrides,
|
||||
xyStrides,
|
||||
xyStrides,
|
||||
reduceDims,
|
||||
scaleBiasMeanVarLengths,
|
||||
scaleBiasMeanVarStrides,
|
||||
scaleBiasMeanVarStrides,
|
||||
scaleBiasMeanVarStrides,
|
||||
x.GetDeviceBuffer(),
|
||||
dy.GetDeviceBuffer(),
|
||||
scale.GetDeviceBuffer(),
|
||||
mean.GetDeviceBuffer(),
|
||||
invVariance.GetDeviceBuffer(),
|
||||
epsilon,
|
||||
PassThrough{},
|
||||
dx.GetDeviceBuffer(),
|
||||
dscale.GetDeviceBuffer(),
|
||||
dbias.GetDeviceBuffer());
|
||||
|
||||
auto invoker_ptr = op_ptr->MakeInvokerPointer();
|
||||
|
||||
if(op_ptr->IsSupportedArgument(argument_ptr.get()))
|
||||
{
|
||||
invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false});
|
||||
}
|
||||
|
||||
std::cout << "Done" << std::endl;
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
@@ -11,7 +11,7 @@
|
||||
#include "ck/library/utility/host_tensor.hpp"
|
||||
#include "ck/library/utility/host_tensor_generator.hpp"
|
||||
#include "ck/library/utility/host_common_util.hpp"
|
||||
#include "ck/library/reference_tensor_operation/cpu/reference_batchnorm_backward_nhwc_c.hpp"
|
||||
#include "ck/library/reference_tensor_operation/cpu/reference_batchnorm_backward.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_batchnorm_backward_impl.hpp"
|
||||
|
||||
static struct option long_options[] = {{"inOutLengths", required_argument, nullptr, 'D'},
|
||||
@@ -106,7 +106,7 @@ class BatchNormBwdArg
|
||||
|
||||
using namespace ck;
|
||||
|
||||
template <typename InOutDataType, typename AccDataType, bool UseMultiblockInK>
|
||||
template <typename XDataType, typename AccDataType, bool UseMultiblockInK>
|
||||
bool bnorm_bwd_nhwc_test(bool do_verification,
|
||||
int init_method,
|
||||
bool time_kernel,
|
||||
@@ -118,13 +118,15 @@ bool bnorm_bwd_nhwc_test(bool do_verification,
|
||||
constexpr index_t Rank = 4;
|
||||
constexpr index_t NumReduceDim = 3;
|
||||
|
||||
using ScaleDataType = XDataType;
|
||||
|
||||
const std::vector<size_t> scaleBiasMeanVarLengths = {inOutLengths[3]};
|
||||
|
||||
// input data of the batchnorm backward algorithm
|
||||
Tensor<InOutDataType> x(inOutLengths);
|
||||
Tensor<InOutDataType> dy(inOutLengths);
|
||||
Tensor<XDataType> x(inOutLengths);
|
||||
Tensor<AccDataType> dy(inOutLengths);
|
||||
|
||||
Tensor<AccDataType> bnScale(scaleBiasMeanVarLengths);
|
||||
Tensor<ScaleDataType> bnScale(scaleBiasMeanVarLengths);
|
||||
|
||||
Tensor<AccDataType> savedMean(scaleBiasMeanVarLengths);
|
||||
Tensor<AccDataType> savedInvVar(scaleBiasMeanVarLengths);
|
||||
@@ -132,8 +134,8 @@ bool bnorm_bwd_nhwc_test(bool do_verification,
|
||||
Tensor<AccDataType> savedVariance(scaleBiasMeanVarLengths);
|
||||
|
||||
// output data of the batchnorm backward algorithm
|
||||
Tensor<InOutDataType> dx_ref(inOutLengths);
|
||||
Tensor<InOutDataType> dx(inOutLengths);
|
||||
Tensor<AccDataType> dx_ref(inOutLengths);
|
||||
Tensor<AccDataType> dx(inOutLengths);
|
||||
|
||||
Tensor<AccDataType> dscale(scaleBiasMeanVarLengths);
|
||||
Tensor<AccDataType> dbias(scaleBiasMeanVarLengths);
|
||||
@@ -153,7 +155,7 @@ bool bnorm_bwd_nhwc_test(bool do_verification,
|
||||
const float noise_stddev = 0.0001f;
|
||||
|
||||
// input data in normal distribution
|
||||
x.GenerateTensorValue(GeneratorTensor_4<InOutDataType>{x_mean, x_stddev}, num_thread);
|
||||
x.GenerateTensorValue(GeneratorTensor_4<XDataType>{x_mean, x_stddev}, num_thread);
|
||||
|
||||
// initialize the savedMean to be values with tiny variation to the mean of the x values
|
||||
savedMean.GenerateTensorValue(GeneratorTensor_4<AccDataType>{x_mean, noise_stddev},
|
||||
@@ -182,7 +184,7 @@ bool bnorm_bwd_nhwc_test(bool do_verification,
|
||||
const float x_stddev = 1.0f;
|
||||
|
||||
// input data in normal distribution
|
||||
x.GenerateTensorValue(GeneratorTensor_4<InOutDataType>{x_mean, x_stddev}, num_thread);
|
||||
x.GenerateTensorValue(GeneratorTensor_4<XDataType>{x_mean, x_stddev}, num_thread);
|
||||
};
|
||||
|
||||
if(do_verification)
|
||||
@@ -190,34 +192,34 @@ bool bnorm_bwd_nhwc_test(bool do_verification,
|
||||
switch(init_method)
|
||||
{
|
||||
case 0:
|
||||
dy.GenerateTensorValue(GeneratorTensor_0<InOutDataType>{}, num_thread);
|
||||
bnScale.GenerateTensorValue(GeneratorTensor_0<InOutDataType>{}, num_thread);
|
||||
dy.GenerateTensorValue(GeneratorTensor_0<AccDataType>{}, num_thread);
|
||||
bnScale.GenerateTensorValue(GeneratorTensor_0<ScaleDataType>{}, num_thread);
|
||||
break;
|
||||
case 1:
|
||||
dy.GenerateTensorValue(GeneratorTensor_1<InOutDataType>{1}, num_thread);
|
||||
bnScale.GenerateTensorValue(GeneratorTensor_1<InOutDataType>{1}, num_thread);
|
||||
dy.GenerateTensorValue(GeneratorTensor_1<AccDataType>{1}, num_thread);
|
||||
bnScale.GenerateTensorValue(GeneratorTensor_1<ScaleDataType>{1}, num_thread);
|
||||
break;
|
||||
case 2:
|
||||
dy.GenerateTensorValue(GeneratorTensor_2<InOutDataType>{-5, 5}, num_thread);
|
||||
bnScale.GenerateTensorValue(GeneratorTensor_2<InOutDataType>{-5, 5}, num_thread);
|
||||
dy.GenerateTensorValue(GeneratorTensor_2<AccDataType>{-2, 2}, num_thread);
|
||||
bnScale.GenerateTensorValue(GeneratorTensor_2<ScaleDataType>{-5, 5}, num_thread);
|
||||
break;
|
||||
default:
|
||||
dy.GenerateTensorValue(GeneratorTensor_3<InOutDataType>{-0.2f, 0.2f}, num_thread);
|
||||
bnScale.GenerateTensorValue(GeneratorTensor_3<InOutDataType>{-0.5f, 0.5f}, num_thread);
|
||||
dy.GenerateTensorValue(GeneratorTensor_3<AccDataType>{-0.2f, 0.2f}, num_thread);
|
||||
bnScale.GenerateTensorValue(GeneratorTensor_3<ScaleDataType>{-0.5f, 0.5f}, num_thread);
|
||||
}
|
||||
};
|
||||
|
||||
// input data of the batchnorm backward algorithm
|
||||
DeviceMem x_dev(sizeof(InOutDataType) * x.mDesc.GetElementSpaceSize());
|
||||
DeviceMem dy_dev(sizeof(InOutDataType) * dy.mDesc.GetElementSpaceSize());
|
||||
DeviceMem x_dev(sizeof(XDataType) * x.mDesc.GetElementSpaceSize());
|
||||
DeviceMem dy_dev(sizeof(AccDataType) * dy.mDesc.GetElementSpaceSize());
|
||||
|
||||
DeviceMem bnScale_dev(sizeof(AccDataType) * bnScale.mDesc.GetElementSpaceSize());
|
||||
DeviceMem bnScale_dev(sizeof(ScaleDataType) * bnScale.mDesc.GetElementSpaceSize());
|
||||
|
||||
DeviceMem savedMean_dev(sizeof(AccDataType) * savedMean.mDesc.GetElementSpaceSize());
|
||||
DeviceMem savedInvVar_dev(sizeof(AccDataType) * savedInvVar.mDesc.GetElementSpaceSize());
|
||||
|
||||
// output data of the batchnorm backward algorithm
|
||||
DeviceMem dx_dev(sizeof(InOutDataType) * dx.mDesc.GetElementSpaceSize());
|
||||
DeviceMem dx_dev(sizeof(AccDataType) * dx.mDesc.GetElementSpaceSize());
|
||||
|
||||
DeviceMem dscale_dev(sizeof(AccDataType) * dscale.mDesc.GetElementSpaceSize());
|
||||
DeviceMem dbias_dev(sizeof(AccDataType) * dbias.mDesc.GetElementSpaceSize());
|
||||
@@ -249,13 +251,13 @@ bool bnorm_bwd_nhwc_test(bool do_verification,
|
||||
using PassThroughOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
using DeviceBatchNormBwdInstance =
|
||||
ck::tensor_operation::device::DeviceBatchNormBwdImpl<InOutDataType,
|
||||
InOutDataType,
|
||||
InOutDataType,
|
||||
ck::tensor_operation::device::DeviceBatchNormBwdImpl<XDataType,
|
||||
AccDataType,
|
||||
AccDataType, // ScaleDataType
|
||||
AccDataType, // BiasDataType
|
||||
AccDataType, // MeanVarDataType
|
||||
AccDataType,
|
||||
AccDataType,
|
||||
ScaleDataType, // ScaleDataType
|
||||
AccDataType, // DscaleDbiasDataType
|
||||
AccDataType, // MeanVarDataType
|
||||
PassThroughOp,
|
||||
Rank,
|
||||
NumReduceDim,
|
||||
@@ -269,8 +271,8 @@ bool bnorm_bwd_nhwc_test(bool do_verification,
|
||||
1, // XSrcVectorSize
|
||||
1, // DySrcVectorSize
|
||||
1, // DxDstVectorSize
|
||||
1, // ScaleSrcDstVectorSize
|
||||
1, // BiasDstVectorSize
|
||||
1, // ScaleSrcVectorSize
|
||||
1, // DscaleDbiasDstVectorSize
|
||||
1>; // MeanVarSrcVectorSize
|
||||
|
||||
auto batchnorm_bwd = DeviceBatchNormBwdInstance{};
|
||||
@@ -324,7 +326,7 @@ bool bnorm_bwd_nhwc_test(bool do_verification,
|
||||
|
||||
// inputing of x, dy, scale, outputing of dx, dscale, dbias
|
||||
num_bytes +=
|
||||
total_length * sizeof(InOutDataType) * 3 + invariant_length * sizeof(AccDataType) * 3;
|
||||
total_length * sizeof(XDataType) * 3 + invariant_length * sizeof(AccDataType) * 3;
|
||||
|
||||
// outputing of mean, inv-variance
|
||||
num_bytes += haveSavedMeanInvVar ? invariant_length * sizeof(AccDataType) * 2 : 0;
|
||||
@@ -341,14 +343,16 @@ bool bnorm_bwd_nhwc_test(bool do_verification,
|
||||
if(do_verification)
|
||||
{
|
||||
using ReferenceBatchNormBwdInstance =
|
||||
ck::tensor_operation::host::ReferenceBatchNormBwd_Input_N_H_W_C_Output_C<InOutDataType,
|
||||
InOutDataType,
|
||||
InOutDataType,
|
||||
AccDataType,
|
||||
AccDataType,
|
||||
AccDataType,
|
||||
AccDataType,
|
||||
PassThroughOp>;
|
||||
ck::tensor_operation::host::ReferenceBatchNormBwd<XDataType,
|
||||
AccDataType,
|
||||
AccDataType,
|
||||
AccDataType,
|
||||
ScaleDataType, // ScaleDataType
|
||||
AccDataType,
|
||||
AccDataType,
|
||||
PassThroughOp,
|
||||
Rank,
|
||||
NumReduceDim>;
|
||||
|
||||
auto batchNormBwd_ref = ReferenceBatchNormBwdInstance{};
|
||||
|
||||
@@ -390,8 +394,8 @@ bool bnorm_bwd_nhwc_test(bool do_verification,
|
||||
dbias_dev.FromDevice(dbias.data());
|
||||
|
||||
// clang-format off
|
||||
pass = pass && ck::utils::check_err(dbias.mData, dbias_ref.mData, "dBias result:", 1e-5, 1e-5);
|
||||
pass = pass && ck::utils::check_err(dscale.mData, dscale_ref.mData, "dScale result:", 1e-5, 2e-4);
|
||||
pass = pass && ck::utils::check_err(dbias.mData, dbias_ref.mData, "dBias result:", 2e-4, 2e-4);
|
||||
pass = pass && ck::utils::check_err(dscale.mData, dscale_ref.mData, "dScale result:", 2e-4, 2e-4);
|
||||
pass = pass && ck::utils::check_err(dx.mData, dx_ref.mData, "dx result:");
|
||||
// clang-format on
|
||||
};
|
||||
|
||||
@@ -13,7 +13,16 @@ namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
template <index_t Rank, index_t NumBatchNormReduceDim, typename DyElementwiseOp>
|
||||
template <typename XDataType,
|
||||
typename DxDataType,
|
||||
typename DyDataType,
|
||||
typename AccDataType,
|
||||
typename ScaleDataType,
|
||||
typename DscaleDbiasDataType,
|
||||
typename MeanVarDataType,
|
||||
typename DyElementwiseOp,
|
||||
index_t Rank,
|
||||
index_t NumBatchNormReduceDim>
|
||||
struct DeviceBatchNormBwd : public BaseOperator
|
||||
{
|
||||
static constexpr index_t NumInvariantDim = Rank - NumBatchNormReduceDim;
|
||||
@@ -26,7 +35,7 @@ struct DeviceBatchNormBwd : public BaseOperator
|
||||
const std::array<int, NumBatchNormReduceDim> reduceDims,
|
||||
const std::array<ck::index_t, NumInvariantDim> bnScaleBiasMeanVarLengths,
|
||||
const std::array<ck::index_t, NumInvariantDim> bnScaleStrides,
|
||||
const std::array<ck::index_t, NumInvariantDim> bnBiasStrides,
|
||||
const std::array<ck::index_t, NumInvariantDim> bnDscaleDbiasStrides,
|
||||
const std::array<ck::index_t, NumInvariantDim> bnMeanVarStrides,
|
||||
const void* p_x,
|
||||
const void* p_dy,
|
||||
@@ -42,9 +51,26 @@ struct DeviceBatchNormBwd : public BaseOperator
|
||||
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
|
||||
};
|
||||
|
||||
template <index_t Rank, index_t NumBatchNormReduceDim, typename DyElementwiseOp>
|
||||
using DeviceBatchNormBwdPtr =
|
||||
std::unique_ptr<DeviceBatchNormBwd<Rank, NumBatchNormReduceDim, DyElementwiseOp>>;
|
||||
template <typename XDataType,
|
||||
typename DxDataType,
|
||||
typename DyDataType,
|
||||
typename AccDataType,
|
||||
typename ScaleDataType,
|
||||
typename DscaleDbiasDataType,
|
||||
typename MeanVarDataType,
|
||||
typename DyElementwiseOp,
|
||||
index_t Rank,
|
||||
index_t NumBatchNormReduceDim>
|
||||
using DeviceBatchNormBwdPtr = std::unique_ptr<DeviceBatchNormBwd<XDataType,
|
||||
DxDataType,
|
||||
DyDataType,
|
||||
AccDataType,
|
||||
ScaleDataType,
|
||||
DscaleDbiasDataType,
|
||||
MeanVarDataType,
|
||||
DyElementwiseOp,
|
||||
Rank,
|
||||
NumBatchNormReduceDim>>;
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
|
||||
@@ -27,7 +27,7 @@ template <typename XDataType,
|
||||
typename DyDataType,
|
||||
typename AccDataType,
|
||||
typename ScaleDataType,
|
||||
typename BiasDataType,
|
||||
typename DscaleDbiasDataType,
|
||||
typename MeanVarDataType,
|
||||
typename DyElementwiseOp,
|
||||
index_t Rank,
|
||||
@@ -42,11 +42,19 @@ template <typename XDataType,
|
||||
index_t XSrcVectorSize,
|
||||
index_t DySrcVectorSize,
|
||||
index_t DxDstVectorSize,
|
||||
index_t ScaleSrcDstVectorSize,
|
||||
index_t BiasDstVectorSize,
|
||||
index_t ScaleSrcVectorSize,
|
||||
index_t DscaleDbiasDstVectorSize,
|
||||
index_t MeanVarSrcVectorSize>
|
||||
struct DeviceBatchNormBwdImpl
|
||||
: public DeviceBatchNormBwd<Rank, NumBatchNormReduceDim, DyElementwiseOp>
|
||||
struct DeviceBatchNormBwdImpl : public DeviceBatchNormBwd<XDataType,
|
||||
DxDataType,
|
||||
DyDataType,
|
||||
AccDataType,
|
||||
ScaleDataType,
|
||||
DscaleDbiasDataType,
|
||||
MeanVarDataType,
|
||||
DyElementwiseOp,
|
||||
Rank,
|
||||
NumBatchNormReduceDim>
|
||||
{
|
||||
static_assert(Rank <= 6, "Bigger Rank size is not supported!");
|
||||
static_assert(BlockSize == MThreadClusterSize * KThreadClusterSize,
|
||||
@@ -194,7 +202,7 @@ struct DeviceBatchNormBwdImpl
|
||||
const std::array<int, NumBatchNormReduceDim> reduceDims,
|
||||
const std::array<ck::index_t, NumInvariantDim> bnScaleBiasMeanVarLengths,
|
||||
const std::array<ck::index_t, NumInvariantDim> bnScaleStrides,
|
||||
const std::array<ck::index_t, NumInvariantDim> bnBiasStrides,
|
||||
const std::array<ck::index_t, NumInvariantDim> bnDscaleDbiasStrides,
|
||||
const std::array<ck::index_t, NumInvariantDim> bnMeanVarStrides,
|
||||
const XDataType* p_x,
|
||||
const DyDataType* p_dy,
|
||||
@@ -204,11 +212,11 @@ struct DeviceBatchNormBwdImpl
|
||||
const DyElementwiseOp dy_elementwise_op,
|
||||
double epsilon,
|
||||
DxDataType* p_dx,
|
||||
ScaleDataType* p_dscale,
|
||||
BiasDataType* p_dbias)
|
||||
DscaleDbiasDataType* p_dscale,
|
||||
DscaleDbiasDataType* p_dbias)
|
||||
: bnScaleBiasMeanVarLengths_(bnScaleBiasMeanVarLengths),
|
||||
bnScaleStrides_(bnScaleStrides),
|
||||
bnBiasStrides_(bnBiasStrides),
|
||||
bnDscaleDbiasStrides_(bnDscaleDbiasStrides),
|
||||
bnMeanVarStrides_(bnMeanVarStrides),
|
||||
p_x_(p_x),
|
||||
p_dy_(p_dy),
|
||||
@@ -272,8 +280,8 @@ struct DeviceBatchNormBwdImpl
|
||||
MakeXY2dDescriptor(xyLengths_, dxStrides_, blkGroupSize, numBlockTileIteration);
|
||||
scale_grid_desc_m =
|
||||
MakeScaleBiasMeanVar1dDescriptor(bnScaleBiasMeanVarLengths, bnScaleStrides);
|
||||
bias_grid_desc_m =
|
||||
MakeScaleBiasMeanVar1dDescriptor(bnScaleBiasMeanVarLengths, bnBiasStrides);
|
||||
dscale_dbias_grid_desc_m =
|
||||
MakeScaleBiasMeanVar1dDescriptor(bnScaleBiasMeanVarLengths, bnDscaleDbiasStrides);
|
||||
mean_var_grid_desc_m =
|
||||
MakeScaleBiasMeanVar1dDescriptor(bnScaleBiasMeanVarLengths, bnMeanVarStrides);
|
||||
}
|
||||
@@ -289,7 +297,7 @@ struct DeviceBatchNormBwdImpl
|
||||
|
||||
std::array<index_t, Rank - NumBatchNormReduceDim> bnScaleBiasMeanVarLengths_;
|
||||
std::array<index_t, Rank - NumBatchNormReduceDim> bnScaleStrides_;
|
||||
std::array<index_t, Rank - NumBatchNormReduceDim> bnBiasStrides_;
|
||||
std::array<index_t, Rank - NumBatchNormReduceDim> bnDscaleDbiasStrides_;
|
||||
std::array<index_t, Rank - NumBatchNormReduceDim> bnMeanVarStrides_;
|
||||
|
||||
const XDataType* p_x_;
|
||||
@@ -299,8 +307,8 @@ struct DeviceBatchNormBwdImpl
|
||||
const MeanVarDataType* p_savedInvVar_;
|
||||
const DyElementwiseOp dy_elementwise_op_;
|
||||
DxDataType* p_dx_;
|
||||
ScaleDataType* p_dscale_;
|
||||
BiasDataType* p_dbias_;
|
||||
DscaleDbiasDataType* p_dscale_;
|
||||
DscaleDbiasDataType* p_dbias_;
|
||||
|
||||
long_index_t invariant_length;
|
||||
long_index_t reduce_length;
|
||||
@@ -313,7 +321,7 @@ struct DeviceBatchNormBwdImpl
|
||||
XYGridDesc_M_K dy_grid_desc_m_k;
|
||||
XYGridDesc_M_K dx_grid_desc_m_k;
|
||||
ScaleBiasGridDesc_M scale_grid_desc_m;
|
||||
ScaleBiasGridDesc_M bias_grid_desc_m;
|
||||
ScaleBiasGridDesc_M dscale_dbias_grid_desc_m;
|
||||
MeanVarGridDesc_M mean_var_grid_desc_m;
|
||||
|
||||
void* workspace_mean;
|
||||
@@ -337,11 +345,11 @@ struct DeviceBatchNormBwdImpl
|
||||
{
|
||||
// workspace for the partial reduced result for dscale
|
||||
workspace_size +=
|
||||
pArg_->invariant_length * pArg_->blkGroupSize * sizeof(ScaleDataType) + 64;
|
||||
pArg_->invariant_length * pArg_->blkGroupSize * sizeof(DscaleDbiasDataType) + 64;
|
||||
|
||||
// workspace for the partial reduced result for dbias
|
||||
workspace_size +=
|
||||
pArg_->invariant_length * pArg_->blkGroupSize * sizeof(BiasDataType) + 64;
|
||||
pArg_->invariant_length * pArg_->blkGroupSize * sizeof(DscaleDbiasDataType) + 64;
|
||||
|
||||
if(!pArg_->haveSavedMeanInvVar_)
|
||||
{
|
||||
@@ -379,7 +387,7 @@ struct DeviceBatchNormBwdImpl
|
||||
// setup buffer for the partial reduced result for dscale
|
||||
pArg_->workspace_reduce_dscale = pArg_->p_workspace_;
|
||||
|
||||
space_sz = pArg_->invariant_length * pArg_->blkGroupSize * sizeof(ScaleDataType);
|
||||
space_sz = pArg_->invariant_length * pArg_->blkGroupSize * sizeof(DscaleDbiasDataType);
|
||||
space_sz = math::integer_least_multiple(space_sz, 64);
|
||||
|
||||
// setup buffer for the partial reduced result for dbias
|
||||
@@ -388,7 +396,7 @@ struct DeviceBatchNormBwdImpl
|
||||
|
||||
if(UseMultiblockInK && pArg_->blkGroupSize > 1)
|
||||
{
|
||||
space_sz = pArg_->invariant_length * pArg_->blkGroupSize * sizeof(BiasDataType);
|
||||
space_sz = pArg_->invariant_length * pArg_->blkGroupSize * sizeof(DscaleDbiasDataType);
|
||||
space_sz = math::integer_least_multiple(space_sz, 64);
|
||||
|
||||
// setup buffer for welford intermediate mean
|
||||
@@ -454,7 +462,7 @@ struct DeviceBatchNormBwdImpl
|
||||
DyDataType,
|
||||
AccDataType,
|
||||
ScaleDataType,
|
||||
BiasDataType,
|
||||
DscaleDbiasDataType,
|
||||
MeanVarDataType,
|
||||
DyElementwiseOp,
|
||||
XYGridDesc_M_K,
|
||||
@@ -477,7 +485,7 @@ struct DeviceBatchNormBwdImpl
|
||||
DxDataType,
|
||||
AccDataType,
|
||||
ScaleDataType,
|
||||
BiasDataType,
|
||||
DscaleDbiasDataType,
|
||||
MeanVarDataType,
|
||||
DyElementwiseOp,
|
||||
XYGridDesc_M_K,
|
||||
@@ -493,8 +501,8 @@ struct DeviceBatchNormBwdImpl
|
||||
XSrcVectorSize,
|
||||
DySrcVectorSize,
|
||||
DxDstVectorSize,
|
||||
ScaleSrcDstVectorSize,
|
||||
BiasDstVectorSize,
|
||||
ScaleSrcVectorSize,
|
||||
DscaleDbiasDstVectorSize,
|
||||
MeanVarSrcVectorSize>;
|
||||
|
||||
if(UseMultiblockInK && arg.blkGroupSize > 1)
|
||||
@@ -553,7 +561,7 @@ struct DeviceBatchNormBwdImpl
|
||||
DyDataType,
|
||||
AccDataType,
|
||||
ScaleDataType,
|
||||
BiasDataType,
|
||||
DscaleDbiasDataType,
|
||||
MeanVarDataType,
|
||||
DyElementwiseOp,
|
||||
XYGridDesc_M_K,
|
||||
@@ -568,7 +576,7 @@ struct DeviceBatchNormBwdImpl
|
||||
DyDataType,
|
||||
DxDataType,
|
||||
ScaleDataType,
|
||||
BiasDataType,
|
||||
DscaleDbiasDataType,
|
||||
MeanVarDataType,
|
||||
DyElementwiseOp,
|
||||
XYGridDesc_M_K,
|
||||
@@ -614,8 +622,8 @@ struct DeviceBatchNormBwdImpl
|
||||
: static_cast<MeanVarDataType*>(arg.workspace_savedInvVar),
|
||||
arg.p_x_,
|
||||
arg.p_dy_,
|
||||
static_cast<ScaleDataType*>(arg.workspace_reduce_dscale),
|
||||
static_cast<BiasDataType*>(arg.workspace_reduce_dbias));
|
||||
static_cast<DscaleDbiasDataType*>(arg.workspace_reduce_dscale),
|
||||
static_cast<DscaleDbiasDataType*>(arg.workspace_reduce_dbias));
|
||||
|
||||
avg_time += launch_and_time_kernel(
|
||||
stream_config,
|
||||
@@ -629,13 +637,13 @@ struct DeviceBatchNormBwdImpl
|
||||
dscale_dbias_grid_desc_m_k,
|
||||
arg.mean_var_grid_desc_m,
|
||||
arg.scale_grid_desc_m,
|
||||
arg.bias_grid_desc_m,
|
||||
arg.dscale_dbias_grid_desc_m,
|
||||
arg.blkGroupSize,
|
||||
arg.reduce_length,
|
||||
arg.numBlockTileIteration,
|
||||
numDscaleDbiasBlockTileIteration,
|
||||
static_cast<const ScaleDataType*>(arg.workspace_reduce_dscale),
|
||||
static_cast<const BiasDataType*>(arg.workspace_reduce_dbias),
|
||||
static_cast<const DscaleDbiasDataType*>(arg.workspace_reduce_dscale),
|
||||
static_cast<const DscaleDbiasDataType*>(arg.workspace_reduce_dbias),
|
||||
arg.haveSavedMeanInvVar_
|
||||
? arg.p_savedMean_
|
||||
: static_cast<const MeanVarDataType*>(arg.workspace_savedMean),
|
||||
@@ -664,7 +672,7 @@ struct DeviceBatchNormBwdImpl
|
||||
DxDataType,
|
||||
AccDataType,
|
||||
ScaleDataType,
|
||||
BiasDataType,
|
||||
DscaleDbiasDataType,
|
||||
MeanVarDataType,
|
||||
DyElementwiseOp,
|
||||
XYGridDesc_M_K,
|
||||
@@ -680,8 +688,8 @@ struct DeviceBatchNormBwdImpl
|
||||
XSrcVectorSize,
|
||||
DySrcVectorSize,
|
||||
DxDstVectorSize,
|
||||
ScaleSrcDstVectorSize,
|
||||
BiasDstVectorSize,
|
||||
ScaleSrcVectorSize,
|
||||
DscaleDbiasDstVectorSize,
|
||||
MeanVarSrcVectorSize>;
|
||||
|
||||
const auto kern_batchnorm_bwd = kernel_batchnorm_backward_with_blockwise_welford<
|
||||
@@ -691,7 +699,7 @@ struct DeviceBatchNormBwdImpl
|
||||
DxDataType,
|
||||
AccDataType,
|
||||
ScaleDataType,
|
||||
BiasDataType,
|
||||
DscaleDbiasDataType,
|
||||
MeanVarDataType,
|
||||
DyElementwiseOp,
|
||||
XYGridDesc_M_K,
|
||||
@@ -708,7 +716,7 @@ struct DeviceBatchNormBwdImpl
|
||||
arg.dy_grid_desc_m_k,
|
||||
arg.dx_grid_desc_m_k,
|
||||
arg.scale_grid_desc_m,
|
||||
arg.bias_grid_desc_m,
|
||||
arg.dscale_dbias_grid_desc_m,
|
||||
arg.mean_var_grid_desc_m,
|
||||
get_reduce_count_per_thread,
|
||||
arg.reduce_length,
|
||||
@@ -764,16 +772,16 @@ struct DeviceBatchNormBwdImpl
|
||||
return false;
|
||||
};
|
||||
|
||||
if(pArg_->bnScaleStrides_[NumInvariantDim - 1] != 1 && ScaleSrcDstVectorSize != 1)
|
||||
if(pArg_->bnScaleStrides_[NumInvariantDim - 1] != 1 && ScaleSrcVectorSize != 1)
|
||||
return false;
|
||||
|
||||
if(pArg_->bnBiasStrides_[NumInvariantDim - 1] != 1 && BiasDstVectorSize != 1)
|
||||
if(pArg_->bnDscaleDbiasStrides_[NumInvariantDim - 1] != 1 && DscaleDbiasDstVectorSize != 1)
|
||||
return false;
|
||||
|
||||
if(pArg_->bnScaleBiasMeanVarLengths_[NumInvariantDim - 1] % ScaleSrcDstVectorSize != 0)
|
||||
if(pArg_->bnScaleBiasMeanVarLengths_[NumInvariantDim - 1] % ScaleSrcVectorSize != 0)
|
||||
return false;
|
||||
|
||||
if(pArg_->bnScaleBiasMeanVarLengths_[NumInvariantDim - 1] % BiasDstVectorSize != 0)
|
||||
if(pArg_->bnScaleBiasMeanVarLengths_[NumInvariantDim - 1] % DscaleDbiasDstVectorSize != 0)
|
||||
return false;
|
||||
|
||||
if(pArg_->haveSavedMeanInvVar_)
|
||||
@@ -806,7 +814,7 @@ struct DeviceBatchNormBwdImpl
|
||||
const std::array<int, NumBatchNormReduceDim> reduceDims,
|
||||
const std::array<ck::index_t, NumInvariantDim> bnScaleBiasMeanVarLengths,
|
||||
const std::array<ck::index_t, NumInvariantDim> bnScaleStrides,
|
||||
const std::array<ck::index_t, NumInvariantDim> bnBiasStrides,
|
||||
const std::array<ck::index_t, NumInvariantDim> bnDscaleDbiasStrides,
|
||||
const std::array<ck::index_t, NumInvariantDim> bnMeanVarStrides,
|
||||
const void* p_x,
|
||||
const void* p_dy,
|
||||
@@ -826,7 +834,7 @@ struct DeviceBatchNormBwdImpl
|
||||
reduceDims,
|
||||
bnScaleBiasMeanVarLengths,
|
||||
bnScaleStrides,
|
||||
bnBiasStrides,
|
||||
bnDscaleDbiasStrides,
|
||||
bnMeanVarStrides,
|
||||
static_cast<const XDataType*>(p_x),
|
||||
static_cast<const DyDataType*>(p_dy),
|
||||
@@ -836,8 +844,8 @@ struct DeviceBatchNormBwdImpl
|
||||
dy_elementwise_op,
|
||||
epsilon,
|
||||
static_cast<DxDataType*>(p_dx),
|
||||
static_cast<ScaleDataType*>(p_dscale),
|
||||
static_cast<BiasDataType*>(p_dbias));
|
||||
static_cast<DscaleDbiasDataType*>(p_dscale),
|
||||
static_cast<DscaleDbiasDataType*>(p_dbias));
|
||||
};
|
||||
|
||||
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
|
||||
@@ -854,7 +862,7 @@ struct DeviceBatchNormBwdImpl
|
||||
str << "M_C" << MThreadClusterSize << "_S" << MThreadSliceSize << ",";
|
||||
str << "K_C" << KThreadClusterSize << "_S" << KThreadSliceSize << ",";
|
||||
str << "XDyDxVectorDim_" << XDyDxVectorDim << ",";
|
||||
str << "VectorSize_X" << XSrcVectorSize << "_scale_" << ScaleSrcDstVectorSize << "_bias_" << BiasDstVectorSize << "_mean_var_" << MeanVarSrcVectorSize << "_Dx_" << DxDstVectorSize << ">";
|
||||
str << "VectorSize_X" << XSrcVectorSize << "_scale_" << ScaleSrcVectorSize << "_bias_" << DscaleDbiasDstVectorSize << "_mean_var_" << MeanVarSrcVectorSize << "_Dx_" << DxDstVectorSize << ">";
|
||||
// clang-format on
|
||||
|
||||
return str.str();
|
||||
|
||||
@@ -16,7 +16,7 @@ template <typename GridwiseReduceSecondHalfBatchNormBackwardFinal_,
|
||||
typename DyDataType,
|
||||
typename DxDataType,
|
||||
typename ScaleDataType,
|
||||
typename BiasDataType,
|
||||
typename DscaleDbiasDataType,
|
||||
typename MeanVarDataType,
|
||||
typename DyElementwiseOp,
|
||||
typename XYGridDesc_M_K,
|
||||
@@ -35,8 +35,8 @@ __global__ void kernel_reduce_second_half_batchnorm_backward_final(
|
||||
long_index_t reduce_size,
|
||||
index_t num_xy_k_block_tile_iteration,
|
||||
index_t num_dscale_dbias_k_block_tile_iteration,
|
||||
const ScaleDataType* const __restrict__ p_reduce_dscale,
|
||||
const BiasDataType* const __restrict__ p_reduce_dbias,
|
||||
const DscaleDbiasDataType* const __restrict__ p_reduce_dscale,
|
||||
const DscaleDbiasDataType* const __restrict__ p_reduce_dbias,
|
||||
const MeanVarDataType* const __restrict__ p_mean,
|
||||
const MeanVarDataType* const __restrict__ p_inv_var,
|
||||
const XDataType* const __restrict__ p_x,
|
||||
@@ -44,8 +44,8 @@ __global__ void kernel_reduce_second_half_batchnorm_backward_final(
|
||||
const ScaleDataType* const __restrict__ p_scale,
|
||||
const DyElementwiseOp dy_elementwise_op,
|
||||
DxDataType* const __restrict__ p_dx,
|
||||
ScaleDataType* const __restrict__ p_dscale,
|
||||
BiasDataType* const __restrict__ p_dbias)
|
||||
DscaleDbiasDataType* const __restrict__ p_dscale,
|
||||
DscaleDbiasDataType* const __restrict__ p_dbias)
|
||||
{
|
||||
GridwiseReduceSecondHalfBatchNormBackwardFinal_::Run(x_grid_desc_m_k,
|
||||
dy_grid_desc_m_k,
|
||||
@@ -76,7 +76,7 @@ template <typename XDataType,
|
||||
typename DxDataType,
|
||||
typename AccDataType,
|
||||
typename ScaleDataType,
|
||||
typename BiasDataType,
|
||||
typename DscaleDbiasDataType,
|
||||
typename MeanVarDataType,
|
||||
typename DyElementwiseOp,
|
||||
typename XYGridDesc_M_K,
|
||||
@@ -92,8 +92,8 @@ template <typename XDataType,
|
||||
index_t XSrcVectorSize,
|
||||
index_t DySrcVectorSize,
|
||||
index_t DxDstVectorSize,
|
||||
index_t ScaleSrcDstVectorSize,
|
||||
index_t BiasDstVectorSize,
|
||||
index_t ScaleSrcVectorSize,
|
||||
index_t DscaleDbiasDstVectorSize,
|
||||
index_t MeanVarSrcVectorSize>
|
||||
struct GridwiseReduceSecondHalfBatchNormBackwardFinal
|
||||
{
|
||||
@@ -155,13 +155,13 @@ struct GridwiseReduceSecondHalfBatchNormBackwardFinal
|
||||
const DscaleDbiasGridDesc_M_K& dscale_dbias_grid_desc_m_k,
|
||||
const MeanVarGridDesc_M& mean_var_grid_desc_m,
|
||||
const ScaleBiasGridDesc_M& scale_grid_desc_m,
|
||||
const ScaleBiasGridDesc_M& bias_grid_desc_m,
|
||||
const ScaleBiasGridDesc_M& dscale_dbias_grid_desc_m,
|
||||
index_t blkgroup_size,
|
||||
long_index_t reduce_size,
|
||||
index_t num_xy_k_block_tile_iteration,
|
||||
index_t num_dscale_dbias_k_block_tile_iteration,
|
||||
const ScaleDataType* const __restrict__ p_reduce_dscale,
|
||||
const BiasDataType* const __restrict__ p_reduce_dbias,
|
||||
const DscaleDbiasDataType* const __restrict__ p_reduce_dscale,
|
||||
const DscaleDbiasDataType* const __restrict__ p_reduce_dbias,
|
||||
const MeanVarDataType* const __restrict__ p_mean,
|
||||
const MeanVarDataType* const __restrict__ p_inv_var,
|
||||
const XDataType* const __restrict__ p_x,
|
||||
@@ -169,8 +169,8 @@ struct GridwiseReduceSecondHalfBatchNormBackwardFinal
|
||||
const ScaleDataType* const __restrict__ p_scale,
|
||||
const DyElementwiseOp dy_elementwise_op,
|
||||
DxDataType* const __restrict__ p_dx,
|
||||
ScaleDataType* const __restrict__ p_dscale,
|
||||
BiasDataType* const __restrict__ p_dbias)
|
||||
DscaleDbiasDataType* const __restrict__ p_dscale,
|
||||
DscaleDbiasDataType* const __restrict__ p_dbias)
|
||||
{
|
||||
__shared__ AccDataType p_reduce_work_buffer[BlockSize];
|
||||
|
||||
@@ -222,8 +222,8 @@ struct GridwiseReduceSecondHalfBatchNormBackwardFinal
|
||||
// Step 1: do final reduction of dbias = sum(dy), dscale = sum(dy * (x-mean) * inv-variance)
|
||||
// clang-format on
|
||||
|
||||
auto threadwise_dscale_load_m_k =
|
||||
ThreadwiseTensorSliceTransfer_v2<ScaleDataType,
|
||||
auto threadwise_dscale_dbias_load_m_k =
|
||||
ThreadwiseTensorSliceTransfer_v2<DscaleDbiasDataType,
|
||||
AccDataType,
|
||||
DscaleDbiasGridDesc_M_K,
|
||||
decltype(thread_buffer_desc_m_1),
|
||||
@@ -238,54 +238,20 @@ struct GridwiseReduceSecondHalfBatchNormBackwardFinal
|
||||
thread_m_cluster_id * MThreadSliceSize,
|
||||
thread_k_cluster_id * 1));
|
||||
|
||||
auto threadwise_dbias_load_m_k =
|
||||
ThreadwiseTensorSliceTransfer_v2<BiasDataType,
|
||||
AccDataType,
|
||||
DscaleDbiasGridDesc_M_K,
|
||||
decltype(thread_buffer_desc_m_1),
|
||||
ThreadBufferLengths_M_1,
|
||||
Sequence<0, 1>,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
true>(
|
||||
dscale_dbias_grid_desc_m_k,
|
||||
make_multi_index(blkgroup_id * M_BlockTileSize +
|
||||
thread_m_cluster_id * MThreadSliceSize,
|
||||
thread_k_cluster_id * 1));
|
||||
|
||||
auto threadwise_dscale_store_m =
|
||||
auto threadwise_dscale_dbias_store_m =
|
||||
ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
|
||||
ScaleDataType,
|
||||
DscaleDbiasDataType,
|
||||
decltype(thread_buffer_desc_m),
|
||||
ScaleBiasGridDesc_M,
|
||||
PassThroughOp,
|
||||
ThreadBufferLengths_M,
|
||||
Sequence<0>,
|
||||
0,
|
||||
ScaleSrcDstVectorSize,
|
||||
DscaleDbiasDstVectorSize,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
1,
|
||||
true>(
|
||||
scale_grid_desc_m,
|
||||
make_multi_index(blkgroup_id * M_BlockTileSize +
|
||||
thread_m_cluster_id * MThreadSliceSize),
|
||||
PassThroughOp{});
|
||||
|
||||
auto threadwise_dbias_store_m =
|
||||
ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
|
||||
BiasDataType,
|
||||
decltype(thread_buffer_desc_m),
|
||||
ScaleBiasGridDesc_M,
|
||||
PassThroughOp,
|
||||
ThreadBufferLengths_M,
|
||||
Sequence<0>,
|
||||
0,
|
||||
BiasDstVectorSize,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
1,
|
||||
true>(
|
||||
bias_grid_desc_m,
|
||||
dscale_dbias_grid_desc_m,
|
||||
make_multi_index(blkgroup_id * M_BlockTileSize +
|
||||
thread_m_cluster_id * MThreadSliceSize),
|
||||
PassThroughOp{});
|
||||
@@ -297,10 +263,10 @@ struct GridwiseReduceSecondHalfBatchNormBackwardFinal
|
||||
p_reduce_dbias, dscale_dbias_grid_desc_m_k.GetElementSpaceSize());
|
||||
|
||||
auto dscale_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_dscale, scale_grid_desc_m.GetElementSpaceSize());
|
||||
p_dscale, dscale_dbias_grid_desc_m.GetElementSpaceSize());
|
||||
|
||||
auto dbias_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_dbias, bias_grid_desc_m.GetElementSpaceSize());
|
||||
p_dbias, dscale_dbias_grid_desc_m.GetElementSpaceSize());
|
||||
|
||||
constexpr auto dscale_dbias_thread_copy_step_m_k =
|
||||
make_multi_index(0, KThreadClusterSize * 1);
|
||||
@@ -313,25 +279,23 @@ struct GridwiseReduceSecondHalfBatchNormBackwardFinal
|
||||
for(index_t reducedTiles = 0; reducedTiles < num_dscale_dbias_k_block_tile_iteration;
|
||||
++reducedTiles)
|
||||
{
|
||||
threadwise_dscale_load_m_k.Run(dscale_dbias_grid_desc_m_k,
|
||||
reduce_dscale_global_buf,
|
||||
thread_buffer_desc_m_1,
|
||||
make_tuple(I0, I0),
|
||||
reduce_dscale_thread_buf);
|
||||
threadwise_dscale_dbias_load_m_k.Run(dscale_dbias_grid_desc_m_k,
|
||||
reduce_dscale_global_buf,
|
||||
thread_buffer_desc_m_1,
|
||||
make_tuple(I0, I0),
|
||||
reduce_dscale_thread_buf);
|
||||
|
||||
threadwise_dbias_load_m_k.Run(dscale_dbias_grid_desc_m_k,
|
||||
reduce_dbias_global_buf,
|
||||
thread_buffer_desc_m_1,
|
||||
make_tuple(I0, I0),
|
||||
reduce_dbias_thread_buf);
|
||||
threadwise_dscale_dbias_load_m_k.Run(dscale_dbias_grid_desc_m_k,
|
||||
reduce_dbias_global_buf,
|
||||
thread_buffer_desc_m_1,
|
||||
make_tuple(I0, I0),
|
||||
reduce_dbias_thread_buf);
|
||||
|
||||
ThreadwiseReduce::Reduce(reduce_dscale_thread_buf, dscale_thread_buf);
|
||||
ThreadwiseReduce::Reduce(reduce_dbias_thread_buf, dbias_thread_buf);
|
||||
|
||||
threadwise_dscale_load_m_k.MoveSrcSliceWindow(dscale_dbias_grid_desc_m_k,
|
||||
dscale_dbias_thread_copy_step_m_k);
|
||||
threadwise_dbias_load_m_k.MoveSrcSliceWindow(dscale_dbias_grid_desc_m_k,
|
||||
dscale_dbias_thread_copy_step_m_k);
|
||||
threadwise_dscale_dbias_load_m_k.MoveSrcSliceWindow(dscale_dbias_grid_desc_m_k,
|
||||
dscale_dbias_thread_copy_step_m_k);
|
||||
}
|
||||
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
|
||||
@@ -343,17 +307,17 @@ struct GridwiseReduceSecondHalfBatchNormBackwardFinal
|
||||
BlockwiseReduce::Reduce(reduce_work_buf, dbias_thread_buf(I));
|
||||
});
|
||||
|
||||
threadwise_dscale_store_m.Run(thread_buffer_desc_m,
|
||||
make_tuple(I0),
|
||||
dscale_thread_buf,
|
||||
scale_grid_desc_m,
|
||||
dscale_global_buf);
|
||||
threadwise_dscale_dbias_store_m.Run(thread_buffer_desc_m,
|
||||
make_tuple(I0),
|
||||
dscale_thread_buf,
|
||||
dscale_dbias_grid_desc_m,
|
||||
dscale_global_buf);
|
||||
|
||||
threadwise_dbias_store_m.Run(thread_buffer_desc_m,
|
||||
make_tuple(I0),
|
||||
dbias_thread_buf,
|
||||
bias_grid_desc_m,
|
||||
dbias_global_buf);
|
||||
threadwise_dscale_dbias_store_m.Run(thread_buffer_desc_m,
|
||||
make_tuple(I0),
|
||||
dbias_thread_buf,
|
||||
dscale_dbias_grid_desc_m,
|
||||
dbias_global_buf);
|
||||
|
||||
// clang-format off
|
||||
// Step 2: calculate dx = 1/N * inv-variance * scale * (N * dy - dbias - dscale * (x - mean) * inv-variance)
|
||||
@@ -418,7 +382,7 @@ struct GridwiseReduceSecondHalfBatchNormBackwardFinal
|
||||
ThreadBufferLengths_M,
|
||||
Sequence<0>,
|
||||
0,
|
||||
ScaleSrcDstVectorSize,
|
||||
ScaleSrcVectorSize,
|
||||
1,
|
||||
true>(
|
||||
scale_grid_desc_m,
|
||||
|
||||
@@ -17,7 +17,7 @@ template <typename GridwiseWelfordSecondHalfReduceFirstHalf_,
|
||||
typename DyDataType,
|
||||
typename AccDataType,
|
||||
typename ScaleDataType,
|
||||
typename BiasDataType,
|
||||
typename DscaleDbiasDataType,
|
||||
typename MeanVarDataType,
|
||||
typename DyElementwiseOp,
|
||||
typename XYGridDesc_M_K,
|
||||
@@ -45,8 +45,8 @@ __global__ void kernel_welford_second_half_reduce_first_half(
|
||||
MeanVarDataType* const __restrict__ p_out_welford_inv_variance,
|
||||
const XDataType* const __restrict__ p_x,
|
||||
const DyDataType* const __restrict__ p_dy,
|
||||
ScaleDataType* const __restrict__ p_reduce_dscale,
|
||||
BiasDataType* const __restrict__ p_reduce_dbias)
|
||||
DscaleDbiasDataType* const __restrict__ p_reduce_dscale,
|
||||
DscaleDbiasDataType* const __restrict__ p_reduce_dbias)
|
||||
{
|
||||
GridwiseWelfordSecondHalfReduceFirstHalf_::Run(x_grid_desc_m_k,
|
||||
dy_grid_desc_m_k,
|
||||
@@ -76,7 +76,7 @@ template <typename XDataType,
|
||||
typename DyDataType,
|
||||
typename AccDataType,
|
||||
typename ScaleDataType,
|
||||
typename BiasDataType,
|
||||
typename DscaleDbiasDataType,
|
||||
typename MeanVarDataType,
|
||||
typename DyElementwiseOp,
|
||||
typename XYGridDesc_M_K,
|
||||
@@ -174,8 +174,8 @@ struct GridwiseWelfordSecondHalfReduceFirstHalf
|
||||
MeanVarDataType* const __restrict__ p_out_welford_inv_variance,
|
||||
const XDataType* const __restrict__ p_x,
|
||||
const DyDataType* const __restrict__ p_dy,
|
||||
ScaleDataType* const __restrict__ p_reduce_dscale,
|
||||
BiasDataType* const __restrict__ p_reduce_dbias)
|
||||
DscaleDbiasDataType* const __restrict__ p_reduce_dscale,
|
||||
DscaleDbiasDataType* const __restrict__ p_reduce_dbias)
|
||||
{
|
||||
__shared__ AccDataType p_reduce_work_buffer[BlockSize];
|
||||
|
||||
@@ -511,28 +511,9 @@ struct GridwiseWelfordSecondHalfReduceFirstHalf
|
||||
BlockwiseReduce::Reduce(reduce_work_buf, reduce_dbias_thread_buf(I));
|
||||
});
|
||||
|
||||
auto threadwise_dscale_store =
|
||||
auto threadwise_dscale_dbias_store =
|
||||
ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
|
||||
ScaleDataType,
|
||||
decltype(thread_buffer_desc_m_1),
|
||||
DscaleDbiasGridDesc_M_G,
|
||||
PassThroughOp,
|
||||
ThreadBufferLengths_M_1,
|
||||
Sequence<0, 1>,
|
||||
1,
|
||||
1,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
1,
|
||||
true>(
|
||||
dscale_dbias_grid_desc_m_g,
|
||||
make_multi_index(blkgroup_id * M_BlockTileSize +
|
||||
thread_m_cluster_id * MThreadSliceSize,
|
||||
block_local_id),
|
||||
PassThroughOp{});
|
||||
|
||||
auto threadwise_dbias_store =
|
||||
ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
|
||||
BiasDataType,
|
||||
DscaleDbiasDataType,
|
||||
decltype(thread_buffer_desc_m_1),
|
||||
DscaleDbiasGridDesc_M_G,
|
||||
PassThroughOp,
|
||||
@@ -557,17 +538,17 @@ struct GridwiseWelfordSecondHalfReduceFirstHalf
|
||||
|
||||
if(thread_k_cluster_id == 0)
|
||||
{
|
||||
threadwise_dscale_store.Run(thread_buffer_desc_m_1,
|
||||
make_tuple(I0, I0),
|
||||
reduce_dscale_thread_buf,
|
||||
dscale_dbias_grid_desc_m_g,
|
||||
reduce_dscale_global_buf);
|
||||
threadwise_dscale_dbias_store.Run(thread_buffer_desc_m_1,
|
||||
make_tuple(I0, I0),
|
||||
reduce_dscale_thread_buf,
|
||||
dscale_dbias_grid_desc_m_g,
|
||||
reduce_dscale_global_buf);
|
||||
|
||||
threadwise_dbias_store.Run(thread_buffer_desc_m_1,
|
||||
make_tuple(I0, I0),
|
||||
reduce_dbias_thread_buf,
|
||||
dscale_dbias_grid_desc_m_g,
|
||||
reduce_dbias_global_buf);
|
||||
threadwise_dscale_dbias_store.Run(thread_buffer_desc_m_1,
|
||||
make_tuple(I0, I0),
|
||||
reduce_dbias_thread_buf,
|
||||
dscale_dbias_grid_desc_m_g,
|
||||
reduce_dbias_global_buf);
|
||||
};
|
||||
};
|
||||
};
|
||||
|
||||
@@ -21,7 +21,7 @@ template <typename GridwiseBatchrNormBackwardWithBlockwiseWelford_,
|
||||
typename DxDataType,
|
||||
typename AccDataType,
|
||||
typename ScaleDataType,
|
||||
typename BiasDataType,
|
||||
typename DscaleDbiasDataType,
|
||||
typename MeanVarDataType,
|
||||
typename DyElementwiseOp,
|
||||
typename XYGridDesc_M_K,
|
||||
@@ -33,7 +33,7 @@ __global__ void kernel_batchnorm_backward_with_blockwise_welford(
|
||||
const XYGridDesc_M_K dy_grid_desc_m_k,
|
||||
const XYGridDesc_M_K dx_grid_desc_m_k,
|
||||
const ScaleBiasGridDesc_M scale_grid_desc_m,
|
||||
const ScaleBiasGridDesc_M bias_grid_desc_m,
|
||||
const ScaleBiasGridDesc_M dscale_dbias_grid_desc_m,
|
||||
const MeanVarGridDesc_M mean_var_grid_desc_m,
|
||||
const GetReduceCountPerThreadFunctor get_reduce_count_per_thread,
|
||||
long_index_t reduce_size,
|
||||
@@ -47,14 +47,14 @@ __global__ void kernel_batchnorm_backward_with_blockwise_welford(
|
||||
const MeanVarDataType* const __restrict__ p_savedInvVar,
|
||||
const DyElementwiseOp dy_elementwise_op,
|
||||
DxDataType* const __restrict__ p_dx,
|
||||
ScaleDataType* const __restrict__ p_dscale,
|
||||
BiasDataType* const __restrict__ p_dbias)
|
||||
DscaleDbiasDataType* const __restrict__ p_dscale,
|
||||
DscaleDbiasDataType* const __restrict__ p_dbias)
|
||||
{
|
||||
GridwiseBatchrNormBackwardWithBlockwiseWelford_::Run(x_grid_desc_m_k,
|
||||
dy_grid_desc_m_k,
|
||||
dx_grid_desc_m_k,
|
||||
scale_grid_desc_m,
|
||||
bias_grid_desc_m,
|
||||
dscale_dbias_grid_desc_m,
|
||||
mean_var_grid_desc_m,
|
||||
get_reduce_count_per_thread,
|
||||
reduce_size,
|
||||
@@ -77,7 +77,7 @@ template <typename XDataType,
|
||||
typename DxDataType,
|
||||
typename AccDataType,
|
||||
typename ScaleDataType,
|
||||
typename BiasDataType,
|
||||
typename DscaleDbiasDataType,
|
||||
typename MeanVarDataType,
|
||||
typename DyElementwiseOp,
|
||||
typename XYGridDesc_M_K,
|
||||
@@ -93,8 +93,8 @@ template <typename XDataType,
|
||||
index_t XSrcVectorSize,
|
||||
index_t DySrcVectorSize,
|
||||
index_t DxDstVectorSize,
|
||||
index_t ScaleSrcDstVectorSize,
|
||||
index_t BiasDstVectorSize,
|
||||
index_t ScaleSrcVectorSize,
|
||||
index_t DscaleDbiasDstVectorSize,
|
||||
index_t MeanVarSrcVectorSize>
|
||||
struct GridwiseBatchNormBackwardWithBlockwiseWelford
|
||||
{
|
||||
@@ -165,7 +165,7 @@ struct GridwiseBatchNormBackwardWithBlockwiseWelford
|
||||
const XYGridDesc_M_K dy_grid_desc_m_k,
|
||||
const XYGridDesc_M_K dx_grid_desc_m_k,
|
||||
const ScaleBiasGridDesc_M scale_grid_desc_m,
|
||||
const ScaleBiasGridDesc_M bias_grid_desc_m,
|
||||
const ScaleBiasGridDesc_M dscale_dbias_grid_desc_m,
|
||||
const MeanVarGridDesc_M mean_var_grid_desc_m,
|
||||
const GetReduceCountPerThreadFunctor get_reduce_count_per_thread,
|
||||
long_index_t reduce_size,
|
||||
@@ -179,8 +179,8 @@ struct GridwiseBatchNormBackwardWithBlockwiseWelford
|
||||
const MeanVarDataType* const __restrict__ p_savedInvVar,
|
||||
const DyElementwiseOp dy_elementwise_op,
|
||||
DxDataType* const __restrict__ p_dx,
|
||||
ScaleDataType* const __restrict__ p_dscale,
|
||||
BiasDataType* const __restrict__ p_dbias)
|
||||
DscaleDbiasDataType* const __restrict__ p_dscale,
|
||||
DscaleDbiasDataType* const __restrict__ p_dbias)
|
||||
{
|
||||
using ck::math::sqrt;
|
||||
|
||||
@@ -253,7 +253,7 @@ struct GridwiseBatchNormBackwardWithBlockwiseWelford
|
||||
XSrcVectorSize,
|
||||
1,
|
||||
true>(
|
||||
x_grid_desc_m_k,
|
||||
dy_grid_desc_m_k,
|
||||
make_multi_index(block_global_id * M_BlockTileSize +
|
||||
thread_m_cluster_id * MThreadSliceSize,
|
||||
thread_k_cluster_id * KThreadSliceSize));
|
||||
@@ -271,7 +271,7 @@ struct GridwiseBatchNormBackwardWithBlockwiseWelford
|
||||
InMemoryDataOperationEnum::Set,
|
||||
1,
|
||||
true>(
|
||||
dy_grid_desc_m_k,
|
||||
dx_grid_desc_m_k,
|
||||
make_multi_index(block_global_id * M_BlockTileSize +
|
||||
thread_m_cluster_id * MThreadSliceSize,
|
||||
thread_k_cluster_id * KThreadSliceSize),
|
||||
@@ -285,45 +285,27 @@ struct GridwiseBatchNormBackwardWithBlockwiseWelford
|
||||
ThreadBufferLengths_M,
|
||||
Sequence<0>,
|
||||
0,
|
||||
ScaleSrcDstVectorSize,
|
||||
ScaleSrcVectorSize,
|
||||
1,
|
||||
true>(
|
||||
scale_grid_desc_m,
|
||||
make_multi_index(block_global_id * M_BlockTileSize +
|
||||
thread_m_cluster_id * MThreadSliceSize));
|
||||
|
||||
auto threadwise_dscale_store =
|
||||
auto threadwise_dscale_dbias_store =
|
||||
ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
|
||||
ScaleDataType,
|
||||
DscaleDbiasDataType,
|
||||
decltype(thread_buffer_desc_m),
|
||||
ScaleBiasGridDesc_M,
|
||||
PassThroughOp,
|
||||
ThreadBufferLengths_M,
|
||||
Sequence<0>,
|
||||
0,
|
||||
ScaleSrcDstVectorSize,
|
||||
DscaleDbiasDstVectorSize,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
1,
|
||||
true>(
|
||||
scale_grid_desc_m,
|
||||
make_multi_index(block_global_id * M_BlockTileSize +
|
||||
thread_m_cluster_id * MThreadSliceSize),
|
||||
PassThroughOp{});
|
||||
|
||||
auto threadwise_dbias_store =
|
||||
ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
|
||||
BiasDataType,
|
||||
decltype(thread_buffer_desc_m),
|
||||
ScaleBiasGridDesc_M,
|
||||
PassThroughOp,
|
||||
ThreadBufferLengths_M,
|
||||
Sequence<0>,
|
||||
0,
|
||||
BiasDstVectorSize,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
1,
|
||||
true>(
|
||||
bias_grid_desc_m,
|
||||
dscale_dbias_grid_desc_m,
|
||||
make_multi_index(block_global_id * M_BlockTileSize +
|
||||
thread_m_cluster_id * MThreadSliceSize),
|
||||
PassThroughOp{});
|
||||
@@ -344,10 +326,10 @@ struct GridwiseBatchNormBackwardWithBlockwiseWelford
|
||||
p_scale, scale_grid_desc_m.GetElementSpaceSize());
|
||||
|
||||
auto dscale_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_dscale, scale_grid_desc_m.GetElementSpaceSize());
|
||||
p_dscale, dscale_dbias_grid_desc_m.GetElementSpaceSize());
|
||||
|
||||
auto dbias_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_dbias, bias_grid_desc_m.GetElementSpaceSize());
|
||||
p_dbias, dscale_dbias_grid_desc_m.GetElementSpaceSize());
|
||||
|
||||
// clang-format off
|
||||
// Step 1: calculating mean and inv-variance using welford method (if savedMean/savedInvVar not available), where inv-variance = 1/sqrt(epsilon+variance)
|
||||
@@ -487,17 +469,17 @@ struct GridwiseBatchNormBackwardWithBlockwiseWelford
|
||||
|
||||
if(thread_k_cluster_id == 0)
|
||||
{
|
||||
threadwise_dscale_store.Run(thread_buffer_desc_m,
|
||||
make_tuple(I0),
|
||||
dscale_thread_buf,
|
||||
scale_grid_desc_m,
|
||||
dscale_global_buf);
|
||||
threadwise_dscale_dbias_store.Run(thread_buffer_desc_m,
|
||||
make_tuple(I0),
|
||||
dscale_thread_buf,
|
||||
dscale_dbias_grid_desc_m,
|
||||
dscale_global_buf);
|
||||
|
||||
threadwise_dbias_store.Run(thread_buffer_desc_m,
|
||||
make_tuple(I0),
|
||||
dbias_thread_buf,
|
||||
bias_grid_desc_m,
|
||||
dbias_global_buf);
|
||||
threadwise_dscale_dbias_store.Run(thread_buffer_desc_m,
|
||||
make_tuple(I0),
|
||||
dbias_thread_buf,
|
||||
dscale_dbias_grid_desc_m,
|
||||
dbias_global_buf);
|
||||
};
|
||||
|
||||
// clang-format off
|
||||
|
||||
@@ -1,258 +0,0 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck/utility/data_type.hpp"
|
||||
#include "ck/utility/math.hpp"
|
||||
#include "ck/tensor_operation/gpu/block/blockwise_welford.hpp"
|
||||
#include "ck/tensor_operation/gpu/thread/threadwise_welford.hpp"
|
||||
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <typename GridwiseMultiblockWelfordFirstHalf_,
|
||||
typename XDataType,
|
||||
typename MeanVarDataType,
|
||||
typename XGridDesc_M_K,
|
||||
typename MeanVarCountGridDesc_M_G,
|
||||
typename GetReduceCountPerThreadFunctor>
|
||||
__global__ void kernel_multiblock_welford_first_half(
|
||||
const XGridDesc_M_K x_grid_desc_m_k,
|
||||
const MeanVarCountGridDesc_M_G mean_var_count_grid_desc_m_g,
|
||||
const GetReduceCountPerThreadFunctor get_reduce_count_per_thread,
|
||||
index_t num_k_block_tile_iteration,
|
||||
const XDataType* const __restrict__ p_x,
|
||||
MeanVarDataType* const p_welford_mean,
|
||||
MeanVarDataType* const p_welford_variance,
|
||||
int32_t* const p_welford_count)
|
||||
{
|
||||
GridwiseMultiblockWelfordFirstHalf_::Run(x_grid_desc_m_k,
|
||||
mean_var_count_grid_desc_m_g,
|
||||
get_reduce_count_per_thread,
|
||||
num_k_block_tile_iteration,
|
||||
p_x,
|
||||
p_welford_mean,
|
||||
p_welford_variance,
|
||||
p_welford_count);
|
||||
};
|
||||
|
||||
template <typename XDataType,
|
||||
typename AccDataType,
|
||||
typename MeanVarDataType,
|
||||
typename XGridDesc_M_K,
|
||||
typename MeanVarCountGridDesc_M_G,
|
||||
typename GetReduceCountPerThreadFunctor,
|
||||
index_t BlockSize,
|
||||
index_t MThreadClusterSize,
|
||||
index_t KThreadClusterSize,
|
||||
index_t MThreadSliceSize,
|
||||
index_t KThreadSliceSize,
|
||||
index_t XSrcCountSrcVectorDim,
|
||||
index_t XSrcCountSrcVectorSize>
|
||||
struct GridwiseMultiblockWelfordFirstHalf
|
||||
{
|
||||
static_assert((XSrcCountSrcVectorDim == 0 && MThreadSliceSize % XSrcCountSrcVectorSize == 0) ||
|
||||
(XSrcCountSrcVectorDim == 1 &&
|
||||
KThreadSliceSize % XSrcCountSrcVectorSize == 0),
|
||||
"Invalid thread slice sizes and/or vector sizes configuration, please check!");
|
||||
|
||||
static constexpr bool reorder_thread_cluster = (XSrcCountSrcVectorDim == 0);
|
||||
|
||||
using ThreadClusterLengths_M_K = Sequence<MThreadClusterSize, KThreadClusterSize>;
|
||||
|
||||
using ThreadBufferDimAccessOrder =
|
||||
typename conditional<reorder_thread_cluster, Sequence<1, 0>, Sequence<0, 1>>::type;
|
||||
|
||||
using ThreadClusterArrangeOrder =
|
||||
typename conditional<reorder_thread_cluster, Sequence<1, 0>, Sequence<0, 1>>::type;
|
||||
|
||||
static constexpr auto thread_cluster_desc =
|
||||
make_cluster_descriptor(ThreadClusterLengths_M_K{}, ThreadClusterArrangeOrder{});
|
||||
|
||||
using ThreadReduceSrcDesc_M_K = decltype(make_naive_tensor_descriptor_packed(
|
||||
make_tuple(Number<MThreadSliceSize>{}, Number<KThreadSliceSize>{})));
|
||||
using ThreadReduceDstDesc_M =
|
||||
decltype(make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{})));
|
||||
|
||||
using ThreadwiseWelford =
|
||||
ThreadwiseWelford<AccDataType, ThreadReduceSrcDesc_M_K, ThreadReduceDstDesc_M>;
|
||||
|
||||
using BlockwiseWelford = BlockwiseWelford<AccDataType,
|
||||
BlockSize,
|
||||
ThreadClusterLengths_M_K,
|
||||
ThreadClusterArrangeOrder,
|
||||
false>;
|
||||
|
||||
using PassThroughOp = tensor_operation::element_wise::PassThrough;
|
||||
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
|
||||
static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize;
|
||||
static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize;
|
||||
|
||||
__device__ static void Run(const XGridDesc_M_K& x_grid_desc_m_k,
|
||||
const MeanVarCountGridDesc_M_G& mean_var_count_grid_desc_m_g,
|
||||
const GetReduceCountPerThreadFunctor& get_reduce_count_per_thread,
|
||||
index_t num_k_block_tile_iteration,
|
||||
const XDataType* const __restrict__ p_x,
|
||||
MeanVarDataType* const p_welford_mean,
|
||||
MeanVarDataType* const p_welford_variance,
|
||||
int32_t* const p_welford_count)
|
||||
{
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize * KThreadSliceSize, true>
|
||||
x_thread_buf;
|
||||
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true>
|
||||
welford_mean_thread_buf;
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true>
|
||||
welford_var_thread_buf;
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr, int32_t, MThreadSliceSize, true>
|
||||
welford_count_thread_buf;
|
||||
|
||||
const index_t blkgroup_size = mean_var_count_grid_desc_m_g.GetLength(I1);
|
||||
|
||||
const index_t thread_local_id = get_thread_local_1d_id();
|
||||
const index_t block_global_id = get_block_1d_id();
|
||||
const index_t blkgroup_id = block_global_id / blkgroup_size;
|
||||
const index_t block_local_id = block_global_id % blkgroup_size;
|
||||
|
||||
const auto thread_cluster_idx =
|
||||
thread_cluster_desc.CalculateBottomIndex(make_multi_index(thread_local_id));
|
||||
|
||||
const auto thread_m_cluster_id = thread_cluster_idx[I0];
|
||||
const auto thread_k_cluster_id = thread_cluster_idx[I1];
|
||||
|
||||
using ThreadBufferLengths_M_K = Sequence<MThreadSliceSize, KThreadSliceSize>;
|
||||
using ThreadBufferLengths_M_1 = Sequence<MThreadSliceSize, 1>;
|
||||
|
||||
constexpr auto thread_buffer_desc_m_k = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(Number<MThreadSliceSize>{}, Number<KThreadSliceSize>{}));
|
||||
constexpr auto thread_buffer_desc_m_1 = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(Number<MThreadSliceSize>{}, Number<1>{}));
|
||||
|
||||
const index_t reduceSizePerBlock = K_BlockTileSize * num_k_block_tile_iteration;
|
||||
|
||||
auto threadwise_x_load = ThreadwiseTensorSliceTransfer_v2<XDataType,
|
||||
AccDataType,
|
||||
XGridDesc_M_K,
|
||||
decltype(thread_buffer_desc_m_k),
|
||||
ThreadBufferLengths_M_K,
|
||||
ThreadBufferDimAccessOrder,
|
||||
XSrcCountSrcVectorDim,
|
||||
XSrcCountSrcVectorSize,
|
||||
1,
|
||||
true>(
|
||||
x_grid_desc_m_k,
|
||||
make_multi_index(blkgroup_id * M_BlockTileSize + thread_m_cluster_id * MThreadSliceSize,
|
||||
block_local_id * reduceSizePerBlock +
|
||||
thread_k_cluster_id * KThreadSliceSize));
|
||||
|
||||
auto threadwise_welford_mean_var_store =
|
||||
ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
|
||||
MeanVarDataType,
|
||||
decltype(thread_buffer_desc_m_1),
|
||||
MeanVarCountGridDesc_M_G,
|
||||
PassThroughOp,
|
||||
ThreadBufferLengths_M_1,
|
||||
Sequence<0, 1>,
|
||||
1,
|
||||
1,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
1,
|
||||
true>(
|
||||
mean_var_count_grid_desc_m_g,
|
||||
make_multi_index(blkgroup_id * M_BlockTileSize +
|
||||
thread_m_cluster_id * MThreadSliceSize,
|
||||
block_local_id),
|
||||
PassThroughOp{});
|
||||
|
||||
auto threadwise_welford_count_store =
|
||||
ThreadwiseTensorSliceTransfer_v1r3<int32_t,
|
||||
int32_t,
|
||||
decltype(thread_buffer_desc_m_1),
|
||||
MeanVarCountGridDesc_M_G,
|
||||
PassThroughOp,
|
||||
ThreadBufferLengths_M_1,
|
||||
Sequence<0, 1>,
|
||||
1,
|
||||
1,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
1,
|
||||
true>(
|
||||
mean_var_count_grid_desc_m_g,
|
||||
make_multi_index(blkgroup_id * M_BlockTileSize +
|
||||
thread_m_cluster_id * MThreadSliceSize,
|
||||
block_local_id),
|
||||
PassThroughOp{});
|
||||
|
||||
constexpr auto thread_copy_fwd_step_m_k = make_multi_index(0, K_BlockTileSize);
|
||||
|
||||
const auto x_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_x, x_grid_desc_m_k.GetElementSpaceSize());
|
||||
|
||||
auto welford_mean_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_welford_mean, mean_var_count_grid_desc_m_g.GetElementSpaceSize());
|
||||
|
||||
auto welford_var_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_welford_variance, mean_var_count_grid_desc_m_g.GetElementSpaceSize());
|
||||
|
||||
auto welford_count_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_welford_count, mean_var_count_grid_desc_m_g.GetElementSpaceSize());
|
||||
|
||||
auto threadwise_welford = ThreadwiseWelford();
|
||||
threadwise_welford.max_count_ =
|
||||
get_reduce_count_per_thread(block_local_id, thread_k_cluster_id);
|
||||
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
|
||||
welford_mean_thread_buf(I) = type_convert<AccDataType>(0.0f);
|
||||
welford_var_thread_buf(I) = type_convert<AccDataType>(0.0f);
|
||||
});
|
||||
|
||||
for(index_t reducedTiles = 0; reducedTiles < num_k_block_tile_iteration; ++reducedTiles)
|
||||
{
|
||||
threadwise_x_load.Run(x_grid_desc_m_k,
|
||||
x_global_val_buf,
|
||||
thread_buffer_desc_m_k,
|
||||
make_tuple(I0, I0),
|
||||
x_thread_buf);
|
||||
|
||||
threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, thread_copy_fwd_step_m_k);
|
||||
threadwise_welford.Run(x_thread_buf, welford_mean_thread_buf, welford_var_thread_buf);
|
||||
}
|
||||
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
|
||||
if constexpr(I > 0)
|
||||
block_sync_lds();
|
||||
|
||||
welford_count_thread_buf(I) = threadwise_welford.cur_count_;
|
||||
BlockwiseWelford::Run(
|
||||
welford_mean_thread_buf(I), welford_var_thread_buf(I), welford_count_thread_buf(I));
|
||||
});
|
||||
|
||||
if(thread_k_cluster_id == 0)
|
||||
{
|
||||
threadwise_welford_mean_var_store.Run(thread_buffer_desc_m_1,
|
||||
make_tuple(I0, I0),
|
||||
welford_mean_thread_buf,
|
||||
mean_var_count_grid_desc_m_g,
|
||||
welford_mean_global_val_buf);
|
||||
|
||||
threadwise_welford_mean_var_store.Run(thread_buffer_desc_m_1,
|
||||
make_tuple(I0, I0),
|
||||
welford_var_thread_buf,
|
||||
mean_var_count_grid_desc_m_g,
|
||||
welford_var_global_val_buf);
|
||||
|
||||
threadwise_welford_count_store.Run(thread_buffer_desc_m_1,
|
||||
make_tuple(I0, I0),
|
||||
welford_count_thread_buf,
|
||||
mean_var_count_grid_desc_m_g,
|
||||
welford_count_global_val_buf);
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,412 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <iostream>
|
||||
#include <array>
|
||||
#include <algorithm>
|
||||
#include <thread>
|
||||
|
||||
#include "ck/utility/math_v2.hpp"
|
||||
#include "ck/utility/ignore.hpp"
|
||||
#include "ck/library/utility/host_common_util.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_batchnorm_backward.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace host {
|
||||
|
||||
template <typename XDataType,
|
||||
typename DxDataType,
|
||||
typename DyDataType,
|
||||
typename AccDataType,
|
||||
typename ScaleDataType,
|
||||
typename DscaleDbiasDataType,
|
||||
typename MeanVarDataType,
|
||||
typename DyElementwiseOp,
|
||||
index_t Rank,
|
||||
index_t NumBatchNormReduceDim>
|
||||
struct ReferenceBatchNormBwd : public device::DeviceBatchNormBwd<XDataType,
|
||||
DxDataType,
|
||||
DyDataType,
|
||||
AccDataType,
|
||||
ScaleDataType,
|
||||
DscaleDbiasDataType,
|
||||
MeanVarDataType,
|
||||
DyElementwiseOp,
|
||||
Rank,
|
||||
NumBatchNormReduceDim>
|
||||
{
|
||||
static_assert(Rank <= 6, "Bigger Rank size is not supported!");
|
||||
|
||||
static constexpr index_t NumInvariantDim = Rank - NumBatchNormReduceDim;
|
||||
|
||||
struct Argument : public device::BaseArgument
|
||||
{
|
||||
Argument(const std::array<index_t, Rank> xyLengths,
|
||||
const std::array<index_t, Rank> xStrides,
|
||||
const std::array<index_t, Rank> dxStrides,
|
||||
const std::array<index_t, Rank> dyStrides,
|
||||
const std::array<int, NumBatchNormReduceDim> reduceDims,
|
||||
const std::array<index_t, NumInvariantDim> bnScaleBiasMeanVarLengths,
|
||||
const std::array<index_t, NumInvariantDim> bnScaleStrides,
|
||||
const std::array<index_t, NumInvariantDim> bnDscaleDbiasStrides,
|
||||
const std::array<index_t, NumInvariantDim> bnMeanVarStrides,
|
||||
const XDataType* p_x,
|
||||
const DyDataType* p_dy,
|
||||
const ScaleDataType* p_scale,
|
||||
const MeanVarDataType* p_savedMean,
|
||||
const MeanVarDataType* p_savedInvVar,
|
||||
double epsilon,
|
||||
const DyElementwiseOp dy_elementwise_op,
|
||||
DxDataType* p_dx,
|
||||
DscaleDbiasDataType* p_dscale,
|
||||
DscaleDbiasDataType* p_dbias)
|
||||
: reduceDims_(reduceDims),
|
||||
bnScaleBiasMeanVarLengths_(bnScaleBiasMeanVarLengths),
|
||||
bnScaleStrides_(bnScaleStrides),
|
||||
bnDscaleDbiasStrides_(bnDscaleDbiasStrides),
|
||||
bnMeanVarStrides_(bnMeanVarStrides),
|
||||
p_x_(p_x),
|
||||
p_dy_(p_dy),
|
||||
p_scale_(p_scale),
|
||||
p_savedMean_(p_savedMean),
|
||||
p_savedInvVar_(p_savedInvVar),
|
||||
dy_elementwise_op_(dy_elementwise_op),
|
||||
p_dx_(p_dx),
|
||||
p_dscale_(p_dscale),
|
||||
p_dbias_(p_dbias)
|
||||
{
|
||||
using ck::host_common::get_index_set;
|
||||
|
||||
if(std::any_of(
|
||||
reduceDims.begin(), reduceDims.end(), [](int d) { return d < 0 || d >= Rank; }))
|
||||
throw std::runtime_error("Invalid reduce dimensions!");
|
||||
|
||||
// get invariant_dims[] and invariant_lengths[]
|
||||
for(int dim = 0, i = 0; dim < Rank; dim++)
|
||||
if(std::none_of(
|
||||
reduceDims.begin(), reduceDims.end(), [&](int d) { return d == dim; }))
|
||||
{
|
||||
invariantDims_[i] = dim;
|
||||
invariant_lengths_[i] = xyLengths[dim];
|
||||
i++;
|
||||
};
|
||||
|
||||
// get reduce_lengths_[]
|
||||
for(int j = 0, i = 0; j < NumBatchNormReduceDim; j++)
|
||||
{
|
||||
int dim = reduceDims[j];
|
||||
reduce_lengths_[i++] = xyLengths[dim];
|
||||
};
|
||||
|
||||
for(int i = 0; i < NumInvariantDim; i++)
|
||||
if(invariant_lengths_[i] != bnScaleBiasMeanVarLengths_[i])
|
||||
throw std::runtime_error("Invalid lengths parameters!");
|
||||
|
||||
for(int j = 0, i = 0; j < NumInvariantDim; j++)
|
||||
{
|
||||
int dim = invariantDims_[j];
|
||||
x_invariant_strides_[i] = xStrides[dim];
|
||||
dy_invariant_strides_[i] = dyStrides[dim];
|
||||
dx_invariant_strides_[i] = dxStrides[dim];
|
||||
i++;
|
||||
};
|
||||
|
||||
for(int j = 0, i = 0; j < NumBatchNormReduceDim; j++)
|
||||
{
|
||||
int dim = reduceDims_[j];
|
||||
x_reduce_strides_[i] = xStrides[dim];
|
||||
dy_reduce_strides_[i] = dyStrides[dim];
|
||||
dx_reduce_strides_[i] = dxStrides[dim];
|
||||
i++;
|
||||
};
|
||||
|
||||
reduceSize_ = std::accumulate(
|
||||
reduce_lengths_.begin(), reduce_lengths_.end(), 1, std::multiplies<size_t>{});
|
||||
|
||||
invariant_index_set_ = get_index_set<NumInvariantDim>(invariant_lengths_);
|
||||
reduce_index_set_ = get_index_set<NumBatchNormReduceDim>(reduce_lengths_);
|
||||
|
||||
epsilon_ = type_convert<AccDataType>(epsilon);
|
||||
|
||||
haveSavedMeanInvVar_ = (p_savedMean != nullptr && p_savedInvVar != nullptr);
|
||||
}
|
||||
|
||||
std::array<int, NumBatchNormReduceDim> reduceDims_;
|
||||
std::array<int, NumInvariantDim> invariantDims_;
|
||||
std::array<index_t, NumInvariantDim> invariant_lengths_;
|
||||
std::array<index_t, NumBatchNormReduceDim> reduce_lengths_;
|
||||
|
||||
const std::array<index_t, NumInvariantDim> bnScaleBiasMeanVarLengths_;
|
||||
const std::array<index_t, NumInvariantDim> bnScaleStrides_;
|
||||
const std::array<index_t, NumInvariantDim> bnDscaleDbiasStrides_;
|
||||
const std::array<index_t, NumInvariantDim> bnMeanVarStrides_;
|
||||
|
||||
std::array<index_t, NumInvariantDim> x_invariant_strides_;
|
||||
std::array<index_t, NumInvariantDim> dy_invariant_strides_;
|
||||
std::array<index_t, NumInvariantDim> dx_invariant_strides_;
|
||||
std::array<index_t, NumBatchNormReduceDim> x_reduce_strides_;
|
||||
std::array<index_t, NumBatchNormReduceDim> dy_reduce_strides_;
|
||||
std::array<index_t, NumBatchNormReduceDim> dx_reduce_strides_;
|
||||
|
||||
const XDataType* p_x_;
|
||||
const DyDataType* p_dy_;
|
||||
const ScaleDataType* p_scale_;
|
||||
const MeanVarDataType* p_savedMean_;
|
||||
const MeanVarDataType* p_savedInvVar_;
|
||||
const DyElementwiseOp dy_elementwise_op_;
|
||||
|
||||
DxDataType* p_dx_;
|
||||
DscaleDbiasDataType* p_dscale_;
|
||||
DscaleDbiasDataType* p_dbias_;
|
||||
|
||||
bool haveSavedMeanInvVar_;
|
||||
|
||||
std::vector<std::array<index_t, NumInvariantDim>> invariant_index_set_;
|
||||
std::vector<std::array<index_t, NumBatchNormReduceDim>> reduce_index_set_;
|
||||
|
||||
AccDataType epsilon_;
|
||||
size_t reduceSize_;
|
||||
};
|
||||
|
||||
struct Invoker : public device::BaseInvoker
|
||||
{
|
||||
float Run(const Argument& arg)
|
||||
{
|
||||
using ck::host_common::get_offset_from_index;
|
||||
|
||||
auto thread_reduce_func = [&](auto invariant_index) {
|
||||
size_t x_invariant_offset = get_offset_from_index<NumInvariantDim>(
|
||||
arg.x_invariant_strides_, invariant_index);
|
||||
size_t dy_invariant_offset = get_offset_from_index<NumInvariantDim>(
|
||||
arg.dy_invariant_strides_, invariant_index);
|
||||
size_t dx_invariant_offset = get_offset_from_index<NumInvariantDim>(
|
||||
arg.dx_invariant_strides_, invariant_index);
|
||||
|
||||
AccDataType mean = type_convert<AccDataType>(0.0f);
|
||||
AccDataType variance = type_convert<AccDataType>(0.0f);
|
||||
AccDataType invVar;
|
||||
int32_t curr_count = 0;
|
||||
|
||||
if(arg.haveSavedMeanInvVar_)
|
||||
{
|
||||
size_t mean_invVar_invariant_offset = get_offset_from_index<NumInvariantDim>(
|
||||
arg.bnMeanVarStrides_, invariant_index);
|
||||
|
||||
mean =
|
||||
type_convert<AccDataType>(arg.p_savedMean_[mean_invVar_invariant_offset]);
|
||||
invVar =
|
||||
type_convert<AccDataType>(arg.p_savedInvVar_[mean_invVar_invariant_offset]);
|
||||
}
|
||||
else
|
||||
{
|
||||
// compute mean, variance using welford method
|
||||
for(const auto& reduce_index : arg.reduce_index_set_)
|
||||
{
|
||||
size_t x_reduce_offset = get_offset_from_index<NumBatchNormReduceDim>(
|
||||
arg.x_reduce_strides_, reduce_index);
|
||||
|
||||
auto x_offset = x_invariant_offset + x_reduce_offset;
|
||||
|
||||
curr_count++;
|
||||
|
||||
AccDataType x = type_convert<AccDataType>(arg.p_x_[x_offset]);
|
||||
|
||||
AccDataType delta = x - mean;
|
||||
|
||||
mean += delta / curr_count;
|
||||
|
||||
AccDataType delta2 = x - mean;
|
||||
|
||||
variance += delta * delta2;
|
||||
};
|
||||
|
||||
// actual variance
|
||||
variance = variance / curr_count;
|
||||
|
||||
// inv-variance defined as 1/sqrt(epsilon+variance)
|
||||
invVar =
|
||||
type_convert<AccDataType>(1.0f) / ck::math::sqrt(arg.epsilon_ + variance);
|
||||
};
|
||||
|
||||
AccDataType dbias =
|
||||
type_convert<AccDataType>(0.0f); // Sum on reduced dimensions of dy
|
||||
AccDataType dscale =
|
||||
type_convert<AccDataType>(0.0f); // Sum on reduced dimensions of dy * norm_x
|
||||
|
||||
// 1) calculate dy * (x - mean) * inv-variance
|
||||
// 2) calculate sum(dy) on reduced dimensions
|
||||
// 3) calculate sum(dy * norm_x) on reduced dimensions
|
||||
for(const auto& reduce_index : arg.reduce_index_set_)
|
||||
{
|
||||
size_t x_reduce_offset = get_offset_from_index<NumBatchNormReduceDim>(
|
||||
arg.x_reduce_strides_, reduce_index);
|
||||
size_t dy_reduce_offset = get_offset_from_index<NumBatchNormReduceDim>(
|
||||
arg.dy_reduce_strides_, reduce_index);
|
||||
|
||||
auto x_offset = x_invariant_offset + x_reduce_offset;
|
||||
auto dy_offset = dy_invariant_offset + dy_reduce_offset;
|
||||
|
||||
AccDataType x = type_convert<AccDataType>(arg.p_x_[x_offset]);
|
||||
|
||||
AccDataType norm_x = (x - mean) * invVar;
|
||||
AccDataType dy = type_convert<AccDataType>(arg.p_dy_[dy_offset]);
|
||||
|
||||
arg.dy_elementwise_op_(dy, dy);
|
||||
|
||||
dbias += dy;
|
||||
dscale += norm_x * dy;
|
||||
};
|
||||
|
||||
size_t dscale_offset = get_offset_from_index<NumInvariantDim>(
|
||||
arg.bnDscaleDbiasStrides_, invariant_index);
|
||||
size_t dbias_offset = get_offset_from_index<NumInvariantDim>(
|
||||
arg.bnDscaleDbiasStrides_, invariant_index);
|
||||
|
||||
arg.p_dscale_[dscale_offset] = type_convert<DscaleDbiasDataType>(dscale);
|
||||
arg.p_dbias_[dbias_offset] = type_convert<DscaleDbiasDataType>(dbias);
|
||||
|
||||
size_t scale_offset =
|
||||
get_offset_from_index<NumInvariantDim>(arg.bnScaleStrides_, invariant_index);
|
||||
|
||||
AccDataType scale = type_convert<AccDataType>(arg.p_scale_[scale_offset]);
|
||||
|
||||
AccDataType multiplier = type_convert<AccDataType>(1.0f) /
|
||||
type_convert<AccDataType>(arg.reduceSize_) * invVar *
|
||||
scale;
|
||||
|
||||
// 1) calculate tmp = dscale * (x - mean) * inv-variance
|
||||
// 2) calculate dx = 1/reduceSize * inv-variance * scale * (reduceSize * dy - dbias
|
||||
// - tmp)
|
||||
for(const auto& reduce_index : arg.reduce_index_set_)
|
||||
{
|
||||
size_t x_reduce_offset = get_offset_from_index<NumBatchNormReduceDim>(
|
||||
arg.x_reduce_strides_, reduce_index);
|
||||
size_t dy_reduce_offset = get_offset_from_index<NumBatchNormReduceDim>(
|
||||
arg.dy_reduce_strides_, reduce_index);
|
||||
size_t dx_reduce_offset = get_offset_from_index<NumBatchNormReduceDim>(
|
||||
arg.dx_reduce_strides_, reduce_index);
|
||||
|
||||
auto x_offset = x_invariant_offset + x_reduce_offset;
|
||||
auto dy_offset = dy_invariant_offset + dy_reduce_offset;
|
||||
auto dx_offset = dx_invariant_offset + dx_reduce_offset;
|
||||
|
||||
AccDataType x = type_convert<AccDataType>(arg.p_x_[x_offset]);
|
||||
|
||||
AccDataType norm_x = (x - mean) * invVar;
|
||||
AccDataType dy = type_convert<AccDataType>(arg.p_dy_[dy_offset]);
|
||||
|
||||
arg.dy_elementwise_op_(dy, dy);
|
||||
|
||||
AccDataType tmpVal = norm_x * dscale;
|
||||
|
||||
AccDataType dx = multiplier * (type_convert<AccDataType>(arg.reduceSize_) * dy -
|
||||
dbias - tmpVal);
|
||||
|
||||
arg.p_dx_[dx_offset] = type_convert<DxDataType>(dx);
|
||||
};
|
||||
};
|
||||
|
||||
std::size_t num_thread = std::thread::hardware_concurrency();
|
||||
std::size_t work_per_thread =
|
||||
(arg.invariant_index_set_.size() + num_thread - 1) / num_thread;
|
||||
|
||||
std::vector<joinable_thread> threads(num_thread);
|
||||
|
||||
for(std::size_t it = 0; it < num_thread; ++it)
|
||||
{
|
||||
std::size_t i_begin = it * work_per_thread;
|
||||
std::size_t i_end = std::min(static_cast<size_t>((it + 1) * work_per_thread),
|
||||
arg.invariant_index_set_.size());
|
||||
|
||||
auto f = [=] {
|
||||
for(std::size_t i = i_begin; i < i_end; ++i)
|
||||
{
|
||||
thread_reduce_func(arg.invariant_index_set_[i]);
|
||||
}
|
||||
};
|
||||
|
||||
threads[it] = joinable_thread(f);
|
||||
}
|
||||
|
||||
return (0.0f);
|
||||
};
|
||||
|
||||
float Run(const device::BaseArgument* p_arg,
|
||||
const StreamConfig& /*stream_config*/ = StreamConfig{}) override
|
||||
{
|
||||
return Run(*dynamic_cast<const Argument*>(p_arg));
|
||||
};
|
||||
};
|
||||
|
||||
bool IsSupportedArgument(const device::BaseArgument* p_arg) override
|
||||
{
|
||||
(void)p_arg;
|
||||
|
||||
return (true);
|
||||
};
|
||||
|
||||
std::unique_ptr<device::BaseArgument>
|
||||
MakeArgumentPointer(const std::array<index_t, Rank> xyLengths,
|
||||
const std::array<index_t, Rank> xStrides,
|
||||
const std::array<index_t, Rank> dxStrides,
|
||||
const std::array<index_t, Rank> dyStrides,
|
||||
const std::array<int, NumBatchNormReduceDim> reduceDims,
|
||||
const std::array<index_t, NumInvariantDim> bnScaleBiasMeanVarLengths,
|
||||
const std::array<index_t, NumInvariantDim> bnScaleStrides,
|
||||
const std::array<index_t, NumInvariantDim> bnDscaleDbiasStrides,
|
||||
const std::array<index_t, NumInvariantDim> bnMeanVarStrides,
|
||||
const void* p_x,
|
||||
const void* p_dy,
|
||||
const void* p_scale,
|
||||
const void* p_savedMean,
|
||||
const void* p_savedInvVar,
|
||||
double epsilon,
|
||||
const DyElementwiseOp dy_elementwise_op,
|
||||
void* p_dx,
|
||||
void* p_dscale,
|
||||
void* p_dbias) override
|
||||
{
|
||||
return std::make_unique<Argument>(xyLengths,
|
||||
xStrides,
|
||||
dxStrides,
|
||||
dyStrides,
|
||||
reduceDims,
|
||||
bnScaleBiasMeanVarLengths,
|
||||
bnScaleStrides,
|
||||
bnDscaleDbiasStrides,
|
||||
bnMeanVarStrides,
|
||||
static_cast<const XDataType*>(p_x),
|
||||
static_cast<const DyDataType*>(p_dy),
|
||||
static_cast<const ScaleDataType*>(p_scale),
|
||||
static_cast<const MeanVarDataType*>(p_savedMean),
|
||||
static_cast<const MeanVarDataType*>(p_savedInvVar),
|
||||
epsilon,
|
||||
dy_elementwise_op,
|
||||
static_cast<DxDataType*>(p_dx),
|
||||
static_cast<DscaleDbiasDataType*>(p_dscale),
|
||||
static_cast<DscaleDbiasDataType*>(p_dbias));
|
||||
};
|
||||
|
||||
std::unique_ptr<device::BaseInvoker> MakeInvokerPointer() override
|
||||
{
|
||||
return std::make_unique<Invoker>();
|
||||
};
|
||||
|
||||
std::string GetTypeString() const override
|
||||
{
|
||||
auto str = std::stringstream();
|
||||
|
||||
// clang-format off
|
||||
str << "Reference_BatchNorm_Backward" << std::endl;
|
||||
// clang-format on
|
||||
|
||||
return str.str();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace host
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -1,319 +0,0 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
#include <algorithm>
|
||||
|
||||
#include "ck/tensor_operation/gpu/device/device_batchnorm_backward.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace host {
|
||||
|
||||
template <typename XDataType,
|
||||
typename DyDataType,
|
||||
typename DxDataType,
|
||||
typename AccDataType,
|
||||
typename ScaleDataType,
|
||||
typename BiasDataType,
|
||||
typename MeanVarDataType,
|
||||
typename DyElementwiseOp>
|
||||
struct ReferenceBatchNormBwd_Input_N_H_W_C_Output_C
|
||||
: public device::DeviceBatchNormBwd<4, 3, DyElementwiseOp>
|
||||
{
|
||||
struct Argument : public device::BaseArgument
|
||||
{
|
||||
Argument(const std::array<index_t, 4> xyLengths,
|
||||
const std::array<index_t, 4> xStrides,
|
||||
const std::array<index_t, 4> dyStrides,
|
||||
const std::array<index_t, 4> dxStrides,
|
||||
const std::array<int, 3> reduceDims,
|
||||
const std::array<ck::index_t, 1> bnScaleBiasMeanVarLengths,
|
||||
const std::array<ck::index_t, 1> bnScaleStrides,
|
||||
const std::array<ck::index_t, 1> bnBiasStrides,
|
||||
const std::array<ck::index_t, 1> bnMeanVarStrides,
|
||||
const XDataType* p_x,
|
||||
const DyDataType* p_dy,
|
||||
const ScaleDataType* p_scale,
|
||||
const MeanVarDataType* p_savedMean,
|
||||
const MeanVarDataType* p_savedInvVar,
|
||||
double epsilon,
|
||||
const DyElementwiseOp dy_elementwise_op,
|
||||
DxDataType* p_dx,
|
||||
ScaleDataType* p_dscale,
|
||||
BiasDataType* p_dbias)
|
||||
: p_x_(p_x),
|
||||
p_dy_(p_dy),
|
||||
p_scale_(p_scale),
|
||||
p_savedMean_(p_savedMean),
|
||||
p_savedInvVar_(p_savedInvVar),
|
||||
epsilon_(epsilon),
|
||||
dy_elementwise_op_(dy_elementwise_op),
|
||||
p_dx_(p_dx),
|
||||
p_dscale_(p_dscale),
|
||||
p_dbias_(p_dbias)
|
||||
{
|
||||
ignore = xStrides;
|
||||
ignore = dyStrides;
|
||||
ignore = dxStrides;
|
||||
ignore = bnScaleStrides;
|
||||
ignore = bnBiasStrides;
|
||||
ignore = bnMeanVarStrides;
|
||||
|
||||
if(xyLengths.size() != 4 || bnScaleBiasMeanVarLengths.size() != 1 ||
|
||||
bnScaleBiasMeanVarLengths[0] != xyLengths[3])
|
||||
throw std::runtime_error("Invalid tensor dimensions!");
|
||||
|
||||
if(reduceDims[0] != 0 || reduceDims[1] != 1 || reduceDims[2] != 2)
|
||||
throw std::runtime_error("Invalid reduce dimensions!");
|
||||
|
||||
n_ = xyLengths[0];
|
||||
h_ = xyLengths[1];
|
||||
w_ = xyLengths[2];
|
||||
c_ = xyLengths[3];
|
||||
|
||||
haveSavedMeanInvVar_ = (p_savedMean != nullptr && p_savedInvVar != nullptr);
|
||||
}
|
||||
|
||||
const XDataType* p_x_;
|
||||
const DyDataType* p_dy_;
|
||||
const ScaleDataType* p_scale_;
|
||||
const MeanVarDataType* p_savedMean_;
|
||||
const MeanVarDataType* p_savedInvVar_;
|
||||
|
||||
double epsilon_;
|
||||
const DyElementwiseOp dy_elementwise_op_;
|
||||
|
||||
DxDataType* p_dx_;
|
||||
ScaleDataType* p_dscale_;
|
||||
BiasDataType* p_dbias_;
|
||||
|
||||
bool haveSavedMeanInvVar_;
|
||||
|
||||
index_t n_, h_, w_, c_;
|
||||
};
|
||||
|
||||
struct Invoker : public device::BaseInvoker
|
||||
{
|
||||
float Run(const Argument& arg)
|
||||
{
|
||||
auto thread_reduce_func = [&](auto iC) {
|
||||
AccDataType reduceSize = type_convert<AccDataType>(arg.n_) *
|
||||
type_convert<AccDataType>(arg.h_) *
|
||||
type_convert<AccDataType>(arg.w_);
|
||||
index_t offset_C = iC;
|
||||
AccDataType mean;
|
||||
AccDataType invVar;
|
||||
|
||||
if(arg.haveSavedMeanInvVar_)
|
||||
{
|
||||
mean = arg.p_savedMean_[offset_C];
|
||||
invVar = arg.p_savedInvVar_[offset_C];
|
||||
}
|
||||
else
|
||||
{
|
||||
AccDataType meansquare;
|
||||
|
||||
meansquare = type_convert<AccDataType>(0.0f);
|
||||
mean = type_convert<AccDataType>(0.0f);
|
||||
|
||||
// compute mean, meanquare, variance, inv-variance
|
||||
for(index_t iN = 0; iN < arg.n_; iN++)
|
||||
{
|
||||
index_t offset_N = iN * arg.h_ * arg.w_ * arg.c_;
|
||||
for(index_t iH = 0; iH < arg.h_; iH++)
|
||||
{
|
||||
index_t offset_H = iH * arg.w_ * arg.c_;
|
||||
for(index_t iW = 0; iW < arg.w_; iW++)
|
||||
{
|
||||
index_t offset_W = iW * arg.c_;
|
||||
|
||||
auto offset = offset_N + offset_H + offset_W + offset_C;
|
||||
|
||||
AccDataType x = type_convert<AccDataType>(arg.p_x_[offset]);
|
||||
|
||||
mean += x;
|
||||
meansquare += x * x;
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
mean = mean / reduceSize;
|
||||
meansquare = meansquare / reduceSize;
|
||||
|
||||
AccDataType variance = meansquare - mean * mean;
|
||||
invVar = type_convert<AccDataType>(1.0f) /
|
||||
std::sqrt(type_convert<AccDataType>(arg.epsilon_) + variance);
|
||||
};
|
||||
|
||||
AccDataType dbias = type_convert<AccDataType>(0.0f); // Sum on NHW of dy
|
||||
AccDataType dscale = type_convert<AccDataType>(0.0f); // Sum on NHW of dy * norm_x
|
||||
|
||||
// 1) calculate dy * (x - mean) * inv-variance
|
||||
// 2) calculate sum(dy) on NHW dimensions
|
||||
// 3) calculate sum(dy * norm_x) on NHW dimensions
|
||||
for(index_t iN = 0; iN < arg.n_; iN++)
|
||||
{
|
||||
index_t offset_N = iN * arg.h_ * arg.w_ * arg.c_;
|
||||
for(index_t iH = 0; iH < arg.h_; iH++)
|
||||
{
|
||||
index_t offset_H = iH * arg.w_ * arg.c_;
|
||||
for(index_t iW = 0; iW < arg.w_; iW++)
|
||||
{
|
||||
index_t offset_W = iW * arg.c_;
|
||||
|
||||
auto offset = offset_N + offset_H + offset_W + offset_C;
|
||||
|
||||
AccDataType x = type_convert<AccDataType>(arg.p_x_[offset]);
|
||||
|
||||
AccDataType norm_x = (x - mean) * invVar;
|
||||
AccDataType dy = type_convert<AccDataType>(arg.p_dy_[offset]);
|
||||
|
||||
arg.dy_elementwise_op_(dy, dy);
|
||||
|
||||
dbias += dy;
|
||||
dscale += norm_x * dy;
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
arg.p_dscale_[offset_C] = type_convert<ScaleDataType>(dscale);
|
||||
arg.p_dbias_[offset_C] = type_convert<BiasDataType>(dbias);
|
||||
|
||||
AccDataType scale = type_convert<AccDataType>(arg.p_scale_[offset_C]);
|
||||
AccDataType multiplier =
|
||||
type_convert<AccDataType>(1.0f) / reduceSize * invVar * scale;
|
||||
|
||||
// 1) calculate tmp = dscale * (x - mean) * inv-variance
|
||||
// 2) calculate dx = 1/nhw * inv-variance * scale * (nhw * dy - dbias - tmp)
|
||||
for(index_t iN = 0; iN < arg.n_; iN++)
|
||||
{
|
||||
index_t offset_N = iN * arg.h_ * arg.w_ * arg.c_;
|
||||
for(index_t iH = 0; iH < arg.h_; iH++)
|
||||
{
|
||||
index_t offset_H = iH * arg.w_ * arg.c_;
|
||||
for(index_t iW = 0; iW < arg.w_; iW++)
|
||||
{
|
||||
index_t offset_W = iW * arg.c_;
|
||||
|
||||
auto offset = offset_N + offset_H + offset_W + offset_C;
|
||||
|
||||
AccDataType x = type_convert<AccDataType>(arg.p_x_[offset]);
|
||||
|
||||
AccDataType norm_x = (x - mean) * invVar;
|
||||
AccDataType dy = type_convert<AccDataType>(arg.p_dy_[offset]);
|
||||
|
||||
arg.dy_elementwise_op_(dy, dy);
|
||||
|
||||
AccDataType tmpVal = norm_x * dscale;
|
||||
|
||||
AccDataType dx = multiplier * (reduceSize * dy - dbias - tmpVal);
|
||||
|
||||
arg.p_dx_[offset] = type_convert<XDataType>(dx);
|
||||
};
|
||||
}
|
||||
};
|
||||
};
|
||||
|
||||
std::size_t num_thread = std::thread::hardware_concurrency();
|
||||
std::size_t work_per_thread = (arg.c_ + num_thread - 1) / num_thread;
|
||||
|
||||
std::vector<joinable_thread> threads(num_thread);
|
||||
|
||||
for(std::size_t it = 0; it < num_thread; ++it)
|
||||
{
|
||||
std::size_t ic_begin = it * work_per_thread;
|
||||
std::size_t ic_end = std::min(static_cast<int>((it + 1) * work_per_thread), arg.c_);
|
||||
|
||||
auto f = [=] {
|
||||
for(std::size_t ic = ic_begin; ic < ic_end; ++ic)
|
||||
{
|
||||
thread_reduce_func(ic);
|
||||
}
|
||||
};
|
||||
|
||||
threads[it] = joinable_thread(f);
|
||||
}
|
||||
|
||||
return (0.0f);
|
||||
};
|
||||
|
||||
float Run(const device::BaseArgument* p_arg,
|
||||
const StreamConfig& /*stream_config*/ = StreamConfig{}) override
|
||||
{
|
||||
return Run(*dynamic_cast<const Argument*>(p_arg));
|
||||
};
|
||||
};
|
||||
|
||||
bool IsSupportedArgument(const device::BaseArgument* p_arg) override
|
||||
{
|
||||
(void)p_arg;
|
||||
|
||||
return (true);
|
||||
};
|
||||
|
||||
std::unique_ptr<device::BaseArgument>
|
||||
MakeArgumentPointer(const std::array<index_t, 4> xyLengths,
|
||||
const std::array<index_t, 4> xStrides,
|
||||
const std::array<index_t, 4> dyStrides,
|
||||
const std::array<index_t, 4> dxStrides,
|
||||
const std::array<int, 3> reduceDims,
|
||||
const std::array<ck::index_t, 1> bnScaleBiasMeanVarLengths,
|
||||
const std::array<ck::index_t, 1> bnScaleStrides,
|
||||
const std::array<ck::index_t, 1> bnBiasStrides,
|
||||
const std::array<ck::index_t, 1> bnMeanVarStrides,
|
||||
const void* p_x,
|
||||
const void* p_dy,
|
||||
const void* p_scale,
|
||||
const void* p_savedMean,
|
||||
const void* p_savedInvVar,
|
||||
double epsilon,
|
||||
const DyElementwiseOp dy_elementwise_op,
|
||||
void* p_dx,
|
||||
void* p_dscale,
|
||||
void* p_dbias) override
|
||||
{
|
||||
return std::make_unique<Argument>(xyLengths,
|
||||
xStrides,
|
||||
dyStrides,
|
||||
dxStrides,
|
||||
reduceDims,
|
||||
bnScaleBiasMeanVarLengths,
|
||||
bnScaleStrides,
|
||||
bnBiasStrides,
|
||||
bnMeanVarStrides,
|
||||
static_cast<const XDataType*>(p_x),
|
||||
static_cast<const DyDataType*>(p_dy),
|
||||
static_cast<const ScaleDataType*>(p_scale),
|
||||
static_cast<const MeanVarDataType*>(p_savedMean),
|
||||
static_cast<const MeanVarDataType*>(p_savedInvVar),
|
||||
epsilon,
|
||||
dy_elementwise_op,
|
||||
static_cast<DxDataType*>(p_dx),
|
||||
static_cast<ScaleDataType*>(p_dscale),
|
||||
static_cast<BiasDataType*>(p_dbias));
|
||||
};
|
||||
|
||||
std::unique_ptr<device::BaseInvoker> MakeInvokerPointer() override
|
||||
{
|
||||
return std::make_unique<Invoker>();
|
||||
};
|
||||
|
||||
std::string GetTypeString() const override
|
||||
{
|
||||
auto str = std::stringstream();
|
||||
|
||||
// clang-format off
|
||||
str << "Reference_BatchNorm_Backward_NHWC_C<" << std::endl;
|
||||
// clang-format on
|
||||
|
||||
return str.str();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace host
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,124 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cstdlib>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_batchnorm_backward.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
// FP16
|
||||
void add_device_batchnorm_backward_rank_4_3_f16_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceBatchNormBwd<F16, F32, F32, F32, F16, F32, F32, PassThrough, 4, 3>>>&);
|
||||
|
||||
// FP32
|
||||
void add_device_batchnorm_backward_rank_4_3_f32_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceBatchNormBwd<F32, F32, F32, F32, F32, F32, F32, PassThrough, 4, 3>>>&);
|
||||
|
||||
// BF16
|
||||
void add_device_batchnorm_backward_rank_4_3_bf16_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceBatchNormBwd<BF16, F32, F32, F32, BF16, F32, F32, PassThrough, 4, 3>>>&);
|
||||
|
||||
// FP64
|
||||
void add_device_batchnorm_backward_rank_4_3_f64_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceBatchNormBwd<F64, F64, F64, F64, F64, F64, F64, PassThrough, 4, 3>>>&);
|
||||
|
||||
template <typename XDataType,
|
||||
typename DxDataType,
|
||||
typename DyDataType,
|
||||
typename AccDataType,
|
||||
typename ScaleDataType,
|
||||
typename DscaleDbiasDataType,
|
||||
typename MeanVarDataType,
|
||||
typename DyElementwiseOp,
|
||||
index_t Rank,
|
||||
index_t NumReduceDim>
|
||||
struct DeviceOperationInstanceFactory<
|
||||
ck::tensor_operation::device::DeviceBatchNormBwd<XDataType,
|
||||
DxDataType,
|
||||
DyDataType,
|
||||
AccDataType,
|
||||
ScaleDataType,
|
||||
DscaleDbiasDataType,
|
||||
MeanVarDataType,
|
||||
DyElementwiseOp,
|
||||
Rank,
|
||||
NumReduceDim>>
|
||||
{
|
||||
using DeviceOp = DeviceBatchNormBwd<XDataType,
|
||||
DxDataType,
|
||||
DyDataType,
|
||||
AccDataType,
|
||||
ScaleDataType,
|
||||
DscaleDbiasDataType,
|
||||
MeanVarDataType,
|
||||
DyElementwiseOp,
|
||||
Rank,
|
||||
NumReduceDim>;
|
||||
|
||||
static auto GetInstances()
|
||||
{
|
||||
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
|
||||
|
||||
if constexpr(is_same_v<XDataType, F16> && is_same_v<DxDataType, F32> &&
|
||||
is_same_v<DyDataType, F32> && is_same_v<AccDataType, F32> &&
|
||||
is_same_v<ScaleDataType, F16> && is_same_v<DscaleDbiasDataType, F32> &&
|
||||
is_same_v<MeanVarDataType, F32>)
|
||||
{
|
||||
if constexpr(Rank == 4 && NumReduceDim == 3 && is_same_v<DyElementwiseOp, PassThrough>)
|
||||
{
|
||||
add_device_batchnorm_backward_rank_4_3_f16_instances(op_ptrs);
|
||||
}
|
||||
}
|
||||
else if constexpr(is_same_v<XDataType, F32> && is_same_v<DxDataType, F32> &&
|
||||
is_same_v<DyDataType, F32> && is_same_v<AccDataType, F32> &&
|
||||
is_same_v<ScaleDataType, F32> && is_same_v<DscaleDbiasDataType, F32> &&
|
||||
is_same_v<MeanVarDataType, F32>)
|
||||
{
|
||||
if constexpr(Rank == 4 && NumReduceDim == 3 && is_same_v<DyElementwiseOp, PassThrough>)
|
||||
{
|
||||
add_device_batchnorm_backward_rank_4_3_f32_instances(op_ptrs);
|
||||
}
|
||||
}
|
||||
else if constexpr(is_same_v<XDataType, BF16> && is_same_v<DxDataType, F32> &&
|
||||
is_same_v<DyDataType, F32> && is_same_v<AccDataType, F32> &&
|
||||
is_same_v<ScaleDataType, BF16> && is_same_v<DscaleDbiasDataType, F32> &&
|
||||
is_same_v<MeanVarDataType, F32>)
|
||||
{
|
||||
if constexpr(Rank == 4 && NumReduceDim == 3 && is_same_v<DyElementwiseOp, PassThrough>)
|
||||
{
|
||||
add_device_batchnorm_backward_rank_4_3_bf16_instances(op_ptrs);
|
||||
}
|
||||
}
|
||||
else if constexpr(is_same_v<XDataType, F64> && is_same_v<DxDataType, F64> &&
|
||||
is_same_v<DyDataType, F64> && is_same_v<AccDataType, F64> &&
|
||||
is_same_v<ScaleDataType, F64> && is_same_v<DscaleDbiasDataType, F64> &&
|
||||
is_same_v<MeanVarDataType, F64>)
|
||||
{
|
||||
if constexpr(Rank == 4 && NumReduceDim == 3 && is_same_v<DyElementwiseOp, PassThrough>)
|
||||
{
|
||||
add_device_batchnorm_backward_rank_4_3_f64_instances(op_ptrs);
|
||||
}
|
||||
}
|
||||
|
||||
return op_ptrs;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -3,4 +3,8 @@ add_instance_library(device_batchnorm_instance
|
||||
device_batchnorm_forward_f32_instance.cpp
|
||||
device_batchnorm_forward_bf16_instance.cpp
|
||||
device_batchnorm_forward_f64_instance.cpp
|
||||
device_batchnorm_backward_f16_instance.cpp
|
||||
device_batchnorm_backward_f32_instance.cpp
|
||||
device_batchnorm_backward_bf16_instance.cpp
|
||||
device_batchnorm_backward_f64_instance.cpp
|
||||
)
|
||||
|
||||
@@ -0,0 +1,146 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_batchnorm_backward_impl.hpp"
|
||||
#include "ck/utility/data_type.hpp"
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
using BF16 = ck::bhalf_t;
|
||||
using F32 = float;
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
// clang-format off
|
||||
template <index_t Rank, index_t NumReduceDim, typename DyElementwiseOp>
|
||||
using device_batchnorm_backward_bf16_blockwise_instances =
|
||||
std::tuple <
|
||||
// XDataType, DxDataType, DyDataType, AccDataType, ScaleDataType, DscaleDbiasDataType, MeanVarDataType, DyElementwiseOp, Rank, NumReduceDim, UseMultiBlockInK, BLockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XDyDxVectorDim, XSrcVectorSize, DySrcVectorSize, DxDstVectorSize, ScaleSrcVectorSize, DscaleDbiasDstVectorSize, MeanVarSrcVectorSize
|
||||
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 128, 2, 2, 2, 0, 2, 2, 2, 2, 2, 2>,
|
||||
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 128, 2, 2, 2, 0, 1, 1, 1, 2, 2, 2>,
|
||||
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 128, 2, 2, 2, 0, 1, 1, 1, 1, 1, 1>,
|
||||
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 128, 2, 2, 2, 0, 2, 2, 2, 1, 1, 1>,
|
||||
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 128, 2, 2, 2, 1, 1, 1, 1, 2, 2, 2>,
|
||||
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 128, 2, 2, 2, 1, 1, 1, 1, 1, 1, 1>,
|
||||
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 64, 4, 2, 2, 0, 2, 2, 2, 2, 2, 2>,
|
||||
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 64, 4, 2, 2, 0, 1, 1, 1, 2, 2, 2>,
|
||||
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 64, 4, 2, 2, 0, 1, 1, 1, 1, 1, 1>,
|
||||
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 64, 4, 2, 2, 0, 2, 2, 2, 1, 1, 1>,
|
||||
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 64, 4, 2, 2, 1, 1, 1, 1, 2, 2, 2>,
|
||||
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 64, 4, 2, 2, 1, 1, 1, 1, 1, 1, 1>,
|
||||
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 32, 8, 2, 2, 0, 2, 2, 2, 2, 2, 2>,
|
||||
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 32, 8, 2, 2, 0, 1, 1, 1, 2, 2, 2>,
|
||||
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 32, 8, 2, 2, 0, 1, 1, 1, 1, 1, 1>,
|
||||
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 32, 8, 2, 2, 0, 2, 2, 2, 1, 1, 1>,
|
||||
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 32, 8, 2, 2, 1, 1, 1, 1, 2, 2, 2>,
|
||||
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 32, 8, 2, 2, 1, 1, 1, 1, 1, 1, 1>,
|
||||
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 16, 16, 2, 2, 0, 2, 2, 2, 2, 2, 2>,
|
||||
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 16, 16, 2, 2, 0, 1, 1, 1, 2, 2, 2>,
|
||||
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 16, 16, 2, 2, 0, 1, 1, 1, 1, 1, 1>,
|
||||
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 16, 16, 2, 2, 0, 2, 2, 2, 1, 1, 1>,
|
||||
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 16, 16, 2, 2, 1, 1, 1, 1, 2, 2, 2>,
|
||||
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 16, 16, 2, 2, 1, 1, 1, 1, 1, 1, 1>,
|
||||
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 8, 32, 2, 2, 0, 2, 2, 2, 2, 2, 2>,
|
||||
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 8, 32, 2, 2, 0, 1, 1, 1, 2, 2, 2>,
|
||||
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 8, 32, 2, 2, 0, 1, 1, 1, 1, 1, 1>,
|
||||
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 8, 32, 2, 2, 0, 2, 2, 2, 1, 1, 1>,
|
||||
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 8, 32, 2, 2, 1, 1, 1, 1, 2, 2, 2>,
|
||||
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 8, 32, 2, 2, 1, 1, 1, 1, 1, 1, 1>,
|
||||
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 4, 64, 2, 2, 0, 2, 2, 2, 2, 2, 2>,
|
||||
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 4, 64, 2, 2, 0, 1, 1, 1, 2, 2, 2>,
|
||||
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 4, 64, 2, 2, 0, 1, 1, 1, 1, 1, 1>,
|
||||
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 4, 64, 2, 2, 0, 2, 2, 2, 1, 1, 1>,
|
||||
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 4, 64, 2, 2, 1, 1, 1, 1, 2, 2, 2>,
|
||||
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 4, 64, 2, 2, 1, 1, 1, 1, 1, 1, 1>,
|
||||
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 2, 128, 2, 2, 0, 2, 2, 2, 2, 2, 2>,
|
||||
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 2, 128, 2, 2, 0, 1, 1, 1, 2, 2, 2>,
|
||||
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 2, 128, 2, 2, 0, 1, 1, 1, 1, 1, 1>,
|
||||
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 2, 128, 2, 2, 0, 2, 2, 2, 1, 1, 1>,
|
||||
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 2, 128, 2, 2, 1, 1, 1, 1, 2, 2, 2>,
|
||||
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 2, 128, 2, 2, 1, 1, 1, 1, 1, 1, 1>,
|
||||
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 1, 256, 2, 2, 0, 2, 2, 2, 2, 2, 2>,
|
||||
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 1, 256, 2, 2, 0, 1, 1, 1, 2, 2, 2>,
|
||||
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 1, 256, 2, 2, 0, 1, 1, 1, 1, 1, 1>,
|
||||
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 1, 256, 2, 2, 0, 2, 2, 2, 1, 1, 1>,
|
||||
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 1, 256, 2, 2, 1, 1, 1, 1, 2, 2, 2>,
|
||||
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 1, 256, 2, 2, 1, 1, 1, 1, 1, 1, 1>
|
||||
>;
|
||||
// clang-format on
|
||||
|
||||
// clang-format off
|
||||
template <index_t Rank, index_t NumReduceDim, typename DyElementwiseOp>
|
||||
using device_batchnorm_backward_bf16_multiblock_instances =
|
||||
std::tuple <
|
||||
// XDataType, DxDataType, DyDataType, AccDataType, ScaleDataType, BiasDataType, MeanVarDataType, DyElementwiseOp, Rank, NumReduceDim, UseMultiBlockInK, BLockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XDyDxVectorDim, XSrcVectorSize, DySrcVectorSize, DxDstVectorSize, ScaleSrcDstVectorSize, BiasDstVectorSize, MeanVarSrcVectorSize
|
||||
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 128, 2, 2, 2, 0, 2, 2, 2, 2, 2, 2>,
|
||||
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 128, 2, 2, 2, 0, 1, 1, 1, 2, 2, 2>,
|
||||
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 128, 2, 2, 2, 0, 1, 1, 1, 1, 1, 1>,
|
||||
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 128, 2, 2, 2, 0, 2, 2, 2, 1, 1, 1>,
|
||||
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 128, 2, 2, 2, 1, 1, 1, 1, 2, 2, 2>,
|
||||
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 128, 2, 2, 2, 1, 1, 1, 1, 1, 1, 1>,
|
||||
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 64, 4, 2, 2, 0, 2, 2, 2, 2, 2, 2>,
|
||||
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 64, 4, 2, 2, 0, 1, 1, 1, 2, 2, 2>,
|
||||
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 64, 4, 2, 2, 0, 1, 1, 1, 1, 1, 1>,
|
||||
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 64, 4, 2, 2, 0, 2, 2, 2, 1, 1, 1>,
|
||||
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 64, 4, 2, 2, 1, 1, 1, 1, 2, 2, 2>,
|
||||
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 64, 4, 2, 2, 1, 1, 1, 1, 1, 1, 1>,
|
||||
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 32, 8, 2, 2, 0, 2, 2, 2, 2, 2, 2>,
|
||||
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 32, 8, 2, 2, 0, 1, 1, 1, 2, 2, 2>,
|
||||
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 32, 8, 2, 2, 0, 1, 1, 1, 1, 1, 1>,
|
||||
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 32, 8, 2, 2, 0, 2, 2, 2, 1, 1, 1>,
|
||||
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 32, 8, 2, 2, 1, 1, 1, 1, 2, 2, 2>,
|
||||
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 32, 8, 2, 2, 1, 1, 1, 1, 1, 1, 1>,
|
||||
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 16, 16, 2, 2, 0, 2, 2, 2, 2, 2, 2>,
|
||||
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 16, 16, 2, 2, 0, 1, 1, 1, 2, 2, 2>,
|
||||
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 16, 16, 2, 2, 0, 1, 1, 1, 1, 1, 1>,
|
||||
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 16, 16, 2, 2, 0, 2, 2, 2, 1, 1, 1>,
|
||||
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 16, 16, 2, 2, 1, 1, 1, 1, 2, 2, 2>,
|
||||
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 16, 16, 2, 2, 1, 1, 1, 1, 1, 1, 1>,
|
||||
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 8, 32, 2, 2, 0, 2, 2, 2, 2, 2, 2>,
|
||||
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 8, 32, 2, 2, 0, 1, 1, 1, 2, 2, 2>,
|
||||
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 8, 32, 2, 2, 0, 1, 1, 1, 1, 1, 1>,
|
||||
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 8, 32, 2, 2, 0, 2, 2, 2, 1, 1, 1>,
|
||||
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 8, 32, 2, 2, 1, 1, 1, 1, 2, 2, 2>,
|
||||
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 8, 32, 2, 2, 1, 1, 1, 1, 1, 1, 1>,
|
||||
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 4, 64, 2, 2, 0, 2, 2, 2, 2, 2, 2>,
|
||||
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 4, 64, 2, 2, 0, 1, 1, 1, 2, 2, 2>,
|
||||
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 4, 64, 2, 2, 0, 1, 1, 1, 1, 1, 1>,
|
||||
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 4, 64, 2, 2, 0, 2, 2, 2, 1, 1, 1>,
|
||||
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 4, 64, 2, 2, 1, 1, 1, 1, 2, 2, 2>,
|
||||
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 4, 64, 2, 2, 1, 1, 1, 1, 1, 1, 1>,
|
||||
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 2, 128, 2, 2, 0, 2, 2, 2, 2, 2, 2>,
|
||||
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 2, 128, 2, 2, 0, 1, 1, 1, 2, 2, 2>,
|
||||
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 2, 128, 2, 2, 0, 1, 1, 1, 1, 1, 1>,
|
||||
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 2, 128, 2, 2, 0, 2, 2, 2, 1, 1, 1>,
|
||||
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 2, 128, 2, 2, 1, 1, 1, 1, 2, 2, 2>,
|
||||
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 2, 128, 2, 2, 1, 1, 1, 1, 1, 1, 1>,
|
||||
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 1, 256, 2, 2, 0, 2, 2, 2, 2, 2, 2>,
|
||||
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 1, 256, 2, 2, 0, 1, 1, 1, 2, 2, 2>,
|
||||
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 1, 256, 2, 2, 0, 1, 1, 1, 1, 1, 1>,
|
||||
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 1, 256, 2, 2, 0, 2, 2, 2, 1, 1, 1>,
|
||||
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 1, 256, 2, 2, 1, 1, 1, 1, 2, 2, 2>,
|
||||
DeviceBatchNormBwdImpl<BF16, F32, F32, F32, BF16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 1, 256, 2, 2, 1, 1, 1, 1, 1, 1, 1>
|
||||
>;
|
||||
// clang-format on
|
||||
|
||||
void add_device_batchnorm_backward_rank_4_3_bf16_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceBatchNormBwd<BF16, F32, F32, F32, BF16, F32, F32, PassThrough, 4, 3>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances, device_batchnorm_backward_bf16_blockwise_instances<4, 3, PassThrough>{});
|
||||
add_device_operation_instances(
|
||||
instances, device_batchnorm_backward_bf16_multiblock_instances<4, 3, PassThrough>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,147 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_batchnorm_backward_impl.hpp"
|
||||
#include "ck/utility/data_type.hpp"
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
using F16 = ck::half_t;
|
||||
using F32 = float;
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
// clang-format off
|
||||
template <index_t Rank, index_t NumReduceDim, typename DyElementwiseOp>
|
||||
using device_batchnorm_backward_f16_blockwise_instances =
|
||||
std::tuple <
|
||||
// XDataType, DxDataType, DyDataType, AccDataType, ScaleDataType, DscaleDbiasDataType, MeanVarDataType, DyElementwiseOp, Rank, NumReduceDim, UseMultiBlockInK, BLockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XDyDxVectorDim, XSrcVectorSize, DySrcVectorSize, DxDstVectorSize, ScaleSrcVectorSize, DscaleDbiasDstVectorSize, MeanVarSrcVectorSize
|
||||
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 128, 2, 2, 2, 0, 2, 2, 2, 2, 2, 2>,
|
||||
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 128, 2, 2, 2, 0, 1, 1, 1, 2, 2, 2>,
|
||||
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 128, 2, 2, 2, 0, 1, 1, 1, 1, 1, 1>,
|
||||
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 128, 2, 2, 2, 0, 2, 2, 2, 1, 1, 1>,
|
||||
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 128, 2, 2, 2, 1, 1, 1, 1, 2, 2, 2>,
|
||||
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 128, 2, 2, 2, 1, 1, 1, 1, 1, 1, 1>,
|
||||
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 64, 4, 2, 2, 0, 2, 2, 2, 2, 2, 2>,
|
||||
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 64, 4, 2, 2, 0, 1, 1, 1, 2, 2, 2>,
|
||||
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 64, 4, 2, 2, 0, 1, 1, 1, 1, 1, 1>,
|
||||
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 64, 4, 2, 2, 0, 2, 2, 2, 1, 1, 1>,
|
||||
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 64, 4, 2, 2, 1, 1, 1, 1, 2, 2, 2>,
|
||||
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 64, 4, 2, 2, 1, 1, 1, 1, 1, 1, 1>,
|
||||
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 32, 8, 2, 2, 0, 2, 2, 2, 2, 2, 2>,
|
||||
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 32, 8, 2, 2, 0, 1, 1, 1, 2, 2, 2>,
|
||||
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 32, 8, 2, 2, 0, 1, 1, 1, 1, 1, 1>,
|
||||
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 32, 8, 2, 2, 0, 2, 2, 2, 1, 1, 1>,
|
||||
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 32, 8, 2, 2, 1, 1, 1, 1, 2, 2, 2>,
|
||||
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 32, 8, 2, 2, 1, 1, 1, 1, 1, 1, 1>,
|
||||
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 16, 16, 2, 2, 0, 2, 2, 2, 2, 2, 2>,
|
||||
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 16, 16, 2, 2, 0, 1, 1, 1, 2, 2, 2>,
|
||||
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 16, 16, 2, 2, 0, 1, 1, 1, 1, 1, 1>,
|
||||
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 16, 16, 2, 2, 0, 2, 2, 2, 1, 1, 1>,
|
||||
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 16, 16, 2, 2, 1, 1, 1, 1, 2, 2, 2>,
|
||||
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 16, 16, 2, 2, 1, 1, 1, 1, 1, 1, 1>,
|
||||
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 8, 32, 2, 2, 0, 2, 2, 2, 2, 2, 2>,
|
||||
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 8, 32, 2, 2, 0, 1, 1, 1, 2, 2, 2>,
|
||||
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 8, 32, 2, 2, 0, 1, 1, 1, 1, 1, 1>,
|
||||
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 8, 32, 2, 2, 0, 2, 2, 2, 1, 1, 1>,
|
||||
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 8, 32, 2, 2, 1, 1, 1, 1, 2, 2, 2>,
|
||||
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 8, 32, 2, 2, 1, 1, 1, 1, 1, 1, 1>,
|
||||
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 4, 64, 2, 2, 0, 2, 2, 2, 2, 2, 2>,
|
||||
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 4, 64, 2, 2, 0, 1, 1, 1, 2, 2, 2>,
|
||||
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 4, 64, 2, 2, 0, 1, 1, 1, 1, 1, 1>,
|
||||
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 4, 64, 2, 2, 0, 2, 2, 2, 1, 1, 1>,
|
||||
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 4, 64, 2, 2, 1, 1, 1, 1, 2, 2, 2>,
|
||||
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 4, 64, 2, 2, 1, 1, 1, 1, 1, 1, 1>,
|
||||
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 2, 128, 2, 2, 0, 2, 2, 2, 2, 2, 2>,
|
||||
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 2, 128, 2, 2, 0, 1, 1, 1, 2, 2, 2>,
|
||||
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 2, 128, 2, 2, 0, 1, 1, 1, 1, 1, 1>,
|
||||
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 2, 128, 2, 2, 0, 2, 2, 2, 1, 1, 1>,
|
||||
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 2, 128, 2, 2, 1, 1, 1, 1, 2, 2, 2>,
|
||||
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 2, 128, 2, 2, 1, 1, 1, 1, 1, 1, 1>,
|
||||
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 1, 256, 2, 2, 0, 2, 2, 2, 2, 2, 2>,
|
||||
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 1, 256, 2, 2, 0, 1, 1, 1, 2, 2, 2>,
|
||||
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 1, 256, 2, 2, 0, 1, 1, 1, 1, 1, 1>,
|
||||
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 1, 256, 2, 2, 0, 2, 2, 2, 1, 1, 1>,
|
||||
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 1, 256, 2, 2, 1, 1, 1, 1, 2, 2, 2>,
|
||||
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 1, 256, 2, 2, 1, 1, 1, 1, 1, 1, 1>
|
||||
>;
|
||||
// clang-format on
|
||||
|
||||
// clang-format off
|
||||
template <index_t Rank, index_t NumReduceDim, typename DyElementwiseOp>
|
||||
using device_batchnorm_backward_f16_multiblock_instances =
|
||||
std::tuple <
|
||||
// XDataType, DxDataType, DyDataType, AccDataType, ScaleDataType, BiasDataType, MeanVarDataType, DyElementwiseOp, Rank, NumReduceDim, UseMultiBlockInK, BLockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XDyDxVectorDim, XSrcVectorSize, DySrcVectorSize, DxDstVectorSize, ScaleSrcDstVectorSize, BiasDstVectorSize, MeanVarSrcVectorSize
|
||||
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 128, 2, 2, 2, 0, 2, 2, 2, 2, 2, 2>,
|
||||
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 128, 2, 2, 2, 0, 1, 1, 1, 2, 2, 2>,
|
||||
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 128, 2, 2, 2, 0, 1, 1, 1, 1, 1, 1>,
|
||||
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 128, 2, 2, 2, 0, 2, 2, 2, 1, 1, 1>,
|
||||
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 128, 2, 2, 2, 1, 1, 1, 1, 2, 2, 2>,
|
||||
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 128, 2, 2, 2, 1, 1, 1, 1, 1, 1, 1>,
|
||||
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 64, 4, 2, 2, 0, 2, 2, 2, 2, 2, 2>,
|
||||
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 64, 4, 2, 2, 0, 1, 1, 1, 2, 2, 2>,
|
||||
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 64, 4, 2, 2, 0, 1, 1, 1, 1, 1, 1>,
|
||||
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 64, 4, 2, 2, 0, 2, 2, 2, 1, 1, 1>,
|
||||
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 64, 4, 2, 2, 1, 1, 1, 1, 2, 2, 2>,
|
||||
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 64, 4, 2, 2, 1, 1, 1, 1, 1, 1, 1>,
|
||||
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 32, 8, 2, 2, 0, 2, 2, 2, 2, 2, 2>,
|
||||
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 32, 8, 2, 2, 0, 1, 1, 1, 2, 2, 2>,
|
||||
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 32, 8, 2, 2, 0, 1, 1, 1, 1, 1, 1>,
|
||||
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 32, 8, 2, 2, 0, 2, 2, 2, 1, 1, 1>,
|
||||
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 32, 8, 2, 2, 1, 1, 1, 1, 2, 2, 2>,
|
||||
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 32, 8, 2, 2, 1, 1, 1, 1, 1, 1, 1>,
|
||||
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 16, 16, 2, 2, 0, 2, 2, 2, 2, 2, 2>,
|
||||
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 16, 16, 2, 2, 0, 1, 1, 1, 2, 2, 2>,
|
||||
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 16, 16, 2, 2, 0, 1, 1, 1, 1, 1, 1>,
|
||||
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 16, 16, 2, 2, 0, 2, 2, 2, 1, 1, 1>,
|
||||
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 16, 16, 2, 2, 1, 1, 1, 1, 2, 2, 2>,
|
||||
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 16, 16, 2, 2, 1, 1, 1, 1, 1, 1, 1>,
|
||||
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 8, 32, 2, 2, 0, 2, 2, 2, 2, 2, 2>,
|
||||
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 8, 32, 2, 2, 0, 1, 1, 1, 2, 2, 2>,
|
||||
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 8, 32, 2, 2, 0, 1, 1, 1, 1, 1, 1>,
|
||||
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 8, 32, 2, 2, 0, 2, 2, 2, 1, 1, 1>,
|
||||
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 8, 32, 2, 2, 1, 1, 1, 1, 2, 2, 2>,
|
||||
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 8, 32, 2, 2, 1, 1, 1, 1, 1, 1, 1>,
|
||||
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 4, 64, 2, 2, 0, 2, 2, 2, 2, 2, 2>,
|
||||
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 4, 64, 2, 2, 0, 1, 1, 1, 2, 2, 2>,
|
||||
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 4, 64, 2, 2, 0, 1, 1, 1, 1, 1, 1>,
|
||||
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 4, 64, 2, 2, 0, 2, 2, 2, 1, 1, 1>,
|
||||
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 4, 64, 2, 2, 1, 1, 1, 1, 2, 2, 2>,
|
||||
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 4, 64, 2, 2, 1, 1, 1, 1, 1, 1, 1>,
|
||||
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 2, 128, 2, 2, 0, 2, 2, 2, 2, 2, 2>,
|
||||
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 2, 128, 2, 2, 0, 1, 1, 1, 2, 2, 2>,
|
||||
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 2, 128, 2, 2, 0, 1, 1, 1, 1, 1, 1>,
|
||||
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 2, 128, 2, 2, 0, 2, 2, 2, 1, 1, 1>,
|
||||
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 2, 128, 2, 2, 1, 1, 1, 1, 2, 2, 2>,
|
||||
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 2, 128, 2, 2, 1, 1, 1, 1, 1, 1, 1>,
|
||||
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 1, 256, 2, 2, 0, 2, 2, 2, 2, 2, 2>,
|
||||
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 1, 256, 2, 2, 0, 1, 1, 1, 2, 2, 2>,
|
||||
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 1, 256, 2, 2, 0, 1, 1, 1, 1, 1, 1>,
|
||||
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 1, 256, 2, 2, 0, 2, 2, 2, 1, 1, 1>,
|
||||
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 1, 256, 2, 2, 1, 1, 1, 1, 2, 2, 2>,
|
||||
DeviceBatchNormBwdImpl<F16, F32, F32, F32, F16, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 1, 256, 2, 2, 1, 1, 1, 1, 1, 1, 1>
|
||||
>;
|
||||
// clang-format on
|
||||
|
||||
void add_device_batchnorm_backward_rank_4_3_f16_instances(
|
||||
std::vector<
|
||||
std::unique_ptr<DeviceBatchNormBwd<F16, F32, F32, F32, F16, F32, F32, PassThrough, 4, 3>>>&
|
||||
instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances, device_batchnorm_backward_f16_blockwise_instances<4, 3, PassThrough>{});
|
||||
add_device_operation_instances(
|
||||
instances, device_batchnorm_backward_f16_multiblock_instances<4, 3, PassThrough>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,145 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_batchnorm_backward_impl.hpp"
|
||||
#include "ck/utility/data_type.hpp"
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
using F32 = float;
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
// clang-format off
|
||||
template <index_t Rank, index_t NumReduceDim, typename DyElementwiseOp>
|
||||
using device_batchnorm_backward_f32_blockwise_instances = std::tuple<
|
||||
// XDataType, DxDataType, DyDataType, AccDataType, ScaleDataType, DscaleDbiasDataType, MeanVarDataType, DyElementwiseOp, Rank, NumReduceDim, UseMultiBlockInK, BLockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XDyDxVectorDim, XSrcVectorSize, DySrcVectorSize, DxDstVectorSize, ScaleSrcVectorSize, DscaleDbiasDstVectorSize, MeanVarSrcVectorSize
|
||||
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 128, 2, 2, 2, 0, 2, 2, 2, 2, 2, 2>,
|
||||
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 128, 2, 2, 2, 0, 1, 1, 1, 2, 2, 2>,
|
||||
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 128, 2, 2, 2, 0, 1, 1, 1, 1, 1, 1>,
|
||||
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 128, 2, 2, 2, 0, 2, 2, 2, 1, 1, 1>,
|
||||
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 128, 2, 2, 2, 1, 1, 1, 1, 2, 2, 2>,
|
||||
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 128, 2, 2, 2, 1, 1, 1, 1, 1, 1, 1>,
|
||||
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 64, 4, 2, 2, 0, 2, 2, 2, 2, 2, 2>,
|
||||
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 64, 4, 2, 2, 0, 1, 1, 1, 2, 2, 2>,
|
||||
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 64, 4, 2, 2, 0, 1, 1, 1, 1, 1, 1>,
|
||||
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 64, 4, 2, 2, 0, 2, 2, 2, 1, 1, 1>,
|
||||
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 64, 4, 2, 2, 1, 1, 1, 1, 2, 2, 2>,
|
||||
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 64, 4, 2, 2, 1, 1, 1, 1, 1, 1, 1>,
|
||||
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 32, 8, 2, 2, 0, 2, 2, 2, 2, 2, 2>,
|
||||
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 32, 8, 2, 2, 0, 1, 1, 1, 2, 2, 2>,
|
||||
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 32, 8, 2, 2, 0, 1, 1, 1, 1, 1, 1>,
|
||||
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 32, 8, 2, 2, 0, 2, 2, 2, 1, 1, 1>,
|
||||
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 32, 8, 2, 2, 1, 1, 1, 1, 2, 2, 2>,
|
||||
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 32, 8, 2, 2, 1, 1, 1, 1, 1, 1, 1>,
|
||||
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 16, 16, 2, 2, 0, 2, 2, 2, 2, 2, 2>,
|
||||
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 16, 16, 2, 2, 0, 1, 1, 1, 2, 2, 2>,
|
||||
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 16, 16, 2, 2, 0, 1, 1, 1, 1, 1, 1>,
|
||||
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 16, 16, 2, 2, 0, 2, 2, 2, 1, 1, 1>,
|
||||
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 16, 16, 2, 2, 1, 1, 1, 1, 2, 2, 2>,
|
||||
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 16, 16, 2, 2, 1, 1, 1, 1, 1, 1, 1>,
|
||||
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 8, 32, 2, 2, 0, 2, 2, 2, 2, 2, 2>,
|
||||
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 8, 32, 2, 2, 0, 1, 1, 1, 2, 2, 2>,
|
||||
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 8, 32, 2, 2, 0, 1, 1, 1, 1, 1, 1>,
|
||||
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 8, 32, 2, 2, 0, 2, 2, 2, 1, 1, 1>,
|
||||
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 8, 32, 2, 2, 1, 1, 1, 1, 2, 2, 2>,
|
||||
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 8, 32, 2, 2, 1, 1, 1, 1, 1, 1, 1>,
|
||||
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 4, 64, 2, 2, 0, 2, 2, 2, 2, 2, 2>,
|
||||
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 4, 64, 2, 2, 0, 1, 1, 1, 2, 2, 2>,
|
||||
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 4, 64, 2, 2, 0, 1, 1, 1, 1, 1, 1>,
|
||||
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 4, 64, 2, 2, 0, 2, 2, 2, 1, 1, 1>,
|
||||
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 4, 64, 2, 2, 1, 1, 1, 1, 2, 2, 2>,
|
||||
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 4, 64, 2, 2, 1, 1, 1, 1, 1, 1, 1>,
|
||||
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 2, 128, 2, 2, 0, 2, 2, 2, 2, 2, 2>,
|
||||
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 2, 128, 2, 2, 0, 1, 1, 1, 2, 2, 2>,
|
||||
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 2, 128, 2, 2, 0, 1, 1, 1, 1, 1, 1>,
|
||||
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 2, 128, 2, 2, 0, 2, 2, 2, 1, 1, 1>,
|
||||
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 2, 128, 2, 2, 1, 1, 1, 1, 2, 2, 2>,
|
||||
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 2, 128, 2, 2, 1, 1, 1, 1, 1, 1, 1>,
|
||||
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 1, 256, 2, 2, 0, 2, 2, 2, 2, 2, 2>,
|
||||
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 1, 256, 2, 2, 0, 1, 1, 1, 2, 2, 2>,
|
||||
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 1, 256, 2, 2, 0, 1, 1, 1, 1, 1, 1>,
|
||||
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 1, 256, 2, 2, 0, 2, 2, 2, 1, 1, 1>,
|
||||
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 1, 256, 2, 2, 1, 1, 1, 1, 2, 2, 2>,
|
||||
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, false, 256, 1, 256, 2, 2, 1, 1, 1, 1, 1, 1, 1>
|
||||
>;
|
||||
// clang-format on
|
||||
|
||||
// clang-format off
|
||||
template <index_t Rank, index_t NumReduceDim, typename DyElementwiseOp>
|
||||
using device_batchnorm_backward_f32_multiblock_instances =
|
||||
std::tuple <
|
||||
// XDataType, DxDataType, DyDataType, AccDataType, ScaleDataType, BiasDataType, MeanVarDataType, DyElementwiseOp, Rank, NumReduceDim, UseMultiBlockInK, BLockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XDyDxVectorDim, XSrcVectorSize, DySrcVectorSize, DxDstVectorSize, ScaleSrcDstVectorSize, BiasDstVectorSize, MeanVarSrcVectorSize
|
||||
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 128, 2, 2, 2, 0, 2, 2, 2, 2, 2, 2>,
|
||||
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 128, 2, 2, 2, 0, 1, 1, 1, 2, 2, 2>,
|
||||
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 128, 2, 2, 2, 0, 1, 1, 1, 1, 1, 1>,
|
||||
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 128, 2, 2, 2, 0, 2, 2, 2, 1, 1, 1>,
|
||||
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 128, 2, 2, 2, 1, 1, 1, 1, 2, 2, 2>,
|
||||
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 128, 2, 2, 2, 1, 1, 1, 1, 1, 1, 1>,
|
||||
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 64, 4, 2, 2, 0, 2, 2, 2, 2, 2, 2>,
|
||||
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 64, 4, 2, 2, 0, 1, 1, 1, 2, 2, 2>,
|
||||
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 64, 4, 2, 2, 0, 1, 1, 1, 1, 1, 1>,
|
||||
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 64, 4, 2, 2, 0, 2, 2, 2, 1, 1, 1>,
|
||||
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 64, 4, 2, 2, 1, 1, 1, 1, 2, 2, 2>,
|
||||
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 64, 4, 2, 2, 1, 1, 1, 1, 1, 1, 1>,
|
||||
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 32, 8, 2, 2, 0, 2, 2, 2, 2, 2, 2>,
|
||||
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 32, 8, 2, 2, 0, 1, 1, 1, 2, 2, 2>,
|
||||
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 32, 8, 2, 2, 0, 1, 1, 1, 1, 1, 1>,
|
||||
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 32, 8, 2, 2, 0, 2, 2, 2, 1, 1, 1>,
|
||||
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 32, 8, 2, 2, 1, 1, 1, 1, 2, 2, 2>,
|
||||
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 32, 8, 2, 2, 1, 1, 1, 1, 1, 1, 1>,
|
||||
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 16, 16, 2, 2, 0, 2, 2, 2, 2, 2, 2>,
|
||||
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 16, 16, 2, 2, 0, 1, 1, 1, 2, 2, 2>,
|
||||
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 16, 16, 2, 2, 0, 1, 1, 1, 1, 1, 1>,
|
||||
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 16, 16, 2, 2, 0, 2, 2, 2, 1, 1, 1>,
|
||||
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 16, 16, 2, 2, 1, 1, 1, 1, 2, 2, 2>,
|
||||
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 16, 16, 2, 2, 1, 1, 1, 1, 1, 1, 1>,
|
||||
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 8, 32, 2, 2, 0, 2, 2, 2, 2, 2, 2>,
|
||||
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 8, 32, 2, 2, 0, 1, 1, 1, 2, 2, 2>,
|
||||
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 8, 32, 2, 2, 0, 1, 1, 1, 1, 1, 1>,
|
||||
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 8, 32, 2, 2, 0, 2, 2, 2, 1, 1, 1>,
|
||||
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 8, 32, 2, 2, 1, 1, 1, 1, 2, 2, 2>,
|
||||
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 8, 32, 2, 2, 1, 1, 1, 1, 1, 1, 1>,
|
||||
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 4, 64, 2, 2, 0, 2, 2, 2, 2, 2, 2>,
|
||||
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 4, 64, 2, 2, 0, 1, 1, 1, 2, 2, 2>,
|
||||
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 4, 64, 2, 2, 0, 1, 1, 1, 1, 1, 1>,
|
||||
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 4, 64, 2, 2, 0, 2, 2, 2, 1, 1, 1>,
|
||||
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 4, 64, 2, 2, 1, 1, 1, 1, 2, 2, 2>,
|
||||
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 4, 64, 2, 2, 1, 1, 1, 1, 1, 1, 1>,
|
||||
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 2, 128, 2, 2, 0, 2, 2, 2, 2, 2, 2>,
|
||||
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 2, 128, 2, 2, 0, 1, 1, 1, 2, 2, 2>,
|
||||
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 2, 128, 2, 2, 0, 1, 1, 1, 1, 1, 1>,
|
||||
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 2, 128, 2, 2, 0, 2, 2, 2, 1, 1, 1>,
|
||||
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 2, 128, 2, 2, 1, 1, 1, 1, 2, 2, 2>,
|
||||
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 2, 128, 2, 2, 1, 1, 1, 1, 1, 1, 1>,
|
||||
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 1, 256, 2, 2, 0, 2, 2, 2, 2, 2, 2>,
|
||||
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 1, 256, 2, 2, 0, 1, 1, 1, 2, 2, 2>,
|
||||
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 1, 256, 2, 2, 0, 1, 1, 1, 1, 1, 1>,
|
||||
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 1, 256, 2, 2, 0, 2, 2, 2, 1, 1, 1>,
|
||||
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 1, 256, 2, 2, 1, 1, 1, 1, 2, 2, 2>,
|
||||
DeviceBatchNormBwdImpl<F32, F32, F32, F32, F32, F32, F32, DyElementwiseOp, Rank, NumReduceDim, true, 256, 1, 256, 2, 2, 1, 1, 1, 1, 1, 1, 1>
|
||||
>;
|
||||
// clang-format on
|
||||
|
||||
void add_device_batchnorm_backward_rank_4_3_f32_instances(
|
||||
std::vector<
|
||||
std::unique_ptr<DeviceBatchNormBwd<F32, F32, F32, F32, F32, F32, F32, PassThrough, 4, 3>>>&
|
||||
instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances, device_batchnorm_backward_f32_blockwise_instances<4, 3, PassThrough>{});
|
||||
add_device_operation_instances(
|
||||
instances, device_batchnorm_backward_f32_multiblock_instances<4, 3, PassThrough>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,145 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_batchnorm_backward_impl.hpp"
|
||||
#include "ck/utility/data_type.hpp"
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
using F64 = double;
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
// clang-format off
|
||||
template <index_t Rank, index_t NumReduceDim, typename DyElementwiseOp>
|
||||
using device_batchnorm_backward_f64_blockwise_instances = std::tuple<
|
||||
// XDataType, DxDataType, DyDataType, AccDataType, ScaleDataType, DscaleDbiasDataType, MeanVarDataType, DyElementwiseOp, Rank, NumReduceDim, UseMultiBlockInK, BLockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XDyDxVectorDim, XSrcVectorSize, DySrcVectorSize, DxDstVectorSize, ScaleSrcVectorSize, DscaleDbiasDstVectorSize, MeanVarSrcVectorSize
|
||||
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, false, 256, 128, 2, 2, 2, 0, 2, 2, 2, 2, 2, 2>,
|
||||
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, false, 256, 128, 2, 2, 2, 0, 1, 1, 1, 2, 2, 2>,
|
||||
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, false, 256, 128, 2, 2, 2, 0, 1, 1, 1, 1, 1, 1>,
|
||||
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, false, 256, 128, 2, 2, 2, 0, 2, 2, 2, 1, 1, 1>,
|
||||
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, false, 256, 128, 2, 2, 2, 1, 1, 1, 1, 2, 2, 2>,
|
||||
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, false, 256, 128, 2, 2, 2, 1, 1, 1, 1, 1, 1, 1>,
|
||||
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, false, 256, 64, 4, 2, 2, 0, 2, 2, 2, 2, 2, 2>,
|
||||
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, false, 256, 64, 4, 2, 2, 0, 1, 1, 1, 2, 2, 2>,
|
||||
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, false, 256, 64, 4, 2, 2, 0, 1, 1, 1, 1, 1, 1>,
|
||||
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, false, 256, 64, 4, 2, 2, 0, 2, 2, 2, 1, 1, 1>,
|
||||
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, false, 256, 64, 4, 2, 2, 1, 1, 1, 1, 2, 2, 2>,
|
||||
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, false, 256, 64, 4, 2, 2, 1, 1, 1, 1, 1, 1, 1>,
|
||||
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, false, 256, 32, 8, 2, 2, 0, 2, 2, 2, 2, 2, 2>,
|
||||
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, false, 256, 32, 8, 2, 2, 0, 1, 1, 1, 2, 2, 2>,
|
||||
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, false, 256, 32, 8, 2, 2, 0, 1, 1, 1, 1, 1, 1>,
|
||||
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, false, 256, 32, 8, 2, 2, 0, 2, 2, 2, 1, 1, 1>,
|
||||
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, false, 256, 32, 8, 2, 2, 1, 1, 1, 1, 2, 2, 2>,
|
||||
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, false, 256, 32, 8, 2, 2, 1, 1, 1, 1, 1, 1, 1>,
|
||||
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, false, 256, 16, 16, 2, 2, 0, 2, 2, 2, 2, 2, 2>,
|
||||
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, false, 256, 16, 16, 2, 2, 0, 1, 1, 1, 2, 2, 2>,
|
||||
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, false, 256, 16, 16, 2, 2, 0, 1, 1, 1, 1, 1, 1>,
|
||||
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, false, 256, 16, 16, 2, 2, 0, 2, 2, 2, 1, 1, 1>,
|
||||
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, false, 256, 16, 16, 2, 2, 1, 1, 1, 1, 2, 2, 2>,
|
||||
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, false, 256, 16, 16, 2, 2, 1, 1, 1, 1, 1, 1, 1>,
|
||||
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, false, 256, 8, 32, 2, 2, 0, 2, 2, 2, 2, 2, 2>,
|
||||
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, false, 256, 8, 32, 2, 2, 0, 1, 1, 1, 2, 2, 2>,
|
||||
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, false, 256, 8, 32, 2, 2, 0, 1, 1, 1, 1, 1, 1>,
|
||||
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, false, 256, 8, 32, 2, 2, 0, 2, 2, 2, 1, 1, 1>,
|
||||
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, false, 256, 8, 32, 2, 2, 1, 1, 1, 1, 2, 2, 2>,
|
||||
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, false, 256, 8, 32, 2, 2, 1, 1, 1, 1, 1, 1, 1>,
|
||||
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, false, 256, 4, 64, 2, 2, 0, 2, 2, 2, 2, 2, 2>,
|
||||
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, false, 256, 4, 64, 2, 2, 0, 1, 1, 1, 2, 2, 2>,
|
||||
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, false, 256, 4, 64, 2, 2, 0, 1, 1, 1, 1, 1, 1>,
|
||||
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, false, 256, 4, 64, 2, 2, 0, 2, 2, 2, 1, 1, 1>,
|
||||
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, false, 256, 4, 64, 2, 2, 1, 1, 1, 1, 2, 2, 2>,
|
||||
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, false, 256, 4, 64, 2, 2, 1, 1, 1, 1, 1, 1, 1>,
|
||||
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, false, 256, 2, 128, 2, 2, 0, 2, 2, 2, 2, 2, 2>,
|
||||
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, false, 256, 2, 128, 2, 2, 0, 1, 1, 1, 2, 2, 2>,
|
||||
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, false, 256, 2, 128, 2, 2, 0, 1, 1, 1, 1, 1, 1>,
|
||||
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, false, 256, 2, 128, 2, 2, 0, 2, 2, 2, 1, 1, 1>,
|
||||
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, false, 256, 2, 128, 2, 2, 1, 1, 1, 1, 2, 2, 2>,
|
||||
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, false, 256, 2, 128, 2, 2, 1, 1, 1, 1, 1, 1, 1>,
|
||||
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, false, 256, 1, 256, 2, 2, 0, 2, 2, 2, 2, 2, 2>,
|
||||
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, false, 256, 1, 256, 2, 2, 0, 1, 1, 1, 2, 2, 2>,
|
||||
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, false, 256, 1, 256, 2, 2, 0, 1, 1, 1, 1, 1, 1>,
|
||||
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, false, 256, 1, 256, 2, 2, 0, 2, 2, 2, 1, 1, 1>,
|
||||
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, false, 256, 1, 256, 2, 2, 1, 1, 1, 1, 2, 2, 2>,
|
||||
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, false, 256, 1, 256, 2, 2, 1, 1, 1, 1, 1, 1, 1>
|
||||
>;
|
||||
// clang-format on
|
||||
|
||||
// clang-format off
|
||||
template <index_t Rank, index_t NumReduceDim, typename DyElementwiseOp>
|
||||
using device_batchnorm_backward_f64_multiblock_instances =
|
||||
std::tuple <
|
||||
// XDataType, DxDataType, DyDataType, AccDataType, ScaleDataType, BiasDataType, MeanVarDataType, DyElementwiseOp, Rank, NumReduceDim, UseMultiBlockInK, BLockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XDyDxVectorDim, XSrcVectorSize, DySrcVectorSize, DxDstVectorSize, ScaleSrcDstVectorSize, BiasDstVectorSize, MeanVarSrcVectorSize
|
||||
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, true, 256, 128, 2, 2, 2, 0, 2, 2, 2, 2, 2, 2>,
|
||||
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, true, 256, 128, 2, 2, 2, 0, 1, 1, 1, 2, 2, 2>,
|
||||
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, true, 256, 128, 2, 2, 2, 0, 1, 1, 1, 1, 1, 1>,
|
||||
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, true, 256, 128, 2, 2, 2, 0, 2, 2, 2, 1, 1, 1>,
|
||||
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, true, 256, 128, 2, 2, 2, 1, 1, 1, 1, 2, 2, 2>,
|
||||
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, true, 256, 128, 2, 2, 2, 1, 1, 1, 1, 1, 1, 1>,
|
||||
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, true, 256, 64, 4, 2, 2, 0, 2, 2, 2, 2, 2, 2>,
|
||||
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, true, 256, 64, 4, 2, 2, 0, 1, 1, 1, 2, 2, 2>,
|
||||
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, true, 256, 64, 4, 2, 2, 0, 1, 1, 1, 1, 1, 1>,
|
||||
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, true, 256, 64, 4, 2, 2, 0, 2, 2, 2, 1, 1, 1>,
|
||||
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, true, 256, 64, 4, 2, 2, 1, 1, 1, 1, 2, 2, 2>,
|
||||
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, true, 256, 64, 4, 2, 2, 1, 1, 1, 1, 1, 1, 1>,
|
||||
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, true, 256, 32, 8, 2, 2, 0, 2, 2, 2, 2, 2, 2>,
|
||||
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, true, 256, 32, 8, 2, 2, 0, 1, 1, 1, 2, 2, 2>,
|
||||
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, true, 256, 32, 8, 2, 2, 0, 1, 1, 1, 1, 1, 1>,
|
||||
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, true, 256, 32, 8, 2, 2, 0, 2, 2, 2, 1, 1, 1>,
|
||||
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, true, 256, 32, 8, 2, 2, 1, 1, 1, 1, 2, 2, 2>,
|
||||
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, true, 256, 32, 8, 2, 2, 1, 1, 1, 1, 1, 1, 1>,
|
||||
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, true, 256, 16, 16, 2, 2, 0, 2, 2, 2, 2, 2, 2>,
|
||||
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, true, 256, 16, 16, 2, 2, 0, 1, 1, 1, 2, 2, 2>,
|
||||
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, true, 256, 16, 16, 2, 2, 0, 1, 1, 1, 1, 1, 1>,
|
||||
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, true, 256, 16, 16, 2, 2, 0, 2, 2, 2, 1, 1, 1>,
|
||||
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, true, 256, 16, 16, 2, 2, 1, 1, 1, 1, 2, 2, 2>,
|
||||
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, true, 256, 16, 16, 2, 2, 1, 1, 1, 1, 1, 1, 1>,
|
||||
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, true, 256, 8, 32, 2, 2, 0, 2, 2, 2, 2, 2, 2>,
|
||||
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, true, 256, 8, 32, 2, 2, 0, 1, 1, 1, 2, 2, 2>,
|
||||
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, true, 256, 8, 32, 2, 2, 0, 1, 1, 1, 1, 1, 1>,
|
||||
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, true, 256, 8, 32, 2, 2, 0, 2, 2, 2, 1, 1, 1>,
|
||||
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, true, 256, 8, 32, 2, 2, 1, 1, 1, 1, 2, 2, 2>,
|
||||
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, true, 256, 8, 32, 2, 2, 1, 1, 1, 1, 1, 1, 1>,
|
||||
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, true, 256, 4, 64, 2, 2, 0, 2, 2, 2, 2, 2, 2>,
|
||||
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, true, 256, 4, 64, 2, 2, 0, 1, 1, 1, 2, 2, 2>,
|
||||
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, true, 256, 4, 64, 2, 2, 0, 1, 1, 1, 1, 1, 1>,
|
||||
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, true, 256, 4, 64, 2, 2, 0, 2, 2, 2, 1, 1, 1>,
|
||||
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, true, 256, 4, 64, 2, 2, 1, 1, 1, 1, 2, 2, 2>,
|
||||
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, true, 256, 4, 64, 2, 2, 1, 1, 1, 1, 1, 1, 1>,
|
||||
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, true, 256, 2, 128, 2, 2, 0, 2, 2, 2, 2, 2, 2>,
|
||||
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, true, 256, 2, 128, 2, 2, 0, 1, 1, 1, 2, 2, 2>,
|
||||
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, true, 256, 2, 128, 2, 2, 0, 1, 1, 1, 1, 1, 1>,
|
||||
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, true, 256, 2, 128, 2, 2, 0, 2, 2, 2, 1, 1, 1>,
|
||||
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, true, 256, 2, 128, 2, 2, 1, 1, 1, 1, 2, 2, 2>,
|
||||
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, true, 256, 2, 128, 2, 2, 1, 1, 1, 1, 1, 1, 1>,
|
||||
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, true, 256, 1, 256, 2, 2, 0, 2, 2, 2, 2, 2, 2>,
|
||||
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, true, 256, 1, 256, 2, 2, 0, 1, 1, 1, 2, 2, 2>,
|
||||
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, true, 256, 1, 256, 2, 2, 0, 1, 1, 1, 1, 1, 1>,
|
||||
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, true, 256, 1, 256, 2, 2, 0, 2, 2, 2, 1, 1, 1>,
|
||||
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, true, 256, 1, 256, 2, 2, 1, 1, 1, 1, 2, 2, 2>,
|
||||
DeviceBatchNormBwdImpl<F64, F64, F64, F64, F64, F64, F64, DyElementwiseOp, Rank, NumReduceDim, true, 256, 1, 256, 2, 2, 1, 1, 1, 1, 1, 1, 1>
|
||||
>;
|
||||
// clang-format on
|
||||
|
||||
void add_device_batchnorm_backward_rank_4_3_f64_instances(
|
||||
std::vector<
|
||||
std::unique_ptr<DeviceBatchNormBwd<F64, F64, F64, F64, F64, F64, F64, PassThrough, 4, 3>>>&
|
||||
instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances, device_batchnorm_backward_f64_blockwise_instances<4, 3, PassThrough>{});
|
||||
add_device_operation_instances(
|
||||
instances, device_batchnorm_backward_f64_multiblock_instances<4, 3, PassThrough>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -27,6 +27,7 @@ set(PROFILER_SOURCE
|
||||
src/profile_layernorm.cpp
|
||||
src/profile_softmax.cpp
|
||||
src/profile_batchnorm_fwd.cpp
|
||||
src/profile_batchnorm_bwd.cpp
|
||||
)
|
||||
|
||||
add_executable(ckProfiler ${PROFILER_SOURCE})
|
||||
|
||||
390
profiler/include/profile_batchnorm_backward_impl.hpp
Normal file
390
profiler/include/profile_batchnorm_backward_impl.hpp
Normal file
@@ -0,0 +1,390 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <iomanip>
|
||||
#include <stdexcept>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
#include "ck/library/utility/check_err.hpp"
|
||||
#include "ck/library/utility/device_memory.hpp"
|
||||
#include "ck/library/utility/host_tensor.hpp"
|
||||
#include "ck/library/utility/host_tensor_generator.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/batchnorm_backward.hpp"
|
||||
#include "ck/library/reference_tensor_operation/cpu/reference_batchnorm_backward.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace profiler {
|
||||
|
||||
template <typename XDataType,
|
||||
typename DxDataType,
|
||||
typename DyDataType,
|
||||
typename AccDataType,
|
||||
typename ScaleDataType,
|
||||
typename DscaleDbiasDataType,
|
||||
typename MeanVarDataType,
|
||||
index_t Rank,
|
||||
index_t NumBatchNormReduceDim>
|
||||
bool profile_batchnorm_backward_impl(bool do_verification,
|
||||
int init_method,
|
||||
bool do_dumpout,
|
||||
bool time_kernel,
|
||||
const std::vector<size_t> inOutLengths,
|
||||
const std::vector<int> reduceDims,
|
||||
bool haveSavedMeanInvVar,
|
||||
double epsilon)
|
||||
{
|
||||
if(inOutLengths.size() != Rank || reduceDims.size() != NumBatchNormReduceDim)
|
||||
{
|
||||
throw std::runtime_error("Invalid tensor lengths or number of reduce dimensions!");
|
||||
};
|
||||
|
||||
std::vector<size_t> scaleBiasMeanVarLengths;
|
||||
|
||||
// used for calculating the effective transferred bytes by each operation
|
||||
size_t total_length;
|
||||
size_t invariant_length = 1;
|
||||
|
||||
total_length =
|
||||
std::accumulate(inOutLengths.begin(), inOutLengths.end(), 1, std::multiplies<size_t>{});
|
||||
|
||||
if(std::any_of(reduceDims.begin(), reduceDims.end(), [](int d) { return d < 0 || d >= Rank; }))
|
||||
throw std::runtime_error("Invalid reduce dimensions!");
|
||||
|
||||
for(int dim = 0; dim < Rank; dim++)
|
||||
{
|
||||
if(std::none_of(reduceDims.begin(), reduceDims.end(), [&](int d) { return dim == d; }))
|
||||
{
|
||||
scaleBiasMeanVarLengths.push_back(inOutLengths[dim]);
|
||||
invariant_length *= inOutLengths[dim];
|
||||
};
|
||||
}
|
||||
|
||||
// input data of the batchnorm backward algorithm
|
||||
Tensor<XDataType> x(inOutLengths);
|
||||
Tensor<DyDataType> dy(inOutLengths);
|
||||
Tensor<ScaleDataType> bnScale(scaleBiasMeanVarLengths);
|
||||
|
||||
Tensor<MeanVarDataType> savedMean(scaleBiasMeanVarLengths);
|
||||
Tensor<MeanVarDataType> savedInvVar(scaleBiasMeanVarLengths);
|
||||
// savedVariance is only used for initializing savedInvVar
|
||||
Tensor<MeanVarDataType> savedVariance(scaleBiasMeanVarLengths);
|
||||
|
||||
// output data of the batchnorm backward algorithm
|
||||
Tensor<DxDataType> dx_ref(inOutLengths);
|
||||
Tensor<DxDataType> dx(inOutLengths);
|
||||
|
||||
Tensor<DscaleDbiasDataType> dscale(scaleBiasMeanVarLengths);
|
||||
Tensor<DscaleDbiasDataType> dbias(scaleBiasMeanVarLengths);
|
||||
|
||||
Tensor<DscaleDbiasDataType> dscale_ref(scaleBiasMeanVarLengths);
|
||||
Tensor<DscaleDbiasDataType> dbias_ref(scaleBiasMeanVarLengths);
|
||||
|
||||
auto inOutStrides = x.mDesc.GetStrides();
|
||||
auto scaleBiasMeanVarStrides = bnScale.mDesc.GetStrides();
|
||||
|
||||
std::size_t num_thread = std::thread::hardware_concurrency();
|
||||
|
||||
if(haveSavedMeanInvVar)
|
||||
{
|
||||
const float x_mean = 0.0f;
|
||||
const float x_stddev = 1.0f;
|
||||
const float noise_stddev = 0.0001f;
|
||||
|
||||
// input data in normal distribution
|
||||
x.GenerateTensorValue(GeneratorTensor_4<XDataType>{x_mean, x_stddev}, num_thread);
|
||||
|
||||
// initialize the savedMean to be values with tiny variation to the mean of the x values
|
||||
savedMean.GenerateTensorValue(GeneratorTensor_4<MeanVarDataType>{x_mean, noise_stddev},
|
||||
num_thread);
|
||||
|
||||
// initialize the variance to be values with tiny variation to the variance of the x values
|
||||
savedVariance.GenerateTensorValue(
|
||||
GeneratorTensor_4<MeanVarDataType>{x_stddev * x_stddev, noise_stddev}, num_thread);
|
||||
|
||||
auto it_src = savedVariance.mData.begin();
|
||||
auto it_dst = savedInvVar.mData.begin();
|
||||
float tmp_epsilon = std::numeric_limits<float>::epsilon();
|
||||
|
||||
while(it_src != savedVariance.mData.end())
|
||||
{
|
||||
*it_dst = type_convert<AccDataType>(
|
||||
1.0f / std::sqrtf(type_convert<float>(*it_src) + tmp_epsilon));
|
||||
|
||||
it_src++;
|
||||
it_dst++;
|
||||
};
|
||||
}
|
||||
else
|
||||
{
|
||||
const float x_mean = 0.0f;
|
||||
const float x_stddev = 1.0f;
|
||||
|
||||
// input data in normal distribution
|
||||
x.GenerateTensorValue(GeneratorTensor_4<XDataType>{x_mean, x_stddev}, num_thread);
|
||||
};
|
||||
|
||||
if(do_verification)
|
||||
{
|
||||
switch(init_method)
|
||||
{
|
||||
case 0:
|
||||
dy.GenerateTensorValue(GeneratorTensor_0<DyDataType>{}, num_thread);
|
||||
bnScale.GenerateTensorValue(GeneratorTensor_0<ScaleDataType>{}, num_thread);
|
||||
break;
|
||||
case 1:
|
||||
dy.GenerateTensorValue(GeneratorTensor_1<DyDataType>{1}, num_thread);
|
||||
bnScale.GenerateTensorValue(GeneratorTensor_1<ScaleDataType>{1}, num_thread);
|
||||
break;
|
||||
case 2:
|
||||
dy.GenerateTensorValue(GeneratorTensor_2<DyDataType>{-2, 2}, num_thread);
|
||||
bnScale.GenerateTensorValue(GeneratorTensor_2<ScaleDataType>{-5, 5}, num_thread);
|
||||
break;
|
||||
default:
|
||||
dy.GenerateTensorValue(GeneratorTensor_3<DyDataType>{-0.2f, 0.2f}, num_thread);
|
||||
bnScale.GenerateTensorValue(GeneratorTensor_3<ScaleDataType>{-0.5f, 0.5f}, num_thread);
|
||||
}
|
||||
};
|
||||
|
||||
// input data of the batchnorm backward algorithm
|
||||
DeviceMem x_dev(sizeof(XDataType) * x.mDesc.GetElementSpaceSize());
|
||||
DeviceMem dy_dev(sizeof(DyDataType) * dy.mDesc.GetElementSpaceSize());
|
||||
|
||||
DeviceMem bnScale_dev(sizeof(ScaleDataType) * bnScale.mDesc.GetElementSpaceSize());
|
||||
|
||||
DeviceMem savedMean_dev(sizeof(MeanVarDataType) * savedMean.mDesc.GetElementSpaceSize());
|
||||
DeviceMem savedInvVar_dev(sizeof(MeanVarDataType) * savedInvVar.mDesc.GetElementSpaceSize());
|
||||
|
||||
// output data of the batchnorm backward algorithm
|
||||
DeviceMem dx_dev(sizeof(DxDataType) * dx.mDesc.GetElementSpaceSize());
|
||||
|
||||
DeviceMem dscale_dev(sizeof(DscaleDbiasDataType) * dscale.mDesc.GetElementSpaceSize());
|
||||
DeviceMem dbias_dev(sizeof(DscaleDbiasDataType) * dbias.mDesc.GetElementSpaceSize());
|
||||
|
||||
x_dev.ToDevice(x.mData.data());
|
||||
dy_dev.ToDevice(dy.mData.data());
|
||||
bnScale_dev.ToDevice(bnScale.mData.data());
|
||||
|
||||
if(haveSavedMeanInvVar)
|
||||
{
|
||||
savedMean_dev.ToDevice(savedMean.mData.data());
|
||||
savedInvVar_dev.ToDevice(savedInvVar.mData.data());
|
||||
};
|
||||
|
||||
std::array<index_t, Rank> arrInOutLengths;
|
||||
std::array<index_t, Rank> arrInOutStrides;
|
||||
std::array<index_t, Rank - NumBatchNormReduceDim> arrScaleBiasMeanVarLengths;
|
||||
std::array<index_t, Rank - NumBatchNormReduceDim> arrScaleBiasMeanVarStrides;
|
||||
std::array<int, NumBatchNormReduceDim> arrReduceDims;
|
||||
|
||||
std::copy(inOutLengths.begin(), inOutLengths.end(), arrInOutLengths.begin());
|
||||
std::copy(inOutStrides.begin(), inOutStrides.end(), arrInOutStrides.begin());
|
||||
std::copy(scaleBiasMeanVarLengths.begin(),
|
||||
scaleBiasMeanVarLengths.end(),
|
||||
arrScaleBiasMeanVarLengths.begin());
|
||||
std::copy(scaleBiasMeanVarStrides.begin(),
|
||||
scaleBiasMeanVarStrides.end(),
|
||||
arrScaleBiasMeanVarStrides.begin());
|
||||
|
||||
std::copy(reduceDims.begin(), reduceDims.end(), arrReduceDims.begin());
|
||||
|
||||
using PassThroughOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
// add device batchnorm-backward instances
|
||||
using DeviceOp = ck::tensor_operation::device::DeviceBatchNormBwd<XDataType,
|
||||
DxDataType,
|
||||
DxDataType,
|
||||
AccDataType,
|
||||
ScaleDataType,
|
||||
DscaleDbiasDataType,
|
||||
MeanVarDataType,
|
||||
PassThroughOp,
|
||||
Rank,
|
||||
NumBatchNormReduceDim>;
|
||||
|
||||
// get device op instances
|
||||
const auto instance_ptrs =
|
||||
ck::tensor_operation::device::instance::DeviceOperationInstanceFactory<
|
||||
DeviceOp>::GetInstances();
|
||||
|
||||
std::cout << "found " << instance_ptrs.size() << " instances" << std::endl;
|
||||
|
||||
std::string best_instance_name;
|
||||
float best_avg_time = std::numeric_limits<float>::max();
|
||||
float best_gb_per_sec = 0;
|
||||
|
||||
if(do_verification)
|
||||
{
|
||||
using ReferenceBatchNormBwdInstance =
|
||||
ck::tensor_operation::host::ReferenceBatchNormBwd<XDataType,
|
||||
DxDataType,
|
||||
DyDataType,
|
||||
AccDataType,
|
||||
ScaleDataType,
|
||||
DscaleDbiasDataType,
|
||||
MeanVarDataType,
|
||||
PassThroughOp,
|
||||
Rank,
|
||||
NumBatchNormReduceDim>;
|
||||
|
||||
auto batchNormBwd_ref = ReferenceBatchNormBwdInstance{};
|
||||
|
||||
auto argument_ptr_ref = batchNormBwd_ref.MakeArgumentPointer(
|
||||
arrInOutLengths,
|
||||
arrInOutStrides,
|
||||
arrInOutStrides,
|
||||
arrInOutStrides,
|
||||
arrReduceDims,
|
||||
arrScaleBiasMeanVarLengths,
|
||||
arrScaleBiasMeanVarStrides,
|
||||
arrScaleBiasMeanVarStrides,
|
||||
arrScaleBiasMeanVarStrides,
|
||||
x.mData.data(),
|
||||
dy.mData.data(),
|
||||
bnScale.mData.data(),
|
||||
haveSavedMeanInvVar ? savedMean.mData.data() : nullptr,
|
||||
haveSavedMeanInvVar ? savedInvVar.mData.data() : nullptr,
|
||||
epsilon,
|
||||
PassThroughOp{},
|
||||
dx_ref.mData.data(),
|
||||
dscale_ref.mData.data(),
|
||||
dbias_ref.mData.data());
|
||||
|
||||
if(!batchNormBwd_ref.IsSupportedArgument(argument_ptr_ref.get()))
|
||||
{
|
||||
std::cout << "The runtime parameters not supported by the reference instance, exiting!"
|
||||
<< std::endl;
|
||||
return (false);
|
||||
};
|
||||
|
||||
auto invoker_ptr_ref = batchNormBwd_ref.MakeInvokerPointer();
|
||||
|
||||
(void)invoker_ptr_ref->Run(argument_ptr_ref.get());
|
||||
}
|
||||
|
||||
int num_kernel = 0;
|
||||
bool pass = true;
|
||||
|
||||
for(auto& inst_ptr : instance_ptrs)
|
||||
{
|
||||
auto argument_ptr = inst_ptr->MakeArgumentPointer(
|
||||
arrInOutLengths,
|
||||
arrInOutStrides,
|
||||
arrInOutStrides,
|
||||
arrInOutStrides,
|
||||
arrReduceDims,
|
||||
arrScaleBiasMeanVarLengths,
|
||||
arrScaleBiasMeanVarStrides,
|
||||
arrScaleBiasMeanVarStrides,
|
||||
arrScaleBiasMeanVarStrides,
|
||||
x_dev.GetDeviceBuffer(),
|
||||
dy_dev.GetDeviceBuffer(),
|
||||
bnScale_dev.GetDeviceBuffer(),
|
||||
haveSavedMeanInvVar ? savedMean_dev.GetDeviceBuffer() : nullptr,
|
||||
haveSavedMeanInvVar ? savedInvVar_dev.GetDeviceBuffer() : nullptr,
|
||||
epsilon,
|
||||
PassThroughOp{},
|
||||
dx_dev.GetDeviceBuffer(),
|
||||
dscale_dev.GetDeviceBuffer(),
|
||||
dbias_dev.GetDeviceBuffer());
|
||||
|
||||
if(inst_ptr->IsSupportedArgument(argument_ptr.get()))
|
||||
{
|
||||
num_kernel++;
|
||||
}
|
||||
else
|
||||
{
|
||||
if(time_kernel)
|
||||
{
|
||||
std::cout << inst_ptr->GetTypeString()
|
||||
<< " skipped due to unsupported argument: " << std::endl;
|
||||
}
|
||||
|
||||
continue;
|
||||
};
|
||||
|
||||
size_t workspace_sz = inst_ptr->GetWorkSpaceSize(argument_ptr.get());
|
||||
|
||||
DeviceMem workspace_dev(workspace_sz);
|
||||
|
||||
inst_ptr->SetWorkSpacePointer(argument_ptr.get(), workspace_dev.GetDeviceBuffer());
|
||||
|
||||
auto invoker_ptr = inst_ptr->MakeInvokerPointer();
|
||||
|
||||
float avg_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel});
|
||||
|
||||
size_t num_bytes = 0;
|
||||
|
||||
// inputing of x, dy, scale, outputing of dx, dscale, dbias
|
||||
num_bytes += total_length * (sizeof(XDataType) + sizeof(DyDataType) + sizeof(DxDataType)) +
|
||||
invariant_length * sizeof(DscaleDbiasDataType) * 2;
|
||||
|
||||
// inputting of savedMean, savedInvVariance
|
||||
if(haveSavedMeanInvVar)
|
||||
num_bytes += invariant_length * sizeof(MeanVarDataType) * 2;
|
||||
|
||||
float gb_per_sec = num_bytes / 1.E6 / avg_time;
|
||||
|
||||
if(time_kernel)
|
||||
std::cout << "Perf: " << avg_time << " ms, " << gb_per_sec << " GB/s, "
|
||||
<< inst_ptr->GetTypeString() << std::endl;
|
||||
|
||||
if(avg_time < best_avg_time)
|
||||
{
|
||||
best_instance_name = inst_ptr->GetTypeString();
|
||||
best_avg_time = avg_time;
|
||||
best_gb_per_sec = gb_per_sec;
|
||||
}
|
||||
|
||||
if(do_verification)
|
||||
{
|
||||
using ck::utils::check_err;
|
||||
bool single_pass = true;
|
||||
|
||||
dx_dev.FromDevice(dx.mData.data());
|
||||
dscale_dev.FromDevice(dscale.data());
|
||||
dbias_dev.FromDevice(dbias.data());
|
||||
|
||||
// clang-format off
|
||||
single_pass = single_pass && ck::utils::check_err(dx.mData, dx_ref.mData, "dx result:", 5e-4, 5e-4);
|
||||
single_pass = single_pass && ck::utils::check_err(dscale.mData, dscale_ref.mData, "dScale result:", 3e-3, 3e-3);
|
||||
single_pass = single_pass && ck::utils::check_err(dbias.mData, dbias_ref.mData, "dBias result:", 3e-3, 3e-3);
|
||||
// clang-format on
|
||||
|
||||
pass = pass && single_pass;
|
||||
};
|
||||
|
||||
if(do_dumpout)
|
||||
{
|
||||
using ck::host_common::dumpBufferToFile;
|
||||
|
||||
// clang-format off
|
||||
dumpBufferToFile("dump_x.bin", x.mData.data(), x.mDesc.GetElementSize());
|
||||
dumpBufferToFile("dump_dy.bin", dy.mData.data(), dy.mDesc.GetElementSize());
|
||||
dumpBufferToFile("dump_dx.bin", dx.mData.data(), dx.mDesc.GetElementSize());
|
||||
dumpBufferToFile("dump_dx_ref.bin", dx_ref.mData.data(), dx_ref.mDesc.GetElementSize());
|
||||
dumpBufferToFile("dump_dscale.bin", dscale.mData.data(), dscale.mDesc.GetElementSize());
|
||||
dumpBufferToFile("dump_dscale_ref.bin", dscale_ref.mData.data(), dscale_ref.mDesc.GetElementSize());
|
||||
// clang-format off
|
||||
};
|
||||
}
|
||||
|
||||
if(time_kernel)
|
||||
{
|
||||
std::cout << "best perf = " << best_avg_time << " ms, " << best_gb_per_sec << " GB/s, "
|
||||
<< best_instance_name << std::endl;
|
||||
}
|
||||
|
||||
if(num_kernel == 0)
|
||||
{
|
||||
std::cout << "Error: No kernel is applicable" << std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
return pass;
|
||||
}
|
||||
|
||||
} // namespace profiler
|
||||
} // namespace ck
|
||||
204
profiler/src/profile_batchnorm_bwd.cpp
Normal file
204
profiler/src/profile_batchnorm_bwd.cpp
Normal file
@@ -0,0 +1,204 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <iostream>
|
||||
#include <vector>
|
||||
#include <getopt.h>
|
||||
|
||||
#include "ck/library/utility/host_common_util.hpp"
|
||||
#include "profiler/include/profile_batchnorm_backward_impl.hpp"
|
||||
|
||||
using ck::index_t;
|
||||
|
||||
using namespace std;
|
||||
|
||||
static const struct option long_options[] = {{"inOutLengths", required_argument, nullptr, 'D'},
|
||||
{"reduceDims", required_argument, nullptr, 'R'},
|
||||
{"dumpout", required_argument, nullptr, 'o'},
|
||||
{"verify", required_argument, nullptr, 'v'},
|
||||
{"help", no_argument, nullptr, '?'},
|
||||
{nullptr, 0, nullptr, 0}};
|
||||
|
||||
class BatchnormBwdArgParser
|
||||
{
|
||||
private:
|
||||
int option_index = 0;
|
||||
|
||||
public:
|
||||
std::vector<size_t> inLengths;
|
||||
std::vector<int> reduceDims;
|
||||
|
||||
bool do_verification = false;
|
||||
bool do_dumpout = false;
|
||||
|
||||
bool haveSavedMeanInvVar;
|
||||
|
||||
int data_type = 0;
|
||||
int init_method = 2;
|
||||
bool time_kernel = false;
|
||||
|
||||
BatchnormBwdArgParser() = default;
|
||||
~BatchnormBwdArgParser() = default;
|
||||
|
||||
void show_usage(const char* cmd)
|
||||
{
|
||||
// clang-format off
|
||||
std::cout << "Usage of " << cmd << std::endl;
|
||||
std::cout << "--inOutLengths or -D, comma separated list of input tensor dimension lengths, must have 4 integers for nhwc" << std::endl;
|
||||
std::cout << "--reduceDims or -R, comma separated list of dimensions to reduce on" << std::endl;
|
||||
std::cout << "--verify or -v, 1/0 to indicate whether to verify the result by comparing with the host-based batch-normalization" << std::endl;
|
||||
std::cout << "Arg1: data type (0: fp16, 1: fp32, 5: bp16, 6: fp64)" << std::endl;
|
||||
std::cout << "Arg2 -- 1/0 to indicate whether to use saved mean and invVariance" << std::endl;
|
||||
std::cout << "Arg3 -- init method used for dy and bnScale (0=no init, 1=single integer value, 2=scope integer value, 3=decimal value)" << std::endl;
|
||||
std::cout << "Arg4 -- time kernel (0=no, 1=yes)" << std::endl;
|
||||
// clang-format on
|
||||
};
|
||||
|
||||
int operator()(int argc, char* argv[])
|
||||
{
|
||||
using ck::host_common::getTypeValuesFromString;
|
||||
|
||||
int ch;
|
||||
|
||||
optind++; // to skip the module name
|
||||
|
||||
while(1)
|
||||
{
|
||||
ch = getopt_long(argc, argv, "D:R:v:o:", long_options, &option_index);
|
||||
if(ch == -1)
|
||||
break;
|
||||
switch(ch)
|
||||
{
|
||||
case 'D':
|
||||
if(!optarg)
|
||||
throw std::runtime_error("Invalid option format!");
|
||||
|
||||
inLengths = getTypeValuesFromString<size_t>(optarg);
|
||||
break;
|
||||
case 'R':
|
||||
if(!optarg)
|
||||
throw std::runtime_error("Invalid option format!");
|
||||
|
||||
reduceDims = getTypeValuesFromString<int>(optarg);
|
||||
break;
|
||||
case 'v':
|
||||
if(!optarg)
|
||||
throw std::runtime_error("Invalid option format!");
|
||||
|
||||
do_verification = static_cast<bool>(std::atoi(optarg));
|
||||
break;
|
||||
case 'o':
|
||||
if(!optarg)
|
||||
throw std::runtime_error("Invalid option format!");
|
||||
|
||||
do_dumpout = static_cast<bool>(std::atoi(optarg));
|
||||
break;
|
||||
case '?':
|
||||
if(std::string(long_options[option_index].name) == "help")
|
||||
{
|
||||
show_usage(argv[0]);
|
||||
return -1;
|
||||
};
|
||||
break;
|
||||
|
||||
default:
|
||||
show_usage(argv[0]);
|
||||
std::cerr << "Invalid cmd-line options!" << std::endl;
|
||||
return -1;
|
||||
};
|
||||
};
|
||||
|
||||
if(optind + 4 > argc)
|
||||
throw std::runtime_error("Invalid cmd-line arguments, more argumetns are needed!");
|
||||
|
||||
data_type = std::atoi(argv[optind++]);
|
||||
haveSavedMeanInvVar = std::atoi(argv[optind++]);
|
||||
init_method = std::atoi(argv[optind++]);
|
||||
time_kernel = static_cast<bool>(std::atoi(argv[optind++]));
|
||||
|
||||
if(data_type != 0 && data_type != 1 && data_type != 3 && data_type != 5 && data_type != 6)
|
||||
return -1;
|
||||
|
||||
return 0;
|
||||
};
|
||||
}; // end of class AppArgs
|
||||
|
||||
static const double epsilon = std::numeric_limits<float>::epsilon();
|
||||
|
||||
int profile_batchnorm_backward(int argc, char* argv[])
|
||||
{
|
||||
using ck::profiler::profile_batchnorm_backward_impl;
|
||||
|
||||
BatchnormBwdArgParser arg_parser;
|
||||
|
||||
if(arg_parser(argc, argv) != 0)
|
||||
return -1;
|
||||
|
||||
using F16 = ck::half_t;
|
||||
using F32 = float;
|
||||
using BF16 = ck::bhalf_t;
|
||||
using F64 = double;
|
||||
|
||||
if(arg_parser.data_type == 0)
|
||||
{
|
||||
if(arg_parser.inLengths.size() == 4 && arg_parser.reduceDims.size() == 3)
|
||||
{
|
||||
profile_batchnorm_backward_impl<F16, F32, F32, F32, F16, F32, F32, 4, 3>(
|
||||
arg_parser.do_verification,
|
||||
arg_parser.init_method,
|
||||
arg_parser.do_dumpout,
|
||||
arg_parser.time_kernel,
|
||||
arg_parser.inLengths,
|
||||
arg_parser.reduceDims,
|
||||
arg_parser.haveSavedMeanInvVar,
|
||||
epsilon);
|
||||
};
|
||||
}
|
||||
else if(arg_parser.data_type == 1)
|
||||
{
|
||||
if(arg_parser.inLengths.size() == 4 && arg_parser.reduceDims.size() == 3)
|
||||
{
|
||||
profile_batchnorm_backward_impl<F32, F32, F32, F32, F32, F32, F32, 4, 3>(
|
||||
arg_parser.do_verification,
|
||||
arg_parser.init_method,
|
||||
arg_parser.do_dumpout,
|
||||
arg_parser.time_kernel,
|
||||
arg_parser.inLengths,
|
||||
arg_parser.reduceDims,
|
||||
arg_parser.haveSavedMeanInvVar,
|
||||
epsilon);
|
||||
};
|
||||
}
|
||||
else if(arg_parser.data_type == 5)
|
||||
{
|
||||
if(arg_parser.inLengths.size() == 4 && arg_parser.reduceDims.size() == 3)
|
||||
{
|
||||
profile_batchnorm_backward_impl<BF16, F32, F32, F32, BF16, F32, F32, 4, 3>(
|
||||
arg_parser.do_verification,
|
||||
arg_parser.init_method,
|
||||
arg_parser.do_dumpout,
|
||||
arg_parser.time_kernel,
|
||||
arg_parser.inLengths,
|
||||
arg_parser.reduceDims,
|
||||
arg_parser.haveSavedMeanInvVar,
|
||||
epsilon);
|
||||
};
|
||||
}
|
||||
else if(arg_parser.data_type == 6)
|
||||
{
|
||||
if(arg_parser.inLengths.size() == 4 && arg_parser.reduceDims.size() == 3)
|
||||
{
|
||||
profile_batchnorm_backward_impl<F64, F64, F64, F64, F64, F64, F64, 4, 3>(
|
||||
arg_parser.do_verification,
|
||||
arg_parser.init_method,
|
||||
arg_parser.do_dumpout,
|
||||
arg_parser.time_kernel,
|
||||
arg_parser.inLengths,
|
||||
arg_parser.reduceDims,
|
||||
arg_parser.haveSavedMeanInvVar,
|
||||
epsilon);
|
||||
};
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
@@ -25,6 +25,7 @@ int profile_layernorm(int, char*[]);
|
||||
int profile_groupnorm(int, char*[]);
|
||||
int profile_reduce(int, char*[]);
|
||||
int profile_batchnorm_forward(int, char*[]);
|
||||
int profile_batchnorm_backward(int, char*[]);
|
||||
|
||||
static void print_helper_message()
|
||||
{
|
||||
@@ -148,6 +149,10 @@ int main(int argc, char* argv[])
|
||||
{
|
||||
return profile_batchnorm_forward(argc, argv);
|
||||
}
|
||||
else if(strcmp(argv[1], "bnorm_bwd") == 0)
|
||||
{
|
||||
return profile_batchnorm_backward(argc, argv);
|
||||
}
|
||||
else
|
||||
{
|
||||
print_helper_message();
|
||||
|
||||
@@ -53,4 +53,4 @@ add_subdirectory(softmax)
|
||||
add_subdirectory(normalization)
|
||||
add_subdirectory(data_type)
|
||||
add_subdirectory(elementwise_normalization)
|
||||
add_subdirectory(batchnorm_fwd)
|
||||
add_subdirectory(batchnorm)
|
||||
|
||||
@@ -1,2 +1,4 @@
|
||||
add_gtest_executable(test_batchnorm_fwd_rank_4 batchnorm_fwd_rank_4.cpp)
|
||||
add_gtest_executable(test_batchnorm_bwd_rank_4 batchnorm_bwd_rank_4.cpp)
|
||||
target_link_libraries(test_batchnorm_fwd_rank_4 PRIVATE utility device_batchnorm_instance)
|
||||
target_link_libraries(test_batchnorm_bwd_rank_4 PRIVATE utility device_batchnorm_instance)
|
||||
92
test/batchnorm/batchnorm_bwd_rank_4.cpp
Normal file
92
test/batchnorm/batchnorm_bwd_rank_4.cpp
Normal file
@@ -0,0 +1,92 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <cstdlib>
|
||||
#include <iostream>
|
||||
#include <initializer_list>
|
||||
#include <vector>
|
||||
#include <tuple>
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include "profiler/include/profile_batchnorm_backward_impl.hpp"
|
||||
|
||||
using F16 = ck::half_t;
|
||||
using F32 = float;
|
||||
using BF16 = ck::bhalf_t;
|
||||
using F64 = double;
|
||||
|
||||
template <typename Tuple>
|
||||
class TestBatchNormBwdRank4 : public ::testing::Test
|
||||
{
|
||||
private:
|
||||
const double epsilon = std::numeric_limits<float>::epsilon();
|
||||
|
||||
protected:
|
||||
using XDataType = std::tuple_element_t<0, Tuple>;
|
||||
using DxDataType = std::tuple_element_t<1, Tuple>;
|
||||
using DyDataType = std::tuple_element_t<2, Tuple>;
|
||||
using AccDataType = std::tuple_element_t<3, Tuple>;
|
||||
using ScaleDataType = std::tuple_element_t<4, Tuple>;
|
||||
using BiasDataType = std::tuple_element_t<5, Tuple>;
|
||||
using MeanVarDataType = std::tuple_element_t<6, Tuple>;
|
||||
|
||||
std::vector<std::vector<size_t>> list_of_lengths = {
|
||||
{128, 16, 3, 1024}, {128, 16, 6, 512}, {1, 1, 1, 1}, {4, 4, 4, 4}, {32, 32, 32, 32}};
|
||||
std::vector<int> reduceDims;
|
||||
|
||||
template <int NumReduceDim>
|
||||
void Run()
|
||||
{
|
||||
for(auto& inOutLengths : list_of_lengths)
|
||||
{
|
||||
bool pass = true;
|
||||
|
||||
EXPECT_FALSE(reduceDims.size() != NumReduceDim);
|
||||
|
||||
pass = pass && ck::profiler::profile_batchnorm_backward_impl<XDataType,
|
||||
DxDataType,
|
||||
DyDataType,
|
||||
AccDataType,
|
||||
ScaleDataType,
|
||||
BiasDataType,
|
||||
MeanVarDataType,
|
||||
4,
|
||||
NumReduceDim>(
|
||||
true, 3, false, false, inOutLengths, reduceDims, true, epsilon);
|
||||
|
||||
pass = pass && ck::profiler::profile_batchnorm_backward_impl<XDataType,
|
||||
DxDataType,
|
||||
DyDataType,
|
||||
AccDataType,
|
||||
ScaleDataType,
|
||||
BiasDataType,
|
||||
MeanVarDataType,
|
||||
4,
|
||||
NumReduceDim>(
|
||||
true, 3, false, false, inOutLengths, reduceDims, false, epsilon);
|
||||
|
||||
EXPECT_TRUE(pass);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
using KernelTypes = ::testing::Types<std::tuple<F16, F32, F32, F32, F16, F32, F32>,
|
||||
std::tuple<F32, F32, F32, F32, F32, F32, F32>,
|
||||
std::tuple<BF16, F32, F32, F32, BF16, F32, F32>,
|
||||
std::tuple<F64, F64, F64, F64, F64, F64, F64>>;
|
||||
|
||||
TYPED_TEST_SUITE(TestBatchNormBwdRank4, KernelTypes);
|
||||
|
||||
// nhwc
|
||||
TYPED_TEST(TestBatchNormBwdRank4, nhwc)
|
||||
{
|
||||
this->reduceDims = {0, 1, 2};
|
||||
this->template Run<3>();
|
||||
}
|
||||
|
||||
// nchw
|
||||
TYPED_TEST(TestBatchNormBwdRank4, nchw)
|
||||
{
|
||||
this->reduceDims = {0, 2, 3};
|
||||
this->template Run<3>();
|
||||
}
|
||||
Reference in New Issue
Block a user