mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-03 05:01:25 +00:00
update layernorm (#1570)
* port layernorm
* change warp_welford.hpp
* Update warpshuffle
* 1. Add save mean and save std back
2. Move construction of tensor_view and tile_window to operator()
* refine welford max count calculation
* unify layernorm api
* Rename file
* Remove save mean and inv std
* Revert "refine welford max count calculation"
This reverts commit 022365802b.
* Fix order of parameter
* refine welford max count calculation again
* Remove fp32 instances
* Fix bug of padding
* refactor api
* Support bf16
* Extract common function
* Refine arg of operator()
* Add kMThreadPerBlock to template parameter
* clang format
* Refine variable name
* Refine file name
* remove redundant line
* refactor layernorm2d pipeline and add block-per-block utility
* fix name
* rename more
* add more block-per-tile instance
* remove duplicated define
* update instance for 2048, 1024 case
* support up to 2048 now
* opt loading
* add n1536
* Add two pass pipeline
* format
* Fix incorrect type
* parallel compilation
* Use smaller N
* fix 2p pass
* Support Repeat_M in distribution
* Refine nameing
* Add reduce example
---------
Co-authored-by: letaoqin <letaoqin@amd.com>
Co-authored-by: aska-0096 <haocwang@amd.com>
Co-authored-by: rocking <ChunYu.Lai@amd.com>
Co-authored-by: carlushuang <carlus.huang@amd.com>
This commit is contained in:
@@ -7,95 +7,30 @@
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename ComputeDataType_, typename XDataType_>
|
||||
struct ThreadWelford
|
||||
template <typename T>
|
||||
CK_TILE_DEVICE void welford_update(T& mean, T& var, T x, int count)
|
||||
{
|
||||
using XDataType = remove_cvref_t<XDataType_>;
|
||||
using ComputeDataType = remove_cvref_t<ComputeDataType_>;
|
||||
// TODO: check nan? maybe no
|
||||
T delta = x - mean;
|
||||
mean += delta / count;
|
||||
T delta2 = x - mean;
|
||||
var += delta * delta2;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_DEVICE void Update(T& mean, T& var, T x)
|
||||
{
|
||||
if(ck_tile::isnan(x))
|
||||
{
|
||||
mean = x;
|
||||
var = x;
|
||||
}
|
||||
else
|
||||
{
|
||||
T delta = x - mean;
|
||||
mean += delta / cur_count_;
|
||||
T delta2 = x - mean;
|
||||
var += delta * delta2;
|
||||
}
|
||||
}
|
||||
template <typename T>
|
||||
CK_TILE_DEVICE static void
|
||||
welford_merge(T& mean_a, T& var_a, int& count_a, T mean_b, T var_b, int count_b)
|
||||
{
|
||||
int count = count_a + count_b;
|
||||
T count_ = type_convert<T>(count);
|
||||
T count_a_ = type_convert<T>(count_a);
|
||||
T count_b_ = type_convert<T>(count_b);
|
||||
T count_b_over_count = count == 0 ? type_convert<T>(0) : count_b_ / count_;
|
||||
|
||||
// [CAUSION] - max_count_ is to deal with the padding problem
|
||||
// max_count_ is depend on caller, eg: naive and splitN welford will have different
|
||||
// calculation of max_count_
|
||||
CK_TILE_DEVICE constexpr ThreadWelford(int max_count) : cur_count_(0), max_count_(max_count) {}
|
||||
|
||||
template <typename XDistributedTensor_,
|
||||
typename MeanDistributedTensor_,
|
||||
typename VarDistributedTensor_>
|
||||
CK_TILE_DEVICE void operator()(const XDistributedTensor_& x_tensor,
|
||||
MeanDistributedTensor_& mean_tensor,
|
||||
VarDistributedTensor_& var_tensor)
|
||||
{
|
||||
constexpr auto I0 = number<0>{};
|
||||
constexpr auto I1 = number<1>{};
|
||||
|
||||
constexpr auto spans = XDistributedTensor_::get_distributed_spans();
|
||||
|
||||
sweep_tile_span(spans[I1], [&](auto dstr_idx_i1) {
|
||||
if(cur_count_ < max_count_)
|
||||
{
|
||||
++cur_count_;
|
||||
|
||||
sweep_tile_span(spans[I0], [&](auto dstr_idx_i0) {
|
||||
constexpr auto in_dstr_idx = make_tuple(dstr_idx_i0, dstr_idx_i1);
|
||||
constexpr auto out_dstr_idx = make_tuple(dstr_idx_i0);
|
||||
|
||||
auto x = ck_tile::type_convert<ComputeDataType>(x_tensor[in_dstr_idx]);
|
||||
|
||||
Update(mean_tensor(out_dstr_idx), var_tensor(out_dstr_idx), x);
|
||||
});
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
template <typename XDistributedTensor_>
|
||||
CK_TILE_DEVICE static auto MakeInitialMeanVarDistributedTensor()
|
||||
{
|
||||
static_assert(std::is_same_v<XDataType, typename XDistributedTensor_::DataType>, "wrong!");
|
||||
|
||||
constexpr auto reduce_dims = sequence<1>{};
|
||||
|
||||
constexpr auto dstr =
|
||||
make_static_tile_distribution(detail::make_reduce_tile_distribution_encoding(
|
||||
XDistributedTensor_::get_tile_distribution()
|
||||
.get_static_tile_distribution_encoding(),
|
||||
reduce_dims));
|
||||
|
||||
auto tensor = make_static_distributed_tensor<ComputeDataType>(dstr);
|
||||
clear_tile(tensor);
|
||||
|
||||
return tensor;
|
||||
}
|
||||
|
||||
template <typename XDistributedTensor_>
|
||||
CK_TILE_DEVICE auto operator()(const XDistributedTensor_& x_tensor)
|
||||
{
|
||||
auto mean_tensor = MakeInitialMeanVarDistributedTensor<XDistributedTensor_>();
|
||||
auto var_tensor = MakeInitialMeanVarDistributedTensor<XDistributedTensor_>();
|
||||
|
||||
(*this)(x_tensor, mean_tensor, var_tensor);
|
||||
|
||||
return ck_tile::make_tuple(mean_tensor, var_tensor);
|
||||
}
|
||||
|
||||
int cur_count_;
|
||||
int max_count_;
|
||||
};
|
||||
T delta = mean_b - mean_a;
|
||||
mean_a += delta * count_b_over_count;
|
||||
var_a += var_b + delta * delta * count_a_ * count_b_over_count;
|
||||
count_a = count;
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
Reference in New Issue
Block a user