mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-02 20:51:23 +00:00
Batchnorm splitk single kernel (#771)
* Use dim 0 as faster dim for writing mean/var/count workspace in batchnorm multiblock method [performance]
* Add CountDataType as template parameter in blockwise_welford
* Add utility/get_shift.hpp
* Add BatchNorm multiblock single-kernel implementation
* Add smem inline assembly based implementation of gms_init/gms_barrier/gms_reset for gfx90a
* Renaming in device_batchnorm_forward_impl.hpp
* Tiny fix in the batchnorm_fwd profiler
* Revert "Add smem inline assembly based implementation of gms_init/gms_barrier/gms_reset for gfx90a"
This reverts commit d16d00919c.
* Use the old two-kernel batchnorm multiblock method for gfx1030
* Use the old two-kernel batchnorm multiblock method for gfx908
* use the single-kernel batchnorm multiblock method only for gfx90a
* Remove get_wave_id() from utility/get_id.hpp since it is not used
* Set true for testing running mean/variance and saving mean/invvariance in the examples
* Fix to copy-right words
* Remove un-needed including in utility/get_id.hpp
* Add comments to workgroup_synchronization.hpp
* Remove un-used codes in gridwise_multiblock_batchnorm_forward.hpp
* Renaming in the kernels
* Remove un-used kernel file
This commit is contained in:
@@ -4,7 +4,7 @@
|
||||
#pragma once
|
||||
|
||||
#include "ck/tensor_description/cluster_descriptor.hpp"
|
||||
#include "ck/utility/reduction_common.hpp"
|
||||
#include "ck/utility/get_shift.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
@@ -35,10 +35,11 @@ struct BlockwiseWelford
|
||||
static constexpr auto thread_cluster_desc =
|
||||
make_cluster_descriptor(ThreadClusterLengths_M_K{}, ThreadClusterArrangeOrder{});
|
||||
|
||||
template <typename CountDataType>
|
||||
__device__ static inline void
|
||||
Merge(T& mean_a, T& var_a, int& count_a, T mean_b, T var_b, int count_b)
|
||||
Merge(T& mean_a, T& var_a, CountDataType& count_a, T mean_b, T var_b, CountDataType count_b)
|
||||
{
|
||||
int count = count_a + count_b;
|
||||
CountDataType count = count_a + count_b;
|
||||
T count_b_over_count = count == 0 ? type_convert<T>(0) : type_convert<T>(count_b) / count;
|
||||
T delta = mean_b - mean_a;
|
||||
mean_a += delta * count_b_over_count;
|
||||
@@ -46,11 +47,12 @@ struct BlockwiseWelford
|
||||
count_a = count;
|
||||
}
|
||||
|
||||
__device__ static void Run(T& mean_value, T& var_value, int& count)
|
||||
template <typename CountDataType>
|
||||
__device__ static void Run(T& mean_value, T& var_value, CountDataType& count)
|
||||
{
|
||||
__shared__ T mean_block_buf[BlockSize];
|
||||
__shared__ T var_block_buf[BlockSize];
|
||||
__shared__ int count_block_buf[BlockSize];
|
||||
__shared__ CountDataType count_block_buf[BlockSize];
|
||||
|
||||
constexpr auto cluster_len_shift = get_shift<BufferLength_K>();
|
||||
|
||||
@@ -76,13 +78,13 @@ struct BlockwiseWelford
|
||||
index_t offset2 = block_buf_desc_m_k.CalculateOffset(thread_cluster_idx +
|
||||
make_tuple(0, indOffset));
|
||||
|
||||
T mean1 = mean_block_buf[offset1];
|
||||
T var1 = var_block_buf[offset1];
|
||||
int count1 = count_block_buf[offset1];
|
||||
T mean1 = mean_block_buf[offset1];
|
||||
T var1 = var_block_buf[offset1];
|
||||
CountDataType count1 = count_block_buf[offset1];
|
||||
|
||||
T mean2 = mean_block_buf[offset2];
|
||||
T var2 = var_block_buf[offset2];
|
||||
int count2 = count_block_buf[offset2];
|
||||
T mean2 = mean_block_buf[offset2];
|
||||
T var2 = var_block_buf[offset2];
|
||||
CountDataType count2 = count_block_buf[offset2];
|
||||
|
||||
Merge(mean1, var1, count1, mean2, var2, count2);
|
||||
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
#pragma once
|
||||
|
||||
#include "ck/tensor_description/cluster_descriptor.hpp"
|
||||
#include "ck/utility/reduction_common.hpp"
|
||||
#include "ck/utility/get_shift.hpp"
|
||||
#include "ck/utility/reduction_functions_accumulate.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
Reference in New Issue
Block a user