Layernorm welford (#346)

* Add threadwise and blockwise welford

* Rename gridwise op, prepare to add welford version

* implement welford and integrate welford into layernorm

* Take care of tail loop

* Fix buf when ThreadSliceK > 1

* Fix bug of merging of two empty set

* Rename clip to clamp

* 1. Fix type of count
2. Remove useless static_assert

* Do not inherit Reduction::Argument

* [What] replace __syncthreads() with block_sync_lds()
[Why] __syncthreads might wait both lgkmcnt(0) and vmcnt(0)

* Add y stride

* Rename.
DeviceLayernorm -> DeviceLayernormImpl
DeviceNormalization2 -> DeviceLayernorm

* Move literal ""_uz & ""_zu into namespace 'literals'

* Move namespace 'literals' as 'ck::literals'

Co-authored-by: Po-Yen, Chen <PoYen.Chen@amd.com>
Co-authored-by: Chao Liu <chao.liu2@amd.com>
This commit is contained in:
rocking5566
2022-08-13 22:43:18 +08:00
committed by GitHub
parent c20a75b07d
commit 0bd6b842b9
13 changed files with 1097 additions and 476 deletions

View File

@@ -0,0 +1,78 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/math_v2.hpp"
namespace ck {
// Assume
// 1) XDesc is known at compile-time
// 2) MeanVarDesc is known at compile-time
// 3) XBuffer is static buffer
// 4) MeanBuffer is static buffer
// 5) VarBuffer is static buffer
template <typename T, typename XThreadDesc_M_K, typename MeanVarThreadDesc_M>
struct ThreadwiseWelford
{
static constexpr auto x_thread_desc_m_k = XThreadDesc_M_K{};
static constexpr auto mean_var_thread_desc_m = MeanVarThreadDesc_M{};
static constexpr auto thread_x_length_m = x_thread_desc_m_k.GetLength(Number<0>{});
static constexpr auto thread_x_length_k = x_thread_desc_m_k.GetLength(Number<1>{});
static constexpr auto thread_mean_var_length_m = mean_var_thread_desc_m.GetLength(Number<0>{});
static_assert(thread_x_length_m == thread_mean_var_length_m,
"lengths of source and mean/var buffer must match!");
__device__ constexpr ThreadwiseWelford() : cur_count_(0), max_count_(0) {}
__device__ inline void Update(T& mean, T& var, T x)
{
using ck::math::isnan;
if(isnan(x))
{
mean = x;
var = x;
}
else
{
T delta = x - mean;
mean += delta / cur_count_;
T delta2 = x - mean;
var += delta * delta2;
}
}
template <typename XBufferType, typename MeanBufferType, typename VarBufferType>
__device__ void
Run(const XBufferType& x_buf_m_k, MeanBufferType& mean_buf_m, VarBufferType& var_buf_m)
{
// FIXME - Better naming for var_buf_m
static_for<0, thread_x_length_k, 1>{}([&](auto iK) {
if(cur_count_ < max_count_)
{
++cur_count_;
static_for<0, thread_x_length_m, 1>{}([&](auto iM) {
constexpr index_t out_offset =
mean_var_thread_desc_m.CalculateOffset(make_tuple(iM));
constexpr auto in_offset =
x_thread_desc_m_k.CalculateOffset(make_tuple(iM, iK));
Update(mean_buf_m(Number<out_offset>{}),
var_buf_m(Number<out_offset>{}),
x_buf_m_k[Number<in_offset>{}]);
});
}
});
};
int cur_count_;
int max_count_;
};
} // namespace ck