mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-02 20:51:23 +00:00
Batchnorm-forward implemented using welford method to calculate variance (#403)
* Update to the batchnorm-forward API and base class * Fix leeked header including in gridwise_set_buffer_value.hpp * Add kernels and device file for batchnorm-forward welford supporting both blockwise and multi-block reduction * Update to the batchnorm-forward example to use the new batchnorm-forward device interface * Change the batchnorm-forward reference to use sequential welford method * Change to assign the workspace into four buffers in the host layer * Use GetReduceCountPerThread functor to replace the initial count for Blockwise and Multiblock welford * Tiny correction and remove un-used file under example/34_batchnorm * Renaming in the kernel arguments * Explicitly use ck::math::sqrt in batchnorm-forward kernels * Add some comments to some kernels * Tiny fix * Generalize the data types in reference_batchnorm_forward_nhwc_c * Use ck::ignore to mark un-used parameters * Move GetReduceCountPerThread functor codes from kernel to device * Remove some un-used codes in device_batchnorm_forward_impl.hpp * Tiny fix in batchnorm_forward example * Move GetReduceCountPerThread() to welford_helper.hpp * Use seperate data type for Scale and Bias * Renaming in device Op * Tiny fix in forward example * Updata to batchnorm-infer (type spliting, renaming) * Add time and bandwidth measurement to the batchnorm-forward example * Add support of elementwise operation for batchnorm forward output * Reduce object copying by passing object as reference type * Tiny change for performance * Updates for performance again * Some Renamings * Add GetActualVariance template parameter for ThreadwiseWelfordMerge * Tiny update in reference batchnorm forward nhwc/c * Move batchnorm multiblock kernel files to grid/batchnorm_multiblock sub-directory * Fuse mean and bias in the normalization calculation Co-authored-by: root <root@dc-smc-18.amd.com> Co-authored-by: rocking5566 <ChunYu.Lai@amd.com>
This commit is contained in:
@@ -75,4 +75,63 @@ struct ThreadwiseWelford
|
||||
int max_count_;
|
||||
};
|
||||
|
||||
template <typename T,
|
||||
typename SrcMeanVarCountThreadDesc_M_K,
|
||||
typename DstMeanVarThreadDesc_M,
|
||||
bool GetActualVariance = false>
|
||||
struct ThreadwiseWelfordMerge
|
||||
{
|
||||
static constexpr auto src_thread_desc_m_k = SrcMeanVarCountThreadDesc_M_K{};
|
||||
static constexpr auto dst_thread_desc_m = DstMeanVarThreadDesc_M{};
|
||||
|
||||
static constexpr auto src_length_m = src_thread_desc_m_k.GetLength(Number<0>{});
|
||||
static constexpr auto src_length_k = src_thread_desc_m_k.GetLength(Number<1>{});
|
||||
static constexpr auto dst_length_m = dst_thread_desc_m.GetLength(Number<0>{});
|
||||
|
||||
static_assert(src_length_m == dst_length_m, "lengths of source and dst buffer must match!");
|
||||
|
||||
__device__ static void
|
||||
Merge(T& mean_a, T& var_a, int32_t& count_a, T mean_b, T var_b, int32_t count_b)
|
||||
{
|
||||
int count = count_a + count_b;
|
||||
T count_b_over_count = count == 0 ? type_convert<T>(0) : type_convert<T>(count_b) / count;
|
||||
T delta = mean_b - mean_a;
|
||||
mean_a += delta * count_b_over_count;
|
||||
var_a += var_b + delta * delta * count_a * count_b_over_count;
|
||||
count_a = count;
|
||||
}
|
||||
|
||||
template <typename SrcMeanBufferType,
|
||||
typename SrcVarBufferType,
|
||||
typename SrcCountBufferType,
|
||||
typename DstMeanBufferType,
|
||||
typename DstVarBufferType,
|
||||
typename DstCountBufferType>
|
||||
__device__ static void Run(const SrcMeanBufferType& src_mean_buf,
|
||||
const SrcVarBufferType& src_var_buf,
|
||||
const SrcCountBufferType& src_count_buf,
|
||||
DstMeanBufferType& dst_mean_buf,
|
||||
DstVarBufferType& dst_var_buf,
|
||||
DstCountBufferType& dst_count_buf)
|
||||
{
|
||||
static_for<0, src_length_m, 1>{}([&](auto iM) {
|
||||
static_for<0, src_length_k, 1>{}([&](auto iK) {
|
||||
constexpr auto src_offset = src_thread_desc_m_k.CalculateOffset(make_tuple(iM, iK));
|
||||
|
||||
Merge(dst_mean_buf(iM),
|
||||
dst_var_buf(iM),
|
||||
dst_count_buf(iM),
|
||||
src_mean_buf[Number<src_offset>{}],
|
||||
src_var_buf[Number<src_offset>{}],
|
||||
src_count_buf[Number<src_offset>{}]);
|
||||
});
|
||||
|
||||
if constexpr(GetActualVariance)
|
||||
{
|
||||
dst_var_buf(iM) = dst_var_buf[iM] / dst_count_buf[iM];
|
||||
};
|
||||
});
|
||||
};
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
|
||||
Reference in New Issue
Block a user