mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 22:39:03 +00:00
BatchNorm backward instance/external API/profiler/tests (#519)
* Refine the device batchnorm-backward base API templates and data type assignments * Remove duplicated kernel file * Add batchnorm backward instances and external API * Add batchnorm-backward profiler and tests * Add client example which uses batchnorm backward external API * Merge test/batchnorm_fwd and test/batchnorm_bwd into one directory * Loose the threshold for batchnorm-backward check_err()
This commit is contained in:
@@ -11,7 +11,7 @@
|
||||
#include "ck/library/utility/host_tensor.hpp"
|
||||
#include "ck/library/utility/host_tensor_generator.hpp"
|
||||
#include "ck/library/utility/host_common_util.hpp"
|
||||
#include "ck/library/reference_tensor_operation/cpu/reference_batchnorm_backward_nhwc_c.hpp"
|
||||
#include "ck/library/reference_tensor_operation/cpu/reference_batchnorm_backward.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_batchnorm_backward_impl.hpp"
|
||||
|
||||
static struct option long_options[] = {{"inOutLengths", required_argument, nullptr, 'D'},
|
||||
@@ -106,7 +106,7 @@ class BatchNormBwdArg
|
||||
|
||||
using namespace ck;
|
||||
|
||||
template <typename InOutDataType, typename AccDataType, bool UseMultiblockInK>
|
||||
template <typename XDataType, typename AccDataType, bool UseMultiblockInK>
|
||||
bool bnorm_bwd_nhwc_test(bool do_verification,
|
||||
int init_method,
|
||||
bool time_kernel,
|
||||
@@ -118,13 +118,15 @@ bool bnorm_bwd_nhwc_test(bool do_verification,
|
||||
constexpr index_t Rank = 4;
|
||||
constexpr index_t NumReduceDim = 3;
|
||||
|
||||
using ScaleDataType = XDataType;
|
||||
|
||||
const std::vector<size_t> scaleBiasMeanVarLengths = {inOutLengths[3]};
|
||||
|
||||
// input data of the batchnorm backward algorithm
|
||||
Tensor<InOutDataType> x(inOutLengths);
|
||||
Tensor<InOutDataType> dy(inOutLengths);
|
||||
Tensor<XDataType> x(inOutLengths);
|
||||
Tensor<AccDataType> dy(inOutLengths);
|
||||
|
||||
Tensor<AccDataType> bnScale(scaleBiasMeanVarLengths);
|
||||
Tensor<ScaleDataType> bnScale(scaleBiasMeanVarLengths);
|
||||
|
||||
Tensor<AccDataType> savedMean(scaleBiasMeanVarLengths);
|
||||
Tensor<AccDataType> savedInvVar(scaleBiasMeanVarLengths);
|
||||
@@ -132,8 +134,8 @@ bool bnorm_bwd_nhwc_test(bool do_verification,
|
||||
Tensor<AccDataType> savedVariance(scaleBiasMeanVarLengths);
|
||||
|
||||
// output data of the batchnorm backward algorithm
|
||||
Tensor<InOutDataType> dx_ref(inOutLengths);
|
||||
Tensor<InOutDataType> dx(inOutLengths);
|
||||
Tensor<AccDataType> dx_ref(inOutLengths);
|
||||
Tensor<AccDataType> dx(inOutLengths);
|
||||
|
||||
Tensor<AccDataType> dscale(scaleBiasMeanVarLengths);
|
||||
Tensor<AccDataType> dbias(scaleBiasMeanVarLengths);
|
||||
@@ -153,7 +155,7 @@ bool bnorm_bwd_nhwc_test(bool do_verification,
|
||||
const float noise_stddev = 0.0001f;
|
||||
|
||||
// input data in normal distribution
|
||||
x.GenerateTensorValue(GeneratorTensor_4<InOutDataType>{x_mean, x_stddev}, num_thread);
|
||||
x.GenerateTensorValue(GeneratorTensor_4<XDataType>{x_mean, x_stddev}, num_thread);
|
||||
|
||||
// initialize the savedMean to be values with tiny variation to the mean of the x values
|
||||
savedMean.GenerateTensorValue(GeneratorTensor_4<AccDataType>{x_mean, noise_stddev},
|
||||
@@ -182,7 +184,7 @@ bool bnorm_bwd_nhwc_test(bool do_verification,
|
||||
const float x_stddev = 1.0f;
|
||||
|
||||
// input data in normal distribution
|
||||
x.GenerateTensorValue(GeneratorTensor_4<InOutDataType>{x_mean, x_stddev}, num_thread);
|
||||
x.GenerateTensorValue(GeneratorTensor_4<XDataType>{x_mean, x_stddev}, num_thread);
|
||||
};
|
||||
|
||||
if(do_verification)
|
||||
@@ -190,34 +192,34 @@ bool bnorm_bwd_nhwc_test(bool do_verification,
|
||||
switch(init_method)
|
||||
{
|
||||
case 0:
|
||||
dy.GenerateTensorValue(GeneratorTensor_0<InOutDataType>{}, num_thread);
|
||||
bnScale.GenerateTensorValue(GeneratorTensor_0<InOutDataType>{}, num_thread);
|
||||
dy.GenerateTensorValue(GeneratorTensor_0<AccDataType>{}, num_thread);
|
||||
bnScale.GenerateTensorValue(GeneratorTensor_0<ScaleDataType>{}, num_thread);
|
||||
break;
|
||||
case 1:
|
||||
dy.GenerateTensorValue(GeneratorTensor_1<InOutDataType>{1}, num_thread);
|
||||
bnScale.GenerateTensorValue(GeneratorTensor_1<InOutDataType>{1}, num_thread);
|
||||
dy.GenerateTensorValue(GeneratorTensor_1<AccDataType>{1}, num_thread);
|
||||
bnScale.GenerateTensorValue(GeneratorTensor_1<ScaleDataType>{1}, num_thread);
|
||||
break;
|
||||
case 2:
|
||||
dy.GenerateTensorValue(GeneratorTensor_2<InOutDataType>{-5, 5}, num_thread);
|
||||
bnScale.GenerateTensorValue(GeneratorTensor_2<InOutDataType>{-5, 5}, num_thread);
|
||||
dy.GenerateTensorValue(GeneratorTensor_2<AccDataType>{-2, 2}, num_thread);
|
||||
bnScale.GenerateTensorValue(GeneratorTensor_2<ScaleDataType>{-5, 5}, num_thread);
|
||||
break;
|
||||
default:
|
||||
dy.GenerateTensorValue(GeneratorTensor_3<InOutDataType>{-0.2f, 0.2f}, num_thread);
|
||||
bnScale.GenerateTensorValue(GeneratorTensor_3<InOutDataType>{-0.5f, 0.5f}, num_thread);
|
||||
dy.GenerateTensorValue(GeneratorTensor_3<AccDataType>{-0.2f, 0.2f}, num_thread);
|
||||
bnScale.GenerateTensorValue(GeneratorTensor_3<ScaleDataType>{-0.5f, 0.5f}, num_thread);
|
||||
}
|
||||
};
|
||||
|
||||
// input data of the batchnorm backward algorithm
|
||||
DeviceMem x_dev(sizeof(InOutDataType) * x.mDesc.GetElementSpaceSize());
|
||||
DeviceMem dy_dev(sizeof(InOutDataType) * dy.mDesc.GetElementSpaceSize());
|
||||
DeviceMem x_dev(sizeof(XDataType) * x.mDesc.GetElementSpaceSize());
|
||||
DeviceMem dy_dev(sizeof(AccDataType) * dy.mDesc.GetElementSpaceSize());
|
||||
|
||||
DeviceMem bnScale_dev(sizeof(AccDataType) * bnScale.mDesc.GetElementSpaceSize());
|
||||
DeviceMem bnScale_dev(sizeof(ScaleDataType) * bnScale.mDesc.GetElementSpaceSize());
|
||||
|
||||
DeviceMem savedMean_dev(sizeof(AccDataType) * savedMean.mDesc.GetElementSpaceSize());
|
||||
DeviceMem savedInvVar_dev(sizeof(AccDataType) * savedInvVar.mDesc.GetElementSpaceSize());
|
||||
|
||||
// output data of the batchnorm backward algorithm
|
||||
DeviceMem dx_dev(sizeof(InOutDataType) * dx.mDesc.GetElementSpaceSize());
|
||||
DeviceMem dx_dev(sizeof(AccDataType) * dx.mDesc.GetElementSpaceSize());
|
||||
|
||||
DeviceMem dscale_dev(sizeof(AccDataType) * dscale.mDesc.GetElementSpaceSize());
|
||||
DeviceMem dbias_dev(sizeof(AccDataType) * dbias.mDesc.GetElementSpaceSize());
|
||||
@@ -249,13 +251,13 @@ bool bnorm_bwd_nhwc_test(bool do_verification,
|
||||
using PassThroughOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
using DeviceBatchNormBwdInstance =
|
||||
ck::tensor_operation::device::DeviceBatchNormBwdImpl<InOutDataType,
|
||||
InOutDataType,
|
||||
InOutDataType,
|
||||
ck::tensor_operation::device::DeviceBatchNormBwdImpl<XDataType,
|
||||
AccDataType,
|
||||
AccDataType, // ScaleDataType
|
||||
AccDataType, // BiasDataType
|
||||
AccDataType, // MeanVarDataType
|
||||
AccDataType,
|
||||
AccDataType,
|
||||
ScaleDataType, // ScaleDataType
|
||||
AccDataType, // DscaleDbiasDataType
|
||||
AccDataType, // MeanVarDataType
|
||||
PassThroughOp,
|
||||
Rank,
|
||||
NumReduceDim,
|
||||
@@ -269,8 +271,8 @@ bool bnorm_bwd_nhwc_test(bool do_verification,
|
||||
1, // XSrcVectorSize
|
||||
1, // DySrcVectorSize
|
||||
1, // DxDstVectorSize
|
||||
1, // ScaleSrcDstVectorSize
|
||||
1, // BiasDstVectorSize
|
||||
1, // ScaleSrcVectorSize
|
||||
1, // DscaleDbiasDstVectorSize
|
||||
1>; // MeanVarSrcVectorSize
|
||||
|
||||
auto batchnorm_bwd = DeviceBatchNormBwdInstance{};
|
||||
@@ -324,7 +326,7 @@ bool bnorm_bwd_nhwc_test(bool do_verification,
|
||||
|
||||
// inputing of x, dy, scale, outputing of dx, dscale, dbias
|
||||
num_bytes +=
|
||||
total_length * sizeof(InOutDataType) * 3 + invariant_length * sizeof(AccDataType) * 3;
|
||||
total_length * sizeof(XDataType) * 3 + invariant_length * sizeof(AccDataType) * 3;
|
||||
|
||||
// outputing of mean, inv-variance
|
||||
num_bytes += haveSavedMeanInvVar ? invariant_length * sizeof(AccDataType) * 2 : 0;
|
||||
@@ -341,14 +343,16 @@ bool bnorm_bwd_nhwc_test(bool do_verification,
|
||||
if(do_verification)
|
||||
{
|
||||
using ReferenceBatchNormBwdInstance =
|
||||
ck::tensor_operation::host::ReferenceBatchNormBwd_Input_N_H_W_C_Output_C<InOutDataType,
|
||||
InOutDataType,
|
||||
InOutDataType,
|
||||
AccDataType,
|
||||
AccDataType,
|
||||
AccDataType,
|
||||
AccDataType,
|
||||
PassThroughOp>;
|
||||
ck::tensor_operation::host::ReferenceBatchNormBwd<XDataType,
|
||||
AccDataType,
|
||||
AccDataType,
|
||||
AccDataType,
|
||||
ScaleDataType, // ScaleDataType
|
||||
AccDataType,
|
||||
AccDataType,
|
||||
PassThroughOp,
|
||||
Rank,
|
||||
NumReduceDim>;
|
||||
|
||||
auto batchNormBwd_ref = ReferenceBatchNormBwdInstance{};
|
||||
|
||||
@@ -390,8 +394,8 @@ bool bnorm_bwd_nhwc_test(bool do_verification,
|
||||
dbias_dev.FromDevice(dbias.data());
|
||||
|
||||
// clang-format off
|
||||
pass = pass && ck::utils::check_err(dbias.mData, dbias_ref.mData, "dBias result:", 1e-5, 1e-5);
|
||||
pass = pass && ck::utils::check_err(dscale.mData, dscale_ref.mData, "dScale result:", 1e-5, 2e-4);
|
||||
pass = pass && ck::utils::check_err(dbias.mData, dbias_ref.mData, "dBias result:", 2e-4, 2e-4);
|
||||
pass = pass && ck::utils::check_err(dscale.mData, dscale_ref.mData, "dScale result:", 2e-4, 2e-4);
|
||||
pass = pass && ck::utils::check_err(dx.mData, dx_ref.mData, "dx result:");
|
||||
// clang-format on
|
||||
};
|
||||
|
||||
Reference in New Issue
Block a user