[CK_TILE] layernorm support fused-quant/fused-add (#1604)

* add prenorm/postnorm support, refactor using generate.py

* update README

* update README

* fix format

* update some description and fix format

* update format

* format

* use non-raw for loading

* format and update n4096

* dynamic-quant ready

* update readme

* support fused dynamic-quant

* update fused-quant, with smooth

* update README

* update args

* update some based on comment
This commit is contained in:
carlushuang
2024-10-31 14:54:53 +08:00
committed by GitHub
parent 9a8a52130d
commit c3a4800c5f
61 changed files with 1790 additions and 766 deletions

View File

@@ -8,20 +8,44 @@
namespace ck_tile {
// Note: for simplicity, each functor only care about single M
struct reference_layernorm2d_default_epilogue
{
template <typename OutDataType, typename AccDataType>
void operator()(int m, HostTensor<OutDataType>& o, const HostTensor<AccDataType>& acc)
{
const int N = acc.mDesc.get_lengths()[1];
for(int n = 0; n < N; ++n)
{
o(m, n) = ck_tile::type_convert<OutDataType>(acc(m, n));
}
}
template <typename OutDataType, typename AccDataType>
auto operator()(int m, const HostTensor<AccDataType>& acc)
{
HostTensor<OutDataType> o(acc.get_lengths(), acc.get_strides());
operator()(m, o, acc);
return o;
}
};
template <typename XDataType,
typename GammaDataType,
typename BetaDataType,
typename ComputeDataType,
typename YDataType,
typename MeanDataType,
typename InvStdDataType>
typename InvStdDataType,
typename Epilogue = reference_layernorm2d_default_epilogue>
void reference_layernorm2d_fwd(const HostTensor<XDataType>& x_m_n,
const HostTensor<GammaDataType>& gamma_n,
const HostTensor<BetaDataType>& beta_n,
HostTensor<YDataType>& y_m_n,
HostTensor<MeanDataType>& mean_m,
HostTensor<InvStdDataType>& invStd_m,
ComputeDataType epsilon)
ComputeDataType epsilon,
Epilogue epilogue_functor = {})
{
auto layernorm2d_fwd_func = [&](auto m) {
const int N = x_m_n.mDesc.get_lengths()[1];
@@ -51,16 +75,19 @@ void reference_layernorm2d_fwd(const HostTensor<XDataType>& x_m_n,
if constexpr(!std::is_same_v<InvStdDataType, ck_tile::null_type>)
invStd_m(m) = ck_tile::type_convert<InvStdDataType>(divisor);
HostTensor<ComputeDataType> acc(x_m_n.get_lengths(), x_m_n.get_strides());
for(int n = 0; n < N; ++n)
{
ComputeDataType x = ck_tile::type_convert<ComputeDataType>(x_m_n(m, n));
ComputeDataType gamma = ck_tile::type_convert<ComputeDataType>(gamma_n(n));
ComputeDataType beta = ck_tile::type_convert<ComputeDataType>(beta_n(n));
auto y = (x - mean) * divisor;
y = y * gamma + beta;
auto a_ = (x - mean) * divisor;
a_ = a_ * gamma + beta;
y_m_n(m, n) = ck_tile::type_convert<YDataType>(y);
acc(m, n) = a_;
}
epilogue_functor(m, y_m_n, acc);
};
make_ParallelTensorFunctor(layernorm2d_fwd_func,