Layernorm welford (#346)

* Add threadwise and blockwise welford

* Rename gridwise op, prepare to add welford version

* implement welford and integrate welford into layernorm

* Take care of tail loop

* Fix buf when ThreadSliceK > 1

* Fix bug of merging of two empty set

* Rename clip to clamp

* 1. Fix type of count
2. Remove useless static_assert

* Do not inherit Reduction::Argument

* [What] replace __syncthreads() with block_sync_lds()
[Why] __syncthreads might wait both lgkmcnt(0) and vmcnt(0)

* Add y stride

* Rename.
DeviceLayernorm -> DeviceLayernormImpl
DeviceNormalization2 -> DeviceLayernorm

* Move literal ""_uz & ""_zu into namespace 'literals'

* Move namespace 'literals' as 'ck::literals'

Co-authored-by: Po-Yen, Chen <PoYen.Chen@amd.com>
Co-authored-by: Chao Liu <chao.liu2@amd.com>
This commit is contained in:
rocking5566
2022-08-13 22:43:18 +08:00
committed by GitHub
parent c20a75b07d
commit 0bd6b842b9
13 changed files with 1097 additions and 476 deletions

View File

@@ -0,0 +1,108 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/tensor_description/cluster_descriptor.hpp"
#include "ck/utility/reduction_common.hpp"
namespace ck {
// clang-format off
// Assume:
// 1) work_buffer is buffer (typically LDS) allocated outside as workspace
// 2) work_buffer has T elements, and space size is no less than 3*BlockSize
// 3) mean_value, var_value and count is the input data in vgpr from each thread
// 4) mean_value, var_value and count is the over-written reduced output in vgpr for each thread
// 5) Merge mean and M from ThreadwiseWelford
// clang-format on
template <typename T,
index_t BlockSize,
typename ThreadClusterLengths_M_K,
typename ThreadClusterArrangeOrder,
bool GetActualVariance = true>
struct BlockwiseWelford
{
static_assert(BlockSize == ThreadClusterLengths_M_K::At(0) * ThreadClusterLengths_M_K::At(1),
"The product of cluster lengths should be same as BlockSize!");
static constexpr auto BufferLength_M = ThreadClusterLengths_M_K::At(0);
static constexpr auto BufferLength_K = ThreadClusterLengths_M_K::At(1);
static constexpr auto block_buf_desc_m_k = make_naive_tensor_descriptor_packed(
make_tuple(Number<BufferLength_M>{}, Number<BufferLength_K>{}));
static constexpr auto thread_cluster_desc =
make_cluster_descriptor(ThreadClusterLengths_M_K{}, ThreadClusterArrangeOrder{});
__device__ static inline void
Merge(T& mean_a, T& var_a, int& count_a, T mean_b, T var_b, int count_b)
{
int count = count_a + count_b;
T count_b_over_count = count == 0 ? type_convert<T>(0) : type_convert<T>(count_b) / count;
T delta = mean_b - mean_a;
mean_a += delta * count_b_over_count;
var_a += var_b + delta * delta * count_a * count_b_over_count;
count_a = count;
}
__device__ static void Run(T& mean_value, T& var_value, int& count)
{
__shared__ T mean_block_buf[BlockSize];
__shared__ T var_block_buf[BlockSize];
__shared__ int count_block_buf[BlockSize];
constexpr auto cluster_len_shift = get_shift<BufferLength_K>();
const auto thread_cluster_idx =
thread_cluster_desc.CalculateBottomIndex(make_multi_index(get_thread_local_1d_id()));
const auto thread_m_cluster_id = thread_cluster_idx[Number<0>{}];
const auto thread_k_cluster_id = thread_cluster_idx[Number<1>{}];
index_t offset1 = block_buf_desc_m_k.CalculateOffset(thread_cluster_idx);
mean_block_buf[offset1] = mean_value;
var_block_buf[offset1] = var_value;
count_block_buf[offset1] = count;
block_sync_lds();
static_for<0, cluster_len_shift, 1>{}([&](auto I) {
constexpr index_t indOffset = 1 << (cluster_len_shift - 1 - I());
if(thread_k_cluster_id < indOffset)
{
index_t offset2 = block_buf_desc_m_k.CalculateOffset(thread_cluster_idx +
make_tuple(0, indOffset));
T mean1 = mean_block_buf[offset1];
T var1 = var_block_buf[offset1];
int count1 = count_block_buf[offset1];
T mean2 = mean_block_buf[offset2];
T var2 = var_block_buf[offset2];
int count2 = count_block_buf[offset2];
Merge(mean1, var1, count1, mean2, var2, count2);
mean_block_buf[offset1] = mean1;
var_block_buf[offset1] = var1;
count_block_buf[offset1] = count1;
}
block_sync_lds();
});
index_t offset = block_buf_desc_m_k.CalculateOffset(make_tuple(thread_m_cluster_id, 0));
count = count_block_buf[offset];
mean_value = mean_block_buf[offset];
if constexpr(GetActualVariance)
var_value = var_block_buf[offset] / count;
else
var_value = var_block_buf[offset];
};
};
} // namespace ck