mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-03 21:21:22 +00:00
[CK_TILE] layernorm support fused-quant/fused-add (#1604)
* add prenorm/postnorm support, refactor using generate.py * update README * update README * fix format * update some description and fix format * update format * format * use non-raw for loading * format and update n4096 * dynamic-quant ready * update readme * support fused dynamic-quant * update fused-quant, with smooth * update README * update args * update some based on comment
This commit is contained in:
@@ -9,23 +9,29 @@ namespace ck_tile {
|
||||
|
||||
// this epilogue just store out a M*N matrix, row major
|
||||
|
||||
template <typename AccDataType_, typename ODataType_, bool kPadM_, bool kPadN_>
|
||||
template <typename AccDataType_,
|
||||
typename ODataType_,
|
||||
bool kPadM_,
|
||||
bool kPadN_,
|
||||
bool UseRawStore_ = true>
|
||||
struct Default2DEpilogueProblem
|
||||
{
|
||||
using AccDataType = remove_cvref_t<AccDataType_>;
|
||||
using ODataType = remove_cvref_t<ODataType_>;
|
||||
static constexpr bool kPadM = kPadM_;
|
||||
static constexpr bool kPadN = kPadN_;
|
||||
using AccDataType = remove_cvref_t<AccDataType_>;
|
||||
using ODataType = remove_cvref_t<ODataType_>;
|
||||
static constexpr bool kPadM = kPadM_;
|
||||
static constexpr bool kPadN = kPadN_;
|
||||
static constexpr bool UseRawStore = UseRawStore_;
|
||||
};
|
||||
|
||||
template <typename Problem_, typename Policy_ = void>
|
||||
struct Default2DEpilogue
|
||||
{
|
||||
using Problem = remove_cvref_t<Problem_>;
|
||||
using AccDataType = remove_cvref_t<typename Problem::AccDataType>;
|
||||
using ODataType = remove_cvref_t<typename Problem::ODataType>;
|
||||
static constexpr bool kPadM = Problem::kPadM;
|
||||
static constexpr bool kPadN = Problem::kPadN;
|
||||
using Problem = remove_cvref_t<Problem_>;
|
||||
using AccDataType = remove_cvref_t<typename Problem::AccDataType>;
|
||||
using ODataType = remove_cvref_t<typename Problem::ODataType>;
|
||||
static constexpr bool kPadM = Problem::kPadM;
|
||||
static constexpr bool kPadN = Problem::kPadN;
|
||||
static constexpr bool UseRawStore = Problem::UseRawStore;
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() { return 0; }
|
||||
|
||||
@@ -36,7 +42,7 @@ struct Default2DEpilogue
|
||||
{
|
||||
|
||||
// TODO: this is ugly
|
||||
if constexpr(kPadM || kPadN)
|
||||
if constexpr(UseRawStore && (kPadM || kPadN))
|
||||
{
|
||||
store_tile_raw(o_dram_window_tmp, cast_tile<ODataType>(o_acc_tile));
|
||||
buffer_store_fence();
|
||||
|
||||
Reference in New Issue
Block a user