mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-05 06:01:23 +00:00
Batchnorm-forward and Batchnorm-infer Implemented using generic kernels (#320)
* Implement multiple-reduction in one kernel (kernels, device ops, examples) * Add generic elementwise kernel and device interface * Add generator for normal-distributed data initialization * Add host refer implementation of batchnorm-forward and batchnorm-infer * Add examples for implementing batchnorm-forward and batchnorm-infer using generic kernels * Remove un-needed including in batchnorm example * Renaming generic_elementwise to elementiwise in kernel and device classes/functions * Change in gemm_layernorm examples to use DeviceElementwise instead of Device5AryElementwise * Change in exampe 19_binary_elementwise to use DeviceElementwise instead of DeviceBinaryElementwise * Change in device_cgemm_4gemm_xdl_cshuffle.hpp to use kernel_elementwise instead of kernel_binary_elementwise * Add DeviceElementwiseBase and use it in device_normalize_instance.cpp * Removing and renaming files * Update to synchronize gemm_layernorm client example to the generic element-wise device op API * Update to synchronize with the latest headers directory and HostTensorDescriptor interface renaming * Merge two static member functions in device_elementwise.hpp * Remove unary_elementwise_1d kernel and device
This commit is contained in:
@@ -198,17 +198,44 @@ struct Normalize
|
||||
// FIXME: is double absolutely necessary?
|
||||
Normalize(double epsilon = 1e-4) : epsilon_(epsilon) {}
|
||||
|
||||
template <typename T>
|
||||
__host__ __device__ constexpr void operator()(
|
||||
T& y, const T& x, const T& mean, const T& mean_square, const T& gamma, const T& beta) const;
|
||||
template <typename T1, typename T2, typename T3>
|
||||
__host__ __device__ constexpr void operator()(T1& y,
|
||||
const T1& x,
|
||||
const T2& mean,
|
||||
const T2& mean_square,
|
||||
const T3& gamma,
|
||||
const T3& beta) const;
|
||||
|
||||
template <>
|
||||
__host__ __device__ constexpr void operator()<float>(float& y,
|
||||
const float& x,
|
||||
const float& mean,
|
||||
const float& mean_square,
|
||||
const float& gamma,
|
||||
const float& beta) const
|
||||
__host__ __device__ constexpr void operator()<half_t, float, half_t>(half_t& y,
|
||||
const half_t& x,
|
||||
const float& mean,
|
||||
const float& mean_square,
|
||||
const half_t& gamma,
|
||||
const half_t& beta) const
|
||||
{
|
||||
using ck::math::sqrt;
|
||||
|
||||
float variance = mean_square - (mean * mean);
|
||||
|
||||
float tmp_x = type_convert<float>(x);
|
||||
float tmp_gamma = type_convert<float>(gamma);
|
||||
float tmp_beta = type_convert<float>(beta);
|
||||
|
||||
float tmp_y =
|
||||
((tmp_x - mean) / sqrt(variance + type_convert<float>(epsilon_))) * tmp_gamma +
|
||||
tmp_beta;
|
||||
|
||||
y = type_convert<half_t>(tmp_y);
|
||||
};
|
||||
|
||||
template <>
|
||||
__host__ __device__ constexpr void operator()<float, float, float>(float& y,
|
||||
const float& x,
|
||||
const float& mean,
|
||||
const float& mean_square,
|
||||
const float& gamma,
|
||||
const float& beta) const
|
||||
{
|
||||
using ck::math::sqrt;
|
||||
|
||||
@@ -217,12 +244,12 @@ struct Normalize
|
||||
};
|
||||
|
||||
template <>
|
||||
__host__ __device__ constexpr void operator()<double>(double& y,
|
||||
const double& x,
|
||||
const double& mean,
|
||||
const double& mean_square,
|
||||
const double& gamma,
|
||||
const double& beta) const
|
||||
__host__ __device__ constexpr void operator()<double, double, double>(double& y,
|
||||
const double& x,
|
||||
const double& mean,
|
||||
const double& mean_square,
|
||||
const double& gamma,
|
||||
const double& beta) const
|
||||
{
|
||||
using ck::math::sqrt;
|
||||
|
||||
|
||||
Reference in New Issue
Block a user