mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 22:39:03 +00:00
Batchnorm-forward implemented using welford method to calculate variance (#403)
* Update to the batchnorm-forward API and base class * Fix leeked header including in gridwise_set_buffer_value.hpp * Add kernels and device file for batchnorm-forward welford supporting both blockwise and multi-block reduction * Update to the batchnorm-forward example to use the new batchnorm-forward device interface * Change the batchnorm-forward reference to use sequential welford method * Change to assign the workspace into four buffers in the host layer * Use GetReduceCountPerThread functor to replace the initial count for Blockwise and Multiblock welford * Tiny correction and remove un-used file under example/34_batchnorm * Renaming in the kernel arguments * Explicitly use ck::math::sqrt in batchnorm-forward kernels * Add some comments to some kernels * Tiny fix * Generalize the data types in reference_batchnorm_forward_nhwc_c * Use ck::ignore to mark un-used parameters * Move GetReduceCountPerThread functor codes from kernel to device * Remove some un-used codes in device_batchnorm_forward_impl.hpp * Tiny fix in batchnorm_forward example * Move GetReduceCountPerThread() to welford_helper.hpp * Use seperate data type for Scale and Bias * Renaming in device Op * Tiny fix in forward example * Updata to batchnorm-infer (type spliting, renaming) * Add time and bandwidth measurement to the batchnorm-forward example * Add support of elementwise operation for batchnorm forward output * Reduce object copying by passing object as reference type * Tiny change for performance * Updates for performance again * Some Renamings * Add GetActualVariance template parameter for ThreadwiseWelfordMerge * Tiny update in reference batchnorm forward nhwc/c * Move batchnorm multiblock kernel files to grid/batchnorm_multiblock sub-directory * Fuse mean and bias in the normalization calculation Co-authored-by: root <root@dc-smc-18.amd.com> Co-authored-by: rocking5566 <ChunYu.Lai@amd.com>
This commit is contained in:
@@ -10,102 +10,17 @@
|
||||
|
||||
#include "ck/utility/data_type.hpp"
|
||||
|
||||
// binary operation used to calculate invVariance from mean and meansquare
|
||||
struct InvVariance
|
||||
{
|
||||
InvVariance(double epsilon) : epsilon_(epsilon){};
|
||||
|
||||
template <typename T>
|
||||
__host__ __device__ constexpr void operator()(T& y, const T& mean, const T& meansquare) const
|
||||
{
|
||||
static_assert(std::is_same<T, float>::value || std::is_same<T, double>::value,
|
||||
"Data type is not supported by this operation!");
|
||||
|
||||
using ck::type_convert;
|
||||
using ck::math::sqrt;
|
||||
|
||||
T tmp_epsilon = type_convert<T>(epsilon_);
|
||||
|
||||
y = meansquare - mean * mean;
|
||||
y = 1.0f / sqrt(tmp_epsilon + y);
|
||||
};
|
||||
|
||||
double epsilon_;
|
||||
};
|
||||
|
||||
// (4-in, 2-out) element-wise operation used to update the moving average of mean and variance
|
||||
struct MovingAverage
|
||||
{
|
||||
MovingAverage(double factor) : factor_(factor){};
|
||||
|
||||
template <typename T>
|
||||
__host__ __device__ constexpr void operator()(T& y0,
|
||||
T& y1,
|
||||
const T& mean,
|
||||
const T& runningMean,
|
||||
const T& meansquare,
|
||||
const T& runningVariance) const
|
||||
{
|
||||
static_assert(std::is_same<T, float>::value || std::is_same<T, double>::value,
|
||||
"Data type is not supported by this operation!");
|
||||
|
||||
using ck::type_convert;
|
||||
|
||||
T tmp_factor = type_convert<T>(factor_);
|
||||
T variance = meansquare - mean * mean;
|
||||
|
||||
y0 = runningMean * (type_convert<T>(1.0f) - tmp_factor) + mean * tmp_factor;
|
||||
y1 = runningVariance * (type_convert<T>(1.0f) - tmp_factor) + variance * tmp_factor;
|
||||
};
|
||||
|
||||
double factor_;
|
||||
};
|
||||
|
||||
struct MovingAverageAndInvVariance
|
||||
{
|
||||
MovingAverageAndInvVariance(double epsilon, double factor)
|
||||
: epsilon_(epsilon), factor_(factor){};
|
||||
|
||||
template <typename T>
|
||||
__host__ __device__ constexpr void operator()(T& y0, // resultRunningMean
|
||||
T& y1, // resultRunningVariance
|
||||
T& y2, // saveInvVariance
|
||||
const T& mean,
|
||||
const T& runningMean,
|
||||
const T& meansquare,
|
||||
const T& runningVariance) const
|
||||
{
|
||||
static_assert(std::is_same<T, float>::value || std::is_same<T, double>::value,
|
||||
"Data type is not supported by this operation!");
|
||||
|
||||
using ck::type_convert;
|
||||
using ck::math::sqrt;
|
||||
|
||||
T tmp_epsilon = type_convert<T>(epsilon_);
|
||||
T tmp_factor = type_convert<T>(factor_);
|
||||
T variance = meansquare - mean * mean;
|
||||
|
||||
y0 = runningMean * (type_convert<T>(1.0f) - tmp_factor) + mean * tmp_factor;
|
||||
y1 = runningVariance * (type_convert<T>(1.0f) - tmp_factor) + variance * tmp_factor;
|
||||
|
||||
y2 = 1.0f / sqrt(tmp_epsilon + variance);
|
||||
};
|
||||
|
||||
double epsilon_;
|
||||
double factor_;
|
||||
};
|
||||
|
||||
struct NormalizeInInfer
|
||||
{
|
||||
NormalizeInInfer(double epsilon = 1e-4) : epsilon_(epsilon) {}
|
||||
|
||||
template <typename T1, typename T2>
|
||||
template <typename T1, typename T2, typename T3, typename T4>
|
||||
__host__ __device__ constexpr void operator()(T1& y,
|
||||
const T1& x,
|
||||
const T2& mean,
|
||||
const T2& variance,
|
||||
const T2& gamma,
|
||||
const T2& beta) const
|
||||
const T3& gamma,
|
||||
const T4& beta) const
|
||||
{
|
||||
static_assert(std::is_same<T2, float>::value || std::is_same<T2, double>::value,
|
||||
"Data type is not supported by this operation!");
|
||||
@@ -117,38 +32,10 @@ struct NormalizeInInfer
|
||||
|
||||
tmp_x = type_convert<T2>(x);
|
||||
|
||||
tmp_y = ((tmp_x - mean) / sqrt(variance + type_convert<T2>(epsilon_))) * gamma + beta;
|
||||
y = type_convert<T1>(tmp_y);
|
||||
};
|
||||
|
||||
double epsilon_;
|
||||
};
|
||||
|
||||
struct NormalizeInForward
|
||||
{
|
||||
NormalizeInForward(double epsilon = 1e-4) : epsilon_(epsilon) {}
|
||||
|
||||
template <typename T1, typename T2>
|
||||
__host__ __device__ constexpr void operator()(T1& y,
|
||||
const T1& x,
|
||||
const T2& mean,
|
||||
const T2& meansquare,
|
||||
const T2& gamma,
|
||||
const T2& beta) const
|
||||
{
|
||||
static_assert(std::is_same<T2, float>::value || std::is_same<T2, double>::value,
|
||||
"Data type is not supported by this operation!");
|
||||
|
||||
using ck::type_convert;
|
||||
using ck::math::sqrt;
|
||||
|
||||
T2 tmp_x, tmp_y;
|
||||
T2 variance = meansquare - mean * mean;
|
||||
|
||||
tmp_x = type_convert<T2>(x);
|
||||
|
||||
tmp_y = ((tmp_x - mean) / sqrt(variance + type_convert<T2>(epsilon_))) * gamma + beta;
|
||||
y = type_convert<T1>(tmp_y);
|
||||
tmp_y = ((tmp_x - mean) / sqrt(variance + type_convert<T2>(epsilon_))) *
|
||||
type_convert<T2>(gamma) +
|
||||
type_convert<T2>(beta);
|
||||
y = type_convert<T1>(tmp_y);
|
||||
};
|
||||
|
||||
double epsilon_;
|
||||
|
||||
@@ -1,295 +0,0 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cassert>
|
||||
#include <vector>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/utility/reduction_operator.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_multiple_reduce_multiblock.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_elementwise.hpp"
|
||||
|
||||
#include "batchnorm_common.hpp"
|
||||
|
||||
template <typename InOutDataType,
|
||||
typename AccDataType,
|
||||
ck::index_t Rank,
|
||||
ck::index_t NumBatchNormReduceDim,
|
||||
bool fastest_dim_is_reduced = false>
|
||||
int bnorm_fwd(bool time_kernel,
|
||||
bool updateMovingAverage,
|
||||
bool saveMeanAndInvVariance,
|
||||
const std::array<int, NumBatchNormReduceDim> reduceDims,
|
||||
const std::array<ck::index_t, Rank> xyLengths,
|
||||
const std::array<ck::index_t, Rank> xStrides,
|
||||
const std::array<ck::index_t, Rank> yStrides,
|
||||
const std::array<ck::index_t, Rank - NumBatchNormReduceDim> bnScaleBiasMeanVarLengths,
|
||||
const std::array<ck::index_t, Rank - NumBatchNormReduceDim> bnScaleBiasMeanVarStrides,
|
||||
const void* p_x,
|
||||
const void* p_scale,
|
||||
const void* p_bias,
|
||||
void* p_y,
|
||||
double exponentialAverageFactor,
|
||||
void* p_runningMean,
|
||||
void* p_runningVariance,
|
||||
double epsilon,
|
||||
void* p_saveMean,
|
||||
void* p_saveInvVariance,
|
||||
void* p_tmp_mean,
|
||||
void* p_tmp_meansquare)
|
||||
{
|
||||
static_assert(NumBatchNormReduceDim < Rank,
|
||||
"Invalid number of reduced dimensions for batchnorm!");
|
||||
|
||||
constexpr ck::index_t NumScaleBiasMeanVarDim = Rank - NumBatchNormReduceDim;
|
||||
|
||||
using InElementwiseOperation_Mean = ck::tensor_operation::element_wise::PassThrough;
|
||||
using AccElementwiseOperation_Mean = ck::tensor_operation::element_wise::UnaryDivide;
|
||||
|
||||
using InElementwiseOperation_Meansquare = ck::tensor_operation::element_wise::UnarySquare;
|
||||
using AccElementwiseOperation_Meansquare = ck::tensor_operation::element_wise::UnaryDivide;
|
||||
|
||||
using DeviceMeanAndMeansquareInstance =
|
||||
ck::tensor_operation::device::DeviceMultipleReduceMultiBlock<
|
||||
2,
|
||||
InOutDataType,
|
||||
AccDataType,
|
||||
ck::Tuple<AccDataType, AccDataType>,
|
||||
Rank,
|
||||
NumBatchNormReduceDim,
|
||||
ck::reduce::Add,
|
||||
ck::Tuple<InElementwiseOperation_Mean, InElementwiseOperation_Meansquare>,
|
||||
ck::Tuple<AccElementwiseOperation_Mean, AccElementwiseOperation_Meansquare>,
|
||||
ck::InMemoryDataOperationEnum::Set,
|
||||
false, // PropagateNan
|
||||
256,
|
||||
16,
|
||||
16,
|
||||
1,
|
||||
1,
|
||||
fastest_dim_is_reduced ? 1 : 0,
|
||||
1,
|
||||
ck::Sequence<1, 1>>;
|
||||
|
||||
using DeviceNormalizeInstance = ck::tensor_operation::device::DeviceElementwise<
|
||||
ck::Tuple<InOutDataType, AccDataType, AccDataType, AccDataType, AccDataType>, // x, mean,
|
||||
// meansquare,
|
||||
// scale, bias
|
||||
ck::Tuple<InOutDataType>, // y
|
||||
NormalizeInForward,
|
||||
Rank,
|
||||
2, // MPerthread
|
||||
ck::Sequence<1, 1, 1, 1, 1>, // scalarPerVector: x, mean, meansquare, scale, bias
|
||||
ck::Sequence<1>>; // scalarPerVector: y
|
||||
|
||||
using DeviceInvVarianceInstance = ck::tensor_operation::device::DeviceElementwise<
|
||||
ck::Tuple<AccDataType, AccDataType>, // mean, meansquare
|
||||
ck::Tuple<AccDataType>, // invVariance
|
||||
InvVariance,
|
||||
NumScaleBiasMeanVarDim,
|
||||
2, // MPerthread
|
||||
ck::Sequence<1, 1>, // scalarPerVector: mean, meansquare
|
||||
ck::Sequence<1>>; // scalarPerVector: invVariance
|
||||
|
||||
using DeviceMovingAverageInstance = ck::tensor_operation::device::DeviceElementwise<
|
||||
ck::Tuple<AccDataType, AccDataType, AccDataType, AccDataType>, // old moving mean, new mean,
|
||||
// old moving variance, new
|
||||
// meansquare
|
||||
ck::Tuple<AccDataType, AccDataType>, // updated moving mean, updated moving variance
|
||||
MovingAverage,
|
||||
NumScaleBiasMeanVarDim,
|
||||
4, // MPerthread
|
||||
ck::Sequence<1, 1, 1, 1>, // scalarPerVector: old moving mean, new mean, old moving
|
||||
// variance, new meansquare
|
||||
ck::Sequence<1, 1>>; // scalarPerVector: updated moving mean, updated moving variance
|
||||
|
||||
using DeviceMovingAverageAndInvVarianceInstance =
|
||||
ck::tensor_operation::device::DeviceElementwise<
|
||||
ck::Tuple<AccDataType, AccDataType, AccDataType, AccDataType>, // old moving mean, new
|
||||
// mean, old moving
|
||||
// variance, new
|
||||
// meansquare
|
||||
ck::Tuple<AccDataType, AccDataType, AccDataType>, // updated moving mean, updated moving
|
||||
// variancem, invVariance
|
||||
MovingAverageAndInvVariance,
|
||||
NumScaleBiasMeanVarDim,
|
||||
4, // MPerthread
|
||||
ck::Sequence<1, 1, 1, 1>, // scalarPerVector: old moving mean, new mean, old moving
|
||||
// variance, new meansquare
|
||||
ck::Sequence<1, 1, 1>>; // scalarPerVector: updated moving mean, updated moving variance
|
||||
|
||||
auto invariantDims = get_invariant_dims<Rank, NumBatchNormReduceDim>(reduceDims);
|
||||
std::array<ck::index_t, Rank> aligned_scaleBiasMeanVarStrides{0};
|
||||
|
||||
int i = 0;
|
||||
for(auto dim : invariantDims)
|
||||
{
|
||||
assert(xyLengths[dim] == bnScaleBiasMeanVarLengths[i]);
|
||||
|
||||
aligned_scaleBiasMeanVarStrides[dim] = bnScaleBiasMeanVarStrides[i];
|
||||
i++;
|
||||
};
|
||||
|
||||
int32_t reduceLength = 1;
|
||||
|
||||
for(auto dim : reduceDims)
|
||||
reduceLength *= xyLengths[dim];
|
||||
|
||||
int32_t invariantLength = 1;
|
||||
|
||||
for(auto dim : invariantDims)
|
||||
invariantLength *= xyLengths[dim];
|
||||
|
||||
size_t total_length = static_cast<size_t>(invariantLength) * reduceLength;
|
||||
|
||||
float avg_time = 0.0f;
|
||||
std::size_t num_bytes = 0;
|
||||
|
||||
auto dev_mean_and_meansquare = DeviceMeanAndMeansquareInstance{};
|
||||
|
||||
void* p_mean = saveMeanAndInvVariance ? p_saveMean : p_tmp_mean;
|
||||
|
||||
const AccDataType alpha = ck::type_convert<AccDataType>(1.0f);
|
||||
const AccDataType beta = ck::type_convert<AccDataType>(0.0f);
|
||||
|
||||
auto argument_ptr1 = dev_mean_and_meansquare.MakeArgumentPointer(
|
||||
xyLengths,
|
||||
xStrides,
|
||||
bnScaleBiasMeanVarLengths,
|
||||
{bnScaleBiasMeanVarStrides, bnScaleBiasMeanVarStrides},
|
||||
reduceDims,
|
||||
{&alpha, &alpha},
|
||||
{&beta, &beta},
|
||||
p_x,
|
||||
{p_mean, p_tmp_meansquare},
|
||||
ck::make_tuple(InElementwiseOperation_Mean{}, InElementwiseOperation_Meansquare{}),
|
||||
ck::make_tuple(AccElementwiseOperation_Mean{reduceLength},
|
||||
AccElementwiseOperation_Meansquare{reduceLength}));
|
||||
|
||||
auto dev_normalize = DeviceNormalizeInstance{};
|
||||
|
||||
auto argument_ptr2 =
|
||||
dev_normalize.MakeArgumentPointer(xyLengths,
|
||||
{xStrides,
|
||||
aligned_scaleBiasMeanVarStrides,
|
||||
aligned_scaleBiasMeanVarStrides,
|
||||
aligned_scaleBiasMeanVarStrides,
|
||||
aligned_scaleBiasMeanVarStrides},
|
||||
{yStrides},
|
||||
{p_x, p_mean, p_tmp_meansquare, p_scale, p_bias},
|
||||
{p_y},
|
||||
NormalizeInForward{epsilon});
|
||||
|
||||
if(!dev_mean_and_meansquare.IsSupportedArgument(argument_ptr1.get()) ||
|
||||
!dev_normalize.IsSupportedArgument(argument_ptr2.get()))
|
||||
{
|
||||
std::cout << "The runtime parameters seems not supported by the Devic, exiting!"
|
||||
<< std::endl;
|
||||
|
||||
return (-1);
|
||||
};
|
||||
|
||||
auto invoker_ptr1 = dev_mean_and_meansquare.MakeInvokerPointer();
|
||||
auto invoker_ptr2 = dev_normalize.MakeInvokerPointer();
|
||||
|
||||
avg_time += invoker_ptr1->Run(argument_ptr1.get(), StreamConfig{nullptr, time_kernel});
|
||||
avg_time += invoker_ptr2->Run(argument_ptr2.get(), StreamConfig{nullptr, time_kernel});
|
||||
|
||||
num_bytes +=
|
||||
(total_length * sizeof(InOutDataType) + invariantLength * 2 * sizeof(AccDataType)) + // No.1
|
||||
(total_length * (1 * sizeof(InOutDataType) + 4 * sizeof(AccDataType)) +
|
||||
total_length * sizeof(InOutDataType)); // No.2
|
||||
|
||||
if(saveMeanAndInvVariance && updateMovingAverage)
|
||||
{
|
||||
auto dev_moving_average_inv_variance = DeviceMovingAverageAndInvVarianceInstance{};
|
||||
|
||||
auto argument_ptr3 = dev_moving_average_inv_variance.MakeArgumentPointer(
|
||||
bnScaleBiasMeanVarLengths,
|
||||
{bnScaleBiasMeanVarStrides,
|
||||
bnScaleBiasMeanVarStrides,
|
||||
bnScaleBiasMeanVarStrides,
|
||||
bnScaleBiasMeanVarStrides},
|
||||
{bnScaleBiasMeanVarStrides, bnScaleBiasMeanVarStrides, bnScaleBiasMeanVarStrides},
|
||||
{p_mean, p_runningMean, p_tmp_meansquare, p_runningVariance},
|
||||
{p_runningMean, p_runningVariance, p_saveInvVariance},
|
||||
MovingAverageAndInvVariance{epsilon, exponentialAverageFactor});
|
||||
|
||||
if(!dev_moving_average_inv_variance.IsSupportedArgument(argument_ptr3.get()))
|
||||
{
|
||||
std::cout << "Runtime parameters not supported by the Device, exiting!" << std::endl;
|
||||
|
||||
return (-1);
|
||||
};
|
||||
|
||||
auto invoker_ptr3 = dev_moving_average_inv_variance.MakeInvokerPointer();
|
||||
|
||||
avg_time += invoker_ptr3->Run(argument_ptr3.get(), StreamConfig{nullptr, time_kernel});
|
||||
|
||||
num_bytes += invariantLength * (4 + 3) * sizeof(AccDataType) * 2; // No.5
|
||||
}
|
||||
else if(saveMeanAndInvVariance)
|
||||
{
|
||||
auto dev_inv_variance = DeviceInvVarianceInstance{};
|
||||
auto argument_ptr3 = dev_inv_variance.MakeArgumentPointer(
|
||||
bnScaleBiasMeanVarLengths,
|
||||
{bnScaleBiasMeanVarStrides, bnScaleBiasMeanVarStrides},
|
||||
{bnScaleBiasMeanVarStrides},
|
||||
{p_mean, p_tmp_meansquare},
|
||||
{p_saveInvVariance},
|
||||
InvVariance{epsilon});
|
||||
|
||||
if(!dev_inv_variance.IsSupportedArgument(argument_ptr3.get()))
|
||||
{
|
||||
std::cout << "Runtime parameters not supported by the Device, exiting!" << std::endl;
|
||||
|
||||
return (-1);
|
||||
};
|
||||
|
||||
auto invoker_ptr3 = dev_inv_variance.MakeInvokerPointer();
|
||||
|
||||
avg_time += invoker_ptr3->Run(argument_ptr3.get(), StreamConfig{nullptr, time_kernel});
|
||||
|
||||
num_bytes += invariantLength * (2 + 1) * sizeof(AccDataType);
|
||||
}
|
||||
else if(updateMovingAverage)
|
||||
{
|
||||
auto dev_moving_average = DeviceMovingAverageInstance{};
|
||||
|
||||
auto argument_ptr3 = dev_moving_average.MakeArgumentPointer(
|
||||
bnScaleBiasMeanVarLengths,
|
||||
{bnScaleBiasMeanVarStrides,
|
||||
bnScaleBiasMeanVarStrides,
|
||||
bnScaleBiasMeanVarStrides,
|
||||
bnScaleBiasMeanVarStrides},
|
||||
{bnScaleBiasMeanVarStrides, bnScaleBiasMeanVarStrides},
|
||||
{p_mean, p_runningMean, p_tmp_meansquare, p_runningVariance},
|
||||
{p_runningMean, p_runningVariance},
|
||||
MovingAverage{exponentialAverageFactor});
|
||||
|
||||
if(!dev_moving_average.IsSupportedArgument(argument_ptr3.get()))
|
||||
{
|
||||
std::cout << "Runtime parameters not supported by the Device, exiting!" << std::endl;
|
||||
|
||||
return (-1);
|
||||
};
|
||||
|
||||
auto invoker_ptr3 = dev_moving_average.MakeInvokerPointer();
|
||||
|
||||
avg_time += invoker_ptr3->Run(argument_ptr3.get(), StreamConfig{nullptr, time_kernel});
|
||||
|
||||
num_bytes += invariantLength * (4 + 2) * sizeof(AccDataType) * 2; // No.5
|
||||
};
|
||||
|
||||
if(time_kernel)
|
||||
{
|
||||
float gb_per_sec = num_bytes / 1.E6 / avg_time;
|
||||
|
||||
std::cout << "Perf: " << avg_time << " ms, " << gb_per_sec << " GB/s" << std::endl;
|
||||
};
|
||||
|
||||
return (0);
|
||||
};
|
||||
@@ -15,13 +15,9 @@
|
||||
#include "ck/library/utility/host_tensor_generator.hpp"
|
||||
#include "ck/library/utility/host_common_util.hpp"
|
||||
#include "ck/library/reference_tensor_operation/cpu/reference_batchnorm_forward_nhwc_c.hpp"
|
||||
|
||||
#include "batchnorm_forward_impl.hpp"
|
||||
|
||||
template <typename InOutDataType, typename AccDataType>
|
||||
using ReferenceBatchNormFwdInstance =
|
||||
ck::tensor_operation::host::ReferenceBatchNormFwd_Input_N_H_W_C_Output_C<InOutDataType,
|
||||
AccDataType>;
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_batchnorm_forward_impl.hpp"
|
||||
#include "ck/library/utility/host_common_util.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
static struct option long_options[] = {{"inOutLengths", required_argument, nullptr, 'D'},
|
||||
{"verify", required_argument, nullptr, 'v'},
|
||||
@@ -41,9 +37,10 @@ class BatchNormFwdArg
|
||||
bool updateMovingAverage;
|
||||
bool saveMeanAndInvVariance;
|
||||
|
||||
int data_type = 0;
|
||||
int init_method = 2;
|
||||
bool time_kernel = false;
|
||||
int data_type = 0;
|
||||
int init_method = 2;
|
||||
bool time_kernel = false;
|
||||
bool use_multiblock_welford = false;
|
||||
|
||||
public:
|
||||
void show_usage(const char* cmd)
|
||||
@@ -68,6 +65,7 @@ class BatchNormFwdArg
|
||||
"value, 3=decimal value)"
|
||||
<< std::endl;
|
||||
std::cout << "Arg5: time kernel (0=no, 1=yes)" << std::endl;
|
||||
std::cout << "Arg6: use multi-block welford (0=n0, 1=yes)" << std::endl;
|
||||
};
|
||||
|
||||
int processArgs(int argc, char* argv[])
|
||||
@@ -110,14 +108,15 @@ class BatchNormFwdArg
|
||||
};
|
||||
};
|
||||
|
||||
if(optind + 5 > argc)
|
||||
if(optind + 6 > argc)
|
||||
throw std::runtime_error("Invalid cmd-line arguments, more argumetns are needed!");
|
||||
|
||||
data_type = std::atoi(argv[optind++]);
|
||||
updateMovingAverage = std::atoi(argv[optind++]);
|
||||
saveMeanAndInvVariance = std::atoi(argv[optind++]);
|
||||
init_method = std::atoi(argv[optind++]);
|
||||
time_kernel = static_cast<bool>(std::atoi(argv[optind]));
|
||||
time_kernel = static_cast<bool>(std::atoi(argv[optind++]));
|
||||
use_multiblock_welford = static_cast<bool>(std::atoi(argv[optind]));
|
||||
|
||||
if(data_type != 0 && data_type != 1 && data_type != 3 && data_type != 5 && data_type != 6)
|
||||
return (-1);
|
||||
@@ -128,7 +127,7 @@ class BatchNormFwdArg
|
||||
|
||||
using namespace ck;
|
||||
|
||||
template <typename InOutDataType, typename AccDataType>
|
||||
template <typename InOutDataType, typename AccDataType, bool UseMultiblockInK>
|
||||
bool bnorm_fwd_nhwc_test(bool do_verification,
|
||||
int init_method,
|
||||
bool time_kernel,
|
||||
@@ -273,73 +272,140 @@ bool bnorm_fwd_nhwc_test(bool do_verification,
|
||||
scaleBiasMeanVarStrides.end(),
|
||||
i_scaleBiasMeanVarStrides.begin());
|
||||
|
||||
int result = 0;
|
||||
using PassThroughOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
// used for saving meansquare
|
||||
DeviceMem workspace(sizeof(AccDataType) * 2 * resultSaveMean_ref.mDesc.GetElementSpaceSize() +
|
||||
128);
|
||||
using DeviceBatchNormFwdInstance =
|
||||
ck::tensor_operation::device::DeviceBatchNormFwdImpl<InOutDataType,
|
||||
InOutDataType,
|
||||
AccDataType,
|
||||
AccDataType, // ScaleDataType
|
||||
AccDataType, // BiasDataType
|
||||
AccDataType, // MeanVarDataType
|
||||
PassThroughOp, // YElementwiseOp
|
||||
Rank,
|
||||
NumReduceDim,
|
||||
UseMultiblockInK,
|
||||
256,
|
||||
16,
|
||||
16,
|
||||
1,
|
||||
2,
|
||||
0,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1>;
|
||||
|
||||
void* p_tmp_mean = workspace.GetDeviceBuffer();
|
||||
void* p_tmp_meansquare =
|
||||
static_cast<char*>(p_tmp_mean) +
|
||||
(sizeof(AccDataType) * resultSaveMean_ref.mDesc.GetElementSpaceSize() + 63) / 64 * 64;
|
||||
auto batchnorm_fwd = DeviceBatchNormFwdInstance{};
|
||||
|
||||
result = bnorm_fwd<InOutDataType, AccDataType, Rank, NumReduceDim, false>(
|
||||
time_kernel,
|
||||
updateMovingAverage,
|
||||
saveMeanAndInvVariance,
|
||||
{0, 1, 2},
|
||||
auto argument_ptr = batchnorm_fwd.MakeArgumentPointer(
|
||||
i_inOutLengths,
|
||||
i_inOutStrides,
|
||||
i_inOutStrides,
|
||||
{0, 1, 2},
|
||||
i_scaleBiasMeanVarLengths,
|
||||
i_scaleBiasMeanVarStrides,
|
||||
i_scaleBiasMeanVarStrides,
|
||||
i_scaleBiasMeanVarStrides,
|
||||
x_dev.GetDeviceBuffer(),
|
||||
bnScale_dev.GetDeviceBuffer(),
|
||||
bnBias_dev.GetDeviceBuffer(),
|
||||
y_dev.GetDeviceBuffer(),
|
||||
averageFactor,
|
||||
updateMovingAverage ? resultRunningMean_dev.GetDeviceBuffer() : nullptr,
|
||||
updateMovingAverage ? resultRunningVariance_dev.GetDeviceBuffer() : nullptr,
|
||||
epsilon,
|
||||
PassThroughOp{},
|
||||
y_dev.GetDeviceBuffer(),
|
||||
saveMeanAndInvVariance ? resultSaveMean_dev.GetDeviceBuffer() : nullptr,
|
||||
saveMeanAndInvVariance ? resultSaveInvVariance_dev.GetDeviceBuffer() : nullptr,
|
||||
p_tmp_mean,
|
||||
p_tmp_meansquare);
|
||||
averageFactor,
|
||||
updateMovingAverage ? resultRunningMean_dev.GetDeviceBuffer() : nullptr,
|
||||
updateMovingAverage ? resultRunningVariance_dev.GetDeviceBuffer() : nullptr);
|
||||
|
||||
if(result < 0)
|
||||
if(!batchnorm_fwd.IsSupportedArgument(argument_ptr.get()))
|
||||
{
|
||||
std::cout << "The runtime parameters seems not supported by the BatchNorm device instance, "
|
||||
"exiting!"
|
||||
<< std::endl;
|
||||
return (false);
|
||||
};
|
||||
|
||||
size_t workspace_sz = batchnorm_fwd.GetWorkSpaceSize(argument_ptr.get());
|
||||
|
||||
DeviceMem workspace_dev(workspace_sz);
|
||||
|
||||
batchnorm_fwd.SetWorkSpacePointer(argument_ptr.get(), workspace_dev.GetDeviceBuffer());
|
||||
|
||||
auto invoker_ptr = batchnorm_fwd.MakeInvokerPointer();
|
||||
|
||||
if(time_kernel)
|
||||
{
|
||||
float avg_time = 0.0f;
|
||||
size_t num_bytes = 0;
|
||||
|
||||
size_t total_length = inOutLengths[0] * inOutLengths[1] * inOutLengths[2] * inOutLengths[3];
|
||||
size_t invariant_length = inOutLengths[3];
|
||||
|
||||
avg_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel});
|
||||
|
||||
// inputing of x, scale, bias, outputing of y
|
||||
num_bytes +=
|
||||
total_length * sizeof(InOutDataType) * 2 + invariant_length * sizeof(AccDataType) * 2;
|
||||
|
||||
// outputing of mean, inv-variance
|
||||
num_bytes += saveMeanAndInvVariance ? invariant_length * sizeof(AccDataType) * 2 : 0;
|
||||
|
||||
// updating of moving mean, variance
|
||||
num_bytes += updateMovingAverage ? invariant_length * sizeof(AccDataType) * 4 : 0;
|
||||
|
||||
float gb_per_sec = num_bytes / 1.E6 / avg_time;
|
||||
|
||||
std::cout << "Perf: " << avg_time << " ms, " << gb_per_sec << " GB/s" << std::endl;
|
||||
}
|
||||
else
|
||||
(void)invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel});
|
||||
|
||||
bool pass = true;
|
||||
|
||||
if(do_verification)
|
||||
{
|
||||
auto batchNormFwd_ref = ReferenceBatchNormFwdInstance<InOutDataType, AccDataType>{};
|
||||
|
||||
using ReferenceBatchNormFwdInstance =
|
||||
ck::tensor_operation::host::ReferenceBatchNormFwd_Input_N_H_W_C_Output_C<InOutDataType,
|
||||
InOutDataType,
|
||||
AccDataType,
|
||||
AccDataType,
|
||||
AccDataType,
|
||||
AccDataType,
|
||||
PassThroughOp>;
|
||||
|
||||
auto batchNormFwd_ref = ReferenceBatchNormFwdInstance{};
|
||||
|
||||
auto argument_ptr_ref = batchNormFwd_ref.MakeArgumentPointer(
|
||||
i_inOutLengths,
|
||||
i_inOutStrides,
|
||||
i_inOutStrides,
|
||||
{0, 1, 2},
|
||||
i_scaleBiasMeanVarLengths,
|
||||
i_scaleBiasMeanVarStrides,
|
||||
i_scaleBiasMeanVarStrides,
|
||||
i_scaleBiasMeanVarStrides,
|
||||
x.mData.data(),
|
||||
bnScale.mData.data(),
|
||||
bnBias.mData.data(),
|
||||
y_ref.mData.data(),
|
||||
0.1, // exponentialAverageFactor
|
||||
updateMovingAverage ? resultRunningMean_ref.mData.data() : nullptr, // resultRunningMean
|
||||
updateMovingAverage ? resultRunningVariance_ref.mData.data()
|
||||
: nullptr, // resultRunningVariance
|
||||
epsilon,
|
||||
PassThroughOp{},
|
||||
y_ref.mData.data(),
|
||||
saveMeanAndInvVariance ? resultSaveMean_ref.mData.data() : nullptr,
|
||||
saveMeanAndInvVariance ? resultSaveInvVariance_ref.mData.data() : nullptr);
|
||||
saveMeanAndInvVariance ? resultSaveInvVariance_ref.mData.data() : nullptr,
|
||||
averageFactor,
|
||||
updateMovingAverage ? resultRunningMean_ref.mData.data() : nullptr,
|
||||
updateMovingAverage ? resultRunningVariance_ref.mData.data() : nullptr);
|
||||
|
||||
if(!batchNormFwd_ref.IsSupportedArgument(argument_ptr_ref.get()))
|
||||
{
|
||||
std::cout
|
||||
<< "The runtime parameters seems not supported by the BatchNorm instance, exiting!"
|
||||
<< std::endl;
|
||||
return (-2);
|
||||
std::cout << "The runtime parameters seems not supported by the BatchNorm reference "
|
||||
"instance, exiting!"
|
||||
<< std::endl;
|
||||
return (false);
|
||||
};
|
||||
|
||||
auto invoker_ptr_ref = batchNormFwd_ref.MakeInvokerPointer();
|
||||
@@ -365,6 +431,8 @@ bool bnorm_fwd_nhwc_test(bool do_verification,
|
||||
|
||||
if(saveMeanAndInvVariance)
|
||||
{
|
||||
using ck::host_common::dumpBufferToFile;
|
||||
|
||||
Tensor<AccDataType> resultSaveMean(scaleBiasMeanVarLengths);
|
||||
Tensor<AccDataType> resultSaveInvVariance(scaleBiasMeanVarLengths);
|
||||
|
||||
@@ -396,70 +464,129 @@ int main(int argc, char* argv[])
|
||||
|
||||
if(arg.data_type == 0)
|
||||
{
|
||||
pass = bnorm_fwd_nhwc_test<ck::half_t, float>(arg.do_verification,
|
||||
arg.init_method,
|
||||
arg.time_kernel,
|
||||
arg.inOutLengths,
|
||||
arg.updateMovingAverage,
|
||||
arg.saveMeanAndInvVariance,
|
||||
averageFactor,
|
||||
epsilon);
|
||||
if(arg.use_multiblock_welford)
|
||||
pass = bnorm_fwd_nhwc_test<ck::half_t, float, true>(arg.do_verification,
|
||||
arg.init_method,
|
||||
arg.time_kernel,
|
||||
arg.inOutLengths,
|
||||
arg.updateMovingAverage,
|
||||
arg.saveMeanAndInvVariance,
|
||||
averageFactor,
|
||||
epsilon);
|
||||
else
|
||||
pass = bnorm_fwd_nhwc_test<ck::half_t, float, false>(arg.do_verification,
|
||||
arg.init_method,
|
||||
arg.time_kernel,
|
||||
arg.inOutLengths,
|
||||
arg.updateMovingAverage,
|
||||
arg.saveMeanAndInvVariance,
|
||||
averageFactor,
|
||||
epsilon);
|
||||
}
|
||||
else if(arg.data_type == 1)
|
||||
{
|
||||
pass = bnorm_fwd_nhwc_test<float, float>(arg.do_verification,
|
||||
arg.init_method,
|
||||
arg.time_kernel,
|
||||
arg.inOutLengths,
|
||||
arg.updateMovingAverage,
|
||||
arg.saveMeanAndInvVariance,
|
||||
averageFactor,
|
||||
epsilon);
|
||||
if(arg.use_multiblock_welford)
|
||||
pass = bnorm_fwd_nhwc_test<float, float, true>(arg.do_verification,
|
||||
arg.init_method,
|
||||
arg.time_kernel,
|
||||
arg.inOutLengths,
|
||||
arg.updateMovingAverage,
|
||||
arg.saveMeanAndInvVariance,
|
||||
averageFactor,
|
||||
epsilon);
|
||||
else
|
||||
pass = bnorm_fwd_nhwc_test<float, float, false>(arg.do_verification,
|
||||
arg.init_method,
|
||||
arg.time_kernel,
|
||||
arg.inOutLengths,
|
||||
arg.updateMovingAverage,
|
||||
arg.saveMeanAndInvVariance,
|
||||
averageFactor,
|
||||
epsilon);
|
||||
}
|
||||
else if(arg.data_type == 3)
|
||||
{
|
||||
pass = bnorm_fwd_nhwc_test<int8_t, float>(arg.do_verification,
|
||||
arg.init_method,
|
||||
arg.time_kernel,
|
||||
arg.inOutLengths,
|
||||
arg.updateMovingAverage,
|
||||
arg.saveMeanAndInvVariance,
|
||||
averageFactor,
|
||||
epsilon);
|
||||
if(arg.use_multiblock_welford)
|
||||
pass = bnorm_fwd_nhwc_test<int8_t, float, true>(arg.do_verification,
|
||||
arg.init_method,
|
||||
arg.time_kernel,
|
||||
arg.inOutLengths,
|
||||
arg.updateMovingAverage,
|
||||
arg.saveMeanAndInvVariance,
|
||||
averageFactor,
|
||||
epsilon);
|
||||
else
|
||||
pass = bnorm_fwd_nhwc_test<int8_t, float, false>(arg.do_verification,
|
||||
arg.init_method,
|
||||
arg.time_kernel,
|
||||
arg.inOutLengths,
|
||||
arg.updateMovingAverage,
|
||||
arg.saveMeanAndInvVariance,
|
||||
averageFactor,
|
||||
epsilon);
|
||||
}
|
||||
else if(arg.data_type == 5)
|
||||
{
|
||||
pass = bnorm_fwd_nhwc_test<ck::bhalf_t, float>(arg.do_verification,
|
||||
arg.init_method,
|
||||
arg.time_kernel,
|
||||
arg.inOutLengths,
|
||||
arg.updateMovingAverage,
|
||||
arg.saveMeanAndInvVariance,
|
||||
averageFactor,
|
||||
epsilon);
|
||||
if(arg.use_multiblock_welford)
|
||||
pass = bnorm_fwd_nhwc_test<ck::bhalf_t, float, true>(arg.do_verification,
|
||||
arg.init_method,
|
||||
arg.time_kernel,
|
||||
arg.inOutLengths,
|
||||
arg.updateMovingAverage,
|
||||
arg.saveMeanAndInvVariance,
|
||||
averageFactor,
|
||||
epsilon);
|
||||
else
|
||||
pass = bnorm_fwd_nhwc_test<ck::bhalf_t, float, false>(arg.do_verification,
|
||||
arg.init_method,
|
||||
arg.time_kernel,
|
||||
arg.inOutLengths,
|
||||
arg.updateMovingAverage,
|
||||
arg.saveMeanAndInvVariance,
|
||||
averageFactor,
|
||||
epsilon);
|
||||
}
|
||||
else if(arg.data_type == 6)
|
||||
{
|
||||
pass = bnorm_fwd_nhwc_test<double, double>(arg.do_verification,
|
||||
arg.init_method,
|
||||
arg.time_kernel,
|
||||
arg.inOutLengths,
|
||||
arg.updateMovingAverage,
|
||||
arg.saveMeanAndInvVariance,
|
||||
averageFactor,
|
||||
epsilon);
|
||||
if(arg.use_multiblock_welford)
|
||||
pass = bnorm_fwd_nhwc_test<double, double, true>(arg.do_verification,
|
||||
arg.init_method,
|
||||
arg.time_kernel,
|
||||
arg.inOutLengths,
|
||||
arg.updateMovingAverage,
|
||||
arg.saveMeanAndInvVariance,
|
||||
averageFactor,
|
||||
epsilon);
|
||||
else
|
||||
pass = bnorm_fwd_nhwc_test<double, double, false>(arg.do_verification,
|
||||
arg.init_method,
|
||||
arg.time_kernel,
|
||||
arg.inOutLengths,
|
||||
arg.updateMovingAverage,
|
||||
arg.saveMeanAndInvVariance,
|
||||
averageFactor,
|
||||
epsilon);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
pass = bnorm_fwd_nhwc_test<ck::half_t, float>(true,
|
||||
2,
|
||||
false, // don't time kernel
|
||||
{128, 16, 16, 1024},
|
||||
true,
|
||||
false,
|
||||
averageFactor,
|
||||
epsilon);
|
||||
pass = bnorm_fwd_nhwc_test<ck::half_t, float, true>(true,
|
||||
2,
|
||||
false, // don't time kernel
|
||||
{128, 16, 6, 512},
|
||||
true,
|
||||
true,
|
||||
averageFactor,
|
||||
epsilon);
|
||||
|
||||
pass = pass && bnorm_fwd_nhwc_test<ck::half_t, float, false>(true,
|
||||
2,
|
||||
false, // don't time kernel
|
||||
{128, 16, 3, 1024},
|
||||
true,
|
||||
true,
|
||||
averageFactor,
|
||||
epsilon);
|
||||
};
|
||||
|
||||
return (pass ? 0 : 1);
|
||||
|
||||
@@ -14,8 +14,12 @@
|
||||
|
||||
#include "batchnorm_common.hpp"
|
||||
|
||||
template <typename InOutDataType,
|
||||
template <typename XDataType,
|
||||
typename YDataType,
|
||||
typename AccDataType,
|
||||
typename ScaleDataType,
|
||||
typename BiasDataType,
|
||||
typename MeanVarDataType,
|
||||
ck::index_t Rank,
|
||||
ck::index_t NumBatchNormReduceDim,
|
||||
bool fastest_dim_is_reduced = false>
|
||||
@@ -26,7 +30,9 @@ int bnorm_infer(
|
||||
const std::array<ck::index_t, Rank> xStrides,
|
||||
const std::array<ck::index_t, Rank> yStrides,
|
||||
const std::array<ck::index_t, Rank - NumBatchNormReduceDim> bnScaleBiasMeanVarLengths,
|
||||
const std::array<ck::index_t, Rank - NumBatchNormReduceDim> bnScaleBiasMeanVarStrides,
|
||||
const std::array<ck::index_t, Rank - NumBatchNormReduceDim> bnScaleStrides,
|
||||
const std::array<ck::index_t, Rank - NumBatchNormReduceDim> bnBiasStrides,
|
||||
const std::array<ck::index_t, Rank - NumBatchNormReduceDim> bnMeanVarStrides,
|
||||
const void* p_x,
|
||||
const void* p_scale,
|
||||
const void* p_bias,
|
||||
@@ -41,11 +47,11 @@ int bnorm_infer(
|
||||
"Invalid number of reduced dimensions for batchnorm!");
|
||||
|
||||
using DeviceNormalizeInstance = ck::tensor_operation::device::DeviceElementwise<
|
||||
ck::Tuple<InOutDataType, AccDataType, AccDataType, AccDataType, AccDataType>, // x, mean,
|
||||
// variance,
|
||||
// scale,
|
||||
// bias,
|
||||
ck::Tuple<InOutDataType>, // y
|
||||
ck::Tuple<XDataType, AccDataType, AccDataType, AccDataType, AccDataType>, // x, mean,
|
||||
// variance,
|
||||
// scale,
|
||||
// bias,
|
||||
ck::Tuple<YDataType>, // y
|
||||
NormalizeInInfer,
|
||||
Rank,
|
||||
2, // MPerthread
|
||||
@@ -53,14 +59,18 @@ int bnorm_infer(
|
||||
ck::Sequence<1>>; // scalarPerVector: y
|
||||
|
||||
auto invariantDims = get_invariant_dims<Rank, NumBatchNormReduceDim>(reduceDims);
|
||||
std::array<ck::index_t, Rank> aligned_scaleBiasMeanVarStrides{0};
|
||||
std::array<ck::index_t, Rank> aligned_bnScaleStrides{0};
|
||||
std::array<ck::index_t, Rank> aligned_bnBiasStrides{0};
|
||||
std::array<ck::index_t, Rank> aligned_bnMeanVarStrides{0};
|
||||
|
||||
int i = 0;
|
||||
for(auto dim : invariantDims)
|
||||
{
|
||||
assert(xyLengths[dim] == bnScaleBiasMeanVarLengths[i]);
|
||||
|
||||
aligned_scaleBiasMeanVarStrides[dim] = bnScaleBiasMeanVarStrides[i];
|
||||
aligned_bnScaleStrides[dim] = bnScaleStrides[i];
|
||||
aligned_bnBiasStrides[dim] = bnBiasStrides[i];
|
||||
aligned_bnMeanVarStrides[dim] = bnMeanVarStrides[i];
|
||||
i++;
|
||||
};
|
||||
|
||||
@@ -84,10 +94,10 @@ int bnorm_infer(
|
||||
auto argument_ptr1 = dev_normalize.MakeArgumentPointer(
|
||||
xyLengths,
|
||||
{xStrides,
|
||||
aligned_scaleBiasMeanVarStrides,
|
||||
aligned_scaleBiasMeanVarStrides,
|
||||
aligned_scaleBiasMeanVarStrides,
|
||||
aligned_scaleBiasMeanVarStrides},
|
||||
aligned_bnMeanVarStrides,
|
||||
aligned_bnMeanVarStrides,
|
||||
aligned_bnScaleStrides,
|
||||
aligned_bnBiasStrides},
|
||||
{yStrides},
|
||||
{p_x, p_estimatedMean, p_estimatedVariance, p_scale, p_bias},
|
||||
{p_y},
|
||||
@@ -105,8 +115,10 @@ int bnorm_infer(
|
||||
|
||||
avg_time += invoker_ptr1->Run(argument_ptr1.get(), StreamConfig{nullptr, time_kernel});
|
||||
|
||||
num_bytes += (total_length * (1 * sizeof(InOutDataType) + 4 * sizeof(AccDataType)) +
|
||||
total_length * sizeof(InOutDataType));
|
||||
num_bytes += total_length * sizeof(XDataType) +
|
||||
invariantLength *
|
||||
(sizeof(ScaleDataType) + sizeof(BiasDataType) + 2 * sizeof(MeanVarDataType)) +
|
||||
total_length * sizeof(YDataType);
|
||||
|
||||
if(time_kernel)
|
||||
{
|
||||
|
||||
@@ -18,11 +18,6 @@
|
||||
|
||||
#include "batchnorm_infer_impl.hpp"
|
||||
|
||||
template <typename InOutDataType, typename AccDataType>
|
||||
using ReferenceBatchNormInferInstance =
|
||||
ck::tensor_operation::host::ReferenceBatchNormInfer_Input_N_H_W_C_Output_C<InOutDataType,
|
||||
AccDataType>;
|
||||
|
||||
static struct option long_options[] = {{"inOutLengths", required_argument, nullptr, 'D'},
|
||||
{"verify", required_argument, nullptr, 'v'},
|
||||
{"help", no_argument, nullptr, '?'},
|
||||
@@ -236,21 +231,30 @@ bool bnorm_infer_nhwc_test(bool do_verification,
|
||||
|
||||
int result = 0;
|
||||
|
||||
result = bnorm_infer<InOutDataType, AccDataType, Rank, NumReduceDim, false>(
|
||||
time_kernel,
|
||||
{0, 1, 2},
|
||||
i_inOutLengths,
|
||||
i_inOutStrides,
|
||||
i_inOutStrides,
|
||||
i_scaleBiasMeanVarLengths,
|
||||
i_scaleBiasMeanVarStrides,
|
||||
x_dev.GetDeviceBuffer(),
|
||||
bnScale_dev.GetDeviceBuffer(),
|
||||
bnBias_dev.GetDeviceBuffer(),
|
||||
epsilon,
|
||||
estimatedMean_dev.GetDeviceBuffer(),
|
||||
estimatedVariance_dev.GetDeviceBuffer(),
|
||||
y_dev.GetDeviceBuffer());
|
||||
result = bnorm_infer<InOutDataType,
|
||||
InOutDataType,
|
||||
AccDataType,
|
||||
AccDataType,
|
||||
AccDataType,
|
||||
AccDataType,
|
||||
Rank,
|
||||
NumReduceDim,
|
||||
false>(time_kernel,
|
||||
{0, 1, 2},
|
||||
i_inOutLengths,
|
||||
i_inOutStrides,
|
||||
i_inOutStrides,
|
||||
i_scaleBiasMeanVarLengths,
|
||||
i_scaleBiasMeanVarStrides,
|
||||
i_scaleBiasMeanVarStrides,
|
||||
i_scaleBiasMeanVarStrides,
|
||||
x_dev.GetDeviceBuffer(),
|
||||
bnScale_dev.GetDeviceBuffer(),
|
||||
bnBias_dev.GetDeviceBuffer(),
|
||||
epsilon,
|
||||
estimatedMean_dev.GetDeviceBuffer(),
|
||||
estimatedVariance_dev.GetDeviceBuffer(),
|
||||
y_dev.GetDeviceBuffer());
|
||||
|
||||
if(result < 0)
|
||||
return (false);
|
||||
@@ -259,7 +263,15 @@ bool bnorm_infer_nhwc_test(bool do_verification,
|
||||
|
||||
if(do_verification)
|
||||
{
|
||||
auto batchNormInfer_ref = ReferenceBatchNormInferInstance<InOutDataType, AccDataType>{};
|
||||
using ReferenceBatchNormInferInstance =
|
||||
ck::tensor_operation::host::ReferenceBatchNormInfer_Input_N_H_W_C_Output_C<
|
||||
InOutDataType,
|
||||
InOutDataType,
|
||||
AccDataType,
|
||||
AccDataType,
|
||||
AccDataType,
|
||||
AccDataType>;
|
||||
auto batchNormInfer_ref = ReferenceBatchNormInferInstance{};
|
||||
|
||||
auto argument_ptr_ref =
|
||||
batchNormInfer_ref.MakeArgumentPointer(i_inOutLengths,
|
||||
@@ -267,6 +279,8 @@ bool bnorm_infer_nhwc_test(bool do_verification,
|
||||
i_inOutStrides,
|
||||
i_scaleBiasMeanVarLengths,
|
||||
i_scaleBiasMeanVarStrides,
|
||||
i_scaleBiasMeanVarStrides,
|
||||
i_scaleBiasMeanVarStrides,
|
||||
x.mData.data(),
|
||||
bnScale.mData.data(),
|
||||
bnBias.mData.data(),
|
||||
|
||||
Reference in New Issue
Block a user