mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 06:49:15 +00:00
Batchnorm-forward implemented using welford method to calculate variance (#403)
* Update to the batchnorm-forward API and base class * Fix leeked header including in gridwise_set_buffer_value.hpp * Add kernels and device file for batchnorm-forward welford supporting both blockwise and multi-block reduction * Update to the batchnorm-forward example to use the new batchnorm-forward device interface * Change the batchnorm-forward reference to use sequential welford method * Change to assign the workspace into four buffers in the host layer * Use GetReduceCountPerThread functor to replace the initial count for Blockwise and Multiblock welford * Tiny correction and remove un-used file under example/34_batchnorm * Renaming in the kernel arguments * Explicitly use ck::math::sqrt in batchnorm-forward kernels * Add some comments to some kernels * Tiny fix * Generalize the data types in reference_batchnorm_forward_nhwc_c * Use ck::ignore to mark un-used parameters * Move GetReduceCountPerThread functor codes from kernel to device * Remove some un-used codes in device_batchnorm_forward_impl.hpp * Tiny fix in batchnorm_forward example * Move GetReduceCountPerThread() to welford_helper.hpp * Use seperate data type for Scale and Bias * Renaming in device Op * Tiny fix in forward example * Updata to batchnorm-infer (type spliting, renaming) * Add time and bandwidth measurement to the batchnorm-forward example * Add support of elementwise operation for batchnorm forward output * Reduce object copying by passing object as reference type * Tiny change for performance * Updates for performance again * Some Renamings * Add GetActualVariance template parameter for ThreadwiseWelfordMerge * Tiny update in reference batchnorm forward nhwc/c * Move batchnorm multiblock kernel files to grid/batchnorm_multiblock sub-directory * Fuse mean and bias in the normalization calculation Co-authored-by: root <root@dc-smc-18.amd.com> Co-authored-by: rocking5566 <ChunYu.Lai@amd.com>
This commit is contained in:
@@ -10,102 +10,17 @@
|
||||
|
||||
#include "ck/utility/data_type.hpp"
|
||||
|
||||
// binary operation used to calculate invVariance from mean and meansquare
|
||||
struct InvVariance
|
||||
{
|
||||
InvVariance(double epsilon) : epsilon_(epsilon){};
|
||||
|
||||
template <typename T>
|
||||
__host__ __device__ constexpr void operator()(T& y, const T& mean, const T& meansquare) const
|
||||
{
|
||||
static_assert(std::is_same<T, float>::value || std::is_same<T, double>::value,
|
||||
"Data type is not supported by this operation!");
|
||||
|
||||
using ck::type_convert;
|
||||
using ck::math::sqrt;
|
||||
|
||||
T tmp_epsilon = type_convert<T>(epsilon_);
|
||||
|
||||
y = meansquare - mean * mean;
|
||||
y = 1.0f / sqrt(tmp_epsilon + y);
|
||||
};
|
||||
|
||||
double epsilon_;
|
||||
};
|
||||
|
||||
// (4-in, 2-out) element-wise operation used to update the moving average of mean and variance
|
||||
struct MovingAverage
|
||||
{
|
||||
MovingAverage(double factor) : factor_(factor){};
|
||||
|
||||
template <typename T>
|
||||
__host__ __device__ constexpr void operator()(T& y0,
|
||||
T& y1,
|
||||
const T& mean,
|
||||
const T& runningMean,
|
||||
const T& meansquare,
|
||||
const T& runningVariance) const
|
||||
{
|
||||
static_assert(std::is_same<T, float>::value || std::is_same<T, double>::value,
|
||||
"Data type is not supported by this operation!");
|
||||
|
||||
using ck::type_convert;
|
||||
|
||||
T tmp_factor = type_convert<T>(factor_);
|
||||
T variance = meansquare - mean * mean;
|
||||
|
||||
y0 = runningMean * (type_convert<T>(1.0f) - tmp_factor) + mean * tmp_factor;
|
||||
y1 = runningVariance * (type_convert<T>(1.0f) - tmp_factor) + variance * tmp_factor;
|
||||
};
|
||||
|
||||
double factor_;
|
||||
};
|
||||
|
||||
struct MovingAverageAndInvVariance
|
||||
{
|
||||
MovingAverageAndInvVariance(double epsilon, double factor)
|
||||
: epsilon_(epsilon), factor_(factor){};
|
||||
|
||||
template <typename T>
|
||||
__host__ __device__ constexpr void operator()(T& y0, // resultRunningMean
|
||||
T& y1, // resultRunningVariance
|
||||
T& y2, // saveInvVariance
|
||||
const T& mean,
|
||||
const T& runningMean,
|
||||
const T& meansquare,
|
||||
const T& runningVariance) const
|
||||
{
|
||||
static_assert(std::is_same<T, float>::value || std::is_same<T, double>::value,
|
||||
"Data type is not supported by this operation!");
|
||||
|
||||
using ck::type_convert;
|
||||
using ck::math::sqrt;
|
||||
|
||||
T tmp_epsilon = type_convert<T>(epsilon_);
|
||||
T tmp_factor = type_convert<T>(factor_);
|
||||
T variance = meansquare - mean * mean;
|
||||
|
||||
y0 = runningMean * (type_convert<T>(1.0f) - tmp_factor) + mean * tmp_factor;
|
||||
y1 = runningVariance * (type_convert<T>(1.0f) - tmp_factor) + variance * tmp_factor;
|
||||
|
||||
y2 = 1.0f / sqrt(tmp_epsilon + variance);
|
||||
};
|
||||
|
||||
double epsilon_;
|
||||
double factor_;
|
||||
};
|
||||
|
||||
struct NormalizeInInfer
|
||||
{
|
||||
NormalizeInInfer(double epsilon = 1e-4) : epsilon_(epsilon) {}
|
||||
|
||||
template <typename T1, typename T2>
|
||||
template <typename T1, typename T2, typename T3, typename T4>
|
||||
__host__ __device__ constexpr void operator()(T1& y,
|
||||
const T1& x,
|
||||
const T2& mean,
|
||||
const T2& variance,
|
||||
const T2& gamma,
|
||||
const T2& beta) const
|
||||
const T3& gamma,
|
||||
const T4& beta) const
|
||||
{
|
||||
static_assert(std::is_same<T2, float>::value || std::is_same<T2, double>::value,
|
||||
"Data type is not supported by this operation!");
|
||||
@@ -117,38 +32,10 @@ struct NormalizeInInfer
|
||||
|
||||
tmp_x = type_convert<T2>(x);
|
||||
|
||||
tmp_y = ((tmp_x - mean) / sqrt(variance + type_convert<T2>(epsilon_))) * gamma + beta;
|
||||
y = type_convert<T1>(tmp_y);
|
||||
};
|
||||
|
||||
double epsilon_;
|
||||
};
|
||||
|
||||
struct NormalizeInForward
|
||||
{
|
||||
NormalizeInForward(double epsilon = 1e-4) : epsilon_(epsilon) {}
|
||||
|
||||
template <typename T1, typename T2>
|
||||
__host__ __device__ constexpr void operator()(T1& y,
|
||||
const T1& x,
|
||||
const T2& mean,
|
||||
const T2& meansquare,
|
||||
const T2& gamma,
|
||||
const T2& beta) const
|
||||
{
|
||||
static_assert(std::is_same<T2, float>::value || std::is_same<T2, double>::value,
|
||||
"Data type is not supported by this operation!");
|
||||
|
||||
using ck::type_convert;
|
||||
using ck::math::sqrt;
|
||||
|
||||
T2 tmp_x, tmp_y;
|
||||
T2 variance = meansquare - mean * mean;
|
||||
|
||||
tmp_x = type_convert<T2>(x);
|
||||
|
||||
tmp_y = ((tmp_x - mean) / sqrt(variance + type_convert<T2>(epsilon_))) * gamma + beta;
|
||||
y = type_convert<T1>(tmp_y);
|
||||
tmp_y = ((tmp_x - mean) / sqrt(variance + type_convert<T2>(epsilon_))) *
|
||||
type_convert<T2>(gamma) +
|
||||
type_convert<T2>(beta);
|
||||
y = type_convert<T1>(tmp_y);
|
||||
};
|
||||
|
||||
double epsilon_;
|
||||
|
||||
Reference in New Issue
Block a user