mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 10:09:41 +00:00
enable bias feature that add bias before adding residual (for rtpllm project) (#1741)
* 1. enable bias feature that add bias before adding residual; 2. change block size from 128->64 when m<64 in fp16
* delete comment
* 1.remove fmha change 2.change buffer name from bias to xbias
* Now bias can be used independently from fadd
* change kbias to kxbias
---------
Co-authored-by: feli <felix.li@amd.com>
[ROCm/composable_kernel commit: d5c8a334ca]
This commit is contained in:
@@ -23,6 +23,10 @@ def get_if_str(idx, total, lase_else = True):
|
||||
else:
|
||||
return 'else if'
|
||||
|
||||
XBIAS_ENUM_STR_MAP = [
|
||||
'no',
|
||||
'xbias'] # pre-norm add bias
|
||||
|
||||
FUSED_ADD_ENUM_STR_MAP = [
|
||||
'no',
|
||||
'pras', # pre-norm
|
||||
@@ -60,6 +64,7 @@ template <typename XDataType_,
|
||||
bool kFastFDiv_,
|
||||
bool kWelford_,
|
||||
bool kTwoPass_,
|
||||
ck_tile::index_t kXbias_ = 0,
|
||||
ck_tile::index_t kFusedAdd_ = 0,
|
||||
ck_tile::index_t kFusedQuant_ = 0>
|
||||
struct layernorm2d_fwd_traits_
|
||||
@@ -123,6 +128,7 @@ struct layernorm2d_fwd_traits_
|
||||
static constexpr bool kFastFDiv = kFastFDiv_;
|
||||
static constexpr bool kWelford = kWelford_;
|
||||
static constexpr bool kTwoPass = kTwoPass_;
|
||||
static constexpr ck_tile::index_t kXbias = kXbias_;
|
||||
static constexpr ck_tile::index_t kFusedAdd = kFusedAdd_;
|
||||
static constexpr ck_tile::index_t kFusedQuant = kFusedQuant_;
|
||||
};
|
||||
@@ -141,6 +147,7 @@ template <typename XDataType_,
|
||||
bool kFastFDiv_,
|
||||
bool kWelford_,
|
||||
bool kTwoPass_,
|
||||
int kXbias_,
|
||||
int kFusedAdd_,
|
||||
int kFusedQuant_>
|
||||
using traits_ = layernorm2d_fwd_traits_<XDataType_,
|
||||
@@ -157,6 +164,7 @@ using traits_ = layernorm2d_fwd_traits_<XDataType_,
|
||||
kFastFDiv_,
|
||||
kWelford_,
|
||||
kTwoPass_,
|
||||
kXbias_,
|
||||
kFusedAdd_,
|
||||
kFusedQuant_>;
|
||||
"""
|
||||
@@ -190,10 +198,12 @@ float layernorm2d_fwd_(const S& s, A a)
|
||||
Traits_::kFastFDiv,
|
||||
Traits_::kWelford,
|
||||
Traits_::kTwoPass,
|
||||
static_cast<ck_tile::Layernorm2dXBiasEnum>(Traits_::kXbias),
|
||||
static_cast<ck_tile::Layernorm2dFusedAddEnum>(Traits_::kFusedAdd),
|
||||
static_cast<ck_tile::Layernorm2dFusedQuantEnum>(Traits_::kFusedQuant)>;
|
||||
using PipelineProblem = ck_tile::Layernorm2dFwdPipelineProblem<
|
||||
typename LayerNormTypeConfig<XDataType, YDataType, XScaleDataType, YScaleDataType>::XDataType,
|
||||
typename LayerNormTypeConfig<XDataType, YDataType, XScaleDataType, YScaleDataType>::XBiasDataType,
|
||||
typename LayerNormTypeConfig<XDataType, YDataType, XScaleDataType, YScaleDataType>::GammaDataType,
|
||||
typename LayerNormTypeConfig<XDataType, YDataType, XScaleDataType, YScaleDataType>::BetaDataType,
|
||||
typename LayerNormTypeConfig<XDataType, YDataType, XScaleDataType, YScaleDataType>::ComputeDataType,
|
||||
@@ -280,7 +290,7 @@ float layernorm2d_fwd(layernorm2d_fwd_traits t,
|
||||
#include "layernorm2d_fwd_api_common.hpp"
|
||||
|
||||
// clang-format off
|
||||
// prec_i prec_o prec_sy rm rn tm tn vn pd mv rpcf welford 2p add sweep
|
||||
// prec_i prec_o prec_sy rm rn tm tn vn pd mv rpcf welford 2p xbias add sweep
|
||||
{F_instance_def}
|
||||
// clang-format on
|
||||
|
||||
@@ -290,6 +300,10 @@ float layernorm2d_fwd(layernorm2d_fwd_traits t,
|
||||
self.working_path = working_path
|
||||
self.kernel_filter = kernel_filter
|
||||
|
||||
class k_xbias_enum(IntEnum):
|
||||
F_NO_XBIAS = 0
|
||||
F_ADD_XBIAS = 1
|
||||
|
||||
class k_fuesd_add_enum(IntEnum):
|
||||
F_NO_ADD = 0
|
||||
F_PRE_ADD = 1
|
||||
@@ -305,6 +319,7 @@ float layernorm2d_fwd(layernorm2d_fwd_traits t,
|
||||
F_kPadN : bool
|
||||
F_kSaveMeanInvStd : bool
|
||||
F_kTwoPass : bool
|
||||
F_kXbias : Any #: layernorm_fwd_codegen.k_bias_enum
|
||||
F_kFusedAdd : Any #: layernorm_fwd_codegen.k_fuesd_add_enum
|
||||
F_kFusedQuant : Any #: layernorm_fwd_codegen.k_fused_sweep_enum
|
||||
|
||||
@@ -321,6 +336,7 @@ float layernorm2d_fwd(layernorm2d_fwd_traits t,
|
||||
@dataclass
|
||||
class k_problem:
|
||||
F_XDataType : str
|
||||
F_XBiasDataType : str
|
||||
F_GammaDataType : str
|
||||
F_BetaDataType : str
|
||||
F_ComputeDataType : str
|
||||
@@ -370,6 +386,7 @@ float layernorm2d_fwd(layernorm2d_fwd_traits t,
|
||||
F_kFastFDiv_ : bool
|
||||
F_kWelford_ : bool
|
||||
F_kTwoPass_ : bool
|
||||
F_kXbias_ : int
|
||||
F_kFusedAdd : int
|
||||
F_kFusedQuant : int
|
||||
|
||||
@@ -377,7 +394,7 @@ float layernorm2d_fwd(layernorm2d_fwd_traits t,
|
||||
def trait_name(self) ->str:
|
||||
t_ = f'{DATA_TYPE_MAP[self.F_XDataType]}, {DATA_TYPE_MAP[self.F_YDataType]}, {DATA_TYPE_MAP[self.F_XScaleDataType]}, {DATA_TYPE_MAP[self.F_YScaleDataType]}, {self.F_Repeat_M:2}, {self.F_Repeat_N:2}, {self.F_ThreadPerBlock_M:2}, {self.F_ThreadPerBlock_N:4}'
|
||||
t_ += f', {self.F_Vector_N:2}, {BOOL_MAP(self.F_kPadN):5}, {BOOL_MAP(self.F_kSaveMeanInvStd_):5}, {BOOL_MAP(self.F_kFastFDiv_):5}, {BOOL_MAP(self.F_kWelford_):5}'
|
||||
t_ += f', {BOOL_MAP(self.F_kTwoPass_):5}, {self.F_kFusedAdd:4}, {self.F_kFusedQuant:4}'
|
||||
t_ += f', {BOOL_MAP(self.F_kTwoPass_):5}, {self.F_kXbias:4}, {self.F_kFusedAdd:4}, {self.F_kFusedQuant:4}'
|
||||
return t_
|
||||
|
||||
# string when calling this kernel
|
||||
@@ -395,6 +412,7 @@ float layernorm2d_fwd(layernorm2d_fwd_traits t,
|
||||
class h_instance:
|
||||
F_DataTypePair : str
|
||||
F_N : str
|
||||
F_xbias : int
|
||||
F_add : int
|
||||
F_sweep : int
|
||||
instance_list : List[Any] # List[h_traits]
|
||||
@@ -404,6 +422,8 @@ float layernorm2d_fwd(layernorm2d_fwd_traits t,
|
||||
prec_i, prec_o = self.F_DataTypePair.split(',')
|
||||
dtype_str = f'{prec_i}' if prec_i == prec_o else f'{prec_i}_{prec_o}'
|
||||
nnn = f'layernorm2d_fwd_{dtype_str}_n{self.F_N}'
|
||||
if self.F_xbias != 0:
|
||||
nnn = nnn + '_' + XBIAS_ENUM_STR_MAP[self.F_xbias]
|
||||
if self.F_add != 0:
|
||||
nnn = nnn + '_' + FUSED_ADD_ENUM_STR_MAP[self.F_add]
|
||||
if self.F_sweep != 0:
|
||||
@@ -462,8 +482,8 @@ float layernorm2d_fwd(layernorm2d_fwd_traits t,
|
||||
elif ins.F_kFusedQuant == 2:
|
||||
_sweep_cond = 't.fused_quant == {f_fused_sweep} && (t.prec_sy == \"{f_sy_type}\")'.format(
|
||||
f_fused_sweep = ins.F_kFusedQuant, f_sy_type=ins.F_YScaleDataType)
|
||||
_cond = '((a.n % {f_vec_n} == 0) && (t.fused_add == {f_fused_add}) && ({f_sweep_cond}))'.format(
|
||||
f_vec_n = ins.F_Vector_N, f_fused_add = ins.F_kFusedAdd,
|
||||
_cond = '((a.n % {f_vec_n} == 0) && (t.xbias == {f_xbias}) && (t.fused_add == {f_fused_add}) && ({f_sweep_cond}))'.format(
|
||||
f_vec_n = ins.F_Vector_N, f_xbias = ins.F_kXbias, f_fused_add = ins.F_kFusedAdd,
|
||||
f_sweep_cond = _sweep_cond)
|
||||
inner_str += self.API_INNER_CASE.format(F_if = get_if_str(idx_in_n, len_in_n, False),
|
||||
F_VEC_COND = _cond, F_instance_func=ins.call_name)
|
||||
@@ -494,62 +514,63 @@ float layernorm2d_fwd(layernorm2d_fwd_traits t,
|
||||
types_16bit = ('int16', 'fp16', 'bf16')
|
||||
#fused_add_list = [0, 1, 2]
|
||||
#fused_sweep_list = [0, 1, 2] # NOTE: only single pass can use fused dynamic quant
|
||||
xbias_list = [0, 1]
|
||||
fused_add_list = [0, 1]
|
||||
fused_sweep_list = [0, 1] # NOTE: only single pass can use fused dynamic quant
|
||||
# rm rn tm tn vn pd mv fdiv welford 2p add sweep
|
||||
h_trait_dict = {'64' : [ h_traits('x', 'y', 'xs', 'ys', 1, 1, 8, 8, 8, True, False, True, True, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 1, 4, 16, 4, True, False, True, True, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 1, 4, 64, 1, True, False, True, True, False, 0, 0)],
|
||||
'128' : [ h_traits('x', 'y', 'xs', 'ys', 1, 1, 4, 16, 8, True, False, True, True, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 1, 4, 64, 2, True, False, True, True, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 2, 4, 64, 1, True, False, True, True, False, 0, 0)],
|
||||
'256' : [ h_traits('x', 'y', 'xs', 'ys', 1, 1, 4, 64, 4, True, False, True, True, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 2, 4, 64, 2, True, False, True, True, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 4, 4, 64, 1, True, False, True, True, False, 0, 0)],
|
||||
'512' : [ h_traits('x', 'y', 'xs', 'ys', 1, 1, 4, 64, 8, True, False, True, True, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 2, 4, 64, 4, True, False, True, True, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 4, 4, 64, 2, True, False, True, True, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 8, 4, 64, 1, True, False, True, True, False, 0, 0)],
|
||||
'768' : [ h_traits('x', 'y', 'xs', 'ys', 1, 3, 4, 64, 4, True, False, True, True, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 6, 4, 64, 2, True, False, True, True, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 12, 4, 64, 1, True, False, True, True, False, 0, 0)],
|
||||
'1024' :[ h_traits('x', 'y', 'xs', 'ys', 1, 1, 2, 128, 8, True, False, True, True, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 2, 2, 128, 4, True, False, True, True, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 4, 2, 128, 2, True, False, True, True, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 256, 1, True, False, True, True, False, 0, 0)],
|
||||
'1536' :[ h_traits('x', 'y', 'xs', 'ys', 1, 3, 4, 64, 8, True, False, True, True, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 3, 2, 128, 4, True, False, True, True, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 3, 1, 256, 2, True, False, True, True, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 6, 1, 256, 1, True, False, True, True, False, 0, 0)],
|
||||
'2048' :[ h_traits('x', 'y', 'xs', 'ys', 1, 1, 1, 256, 8, True, False, True, True, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 2, 1, 256, 4, True, False, True, True, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 256, 2, True, False, True, True, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 8, 1, 256, 1, True, False, True, True, False, 0, 0)],
|
||||
'3072' :[ h_traits('x', 'y', 'xs', 'ys', 1, 3, 1, 128, 8, True, False, True, True, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 3, 1, 256, 4, True, False, True, True, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 6, 1, 256, 2, True, False, True, True, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 3, 1,1024, 1, True, False, True, True, False, 0, 0)],
|
||||
'4096' :[ h_traits('x', 'y', 'xs', 'ys', 1, 2, 1, 256, 8, True, False, True, True, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 256, 4, True, False, True, True, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 2, 1,1024, 2, True, False, True, True, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 4, 1,1024, 1, True, False, True, True, False, 0, 0)],
|
||||
'6144' :[ h_traits('x', 'y', 'xs', 'ys', 1, 3, 1, 256, 8, True, False, True, True, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 3, 1, 512, 4, True, False, True, True, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 3, 1,1024, 2, True, False, True, True, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 6, 1,1024, 1, True, False, True, True, False, 0, 0)],
|
||||
'8192' :[ h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 256, 8, True, False, True, True, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 512, 4, True, False, True, True, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 4, 1,1024, 2, True, False, True, True, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 8, 1,1024, 1, True, False, True, True, False, 0, 0)],
|
||||
'big' :[ h_traits('x', 'y', 'xs', 'ys', 1, 2, 1, 256, 8, True, False, True, True, True, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 256, 4, True, False, True, True, True, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 2, 1,1024, 2, True, False, True, True, True, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 4, 1,1024, 1, True, False, True, True, True, 0, 0)]}
|
||||
# rm rn tm tn vn pd mv fdiv welford 2p xbias add sweep
|
||||
h_trait_dict = {'64' : [ h_traits('x', 'y', 'xs', 'ys', 1, 1, 8, 8, 8, True, False, True, True, False, 0, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 1, 4, 16, 4, True, False, True, True, False, 0, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 1, 4, 64, 1, True, False, True, True, False, 0, 0, 0)],
|
||||
'128' : [ h_traits('x', 'y', 'xs', 'ys', 1, 1, 4, 16, 8, True, False, True, True, False, 0, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 1, 4, 64, 2, True, False, True, True, False, 0, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 2, 4, 64, 1, True, False, True, True, False, 0, 0, 0)],
|
||||
'256' : [ h_traits('x', 'y', 'xs', 'ys', 1, 1, 4, 64, 4, True, False, True, True, False, 0, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 2, 4, 64, 2, True, False, True, True, False, 0, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 4, 4, 64, 1, True, False, True, True, False, 0, 0, 0)],
|
||||
'512' : [ h_traits('x', 'y', 'xs', 'ys', 1, 1, 4, 64, 8, True, False, True, True, False, 0, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 2, 4, 64, 4, True, False, True, True, False, 0, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 4, 4, 64, 2, True, False, True, True, False, 0, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 8, 4, 64, 1, True, False, True, True, False, 0, 0, 0)],
|
||||
'768' : [ h_traits('x', 'y', 'xs', 'ys', 1, 3, 4, 64, 4, True, False, True, True, False, 0, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 6, 4, 64, 2, True, False, True, True, False, 0, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 12, 4, 64, 1, True, False, True, True, False, 0, 0, 0)],
|
||||
'1024' :[ h_traits('x', 'y', 'xs', 'ys', 1, 1, 2, 128, 8, True, False, True, True, False, 0, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 2, 2, 128, 4, True, False, True, True, False, 0, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 4, 2, 128, 2, True, False, True, True, False, 0, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 256, 1, True, False, True, True, False, 0, 0, 0)],
|
||||
'1536' :[ h_traits('x', 'y', 'xs', 'ys', 1, 3, 4, 64, 8, True, False, True, True, False, 0, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 3, 2, 128, 4, True, False, True, True, False, 0, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 3, 1, 256, 2, True, False, True, True, False, 0, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 6, 1, 256, 1, True, False, True, True, False, 0, 0, 0)],
|
||||
'2048' :[ h_traits('x', 'y', 'xs', 'ys', 1, 1, 1, 256, 8, True, False, True, True, False, 0, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 2, 1, 256, 4, True, False, True, True, False, 0, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 256, 2, True, False, True, True, False, 0, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 8, 1, 256, 1, True, False, True, True, False, 0, 0, 0)],
|
||||
'3072' :[ h_traits('x', 'y', 'xs', 'ys', 1, 3, 1, 128, 8, True, False, True, True, False, 0, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 3, 1, 256, 4, True, False, True, True, False, 0, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 6, 1, 256, 2, True, False, True, True, False, 0, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 3, 1,1024, 1, True, False, True, True, False, 0, 0, 0)],
|
||||
'4096' :[ h_traits('x', 'y', 'xs', 'ys', 1, 2, 1, 256, 8, True, False, True, True, False, 0, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 256, 4, True, False, True, True, False, 0, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 2, 1,1024, 2, True, False, True, True, False, 0, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 4, 1,1024, 1, True, False, True, True, False, 0, 0, 0)],
|
||||
'6144' :[ h_traits('x', 'y', 'xs', 'ys', 1, 3, 1, 256, 8, True, False, True, True, False, 0, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 3, 1, 512, 4, True, False, True, True, False, 0, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 3, 1,1024, 2, True, False, True, True, False, 0, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 6, 1,1024, 1, True, False, True, True, False, 0, 0, 0)],
|
||||
'8192' :[ h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 256, 8, True, False, True, True, False, 0, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 512, 4, True, False, True, True, False, 0, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 4, 1,1024, 2, True, False, True, True, False, 0, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 8, 1,1024, 1, True, False, True, True, False, 0, 0, 0)],
|
||||
'big' :[ h_traits('x', 'y', 'xs', 'ys', 1, 2, 1, 256, 8, True, False, True, True, True, 0, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 256, 4, True, False, True, True, True, 0, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 2, 1,1024, 2, True, False, True, True, True, 0, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 4, 1,1024, 1, True, False, True, True, True, 0, 0, 0)]}
|
||||
total_blob = list()
|
||||
for hs_key in h_trait_dict:
|
||||
hs = h_trait_dict[hs_key]
|
||||
current_n = hs[0].F_Repeat_N * hs[0].F_ThreadPerBlock_N * hs[0].F_Vector_N
|
||||
for dtype, scale_type, fused_add, fused_quant in itertools.product(dtype_list, scale_list, fused_add_list, fused_sweep_list):
|
||||
for dtype, scale_type, xbias, fused_add, fused_quant in itertools.product(dtype_list, scale_list, xbias_list, fused_add_list, fused_sweep_list):
|
||||
prec_i, prec_o = dtype.split(',')
|
||||
scale_x, scale_y = scale_type.split(',')
|
||||
if prec_o in dynamic_quant_out_dtype and fused_quant != 1:
|
||||
@@ -563,6 +584,7 @@ float layernorm2d_fwd(layernorm2d_fwd_traits t,
|
||||
h_.F_YDataType = prec_o
|
||||
h_.F_XScaleDataType = scale_y
|
||||
h_.F_YScaleDataType = scale_x
|
||||
h_.F_kXbias = xbias
|
||||
h_.F_kFusedAdd = fused_add
|
||||
h_.F_kFusedQuant = fused_quant
|
||||
# disable welford update for 8bit and 16 bit smallN
|
||||
@@ -579,7 +601,7 @@ float layernorm2d_fwd(layernorm2d_fwd_traits t,
|
||||
current_hs.append(h_) # + "\n"
|
||||
#f.write(str(f.parent / GEN_DIR / (blobs.api_common_header_
|
||||
current_n_str = 'big' if hs_key == 'big' else current_n
|
||||
total_blob.append(h_instance(dtype, current_n_str, fused_add, fused_quant, current_hs))
|
||||
total_blob.append(h_instance(dtype, current_n_str, xbias, fused_add, fused_quant, current_hs))
|
||||
return total_blob
|
||||
|
||||
def list_blobs(self, args) -> None:
|
||||
|
||||
@@ -41,6 +41,7 @@ auto create_args(int argc, char* argv[])
|
||||
.insert("prec_sy",
|
||||
"auto",
|
||||
"output quant scale type, set auto will use fp32. used when fquant=1 or 2")
|
||||
.insert("xbias", "0", "add bias, 0:no add, 1:add bias before fadd")
|
||||
.insert("fadd", "0", "fused-add, 0:no fused add, 1:preadd+store, 2:preadd only")
|
||||
.insert("fquant", "0", "fused-quant, 0:no, 1:smooth-dynamic-quant, 2:dynamic-quant")
|
||||
.insert("warmup", "5", "cold iter")
|
||||
@@ -93,6 +94,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
int do_validation = arg_parser.get_int("v");
|
||||
int warmup = arg_parser.get_int("warmup");
|
||||
int repeat = arg_parser.get_int("repeat");
|
||||
int xbias = arg_parser.get_int("xbias");
|
||||
int fused_add = arg_parser.get_int("fadd");
|
||||
int fused_quant = arg_parser.get_int("fquant");
|
||||
if(fused_quant == 1 && prec_o != "int8")
|
||||
@@ -107,6 +109,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
|
||||
using XDataType = typename TypeConfig::XDataType;
|
||||
using YDataType = typename TypeConfig::YDataType;
|
||||
using XBiasDataType = typename TypeConfig::XBiasDataType;
|
||||
using GammaDataType = typename TypeConfig::GammaDataType;
|
||||
using BetaDataType = typename TypeConfig::BetaDataType;
|
||||
using XResidualDataType = XDataType;
|
||||
@@ -121,6 +124,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
|
||||
// host verify
|
||||
ck_tile::HostTensor<XDataType> x_host({m, n}, {x_stride, 1});
|
||||
ck_tile::HostTensor<XBiasDataType> x_bias_host({n});
|
||||
ck_tile::HostTensor<GammaDataType> gamma_host({n});
|
||||
ck_tile::HostTensor<BetaDataType> beta_host({n});
|
||||
|
||||
@@ -141,10 +145,12 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
ck_tile::FillUniformDistribution<XDataType>{-.5f, .5f}(x_host);
|
||||
ck_tile::FillUniformDistribution<XResidualDataType>{-.5f, .5f}(x_residual_host);
|
||||
ck_tile::FillUniformDistribution<XScaleDataType>{-1.f, 1.f}(x_scale_host);
|
||||
ck_tile::FillUniformDistribution<XBiasDataType>{-.5f, .5f}(x_bias_host);
|
||||
ck_tile::FillUniformDistribution<GammaDataType>{-.5f, .5f}(gamma_host);
|
||||
ck_tile::FillUniformDistribution<BetaDataType>{-.5f, .5f}(beta_host);
|
||||
|
||||
ck_tile::DeviceMem x_buf(x_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem x_bias_buf(x_bias_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem gamma_buf(gamma_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem beta_buf(beta_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem y_buf(y_host_dev.get_element_space_size_in_bytes());
|
||||
@@ -155,6 +161,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
ck_tile::DeviceMem y_residual_buf(y_residual_host.get_element_space_size_in_bytes());
|
||||
|
||||
x_buf.ToDevice(x_host.data());
|
||||
x_bias_buf.ToDevice(x_bias_host.data());
|
||||
gamma_buf.ToDevice(gamma_host.data());
|
||||
beta_buf.ToDevice(beta_host.data());
|
||||
x_residual_buf.ToDevice(x_residual_host.data());
|
||||
@@ -179,11 +186,12 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
<< ", yr_stride:" << yr_stride << std::flush;
|
||||
|
||||
layernorm2d_fwd_traits traits{
|
||||
prec_i, prec_o, prec_sx, prec_sy, SaveMeanVar, fused_add, fused_quant};
|
||||
prec_i, prec_o, prec_sx, prec_sy, SaveMeanVar, xbias, fused_add, fused_quant};
|
||||
|
||||
layernorm2d_fwd_args args{x_buf.GetDeviceBuffer(),
|
||||
fused_add != 0 ? x_residual_buf.GetDeviceBuffer() : nullptr,
|
||||
fused_quant == 1 ? x_scale_buf.GetDeviceBuffer() : nullptr,
|
||||
x_bias_buf.GetDeviceBuffer(),
|
||||
gamma_buf.GetDeviceBuffer(),
|
||||
beta_buf.GetDeviceBuffer(),
|
||||
|
||||
@@ -210,8 +218,9 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
return false;
|
||||
}
|
||||
|
||||
std::size_t num_byte = sizeof(XDataType) * m * n + sizeof(GammaDataType) * n +
|
||||
sizeof(BetaDataType) * n + sizeof(YDataType) * m * n;
|
||||
std::size_t num_byte = sizeof(XDataType) * m * n + sizeof(XBiasDataType) * n +
|
||||
sizeof(GammaDataType) * n + sizeof(BetaDataType) * n +
|
||||
sizeof(YDataType) * m * n;
|
||||
|
||||
float gb_per_sec = num_byte / 1.E6 / ave_time;
|
||||
std::cout << ", " << ave_time * 1.E3 << " us, " << gb_per_sec << " GB/s" << std::flush;
|
||||
@@ -221,6 +230,22 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
if(do_validation)
|
||||
{
|
||||
// reference
|
||||
if(xbias != 0)
|
||||
{
|
||||
// add bias before fadd
|
||||
int M = x_host.mDesc.get_lengths()[0];
|
||||
int N = x_host.mDesc.get_lengths()[1];
|
||||
for(int idx_m = 0; idx_m < M; ++idx_m)
|
||||
{
|
||||
for(int idx_n = 0; idx_n < N; ++idx_n)
|
||||
{
|
||||
x_host(idx_m, idx_n) = ck_tile::type_convert<XDataType>(
|
||||
ck_tile::type_convert<ComputeDataType>(x_host(idx_m, idx_n)) +
|
||||
ck_tile::type_convert<ComputeDataType>(x_bias_host(idx_n)));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if(fused_add != 0)
|
||||
{
|
||||
// fused pre_add/pre_add_store
|
||||
|
||||
@@ -16,6 +16,7 @@ struct LayerNormTypeConfig<ck_tile::half_t, OutType, XScaleDataType_, YScaleData
|
||||
{
|
||||
using XDataType = ck_tile::half_t;
|
||||
using YDataType = OutType;
|
||||
using XBiasDataType = ck_tile::half_t;
|
||||
using GammaDataType = ck_tile::half_t;
|
||||
using BetaDataType = ck_tile::half_t;
|
||||
using MeanDataType = ck_tile::half_t;
|
||||
@@ -30,6 +31,7 @@ struct LayerNormTypeConfig<ck_tile::bf16_t, OutType, XScaleDataType_, YScaleData
|
||||
{
|
||||
using XDataType = ck_tile::bf16_t;
|
||||
using YDataType = OutType;
|
||||
using XBiasDataType = ck_tile::bf16_t;
|
||||
using GammaDataType = ck_tile::bf16_t;
|
||||
using BetaDataType = ck_tile::bf16_t;
|
||||
using MeanDataType = ck_tile::bf16_t;
|
||||
@@ -57,6 +59,7 @@ struct layernorm2d_fwd_traits
|
||||
std::string prec_sy; // y-scale, used for [M*1] output for next layer
|
||||
|
||||
bool save_mean_var; //
|
||||
int xbias; // 0:no-bias, 1:add bias
|
||||
int fused_add; // 0:no-add, 1:pre-add-store, 2:pre-add
|
||||
int fused_quant; // 0:no-sweep, 1:smooth-dynamic-quant, 2:dynamic-quant
|
||||
};
|
||||
|
||||
@@ -15,6 +15,7 @@ struct Layernorm2dFwdHostArgs
|
||||
const void* p_x; // [m ,n], input, fp16/bf16
|
||||
const void* p_x_residual; // [m ,n], shortcut input, prec same as input, nullptr if not used
|
||||
const void* p_x_scale; // [1 ,n], smooth scale input, fp32, nullptr if not used
|
||||
const void* p_x_bias; // [1, n], bias, prec same as input
|
||||
const void* p_gamma; // [1, n], gamma, prec same as input
|
||||
const void* p_beta; // [1, n], beta, prec same as input
|
||||
|
||||
@@ -43,6 +44,7 @@ struct Layernorm2dFwd
|
||||
using Problem = typename Pipeline::Problem;
|
||||
|
||||
using XDataType = remove_cvref_t<typename Problem::XDataType>;
|
||||
using XBiasDataType = remove_cvref_t<typename Problem::XBiasDataType>;
|
||||
using GammaDataType = remove_cvref_t<typename Problem::GammaDataType>;
|
||||
using BetaDataType = remove_cvref_t<typename Problem::BetaDataType>;
|
||||
using ComputeDataType = remove_cvref_t<typename Problem::ComputeDataType>;
|
||||
@@ -67,6 +69,7 @@ struct Layernorm2dFwd
|
||||
static constexpr bool kPadM = false; // always no need to pad along M
|
||||
static constexpr bool kPadN = Problem::Traits::kPadN;
|
||||
static constexpr bool kTwoPass = Problem::Traits::kTwoPass;
|
||||
static constexpr auto kXbias = Problem::Traits::kXbias;
|
||||
static constexpr auto kFusedAdd = Problem::Traits::kFusedAdd;
|
||||
static constexpr auto kFusedQuant = Problem::Traits::kFusedQuant;
|
||||
|
||||
@@ -82,6 +85,7 @@ struct Layernorm2dFwd
|
||||
const void* p_x; // [m ,n], input, fp16/bf16
|
||||
const void* p_x_residual; // [m ,n], shortcut input, prec same as input, nullptr if not used
|
||||
const void* p_x_scale; // [1 ,n], smooth scale input, fp32, nullptr if not used
|
||||
const void* p_x_bias; // [1, n], bias, prec same as input
|
||||
const void* p_gamma; // [1, n], gamma, prec same as input
|
||||
const void* p_beta; // [1, n], beta, prec same as input
|
||||
|
||||
@@ -108,6 +112,7 @@ struct Layernorm2dFwd
|
||||
return Kargs{hargs.p_x,
|
||||
hargs.p_x_residual,
|
||||
hargs.p_x_scale,
|
||||
hargs.p_x_bias,
|
||||
hargs.p_gamma,
|
||||
hargs.p_beta,
|
||||
hargs.p_y,
|
||||
@@ -152,6 +157,7 @@ struct Layernorm2dFwd
|
||||
using S_ = typename Problem::BlockShape;
|
||||
auto surfix = [&] () {
|
||||
std::string n;
|
||||
if (kXbias != Layernorm2dXBiasEnum::NO_BIAS) n += _SS_("_") + Layernorm2dXBiasEnumName<kXbias>::name;
|
||||
if (kFusedAdd != Layernorm2dFusedAddEnum::NO_ADD) n += _SS_("_") + Layernorm2dFusedAddEnumName<kFusedAdd>::name;
|
||||
if (kFusedQuant != Layernorm2dFusedQuantEnum::NO_SWEEP) n += _SS_("_") + Layernorm2dFusedQuantEnumName<kFusedQuant>::name;
|
||||
if (kPadN) n += "_pn";
|
||||
@@ -228,6 +234,27 @@ struct Layernorm2dFwd
|
||||
}
|
||||
}();
|
||||
|
||||
const auto x_bias_window = [&]() {
|
||||
if constexpr(kXbias == Layernorm2dXBiasEnum::ADD_BIAS)
|
||||
{
|
||||
const auto tmp_ = make_naive_tensor_view<address_space_enum::global>(
|
||||
static_cast<const XBiasDataType*>(kargs.p_x_bias),
|
||||
make_tuple(kargs.n),
|
||||
make_tuple(1),
|
||||
number<Vector_N>{},
|
||||
number<1>{});
|
||||
|
||||
const auto tmp2_ =
|
||||
pad_tensor_view(tmp_, make_tuple(number<Block_N>{}), sequence<false>{});
|
||||
|
||||
return make_tile_window(tmp2_, make_tuple(number<Block_N>{}), {0});
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_null_tile_window(make_tuple(number<Block_N>{}));
|
||||
}
|
||||
}();
|
||||
|
||||
const auto gamma_window = [&]() {
|
||||
const auto tmp_ = make_naive_tensor_view<address_space_enum::global>(
|
||||
static_cast<const GammaDataType*>(kargs.p_gamma),
|
||||
@@ -371,6 +398,7 @@ struct Layernorm2dFwd
|
||||
|
||||
Pipeline{}(x_window,
|
||||
x_residual_window,
|
||||
x_bias_window,
|
||||
gamma_window,
|
||||
beta_window,
|
||||
y_window,
|
||||
|
||||
@@ -18,6 +18,7 @@ struct Layernorm2dFwdPipelineOnePass
|
||||
using Policy = ck_tile::remove_cvref_t<Policy_>;
|
||||
|
||||
using XDataType = ck_tile::remove_cvref_t<typename Problem::XDataType>;
|
||||
using XBiasDataType = ck_tile::remove_cvref_t<typename Problem::XBiasDataType>;
|
||||
using GammaDataType = ck_tile::remove_cvref_t<typename Problem::GammaDataType>;
|
||||
using BetaDataType = ck_tile::remove_cvref_t<typename Problem::BetaDataType>;
|
||||
using ComputeDataType = ck_tile::remove_cvref_t<typename Problem::ComputeDataType>;
|
||||
@@ -38,6 +39,7 @@ struct Layernorm2dFwdPipelineOnePass
|
||||
static constexpr bool kPadN = Problem::Traits::kPadN;
|
||||
static constexpr bool kFastFDiv = Problem::Traits::kFastFDiv;
|
||||
static constexpr bool kWelford = Problem::Traits::kWelford;
|
||||
static constexpr auto kXbias = Problem::Traits::kXbias;
|
||||
static constexpr auto kFusedAdd = Problem::Traits::kFusedAdd;
|
||||
static constexpr auto kFusedQuant = Problem::Traits::kFusedQuant;
|
||||
|
||||
@@ -55,6 +57,7 @@ struct Layernorm2dFwdPipelineOnePass
|
||||
|
||||
template <typename XWindow,
|
||||
typename XResidualWindow,
|
||||
typename XBiasWindow,
|
||||
typename GammaWindow,
|
||||
typename BetaWindow,
|
||||
typename YWindow,
|
||||
@@ -66,6 +69,7 @@ struct Layernorm2dFwdPipelineOnePass
|
||||
typename Epilogue>
|
||||
CK_TILE_DEVICE auto operator()(const XWindow& x_window_,
|
||||
const XResidualWindow& x_residual_window_,
|
||||
const XBiasWindow& x_bias_window_,
|
||||
const GammaWindow& gamma_window_,
|
||||
const BetaWindow& beta_window_,
|
||||
YWindow& y_window_,
|
||||
@@ -81,6 +85,8 @@ struct Layernorm2dFwdPipelineOnePass
|
||||
{
|
||||
const auto x_window =
|
||||
make_tile_window(x_window_, Policy::template MakeXBlockTileDistribution<Problem>());
|
||||
const auto x_bias_window = make_tile_window(
|
||||
x_bias_window_, Policy::template MakeGammaBetaBlockTileDistribution<Problem>());
|
||||
const auto gamma_window = make_tile_window(
|
||||
gamma_window_, Policy::template MakeGammaBetaBlockTileDistribution<Problem>());
|
||||
const auto beta_window = make_tile_window(
|
||||
@@ -90,8 +96,9 @@ struct Layernorm2dFwdPipelineOnePass
|
||||
auto y_residual_window = make_tile_window(
|
||||
y_residual_window_, Policy::template MakeXBlockTileDistribution<Problem>());
|
||||
|
||||
auto x = load_tile(x_window);
|
||||
auto x_resi = load_tile(x_residual_window);
|
||||
auto x = load_tile(x_window);
|
||||
auto x_resi = load_tile(x_residual_window);
|
||||
const auto x_bias = load_tile(x_bias_window);
|
||||
|
||||
int cur_count = 0;
|
||||
int max_count =
|
||||
@@ -112,6 +119,15 @@ struct Layernorm2dFwdPipelineOnePass
|
||||
|
||||
auto acc = cast_tile<ComputeDataType>(x);
|
||||
|
||||
if constexpr(kXbias == Layernorm2dXBiasEnum::ADD_BIAS)
|
||||
{
|
||||
sweep_tile(x, [&](auto idx) {
|
||||
// compute x = bias + x
|
||||
constexpr auto j_idx = make_tuple(idx[number<1>{}]);
|
||||
acc(idx) = type_convert<ComputeDataType>(x_bias[j_idx]) + acc(idx);
|
||||
});
|
||||
}
|
||||
|
||||
if constexpr(kFusedAdd == Layernorm2dFusedAddEnum::PRE_ADD_STORE ||
|
||||
kFusedAdd == Layernorm2dFusedAddEnum::PRE_ADD)
|
||||
{
|
||||
|
||||
@@ -8,6 +8,7 @@
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename XDataType_,
|
||||
typename XBiasDataType_,
|
||||
typename GammaDataType_,
|
||||
typename BetaDataType_,
|
||||
typename ComputeDataType_,
|
||||
@@ -21,6 +22,7 @@ template <typename XDataType_,
|
||||
struct Layernorm2dFwdPipelineProblem
|
||||
{
|
||||
using XDataType = remove_cvref_t<XDataType_>;
|
||||
using XBiasDataType = remove_cvref_t<XBiasDataType_>;
|
||||
using GammaDataType = remove_cvref_t<GammaDataType_>;
|
||||
using BetaDataType = remove_cvref_t<BetaDataType_>;
|
||||
using ComputeDataType = remove_cvref_t<ComputeDataType_>;
|
||||
|
||||
@@ -17,6 +17,7 @@ struct Layernorm2dFwdPipelineTwoPass
|
||||
using Policy = ck_tile::remove_cvref_t<Policy_>;
|
||||
|
||||
using XDataType = ck_tile::remove_cvref_t<typename Problem::XDataType>;
|
||||
using XBiasDataType = ck_tile::remove_cvref_t<typename Problem::XBiasDataType>;
|
||||
using GammaDataType = ck_tile::remove_cvref_t<typename Problem::GammaDataType>;
|
||||
using BetaDataType = ck_tile::remove_cvref_t<typename Problem::BetaDataType>;
|
||||
using ComputeDataType = ck_tile::remove_cvref_t<typename Problem::ComputeDataType>;
|
||||
@@ -37,6 +38,7 @@ struct Layernorm2dFwdPipelineTwoPass
|
||||
static constexpr bool kPadN = Problem::Traits::kPadN;
|
||||
static constexpr bool kFastFDiv = Problem::Traits::kFastFDiv;
|
||||
static constexpr bool kWelford = Problem::Traits::kWelford;
|
||||
static constexpr auto kXbias = Problem::Traits::kXbias;
|
||||
static constexpr auto kFusedAdd = Problem::Traits::kFusedAdd;
|
||||
static constexpr auto kFusedQuant = Problem::Traits::kFusedQuant;
|
||||
|
||||
@@ -54,6 +56,7 @@ struct Layernorm2dFwdPipelineTwoPass
|
||||
|
||||
template <typename XWindow,
|
||||
typename XResidualWindow,
|
||||
typename XBiasWindow,
|
||||
typename GammaWindow,
|
||||
typename BetaWindow,
|
||||
typename YWindow,
|
||||
@@ -65,6 +68,7 @@ struct Layernorm2dFwdPipelineTwoPass
|
||||
typename Epilogue>
|
||||
CK_TILE_DEVICE auto operator()(const XWindow& x_window_,
|
||||
const XResidualWindow& x_residual_window_,
|
||||
const XBiasWindow& x_bias_window_,
|
||||
const GammaWindow& gamma_window_,
|
||||
const BetaWindow& beta_window_,
|
||||
YWindow& y_window,
|
||||
@@ -81,6 +85,8 @@ struct Layernorm2dFwdPipelineTwoPass
|
||||
static_assert(kWelford == true, "2 pass only supports welford merge");
|
||||
auto x_window =
|
||||
make_tile_window(x_window_, Policy::template MakeXBlockTileDistribution<Problem>());
|
||||
auto x_bias_window = make_tile_window(
|
||||
x_bias_window_, Policy::template MakeGammaBetaBlockTileDistribution<Problem>());
|
||||
auto gamma_window = make_tile_window(
|
||||
gamma_window_, Policy::template MakeGammaBetaBlockTileDistribution<Problem>());
|
||||
auto beta_window = make_tile_window(
|
||||
@@ -115,13 +121,24 @@ struct Layernorm2dFwdPipelineTwoPass
|
||||
|
||||
for(int iN = __builtin_amdgcn_readfirstlane(0); iN < num_n_tile_iteration; ++iN)
|
||||
{
|
||||
auto x = load_tile(x_window);
|
||||
auto x_resi = load_tile(x_residual_window);
|
||||
auto x = load_tile(x_window);
|
||||
auto x_resi = load_tile(x_residual_window);
|
||||
const auto x_bias = load_tile(x_bias_window);
|
||||
|
||||
move_tile_window(x_window, {0, Block_N});
|
||||
move_tile_window(x_residual_window, {0, Block_N});
|
||||
move_tile_window(x_bias_window, {Block_N});
|
||||
auto acc = cast_tile<ComputeDataType>(x);
|
||||
|
||||
if constexpr(kXbias == Layernorm2dXBiasEnum::ADD_BIAS)
|
||||
{
|
||||
sweep_tile(x, [&](auto idx) {
|
||||
// compute x = bias + x
|
||||
constexpr auto j_idx = make_tuple(idx[number<1>{}]);
|
||||
acc(idx) = type_convert<ComputeDataType>(x_bias[j_idx]) + acc(idx);
|
||||
});
|
||||
}
|
||||
|
||||
if constexpr(kFusedAdd == Layernorm2dFusedAddEnum::PRE_ADD_STORE ||
|
||||
kFusedAdd == Layernorm2dFusedAddEnum::PRE_ADD)
|
||||
{
|
||||
@@ -167,6 +184,7 @@ struct Layernorm2dFwdPipelineTwoPass
|
||||
|
||||
move_tile_window(x_window, {0, -Block_N});
|
||||
move_tile_window(x_residual_window, {0, -Block_N});
|
||||
move_tile_window(x_bias_window, {-Block_N});
|
||||
move_tile_window(gamma_window, {stride_to_right_most_window});
|
||||
move_tile_window(beta_window, {stride_to_right_most_window});
|
||||
move_tile_window(y_window, {0, stride_to_right_most_window});
|
||||
@@ -174,9 +192,19 @@ struct Layernorm2dFwdPipelineTwoPass
|
||||
// layernorm computation
|
||||
for(int iN = __builtin_amdgcn_readfirstlane(0); iN < num_n_tile_iteration; ++iN)
|
||||
{
|
||||
auto x = load_tile(x_window);
|
||||
auto x_resi = load_tile(x_residual_window);
|
||||
auto acc = cast_tile<ComputeDataType>(x);
|
||||
auto x = load_tile(x_window);
|
||||
auto x_resi = load_tile(x_residual_window);
|
||||
const auto x_bias = load_tile(x_bias_window);
|
||||
auto acc = cast_tile<ComputeDataType>(x);
|
||||
|
||||
if constexpr(kXbias == Layernorm2dXBiasEnum::ADD_BIAS)
|
||||
{
|
||||
sweep_tile(x, [&](auto idx) {
|
||||
// compute x = bias + x
|
||||
constexpr auto j_idx = make_tuple(idx[number<1>{}]);
|
||||
acc(idx) = type_convert<ComputeDataType>(x_bias[j_idx]) + acc(idx);
|
||||
});
|
||||
}
|
||||
|
||||
if constexpr(kFusedAdd == Layernorm2dFusedAddEnum::PRE_ADD_STORE ||
|
||||
kFusedAdd == Layernorm2dFusedAddEnum::PRE_ADD)
|
||||
@@ -209,6 +237,7 @@ struct Layernorm2dFwdPipelineTwoPass
|
||||
|
||||
move_tile_window(x_window, {0, -Block_N});
|
||||
move_tile_window(x_residual_window, {0, -Block_N});
|
||||
move_tile_window(x_bias_window, {-Block_N});
|
||||
move_tile_window(gamma_window, {-Block_N});
|
||||
move_tile_window(beta_window, {-Block_N});
|
||||
move_tile_window(y_window, {0, -Block_N});
|
||||
|
||||
@@ -7,6 +7,19 @@
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
enum class Layernorm2dXBiasEnum
|
||||
{
|
||||
NO_BIAS = 0,
|
||||
// add bias before fused add
|
||||
ADD_BIAS = 1,
|
||||
};
|
||||
|
||||
// clang-format off
|
||||
template<Layernorm2dXBiasEnum> struct Layernorm2dXBiasEnumName;
|
||||
template<> struct Layernorm2dXBiasEnumName<Layernorm2dXBiasEnum::NO_BIAS> { static constexpr const char * name = "no"; };
|
||||
template<> struct Layernorm2dXBiasEnumName<Layernorm2dXBiasEnum::ADD_BIAS> { static constexpr const char * name = "xbias"; };
|
||||
// clang-format on
|
||||
|
||||
enum class Layernorm2dFusedAddEnum
|
||||
{
|
||||
NO_ADD = 0,
|
||||
@@ -42,6 +55,7 @@ template <bool kPadN_,
|
||||
bool kFastFDiv_,
|
||||
bool kWelford_,
|
||||
bool kTwoPass_,
|
||||
Layernorm2dXBiasEnum kXbias_,
|
||||
Layernorm2dFusedAddEnum kFusedAdd_,
|
||||
Layernorm2dFusedQuantEnum kFusedQuant_>
|
||||
struct Layernorm2dFwdTraits
|
||||
@@ -51,6 +65,7 @@ struct Layernorm2dFwdTraits
|
||||
static constexpr bool kFastFDiv = kFastFDiv_;
|
||||
static constexpr bool kWelford = kWelford_;
|
||||
static constexpr bool kTwoPass = kTwoPass_;
|
||||
static constexpr Layernorm2dXBiasEnum kXbias = kXbias_;
|
||||
static constexpr Layernorm2dFusedAddEnum kFusedAdd = kFusedAdd_;
|
||||
static constexpr Layernorm2dFusedQuantEnum kFusedQuant = kFusedQuant_;
|
||||
};
|
||||
|
||||
Reference in New Issue
Block a user