update layernorm (#1570)

* port layernorm

* change warp_welford.hpp

* Update warpshuffle

* 1. Add save mean and save std back
2. Move construction of tensor_view and tile_window to operator()

* refine welford max count calculation

* unify layernorm api

* Rename file

* Remove save mean and inv std

* Revert "refine welford max count calculation"

This reverts commit 022365802b.

* Fix order of parameter

* refine welford max count calculation again

* Remove fp32 instances

* Fix bug of padding

* refactor api

* Support bf16

* Extract common function

* Refine arg of operator()

* Add kMThreadPerBlock to template parameter

* clang format

* Refine variable name

* Refine file name

* remove redundant line

* refactor layernorm2d pipeline and add block-per-block utility

* fix name

* rename more

* add more block-per-tile instance

* remove duplicated define

* update instance for 2048, 1024 case

* support up to 2048 now

* opt loading

* add n1536

* Add two pass pipeline

* format

* Fix incorrect type

* parallel compilation

* Use smaller N

* fix 2p pass

* Support Repeat_M in distribution

* Refine nameing

* Add reduce example

---------

Co-authored-by: letaoqin <letaoqin@amd.com>
Co-authored-by: aska-0096 <haocwang@amd.com>
Co-authored-by: rocking <ChunYu.Lai@amd.com>
Co-authored-by: carlushuang <carlus.huang@amd.com>
This commit is contained in:
ltqin
2024-10-22 09:26:18 +08:00
committed by GitHub
parent 3f710930f6
commit 0394f8a713
59 changed files with 2917 additions and 1042 deletions

View File

@@ -7,95 +7,30 @@
namespace ck_tile {
template <typename ComputeDataType_, typename XDataType_>
struct ThreadWelford
template <typename T>
CK_TILE_DEVICE void welford_update(T& mean, T& var, T x, int count)
{
using XDataType = remove_cvref_t<XDataType_>;
using ComputeDataType = remove_cvref_t<ComputeDataType_>;
// TODO: check nan? maybe no
T delta = x - mean;
mean += delta / count;
T delta2 = x - mean;
var += delta * delta2;
}
template <typename T>
CK_TILE_DEVICE void Update(T& mean, T& var, T x)
{
if(ck_tile::isnan(x))
{
mean = x;
var = x;
}
else
{
T delta = x - mean;
mean += delta / cur_count_;
T delta2 = x - mean;
var += delta * delta2;
}
}
template <typename T>
CK_TILE_DEVICE static void
welford_merge(T& mean_a, T& var_a, int& count_a, T mean_b, T var_b, int count_b)
{
int count = count_a + count_b;
T count_ = type_convert<T>(count);
T count_a_ = type_convert<T>(count_a);
T count_b_ = type_convert<T>(count_b);
T count_b_over_count = count == 0 ? type_convert<T>(0) : count_b_ / count_;
// [CAUSION] - max_count_ is to deal with the padding problem
// max_count_ is depend on caller, eg: naive and splitN welford will have different
// calculation of max_count_
CK_TILE_DEVICE constexpr ThreadWelford(int max_count) : cur_count_(0), max_count_(max_count) {}
template <typename XDistributedTensor_,
typename MeanDistributedTensor_,
typename VarDistributedTensor_>
CK_TILE_DEVICE void operator()(const XDistributedTensor_& x_tensor,
MeanDistributedTensor_& mean_tensor,
VarDistributedTensor_& var_tensor)
{
constexpr auto I0 = number<0>{};
constexpr auto I1 = number<1>{};
constexpr auto spans = XDistributedTensor_::get_distributed_spans();
sweep_tile_span(spans[I1], [&](auto dstr_idx_i1) {
if(cur_count_ < max_count_)
{
++cur_count_;
sweep_tile_span(spans[I0], [&](auto dstr_idx_i0) {
constexpr auto in_dstr_idx = make_tuple(dstr_idx_i0, dstr_idx_i1);
constexpr auto out_dstr_idx = make_tuple(dstr_idx_i0);
auto x = ck_tile::type_convert<ComputeDataType>(x_tensor[in_dstr_idx]);
Update(mean_tensor(out_dstr_idx), var_tensor(out_dstr_idx), x);
});
}
});
}
template <typename XDistributedTensor_>
CK_TILE_DEVICE static auto MakeInitialMeanVarDistributedTensor()
{
static_assert(std::is_same_v<XDataType, typename XDistributedTensor_::DataType>, "wrong!");
constexpr auto reduce_dims = sequence<1>{};
constexpr auto dstr =
make_static_tile_distribution(detail::make_reduce_tile_distribution_encoding(
XDistributedTensor_::get_tile_distribution()
.get_static_tile_distribution_encoding(),
reduce_dims));
auto tensor = make_static_distributed_tensor<ComputeDataType>(dstr);
clear_tile(tensor);
return tensor;
}
template <typename XDistributedTensor_>
CK_TILE_DEVICE auto operator()(const XDistributedTensor_& x_tensor)
{
auto mean_tensor = MakeInitialMeanVarDistributedTensor<XDistributedTensor_>();
auto var_tensor = MakeInitialMeanVarDistributedTensor<XDistributedTensor_>();
(*this)(x_tensor, mean_tensor, var_tensor);
return ck_tile::make_tuple(mean_tensor, var_tensor);
}
int cur_count_;
int max_count_;
};
T delta = mean_b - mean_a;
mean_a += delta * count_b_over_count;
var_a += var_b + delta * delta * count_a_ * count_b_over_count;
count_a = count;
}
} // namespace ck_tile