mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 02:02:46 +00:00
Batchnorm-forward implemented using welford method to calculate variance (#403)
* Update to the batchnorm-forward API and base class
* Fix leeked header including in gridwise_set_buffer_value.hpp
* Add kernels and device file for batchnorm-forward welford supporting both blockwise and multi-block reduction
* Update to the batchnorm-forward example to use the new batchnorm-forward device interface
* Change the batchnorm-forward reference to use sequential welford method
* Change to assign the workspace into four buffers in the host layer
* Use GetReduceCountPerThread functor to replace the initial count for Blockwise and Multiblock welford
* Tiny correction and remove un-used file under example/34_batchnorm
* Renaming in the kernel arguments
* Explicitly use ck::math::sqrt in batchnorm-forward kernels
* Add some comments to some kernels
* Tiny fix
* Generalize the data types in reference_batchnorm_forward_nhwc_c
* Use ck::ignore to mark un-used parameters
* Move GetReduceCountPerThread functor codes from kernel to device
* Remove some un-used codes in device_batchnorm_forward_impl.hpp
* Tiny fix in batchnorm_forward example
* Move GetReduceCountPerThread() to welford_helper.hpp
* Use seperate data type for Scale and Bias
* Renaming in device Op
* Tiny fix in forward example
* Updata to batchnorm-infer (type spliting, renaming)
* Add time and bandwidth measurement to the batchnorm-forward example
* Add support of elementwise operation for batchnorm forward output
* Reduce object copying by passing object as reference type
* Tiny change for performance
* Updates for performance again
* Some Renamings
* Add GetActualVariance template parameter for ThreadwiseWelfordMerge
* Tiny update in reference batchnorm forward nhwc/c
* Move batchnorm multiblock kernel files to grid/batchnorm_multiblock sub-directory
* Fuse mean and bias in the normalization calculation
Co-authored-by: root <root@dc-smc-18.amd.com>
Co-authored-by: rocking5566 <ChunYu.Lai@amd.com>
[ROCm/composable_kernel commit: 7fa892e63e]
This commit is contained in:
@@ -10,102 +10,17 @@
|
||||
|
||||
#include "ck/utility/data_type.hpp"
|
||||
|
||||
// binary operation used to calculate invVariance from mean and meansquare
|
||||
struct InvVariance
|
||||
{
|
||||
InvVariance(double epsilon) : epsilon_(epsilon){};
|
||||
|
||||
template <typename T>
|
||||
__host__ __device__ constexpr void operator()(T& y, const T& mean, const T& meansquare) const
|
||||
{
|
||||
static_assert(std::is_same<T, float>::value || std::is_same<T, double>::value,
|
||||
"Data type is not supported by this operation!");
|
||||
|
||||
using ck::type_convert;
|
||||
using ck::math::sqrt;
|
||||
|
||||
T tmp_epsilon = type_convert<T>(epsilon_);
|
||||
|
||||
y = meansquare - mean * mean;
|
||||
y = 1.0f / sqrt(tmp_epsilon + y);
|
||||
};
|
||||
|
||||
double epsilon_;
|
||||
};
|
||||
|
||||
// (4-in, 2-out) element-wise operation used to update the moving average of mean and variance
|
||||
struct MovingAverage
|
||||
{
|
||||
MovingAverage(double factor) : factor_(factor){};
|
||||
|
||||
template <typename T>
|
||||
__host__ __device__ constexpr void operator()(T& y0,
|
||||
T& y1,
|
||||
const T& mean,
|
||||
const T& runningMean,
|
||||
const T& meansquare,
|
||||
const T& runningVariance) const
|
||||
{
|
||||
static_assert(std::is_same<T, float>::value || std::is_same<T, double>::value,
|
||||
"Data type is not supported by this operation!");
|
||||
|
||||
using ck::type_convert;
|
||||
|
||||
T tmp_factor = type_convert<T>(factor_);
|
||||
T variance = meansquare - mean * mean;
|
||||
|
||||
y0 = runningMean * (type_convert<T>(1.0f) - tmp_factor) + mean * tmp_factor;
|
||||
y1 = runningVariance * (type_convert<T>(1.0f) - tmp_factor) + variance * tmp_factor;
|
||||
};
|
||||
|
||||
double factor_;
|
||||
};
|
||||
|
||||
struct MovingAverageAndInvVariance
|
||||
{
|
||||
MovingAverageAndInvVariance(double epsilon, double factor)
|
||||
: epsilon_(epsilon), factor_(factor){};
|
||||
|
||||
template <typename T>
|
||||
__host__ __device__ constexpr void operator()(T& y0, // resultRunningMean
|
||||
T& y1, // resultRunningVariance
|
||||
T& y2, // saveInvVariance
|
||||
const T& mean,
|
||||
const T& runningMean,
|
||||
const T& meansquare,
|
||||
const T& runningVariance) const
|
||||
{
|
||||
static_assert(std::is_same<T, float>::value || std::is_same<T, double>::value,
|
||||
"Data type is not supported by this operation!");
|
||||
|
||||
using ck::type_convert;
|
||||
using ck::math::sqrt;
|
||||
|
||||
T tmp_epsilon = type_convert<T>(epsilon_);
|
||||
T tmp_factor = type_convert<T>(factor_);
|
||||
T variance = meansquare - mean * mean;
|
||||
|
||||
y0 = runningMean * (type_convert<T>(1.0f) - tmp_factor) + mean * tmp_factor;
|
||||
y1 = runningVariance * (type_convert<T>(1.0f) - tmp_factor) + variance * tmp_factor;
|
||||
|
||||
y2 = 1.0f / sqrt(tmp_epsilon + variance);
|
||||
};
|
||||
|
||||
double epsilon_;
|
||||
double factor_;
|
||||
};
|
||||
|
||||
struct NormalizeInInfer
|
||||
{
|
||||
NormalizeInInfer(double epsilon = 1e-4) : epsilon_(epsilon) {}
|
||||
|
||||
template <typename T1, typename T2>
|
||||
template <typename T1, typename T2, typename T3, typename T4>
|
||||
__host__ __device__ constexpr void operator()(T1& y,
|
||||
const T1& x,
|
||||
const T2& mean,
|
||||
const T2& variance,
|
||||
const T2& gamma,
|
||||
const T2& beta) const
|
||||
const T3& gamma,
|
||||
const T4& beta) const
|
||||
{
|
||||
static_assert(std::is_same<T2, float>::value || std::is_same<T2, double>::value,
|
||||
"Data type is not supported by this operation!");
|
||||
@@ -117,38 +32,10 @@ struct NormalizeInInfer
|
||||
|
||||
tmp_x = type_convert<T2>(x);
|
||||
|
||||
tmp_y = ((tmp_x - mean) / sqrt(variance + type_convert<T2>(epsilon_))) * gamma + beta;
|
||||
y = type_convert<T1>(tmp_y);
|
||||
};
|
||||
|
||||
double epsilon_;
|
||||
};
|
||||
|
||||
struct NormalizeInForward
|
||||
{
|
||||
NormalizeInForward(double epsilon = 1e-4) : epsilon_(epsilon) {}
|
||||
|
||||
template <typename T1, typename T2>
|
||||
__host__ __device__ constexpr void operator()(T1& y,
|
||||
const T1& x,
|
||||
const T2& mean,
|
||||
const T2& meansquare,
|
||||
const T2& gamma,
|
||||
const T2& beta) const
|
||||
{
|
||||
static_assert(std::is_same<T2, float>::value || std::is_same<T2, double>::value,
|
||||
"Data type is not supported by this operation!");
|
||||
|
||||
using ck::type_convert;
|
||||
using ck::math::sqrt;
|
||||
|
||||
T2 tmp_x, tmp_y;
|
||||
T2 variance = meansquare - mean * mean;
|
||||
|
||||
tmp_x = type_convert<T2>(x);
|
||||
|
||||
tmp_y = ((tmp_x - mean) / sqrt(variance + type_convert<T2>(epsilon_))) * gamma + beta;
|
||||
y = type_convert<T1>(tmp_y);
|
||||
tmp_y = ((tmp_x - mean) / sqrt(variance + type_convert<T2>(epsilon_))) *
|
||||
type_convert<T2>(gamma) +
|
||||
type_convert<T2>(beta);
|
||||
y = type_convert<T1>(tmp_y);
|
||||
};
|
||||
|
||||
double epsilon_;
|
||||
|
||||
@@ -1,295 +0,0 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cassert>
|
||||
#include <vector>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/utility/reduction_operator.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_multiple_reduce_multiblock.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_elementwise.hpp"
|
||||
|
||||
#include "batchnorm_common.hpp"
|
||||
|
||||
template <typename InOutDataType,
|
||||
typename AccDataType,
|
||||
ck::index_t Rank,
|
||||
ck::index_t NumBatchNormReduceDim,
|
||||
bool fastest_dim_is_reduced = false>
|
||||
int bnorm_fwd(bool time_kernel,
|
||||
bool updateMovingAverage,
|
||||
bool saveMeanAndInvVariance,
|
||||
const std::array<int, NumBatchNormReduceDim> reduceDims,
|
||||
const std::array<ck::index_t, Rank> xyLengths,
|
||||
const std::array<ck::index_t, Rank> xStrides,
|
||||
const std::array<ck::index_t, Rank> yStrides,
|
||||
const std::array<ck::index_t, Rank - NumBatchNormReduceDim> bnScaleBiasMeanVarLengths,
|
||||
const std::array<ck::index_t, Rank - NumBatchNormReduceDim> bnScaleBiasMeanVarStrides,
|
||||
const void* p_x,
|
||||
const void* p_scale,
|
||||
const void* p_bias,
|
||||
void* p_y,
|
||||
double exponentialAverageFactor,
|
||||
void* p_runningMean,
|
||||
void* p_runningVariance,
|
||||
double epsilon,
|
||||
void* p_saveMean,
|
||||
void* p_saveInvVariance,
|
||||
void* p_tmp_mean,
|
||||
void* p_tmp_meansquare)
|
||||
{
|
||||
static_assert(NumBatchNormReduceDim < Rank,
|
||||
"Invalid number of reduced dimensions for batchnorm!");
|
||||
|
||||
constexpr ck::index_t NumScaleBiasMeanVarDim = Rank - NumBatchNormReduceDim;
|
||||
|
||||
using InElementwiseOperation_Mean = ck::tensor_operation::element_wise::PassThrough;
|
||||
using AccElementwiseOperation_Mean = ck::tensor_operation::element_wise::UnaryDivide;
|
||||
|
||||
using InElementwiseOperation_Meansquare = ck::tensor_operation::element_wise::UnarySquare;
|
||||
using AccElementwiseOperation_Meansquare = ck::tensor_operation::element_wise::UnaryDivide;
|
||||
|
||||
using DeviceMeanAndMeansquareInstance =
|
||||
ck::tensor_operation::device::DeviceMultipleReduceMultiBlock<
|
||||
2,
|
||||
InOutDataType,
|
||||
AccDataType,
|
||||
ck::Tuple<AccDataType, AccDataType>,
|
||||
Rank,
|
||||
NumBatchNormReduceDim,
|
||||
ck::reduce::Add,
|
||||
ck::Tuple<InElementwiseOperation_Mean, InElementwiseOperation_Meansquare>,
|
||||
ck::Tuple<AccElementwiseOperation_Mean, AccElementwiseOperation_Meansquare>,
|
||||
ck::InMemoryDataOperationEnum::Set,
|
||||
false, // PropagateNan
|
||||
256,
|
||||
16,
|
||||
16,
|
||||
1,
|
||||
1,
|
||||
fastest_dim_is_reduced ? 1 : 0,
|
||||
1,
|
||||
ck::Sequence<1, 1>>;
|
||||
|
||||
using DeviceNormalizeInstance = ck::tensor_operation::device::DeviceElementwise<
|
||||
ck::Tuple<InOutDataType, AccDataType, AccDataType, AccDataType, AccDataType>, // x, mean,
|
||||
// meansquare,
|
||||
// scale, bias
|
||||
ck::Tuple<InOutDataType>, // y
|
||||
NormalizeInForward,
|
||||
Rank,
|
||||
2, // MPerthread
|
||||
ck::Sequence<1, 1, 1, 1, 1>, // scalarPerVector: x, mean, meansquare, scale, bias
|
||||
ck::Sequence<1>>; // scalarPerVector: y
|
||||
|
||||
using DeviceInvVarianceInstance = ck::tensor_operation::device::DeviceElementwise<
|
||||
ck::Tuple<AccDataType, AccDataType>, // mean, meansquare
|
||||
ck::Tuple<AccDataType>, // invVariance
|
||||
InvVariance,
|
||||
NumScaleBiasMeanVarDim,
|
||||
2, // MPerthread
|
||||
ck::Sequence<1, 1>, // scalarPerVector: mean, meansquare
|
||||
ck::Sequence<1>>; // scalarPerVector: invVariance
|
||||
|
||||
using DeviceMovingAverageInstance = ck::tensor_operation::device::DeviceElementwise<
|
||||
ck::Tuple<AccDataType, AccDataType, AccDataType, AccDataType>, // old moving mean, new mean,
|
||||
// old moving variance, new
|
||||
// meansquare
|
||||
ck::Tuple<AccDataType, AccDataType>, // updated moving mean, updated moving variance
|
||||
MovingAverage,
|
||||
NumScaleBiasMeanVarDim,
|
||||
4, // MPerthread
|
||||
ck::Sequence<1, 1, 1, 1>, // scalarPerVector: old moving mean, new mean, old moving
|
||||
// variance, new meansquare
|
||||
ck::Sequence<1, 1>>; // scalarPerVector: updated moving mean, updated moving variance
|
||||
|
||||
using DeviceMovingAverageAndInvVarianceInstance =
|
||||
ck::tensor_operation::device::DeviceElementwise<
|
||||
ck::Tuple<AccDataType, AccDataType, AccDataType, AccDataType>, // old moving mean, new
|
||||
// mean, old moving
|
||||
// variance, new
|
||||
// meansquare
|
||||
ck::Tuple<AccDataType, AccDataType, AccDataType>, // updated moving mean, updated moving
|
||||
// variancem, invVariance
|
||||
MovingAverageAndInvVariance,
|
||||
NumScaleBiasMeanVarDim,
|
||||
4, // MPerthread
|
||||
ck::Sequence<1, 1, 1, 1>, // scalarPerVector: old moving mean, new mean, old moving
|
||||
// variance, new meansquare
|
||||
ck::Sequence<1, 1, 1>>; // scalarPerVector: updated moving mean, updated moving variance
|
||||
|
||||
auto invariantDims = get_invariant_dims<Rank, NumBatchNormReduceDim>(reduceDims);
|
||||
std::array<ck::index_t, Rank> aligned_scaleBiasMeanVarStrides{0};
|
||||
|
||||
int i = 0;
|
||||
for(auto dim : invariantDims)
|
||||
{
|
||||
assert(xyLengths[dim] == bnScaleBiasMeanVarLengths[i]);
|
||||
|
||||
aligned_scaleBiasMeanVarStrides[dim] = bnScaleBiasMeanVarStrides[i];
|
||||
i++;
|
||||
};
|
||||
|
||||
int32_t reduceLength = 1;
|
||||
|
||||
for(auto dim : reduceDims)
|
||||
reduceLength *= xyLengths[dim];
|
||||
|
||||
int32_t invariantLength = 1;
|
||||
|
||||
for(auto dim : invariantDims)
|
||||
invariantLength *= xyLengths[dim];
|
||||
|
||||
size_t total_length = static_cast<size_t>(invariantLength) * reduceLength;
|
||||
|
||||
float avg_time = 0.0f;
|
||||
std::size_t num_bytes = 0;
|
||||
|
||||
auto dev_mean_and_meansquare = DeviceMeanAndMeansquareInstance{};
|
||||
|
||||
void* p_mean = saveMeanAndInvVariance ? p_saveMean : p_tmp_mean;
|
||||
|
||||
const AccDataType alpha = ck::type_convert<AccDataType>(1.0f);
|
||||
const AccDataType beta = ck::type_convert<AccDataType>(0.0f);
|
||||
|
||||
auto argument_ptr1 = dev_mean_and_meansquare.MakeArgumentPointer(
|
||||
xyLengths,
|
||||
xStrides,
|
||||
bnScaleBiasMeanVarLengths,
|
||||
{bnScaleBiasMeanVarStrides, bnScaleBiasMeanVarStrides},
|
||||
reduceDims,
|
||||
{&alpha, &alpha},
|
||||
{&beta, &beta},
|
||||
p_x,
|
||||
{p_mean, p_tmp_meansquare},
|
||||
ck::make_tuple(InElementwiseOperation_Mean{}, InElementwiseOperation_Meansquare{}),
|
||||
ck::make_tuple(AccElementwiseOperation_Mean{reduceLength},
|
||||
AccElementwiseOperation_Meansquare{reduceLength}));
|
||||
|
||||
auto dev_normalize = DeviceNormalizeInstance{};
|
||||
|
||||
auto argument_ptr2 =
|
||||
dev_normalize.MakeArgumentPointer(xyLengths,
|
||||
{xStrides,
|
||||
aligned_scaleBiasMeanVarStrides,
|
||||
aligned_scaleBiasMeanVarStrides,
|
||||
aligned_scaleBiasMeanVarStrides,
|
||||
aligned_scaleBiasMeanVarStrides},
|
||||
{yStrides},
|
||||
{p_x, p_mean, p_tmp_meansquare, p_scale, p_bias},
|
||||
{p_y},
|
||||
NormalizeInForward{epsilon});
|
||||
|
||||
if(!dev_mean_and_meansquare.IsSupportedArgument(argument_ptr1.get()) ||
|
||||
!dev_normalize.IsSupportedArgument(argument_ptr2.get()))
|
||||
{
|
||||
std::cout << "The runtime parameters seems not supported by the Devic, exiting!"
|
||||
<< std::endl;
|
||||
|
||||
return (-1);
|
||||
};
|
||||
|
||||
auto invoker_ptr1 = dev_mean_and_meansquare.MakeInvokerPointer();
|
||||
auto invoker_ptr2 = dev_normalize.MakeInvokerPointer();
|
||||
|
||||
avg_time += invoker_ptr1->Run(argument_ptr1.get(), StreamConfig{nullptr, time_kernel});
|
||||
avg_time += invoker_ptr2->Run(argument_ptr2.get(), StreamConfig{nullptr, time_kernel});
|
||||
|
||||
num_bytes +=
|
||||
(total_length * sizeof(InOutDataType) + invariantLength * 2 * sizeof(AccDataType)) + // No.1
|
||||
(total_length * (1 * sizeof(InOutDataType) + 4 * sizeof(AccDataType)) +
|
||||
total_length * sizeof(InOutDataType)); // No.2
|
||||
|
||||
if(saveMeanAndInvVariance && updateMovingAverage)
|
||||
{
|
||||
auto dev_moving_average_inv_variance = DeviceMovingAverageAndInvVarianceInstance{};
|
||||
|
||||
auto argument_ptr3 = dev_moving_average_inv_variance.MakeArgumentPointer(
|
||||
bnScaleBiasMeanVarLengths,
|
||||
{bnScaleBiasMeanVarStrides,
|
||||
bnScaleBiasMeanVarStrides,
|
||||
bnScaleBiasMeanVarStrides,
|
||||
bnScaleBiasMeanVarStrides},
|
||||
{bnScaleBiasMeanVarStrides, bnScaleBiasMeanVarStrides, bnScaleBiasMeanVarStrides},
|
||||
{p_mean, p_runningMean, p_tmp_meansquare, p_runningVariance},
|
||||
{p_runningMean, p_runningVariance, p_saveInvVariance},
|
||||
MovingAverageAndInvVariance{epsilon, exponentialAverageFactor});
|
||||
|
||||
if(!dev_moving_average_inv_variance.IsSupportedArgument(argument_ptr3.get()))
|
||||
{
|
||||
std::cout << "Runtime parameters not supported by the Device, exiting!" << std::endl;
|
||||
|
||||
return (-1);
|
||||
};
|
||||
|
||||
auto invoker_ptr3 = dev_moving_average_inv_variance.MakeInvokerPointer();
|
||||
|
||||
avg_time += invoker_ptr3->Run(argument_ptr3.get(), StreamConfig{nullptr, time_kernel});
|
||||
|
||||
num_bytes += invariantLength * (4 + 3) * sizeof(AccDataType) * 2; // No.5
|
||||
}
|
||||
else if(saveMeanAndInvVariance)
|
||||
{
|
||||
auto dev_inv_variance = DeviceInvVarianceInstance{};
|
||||
auto argument_ptr3 = dev_inv_variance.MakeArgumentPointer(
|
||||
bnScaleBiasMeanVarLengths,
|
||||
{bnScaleBiasMeanVarStrides, bnScaleBiasMeanVarStrides},
|
||||
{bnScaleBiasMeanVarStrides},
|
||||
{p_mean, p_tmp_meansquare},
|
||||
{p_saveInvVariance},
|
||||
InvVariance{epsilon});
|
||||
|
||||
if(!dev_inv_variance.IsSupportedArgument(argument_ptr3.get()))
|
||||
{
|
||||
std::cout << "Runtime parameters not supported by the Device, exiting!" << std::endl;
|
||||
|
||||
return (-1);
|
||||
};
|
||||
|
||||
auto invoker_ptr3 = dev_inv_variance.MakeInvokerPointer();
|
||||
|
||||
avg_time += invoker_ptr3->Run(argument_ptr3.get(), StreamConfig{nullptr, time_kernel});
|
||||
|
||||
num_bytes += invariantLength * (2 + 1) * sizeof(AccDataType);
|
||||
}
|
||||
else if(updateMovingAverage)
|
||||
{
|
||||
auto dev_moving_average = DeviceMovingAverageInstance{};
|
||||
|
||||
auto argument_ptr3 = dev_moving_average.MakeArgumentPointer(
|
||||
bnScaleBiasMeanVarLengths,
|
||||
{bnScaleBiasMeanVarStrides,
|
||||
bnScaleBiasMeanVarStrides,
|
||||
bnScaleBiasMeanVarStrides,
|
||||
bnScaleBiasMeanVarStrides},
|
||||
{bnScaleBiasMeanVarStrides, bnScaleBiasMeanVarStrides},
|
||||
{p_mean, p_runningMean, p_tmp_meansquare, p_runningVariance},
|
||||
{p_runningMean, p_runningVariance},
|
||||
MovingAverage{exponentialAverageFactor});
|
||||
|
||||
if(!dev_moving_average.IsSupportedArgument(argument_ptr3.get()))
|
||||
{
|
||||
std::cout << "Runtime parameters not supported by the Device, exiting!" << std::endl;
|
||||
|
||||
return (-1);
|
||||
};
|
||||
|
||||
auto invoker_ptr3 = dev_moving_average.MakeInvokerPointer();
|
||||
|
||||
avg_time += invoker_ptr3->Run(argument_ptr3.get(), StreamConfig{nullptr, time_kernel});
|
||||
|
||||
num_bytes += invariantLength * (4 + 2) * sizeof(AccDataType) * 2; // No.5
|
||||
};
|
||||
|
||||
if(time_kernel)
|
||||
{
|
||||
float gb_per_sec = num_bytes / 1.E6 / avg_time;
|
||||
|
||||
std::cout << "Perf: " << avg_time << " ms, " << gb_per_sec << " GB/s" << std::endl;
|
||||
};
|
||||
|
||||
return (0);
|
||||
};
|
||||
@@ -15,13 +15,9 @@
|
||||
#include "ck/library/utility/host_tensor_generator.hpp"
|
||||
#include "ck/library/utility/host_common_util.hpp"
|
||||
#include "ck/library/reference_tensor_operation/cpu/reference_batchnorm_forward_nhwc_c.hpp"
|
||||
|
||||
#include "batchnorm_forward_impl.hpp"
|
||||
|
||||
template <typename InOutDataType, typename AccDataType>
|
||||
using ReferenceBatchNormFwdInstance =
|
||||
ck::tensor_operation::host::ReferenceBatchNormFwd_Input_N_H_W_C_Output_C<InOutDataType,
|
||||
AccDataType>;
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_batchnorm_forward_impl.hpp"
|
||||
#include "ck/library/utility/host_common_util.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
static struct option long_options[] = {{"inOutLengths", required_argument, nullptr, 'D'},
|
||||
{"verify", required_argument, nullptr, 'v'},
|
||||
@@ -41,9 +37,10 @@ class BatchNormFwdArg
|
||||
bool updateMovingAverage;
|
||||
bool saveMeanAndInvVariance;
|
||||
|
||||
int data_type = 0;
|
||||
int init_method = 2;
|
||||
bool time_kernel = false;
|
||||
int data_type = 0;
|
||||
int init_method = 2;
|
||||
bool time_kernel = false;
|
||||
bool use_multiblock_welford = false;
|
||||
|
||||
public:
|
||||
void show_usage(const char* cmd)
|
||||
@@ -68,6 +65,7 @@ class BatchNormFwdArg
|
||||
"value, 3=decimal value)"
|
||||
<< std::endl;
|
||||
std::cout << "Arg5: time kernel (0=no, 1=yes)" << std::endl;
|
||||
std::cout << "Arg6: use multi-block welford (0=n0, 1=yes)" << std::endl;
|
||||
};
|
||||
|
||||
int processArgs(int argc, char* argv[])
|
||||
@@ -110,14 +108,15 @@ class BatchNormFwdArg
|
||||
};
|
||||
};
|
||||
|
||||
if(optind + 5 > argc)
|
||||
if(optind + 6 > argc)
|
||||
throw std::runtime_error("Invalid cmd-line arguments, more argumetns are needed!");
|
||||
|
||||
data_type = std::atoi(argv[optind++]);
|
||||
updateMovingAverage = std::atoi(argv[optind++]);
|
||||
saveMeanAndInvVariance = std::atoi(argv[optind++]);
|
||||
init_method = std::atoi(argv[optind++]);
|
||||
time_kernel = static_cast<bool>(std::atoi(argv[optind]));
|
||||
time_kernel = static_cast<bool>(std::atoi(argv[optind++]));
|
||||
use_multiblock_welford = static_cast<bool>(std::atoi(argv[optind]));
|
||||
|
||||
if(data_type != 0 && data_type != 1 && data_type != 3 && data_type != 5 && data_type != 6)
|
||||
return (-1);
|
||||
@@ -128,7 +127,7 @@ class BatchNormFwdArg
|
||||
|
||||
using namespace ck;
|
||||
|
||||
template <typename InOutDataType, typename AccDataType>
|
||||
template <typename InOutDataType, typename AccDataType, bool UseMultiblockInK>
|
||||
bool bnorm_fwd_nhwc_test(bool do_verification,
|
||||
int init_method,
|
||||
bool time_kernel,
|
||||
@@ -273,73 +272,140 @@ bool bnorm_fwd_nhwc_test(bool do_verification,
|
||||
scaleBiasMeanVarStrides.end(),
|
||||
i_scaleBiasMeanVarStrides.begin());
|
||||
|
||||
int result = 0;
|
||||
using PassThroughOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
// used for saving meansquare
|
||||
DeviceMem workspace(sizeof(AccDataType) * 2 * resultSaveMean_ref.mDesc.GetElementSpaceSize() +
|
||||
128);
|
||||
using DeviceBatchNormFwdInstance =
|
||||
ck::tensor_operation::device::DeviceBatchNormFwdImpl<InOutDataType,
|
||||
InOutDataType,
|
||||
AccDataType,
|
||||
AccDataType, // ScaleDataType
|
||||
AccDataType, // BiasDataType
|
||||
AccDataType, // MeanVarDataType
|
||||
PassThroughOp, // YElementwiseOp
|
||||
Rank,
|
||||
NumReduceDim,
|
||||
UseMultiblockInK,
|
||||
256,
|
||||
16,
|
||||
16,
|
||||
1,
|
||||
2,
|
||||
0,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1>;
|
||||
|
||||
void* p_tmp_mean = workspace.GetDeviceBuffer();
|
||||
void* p_tmp_meansquare =
|
||||
static_cast<char*>(p_tmp_mean) +
|
||||
(sizeof(AccDataType) * resultSaveMean_ref.mDesc.GetElementSpaceSize() + 63) / 64 * 64;
|
||||
auto batchnorm_fwd = DeviceBatchNormFwdInstance{};
|
||||
|
||||
result = bnorm_fwd<InOutDataType, AccDataType, Rank, NumReduceDim, false>(
|
||||
time_kernel,
|
||||
updateMovingAverage,
|
||||
saveMeanAndInvVariance,
|
||||
{0, 1, 2},
|
||||
auto argument_ptr = batchnorm_fwd.MakeArgumentPointer(
|
||||
i_inOutLengths,
|
||||
i_inOutStrides,
|
||||
i_inOutStrides,
|
||||
{0, 1, 2},
|
||||
i_scaleBiasMeanVarLengths,
|
||||
i_scaleBiasMeanVarStrides,
|
||||
i_scaleBiasMeanVarStrides,
|
||||
i_scaleBiasMeanVarStrides,
|
||||
x_dev.GetDeviceBuffer(),
|
||||
bnScale_dev.GetDeviceBuffer(),
|
||||
bnBias_dev.GetDeviceBuffer(),
|
||||
y_dev.GetDeviceBuffer(),
|
||||
averageFactor,
|
||||
updateMovingAverage ? resultRunningMean_dev.GetDeviceBuffer() : nullptr,
|
||||
updateMovingAverage ? resultRunningVariance_dev.GetDeviceBuffer() : nullptr,
|
||||
epsilon,
|
||||
PassThroughOp{},
|
||||
y_dev.GetDeviceBuffer(),
|
||||
saveMeanAndInvVariance ? resultSaveMean_dev.GetDeviceBuffer() : nullptr,
|
||||
saveMeanAndInvVariance ? resultSaveInvVariance_dev.GetDeviceBuffer() : nullptr,
|
||||
p_tmp_mean,
|
||||
p_tmp_meansquare);
|
||||
averageFactor,
|
||||
updateMovingAverage ? resultRunningMean_dev.GetDeviceBuffer() : nullptr,
|
||||
updateMovingAverage ? resultRunningVariance_dev.GetDeviceBuffer() : nullptr);
|
||||
|
||||
if(result < 0)
|
||||
if(!batchnorm_fwd.IsSupportedArgument(argument_ptr.get()))
|
||||
{
|
||||
std::cout << "The runtime parameters seems not supported by the BatchNorm device instance, "
|
||||
"exiting!"
|
||||
<< std::endl;
|
||||
return (false);
|
||||
};
|
||||
|
||||
size_t workspace_sz = batchnorm_fwd.GetWorkSpaceSize(argument_ptr.get());
|
||||
|
||||
DeviceMem workspace_dev(workspace_sz);
|
||||
|
||||
batchnorm_fwd.SetWorkSpacePointer(argument_ptr.get(), workspace_dev.GetDeviceBuffer());
|
||||
|
||||
auto invoker_ptr = batchnorm_fwd.MakeInvokerPointer();
|
||||
|
||||
if(time_kernel)
|
||||
{
|
||||
float avg_time = 0.0f;
|
||||
size_t num_bytes = 0;
|
||||
|
||||
size_t total_length = inOutLengths[0] * inOutLengths[1] * inOutLengths[2] * inOutLengths[3];
|
||||
size_t invariant_length = inOutLengths[3];
|
||||
|
||||
avg_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel});
|
||||
|
||||
// inputing of x, scale, bias, outputing of y
|
||||
num_bytes +=
|
||||
total_length * sizeof(InOutDataType) * 2 + invariant_length * sizeof(AccDataType) * 2;
|
||||
|
||||
// outputing of mean, inv-variance
|
||||
num_bytes += saveMeanAndInvVariance ? invariant_length * sizeof(AccDataType) * 2 : 0;
|
||||
|
||||
// updating of moving mean, variance
|
||||
num_bytes += updateMovingAverage ? invariant_length * sizeof(AccDataType) * 4 : 0;
|
||||
|
||||
float gb_per_sec = num_bytes / 1.E6 / avg_time;
|
||||
|
||||
std::cout << "Perf: " << avg_time << " ms, " << gb_per_sec << " GB/s" << std::endl;
|
||||
}
|
||||
else
|
||||
(void)invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel});
|
||||
|
||||
bool pass = true;
|
||||
|
||||
if(do_verification)
|
||||
{
|
||||
auto batchNormFwd_ref = ReferenceBatchNormFwdInstance<InOutDataType, AccDataType>{};
|
||||
|
||||
using ReferenceBatchNormFwdInstance =
|
||||
ck::tensor_operation::host::ReferenceBatchNormFwd_Input_N_H_W_C_Output_C<InOutDataType,
|
||||
InOutDataType,
|
||||
AccDataType,
|
||||
AccDataType,
|
||||
AccDataType,
|
||||
AccDataType,
|
||||
PassThroughOp>;
|
||||
|
||||
auto batchNormFwd_ref = ReferenceBatchNormFwdInstance{};
|
||||
|
||||
auto argument_ptr_ref = batchNormFwd_ref.MakeArgumentPointer(
|
||||
i_inOutLengths,
|
||||
i_inOutStrides,
|
||||
i_inOutStrides,
|
||||
{0, 1, 2},
|
||||
i_scaleBiasMeanVarLengths,
|
||||
i_scaleBiasMeanVarStrides,
|
||||
i_scaleBiasMeanVarStrides,
|
||||
i_scaleBiasMeanVarStrides,
|
||||
x.mData.data(),
|
||||
bnScale.mData.data(),
|
||||
bnBias.mData.data(),
|
||||
y_ref.mData.data(),
|
||||
0.1, // exponentialAverageFactor
|
||||
updateMovingAverage ? resultRunningMean_ref.mData.data() : nullptr, // resultRunningMean
|
||||
updateMovingAverage ? resultRunningVariance_ref.mData.data()
|
||||
: nullptr, // resultRunningVariance
|
||||
epsilon,
|
||||
PassThroughOp{},
|
||||
y_ref.mData.data(),
|
||||
saveMeanAndInvVariance ? resultSaveMean_ref.mData.data() : nullptr,
|
||||
saveMeanAndInvVariance ? resultSaveInvVariance_ref.mData.data() : nullptr);
|
||||
saveMeanAndInvVariance ? resultSaveInvVariance_ref.mData.data() : nullptr,
|
||||
averageFactor,
|
||||
updateMovingAverage ? resultRunningMean_ref.mData.data() : nullptr,
|
||||
updateMovingAverage ? resultRunningVariance_ref.mData.data() : nullptr);
|
||||
|
||||
if(!batchNormFwd_ref.IsSupportedArgument(argument_ptr_ref.get()))
|
||||
{
|
||||
std::cout
|
||||
<< "The runtime parameters seems not supported by the BatchNorm instance, exiting!"
|
||||
<< std::endl;
|
||||
return (-2);
|
||||
std::cout << "The runtime parameters seems not supported by the BatchNorm reference "
|
||||
"instance, exiting!"
|
||||
<< std::endl;
|
||||
return (false);
|
||||
};
|
||||
|
||||
auto invoker_ptr_ref = batchNormFwd_ref.MakeInvokerPointer();
|
||||
@@ -365,6 +431,8 @@ bool bnorm_fwd_nhwc_test(bool do_verification,
|
||||
|
||||
if(saveMeanAndInvVariance)
|
||||
{
|
||||
using ck::host_common::dumpBufferToFile;
|
||||
|
||||
Tensor<AccDataType> resultSaveMean(scaleBiasMeanVarLengths);
|
||||
Tensor<AccDataType> resultSaveInvVariance(scaleBiasMeanVarLengths);
|
||||
|
||||
@@ -396,70 +464,129 @@ int main(int argc, char* argv[])
|
||||
|
||||
if(arg.data_type == 0)
|
||||
{
|
||||
pass = bnorm_fwd_nhwc_test<ck::half_t, float>(arg.do_verification,
|
||||
arg.init_method,
|
||||
arg.time_kernel,
|
||||
arg.inOutLengths,
|
||||
arg.updateMovingAverage,
|
||||
arg.saveMeanAndInvVariance,
|
||||
averageFactor,
|
||||
epsilon);
|
||||
if(arg.use_multiblock_welford)
|
||||
pass = bnorm_fwd_nhwc_test<ck::half_t, float, true>(arg.do_verification,
|
||||
arg.init_method,
|
||||
arg.time_kernel,
|
||||
arg.inOutLengths,
|
||||
arg.updateMovingAverage,
|
||||
arg.saveMeanAndInvVariance,
|
||||
averageFactor,
|
||||
epsilon);
|
||||
else
|
||||
pass = bnorm_fwd_nhwc_test<ck::half_t, float, false>(arg.do_verification,
|
||||
arg.init_method,
|
||||
arg.time_kernel,
|
||||
arg.inOutLengths,
|
||||
arg.updateMovingAverage,
|
||||
arg.saveMeanAndInvVariance,
|
||||
averageFactor,
|
||||
epsilon);
|
||||
}
|
||||
else if(arg.data_type == 1)
|
||||
{
|
||||
pass = bnorm_fwd_nhwc_test<float, float>(arg.do_verification,
|
||||
arg.init_method,
|
||||
arg.time_kernel,
|
||||
arg.inOutLengths,
|
||||
arg.updateMovingAverage,
|
||||
arg.saveMeanAndInvVariance,
|
||||
averageFactor,
|
||||
epsilon);
|
||||
if(arg.use_multiblock_welford)
|
||||
pass = bnorm_fwd_nhwc_test<float, float, true>(arg.do_verification,
|
||||
arg.init_method,
|
||||
arg.time_kernel,
|
||||
arg.inOutLengths,
|
||||
arg.updateMovingAverage,
|
||||
arg.saveMeanAndInvVariance,
|
||||
averageFactor,
|
||||
epsilon);
|
||||
else
|
||||
pass = bnorm_fwd_nhwc_test<float, float, false>(arg.do_verification,
|
||||
arg.init_method,
|
||||
arg.time_kernel,
|
||||
arg.inOutLengths,
|
||||
arg.updateMovingAverage,
|
||||
arg.saveMeanAndInvVariance,
|
||||
averageFactor,
|
||||
epsilon);
|
||||
}
|
||||
else if(arg.data_type == 3)
|
||||
{
|
||||
pass = bnorm_fwd_nhwc_test<int8_t, float>(arg.do_verification,
|
||||
arg.init_method,
|
||||
arg.time_kernel,
|
||||
arg.inOutLengths,
|
||||
arg.updateMovingAverage,
|
||||
arg.saveMeanAndInvVariance,
|
||||
averageFactor,
|
||||
epsilon);
|
||||
if(arg.use_multiblock_welford)
|
||||
pass = bnorm_fwd_nhwc_test<int8_t, float, true>(arg.do_verification,
|
||||
arg.init_method,
|
||||
arg.time_kernel,
|
||||
arg.inOutLengths,
|
||||
arg.updateMovingAverage,
|
||||
arg.saveMeanAndInvVariance,
|
||||
averageFactor,
|
||||
epsilon);
|
||||
else
|
||||
pass = bnorm_fwd_nhwc_test<int8_t, float, false>(arg.do_verification,
|
||||
arg.init_method,
|
||||
arg.time_kernel,
|
||||
arg.inOutLengths,
|
||||
arg.updateMovingAverage,
|
||||
arg.saveMeanAndInvVariance,
|
||||
averageFactor,
|
||||
epsilon);
|
||||
}
|
||||
else if(arg.data_type == 5)
|
||||
{
|
||||
pass = bnorm_fwd_nhwc_test<ck::bhalf_t, float>(arg.do_verification,
|
||||
arg.init_method,
|
||||
arg.time_kernel,
|
||||
arg.inOutLengths,
|
||||
arg.updateMovingAverage,
|
||||
arg.saveMeanAndInvVariance,
|
||||
averageFactor,
|
||||
epsilon);
|
||||
if(arg.use_multiblock_welford)
|
||||
pass = bnorm_fwd_nhwc_test<ck::bhalf_t, float, true>(arg.do_verification,
|
||||
arg.init_method,
|
||||
arg.time_kernel,
|
||||
arg.inOutLengths,
|
||||
arg.updateMovingAverage,
|
||||
arg.saveMeanAndInvVariance,
|
||||
averageFactor,
|
||||
epsilon);
|
||||
else
|
||||
pass = bnorm_fwd_nhwc_test<ck::bhalf_t, float, false>(arg.do_verification,
|
||||
arg.init_method,
|
||||
arg.time_kernel,
|
||||
arg.inOutLengths,
|
||||
arg.updateMovingAverage,
|
||||
arg.saveMeanAndInvVariance,
|
||||
averageFactor,
|
||||
epsilon);
|
||||
}
|
||||
else if(arg.data_type == 6)
|
||||
{
|
||||
pass = bnorm_fwd_nhwc_test<double, double>(arg.do_verification,
|
||||
arg.init_method,
|
||||
arg.time_kernel,
|
||||
arg.inOutLengths,
|
||||
arg.updateMovingAverage,
|
||||
arg.saveMeanAndInvVariance,
|
||||
averageFactor,
|
||||
epsilon);
|
||||
if(arg.use_multiblock_welford)
|
||||
pass = bnorm_fwd_nhwc_test<double, double, true>(arg.do_verification,
|
||||
arg.init_method,
|
||||
arg.time_kernel,
|
||||
arg.inOutLengths,
|
||||
arg.updateMovingAverage,
|
||||
arg.saveMeanAndInvVariance,
|
||||
averageFactor,
|
||||
epsilon);
|
||||
else
|
||||
pass = bnorm_fwd_nhwc_test<double, double, false>(arg.do_verification,
|
||||
arg.init_method,
|
||||
arg.time_kernel,
|
||||
arg.inOutLengths,
|
||||
arg.updateMovingAverage,
|
||||
arg.saveMeanAndInvVariance,
|
||||
averageFactor,
|
||||
epsilon);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
pass = bnorm_fwd_nhwc_test<ck::half_t, float>(true,
|
||||
2,
|
||||
false, // don't time kernel
|
||||
{128, 16, 16, 1024},
|
||||
true,
|
||||
false,
|
||||
averageFactor,
|
||||
epsilon);
|
||||
pass = bnorm_fwd_nhwc_test<ck::half_t, float, true>(true,
|
||||
2,
|
||||
false, // don't time kernel
|
||||
{128, 16, 6, 512},
|
||||
true,
|
||||
true,
|
||||
averageFactor,
|
||||
epsilon);
|
||||
|
||||
pass = pass && bnorm_fwd_nhwc_test<ck::half_t, float, false>(true,
|
||||
2,
|
||||
false, // don't time kernel
|
||||
{128, 16, 3, 1024},
|
||||
true,
|
||||
true,
|
||||
averageFactor,
|
||||
epsilon);
|
||||
};
|
||||
|
||||
return (pass ? 0 : 1);
|
||||
|
||||
@@ -14,8 +14,12 @@
|
||||
|
||||
#include "batchnorm_common.hpp"
|
||||
|
||||
template <typename InOutDataType,
|
||||
template <typename XDataType,
|
||||
typename YDataType,
|
||||
typename AccDataType,
|
||||
typename ScaleDataType,
|
||||
typename BiasDataType,
|
||||
typename MeanVarDataType,
|
||||
ck::index_t Rank,
|
||||
ck::index_t NumBatchNormReduceDim,
|
||||
bool fastest_dim_is_reduced = false>
|
||||
@@ -26,7 +30,9 @@ int bnorm_infer(
|
||||
const std::array<ck::index_t, Rank> xStrides,
|
||||
const std::array<ck::index_t, Rank> yStrides,
|
||||
const std::array<ck::index_t, Rank - NumBatchNormReduceDim> bnScaleBiasMeanVarLengths,
|
||||
const std::array<ck::index_t, Rank - NumBatchNormReduceDim> bnScaleBiasMeanVarStrides,
|
||||
const std::array<ck::index_t, Rank - NumBatchNormReduceDim> bnScaleStrides,
|
||||
const std::array<ck::index_t, Rank - NumBatchNormReduceDim> bnBiasStrides,
|
||||
const std::array<ck::index_t, Rank - NumBatchNormReduceDim> bnMeanVarStrides,
|
||||
const void* p_x,
|
||||
const void* p_scale,
|
||||
const void* p_bias,
|
||||
@@ -41,11 +47,11 @@ int bnorm_infer(
|
||||
"Invalid number of reduced dimensions for batchnorm!");
|
||||
|
||||
using DeviceNormalizeInstance = ck::tensor_operation::device::DeviceElementwise<
|
||||
ck::Tuple<InOutDataType, AccDataType, AccDataType, AccDataType, AccDataType>, // x, mean,
|
||||
// variance,
|
||||
// scale,
|
||||
// bias,
|
||||
ck::Tuple<InOutDataType>, // y
|
||||
ck::Tuple<XDataType, AccDataType, AccDataType, AccDataType, AccDataType>, // x, mean,
|
||||
// variance,
|
||||
// scale,
|
||||
// bias,
|
||||
ck::Tuple<YDataType>, // y
|
||||
NormalizeInInfer,
|
||||
Rank,
|
||||
2, // MPerthread
|
||||
@@ -53,14 +59,18 @@ int bnorm_infer(
|
||||
ck::Sequence<1>>; // scalarPerVector: y
|
||||
|
||||
auto invariantDims = get_invariant_dims<Rank, NumBatchNormReduceDim>(reduceDims);
|
||||
std::array<ck::index_t, Rank> aligned_scaleBiasMeanVarStrides{0};
|
||||
std::array<ck::index_t, Rank> aligned_bnScaleStrides{0};
|
||||
std::array<ck::index_t, Rank> aligned_bnBiasStrides{0};
|
||||
std::array<ck::index_t, Rank> aligned_bnMeanVarStrides{0};
|
||||
|
||||
int i = 0;
|
||||
for(auto dim : invariantDims)
|
||||
{
|
||||
assert(xyLengths[dim] == bnScaleBiasMeanVarLengths[i]);
|
||||
|
||||
aligned_scaleBiasMeanVarStrides[dim] = bnScaleBiasMeanVarStrides[i];
|
||||
aligned_bnScaleStrides[dim] = bnScaleStrides[i];
|
||||
aligned_bnBiasStrides[dim] = bnBiasStrides[i];
|
||||
aligned_bnMeanVarStrides[dim] = bnMeanVarStrides[i];
|
||||
i++;
|
||||
};
|
||||
|
||||
@@ -84,10 +94,10 @@ int bnorm_infer(
|
||||
auto argument_ptr1 = dev_normalize.MakeArgumentPointer(
|
||||
xyLengths,
|
||||
{xStrides,
|
||||
aligned_scaleBiasMeanVarStrides,
|
||||
aligned_scaleBiasMeanVarStrides,
|
||||
aligned_scaleBiasMeanVarStrides,
|
||||
aligned_scaleBiasMeanVarStrides},
|
||||
aligned_bnMeanVarStrides,
|
||||
aligned_bnMeanVarStrides,
|
||||
aligned_bnScaleStrides,
|
||||
aligned_bnBiasStrides},
|
||||
{yStrides},
|
||||
{p_x, p_estimatedMean, p_estimatedVariance, p_scale, p_bias},
|
||||
{p_y},
|
||||
@@ -105,8 +115,10 @@ int bnorm_infer(
|
||||
|
||||
avg_time += invoker_ptr1->Run(argument_ptr1.get(), StreamConfig{nullptr, time_kernel});
|
||||
|
||||
num_bytes += (total_length * (1 * sizeof(InOutDataType) + 4 * sizeof(AccDataType)) +
|
||||
total_length * sizeof(InOutDataType));
|
||||
num_bytes += total_length * sizeof(XDataType) +
|
||||
invariantLength *
|
||||
(sizeof(ScaleDataType) + sizeof(BiasDataType) + 2 * sizeof(MeanVarDataType)) +
|
||||
total_length * sizeof(YDataType);
|
||||
|
||||
if(time_kernel)
|
||||
{
|
||||
|
||||
@@ -18,11 +18,6 @@
|
||||
|
||||
#include "batchnorm_infer_impl.hpp"
|
||||
|
||||
template <typename InOutDataType, typename AccDataType>
|
||||
using ReferenceBatchNormInferInstance =
|
||||
ck::tensor_operation::host::ReferenceBatchNormInfer_Input_N_H_W_C_Output_C<InOutDataType,
|
||||
AccDataType>;
|
||||
|
||||
static struct option long_options[] = {{"inOutLengths", required_argument, nullptr, 'D'},
|
||||
{"verify", required_argument, nullptr, 'v'},
|
||||
{"help", no_argument, nullptr, '?'},
|
||||
@@ -236,21 +231,30 @@ bool bnorm_infer_nhwc_test(bool do_verification,
|
||||
|
||||
int result = 0;
|
||||
|
||||
result = bnorm_infer<InOutDataType, AccDataType, Rank, NumReduceDim, false>(
|
||||
time_kernel,
|
||||
{0, 1, 2},
|
||||
i_inOutLengths,
|
||||
i_inOutStrides,
|
||||
i_inOutStrides,
|
||||
i_scaleBiasMeanVarLengths,
|
||||
i_scaleBiasMeanVarStrides,
|
||||
x_dev.GetDeviceBuffer(),
|
||||
bnScale_dev.GetDeviceBuffer(),
|
||||
bnBias_dev.GetDeviceBuffer(),
|
||||
epsilon,
|
||||
estimatedMean_dev.GetDeviceBuffer(),
|
||||
estimatedVariance_dev.GetDeviceBuffer(),
|
||||
y_dev.GetDeviceBuffer());
|
||||
result = bnorm_infer<InOutDataType,
|
||||
InOutDataType,
|
||||
AccDataType,
|
||||
AccDataType,
|
||||
AccDataType,
|
||||
AccDataType,
|
||||
Rank,
|
||||
NumReduceDim,
|
||||
false>(time_kernel,
|
||||
{0, 1, 2},
|
||||
i_inOutLengths,
|
||||
i_inOutStrides,
|
||||
i_inOutStrides,
|
||||
i_scaleBiasMeanVarLengths,
|
||||
i_scaleBiasMeanVarStrides,
|
||||
i_scaleBiasMeanVarStrides,
|
||||
i_scaleBiasMeanVarStrides,
|
||||
x_dev.GetDeviceBuffer(),
|
||||
bnScale_dev.GetDeviceBuffer(),
|
||||
bnBias_dev.GetDeviceBuffer(),
|
||||
epsilon,
|
||||
estimatedMean_dev.GetDeviceBuffer(),
|
||||
estimatedVariance_dev.GetDeviceBuffer(),
|
||||
y_dev.GetDeviceBuffer());
|
||||
|
||||
if(result < 0)
|
||||
return (false);
|
||||
@@ -259,7 +263,15 @@ bool bnorm_infer_nhwc_test(bool do_verification,
|
||||
|
||||
if(do_verification)
|
||||
{
|
||||
auto batchNormInfer_ref = ReferenceBatchNormInferInstance<InOutDataType, AccDataType>{};
|
||||
using ReferenceBatchNormInferInstance =
|
||||
ck::tensor_operation::host::ReferenceBatchNormInfer_Input_N_H_W_C_Output_C<
|
||||
InOutDataType,
|
||||
InOutDataType,
|
||||
AccDataType,
|
||||
AccDataType,
|
||||
AccDataType,
|
||||
AccDataType>;
|
||||
auto batchNormInfer_ref = ReferenceBatchNormInferInstance{};
|
||||
|
||||
auto argument_ptr_ref =
|
||||
batchNormInfer_ref.MakeArgumentPointer(i_inOutLengths,
|
||||
@@ -267,6 +279,8 @@ bool bnorm_infer_nhwc_test(bool do_verification,
|
||||
i_inOutStrides,
|
||||
i_scaleBiasMeanVarLengths,
|
||||
i_scaleBiasMeanVarStrides,
|
||||
i_scaleBiasMeanVarStrides,
|
||||
i_scaleBiasMeanVarStrides,
|
||||
x.mData.data(),
|
||||
bnScale.mData.data(),
|
||||
bnBias.mData.data(),
|
||||
|
||||
@@ -13,31 +13,36 @@ namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
template <index_t Rank, index_t NumBatchNormReduceDim>
|
||||
template <index_t Rank, index_t NumBatchNormReduceDim, typename YElementwiseOp>
|
||||
struct DeviceBatchNormFwd : public BaseOperator
|
||||
{
|
||||
virtual std::unique_ptr<BaseArgument> MakeArgumentPointer(
|
||||
const std::array<index_t, Rank> xyLengths,
|
||||
const std::array<index_t, Rank> xStrides,
|
||||
const std::array<index_t, Rank> yStrides,
|
||||
const std::array<int, NumBatchNormReduceDim> reduceDims,
|
||||
const std::array<index_t, Rank - NumBatchNormReduceDim> bnScaleBiasMeanVarLengths,
|
||||
const std::array<index_t, Rank - NumBatchNormReduceDim> bnScaleBiasMeanVarStrides,
|
||||
const std::array<index_t, Rank - NumBatchNormReduceDim> bnScaleStrides,
|
||||
const std::array<index_t, Rank - NumBatchNormReduceDim> bnBiasStrides,
|
||||
const std::array<index_t, Rank - NumBatchNormReduceDim> bnMeanVarStrides,
|
||||
const void* p_x,
|
||||
const void* bnScale,
|
||||
const void* bnBias,
|
||||
double epsilon,
|
||||
const YElementwiseOp y_elementwise_op,
|
||||
void* p_y,
|
||||
void* resultSaveMean,
|
||||
void* resultSaveInvVariance,
|
||||
double exponentialAverageFactor,
|
||||
void* resultRunningMean,
|
||||
void* resultRunningVariance,
|
||||
double epsilon,
|
||||
void* resultSaveMean,
|
||||
void* resultSaveInvVariance) = 0;
|
||||
void* resultRunningVariance) = 0;
|
||||
|
||||
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
|
||||
};
|
||||
|
||||
template <index_t Rank, index_t NumBatchNormReduceDim>
|
||||
using DeviceBatchNormFwdPtr = std::unique_ptr<DeviceBatchNormFwd<Rank, NumBatchNormReduceDim>>;
|
||||
template <index_t Rank, index_t NumBatchNormReduceDim, typename YElementwiseOp>
|
||||
using DeviceBatchNormFwdPtr =
|
||||
std::unique_ptr<DeviceBatchNormFwd<Rank, NumBatchNormReduceDim, YElementwiseOp>>;
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
|
||||
@@ -21,7 +21,9 @@ struct DeviceBatchNormInfer : public BaseOperator
|
||||
const std::array<index_t, Rank> xStrides,
|
||||
const std::array<index_t, Rank> yStrides,
|
||||
const std::array<index_t, Rank - NumBatchNormReduceDim> bnScaleBiasMeanVarLengths,
|
||||
const std::array<index_t, Rank - NumBatchNormReduceDim> bnScaleBiasMeanVarStrides,
|
||||
const std::array<index_t, Rank - NumBatchNormReduceDim> bnScaleStrides,
|
||||
const std::array<index_t, Rank - NumBatchNormReduceDim> bnBiasStrides,
|
||||
const std::array<index_t, Rank - NumBatchNormReduceDim> bnMeanVarStrides,
|
||||
const void* p_x,
|
||||
const void* bnScale,
|
||||
const void* bnBias,
|
||||
|
||||
@@ -0,0 +1,711 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
|
||||
#include "ck/utility/reduction_operator.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_batchnorm_forward.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_reduce_common.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/welford_helper.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_welford_first_half.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_welford_second_half_batchnorm_forward_final.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_batchnorm_forward_blockwise_welford.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
#include "ck/host_utility/kernel_launch.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
template <typename XDataType,
|
||||
typename YDataType,
|
||||
typename AccDataType,
|
||||
typename ScaleDataType,
|
||||
typename BiasDataType,
|
||||
typename MeanVarDataType,
|
||||
typename YElementwiseOp,
|
||||
index_t Rank,
|
||||
index_t NumBatchNormReduceDim,
|
||||
bool UseMultiblockInK,
|
||||
index_t BlockSize,
|
||||
index_t MThreadClusterSize,
|
||||
index_t KThreadClusterSize,
|
||||
index_t MThreadSliceSize,
|
||||
index_t KThreadSliceSize,
|
||||
index_t XSrcYDstVectorDim,
|
||||
index_t XSrcVectorSize,
|
||||
index_t YDstVectorSize,
|
||||
index_t ScaleSrcVectorSize,
|
||||
index_t BiasSrcVectorSize,
|
||||
index_t MeanVarSrcDstVectorSize>
|
||||
struct DeviceBatchNormFwdImpl
|
||||
: public DeviceBatchNormFwd<Rank, NumBatchNormReduceDim, YElementwiseOp>
|
||||
{
|
||||
static_assert(Rank <= 6, "Bigger Rank size is not supported!");
|
||||
static_assert(BlockSize == MThreadClusterSize * KThreadClusterSize,
|
||||
"Invalid thread cluster size assignments!");
|
||||
|
||||
static_assert((XSrcYDstVectorDim == 0 && MThreadSliceSize % XSrcVectorSize == 0) ||
|
||||
(XSrcYDstVectorDim == 1 && KThreadSliceSize % XSrcVectorSize == 0),
|
||||
"Invalid thread slice sizes and/or vector sizes configuration, please check!");
|
||||
|
||||
static constexpr index_t NumInvariantDim = Rank - NumBatchNormReduceDim;
|
||||
|
||||
static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize;
|
||||
static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize;
|
||||
|
||||
static auto MakeXY2dDescriptor(const std::array<index_t, Rank>& xyLengths,
|
||||
const std::array<index_t, Rank>& xyStrides,
|
||||
int blkGroupSize,
|
||||
int numBlockTileIteration)
|
||||
{
|
||||
const auto tupleXYLengths =
|
||||
generate_tuple([&](auto I) { return xyLengths[I]; }, Number<Rank>{});
|
||||
const auto tupleXYStrides =
|
||||
generate_tuple([&](auto I) { return xyStrides[I]; }, Number<Rank>{});
|
||||
|
||||
const auto raw_grid_desc = make_naive_tensor_descriptor(tupleXYLengths, tupleXYStrides);
|
||||
|
||||
const auto grid_desc_m_k = [&]() {
|
||||
using InvariantDims = typename arithmetic_sequence_gen<0, NumInvariantDim, 1>::type;
|
||||
using ReduceDims = typename arithmetic_sequence_gen<NumInvariantDim, Rank, 1>::type;
|
||||
|
||||
const auto reduceDimLengths =
|
||||
generate_tuple([&](auto I) { return xyLengths[NumInvariantDim + I]; },
|
||||
Number<NumBatchNormReduceDim>{});
|
||||
const auto invariantDimLengths =
|
||||
generate_tuple([&](auto I) { return xyLengths[I]; }, Number<NumInvariantDim>{});
|
||||
|
||||
return transform_tensor_descriptor(raw_grid_desc,
|
||||
make_tuple(make_merge_transform(invariantDimLengths),
|
||||
make_merge_transform(reduceDimLengths)),
|
||||
make_tuple(InvariantDims{}, ReduceDims{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
}();
|
||||
|
||||
const auto invariantLength = grid_desc_m_k.GetLength(Number<0>{});
|
||||
const auto reduceLength = grid_desc_m_k.GetLength(Number<1>{});
|
||||
|
||||
const int workSizePerBlock = K_BlockTileSize * numBlockTileIteration;
|
||||
const auto mPad =
|
||||
math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength;
|
||||
const auto kPad = workSizePerBlock * blkGroupSize - reduceLength;
|
||||
|
||||
auto grid_desc_m_k_padded =
|
||||
transform_tensor_descriptor(grid_desc_m_k,
|
||||
make_tuple(make_right_pad_transform(invariantLength, mPad),
|
||||
make_right_pad_transform(reduceLength, kPad)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
return (grid_desc_m_k_padded);
|
||||
};
|
||||
|
||||
static auto MakeMeanVarCountOutputMG2dDescriptor(int invariantLength, int blkGroupSize)
|
||||
{
|
||||
const auto grid_desc_m_g =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(invariantLength, blkGroupSize));
|
||||
|
||||
const auto mPad =
|
||||
math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength;
|
||||
|
||||
auto grid_desc_m_g_padded =
|
||||
transform_tensor_descriptor(grid_desc_m_g,
|
||||
make_tuple(make_right_pad_transform(invariantLength, mPad),
|
||||
make_pass_through_transform(blkGroupSize)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
return (grid_desc_m_g_padded);
|
||||
};
|
||||
|
||||
static auto MakeMeanVarCountInputMK2dDescriptor(int invariantLength, int blkGroupSize)
|
||||
{
|
||||
const auto reduceLength = blkGroupSize;
|
||||
const auto grid_desc_m_k =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(invariantLength, reduceLength));
|
||||
|
||||
const auto mPad =
|
||||
math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength;
|
||||
const auto kPad =
|
||||
math::integer_least_multiple(reduceLength, KThreadClusterSize) - reduceLength;
|
||||
|
||||
auto grid_desc_m_k_padded =
|
||||
transform_tensor_descriptor(grid_desc_m_k,
|
||||
make_tuple(make_right_pad_transform(invariantLength, mPad),
|
||||
make_right_pad_transform(reduceLength, kPad)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
return (grid_desc_m_k_padded);
|
||||
};
|
||||
|
||||
static auto
|
||||
MakeScaleBiasMeanVar1dDescriptor(const std::array<index_t, NumInvariantDim>& lengths,
|
||||
const std::array<index_t, NumInvariantDim>& strides)
|
||||
{
|
||||
const auto tupleLengths =
|
||||
generate_tuple([&](auto I) { return lengths[I]; }, Number<NumInvariantDim>{});
|
||||
const auto tupleStrides =
|
||||
generate_tuple([&](auto I) { return strides[I]; }, Number<NumInvariantDim>{});
|
||||
|
||||
auto raw_grid_desc = make_naive_tensor_descriptor(tupleLengths, tupleStrides);
|
||||
|
||||
auto grid_desc_m = transform_tensor_descriptor(
|
||||
raw_grid_desc,
|
||||
make_tuple(make_merge_transform(tupleLengths)),
|
||||
make_tuple(typename arithmetic_sequence_gen<0, NumInvariantDim, 1>::type{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
|
||||
const auto invariantLength = grid_desc_m.GetLength(Number<0>{});
|
||||
|
||||
const auto mPad =
|
||||
math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength;
|
||||
|
||||
auto grid_desc_m_padded =
|
||||
transform_tensor_descriptor(grid_desc_m,
|
||||
make_tuple(make_right_pad_transform(invariantLength, mPad)),
|
||||
make_tuple(Sequence<0>{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
return (grid_desc_m_padded);
|
||||
};
|
||||
|
||||
using XYGridDesc_M_K = decltype(MakeXY2dDescriptor({1}, {1}, 1, 1));
|
||||
using ScaleBiasMeanVarGridDesc_M = decltype(MakeScaleBiasMeanVar1dDescriptor({1}, {1}));
|
||||
|
||||
struct Argument : public BaseArgument
|
||||
{
|
||||
Argument(const std::array<index_t, Rank> xyLengths,
|
||||
const std::array<index_t, Rank> xStrides,
|
||||
const std::array<index_t, Rank> yStrides,
|
||||
const std::array<int, NumBatchNormReduceDim> reduceDims,
|
||||
const std::array<index_t, Rank - NumBatchNormReduceDim> bnScaleBiasMeanVarLengths,
|
||||
const std::array<index_t, Rank - NumBatchNormReduceDim> bnScaleStrides,
|
||||
const std::array<index_t, Rank - NumBatchNormReduceDim> bnBiasStrides,
|
||||
const std::array<index_t, Rank - NumBatchNormReduceDim> bnMeanVarStrides,
|
||||
const XDataType* p_x,
|
||||
const ScaleDataType* p_scale,
|
||||
const BiasDataType* p_bias,
|
||||
const YElementwiseOp y_elementwise_op,
|
||||
double epsilon,
|
||||
YDataType* p_y,
|
||||
MeanVarDataType* resultSaveMean,
|
||||
MeanVarDataType* resultSaveInvVariance,
|
||||
double averageFactor,
|
||||
MeanVarDataType* resultRunningMean,
|
||||
MeanVarDataType* resultRunningVariance)
|
||||
: bnScaleBiasMeanVarLengths_(bnScaleBiasMeanVarLengths),
|
||||
bnScaleStrides_(bnScaleStrides),
|
||||
bnBiasStrides_(bnBiasStrides),
|
||||
bnMeanVarStrides_(bnMeanVarStrides),
|
||||
p_x_(p_x),
|
||||
p_scale_(p_scale),
|
||||
p_bias_(p_bias),
|
||||
y_elementwise_op_(y_elementwise_op),
|
||||
p_y_(p_y),
|
||||
resultSaveMean_(resultSaveMean),
|
||||
resultSaveInvVariance_(resultSaveInvVariance),
|
||||
resultRunningMean_(resultRunningMean),
|
||||
resultRunningVariance_(resultRunningVariance)
|
||||
{
|
||||
xyLengths_ =
|
||||
shuffle_tensor_dimensions<Rank, NumBatchNormReduceDim>(xyLengths, reduceDims);
|
||||
xStrides_ =
|
||||
shuffle_tensor_dimensions<Rank, NumBatchNormReduceDim>(xStrides, reduceDims);
|
||||
yStrides_ =
|
||||
shuffle_tensor_dimensions<Rank, NumBatchNormReduceDim>(yStrides, reduceDims);
|
||||
|
||||
std::tie(invariant_length_, reduce_length_) =
|
||||
get_2d_lengths<Rank, NumBatchNormReduceDim>(xyLengths_);
|
||||
|
||||
epsilon_ = type_convert<AccDataType>(epsilon);
|
||||
averageFactor_ = type_convert<AccDataType>(averageFactor);
|
||||
|
||||
updateMovingAverage_ =
|
||||
(resultRunningMean != nullptr && resultRunningVariance != nullptr);
|
||||
saveMeanInvVariance_ = (resultSaveMean != nullptr && resultSaveInvVariance_ != nullptr);
|
||||
|
||||
if(UseMultiblockInK)
|
||||
{
|
||||
int iterations = 1;
|
||||
while(true)
|
||||
{
|
||||
int testBlkGroupSize = (reduce_length_ + (K_BlockTileSize * iterations) - 1) /
|
||||
(K_BlockTileSize * iterations);
|
||||
|
||||
// we want the blkGroupSize be not more than 128
|
||||
if(testBlkGroupSize <= 128)
|
||||
break;
|
||||
|
||||
iterations++;
|
||||
};
|
||||
|
||||
blkGroupSize_ = (reduce_length_ + (K_BlockTileSize * iterations) - 1) /
|
||||
(K_BlockTileSize * iterations);
|
||||
|
||||
numBlockTileIteration_ = iterations;
|
||||
}
|
||||
else
|
||||
{
|
||||
blkGroupSize_ = 1;
|
||||
numBlockTileIteration_ = (reduce_length_ + K_BlockTileSize - 1) / K_BlockTileSize;
|
||||
};
|
||||
|
||||
gridSize_ = (invariant_length_ + M_BlockTileSize - 1) / M_BlockTileSize * blkGroupSize_;
|
||||
|
||||
x_grid_desc_m_k_ =
|
||||
MakeXY2dDescriptor(xyLengths_, xStrides_, blkGroupSize_, numBlockTileIteration_);
|
||||
y_grid_desc_m_k_ =
|
||||
MakeXY2dDescriptor(xyLengths_, yStrides_, blkGroupSize_, numBlockTileIteration_);
|
||||
scale_grid_desc_m_ =
|
||||
MakeScaleBiasMeanVar1dDescriptor(bnScaleBiasMeanVarLengths, bnScaleStrides_);
|
||||
bias_grid_desc_m_ =
|
||||
MakeScaleBiasMeanVar1dDescriptor(bnScaleBiasMeanVarLengths, bnBiasStrides_);
|
||||
mean_var_grid_desc_m_ =
|
||||
MakeScaleBiasMeanVar1dDescriptor(bnScaleBiasMeanVarLengths, bnMeanVarStrides_);
|
||||
}
|
||||
|
||||
AccDataType epsilon_;
|
||||
AccDataType averageFactor_;
|
||||
|
||||
bool updateMovingAverage_;
|
||||
bool saveMeanInvVariance_;
|
||||
|
||||
std::array<index_t, Rank> xyLengths_;
|
||||
std::array<index_t, Rank> xStrides_;
|
||||
std::array<index_t, Rank> yStrides_;
|
||||
|
||||
std::array<index_t, Rank - NumBatchNormReduceDim> bnScaleBiasMeanVarLengths_;
|
||||
std::array<index_t, Rank - NumBatchNormReduceDim> bnScaleStrides_;
|
||||
std::array<index_t, Rank - NumBatchNormReduceDim> bnBiasStrides_;
|
||||
std::array<index_t, Rank - NumBatchNormReduceDim> bnMeanVarStrides_;
|
||||
|
||||
const XDataType* p_x_;
|
||||
const ScaleDataType* p_scale_;
|
||||
const BiasDataType* p_bias_;
|
||||
const YElementwiseOp y_elementwise_op_;
|
||||
YDataType* p_y_;
|
||||
|
||||
MeanVarDataType* resultSaveMean_;
|
||||
MeanVarDataType* resultSaveInvVariance_;
|
||||
|
||||
MeanVarDataType* resultRunningMean_;
|
||||
MeanVarDataType* resultRunningVariance_;
|
||||
|
||||
long_index_t invariant_length_;
|
||||
long_index_t reduce_length_;
|
||||
|
||||
int blkGroupSize_;
|
||||
int numBlockTileIteration_;
|
||||
size_t gridSize_;
|
||||
|
||||
XYGridDesc_M_K x_grid_desc_m_k_;
|
||||
XYGridDesc_M_K y_grid_desc_m_k_;
|
||||
ScaleBiasMeanVarGridDesc_M scale_grid_desc_m_;
|
||||
ScaleBiasMeanVarGridDesc_M bias_grid_desc_m_;
|
||||
ScaleBiasMeanVarGridDesc_M mean_var_grid_desc_m_;
|
||||
|
||||
void* workspace_mean_;
|
||||
void* workspace_variance_;
|
||||
void* workspace_count_;
|
||||
};
|
||||
|
||||
size_t GetWorkSpaceSize(const BaseArgument* pArg) const override
|
||||
{
|
||||
const Argument* pArg_ = dynamic_cast<const Argument*>(pArg);
|
||||
|
||||
size_t workspace_size = 0;
|
||||
|
||||
if(UseMultiblockInK && pArg_->blkGroupSize_ > 1)
|
||||
{
|
||||
// workspace for welford intermediate mean
|
||||
workspace_size +=
|
||||
pArg_->invariant_length_ * pArg_->blkGroupSize_ * sizeof(MeanVarDataType) + 64;
|
||||
|
||||
// workspace for welford intermediate variance
|
||||
workspace_size +=
|
||||
pArg_->invariant_length_ * pArg_->blkGroupSize_ * sizeof(MeanVarDataType) + 64;
|
||||
|
||||
// workspace for welford intermediate count
|
||||
workspace_size +=
|
||||
pArg_->invariant_length_ * pArg_->blkGroupSize_ * sizeof(int32_t) + 64;
|
||||
}
|
||||
|
||||
return (workspace_size);
|
||||
};
|
||||
|
||||
void SetWorkSpacePointer(BaseArgument* pArg, void* p_workspace) const override
|
||||
{
|
||||
Argument* pArg_ = dynamic_cast<Argument*>(pArg);
|
||||
|
||||
pArg_->p_workspace_ = p_workspace;
|
||||
|
||||
if(UseMultiblockInK && pArg_->blkGroupSize_ > 1)
|
||||
{
|
||||
|
||||
// setup buffer used for intermediate welford mean
|
||||
pArg_->workspace_mean_ = static_cast<char*>(pArg_->p_workspace_);
|
||||
|
||||
index_t mean_space_sz =
|
||||
pArg_->invariant_length_ * pArg_->blkGroupSize_ * sizeof(MeanVarDataType);
|
||||
|
||||
mean_space_sz = math::integer_least_multiple(mean_space_sz, 64);
|
||||
|
||||
// setup buffer used for intermediate welford varirance
|
||||
pArg_->workspace_variance_ =
|
||||
reinterpret_cast<char*>(pArg_->workspace_mean_) + mean_space_sz;
|
||||
|
||||
index_t variance_space_sz =
|
||||
pArg_->invariant_length_ * pArg_->blkGroupSize_ * sizeof(MeanVarDataType);
|
||||
|
||||
variance_space_sz = math::integer_least_multiple(variance_space_sz, 64);
|
||||
|
||||
// setup buffer used for intermediate welfor count
|
||||
pArg_->workspace_count_ =
|
||||
reinterpret_cast<char*>(pArg_->workspace_variance_) + variance_space_sz;
|
||||
};
|
||||
};
|
||||
|
||||
struct Invoker : public BaseInvoker
|
||||
{
|
||||
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
|
||||
{
|
||||
float avg_time = 0;
|
||||
|
||||
if(UseMultiblockInK && arg.blkGroupSize_ > 1)
|
||||
{
|
||||
using GetReduceCountPerThreadFunctor =
|
||||
GetReduceCountPerThreadForMultiblockWelford<K_BlockTileSize, KThreadSliceSize>;
|
||||
|
||||
GetReduceCountPerThreadFunctor get_reduce_count_per_thread(
|
||||
arg.blkGroupSize_, arg.numBlockTileIteration_, arg.reduce_length_);
|
||||
|
||||
const auto mean_var_count_grid_desc_m_g =
|
||||
DeviceBatchNormFwdImpl::MakeMeanVarCountOutputMG2dDescriptor(
|
||||
arg.invariant_length_, arg.blkGroupSize_);
|
||||
|
||||
const auto mean_var_count_grid_desc_m_k =
|
||||
DeviceBatchNormFwdImpl::MakeMeanVarCountInputMK2dDescriptor(
|
||||
arg.invariant_length_, arg.blkGroupSize_);
|
||||
|
||||
using MeanVarCountGridDesc_M_G = decltype(mean_var_count_grid_desc_m_g);
|
||||
using MeanVarCountGridDesc_M_K = decltype(mean_var_count_grid_desc_m_k);
|
||||
|
||||
using GridwiseMultiblockWelfordFirstHalf_ =
|
||||
GridwiseMultiblockWelfordFirstHalf<XDataType,
|
||||
AccDataType,
|
||||
MeanVarDataType,
|
||||
XYGridDesc_M_K,
|
||||
MeanVarCountGridDesc_M_G,
|
||||
GetReduceCountPerThreadFunctor,
|
||||
BlockSize,
|
||||
MThreadClusterSize,
|
||||
KThreadClusterSize,
|
||||
MThreadSliceSize,
|
||||
KThreadSliceSize,
|
||||
XSrcYDstVectorDim,
|
||||
XSrcVectorSize>;
|
||||
|
||||
using GridwiseWelfordSecondHalfBatchNormForwardFinal_ =
|
||||
GridwiseWelfordSecondHalfBatchNormForwardFinal<XDataType,
|
||||
YDataType,
|
||||
AccDataType,
|
||||
ScaleDataType,
|
||||
BiasDataType,
|
||||
MeanVarDataType,
|
||||
YElementwiseOp,
|
||||
XYGridDesc_M_K,
|
||||
MeanVarCountGridDesc_M_K,
|
||||
ScaleBiasMeanVarGridDesc_M,
|
||||
ScaleBiasMeanVarGridDesc_M,
|
||||
BlockSize,
|
||||
MThreadClusterSize,
|
||||
KThreadClusterSize,
|
||||
MThreadSliceSize,
|
||||
KThreadSliceSize,
|
||||
XSrcYDstVectorDim,
|
||||
XSrcVectorSize,
|
||||
YDstVectorSize,
|
||||
ScaleSrcVectorSize,
|
||||
BiasSrcVectorSize,
|
||||
MeanVarSrcDstVectorSize>;
|
||||
|
||||
index_t numMeanVarCountBlockTileIteration =
|
||||
(arg.blkGroupSize_ + KThreadClusterSize - 1) / KThreadClusterSize;
|
||||
|
||||
const auto kern_multiblock_welford_first_half =
|
||||
kernel_multiblock_welford_first_half<GridwiseMultiblockWelfordFirstHalf_,
|
||||
XDataType,
|
||||
MeanVarDataType,
|
||||
XYGridDesc_M_K,
|
||||
MeanVarCountGridDesc_M_G,
|
||||
GetReduceCountPerThreadFunctor>;
|
||||
|
||||
const auto kern_welford_second_half_batchnorm_forward_final =
|
||||
kernel_welford_second_half_batchnorm_forward_final<
|
||||
GridwiseWelfordSecondHalfBatchNormForwardFinal_,
|
||||
XDataType,
|
||||
YDataType,
|
||||
AccDataType,
|
||||
ScaleDataType,
|
||||
BiasDataType,
|
||||
MeanVarDataType,
|
||||
YElementwiseOp,
|
||||
XYGridDesc_M_K,
|
||||
MeanVarCountGridDesc_M_K,
|
||||
ScaleBiasMeanVarGridDesc_M,
|
||||
ScaleBiasMeanVarGridDesc_M>;
|
||||
|
||||
avg_time +=
|
||||
launch_and_time_kernel(stream_config,
|
||||
kern_multiblock_welford_first_half,
|
||||
dim3(arg.gridSize_),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
arg.x_grid_desc_m_k_,
|
||||
mean_var_count_grid_desc_m_g,
|
||||
get_reduce_count_per_thread,
|
||||
arg.numBlockTileIteration_,
|
||||
arg.p_x_,
|
||||
static_cast<MeanVarDataType*>(arg.workspace_mean_),
|
||||
static_cast<MeanVarDataType*>(arg.workspace_variance_),
|
||||
static_cast<int32_t*>(arg.workspace_count_));
|
||||
|
||||
avg_time +=
|
||||
launch_and_time_kernel(stream_config,
|
||||
kern_welford_second_half_batchnorm_forward_final,
|
||||
dim3(arg.gridSize_),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
arg.x_grid_desc_m_k_,
|
||||
arg.y_grid_desc_m_k_,
|
||||
mean_var_count_grid_desc_m_k,
|
||||
arg.scale_grid_desc_m_,
|
||||
arg.bias_grid_desc_m_,
|
||||
arg.mean_var_grid_desc_m_,
|
||||
arg.blkGroupSize_,
|
||||
arg.numBlockTileIteration_,
|
||||
numMeanVarCountBlockTileIteration,
|
||||
arg.epsilon_,
|
||||
static_cast<MeanVarDataType*>(arg.workspace_mean_),
|
||||
static_cast<MeanVarDataType*>(arg.workspace_variance_),
|
||||
static_cast<int32_t*>(arg.workspace_count_),
|
||||
arg.p_x_,
|
||||
arg.p_scale_,
|
||||
arg.p_bias_,
|
||||
arg.y_elementwise_op_,
|
||||
arg.p_y_,
|
||||
arg.updateMovingAverage_,
|
||||
arg.averageFactor_,
|
||||
arg.resultRunningMean_,
|
||||
arg.resultRunningVariance_,
|
||||
arg.saveMeanInvVariance_,
|
||||
arg.resultSaveMean_,
|
||||
arg.resultSaveInvVariance_);
|
||||
}
|
||||
else
|
||||
{
|
||||
using GetReduceCountPerThreadFunctor =
|
||||
GetReduceCountPerThreadForBlockwiseWelford<K_BlockTileSize, KThreadSliceSize>;
|
||||
|
||||
GetReduceCountPerThreadFunctor get_reduce_count_per_thread(
|
||||
arg.numBlockTileIteration_, arg.reduce_length_);
|
||||
|
||||
using GridwiseBatchNormForwardWithBlockwiseWelford_ =
|
||||
GridwiseBatchNormForwardWithBlockwiseWelford<XDataType,
|
||||
YDataType,
|
||||
AccDataType,
|
||||
ScaleDataType,
|
||||
BiasDataType,
|
||||
MeanVarDataType,
|
||||
YElementwiseOp,
|
||||
XYGridDesc_M_K,
|
||||
ScaleBiasMeanVarGridDesc_M,
|
||||
ScaleBiasMeanVarGridDesc_M,
|
||||
GetReduceCountPerThreadFunctor,
|
||||
BlockSize,
|
||||
MThreadClusterSize,
|
||||
KThreadClusterSize,
|
||||
MThreadSliceSize,
|
||||
KThreadSliceSize,
|
||||
XSrcYDstVectorDim,
|
||||
XSrcVectorSize,
|
||||
YDstVectorSize,
|
||||
ScaleSrcVectorSize,
|
||||
BiasSrcVectorSize,
|
||||
MeanVarSrcDstVectorSize>;
|
||||
|
||||
const auto kern_batchnorm_fwd = kernel_batchnorm_forward_with_blockwise_welford<
|
||||
GridwiseBatchNormForwardWithBlockwiseWelford_,
|
||||
XDataType,
|
||||
YDataType,
|
||||
AccDataType,
|
||||
ScaleDataType,
|
||||
BiasDataType,
|
||||
MeanVarDataType,
|
||||
YElementwiseOp,
|
||||
XYGridDesc_M_K,
|
||||
ScaleBiasMeanVarGridDesc_M,
|
||||
ScaleBiasMeanVarGridDesc_M,
|
||||
GetReduceCountPerThreadFunctor>;
|
||||
|
||||
avg_time += launch_and_time_kernel(stream_config,
|
||||
kern_batchnorm_fwd,
|
||||
dim3(arg.gridSize_),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
arg.x_grid_desc_m_k_,
|
||||
arg.y_grid_desc_m_k_,
|
||||
arg.scale_grid_desc_m_,
|
||||
arg.bias_grid_desc_m_,
|
||||
arg.mean_var_grid_desc_m_,
|
||||
get_reduce_count_per_thread,
|
||||
arg.numBlockTileIteration_,
|
||||
arg.epsilon_,
|
||||
arg.p_x_,
|
||||
arg.p_scale_,
|
||||
arg.p_bias_,
|
||||
arg.y_elementwise_op_,
|
||||
arg.p_y_,
|
||||
arg.updateMovingAverage_, // true or false
|
||||
arg.averageFactor_,
|
||||
arg.resultRunningMean_,
|
||||
arg.resultRunningVariance_,
|
||||
arg.saveMeanInvVariance_, // true or false
|
||||
arg.resultSaveMean_,
|
||||
arg.resultSaveInvVariance_);
|
||||
};
|
||||
|
||||
return (avg_time);
|
||||
};
|
||||
|
||||
float Run(const BaseArgument* pArg,
|
||||
const StreamConfig& stream_config = StreamConfig{}) override
|
||||
{
|
||||
return Run(*dynamic_cast<const Argument*>(pArg), stream_config);
|
||||
};
|
||||
};
|
||||
|
||||
bool IsSupportedArgument(const BaseArgument* pArg) override
|
||||
{
|
||||
const Argument* pArg_ = dynamic_cast<const Argument*>(pArg);
|
||||
|
||||
if constexpr(XSrcYDstVectorDim == 0)
|
||||
{
|
||||
if(pArg_->xStrides_[NumInvariantDim - 1] != 1 ||
|
||||
pArg_->yStrides_[NumInvariantDim - 1] != 1)
|
||||
return false;
|
||||
|
||||
if(pArg_->xyLengths_[NumInvariantDim - 1] % XSrcVectorSize != 0 ||
|
||||
pArg_->xyLengths_[NumInvariantDim - 1] % YDstVectorSize != 0)
|
||||
return false;
|
||||
}
|
||||
else
|
||||
{
|
||||
if(pArg_->xStrides_[Rank - 1] != 1 || pArg_->yStrides_[Rank - 1] != 1)
|
||||
return false;
|
||||
|
||||
if(pArg_->xyLengths_[Rank - 1] % XSrcVectorSize != 0 ||
|
||||
pArg_->xyLengths_[Rank - 1] % YDstVectorSize != 0)
|
||||
return false;
|
||||
};
|
||||
|
||||
if(pArg_->bnScaleStrides_[NumInvariantDim - 1] != 1 && ScaleSrcVectorSize != 1)
|
||||
return false;
|
||||
if(pArg_->bnBiasStrides_[NumInvariantDim - 1] != 1 && BiasSrcVectorSize != 1)
|
||||
return false;
|
||||
|
||||
if(pArg_->bnScaleBiasMeanVarLengths_[NumInvariantDim - 1] % ScaleSrcVectorSize != 0)
|
||||
return false;
|
||||
if(pArg_->bnScaleBiasMeanVarLengths_[NumInvariantDim - 1] % BiasSrcVectorSize != 0)
|
||||
return false;
|
||||
|
||||
if(pArg_->bnMeanVarStrides_[NumInvariantDim - 1] != 1 && MeanVarSrcDstVectorSize != 1)
|
||||
return false;
|
||||
|
||||
if(pArg_->bnScaleBiasMeanVarLengths_[NumInvariantDim - 1] % MeanVarSrcDstVectorSize != 0)
|
||||
return false;
|
||||
|
||||
bool is_valid = true;
|
||||
|
||||
static_for<0, NumInvariantDim, 1>{}([&](auto I) {
|
||||
if(pArg_->xyLengths_[I] != pArg_->bnScaleBiasMeanVarLengths_[I])
|
||||
is_valid = false;
|
||||
});
|
||||
|
||||
if(!is_valid)
|
||||
return false;
|
||||
|
||||
return true;
|
||||
};
|
||||
|
||||
std::unique_ptr<BaseArgument> MakeArgumentPointer(
|
||||
const std::array<index_t, Rank> xyLengths,
|
||||
const std::array<index_t, Rank> xStrides,
|
||||
const std::array<index_t, Rank> yStrides,
|
||||
const std::array<int, NumBatchNormReduceDim> reduceDims,
|
||||
const std::array<index_t, Rank - NumBatchNormReduceDim> bnScaleBiasMeanVarLengths,
|
||||
const std::array<index_t, Rank - NumBatchNormReduceDim> bnScaleStrides,
|
||||
const std::array<index_t, Rank - NumBatchNormReduceDim> bnBiasStrides,
|
||||
const std::array<index_t, Rank - NumBatchNormReduceDim> bnMeanVarStrides,
|
||||
const void* p_x,
|
||||
const void* p_scale,
|
||||
const void* p_bias,
|
||||
double epsilon,
|
||||
const YElementwiseOp y_elementwise_op,
|
||||
void* p_y,
|
||||
void* resultSaveMean,
|
||||
void* resultSaveInvVariance,
|
||||
double averageFactor,
|
||||
void* resultRunningMean,
|
||||
void* resultRunningVariance) override
|
||||
{
|
||||
return std::make_unique<Argument>(xyLengths,
|
||||
xStrides,
|
||||
yStrides,
|
||||
reduceDims,
|
||||
bnScaleBiasMeanVarLengths,
|
||||
bnScaleStrides,
|
||||
bnBiasStrides,
|
||||
bnMeanVarStrides,
|
||||
static_cast<const XDataType*>(p_x),
|
||||
static_cast<const ScaleDataType*>(p_scale),
|
||||
static_cast<const BiasDataType*>(p_bias),
|
||||
y_elementwise_op,
|
||||
epsilon,
|
||||
static_cast<YDataType*>(p_y),
|
||||
static_cast<MeanVarDataType*>(resultSaveMean),
|
||||
static_cast<MeanVarDataType*>(resultSaveInvVariance),
|
||||
averageFactor,
|
||||
static_cast<MeanVarDataType*>(resultRunningMean),
|
||||
static_cast<MeanVarDataType*>(resultRunningVariance));
|
||||
};
|
||||
|
||||
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
|
||||
{
|
||||
return std::make_unique<Invoker>();
|
||||
};
|
||||
|
||||
std::string GetTypeString() const override
|
||||
{
|
||||
auto str = std::stringstream();
|
||||
|
||||
// clang-format off
|
||||
str << "DeviceBatchNormFwdImpl<" << BlockSize << ",";
|
||||
str << "M_C" << MThreadClusterSize << "_S" << MThreadSliceSize << ",";
|
||||
str << "K_C" << KThreadClusterSize << "_S" << KThreadSliceSize << ",";
|
||||
str << "XSrcYDstVectorDim_" << XSrcYDstVectorDim << ",";
|
||||
str << "VectorSize_X" << XSrcVectorSize << "_scale_" << ScaleSrcVectorSize << "_bias_" << BiasSrcVectorSize << "_mean_var_" << MeanVarSrcDstVectorSize << "_Y" << YDstVectorSize << ">";
|
||||
// clang-format on
|
||||
|
||||
return str.str();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
89
include/ck/tensor_operation/gpu/device/welford_helper.hpp
Normal file
89
include/ck/tensor_operation/gpu/device/welford_helper.hpp
Normal file
@@ -0,0 +1,89 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
template <index_t K_BlockTileSize, index_t KThreadSliceSize>
|
||||
struct GetReduceCountPerThreadForBlockwiseWelford
|
||||
{
|
||||
GetReduceCountPerThreadForBlockwiseWelford(index_t numBlockTileIteration,
|
||||
long_index_t reduce_length)
|
||||
: numBlockTileIteration_{numBlockTileIteration}
|
||||
{
|
||||
count_in_last_tile_ = reduce_length % K_BlockTileSize;
|
||||
};
|
||||
|
||||
__device__ index_t operator()(index_t thread_k_cluster_id) const
|
||||
{
|
||||
if(count_in_last_tile_ == 0)
|
||||
return (KThreadSliceSize * numBlockTileIteration_);
|
||||
else
|
||||
{
|
||||
index_t num_complete_slice = count_in_last_tile_ / KThreadSliceSize;
|
||||
index_t count_in_last_slice = count_in_last_tile_ % KThreadSliceSize;
|
||||
|
||||
if(thread_k_cluster_id < num_complete_slice)
|
||||
return (KThreadSliceSize * numBlockTileIteration_);
|
||||
else if(thread_k_cluster_id == num_complete_slice)
|
||||
return (KThreadSliceSize * (numBlockTileIteration_ - 1) + count_in_last_slice);
|
||||
else
|
||||
return (KThreadSliceSize * (numBlockTileIteration_ - 1));
|
||||
};
|
||||
};
|
||||
|
||||
index_t numBlockTileIteration_;
|
||||
index_t count_in_last_tile_;
|
||||
};
|
||||
|
||||
template <index_t K_BlockTileSize, index_t KThreadSliceSize>
|
||||
struct GetReduceCountPerThreadForMultiblockWelford
|
||||
{
|
||||
GetReduceCountPerThreadForMultiblockWelford(index_t blkGroupSize,
|
||||
index_t numBlockTileIteration,
|
||||
long_index_t reduce_length)
|
||||
: blkGroupSize_(blkGroupSize), numBlockTileIteration_{numBlockTileIteration}
|
||||
{
|
||||
last_block_reduce_length_ =
|
||||
reduce_length - K_BlockTileSize * numBlockTileIteration_ * (blkGroupSize_ - 1);
|
||||
numBlockTileIterationByLastBlock_ =
|
||||
(last_block_reduce_length_ + K_BlockTileSize - 1) / K_BlockTileSize;
|
||||
};
|
||||
|
||||
__device__ index_t operator()(index_t block_local_id, index_t thread_k_cluster_id) const
|
||||
{
|
||||
if(last_block_reduce_length_ == K_BlockTileSize * numBlockTileIteration_ ||
|
||||
block_local_id < blkGroupSize_ - 1)
|
||||
return (KThreadSliceSize * numBlockTileIteration_);
|
||||
|
||||
index_t count_in_last_tile = last_block_reduce_length_ % K_BlockTileSize;
|
||||
|
||||
if(count_in_last_tile == 0)
|
||||
return (KThreadSliceSize * numBlockTileIterationByLastBlock_);
|
||||
else
|
||||
{
|
||||
index_t num_complete_slice = count_in_last_tile / KThreadSliceSize;
|
||||
|
||||
if(thread_k_cluster_id < num_complete_slice)
|
||||
return (KThreadSliceSize * numBlockTileIterationByLastBlock_);
|
||||
else if(thread_k_cluster_id == num_complete_slice)
|
||||
return (KThreadSliceSize * (numBlockTileIterationByLastBlock_ - 1) +
|
||||
count_in_last_tile);
|
||||
else
|
||||
return (KThreadSliceSize * (numBlockTileIterationByLastBlock_ - 1));
|
||||
};
|
||||
};
|
||||
|
||||
index_t blkGroupSize_;
|
||||
index_t numBlockTileIteration_;
|
||||
|
||||
index_t last_block_reduce_length_;
|
||||
index_t numBlockTileIterationByLastBlock_;
|
||||
};
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,258 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck/utility/data_type.hpp"
|
||||
#include "ck/utility/math.hpp"
|
||||
#include "ck/tensor_operation/gpu/block/blockwise_welford.hpp"
|
||||
#include "ck/tensor_operation/gpu/thread/threadwise_welford.hpp"
|
||||
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <typename GridwiseMultiblockWelfordFirstHalf_,
|
||||
typename XDataType,
|
||||
typename MeanVarDataType,
|
||||
typename XGridDesc_M_K,
|
||||
typename MeanVarCountGridDesc_M_G,
|
||||
typename GetReduceCountPerThreadFunctor>
|
||||
__global__ void kernel_multiblock_welford_first_half(
|
||||
const XGridDesc_M_K x_grid_desc_m_k,
|
||||
const MeanVarCountGridDesc_M_G mean_var_count_grid_desc_m_g,
|
||||
const GetReduceCountPerThreadFunctor get_reduce_count_per_thread,
|
||||
index_t num_k_block_tile_iteration,
|
||||
const XDataType* const __restrict__ p_x,
|
||||
MeanVarDataType* const p_welford_mean,
|
||||
MeanVarDataType* const p_welford_variance,
|
||||
int32_t* const p_welford_count)
|
||||
{
|
||||
GridwiseMultiblockWelfordFirstHalf_::Run(x_grid_desc_m_k,
|
||||
mean_var_count_grid_desc_m_g,
|
||||
get_reduce_count_per_thread,
|
||||
num_k_block_tile_iteration,
|
||||
p_x,
|
||||
p_welford_mean,
|
||||
p_welford_variance,
|
||||
p_welford_count);
|
||||
};
|
||||
|
||||
template <typename XDataType,
|
||||
typename AccDataType,
|
||||
typename MeanVarDataType,
|
||||
typename XGridDesc_M_K,
|
||||
typename MeanVarCountGridDesc_M_G,
|
||||
typename GetReduceCountPerThreadFunctor,
|
||||
index_t BlockSize,
|
||||
index_t MThreadClusterSize,
|
||||
index_t KThreadClusterSize,
|
||||
index_t MThreadSliceSize,
|
||||
index_t KThreadSliceSize,
|
||||
index_t XSrcCountSrcVectorDim,
|
||||
index_t XSrcCountSrcVectorSize>
|
||||
struct GridwiseMultiblockWelfordFirstHalf
|
||||
{
|
||||
static_assert((XSrcCountSrcVectorDim == 0 && MThreadSliceSize % XSrcCountSrcVectorSize == 0) ||
|
||||
(XSrcCountSrcVectorDim == 1 &&
|
||||
KThreadSliceSize % XSrcCountSrcVectorSize == 0),
|
||||
"Invalid thread slice sizes and/or vector sizes configuration, please check!");
|
||||
|
||||
static constexpr bool reorder_thread_cluster = (XSrcCountSrcVectorDim == 0);
|
||||
|
||||
using ThreadClusterLengths_M_K = Sequence<MThreadClusterSize, KThreadClusterSize>;
|
||||
|
||||
using ThreadBufferDimAccessOrder =
|
||||
typename conditional<reorder_thread_cluster, Sequence<1, 0>, Sequence<0, 1>>::type;
|
||||
|
||||
using ThreadClusterArrangeOrder =
|
||||
typename conditional<reorder_thread_cluster, Sequence<1, 0>, Sequence<0, 1>>::type;
|
||||
|
||||
static constexpr auto thread_cluster_desc =
|
||||
make_cluster_descriptor(ThreadClusterLengths_M_K{}, ThreadClusterArrangeOrder{});
|
||||
|
||||
using ThreadReduceSrcDesc_M_K = decltype(make_naive_tensor_descriptor_packed(
|
||||
make_tuple(Number<MThreadSliceSize>{}, Number<KThreadSliceSize>{})));
|
||||
using ThreadReduceDstDesc_M =
|
||||
decltype(make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{})));
|
||||
|
||||
using ThreadwiseWelford =
|
||||
ThreadwiseWelford<AccDataType, ThreadReduceSrcDesc_M_K, ThreadReduceDstDesc_M>;
|
||||
|
||||
using BlockwiseWelford = BlockwiseWelford<AccDataType,
|
||||
BlockSize,
|
||||
ThreadClusterLengths_M_K,
|
||||
ThreadClusterArrangeOrder,
|
||||
false>;
|
||||
|
||||
using PassThroughOp = tensor_operation::element_wise::PassThrough;
|
||||
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
|
||||
static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize;
|
||||
static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize;
|
||||
|
||||
__device__ static void Run(const XGridDesc_M_K& x_grid_desc_m_k,
|
||||
const MeanVarCountGridDesc_M_G& mean_var_count_grid_desc_m_g,
|
||||
const GetReduceCountPerThreadFunctor& get_reduce_count_per_thread,
|
||||
index_t num_k_block_tile_iteration,
|
||||
const XDataType* const __restrict__ p_x,
|
||||
MeanVarDataType* const p_welford_mean,
|
||||
MeanVarDataType* const p_welford_variance,
|
||||
int32_t* const p_welford_count)
|
||||
{
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize * KThreadSliceSize, true>
|
||||
x_thread_buf;
|
||||
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true>
|
||||
welford_mean_thread_buf;
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true>
|
||||
welford_var_thread_buf;
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr, int32_t, MThreadSliceSize, true>
|
||||
welford_count_thread_buf;
|
||||
|
||||
const index_t blkgroup_size = mean_var_count_grid_desc_m_g.GetLength(I1);
|
||||
|
||||
const index_t thread_local_id = get_thread_local_1d_id();
|
||||
const index_t block_global_id = get_block_1d_id();
|
||||
const index_t blkgroup_id = block_global_id / blkgroup_size;
|
||||
const index_t block_local_id = block_global_id % blkgroup_size;
|
||||
|
||||
const auto thread_cluster_idx =
|
||||
thread_cluster_desc.CalculateBottomIndex(make_multi_index(thread_local_id));
|
||||
|
||||
const auto thread_m_cluster_id = thread_cluster_idx[I0];
|
||||
const auto thread_k_cluster_id = thread_cluster_idx[I1];
|
||||
|
||||
using ThreadBufferLengths_M_K = Sequence<MThreadSliceSize, KThreadSliceSize>;
|
||||
using ThreadBufferLengths_M_1 = Sequence<MThreadSliceSize, 1>;
|
||||
|
||||
constexpr auto thread_buffer_desc_m_k = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(Number<MThreadSliceSize>{}, Number<KThreadSliceSize>{}));
|
||||
constexpr auto thread_buffer_desc_m_1 = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(Number<MThreadSliceSize>{}, Number<1>{}));
|
||||
|
||||
const index_t reduceSizePerBlock = K_BlockTileSize * num_k_block_tile_iteration;
|
||||
|
||||
auto threadwise_x_load = ThreadwiseTensorSliceTransfer_v2<XDataType,
|
||||
AccDataType,
|
||||
XGridDesc_M_K,
|
||||
decltype(thread_buffer_desc_m_k),
|
||||
ThreadBufferLengths_M_K,
|
||||
ThreadBufferDimAccessOrder,
|
||||
XSrcCountSrcVectorDim,
|
||||
XSrcCountSrcVectorSize,
|
||||
1,
|
||||
true>(
|
||||
x_grid_desc_m_k,
|
||||
make_multi_index(blkgroup_id * M_BlockTileSize + thread_m_cluster_id * MThreadSliceSize,
|
||||
block_local_id * reduceSizePerBlock +
|
||||
thread_k_cluster_id * KThreadSliceSize));
|
||||
|
||||
auto threadwise_welford_mean_var_store =
|
||||
ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
|
||||
MeanVarDataType,
|
||||
decltype(thread_buffer_desc_m_1),
|
||||
MeanVarCountGridDesc_M_G,
|
||||
PassThroughOp,
|
||||
ThreadBufferLengths_M_1,
|
||||
Sequence<0, 1>,
|
||||
1,
|
||||
1,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
1,
|
||||
true>(
|
||||
mean_var_count_grid_desc_m_g,
|
||||
make_multi_index(blkgroup_id * M_BlockTileSize +
|
||||
thread_m_cluster_id * MThreadSliceSize,
|
||||
block_local_id),
|
||||
PassThroughOp{});
|
||||
|
||||
auto threadwise_welford_count_store =
|
||||
ThreadwiseTensorSliceTransfer_v1r3<int32_t,
|
||||
int32_t,
|
||||
decltype(thread_buffer_desc_m_1),
|
||||
MeanVarCountGridDesc_M_G,
|
||||
PassThroughOp,
|
||||
ThreadBufferLengths_M_1,
|
||||
Sequence<0, 1>,
|
||||
1,
|
||||
1,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
1,
|
||||
true>(
|
||||
mean_var_count_grid_desc_m_g,
|
||||
make_multi_index(blkgroup_id * M_BlockTileSize +
|
||||
thread_m_cluster_id * MThreadSliceSize,
|
||||
block_local_id),
|
||||
PassThroughOp{});
|
||||
|
||||
constexpr auto thread_copy_fwd_step_m_k = make_multi_index(0, K_BlockTileSize);
|
||||
|
||||
const auto x_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_x, x_grid_desc_m_k.GetElementSpaceSize());
|
||||
|
||||
auto welford_mean_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_welford_mean, mean_var_count_grid_desc_m_g.GetElementSpaceSize());
|
||||
|
||||
auto welford_var_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_welford_variance, mean_var_count_grid_desc_m_g.GetElementSpaceSize());
|
||||
|
||||
auto welford_count_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_welford_count, mean_var_count_grid_desc_m_g.GetElementSpaceSize());
|
||||
|
||||
auto threadwise_welford = ThreadwiseWelford();
|
||||
threadwise_welford.max_count_ =
|
||||
get_reduce_count_per_thread(block_local_id, thread_k_cluster_id);
|
||||
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
|
||||
welford_mean_thread_buf(I) = type_convert<AccDataType>(0.0f);
|
||||
welford_var_thread_buf(I) = type_convert<AccDataType>(0.0f);
|
||||
});
|
||||
|
||||
for(index_t reducedTiles = 0; reducedTiles < num_k_block_tile_iteration; ++reducedTiles)
|
||||
{
|
||||
threadwise_x_load.Run(x_grid_desc_m_k,
|
||||
x_global_val_buf,
|
||||
thread_buffer_desc_m_k,
|
||||
make_tuple(I0, I0),
|
||||
x_thread_buf);
|
||||
|
||||
threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, thread_copy_fwd_step_m_k);
|
||||
threadwise_welford.Run(x_thread_buf, welford_mean_thread_buf, welford_var_thread_buf);
|
||||
}
|
||||
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
|
||||
if constexpr(I > 0)
|
||||
block_sync_lds();
|
||||
|
||||
welford_count_thread_buf(I) = threadwise_welford.cur_count_;
|
||||
BlockwiseWelford::Run(
|
||||
welford_mean_thread_buf(I), welford_var_thread_buf(I), welford_count_thread_buf(I));
|
||||
});
|
||||
|
||||
if(thread_k_cluster_id == 0)
|
||||
{
|
||||
threadwise_welford_mean_var_store.Run(thread_buffer_desc_m_1,
|
||||
make_tuple(I0, I0),
|
||||
welford_mean_thread_buf,
|
||||
mean_var_count_grid_desc_m_g,
|
||||
welford_mean_global_val_buf);
|
||||
|
||||
threadwise_welford_mean_var_store.Run(thread_buffer_desc_m_1,
|
||||
make_tuple(I0, I0),
|
||||
welford_var_thread_buf,
|
||||
mean_var_count_grid_desc_m_g,
|
||||
welford_var_global_val_buf);
|
||||
|
||||
threadwise_welford_count_store.Run(thread_buffer_desc_m_1,
|
||||
make_tuple(I0, I0),
|
||||
welford_count_thread_buf,
|
||||
mean_var_count_grid_desc_m_g,
|
||||
welford_count_global_val_buf);
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,570 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck/utility/data_type.hpp"
|
||||
#include "ck/utility/math_v2.hpp"
|
||||
#include "ck/tensor_operation/gpu/block/blockwise_welford.hpp"
|
||||
#include "ck/tensor_operation/gpu/thread/threadwise_welford.hpp"
|
||||
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <typename GridwiseWelfordSecondHalfBatchNormForwardFinal_,
|
||||
typename XDataType,
|
||||
typename YDataType,
|
||||
typename AccDataType,
|
||||
typename ScaleDataType,
|
||||
typename BiasDataType,
|
||||
typename MeanVarDataType,
|
||||
typename YElementwiseOp,
|
||||
typename XYGridDesc_M_K,
|
||||
typename MeanVarCountGridDesc_M_K,
|
||||
typename ScaleBiasGridDesc_M,
|
||||
typename MeanVarGridDesc_M>
|
||||
__global__ void kernel_welford_second_half_batchnorm_forward_final(
|
||||
const XYGridDesc_M_K x_grid_desc_m_k,
|
||||
const XYGridDesc_M_K y_grid_desc_m_k,
|
||||
const MeanVarCountGridDesc_M_K mean_var_count_grid_desc_m_k,
|
||||
const ScaleBiasGridDesc_M scale_grid_desc_m,
|
||||
const ScaleBiasGridDesc_M bias_grid_desc_m,
|
||||
const MeanVarGridDesc_M mean_var_grid_desc_m,
|
||||
index_t blkgroup_size,
|
||||
index_t num_xy_k_block_tile_iteration,
|
||||
index_t num_mean_var_count_k_block_tile_iteration,
|
||||
AccDataType epsilon,
|
||||
const MeanVarDataType* const __restrict__ p_in_welford_mean,
|
||||
const MeanVarDataType* const __restrict__ p_in_welford_variance,
|
||||
const int32_t* const __restrict__ p_in_welford_count,
|
||||
const XDataType* const __restrict__ p_x,
|
||||
const ScaleDataType* const __restrict__ p_scale,
|
||||
const BiasDataType* const __restrict__ p_bias,
|
||||
const YElementwiseOp y_elementwise_op,
|
||||
YDataType* const __restrict__ p_y,
|
||||
bool updateMovingAverage,
|
||||
AccDataType averageFactor,
|
||||
MeanVarDataType* const __restrict__ resultRunningMean,
|
||||
MeanVarDataType* const __restrict__ resultRunningVariance,
|
||||
bool saveMeanInvVariance,
|
||||
MeanVarDataType* const __restrict__ resultSaveMean,
|
||||
MeanVarDataType* const __restrict__ resultSaveInvVariance)
|
||||
{
|
||||
GridwiseWelfordSecondHalfBatchNormForwardFinal_::Run(x_grid_desc_m_k,
|
||||
y_grid_desc_m_k,
|
||||
mean_var_count_grid_desc_m_k,
|
||||
scale_grid_desc_m,
|
||||
bias_grid_desc_m,
|
||||
mean_var_grid_desc_m,
|
||||
blkgroup_size,
|
||||
num_xy_k_block_tile_iteration,
|
||||
num_mean_var_count_k_block_tile_iteration,
|
||||
epsilon,
|
||||
p_in_welford_mean,
|
||||
p_in_welford_variance,
|
||||
p_in_welford_count,
|
||||
p_x,
|
||||
p_scale,
|
||||
p_bias,
|
||||
y_elementwise_op,
|
||||
p_y,
|
||||
updateMovingAverage,
|
||||
averageFactor,
|
||||
resultRunningMean,
|
||||
resultRunningVariance,
|
||||
saveMeanInvVariance,
|
||||
resultSaveMean,
|
||||
resultSaveInvVariance);
|
||||
};
|
||||
|
||||
template <typename XDataType,
|
||||
typename YDataType,
|
||||
typename AccDataType,
|
||||
typename ScaleDataType,
|
||||
typename BiasDataType,
|
||||
typename MeanVarDataType,
|
||||
typename YElementwiseOp,
|
||||
typename XYGridDesc_M_K,
|
||||
typename MeanVarCountGridDesc_M_K,
|
||||
typename ScaleBiasGridDesc_M,
|
||||
typename MeanVarGridDesc_M,
|
||||
index_t BlockSize,
|
||||
index_t MThreadClusterSize,
|
||||
index_t KThreadClusterSize,
|
||||
index_t MThreadSliceSize,
|
||||
index_t KThreadSliceSize,
|
||||
index_t XSrcYDstVectorDim,
|
||||
index_t XSrcVectorSize,
|
||||
index_t YDstVectorSize,
|
||||
index_t ScaleSrcVectorSize,
|
||||
index_t BiasSrcVectorSize,
|
||||
index_t MeanVarSrcDstVectorSize>
|
||||
struct GridwiseWelfordSecondHalfBatchNormForwardFinal
|
||||
{
|
||||
static_assert((XSrcYDstVectorDim == 0 && MThreadSliceSize % XSrcVectorSize == 0) ||
|
||||
(XSrcYDstVectorDim == 1 && KThreadSliceSize % XSrcVectorSize == 0),
|
||||
"Invalid thread slice sizes and/or vector sizes configuration, please check!");
|
||||
|
||||
static_assert((XSrcYDstVectorDim == 0 && MThreadSliceSize % YDstVectorSize == 0) ||
|
||||
(XSrcYDstVectorDim == 1 && KThreadSliceSize % YDstVectorSize == 0),
|
||||
"Invalid thread slice sizes and/or vector sizes configuration, please check!");
|
||||
|
||||
static constexpr bool reorder_thread_cluster = (XSrcYDstVectorDim == 0);
|
||||
|
||||
using ThreadClusterLengths_M_K = Sequence<MThreadClusterSize, KThreadClusterSize>;
|
||||
|
||||
using ThreadBufferDimAccessOrder =
|
||||
typename conditional<reorder_thread_cluster, Sequence<1, 0>, Sequence<0, 1>>::type;
|
||||
|
||||
using ThreadClusterArrangeOrder =
|
||||
typename conditional<reorder_thread_cluster, Sequence<1, 0>, Sequence<0, 1>>::type;
|
||||
|
||||
static constexpr auto thread_cluster_desc =
|
||||
make_cluster_descriptor(ThreadClusterLengths_M_K{}, ThreadClusterArrangeOrder{});
|
||||
|
||||
using ThreadReduceSrcDesc_M_1 = decltype(
|
||||
make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{}, Number<1>{})));
|
||||
using ThreadReduceDstDesc_M =
|
||||
decltype(make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{})));
|
||||
|
||||
using ThreadwiseWelford =
|
||||
ThreadwiseWelfordMerge<AccDataType, ThreadReduceSrcDesc_M_1, ThreadReduceDstDesc_M>;
|
||||
|
||||
using BlockwiseWelford = BlockwiseWelford<AccDataType,
|
||||
BlockSize,
|
||||
ThreadClusterLengths_M_K,
|
||||
ThreadClusterArrangeOrder>;
|
||||
|
||||
using PassThroughOp = tensor_operation::element_wise::PassThrough;
|
||||
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
|
||||
static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize;
|
||||
static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize;
|
||||
|
||||
__device__ static void Run(const XYGridDesc_M_K& x_grid_desc_m_k,
|
||||
const XYGridDesc_M_K& y_grid_desc_m_k,
|
||||
const MeanVarCountGridDesc_M_K& mean_var_count_grid_desc_m_k,
|
||||
const ScaleBiasGridDesc_M& scale_grid_desc_m,
|
||||
const ScaleBiasGridDesc_M& bias_grid_desc_m,
|
||||
const MeanVarGridDesc_M& mean_var_grid_desc_m,
|
||||
index_t blkgroup_size,
|
||||
index_t num_xy_k_block_tile_iteration,
|
||||
index_t num_mean_var_count_k_block_tile_iteration,
|
||||
AccDataType epsilon,
|
||||
const MeanVarDataType* const __restrict__ p_in_welford_mean,
|
||||
const MeanVarDataType* const __restrict__ p_in_welford_variance,
|
||||
const int32_t* const __restrict__ p_in_welford_count,
|
||||
const XDataType* const __restrict__ p_x,
|
||||
const ScaleDataType* const __restrict__ p_scale,
|
||||
const BiasDataType* const __restrict__ p_bias,
|
||||
const YElementwiseOp y_elementwise_op,
|
||||
YDataType* const __restrict__ p_y,
|
||||
bool updateMovingAverage,
|
||||
AccDataType averageFactor,
|
||||
MeanVarDataType* const __restrict__ resultRunningMean,
|
||||
MeanVarDataType* const __restrict__ resultRunningVariance,
|
||||
bool saveMeanInvVariance,
|
||||
MeanVarDataType* const __restrict__ resultSaveMean,
|
||||
MeanVarDataType* const __restrict__ resultSaveInvVariance)
|
||||
|
||||
{
|
||||
using ck::math::sqrt;
|
||||
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize * 1, true>
|
||||
in_welford_mean_thread_buf;
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize * 1, true>
|
||||
in_welford_var_thread_buf;
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr, int32_t, MThreadSliceSize * 1, true>
|
||||
in_welford_count_thread_buf;
|
||||
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true>
|
||||
welford_mean_thread_buf;
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true>
|
||||
welford_var_thread_buf;
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr, int32_t, MThreadSliceSize, true>
|
||||
welford_count_thread_buf;
|
||||
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize * KThreadSliceSize, true>
|
||||
x_thread_buf;
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize * KThreadSliceSize, true>
|
||||
y_thread_buf;
|
||||
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true> scale_thread_buf;
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true> bias_thread_buf;
|
||||
|
||||
const index_t thread_local_id = get_thread_local_1d_id();
|
||||
const index_t block_global_id = get_block_1d_id();
|
||||
const index_t blkgroup_id = block_global_id / blkgroup_size;
|
||||
const index_t block_local_id = block_global_id % blkgroup_size;
|
||||
|
||||
const auto thread_cluster_idx =
|
||||
thread_cluster_desc.CalculateBottomIndex(make_multi_index(thread_local_id));
|
||||
|
||||
const auto thread_m_cluster_id = thread_cluster_idx[I0];
|
||||
const auto thread_k_cluster_id = thread_cluster_idx[I1];
|
||||
|
||||
using ThreadBufferLengths_M_K = Sequence<MThreadSliceSize, KThreadSliceSize>;
|
||||
using ThreadBufferLengths_M = Sequence<MThreadSliceSize>;
|
||||
using ThreadBufferLengths_M_1 = Sequence<MThreadSliceSize, 1>;
|
||||
constexpr auto thread_buffer_desc_m_k = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(Number<MThreadSliceSize>{}, Number<KThreadSliceSize>{}));
|
||||
constexpr auto thread_buffer_desc_m =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{}));
|
||||
constexpr auto thread_buffer_desc_m_1 = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(Number<MThreadSliceSize>{}, Number<1>{}));
|
||||
|
||||
auto threadwise_mean_var_load_m_k =
|
||||
ThreadwiseTensorSliceTransfer_v2<MeanVarDataType,
|
||||
AccDataType,
|
||||
MeanVarCountGridDesc_M_K,
|
||||
decltype(thread_buffer_desc_m_1),
|
||||
ThreadBufferLengths_M_1,
|
||||
Sequence<0, 1>,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
true>(
|
||||
mean_var_count_grid_desc_m_k,
|
||||
make_multi_index(blkgroup_id * M_BlockTileSize +
|
||||
thread_m_cluster_id * MThreadSliceSize,
|
||||
thread_k_cluster_id * 1));
|
||||
|
||||
auto threadwise_count_load_m_k =
|
||||
ThreadwiseTensorSliceTransfer_v2<int32_t,
|
||||
int32_t,
|
||||
MeanVarCountGridDesc_M_K,
|
||||
decltype(thread_buffer_desc_m_1),
|
||||
ThreadBufferLengths_M_1,
|
||||
Sequence<0, 1>,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
true>(
|
||||
mean_var_count_grid_desc_m_k,
|
||||
make_multi_index(blkgroup_id * M_BlockTileSize +
|
||||
thread_m_cluster_id * MThreadSliceSize,
|
||||
thread_k_cluster_id * 1));
|
||||
|
||||
const auto welford_mean_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_in_welford_mean, mean_var_count_grid_desc_m_k.GetElementSpaceSize());
|
||||
|
||||
const auto welford_var_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_in_welford_variance, mean_var_count_grid_desc_m_k.GetElementSpaceSize());
|
||||
|
||||
const auto welford_count_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_in_welford_count, mean_var_count_grid_desc_m_k.GetElementSpaceSize());
|
||||
|
||||
constexpr auto mean_var_count_thread_copy_step_m_k =
|
||||
make_multi_index(0, KThreadClusterSize * 1);
|
||||
|
||||
// Step 1: do final welford reduction to get mean and variance
|
||||
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
|
||||
welford_mean_thread_buf(I) = type_convert<AccDataType>(0.0f);
|
||||
welford_var_thread_buf(I) = type_convert<AccDataType>(0.0f);
|
||||
welford_count_thread_buf(I) = 0;
|
||||
});
|
||||
|
||||
for(index_t reducedTiles = 0; reducedTiles < num_mean_var_count_k_block_tile_iteration;
|
||||
++reducedTiles)
|
||||
{
|
||||
threadwise_mean_var_load_m_k.Run(mean_var_count_grid_desc_m_k,
|
||||
welford_mean_global_val_buf,
|
||||
thread_buffer_desc_m_1,
|
||||
make_tuple(I0, I0),
|
||||
in_welford_mean_thread_buf);
|
||||
|
||||
threadwise_mean_var_load_m_k.Run(mean_var_count_grid_desc_m_k,
|
||||
welford_var_global_val_buf,
|
||||
thread_buffer_desc_m_1,
|
||||
make_tuple(I0, I0),
|
||||
in_welford_var_thread_buf);
|
||||
|
||||
threadwise_count_load_m_k.Run(mean_var_count_grid_desc_m_k,
|
||||
welford_count_global_val_buf,
|
||||
thread_buffer_desc_m_1,
|
||||
make_tuple(I0, I0),
|
||||
in_welford_count_thread_buf);
|
||||
|
||||
ThreadwiseWelford::Run(in_welford_mean_thread_buf,
|
||||
in_welford_var_thread_buf,
|
||||
in_welford_count_thread_buf,
|
||||
welford_mean_thread_buf,
|
||||
welford_var_thread_buf,
|
||||
welford_count_thread_buf);
|
||||
|
||||
threadwise_mean_var_load_m_k.MoveSrcSliceWindow(mean_var_count_grid_desc_m_k,
|
||||
mean_var_count_thread_copy_step_m_k);
|
||||
threadwise_count_load_m_k.MoveSrcSliceWindow(mean_var_count_grid_desc_m_k,
|
||||
mean_var_count_thread_copy_step_m_k);
|
||||
}
|
||||
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
|
||||
if constexpr(I > 0)
|
||||
block_sync_lds();
|
||||
|
||||
BlockwiseWelford::Run(
|
||||
welford_mean_thread_buf(I), welford_var_thread_buf(I), welford_count_thread_buf(I));
|
||||
});
|
||||
|
||||
// Step 2: do normalization and output y
|
||||
|
||||
const index_t workSizePerBlock = K_BlockTileSize * num_xy_k_block_tile_iteration;
|
||||
|
||||
auto threadwise_x_load = ThreadwiseTensorSliceTransfer_v2<XDataType,
|
||||
AccDataType,
|
||||
XYGridDesc_M_K,
|
||||
decltype(thread_buffer_desc_m_k),
|
||||
ThreadBufferLengths_M_K,
|
||||
ThreadBufferDimAccessOrder,
|
||||
XSrcYDstVectorDim,
|
||||
XSrcVectorSize,
|
||||
1,
|
||||
true>(
|
||||
x_grid_desc_m_k,
|
||||
make_multi_index(blkgroup_id * M_BlockTileSize + thread_m_cluster_id * MThreadSliceSize,
|
||||
workSizePerBlock * block_local_id +
|
||||
thread_k_cluster_id * KThreadSliceSize));
|
||||
|
||||
auto threadwise_y_store =
|
||||
ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
|
||||
YDataType,
|
||||
decltype(thread_buffer_desc_m_k),
|
||||
XYGridDesc_M_K,
|
||||
YElementwiseOp,
|
||||
ThreadBufferLengths_M_K,
|
||||
ThreadBufferDimAccessOrder,
|
||||
XSrcYDstVectorDim,
|
||||
YDstVectorSize,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
1,
|
||||
true>(
|
||||
y_grid_desc_m_k,
|
||||
make_multi_index(
|
||||
blkgroup_id * M_BlockTileSize + thread_m_cluster_id * MThreadSliceSize,
|
||||
workSizePerBlock * block_local_id + thread_k_cluster_id * KThreadSliceSize),
|
||||
y_elementwise_op);
|
||||
|
||||
auto threadwise_scale_load =
|
||||
ThreadwiseTensorSliceTransfer_v2<ScaleDataType,
|
||||
AccDataType,
|
||||
ScaleBiasGridDesc_M,
|
||||
decltype(thread_buffer_desc_m),
|
||||
ThreadBufferLengths_M,
|
||||
Sequence<0>,
|
||||
0,
|
||||
ScaleSrcVectorSize,
|
||||
1,
|
||||
true>(
|
||||
scale_grid_desc_m,
|
||||
make_multi_index(blkgroup_id * M_BlockTileSize +
|
||||
thread_m_cluster_id * MThreadSliceSize));
|
||||
|
||||
auto threadwise_bias_load = ThreadwiseTensorSliceTransfer_v2<BiasDataType,
|
||||
AccDataType,
|
||||
ScaleBiasGridDesc_M,
|
||||
decltype(thread_buffer_desc_m),
|
||||
ThreadBufferLengths_M,
|
||||
Sequence<0>,
|
||||
0,
|
||||
BiasSrcVectorSize,
|
||||
1,
|
||||
true>(
|
||||
bias_grid_desc_m,
|
||||
make_multi_index(blkgroup_id * M_BlockTileSize +
|
||||
thread_m_cluster_id * MThreadSliceSize));
|
||||
|
||||
const auto x_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_x, x_grid_desc_m_k.GetElementSpaceSize());
|
||||
|
||||
const auto scale_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_scale, scale_grid_desc_m.GetElementSpaceSize());
|
||||
|
||||
const auto bias_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_bias, bias_grid_desc_m.GetElementSpaceSize());
|
||||
|
||||
auto y_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_y, y_grid_desc_m_k.GetElementSpaceSize());
|
||||
|
||||
threadwise_scale_load.Run(scale_grid_desc_m,
|
||||
scale_global_val_buf,
|
||||
thread_buffer_desc_m,
|
||||
make_tuple(I0),
|
||||
scale_thread_buf);
|
||||
|
||||
threadwise_bias_load.Run(bias_grid_desc_m,
|
||||
bias_global_val_buf,
|
||||
thread_buffer_desc_m,
|
||||
make_tuple(I0),
|
||||
bias_thread_buf);
|
||||
|
||||
constexpr auto xy_thread_copy_step_m_k = make_multi_index(0, K_BlockTileSize);
|
||||
|
||||
for(index_t workTiles = 0; workTiles < num_xy_k_block_tile_iteration; ++workTiles)
|
||||
{
|
||||
threadwise_x_load.Run(x_grid_desc_m_k,
|
||||
x_global_val_buf,
|
||||
thread_buffer_desc_m_k,
|
||||
make_tuple(I0, I0),
|
||||
x_thread_buf);
|
||||
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
|
||||
AccDataType multiplier =
|
||||
scale_thread_buf[iM] / sqrt(welford_var_thread_buf[iM] + epsilon);
|
||||
|
||||
AccDataType fused_mean_bias =
|
||||
bias_thread_buf[iM] - welford_mean_thread_buf[iM] * multiplier;
|
||||
|
||||
static_for<0, KThreadSliceSize, 1>{}([&](auto iK) {
|
||||
constexpr auto offset =
|
||||
thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK));
|
||||
|
||||
y_thread_buf(Number<offset>{}) =
|
||||
x_thread_buf[Number<offset>{}] * multiplier + fused_mean_bias;
|
||||
});
|
||||
});
|
||||
|
||||
threadwise_y_store.Run(thread_buffer_desc_m_k,
|
||||
make_tuple(I0, I0),
|
||||
y_thread_buf,
|
||||
y_grid_desc_m_k,
|
||||
y_global_val_buf);
|
||||
|
||||
threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, xy_thread_copy_step_m_k);
|
||||
threadwise_y_store.MoveDstSliceWindow(y_grid_desc_m_k, xy_thread_copy_step_m_k);
|
||||
}
|
||||
|
||||
// Step 3: update the moving average of mean and variance (optional)
|
||||
|
||||
if(updateMovingAverage && block_local_id == 0 && thread_k_cluster_id == 0)
|
||||
{
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true>
|
||||
running_mean_thread_buf;
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true>
|
||||
running_var_thread_buf;
|
||||
|
||||
auto running_mean_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
resultRunningMean, mean_var_grid_desc_m.GetElementSpaceSize());
|
||||
|
||||
auto running_var_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
resultRunningVariance, mean_var_grid_desc_m.GetElementSpaceSize());
|
||||
|
||||
auto threadwise_mean_var_load_m =
|
||||
ThreadwiseTensorSliceTransfer_v2<MeanVarDataType,
|
||||
AccDataType,
|
||||
MeanVarGridDesc_M,
|
||||
decltype(thread_buffer_desc_m),
|
||||
ThreadBufferLengths_M,
|
||||
Sequence<0>,
|
||||
0,
|
||||
MeanVarSrcDstVectorSize,
|
||||
1,
|
||||
true>(
|
||||
mean_var_grid_desc_m,
|
||||
make_multi_index(blkgroup_id * M_BlockTileSize +
|
||||
thread_m_cluster_id * MThreadSliceSize));
|
||||
|
||||
threadwise_mean_var_load_m.Run(mean_var_grid_desc_m,
|
||||
running_mean_global_buf,
|
||||
thread_buffer_desc_m,
|
||||
make_tuple(I0),
|
||||
running_mean_thread_buf);
|
||||
|
||||
threadwise_mean_var_load_m.Run(mean_var_grid_desc_m,
|
||||
running_var_global_buf,
|
||||
thread_buffer_desc_m,
|
||||
make_tuple(I0),
|
||||
running_var_thread_buf);
|
||||
|
||||
AccDataType oneMinusAverageFactor = type_convert<AccDataType>(1.0) - averageFactor;
|
||||
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
|
||||
running_mean_thread_buf(I) = running_mean_thread_buf[I] * oneMinusAverageFactor +
|
||||
welford_mean_thread_buf[I] * averageFactor;
|
||||
running_var_thread_buf(I) = running_var_thread_buf[I] * oneMinusAverageFactor +
|
||||
welford_var_thread_buf[I] * averageFactor;
|
||||
});
|
||||
|
||||
auto threadwise_mean_var_store =
|
||||
ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
|
||||
MeanVarDataType,
|
||||
decltype(thread_buffer_desc_m),
|
||||
MeanVarGridDesc_M,
|
||||
PassThroughOp,
|
||||
ThreadBufferLengths_M,
|
||||
Sequence<0>,
|
||||
0,
|
||||
MeanVarSrcDstVectorSize,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
1,
|
||||
true>(
|
||||
mean_var_grid_desc_m,
|
||||
make_multi_index(blkgroup_id * M_BlockTileSize +
|
||||
thread_m_cluster_id * MThreadSliceSize),
|
||||
PassThroughOp{});
|
||||
|
||||
threadwise_mean_var_store.Run(thread_buffer_desc_m,
|
||||
make_tuple(I0),
|
||||
running_mean_thread_buf,
|
||||
mean_var_grid_desc_m,
|
||||
running_mean_global_buf);
|
||||
|
||||
threadwise_mean_var_store.Run(thread_buffer_desc_m,
|
||||
make_tuple(I0),
|
||||
running_var_thread_buf,
|
||||
mean_var_grid_desc_m,
|
||||
running_var_global_buf);
|
||||
};
|
||||
|
||||
// Step 4: save mean and inv-variance (optional)
|
||||
|
||||
if(saveMeanInvVariance && block_local_id == 0 && thread_k_cluster_id == 0)
|
||||
{
|
||||
auto result_mean_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
resultSaveMean, mean_var_grid_desc_m.GetElementSpaceSize());
|
||||
|
||||
auto result_inv_var_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
resultSaveInvVariance, mean_var_grid_desc_m.GetElementSpaceSize());
|
||||
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
|
||||
welford_var_thread_buf(I) =
|
||||
type_convert<AccDataType>(1.0f) / sqrt(epsilon + welford_var_thread_buf[I]);
|
||||
});
|
||||
|
||||
auto threadwise_mean_inv_var_store =
|
||||
ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
|
||||
MeanVarDataType,
|
||||
decltype(thread_buffer_desc_m),
|
||||
MeanVarGridDesc_M,
|
||||
PassThroughOp,
|
||||
ThreadBufferLengths_M,
|
||||
Sequence<0>,
|
||||
0,
|
||||
MeanVarSrcDstVectorSize,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
1,
|
||||
true>(
|
||||
mean_var_grid_desc_m,
|
||||
make_multi_index(blkgroup_id * M_BlockTileSize +
|
||||
thread_m_cluster_id * MThreadSliceSize),
|
||||
PassThroughOp{});
|
||||
|
||||
threadwise_mean_inv_var_store.Run(thread_buffer_desc_m,
|
||||
make_tuple(I0),
|
||||
welford_mean_thread_buf,
|
||||
mean_var_grid_desc_m,
|
||||
result_mean_global_buf);
|
||||
|
||||
threadwise_mean_inv_var_store.Run(thread_buffer_desc_m,
|
||||
make_tuple(I0),
|
||||
welford_var_thread_buf,
|
||||
mean_var_grid_desc_m,
|
||||
result_inv_var_global_buf);
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,482 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck/utility/data_type.hpp"
|
||||
#include "ck/utility/math_v2.hpp"
|
||||
#include "ck/tensor_operation/gpu/block/blockwise_welford.hpp"
|
||||
#include "ck/tensor_operation/gpu/thread/threadwise_welford.hpp"
|
||||
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <typename GridwiseBatchrNormForwardWithBlockwiseWelford_,
|
||||
typename XDataType,
|
||||
typename YDataType,
|
||||
typename AccDataType,
|
||||
typename ScaleDataType,
|
||||
typename BiasDataType,
|
||||
typename MeanVarDataType,
|
||||
typename YElementwiseOp,
|
||||
typename XYGridDesc_M_K,
|
||||
typename ScaleBiasGridDesc_M,
|
||||
typename MeanVarGridDesc_M,
|
||||
typename GetReduceCountPerThreadFunctor>
|
||||
__global__ void kernel_batchnorm_forward_with_blockwise_welford(
|
||||
const XYGridDesc_M_K x_grid_desc_m_k,
|
||||
const XYGridDesc_M_K y_grid_desc_m_k,
|
||||
const ScaleBiasGridDesc_M scale_grid_desc_m,
|
||||
const ScaleBiasGridDesc_M bias_grid_desc_m,
|
||||
const MeanVarGridDesc_M mean_var_grid_desc_m,
|
||||
const GetReduceCountPerThreadFunctor get_reduce_count_per_thread,
|
||||
index_t num_k_block_tile_iteration,
|
||||
AccDataType epsilon,
|
||||
const XDataType* const __restrict__ p_x,
|
||||
const ScaleDataType* const __restrict__ p_scale,
|
||||
const BiasDataType* const __restrict__ p_bias,
|
||||
const YElementwiseOp y_elementwise_op,
|
||||
YDataType* const __restrict__ p_y,
|
||||
bool updateMovingAverage,
|
||||
AccDataType averageFactor,
|
||||
MeanVarDataType* const __restrict__ resultRunningMean,
|
||||
MeanVarDataType* const __restrict__ resultRunningVariance,
|
||||
bool saveMeanInvVariance,
|
||||
MeanVarDataType* const __restrict__ resultSaveMean,
|
||||
MeanVarDataType* const __restrict__ resultSaveInvVariance)
|
||||
{
|
||||
GridwiseBatchrNormForwardWithBlockwiseWelford_::Run(x_grid_desc_m_k,
|
||||
y_grid_desc_m_k,
|
||||
scale_grid_desc_m,
|
||||
bias_grid_desc_m,
|
||||
mean_var_grid_desc_m,
|
||||
get_reduce_count_per_thread,
|
||||
num_k_block_tile_iteration,
|
||||
epsilon,
|
||||
p_x,
|
||||
p_scale,
|
||||
p_bias,
|
||||
y_elementwise_op,
|
||||
p_y,
|
||||
updateMovingAverage,
|
||||
averageFactor,
|
||||
resultRunningMean,
|
||||
resultRunningVariance,
|
||||
saveMeanInvVariance,
|
||||
resultSaveMean,
|
||||
resultSaveInvVariance);
|
||||
};
|
||||
|
||||
template <typename XDataType,
|
||||
typename YDataType,
|
||||
typename AccDataType,
|
||||
typename ScaleDataType,
|
||||
typename BiasDataType,
|
||||
typename MeanVarDataType,
|
||||
typename YElementwiseOp,
|
||||
typename XYGridDesc_M_K,
|
||||
typename ScaleBiasGridDesc_M,
|
||||
typename MeanVarGridDesc_M,
|
||||
typename GetReduceCountPerThreadFunctor,
|
||||
index_t BlockSize,
|
||||
index_t MThreadClusterSize,
|
||||
index_t KThreadClusterSize,
|
||||
index_t MThreadSliceSize,
|
||||
index_t KThreadSliceSize,
|
||||
index_t XSrcYDstVectorDim,
|
||||
index_t XSrcVectorSize,
|
||||
index_t YDstVectorSize,
|
||||
index_t ScaleSrcVectorSize,
|
||||
index_t BiasSrcVectorSize,
|
||||
index_t MeanVarSrcDstVectorSize>
|
||||
struct GridwiseBatchNormForwardWithBlockwiseWelford
|
||||
{
|
||||
static_assert((XSrcYDstVectorDim == 0 && MThreadSliceSize % XSrcVectorSize == 0) ||
|
||||
(XSrcYDstVectorDim == 1 && KThreadSliceSize % XSrcVectorSize == 0),
|
||||
"Invalid thread slice sizes and/or vector sizes configuration, please check!");
|
||||
|
||||
static_assert((XSrcYDstVectorDim == 0 && MThreadSliceSize % YDstVectorSize == 0) ||
|
||||
(XSrcYDstVectorDim == 1 && KThreadSliceSize % YDstVectorSize == 0),
|
||||
"Invalid thread slice sizes and/or vector sizes configuration, please check!");
|
||||
|
||||
static constexpr bool reorder_thread_cluster = (XSrcYDstVectorDim == 0);
|
||||
|
||||
using ThreadClusterLengths_M_K = Sequence<MThreadClusterSize, KThreadClusterSize>;
|
||||
|
||||
using ThreadBufferDimAccessOrder =
|
||||
typename conditional<reorder_thread_cluster, Sequence<1, 0>, Sequence<0, 1>>::type;
|
||||
|
||||
using ThreadClusterArrangeOrder =
|
||||
typename conditional<reorder_thread_cluster, Sequence<1, 0>, Sequence<0, 1>>::type;
|
||||
|
||||
static constexpr auto thread_cluster_desc =
|
||||
make_cluster_descriptor(ThreadClusterLengths_M_K{}, ThreadClusterArrangeOrder{});
|
||||
|
||||
using ThreadReduceSrcDesc_M_K = decltype(make_naive_tensor_descriptor_packed(
|
||||
make_tuple(Number<MThreadSliceSize>{}, Number<KThreadSliceSize>{})));
|
||||
using ThreadReduceDstDesc_M =
|
||||
decltype(make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{})));
|
||||
|
||||
using ThreadwiseWelford =
|
||||
ThreadwiseWelford<AccDataType, ThreadReduceSrcDesc_M_K, ThreadReduceDstDesc_M>;
|
||||
|
||||
using BlockwiseWelford = BlockwiseWelford<AccDataType,
|
||||
BlockSize,
|
||||
ThreadClusterLengths_M_K,
|
||||
ThreadClusterArrangeOrder>;
|
||||
|
||||
using PassThroughOp = tensor_operation::element_wise::PassThrough;
|
||||
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
|
||||
static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize;
|
||||
static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize;
|
||||
|
||||
__device__ static void Run(const XYGridDesc_M_K& x_grid_desc_m_k,
|
||||
const XYGridDesc_M_K& y_grid_desc_m_k,
|
||||
const ScaleBiasGridDesc_M& scale_grid_desc_m,
|
||||
const ScaleBiasGridDesc_M& bias_grid_desc_m,
|
||||
const MeanVarGridDesc_M& mean_var_grid_desc_m,
|
||||
const GetReduceCountPerThreadFunctor& get_reduce_count_per_thread,
|
||||
index_t num_k_block_tile_iteration,
|
||||
AccDataType epsilon,
|
||||
const XDataType* const __restrict__ p_x,
|
||||
const ScaleDataType* const __restrict__ p_scale,
|
||||
const BiasDataType* const __restrict__ p_bias,
|
||||
const YElementwiseOp y_elementwise_op,
|
||||
YDataType* const __restrict__ p_y,
|
||||
bool updateMovingAverage,
|
||||
AccDataType averageFactor,
|
||||
MeanVarDataType* const __restrict__ resultRunningMean,
|
||||
MeanVarDataType* const __restrict__ resultRunningVariance,
|
||||
bool saveMeanInvVariance,
|
||||
MeanVarDataType* const __restrict__ resultSaveMean,
|
||||
MeanVarDataType* const __restrict__ resultSaveInvVariance)
|
||||
{
|
||||
using ck::math::sqrt;
|
||||
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize * KThreadSliceSize, true>
|
||||
x_thread_buf;
|
||||
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true> scale_thread_buf;
|
||||
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true> bias_thread_buf;
|
||||
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize * KThreadSliceSize, true>
|
||||
y_thread_buf;
|
||||
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true> mean_thread_buf;
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true> var_thread_buf;
|
||||
|
||||
const index_t thread_local_id = get_thread_local_1d_id();
|
||||
const index_t block_global_id = get_block_1d_id();
|
||||
|
||||
const auto thread_cluster_idx =
|
||||
thread_cluster_desc.CalculateBottomIndex(make_multi_index(thread_local_id));
|
||||
|
||||
const auto thread_m_cluster_id = thread_cluster_idx[I0];
|
||||
const auto thread_k_cluster_id = thread_cluster_idx[I1];
|
||||
|
||||
using ThreadBufferLengths_M_K = Sequence<MThreadSliceSize, KThreadSliceSize>;
|
||||
using ThreadBufferLengths_M = Sequence<MThreadSliceSize>;
|
||||
constexpr auto thread_buffer_desc_m_k = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(Number<MThreadSliceSize>{}, Number<KThreadSliceSize>{}));
|
||||
constexpr auto thread_buffer_desc_m =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{}));
|
||||
|
||||
auto threadwise_x_load = ThreadwiseTensorSliceTransfer_v2<XDataType,
|
||||
AccDataType,
|
||||
XYGridDesc_M_K,
|
||||
decltype(thread_buffer_desc_m_k),
|
||||
ThreadBufferLengths_M_K,
|
||||
ThreadBufferDimAccessOrder,
|
||||
XSrcYDstVectorDim,
|
||||
XSrcVectorSize,
|
||||
1,
|
||||
true>(
|
||||
x_grid_desc_m_k,
|
||||
make_multi_index(block_global_id * M_BlockTileSize +
|
||||
thread_m_cluster_id * MThreadSliceSize,
|
||||
thread_k_cluster_id * KThreadSliceSize));
|
||||
|
||||
auto threadwise_y_store =
|
||||
ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
|
||||
YDataType,
|
||||
decltype(thread_buffer_desc_m_k),
|
||||
XYGridDesc_M_K,
|
||||
YElementwiseOp,
|
||||
ThreadBufferLengths_M_K,
|
||||
ThreadBufferDimAccessOrder,
|
||||
XSrcYDstVectorDim,
|
||||
YDstVectorSize,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
1,
|
||||
true>(
|
||||
y_grid_desc_m_k,
|
||||
make_multi_index(block_global_id * M_BlockTileSize +
|
||||
thread_m_cluster_id * MThreadSliceSize,
|
||||
thread_k_cluster_id * KThreadSliceSize),
|
||||
y_elementwise_op);
|
||||
|
||||
auto threadwise_scale_load =
|
||||
ThreadwiseTensorSliceTransfer_v2<ScaleDataType,
|
||||
AccDataType,
|
||||
ScaleBiasGridDesc_M,
|
||||
decltype(thread_buffer_desc_m),
|
||||
ThreadBufferLengths_M,
|
||||
Sequence<0>,
|
||||
0,
|
||||
ScaleSrcVectorSize,
|
||||
1,
|
||||
true>(
|
||||
scale_grid_desc_m,
|
||||
make_multi_index(block_global_id * M_BlockTileSize +
|
||||
thread_m_cluster_id * MThreadSliceSize));
|
||||
|
||||
auto threadwise_bias_load = ThreadwiseTensorSliceTransfer_v2<BiasDataType,
|
||||
AccDataType,
|
||||
ScaleBiasGridDesc_M,
|
||||
decltype(thread_buffer_desc_m),
|
||||
ThreadBufferLengths_M,
|
||||
Sequence<0>,
|
||||
0,
|
||||
BiasSrcVectorSize,
|
||||
1,
|
||||
true>(
|
||||
bias_grid_desc_m,
|
||||
make_multi_index(block_global_id * M_BlockTileSize +
|
||||
thread_m_cluster_id * MThreadSliceSize));
|
||||
|
||||
constexpr auto thread_copy_fwd_step_m_k = make_multi_index(0, K_BlockTileSize);
|
||||
constexpr auto thread_copy_bwd_step_m_k = make_multi_index(0, -K_BlockTileSize);
|
||||
|
||||
const auto x_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_x, x_grid_desc_m_k.GetElementSpaceSize());
|
||||
|
||||
const auto scale_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_scale, scale_grid_desc_m.GetElementSpaceSize());
|
||||
|
||||
const auto bias_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_bias, bias_grid_desc_m.GetElementSpaceSize());
|
||||
|
||||
auto y_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_y, y_grid_desc_m_k.GetElementSpaceSize());
|
||||
|
||||
// Step 1: do welford reduction to get mean and variance
|
||||
|
||||
auto threadwise_welford = ThreadwiseWelford();
|
||||
threadwise_welford.max_count_ = get_reduce_count_per_thread(thread_k_cluster_id);
|
||||
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
|
||||
mean_thread_buf(I) = type_convert<AccDataType>(0.0f);
|
||||
var_thread_buf(I) = type_convert<AccDataType>(0.0f);
|
||||
});
|
||||
|
||||
for(index_t reducedTiles = 0; reducedTiles < num_k_block_tile_iteration; ++reducedTiles)
|
||||
{
|
||||
|
||||
threadwise_x_load.Run(x_grid_desc_m_k,
|
||||
x_global_val_buf,
|
||||
thread_buffer_desc_m_k,
|
||||
make_tuple(I0, I0),
|
||||
x_thread_buf);
|
||||
|
||||
threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, thread_copy_fwd_step_m_k);
|
||||
threadwise_welford.Run(x_thread_buf, mean_thread_buf, var_thread_buf);
|
||||
}
|
||||
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
|
||||
if constexpr(I > 0)
|
||||
block_sync_lds();
|
||||
|
||||
int count = threadwise_welford.cur_count_;
|
||||
BlockwiseWelford::Run(mean_thread_buf(I), var_thread_buf(I), count);
|
||||
});
|
||||
|
||||
// Step 2: do normalization and output y
|
||||
|
||||
threadwise_scale_load.Run(scale_grid_desc_m,
|
||||
scale_global_val_buf,
|
||||
thread_buffer_desc_m,
|
||||
make_tuple(I0),
|
||||
scale_thread_buf);
|
||||
|
||||
threadwise_bias_load.Run(bias_grid_desc_m,
|
||||
bias_global_val_buf,
|
||||
thread_buffer_desc_m,
|
||||
make_tuple(I0),
|
||||
bias_thread_buf);
|
||||
|
||||
auto thread_copy_tail_m_k = (num_k_block_tile_iteration - 1) * thread_copy_fwd_step_m_k;
|
||||
|
||||
threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, thread_copy_bwd_step_m_k);
|
||||
threadwise_y_store.MoveDstSliceWindow(y_grid_desc_m_k, thread_copy_tail_m_k);
|
||||
|
||||
for(index_t reducedTiles = 0; reducedTiles < num_k_block_tile_iteration; ++reducedTiles)
|
||||
{
|
||||
threadwise_x_load.Run(x_grid_desc_m_k,
|
||||
x_global_val_buf,
|
||||
thread_buffer_desc_m_k,
|
||||
make_tuple(I0, I0),
|
||||
x_thread_buf);
|
||||
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
|
||||
AccDataType multiplier =
|
||||
scale_thread_buf[Number<iM>{}] / sqrt(var_thread_buf[iM] + epsilon);
|
||||
|
||||
AccDataType fused_mean_bias =
|
||||
bias_thread_buf[Number<iM>{}] - mean_thread_buf[iM] * multiplier;
|
||||
|
||||
static_for<0, KThreadSliceSize, 1>{}([&](auto iK) {
|
||||
constexpr auto offset =
|
||||
thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK));
|
||||
|
||||
// normalize
|
||||
y_thread_buf(Number<offset>{}) =
|
||||
x_thread_buf[Number<offset>{}] * multiplier + fused_mean_bias;
|
||||
});
|
||||
});
|
||||
|
||||
threadwise_y_store.Run(thread_buffer_desc_m_k,
|
||||
make_tuple(I0, I0),
|
||||
y_thread_buf,
|
||||
y_grid_desc_m_k,
|
||||
y_global_val_buf);
|
||||
|
||||
threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, thread_copy_bwd_step_m_k);
|
||||
threadwise_y_store.MoveDstSliceWindow(y_grid_desc_m_k, thread_copy_bwd_step_m_k);
|
||||
}
|
||||
|
||||
// Step 3: update the moving average of mean and variance (optional)
|
||||
|
||||
if(updateMovingAverage && thread_k_cluster_id == 0)
|
||||
{
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true>
|
||||
running_mean_thread_buf;
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true>
|
||||
running_var_thread_buf;
|
||||
|
||||
auto running_mean_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
resultRunningMean, mean_var_grid_desc_m.GetElementSpaceSize());
|
||||
|
||||
auto running_var_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
resultRunningVariance, mean_var_grid_desc_m.GetElementSpaceSize());
|
||||
|
||||
auto threadwise_mean_var_load =
|
||||
ThreadwiseTensorSliceTransfer_v2<MeanVarDataType,
|
||||
AccDataType,
|
||||
MeanVarGridDesc_M,
|
||||
decltype(thread_buffer_desc_m),
|
||||
ThreadBufferLengths_M,
|
||||
Sequence<0>,
|
||||
0,
|
||||
MeanVarSrcDstVectorSize,
|
||||
1,
|
||||
true>(
|
||||
mean_var_grid_desc_m,
|
||||
make_multi_index(block_global_id * M_BlockTileSize +
|
||||
thread_m_cluster_id * MThreadSliceSize));
|
||||
|
||||
threadwise_mean_var_load.Run(mean_var_grid_desc_m,
|
||||
running_mean_global_buf,
|
||||
thread_buffer_desc_m,
|
||||
make_tuple(I0),
|
||||
running_mean_thread_buf);
|
||||
|
||||
threadwise_mean_var_load.Run(mean_var_grid_desc_m,
|
||||
running_var_global_buf,
|
||||
thread_buffer_desc_m,
|
||||
make_tuple(I0),
|
||||
running_var_thread_buf);
|
||||
|
||||
AccDataType oneMinusAverageFactor = type_convert<AccDataType>(1.0) - averageFactor;
|
||||
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
|
||||
running_mean_thread_buf(I) = running_mean_thread_buf[I] * oneMinusAverageFactor +
|
||||
mean_thread_buf[I] * averageFactor;
|
||||
running_var_thread_buf(I) = running_var_thread_buf[I] * oneMinusAverageFactor +
|
||||
var_thread_buf[I] * averageFactor;
|
||||
});
|
||||
|
||||
auto threadwise_mean_var_store =
|
||||
ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
|
||||
MeanVarDataType,
|
||||
decltype(thread_buffer_desc_m),
|
||||
MeanVarGridDesc_M,
|
||||
PassThroughOp,
|
||||
ThreadBufferLengths_M,
|
||||
Sequence<0>,
|
||||
0,
|
||||
MeanVarSrcDstVectorSize,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
1,
|
||||
true>(
|
||||
mean_var_grid_desc_m,
|
||||
make_multi_index(block_global_id * M_BlockTileSize +
|
||||
thread_m_cluster_id * MThreadSliceSize),
|
||||
PassThroughOp{});
|
||||
|
||||
threadwise_mean_var_store.Run(thread_buffer_desc_m,
|
||||
make_tuple(I0),
|
||||
running_mean_thread_buf,
|
||||
mean_var_grid_desc_m,
|
||||
running_mean_global_buf);
|
||||
|
||||
threadwise_mean_var_store.Run(thread_buffer_desc_m,
|
||||
make_tuple(I0),
|
||||
running_var_thread_buf,
|
||||
mean_var_grid_desc_m,
|
||||
running_var_global_buf);
|
||||
};
|
||||
|
||||
// Step 4: save mean and inv-variance (optional)
|
||||
|
||||
if(saveMeanInvVariance && thread_k_cluster_id == 0)
|
||||
{
|
||||
auto result_mean_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
resultSaveMean, mean_var_grid_desc_m.GetElementSpaceSize());
|
||||
|
||||
auto result_inv_var_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
resultSaveInvVariance, mean_var_grid_desc_m.GetElementSpaceSize());
|
||||
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
|
||||
var_thread_buf(I) =
|
||||
type_convert<AccDataType>(1.0f) / sqrt(epsilon + var_thread_buf[I]);
|
||||
});
|
||||
|
||||
auto threadwise_mean_inv_var_store =
|
||||
ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
|
||||
MeanVarDataType,
|
||||
decltype(thread_buffer_desc_m),
|
||||
MeanVarGridDesc_M,
|
||||
PassThroughOp,
|
||||
ThreadBufferLengths_M,
|
||||
Sequence<0>,
|
||||
0,
|
||||
MeanVarSrcDstVectorSize,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
1,
|
||||
true>(
|
||||
mean_var_grid_desc_m,
|
||||
make_multi_index(block_global_id * M_BlockTileSize +
|
||||
thread_m_cluster_id * MThreadSliceSize),
|
||||
PassThroughOp{});
|
||||
|
||||
threadwise_mean_inv_var_store.Run(thread_buffer_desc_m,
|
||||
make_tuple(I0),
|
||||
mean_thread_buf,
|
||||
mean_var_grid_desc_m,
|
||||
result_mean_global_buf);
|
||||
|
||||
threadwise_mean_inv_var_store.Run(thread_buffer_desc_m,
|
||||
make_tuple(I0),
|
||||
var_thread_buf,
|
||||
mean_var_grid_desc_m,
|
||||
result_inv_var_global_buf);
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
@@ -3,6 +3,7 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
@@ -75,4 +75,63 @@ struct ThreadwiseWelford
|
||||
int max_count_;
|
||||
};
|
||||
|
||||
template <typename T,
|
||||
typename SrcMeanVarCountThreadDesc_M_K,
|
||||
typename DstMeanVarThreadDesc_M,
|
||||
bool GetActualVariance = false>
|
||||
struct ThreadwiseWelfordMerge
|
||||
{
|
||||
static constexpr auto src_thread_desc_m_k = SrcMeanVarCountThreadDesc_M_K{};
|
||||
static constexpr auto dst_thread_desc_m = DstMeanVarThreadDesc_M{};
|
||||
|
||||
static constexpr auto src_length_m = src_thread_desc_m_k.GetLength(Number<0>{});
|
||||
static constexpr auto src_length_k = src_thread_desc_m_k.GetLength(Number<1>{});
|
||||
static constexpr auto dst_length_m = dst_thread_desc_m.GetLength(Number<0>{});
|
||||
|
||||
static_assert(src_length_m == dst_length_m, "lengths of source and dst buffer must match!");
|
||||
|
||||
__device__ static void
|
||||
Merge(T& mean_a, T& var_a, int32_t& count_a, T mean_b, T var_b, int32_t count_b)
|
||||
{
|
||||
int count = count_a + count_b;
|
||||
T count_b_over_count = count == 0 ? type_convert<T>(0) : type_convert<T>(count_b) / count;
|
||||
T delta = mean_b - mean_a;
|
||||
mean_a += delta * count_b_over_count;
|
||||
var_a += var_b + delta * delta * count_a * count_b_over_count;
|
||||
count_a = count;
|
||||
}
|
||||
|
||||
template <typename SrcMeanBufferType,
|
||||
typename SrcVarBufferType,
|
||||
typename SrcCountBufferType,
|
||||
typename DstMeanBufferType,
|
||||
typename DstVarBufferType,
|
||||
typename DstCountBufferType>
|
||||
__device__ static void Run(const SrcMeanBufferType& src_mean_buf,
|
||||
const SrcVarBufferType& src_var_buf,
|
||||
const SrcCountBufferType& src_count_buf,
|
||||
DstMeanBufferType& dst_mean_buf,
|
||||
DstVarBufferType& dst_var_buf,
|
||||
DstCountBufferType& dst_count_buf)
|
||||
{
|
||||
static_for<0, src_length_m, 1>{}([&](auto iM) {
|
||||
static_for<0, src_length_k, 1>{}([&](auto iK) {
|
||||
constexpr auto src_offset = src_thread_desc_m_k.CalculateOffset(make_tuple(iM, iK));
|
||||
|
||||
Merge(dst_mean_buf(iM),
|
||||
dst_var_buf(iM),
|
||||
dst_count_buf(iM),
|
||||
src_mean_buf[Number<src_offset>{}],
|
||||
src_var_buf[Number<src_offset>{}],
|
||||
src_count_buf[Number<src_offset>{}]);
|
||||
});
|
||||
|
||||
if constexpr(GetActualVariance)
|
||||
{
|
||||
dst_var_buf(iM) = dst_var_buf[iM] / dst_count_buf[iM];
|
||||
};
|
||||
});
|
||||
};
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
|
||||
@@ -9,46 +9,61 @@
|
||||
#include <algorithm>
|
||||
#include <thread>
|
||||
|
||||
#include "ck/utility/math_v2.hpp"
|
||||
#include "ck/utility/ignore.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_batchnorm_forward.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace host {
|
||||
|
||||
template <typename InOutDataType, typename AccDataType>
|
||||
struct ReferenceBatchNormFwd_Input_N_H_W_C_Output_C : public device::DeviceBatchNormFwd<4, 3>
|
||||
template <typename XDataType,
|
||||
typename YDataType,
|
||||
typename AccDataType,
|
||||
typename ScaleDataType,
|
||||
typename BiasDataType,
|
||||
typename MeanVarDataType,
|
||||
typename YElementwiseOp>
|
||||
struct ReferenceBatchNormFwd_Input_N_H_W_C_Output_C
|
||||
: public device::DeviceBatchNormFwd<4, 3, YElementwiseOp>
|
||||
{
|
||||
struct Argument : public device::BaseArgument
|
||||
{
|
||||
Argument(const std::array<index_t, 4> xyLengths,
|
||||
const std::array<index_t, 4> xStrides,
|
||||
const std::array<index_t, 4> yStrides,
|
||||
const std::array<int, 3> reduceDims,
|
||||
const std::array<index_t, 1> bnScaleBiasMeanVarLengths,
|
||||
const std::array<index_t, 1> bnScaleBiasMeanVarStrides,
|
||||
const InOutDataType* p_x,
|
||||
const AccDataType* bnScale,
|
||||
const AccDataType* bnBias,
|
||||
InOutDataType* p_y,
|
||||
double exponentialAverageFactor,
|
||||
AccDataType* resultRunningMean,
|
||||
AccDataType* resultRunningVariance,
|
||||
const std::array<index_t, 1> bnScaleStrides,
|
||||
const std::array<index_t, 1> bnBiasStrides,
|
||||
const std::array<index_t, 1> bnMeanVarStrides,
|
||||
const XDataType* p_x,
|
||||
const ScaleDataType* bnScale,
|
||||
const BiasDataType* bnBias,
|
||||
double epsilon,
|
||||
AccDataType* resultSaveMean,
|
||||
AccDataType* resultSaveInvVariance)
|
||||
const YElementwiseOp y_elementwise_op,
|
||||
YDataType* p_y,
|
||||
MeanVarDataType* resultSaveMean,
|
||||
MeanVarDataType* resultSaveInvVariance,
|
||||
double averageFactor,
|
||||
MeanVarDataType* resultRunningMean,
|
||||
MeanVarDataType* resultRunningVariance)
|
||||
: p_x_(p_x),
|
||||
bnScale_(bnScale),
|
||||
bnBias_(bnBias),
|
||||
y_elementwise_op_(y_elementwise_op),
|
||||
p_y_(p_y),
|
||||
resultRunningMean_(resultRunningMean),
|
||||
resultRunningVariance_(resultRunningVariance),
|
||||
resultSaveMean_(resultSaveMean),
|
||||
resultSaveInvVariance_(resultSaveInvVariance),
|
||||
exponentialAverageFactor_(exponentialAverageFactor),
|
||||
epsilon_(epsilon)
|
||||
resultRunningMean_(resultRunningMean),
|
||||
resultRunningVariance_(resultRunningVariance)
|
||||
{
|
||||
(void)xStrides;
|
||||
(void)yStrides;
|
||||
(void)bnScaleBiasMeanVarStrides;
|
||||
ignore = xStrides;
|
||||
ignore = yStrides;
|
||||
ignore = bnScaleStrides;
|
||||
ignore = bnBiasStrides;
|
||||
ignore = bnMeanVarStrides;
|
||||
ignore = reduceDims;
|
||||
|
||||
if(xyLengths.size() != 4 || bnScaleBiasMeanVarLengths.size() != 1 ||
|
||||
bnScaleBiasMeanVarLengths[0] != xyLengths[3])
|
||||
@@ -59,26 +74,30 @@ struct ReferenceBatchNormFwd_Input_N_H_W_C_Output_C : public device::DeviceBatch
|
||||
w = xyLengths[2];
|
||||
c = xyLengths[3];
|
||||
|
||||
epsilon_ = type_convert<AccDataType>(epsilon);
|
||||
averageFactor_ = type_convert<AccDataType>(averageFactor);
|
||||
|
||||
resultSave = (resultSaveMean != nullptr && resultSaveInvVariance != nullptr);
|
||||
resultRunning = (resultRunningMean != nullptr && resultRunningVariance != nullptr);
|
||||
}
|
||||
|
||||
const InOutDataType* p_x_;
|
||||
const AccDataType* bnScale_;
|
||||
const AccDataType* bnBias_;
|
||||
InOutDataType* p_y_;
|
||||
const XDataType* p_x_;
|
||||
const ScaleDataType* bnScale_;
|
||||
const BiasDataType* bnBias_;
|
||||
const YElementwiseOp y_elementwise_op_;
|
||||
YDataType* p_y_;
|
||||
|
||||
AccDataType* resultRunningMean_;
|
||||
AccDataType* resultRunningVariance_;
|
||||
AccDataType* resultSaveMean_;
|
||||
AccDataType* resultSaveInvVariance_;
|
||||
MeanVarDataType* resultSaveMean_;
|
||||
MeanVarDataType* resultSaveInvVariance_;
|
||||
MeanVarDataType* resultRunningMean_;
|
||||
MeanVarDataType* resultRunningVariance_;
|
||||
|
||||
bool resultSave, resultRunning;
|
||||
|
||||
index_t n, h, w, c;
|
||||
|
||||
double exponentialAverageFactor_;
|
||||
double epsilon_;
|
||||
AccDataType averageFactor_;
|
||||
AccDataType epsilon_;
|
||||
};
|
||||
|
||||
struct Invoker : public device::BaseInvoker
|
||||
@@ -86,14 +105,12 @@ struct ReferenceBatchNormFwd_Input_N_H_W_C_Output_C : public device::DeviceBatch
|
||||
float Run(const Argument& arg)
|
||||
{
|
||||
auto thread_reduce_func = [&](auto iC) {
|
||||
AccDataType reduceSize = type_convert<AccDataType>(arg.n) *
|
||||
type_convert<AccDataType>(arg.h) *
|
||||
type_convert<AccDataType>(arg.w);
|
||||
index_t offset_C = iC;
|
||||
AccDataType mean = type_convert<AccDataType>(0.0f);
|
||||
AccDataType meansquare = type_convert<AccDataType>(0.0f);
|
||||
index_t offset_C = iC;
|
||||
AccDataType mean = type_convert<AccDataType>(0.0f);
|
||||
AccDataType variance = type_convert<AccDataType>(0.0f);
|
||||
int32_t curr_count = 0;
|
||||
|
||||
// compute mean, meanquare, variance, invVariance
|
||||
// compute mean, variance using welford method
|
||||
for(index_t iN = 0; iN < arg.n; iN++)
|
||||
{
|
||||
index_t offset_N = iN * arg.h * arg.w * arg.c;
|
||||
@@ -106,40 +123,46 @@ struct ReferenceBatchNormFwd_Input_N_H_W_C_Output_C : public device::DeviceBatch
|
||||
|
||||
auto offset = offset_N + offset_H + offset_W + offset_C;
|
||||
|
||||
curr_count++;
|
||||
|
||||
AccDataType x = type_convert<AccDataType>(arg.p_x_[offset]);
|
||||
|
||||
mean += x;
|
||||
meansquare += x * x;
|
||||
AccDataType delta = x - mean;
|
||||
|
||||
mean += delta / curr_count;
|
||||
|
||||
AccDataType delta2 = x - mean;
|
||||
|
||||
variance += delta * delta2;
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
mean = mean / reduceSize;
|
||||
meansquare = meansquare / reduceSize;
|
||||
// actual variance
|
||||
variance = variance / curr_count;
|
||||
|
||||
AccDataType variance = meansquare - mean * mean;
|
||||
AccDataType invVariance =
|
||||
type_convert<AccDataType>(1.0f) /
|
||||
std::sqrt(type_convert<AccDataType>(arg.epsilon_) + variance);
|
||||
type_convert<AccDataType>(1.0f) / ck::math::sqrt(arg.epsilon_ + variance);
|
||||
|
||||
// save the mean/invVariance if required
|
||||
if(arg.resultSave)
|
||||
{
|
||||
arg.resultSaveMean_[iC] = mean;
|
||||
arg.resultSaveInvVariance_[iC] = invVariance;
|
||||
arg.resultSaveMean_[iC] = type_convert<MeanVarDataType>(mean);
|
||||
arg.resultSaveInvVariance_[iC] = type_convert<MeanVarDataType>(invVariance);
|
||||
};
|
||||
|
||||
// update the moving average if required
|
||||
if(arg.resultRunning)
|
||||
{
|
||||
arg.resultRunningMean_[iC] =
|
||||
arg.resultRunningMean_[iC] *
|
||||
type_convert<AccDataType>(1.0 - arg.exponentialAverageFactor_) +
|
||||
mean * arg.exponentialAverageFactor_;
|
||||
arg.resultRunningVariance_[iC] =
|
||||
arg.resultRunningVariance_[iC] *
|
||||
type_convert<AccDataType>(1.0 - arg.exponentialAverageFactor_) +
|
||||
variance * arg.exponentialAverageFactor_;
|
||||
AccDataType oneMinusAverageFactor =
|
||||
type_convert<AccDataType>(1.0) - arg.averageFactor_;
|
||||
arg.resultRunningMean_[iC] = type_convert<MeanVarDataType>(
|
||||
type_convert<AccDataType>(arg.resultRunningMean_[iC]) *
|
||||
oneMinusAverageFactor +
|
||||
mean * arg.averageFactor_);
|
||||
arg.resultRunningVariance_[iC] = type_convert<MeanVarDataType>(
|
||||
arg.resultRunningVariance_[iC] * oneMinusAverageFactor +
|
||||
variance * arg.averageFactor_);
|
||||
};
|
||||
|
||||
// Normalization
|
||||
@@ -160,7 +183,7 @@ struct ReferenceBatchNormFwd_Input_N_H_W_C_Output_C : public device::DeviceBatch
|
||||
AccDataType norm_x =
|
||||
arg.bnScale_[iC] * (x - mean) * invVariance + arg.bnBias_[iC];
|
||||
|
||||
arg.p_y_[offset] = type_convert<InOutDataType>(norm_x);
|
||||
arg.p_y_[offset] = type_convert<YDataType>(norm_x);
|
||||
};
|
||||
}
|
||||
};
|
||||
@@ -207,34 +230,42 @@ struct ReferenceBatchNormFwd_Input_N_H_W_C_Output_C : public device::DeviceBatch
|
||||
MakeArgumentPointer(const std::array<index_t, 4> xyLengths,
|
||||
const std::array<index_t, 4> xStrides,
|
||||
const std::array<index_t, 4> yStrides,
|
||||
const std::array<int, 3> reduceDims,
|
||||
const std::array<index_t, 1> bnScaleBiasMeanVarLengths,
|
||||
const std::array<index_t, 1> bnScaleBiasMeanVarStrides,
|
||||
const std::array<index_t, 1> bnScaleStrides,
|
||||
const std::array<index_t, 1> bnBiasStrides,
|
||||
const std::array<index_t, 1> bnMeanVarStrides,
|
||||
const void* p_x,
|
||||
const void* bnScale,
|
||||
const void* bnBias,
|
||||
void* p_y,
|
||||
double exponentialAverageFactor,
|
||||
void* resultRunningMean,
|
||||
void* resultRunningVariance,
|
||||
double epsilon,
|
||||
const YElementwiseOp y_elementwise_op,
|
||||
void* p_y,
|
||||
void* resultSaveMean,
|
||||
void* resultSaveInvVariance) override
|
||||
void* resultSaveInvVariance,
|
||||
double averageFactor,
|
||||
void* resultRunningMean,
|
||||
void* resultRunningVariance) override
|
||||
{
|
||||
return std::make_unique<Argument>(xyLengths,
|
||||
xStrides,
|
||||
yStrides,
|
||||
reduceDims,
|
||||
bnScaleBiasMeanVarLengths,
|
||||
bnScaleBiasMeanVarStrides,
|
||||
static_cast<const InOutDataType*>(p_x),
|
||||
static_cast<const AccDataType*>(bnScale),
|
||||
static_cast<const AccDataType*>(bnBias),
|
||||
static_cast<InOutDataType*>(p_y),
|
||||
exponentialAverageFactor,
|
||||
static_cast<AccDataType*>(resultRunningMean),
|
||||
static_cast<AccDataType*>(resultRunningVariance),
|
||||
bnScaleStrides,
|
||||
bnBiasStrides,
|
||||
bnMeanVarStrides,
|
||||
static_cast<const XDataType*>(p_x),
|
||||
static_cast<const ScaleDataType*>(bnScale),
|
||||
static_cast<const BiasDataType*>(bnBias),
|
||||
epsilon,
|
||||
static_cast<AccDataType*>(resultSaveMean),
|
||||
static_cast<AccDataType*>(resultSaveInvVariance));
|
||||
y_elementwise_op,
|
||||
static_cast<YDataType*>(p_y),
|
||||
static_cast<MeanVarDataType*>(resultSaveMean),
|
||||
static_cast<MeanVarDataType*>(resultSaveInvVariance),
|
||||
averageFactor,
|
||||
static_cast<MeanVarDataType*>(resultRunningMean),
|
||||
static_cast<MeanVarDataType*>(resultRunningVariance));
|
||||
};
|
||||
|
||||
std::unique_ptr<device::BaseInvoker> MakeInvokerPointer() override
|
||||
|
||||
@@ -14,7 +14,12 @@ namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace host {
|
||||
|
||||
template <typename InOutDataType, typename AccDataType>
|
||||
template <typename XDataType,
|
||||
typename YDataType,
|
||||
typename AccDataType,
|
||||
typename ScaleDataType,
|
||||
typename BiasDataType,
|
||||
typename MeanVarDataType>
|
||||
struct ReferenceBatchNormInfer_Input_N_H_W_C_Output_C : public device::DeviceBatchNormInfer<4, 3>
|
||||
{
|
||||
struct Argument : public device::BaseArgument
|
||||
@@ -23,14 +28,16 @@ struct ReferenceBatchNormInfer_Input_N_H_W_C_Output_C : public device::DeviceBat
|
||||
const std::array<index_t, 4> xStrides,
|
||||
const std::array<index_t, 4> yStrides,
|
||||
const std::array<index_t, 1> bnScaleBiasMeanVarLengths,
|
||||
const std::array<index_t, 1> bnScaleBiasMeanVarStrides,
|
||||
const InOutDataType* p_x,
|
||||
const AccDataType* bnScale,
|
||||
const AccDataType* bnBias,
|
||||
const std::array<index_t, 1> bnScaleStrides,
|
||||
const std::array<index_t, 1> bnBiasStrides,
|
||||
const std::array<index_t, 1> bnMeanVarStrides,
|
||||
const XDataType* p_x,
|
||||
const ScaleDataType* bnScale,
|
||||
const BiasDataType* bnBias,
|
||||
double epsilon,
|
||||
const AccDataType* estimatedMean,
|
||||
const AccDataType* estimatedVariance,
|
||||
InOutDataType* p_y)
|
||||
const MeanVarDataType* estimatedMean,
|
||||
const MeanVarDataType* estimatedVariance,
|
||||
YDataType* p_y)
|
||||
: p_x_(p_x),
|
||||
bnScale_(bnScale),
|
||||
bnBias_(bnBias),
|
||||
@@ -39,32 +46,34 @@ struct ReferenceBatchNormInfer_Input_N_H_W_C_Output_C : public device::DeviceBat
|
||||
estimatedVariance_(estimatedVariance),
|
||||
p_y_(p_y)
|
||||
{
|
||||
(void)xStrides;
|
||||
(void)yStrides;
|
||||
(void)bnScaleBiasMeanVarStrides;
|
||||
ignore = xStrides;
|
||||
ignore = yStrides;
|
||||
ignore = bnScaleStrides;
|
||||
ignore = bnBiasStrides;
|
||||
ignore = bnMeanVarStrides;
|
||||
|
||||
if(xyLengths.size() != 4 || bnScaleBiasMeanVarLengths.size() != 1 ||
|
||||
bnScaleBiasMeanVarLengths[0] != xyLengths[3])
|
||||
throw std::runtime_error("Invalid tensor dimensions!");
|
||||
|
||||
n = xyLengths[0];
|
||||
h = xyLengths[1];
|
||||
w = xyLengths[2];
|
||||
c = xyLengths[3];
|
||||
n_ = xyLengths[0];
|
||||
h_ = xyLengths[1];
|
||||
w_ = xyLengths[2];
|
||||
c_ = xyLengths[3];
|
||||
}
|
||||
|
||||
const InOutDataType* p_x_;
|
||||
const AccDataType* bnScale_;
|
||||
const AccDataType* bnBias_;
|
||||
const XDataType* p_x_;
|
||||
const ScaleDataType* bnScale_;
|
||||
const BiasDataType* bnBias_;
|
||||
|
||||
double epsilon_;
|
||||
|
||||
const AccDataType* estimatedMean_;
|
||||
const AccDataType* estimatedVariance_;
|
||||
const MeanVarDataType* estimatedMean_;
|
||||
const MeanVarDataType* estimatedVariance_;
|
||||
|
||||
InOutDataType* p_y_;
|
||||
YDataType* p_y_;
|
||||
|
||||
index_t n, h, w, c;
|
||||
index_t n_, h_, w_, c_;
|
||||
};
|
||||
|
||||
struct Invoker : public device::BaseInvoker
|
||||
@@ -81,15 +90,15 @@ struct ReferenceBatchNormInfer_Input_N_H_W_C_Output_C : public device::DeviceBat
|
||||
std::sqrt(type_convert<AccDataType>(arg.epsilon_) + variance);
|
||||
|
||||
// Normalization
|
||||
for(index_t iN = 0; iN < arg.n; iN++)
|
||||
for(index_t iN = 0; iN < arg.n_; iN++)
|
||||
{
|
||||
index_t offset_N = iN * arg.h * arg.w * arg.c;
|
||||
for(index_t iH = 0; iH < arg.h; iH++)
|
||||
index_t offset_N = iN * arg.h_ * arg.w_ * arg.c_;
|
||||
for(index_t iH = 0; iH < arg.h_; iH++)
|
||||
{
|
||||
index_t offset_H = iH * arg.w * arg.c;
|
||||
for(index_t iW = 0; iW < arg.w; iW++)
|
||||
index_t offset_H = iH * arg.w_ * arg.c_;
|
||||
for(index_t iW = 0; iW < arg.w_; iW++)
|
||||
{
|
||||
index_t offset_W = iW * arg.c;
|
||||
index_t offset_W = iW * arg.c_;
|
||||
|
||||
auto offset = offset_N + offset_H + offset_W + offset_C;
|
||||
|
||||
@@ -98,21 +107,21 @@ struct ReferenceBatchNormInfer_Input_N_H_W_C_Output_C : public device::DeviceBat
|
||||
AccDataType norm_x =
|
||||
arg.bnScale_[iC] * (x - mean) * invVariance + arg.bnBias_[iC];
|
||||
|
||||
arg.p_y_[offset] = type_convert<InOutDataType>(norm_x);
|
||||
arg.p_y_[offset] = type_convert<YDataType>(norm_x);
|
||||
};
|
||||
}
|
||||
};
|
||||
};
|
||||
|
||||
std::size_t num_thread = std::thread::hardware_concurrency();
|
||||
std::size_t work_per_thread = (arg.c + num_thread - 1) / num_thread;
|
||||
std::size_t work_per_thread = (arg.c_ + num_thread - 1) / num_thread;
|
||||
|
||||
std::vector<joinable_thread> threads(num_thread);
|
||||
|
||||
for(std::size_t it = 0; it < num_thread; ++it)
|
||||
{
|
||||
std::size_t ic_begin = it * work_per_thread;
|
||||
std::size_t ic_end = std::min(static_cast<int>((it + 1) * work_per_thread), arg.c);
|
||||
std::size_t ic_end = std::min(static_cast<int>((it + 1) * work_per_thread), arg.c_);
|
||||
|
||||
auto f = [=] {
|
||||
for(std::size_t ic = ic_begin; ic < ic_end; ++ic)
|
||||
@@ -146,7 +155,9 @@ struct ReferenceBatchNormInfer_Input_N_H_W_C_Output_C : public device::DeviceBat
|
||||
const std::array<index_t, 4> xStrides,
|
||||
const std::array<index_t, 4> yStrides,
|
||||
const std::array<index_t, 1> bnScaleBiasMeanVarLengths,
|
||||
const std::array<index_t, 1> bnScaleBiasMeanVarStrides,
|
||||
const std::array<index_t, 1> bnScaleStrides,
|
||||
const std::array<index_t, 1> bnBiasStrides,
|
||||
const std::array<index_t, 1> bnMeanVarStrides,
|
||||
const void* p_x,
|
||||
const void* bnScale,
|
||||
const void* bnBias,
|
||||
@@ -159,14 +170,16 @@ struct ReferenceBatchNormInfer_Input_N_H_W_C_Output_C : public device::DeviceBat
|
||||
xStrides,
|
||||
yStrides,
|
||||
bnScaleBiasMeanVarLengths,
|
||||
bnScaleBiasMeanVarStrides,
|
||||
static_cast<const InOutDataType*>(p_x),
|
||||
static_cast<const AccDataType*>(bnScale),
|
||||
static_cast<const AccDataType*>(bnBias),
|
||||
bnScaleStrides,
|
||||
bnBiasStrides,
|
||||
bnMeanVarStrides,
|
||||
static_cast<const XDataType*>(p_x),
|
||||
static_cast<const ScaleDataType*>(bnScale),
|
||||
static_cast<const BiasDataType*>(bnBias),
|
||||
epsilon,
|
||||
static_cast<const AccDataType*>(estimatedMean),
|
||||
static_cast<const AccDataType*>(estimatedVariance),
|
||||
static_cast<InOutDataType*>(p_y));
|
||||
static_cast<const MeanVarDataType*>(estimatedMean),
|
||||
static_cast<const MeanVarDataType*>(estimatedVariance),
|
||||
static_cast<YDataType*>(p_y));
|
||||
};
|
||||
|
||||
std::unique_ptr<device::BaseInvoker> MakeInvokerPointer() override
|
||||
|
||||
Reference in New Issue
Block a user