[CK_TILE] layernorm support fused-quant/fused-add (#1604)

* add prenorm/postnorm support, refactor using generate.py

* update README

* update README

* fix format

* update some description and fix format

* update format

* format

* use non-raw for loading

* format and update n4096

* dynamic-quant ready

* update readme

* support fused dynamic-quant

* update fused-quant, with smooth

* update README

* update args

* update some based on comment
This commit is contained in:
carlushuang
2024-10-31 14:54:53 +08:00
committed by GitHub
parent 9a8a52130d
commit c3a4800c5f
61 changed files with 1790 additions and 766 deletions

View File

@@ -9,23 +9,29 @@ namespace ck_tile {
// this epilogue just store out a M*N matrix, row major
template <typename AccDataType_, typename ODataType_, bool kPadM_, bool kPadN_>
template <typename AccDataType_,
typename ODataType_,
bool kPadM_,
bool kPadN_,
bool UseRawStore_ = true>
struct Default2DEpilogueProblem
{
using AccDataType = remove_cvref_t<AccDataType_>;
using ODataType = remove_cvref_t<ODataType_>;
static constexpr bool kPadM = kPadM_;
static constexpr bool kPadN = kPadN_;
using AccDataType = remove_cvref_t<AccDataType_>;
using ODataType = remove_cvref_t<ODataType_>;
static constexpr bool kPadM = kPadM_;
static constexpr bool kPadN = kPadN_;
static constexpr bool UseRawStore = UseRawStore_;
};
template <typename Problem_, typename Policy_ = void>
struct Default2DEpilogue
{
using Problem = remove_cvref_t<Problem_>;
using AccDataType = remove_cvref_t<typename Problem::AccDataType>;
using ODataType = remove_cvref_t<typename Problem::ODataType>;
static constexpr bool kPadM = Problem::kPadM;
static constexpr bool kPadN = Problem::kPadN;
using Problem = remove_cvref_t<Problem_>;
using AccDataType = remove_cvref_t<typename Problem::AccDataType>;
using ODataType = remove_cvref_t<typename Problem::ODataType>;
static constexpr bool kPadM = Problem::kPadM;
static constexpr bool kPadN = Problem::kPadN;
static constexpr bool UseRawStore = Problem::UseRawStore;
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() { return 0; }
@@ -36,7 +42,7 @@ struct Default2DEpilogue
{
// TODO: this is ugly
if constexpr(kPadM || kPadN)
if constexpr(UseRawStore && (kPadM || kPadN))
{
store_tile_raw(o_dram_window_tmp, cast_tile<ODataType>(o_acc_tile));
buffer_store_fence();