[CK_TILE] Add Various Fusion Functions to RMSNorm (#1802)

* Add shortcut to RMSNorm

* Modify test for adding shortcut for RMSNorm

* Add fused parameter into tests

* 1. Add YDataType. 2. rmsnorm2d_fwd_traits_ from rmsnorm2d_fwd.hpp to rmsnorm2d_fwd_api.cpp and rmsnorm2d_fwd_instance_common.hpp

* 1. Supports various stride and percisions.

* Add support of Epilogue

* Add fuse and epilogue support to rmsnorm ref

* Modify rmsnorm example

* Refactor tests/examples

* Bug fix for newly added tests/examples

* Bug fix for new tests 2

* Modify smoke test scripts

remove dbg code

* Supports non-smooth dyanmic quant

* Update Rmsnorm2dFwd::GetName()

* rename xscale and prec_sx to smoothscale and prec_sm

Bug fix after rename

Remove files

* change example_rmsnorm2d_fwd.cpp

* update performance calculator

* Fix issue in two-pass when fuse add is enabled

* Remove comment of beta

---------

Co-authored-by: rocking <ChunYu.Lai@amd.com>
This commit is contained in:
ruanjm
2025-01-15 10:23:48 +08:00
committed by GitHub
parent c0b90f130f
commit 04dd314883
58 changed files with 1823 additions and 1045 deletions

View File

@@ -8,16 +8,40 @@
namespace ck_tile {
// Note: for simplicity, each functor only care about single M
struct reference_rmsnorm2d_default_epilogue
{
template <typename OutDataType, typename AccDataType>
void operator()(int m, HostTensor<OutDataType>& o, const HostTensor<AccDataType>& acc)
{
const int N = acc.mDesc.get_lengths()[1];
for(int n = 0; n < N; ++n)
{
o(m, n) = ck_tile::type_convert<OutDataType>(acc(m, n));
}
}
template <typename OutDataType, typename AccDataType>
auto operator()(int m, const HostTensor<AccDataType>& acc)
{
HostTensor<OutDataType> o(acc.get_lengths(), acc.get_strides());
operator()(m, o, acc);
return o;
}
};
template <typename XDataType,
typename GammaDataType,
typename ComputeDataType,
typename YDataType,
typename InvRmsDataType>
typename InvRmsDataType,
typename Epilogue = reference_rmsnorm2d_default_epilogue>
void reference_rmsnorm2d_fwd(const HostTensor<XDataType>& x_m_n,
const HostTensor<GammaDataType>& gamma_n,
HostTensor<YDataType>& y_m_n,
HostTensor<InvRmsDataType>& invRms_m,
ComputeDataType epsilon)
ComputeDataType epsilon,
Epilogue epilogue_functor = {})
{
auto rmsnorm2d_fwd_func = [&](auto m) {
const int N = x_m_n.mDesc.get_lengths()[1];
@@ -37,13 +61,15 @@ void reference_rmsnorm2d_fwd(const HostTensor<XDataType>& x_m_n,
if constexpr(!std::is_same_v<InvRmsDataType, ck_tile::null_type>)
invRms_m(m) = ck_tile::type_convert<InvRmsDataType>(divisor);
HostTensor<ComputeDataType> acc(x_m_n.get_lengths(), x_m_n.get_strides());
for(int n = 0; n < N; ++n)
{
ComputeDataType x = ck_tile::type_convert<ComputeDataType>(x_m_n(m, n));
ComputeDataType gamma = ck_tile::type_convert<ComputeDataType>(gamma_n(n));
auto y = x * divisor * gamma;
y_m_n(m, n) = ck_tile::type_convert<YDataType>(y);
acc(m, n) = x * divisor * gamma;
}
epilogue_functor(m, y_m_n, acc);
};
make_ParallelTensorFunctor(rmsnorm2d_fwd_func, invRms_m.mDesc.get_lengths()[0])(