mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 14:59:17 +00:00
update layernorm (#1570)
* port layernorm
* change warp_welford.hpp
* Update warpshuffle
* 1. Add save mean and save std back
2. Move construction of tensor_view and tile_window to operator()
* refine welford max count calculation
* unify layernorm api
* Rename file
* Remove save mean and inv std
* Revert "refine welford max count calculation"
This reverts commit 022365802b.
* Fix order of parameter
* refine welford max count calculation again
* Remove fp32 instances
* Fix bug of padding
* refactor api
* Support bf16
* Extract common function
* Refine arg of operator()
* Add kMThreadPerBlock to template parameter
* clang format
* Refine variable name
* Refine file name
* remove redundant line
* refactor layernorm2d pipeline and add block-per-block utility
* fix name
* rename more
* add more block-per-tile instance
* remove duplicated define
* update instance for 2048, 1024 case
* support up to 2048 now
* opt loading
* add n1536
* Add two pass pipeline
* format
* Fix incorrect type
* parallel compilation
* Use smaller N
* fix 2p pass
* Support Repeat_M in distribution
* Refine nameing
* Add reduce example
---------
Co-authored-by: letaoqin <letaoqin@amd.com>
Co-authored-by: aska-0096 <haocwang@amd.com>
Co-authored-by: rocking <ChunYu.Lai@amd.com>
Co-authored-by: carlushuang <carlus.huang@amd.com>
This commit is contained in:
@@ -1111,4 +1111,126 @@ CK_TILE_HOST_DEVICE constexpr auto generate_array(F&& f, number<N>)
|
||||
typename arithmetic_sequence_gen<0, N, 1>::type{});
|
||||
}
|
||||
|
||||
namespace impl {
|
||||
template <typename, typename, typename, index_t>
|
||||
struct reverse_slice_sequence_impl;
|
||||
|
||||
template <index_t x,
|
||||
index_t... xs,
|
||||
index_t m,
|
||||
index_t... ms,
|
||||
index_t id,
|
||||
index_t... ids,
|
||||
index_t SliceSize>
|
||||
struct reverse_slice_sequence_impl<sequence<x, xs...>,
|
||||
sequence<m, ms...>,
|
||||
sequence<id, ids...>,
|
||||
SliceSize>
|
||||
{
|
||||
using old_scan =
|
||||
reverse_slice_sequence_impl<sequence<xs...>, sequence<ms...>, sequence<ids...>, SliceSize>;
|
||||
|
||||
static constexpr auto slice_size = old_scan::remaining_slice_sizes::front().value;
|
||||
static constexpr auto slice_length =
|
||||
std::conditional_t<m, number<gcd(x, slice_size)>, number<x>>::value;
|
||||
|
||||
using dim_lengths =
|
||||
typename sequence_merge<sequence<slice_length>, typename old_scan::dim_lengths>::type;
|
||||
using dim_slices =
|
||||
typename sequence_merge<sequence<x / slice_length>, typename old_scan::dim_slices>::type;
|
||||
using remaining_slice_sizes = typename sequence_merge<
|
||||
std::conditional_t<m, sequence<slice_size / slice_length>, sequence<slice_size>>,
|
||||
typename old_scan::remaining_slice_sizes>::type;
|
||||
|
||||
// the first idx that sliced length not equal to original length
|
||||
static constexpr index_t _flag =
|
||||
slice_length != x && remaining_slice_sizes{}.front().value == 1;
|
||||
static constexpr index_t _split_flag = std::conditional_t<m, number<_flag>, number<0>>::value;
|
||||
static constexpr index_t _split_idx =
|
||||
std::conditional_t<_split_flag, number<id>, number<0>>::value;
|
||||
|
||||
static constexpr index_t split_flag = _split_flag || old_scan::split_flag;
|
||||
static constexpr index_t split_idx = std::
|
||||
conditional_t<old_scan::split_flag, number<old_scan::split_idx>, number<_split_idx>>::value;
|
||||
};
|
||||
|
||||
template <index_t x, index_t m, index_t id, index_t SliceSize>
|
||||
struct reverse_slice_sequence_impl<sequence<x>, sequence<m>, sequence<id>, SliceSize>
|
||||
{
|
||||
static constexpr auto slice_size = SliceSize;
|
||||
static constexpr auto slice_length =
|
||||
std::conditional_t<m, number<gcd(x, slice_size)>, number<x>>::value;
|
||||
|
||||
using dim_lengths = sequence<slice_length>;
|
||||
using dim_slices = sequence<x / slice_length>;
|
||||
using remaining_slice_sizes =
|
||||
std::conditional_t<m, sequence<slice_size / slice_length>, sequence<slice_size>>;
|
||||
|
||||
// the first idx that sliced length not equal to original length
|
||||
static constexpr index_t _flag =
|
||||
slice_length != x && remaining_slice_sizes{}.front().value == 1;
|
||||
static constexpr index_t split_flag = std::conditional_t<m, number<_flag>, number<0>>::value;
|
||||
static constexpr index_t split_idx =
|
||||
std::conditional_t<split_flag, number<id>, number<0>>::value;
|
||||
};
|
||||
} // namespace impl
|
||||
|
||||
// clang-format off
|
||||
// input a sequence(with optional mask), and the SliceSize : size per slice
|
||||
// output the sequence each slice, and number of slices
|
||||
//
|
||||
// e.g. <2, 1, 4, 2>, 8 -> lengths:<1, 1, 4, 2> , nums: <2, 1, 1, 1> : 2 slices , slice_idx: 0
|
||||
// <4, 2, 4, 1, 2>, 4 -> lengths:<1, 1, 2, 1, 2> , nums: <4, 2, 2, 1, 1> : 16 slices , slice_idx: 2
|
||||
// <4, 2, 4, 1, 6>, 4 -> lengths:<1, 1, 2, 1, 2> , nums: <4, 2, 2, 1, 3> : 48 slices , slice_idx: 2
|
||||
// <4, 2, 5, 1, 2>, 10 -> lengths:<1, 1, 5, 1, 2> , nums: <4, 2, 1, 1, 1> : 8 slices , slice_idx: 1
|
||||
//
|
||||
// <4, 2, 8>, 64 -> lengths:<4, 2, 8> , nums: <1, 1, 1> : 1 slices , slice_idx: 0
|
||||
// <4, 2, 8>, 32 -> lengths:<2, 2, 8> , nums: <2, 1, 1> : 2 slices , slice_idx: 0
|
||||
// <4, 2, 8>, 16 -> lengths:<1, 2, 8> , nums: <4, 1, 1> : 4 slices , slice_idx: 0
|
||||
// <4, 2, 8>, 8 -> lengths:<1, 1, 8> , nums: <4, 2, 1> : 8 slices , slice_idx: 1
|
||||
// <4, 2, 8>, 4 -> lengths:<1, 1, 4> , nums: <4, 2, 2> : 16 slices , slice_idx: 2
|
||||
// <4, 2, 8>, 2 -> lengths:<1, 1, 2> , nums: <4, 2, 4> : 32 slices , slice_idx: 2
|
||||
// <4, 2, 8>, 1 -> lengths:<1, 1, 1> , nums: <4, 2, 8> : 64 slices , slice_idx: 2
|
||||
//
|
||||
// <4, 2, 1, 4, 2> / 4 ->
|
||||
// mask:<1, 1, 1, 0, 1>, -> lengths:<1, 2, 1, 4, 2> , nums: <4, 1, 1, 1, 1> : 8 slices , slice_idx: 0
|
||||
//
|
||||
// return tuple<slice_lengths, slice_nums, slice_index>, slice_index is at which index will start
|
||||
// have split slices (right -> left)
|
||||
// or the first index that sliced length is different from the original length
|
||||
// clang-format on
|
||||
template <typename Seq,
|
||||
index_t SliceSize,
|
||||
typename Mask = typename uniform_sequence_gen<Seq::size(), 1>::type>
|
||||
constexpr auto reverse_slice_sequence(Seq,
|
||||
number<SliceSize>,
|
||||
Mask = typename uniform_sequence_gen<Seq::size(), 1>::type{})
|
||||
{
|
||||
static_assert(Seq::size() == Mask::size());
|
||||
using sliced_type =
|
||||
impl::reverse_slice_sequence_impl<Seq,
|
||||
Mask,
|
||||
typename arithmetic_sequence_gen<0, Seq::size(), 1>::type,
|
||||
SliceSize>;
|
||||
static_assert(sliced_type::remaining_slice_sizes::front().value == 1,
|
||||
"can not evenly divide this sequence, please check");
|
||||
return make_tuple(typename sliced_type::dim_lengths{},
|
||||
typename sliced_type::dim_slices{},
|
||||
number<sliced_type::split_idx>{});
|
||||
}
|
||||
|
||||
template <typename Seq,
|
||||
index_t SliceSize,
|
||||
typename Mask = typename uniform_sequence_gen<Seq::size(), 1>::type>
|
||||
constexpr auto slice_sequence(Seq,
|
||||
number<SliceSize>,
|
||||
Mask = typename uniform_sequence_gen<Seq::size(), 1>::type{})
|
||||
{
|
||||
constexpr auto r =
|
||||
reverse_slice_sequence(Seq{}.reverse(), number<SliceSize>{}, Mask{}.reverse());
|
||||
return make_tuple(r[number<0>{}].reverse(),
|
||||
r[number<1>{}].reverse(),
|
||||
number<Seq::size() - r[number<2>{}] - 1>{});
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -488,6 +488,26 @@ CK_TILE_HOST_DEVICE constexpr auto transform_tuples(F f, const X& x, const Y& y,
|
||||
f, x, y, z, typename arithmetic_sequence_gen<0, X::size(), 1>::type{});
|
||||
}
|
||||
|
||||
namespace detail {
|
||||
|
||||
template <typename F, typename X, index_t... Is>
|
||||
CK_TILE_HOST_DEVICE constexpr auto embed_tuples_impl(F f, const X& x, sequence<Is...>)
|
||||
{
|
||||
return concat_tuple(f(x.at(number<Is>{}))...);
|
||||
}
|
||||
|
||||
} // namespace detail
|
||||
|
||||
// make sure F return at least a tuple
|
||||
// e.g. x : tuple<X, Y>, f will return tuple<Z, W>
|
||||
// this function will return
|
||||
template <typename F, typename X>
|
||||
CK_TILE_HOST_DEVICE constexpr auto embed_tuples(F f, const X& x)
|
||||
{
|
||||
return detail::embed_tuples_impl(
|
||||
f, x, typename arithmetic_sequence_gen<0, X::size(), 1>::type{});
|
||||
}
|
||||
|
||||
// By default unroll to the flatten
|
||||
template <index_t Depth = 0, index_t MaxDepth = -1>
|
||||
CK_TILE_HOST_DEVICE constexpr auto unroll_nested_tuple(const tuple<>& t)
|
||||
|
||||
Reference in New Issue
Block a user