mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 22:39:03 +00:00
[CK_TILE] Improve RMS/Layer Normalization 2 Pass Pipeline Performance (#1861)
* 50ms -> 28ms * Fix bug in non fuse_add_store cases * Fine tuned setting for 2 pass pipeline * adjust workload * remove unnecessary change * add layernorm * Adding output quant and unquant results at the same time. * fix test * fix format * tune for cases 128x640 and 128x1024 * bug ifx
This commit is contained in:
@@ -564,9 +564,9 @@ float layernorm2d_fwd(layernorm2d_fwd_traits t,
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 512, 4, True, False, True, True, False, 0, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 4, 1,1024, 2, True, False, True, True, False, 0, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 8, 1,1024, 1, True, False, True, True, False, 0, 0, 0)],
|
||||
'big' :[ h_traits('x', 'y', 'xs', 'ys', 1, 2, 1, 256, 8, True, False, True, True, True, 0, 0, 0),
|
||||
'big' :[ h_traits('x', 'y', 'xs', 'ys', 1, 1, 1,1024, 8, True, False, True, True, True, 0, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 256, 4, True, False, True, True, True, 0, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 2, 1,1024, 2, True, False, True, True, True, 0, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 12, 1, 256, 2, True, False, True, True, True, 0, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 4, 1,1024, 1, True, False, True, True, True, 0, 0, 0)]}
|
||||
total_blob = list()
|
||||
for hs_key in h_trait_dict:
|
||||
|
||||
@@ -41,6 +41,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
using YDataType = DataType;
|
||||
using GammaDataType = DataType;
|
||||
using InvRmsDataType = ck_tile::null_type;
|
||||
using UnquantYDataType = ck_tile::null_type;
|
||||
using SmoothScaleDataType = ck_tile::null_type;
|
||||
using YScaleDataType = ck_tile::null_type;
|
||||
|
||||
@@ -55,6 +56,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
|
||||
ck_tile::HostTensor<InvRmsDataType> invRms_host_ref({m});
|
||||
|
||||
ck_tile::HostTensor<UnquantYDataType> unquant_y_host_ref({m, n}, {stride, 1});
|
||||
|
||||
ck_tile::FillUniformDistribution<XDataType>{-.5f, .5f}(x_host);
|
||||
ck_tile::FillUniformDistribution<GammaDataType>{-.5f, .5f}(gamma_host);
|
||||
|
||||
@@ -76,6 +79,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
using PipelineTraits =
|
||||
ck_tile::Rmsnorm2dFwdTraits<true, // kPadN
|
||||
false, // kSaveInvRms
|
||||
false, // kSaveUnquant
|
||||
kTwoPass,
|
||||
ck_tile::Rmsnorm2dFusedAddEnum::NO_ADD, // fuse add
|
||||
ck_tile::Rmsnorm2dFusedQuantEnum::NO_SWEEP>; // fuse quant
|
||||
@@ -85,6 +89,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
ComputeDataType,
|
||||
YDataType,
|
||||
InvRmsDataType,
|
||||
UnquantYDataType,
|
||||
SmoothScaleDataType,
|
||||
YScaleDataType,
|
||||
Shape,
|
||||
@@ -108,6 +113,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
nullptr,
|
||||
nullptr,
|
||||
nullptr,
|
||||
nullptr,
|
||||
epsilon,
|
||||
m,
|
||||
n,
|
||||
@@ -135,8 +141,9 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
GammaDataType,
|
||||
ComputeDataType,
|
||||
YDataType,
|
||||
InvRmsDataType>(
|
||||
x_host, gamma_host, y_host_ref, invRms_host_ref, epsilon);
|
||||
InvRmsDataType,
|
||||
UnquantYDataType>(
|
||||
x_host, gamma_host, y_host_ref, invRms_host_ref, unquant_y_host_ref, epsilon);
|
||||
|
||||
y_buf.FromDevice(y_host_dev.data());
|
||||
|
||||
|
||||
@@ -54,6 +54,7 @@ template <typename XDataType_,
|
||||
typename YDataType_,
|
||||
typename SmoothScaleDataType_,
|
||||
typename YScaleDataType_,
|
||||
typename UnquantYDataType_,
|
||||
ck_tile::index_t Repeat_M_, // each thread repeat along M
|
||||
ck_tile::index_t Repeat_N_, // each thread repeat along N
|
||||
ck_tile::index_t ThreadPerBlock_M_, // num threads along M
|
||||
@@ -61,6 +62,7 @@ template <typename XDataType_,
|
||||
ck_tile::index_t Vector_N_, // vector size along N
|
||||
bool kPadN_,
|
||||
bool kSaveInvRms_,
|
||||
bool kSaveUnquant_,
|
||||
bool kTwoPass_,
|
||||
ck_tile::index_t kFusedAdd_ = 0,
|
||||
ck_tile::index_t kFusedQuant_ = 0>
|
||||
@@ -70,6 +72,7 @@ struct rmsnorm2d_fwd_traits_
|
||||
using YDataType = ck_tile::remove_cvref_t<YDataType_>;
|
||||
using SmoothScaleDataType = ck_tile::remove_cvref_t<SmoothScaleDataType_>;
|
||||
using YScaleDataType = ck_tile::remove_cvref_t<YScaleDataType_>;
|
||||
using UnquantYDataType = ck_tile::remove_cvref_t<UnquantYDataType_>;
|
||||
|
||||
static constexpr bool is_warp_per_row = ThreadPerBlock_N_ <= warpSize;
|
||||
static_assert((ThreadPerBlock_M_ * ThreadPerBlock_N_) % warpSize == 0);
|
||||
@@ -120,9 +123,10 @@ struct rmsnorm2d_fwd_traits_
|
||||
|
||||
using Shape = ck_tile::Generic2dBlockShape<BlockTile, BlockWarps, WarpTile, Vector>;
|
||||
|
||||
static constexpr bool kPadN = kPadN_;
|
||||
static constexpr bool kSaveInvRms = kSaveInvRms_;
|
||||
static constexpr bool kTwoPass = kTwoPass_;
|
||||
static constexpr bool kPadN = kPadN_;
|
||||
static constexpr bool kSaveInvRms = kSaveInvRms_;
|
||||
static constexpr bool kSaveUnquant = kSaveUnquant_;
|
||||
static constexpr bool kTwoPass = kTwoPass_;
|
||||
static constexpr ck_tile::index_t kFusedAdd = kFusedAdd_;
|
||||
static constexpr ck_tile::index_t kFusedQuant = kFusedQuant_;
|
||||
};
|
||||
@@ -131,6 +135,7 @@ template <typename XDataType_,
|
||||
typename YDataType_,
|
||||
typename SmoothScaleDataType_,
|
||||
typename YScaleDataType_,
|
||||
typename UnquantYDataType_,
|
||||
ck_tile::index_t Repeat_M_, // each thread repeat along M
|
||||
ck_tile::index_t Repeat_N_, // each thread repeat along N
|
||||
ck_tile::index_t ThreadPerBlock_M_, // num threads along M
|
||||
@@ -138,6 +143,7 @@ template <typename XDataType_,
|
||||
ck_tile::index_t Vector_N_, // vector size along N
|
||||
bool kPadN_,
|
||||
bool kSaveInvRms_,
|
||||
bool kSaveUnquant_,
|
||||
bool kTwoPass_,
|
||||
int kFusedAdd_,
|
||||
int kFusedQuant_>
|
||||
@@ -145,6 +151,7 @@ using traits_ = rmsnorm2d_fwd_traits_<XDataType_,
|
||||
YDataType_,
|
||||
SmoothScaleDataType_,
|
||||
YScaleDataType_,
|
||||
UnquantYDataType_,
|
||||
Repeat_M_,
|
||||
Repeat_N_,
|
||||
ThreadPerBlock_M_,
|
||||
@@ -152,6 +159,7 @@ using traits_ = rmsnorm2d_fwd_traits_<XDataType_,
|
||||
Vector_N_,
|
||||
kPadN_,
|
||||
kSaveInvRms_,
|
||||
kSaveUnquant_,
|
||||
kTwoPass_,
|
||||
kFusedAdd_,
|
||||
kFusedQuant_>;
|
||||
@@ -180,11 +188,13 @@ float rmsnorm2d_fwd_(const S& s, A a)
|
||||
using YDataType = typename Traits_::YDataType;
|
||||
using SmoothScaleDataType = typename Traits_::SmoothScaleDataType;
|
||||
using YScaleDataType = typename Traits_::YScaleDataType;
|
||||
using UnquantYDataType = typename Traits_::UnquantYDataType;
|
||||
using ComputeDataType = typename RmsnormTypeConfig<XDataType, YDataType, SmoothScaleDataType, YScaleDataType>::ComputeDataType;
|
||||
|
||||
using PipelineTraits =
|
||||
ck_tile::Rmsnorm2dFwdTraits<Traits_::kPadN,
|
||||
Traits_::kSaveInvRms,
|
||||
Traits_::kSaveUnquant,
|
||||
Traits_::kTwoPass,
|
||||
static_cast<ck_tile::Rmsnorm2dFusedAddEnum>(Traits_::kFusedAdd),
|
||||
static_cast<ck_tile::Rmsnorm2dFusedQuantEnum>(Traits_::kFusedQuant)>;
|
||||
@@ -195,6 +205,7 @@ float rmsnorm2d_fwd_(const S& s, A a)
|
||||
typename RmsnormTypeConfig<XDataType, YDataType, SmoothScaleDataType, YScaleDataType>::ComputeDataType,
|
||||
typename RmsnormTypeConfig<XDataType, YDataType, SmoothScaleDataType, YScaleDataType>::YDataType,
|
||||
typename RmsnormTypeConfig<XDataType, YDataType, SmoothScaleDataType, YScaleDataType>::InvRmsDataType,
|
||||
typename RmsnormTypeConfig<XDataType, YDataType, SmoothScaleDataType, YScaleDataType>::UnquantYDataType,
|
||||
typename RmsnormTypeConfig<XDataType, YDataType, SmoothScaleDataType, YScaleDataType>::SmoothScaleDataType,
|
||||
typename RmsnormTypeConfig<XDataType, YDataType, SmoothScaleDataType, YScaleDataType>::YScaleDataType,
|
||||
typename Traits_::Shape,
|
||||
@@ -213,7 +224,16 @@ float rmsnorm2d_fwd_(const S& s, A a)
|
||||
|
||||
using DynamicQuantEpilogue = ck_tile::DynamicQuantEpilogue<DynamicQuantEpilogueProblem>;
|
||||
|
||||
using Epilogue = std::conditional_t<Traits_::kFusedQuant != 0, DynamicQuantEpilogue, Default2DEpilogue>;
|
||||
using Default2DAndDynamicQuantEpilogueProblem = ck_tile::Default2DAndDynamicQuantEpilogueProblem<
|
||||
ComputeDataType, SmoothScaleDataType, YScaleDataType, YDataType, UnquantYDataType, typename Traits_::Shape,
|
||||
ck_tile::Default2DAndDynamicQuantEpilogueTraits<false, Traits_::kPadN, UseSmoothInputScale, false, true/*max3*/>>;
|
||||
using Default2DAndDynamicQuantEpilogue = ck_tile::Default2DAndDynamicQuantEpilogue<Default2DAndDynamicQuantEpilogueProblem>;
|
||||
|
||||
using Epilogue = std::conditional_t<Traits_::kFusedQuant != 0,
|
||||
std::conditional_t<Traits_::kSaveUnquant,
|
||||
Default2DAndDynamicQuantEpilogue,
|
||||
DynamicQuantEpilogue>,
|
||||
Default2DEpilogue>;
|
||||
|
||||
using Kernel = ck_tile::Rmsnorm2dFwd<Pipeline, Epilogue>;
|
||||
|
||||
@@ -355,6 +375,7 @@ float rmsnorm2d_fwd(rmsnorm2d_fwd_traits t,
|
||||
F_YDataType : str
|
||||
F_SmoothScaleDataType : str
|
||||
F_YScaleDataType : str
|
||||
F_UnquantYDataType : str
|
||||
F_Repeat_M : int
|
||||
F_Repeat_N : int
|
||||
F_ThreadPerBlock_M : int
|
||||
@@ -362,14 +383,15 @@ float rmsnorm2d_fwd(rmsnorm2d_fwd_traits t,
|
||||
F_Vector_N : int
|
||||
F_kPadN : bool
|
||||
F_kSaveInvRms : bool
|
||||
F_kSaveUnquant: bool
|
||||
F_kTwoPass : bool
|
||||
F_kFusedAdd : int
|
||||
F_kFusedQuant : int
|
||||
|
||||
@property
|
||||
def trait_name(self) ->str:
|
||||
t_ = f'{DATA_TYPE_MAP[self.F_XDataType]}, {DATA_TYPE_MAP[self.F_YDataType]}, {DATA_TYPE_MAP[self.F_SmoothScaleDataType]}, {DATA_TYPE_MAP[self.F_YScaleDataType]}, {self.F_Repeat_M:2}, {self.F_Repeat_N:2}, {self.F_ThreadPerBlock_M:2}, {self.F_ThreadPerBlock_N:4}'
|
||||
t_ += f', {self.F_Vector_N:2}, {BOOL_MAP(self.F_kPadN):5}, {BOOL_MAP(self.F_kSaveInvRms):5}'
|
||||
t_ = f'{DATA_TYPE_MAP[self.F_XDataType]}, {DATA_TYPE_MAP[self.F_YDataType]}, {DATA_TYPE_MAP[self.F_SmoothScaleDataType]}, {DATA_TYPE_MAP[self.F_YScaleDataType]}, {DATA_TYPE_MAP[self.F_UnquantYDataType]}, {self.F_Repeat_M:2}, {self.F_Repeat_N:2}, {self.F_ThreadPerBlock_M:2}, {self.F_ThreadPerBlock_N:4}'
|
||||
t_ += f', {self.F_Vector_N:2}, {BOOL_MAP(self.F_kPadN):5}, {BOOL_MAP(self.F_kSaveInvRms):5}, {BOOL_MAP(self.F_kSaveUnquant):5}'
|
||||
t_ += f', {BOOL_MAP(self.F_kTwoPass):5}, {self.F_kFusedAdd:4}, {self.F_kFusedQuant:4}'
|
||||
return t_
|
||||
|
||||
@@ -390,6 +412,7 @@ float rmsnorm2d_fwd(rmsnorm2d_fwd_traits t,
|
||||
F_N : str
|
||||
F_add : int
|
||||
F_sweep : int
|
||||
F_saveunquant : bool
|
||||
instance_list : List[Any] # List[h_traits]
|
||||
|
||||
@property
|
||||
@@ -401,6 +424,8 @@ float rmsnorm2d_fwd(rmsnorm2d_fwd_traits t,
|
||||
nnn = nnn + '_' + FUSED_ADD_ENUM_STR_MAP[self.F_add]
|
||||
if self.F_sweep != 0:
|
||||
nnn = nnn + '_' + FUSED_FUSED_SWEEP_STR_MAP[self.F_sweep]
|
||||
if self.F_saveunquant:
|
||||
nnn = nnn + '_saveunquant'
|
||||
return nnn
|
||||
|
||||
@property
|
||||
@@ -451,11 +476,11 @@ float rmsnorm2d_fwd(rmsnorm2d_fwd_traits t,
|
||||
if ins.F_kFusedQuant == 0:
|
||||
_sweep_cond = 't.fused_quant == {f_fused_sweep}'.format(f_fused_sweep = ins.F_kFusedQuant)
|
||||
elif ins.F_kFusedQuant == 1:
|
||||
_sweep_cond = 't.fused_quant == {f_fused_sweep} && (t.prec_sm == \"{f_sx_type}\" && t.prec_sy == \"{f_sy_type}\")'.format(
|
||||
f_fused_sweep = ins.F_kFusedQuant, f_sx_type=ins.F_SmoothScaleDataType, f_sy_type=ins.F_YScaleDataType)
|
||||
_sweep_cond = 't.fused_quant == {f_fused_sweep} && (t.prec_sm == \"{f_sx_type}\" && t.prec_sy == \"{f_sy_type}\" && t.save_unquant == {f_suq})'.format(
|
||||
f_fused_sweep = ins.F_kFusedQuant, f_sx_type=ins.F_SmoothScaleDataType, f_sy_type=ins.F_YScaleDataType, f_suq=BOOL_MAP(ins.F_kSaveUnquant))
|
||||
elif ins.F_kFusedQuant == 2:
|
||||
_sweep_cond = 't.fused_quant == {f_fused_sweep} && (t.prec_sy == \"{f_sy_type}\")'.format(
|
||||
f_fused_sweep = ins.F_kFusedQuant, f_sy_type=ins.F_YScaleDataType)
|
||||
_sweep_cond = 't.fused_quant == {f_fused_sweep} && (t.prec_sy == \"{f_sy_type}\" && t.save_unquant == {f_suq})'.format(
|
||||
f_fused_sweep = ins.F_kFusedQuant, f_sy_type=ins.F_YScaleDataType, f_suq=BOOL_MAP(ins.F_kSaveUnquant))
|
||||
_cond = '((a.n % {f_vec_n} == 0) && (t.fused_add == {f_fused_add}) && ({f_sweep_cond}))'.format(
|
||||
f_vec_n = ins.F_Vector_N, f_fused_add = ins.F_kFusedAdd,
|
||||
f_sweep_cond = _sweep_cond)
|
||||
@@ -489,67 +514,72 @@ float rmsnorm2d_fwd(rmsnorm2d_fwd_traits t,
|
||||
#fused_sweep_list = [0, 1, 2] # NOTE: only single pass can use fused (smooth) dynamic quant
|
||||
fused_add_list = [0, 1]
|
||||
fused_sweep_list = [0, 1, 2] # NOTE: only single pass can use fused (smooth) dynamic quant
|
||||
bool_list = [False, True]
|
||||
|
||||
# rm rn tm tn vn pd mv 2p add sweep
|
||||
h_trait_dict = {'64' : [ h_traits('x', 'y', 'xs', 'ys', 1, 1, 8, 8, 8, True, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 1, 4, 16, 4, True, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 1, 4, 64, 1, True, False, False, 0, 0)],
|
||||
'128' : [ h_traits('x', 'y', 'xs', 'ys', 1, 1, 4, 16, 8, True, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 1, 4, 64, 2, True, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 2, 4, 64, 1, True, False, False, 0, 0)],
|
||||
'256' : [ h_traits('x', 'y', 'xs', 'ys', 1, 1, 4, 64, 4, True, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 2, 4, 64, 2, True, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 4, 4, 64, 1, True, False, False, 0, 0)],
|
||||
'512' : [ h_traits('x', 'y', 'xs', 'ys', 1, 1, 4, 64, 8, True, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 2, 4, 64, 4, True, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 4, 4, 64, 2, True, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 8, 4, 64, 1, True, False, False, 0, 0)],
|
||||
'768' : [ h_traits('x', 'y', 'xs', 'ys', 1, 3, 4, 64, 4, True, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 6, 4, 64, 2, True, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 12, 4, 64, 1, True, False, False, 0, 0)],
|
||||
'1024' :[ h_traits('x', 'y', 'xs', 'ys', 1, 1, 2, 128, 8, True, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 2, 2, 128, 4, True, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 4, 2, 128, 2, True, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 256, 1, True, False, False, 0, 0)],
|
||||
'1536' :[ h_traits('x', 'y', 'xs', 'ys', 1, 3, 4, 64, 8, True, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 3, 2, 128, 4, True, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 3, 1, 256, 2, True, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 6, 1, 256, 1, True, False, False, 0, 0)],
|
||||
'2048' :[ h_traits('x', 'y', 'xs', 'ys', 1, 1, 1, 256, 8, True, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 2, 1, 256, 4, True, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 256, 2, True, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 8, 1, 256, 1, True, False, False, 0, 0)],
|
||||
'3072' :[ h_traits('x', 'y', 'xs', 'ys', 1, 3, 1, 128, 8, True, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 3, 1, 256, 4, True, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 6, 1, 256, 2, True, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 3, 1,1024, 1, True, False, False, 0, 0)],
|
||||
'4096' :[ h_traits('x', 'y', 'xs', 'ys', 1, 2, 1, 256, 8, True, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 256, 4, True, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 2, 1,1024, 2, True, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 4, 1,1024, 1, True, False, False, 0, 0)],
|
||||
'6144' :[ h_traits('x', 'y', 'xs', 'ys', 1, 3, 1, 256, 8, True, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 3, 1, 512, 4, True, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 3, 1,1024, 2, True, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 6, 1,1024, 1, True, False, False, 0, 0)],
|
||||
'8192' :[ h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 256, 8, True, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 512, 4, True, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 4, 1,1024, 2, True, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 8, 1,1024, 1, True, False, False, 0, 0)],
|
||||
'big' :[ h_traits('x', 'y', 'xs', 'ys', 1, 2, 1, 256, 8, True, False, True, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 256, 4, True, False, True, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 2, 1,1024, 2, True, False, True, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 4, 1,1024, 1, True, False, True, 0, 0)]}
|
||||
# rm rn tm tn vn pd mv unquant 2p add sweep
|
||||
h_trait_dict = {'64' : [ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 1, 8, 8, 8, True, False, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 1, 4, 16, 4, True, False, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 1, 4, 64, 1, True, False, False, False, 0, 0)],
|
||||
'128' : [ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 1, 4, 16, 8, True, False, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 1, 4, 64, 2, True, False, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 2, 4, 64, 1, True, False, False, False, 0, 0)],
|
||||
'256' : [ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 1, 4, 64, 4, True, False, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 2, 4, 64, 2, True, False, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 4, 64, 1, True, False, False, False, 0, 0)],
|
||||
'512' : [ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 1, 4, 64, 8, True, False, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 2, 4, 64, 4, True, False, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 4, 64, 2, True, False, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 8, 4, 64, 1, True, False, False, False, 0, 0)],
|
||||
'640' : [ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 5, 4, 64, 2, True, False, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 5, 4, 128, 1, True, False, False, False, 0, 0)],
|
||||
'768' : [ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 3, 4, 64, 4, True, False, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 6, 4, 64, 2, True, False, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 12, 4, 64, 1, True, False, False, False, 0, 0)],
|
||||
'1024' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 2, 2, 64, 8, True, False, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 2, 64, 4, True, False, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 8, 2, 64, 2, True, False, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 1, 256, 1, True, False, False, False, 0, 0)],
|
||||
'1536' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 3, 4, 64, 8, True, False, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 3, 2, 128, 4, True, False, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 3, 1, 256, 2, True, False, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 6, 1, 256, 1, True, False, False, False, 0, 0)],
|
||||
'2048' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 1, 1, 256, 8, True, False, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 2, 1, 256, 4, True, False, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 1, 256, 2, True, False, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 8, 1, 256, 1, True, False, False, False, 0, 0)],
|
||||
'3072' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 3, 1, 128, 8, True, False, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 3, 1, 256, 4, True, False, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 6, 1, 256, 2, True, False, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 3, 1,1024, 1, True, False, False, False, 0, 0)],
|
||||
'4096' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 2, 1, 256, 8, True, False, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 1, 256, 4, True, False, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 2, 1,1024, 2, True, False, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 1,1024, 1, True, False, False, False, 0, 0)],
|
||||
'6144' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 3, 1, 256, 8, True, False, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 3, 1, 512, 4, True, False, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 3, 1,1024, 2, True, False, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 6, 1,1024, 1, True, False, False, False, 0, 0)],
|
||||
'8192' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 1, 256, 8, True, False, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 1, 512, 4, True, False, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 1,1024, 2, True, False, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 8, 1,1024, 1, True, False, False, False, 0, 0)],
|
||||
'big' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 1, 1,1024, 8, True, False, False, True, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 1, 256, 4, True, False, False, True, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 12, 1, 256, 2, True, False, False, True, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 1,1024, 1, True, False, False, True, 0, 0)]}
|
||||
total_blob = list()
|
||||
for hs_key in h_trait_dict:
|
||||
hs = h_trait_dict[hs_key]
|
||||
current_n = hs[0].F_Repeat_N * hs[0].F_ThreadPerBlock_N * hs[0].F_Vector_N
|
||||
for dtype, scale_type, fused_add, fused_quant in itertools.product(dtype_list, scale_list, fused_add_list, fused_sweep_list):
|
||||
for dtype, scale_type, fused_add, fused_quant, save_unquant in itertools.product(dtype_list, scale_list, fused_add_list, fused_sweep_list, bool_list):
|
||||
prec_i, prec_o = dtype.split(',')
|
||||
scale_sm, scale_y = scale_type.split(',')
|
||||
if prec_o in dynamic_quant_out_dtype and fused_quant != 1 and fused_quant != 2:
|
||||
continue # skip non dynamic quant case
|
||||
if (fused_quant == 1 or fused_quant == 2) and hs_key == 'big':
|
||||
continue
|
||||
if (fused_quant == 0 and save_unquant == True):
|
||||
continue # save_unquant should always be false when there is no quant enabled
|
||||
current_hs = list()
|
||||
for chs_ in hs:
|
||||
h_ = copy.copy(chs_) # copy the base instance out
|
||||
@@ -557,12 +587,14 @@ float rmsnorm2d_fwd(rmsnorm2d_fwd_traits t,
|
||||
h_.F_YDataType = prec_o
|
||||
h_.F_SmoothScaleDataType = scale_sm
|
||||
h_.F_YScaleDataType = scale_y
|
||||
h_.F_UnquantYDataType = prec_i
|
||||
h_.F_kFusedAdd = fused_add
|
||||
h_.F_kFusedQuant = fused_quant
|
||||
h_.F_kSaveUnquant = save_unquant
|
||||
current_hs.append(h_) # + "\n"
|
||||
#f.write(str(f.parent / GEN_DIR / (blobs.api_common_header_
|
||||
current_n_str = 'big' if hs_key == 'big' else current_n
|
||||
total_blob.append(h_instance(dtype, current_n_str, fused_add, fused_quant, current_hs))
|
||||
total_blob.append(h_instance(dtype, current_n_str, fused_add, fused_quant, save_unquant, current_hs))
|
||||
return total_blob
|
||||
|
||||
def list_blobs(self) -> None:
|
||||
|
||||
@@ -38,6 +38,7 @@ auto create_args(int argc, char* argv[])
|
||||
.insert("yr_stride", "-1", "y residule row_stride, if -1 then equal to n")
|
||||
.insert("e", "1e-5", "epsilon")
|
||||
.insert("save_rms", "0", "save rms(invrms) or not. set to 1 in training case")
|
||||
.insert("save_unquant", "0", "save result before quant")
|
||||
.insert("v", "1", "cpu validation or not")
|
||||
.insert("kname", "1", "print kernel name or not")
|
||||
.insert("prec_i", "fp16", "input precision")
|
||||
@@ -61,7 +62,8 @@ template <typename InDataType,
|
||||
typename OutDataType,
|
||||
typename SmoothScaleDataType,
|
||||
typename YScaleDataType,
|
||||
bool SaveRms>
|
||||
bool SaveRms,
|
||||
bool SaveUnquant>
|
||||
bool run(const ck_tile::ArgParser& arg_parser)
|
||||
{
|
||||
ck_tile::index_t m = arg_parser.get_int("m");
|
||||
@@ -113,6 +115,14 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
return false;
|
||||
}
|
||||
|
||||
if((fused_quant == 0) && SaveUnquant)
|
||||
{
|
||||
std::cout
|
||||
<< "save_unquant should be 0 if quant output is not enabled because it is meaningless. "
|
||||
<< "Output Y is what wanted." << std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
using TypeConfig =
|
||||
RmsnormTypeConfig<InDataType, OutDataType, SmoothScaleDataType, YScaleDataType>;
|
||||
|
||||
@@ -124,6 +134,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
|
||||
using InvRmsDataType =
|
||||
std::conditional_t<SaveRms, typename TypeConfig::InvRmsDataType, ck_tile::null_type>;
|
||||
using UnquantYDataType =
|
||||
std::conditional_t<SaveUnquant, typename TypeConfig::UnquantYDataType, ck_tile::null_type>;
|
||||
|
||||
using ComputeDataType = typename TypeConfig::ComputeDataType;
|
||||
|
||||
@@ -143,6 +155,10 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
|
||||
ck_tile::HostTensor<InvRmsDataType> invRms_host_ref({m});
|
||||
|
||||
ck_tile::HostTensor<UnquantYDataType> unquant_y_host_ref({m, n}, {y_stride, 1});
|
||||
ck_tile::HostTensor<UnquantYDataType> unquant_y_host_dev({m, n}, {y_stride, 1});
|
||||
ck_tile::HostTensor<ck_tile::null_type> unquant_y_null({1});
|
||||
|
||||
ck_tile::FillUniformDistribution<XDataType>{-.5f, .5f}(x_host);
|
||||
ck_tile::FillUniformDistribution<XResidualDataType>{-.5f, .5f}(x_residual_host);
|
||||
ck_tile::FillUniformDistribution<SmoothScaleDataType>{-1.f, 1.f}(sm_scale_host);
|
||||
@@ -155,6 +171,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
ck_tile::DeviceMem sm_scale_buf(sm_scale_host_dev.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem x_residual_buf(x_residual_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem y_residual_buf(y_residual_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem unquant_y_buf(unquant_y_host_dev.get_element_space_size_in_bytes());
|
||||
|
||||
x_buf.ToDevice(x_host.data());
|
||||
gamma_buf.ToDevice(gamma_host.data());
|
||||
@@ -179,7 +196,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
<< ", xr_stride:" << xr_stride << ", y_stride:" << y_stride
|
||||
<< ", yr_stride:" << yr_stride << std::flush;
|
||||
|
||||
rmsnorm2d_fwd_traits traits{prec_i, prec_o, prec_sm, prec_sy, SaveRms, fused_add, fused_quant};
|
||||
rmsnorm2d_fwd_traits traits{
|
||||
prec_i, prec_o, prec_sm, prec_sy, SaveRms, SaveUnquant, fused_add, fused_quant};
|
||||
|
||||
rmsnorm2d_fwd_args args{x_buf.GetDeviceBuffer(),
|
||||
fused_add != 0 ? x_residual_buf.GetDeviceBuffer() : nullptr,
|
||||
@@ -189,6 +207,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
fused_add == 1 ? y_residual_buf.GetDeviceBuffer() : nullptr,
|
||||
fused_quant != 0 ? y_scale_buf.GetDeviceBuffer() : nullptr,
|
||||
nullptr, // p_invRms, unsupported yet
|
||||
SaveUnquant ? unquant_y_buf.GetDeviceBuffer() : nullptr,
|
||||
epsilon,
|
||||
m,
|
||||
n,
|
||||
@@ -203,6 +222,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
std::size_t num_byte =
|
||||
sizeof(XDataType) * m * n + sizeof(GammaDataType) * n + sizeof(YDataType) * m * n;
|
||||
num_byte += SaveRms ? sizeof(InvRmsDataType) * m * n : 0;
|
||||
num_byte += SaveUnquant ? sizeof(UnquantYDataType) * m * n : 0;
|
||||
num_byte += fused_add ? sizeof(XResidualDataType) * m * n : 0;
|
||||
num_byte += ((fused_quant == 1) || (fused_quant == 2)) ? sizeof(YScaleDataType) * m : 0;
|
||||
num_byte += (fused_quant == 1) ? sizeof(SmoothScaleDataType) * n : 0;
|
||||
@@ -262,21 +282,57 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
}
|
||||
};
|
||||
|
||||
ck_tile::reference_rmsnorm2d_fwd<XDataType,
|
||||
GammaDataType,
|
||||
ComputeDataType,
|
||||
YDataType,
|
||||
InvRmsDataType>(
|
||||
x_host, gamma_host, y_host_ref, invRms_host_ref, epsilon, dquant_functor);
|
||||
auto default_and_dquant_functor = [&](int m_, auto& o_unquant_, auto& o_, auto& acc_) {
|
||||
const int N = acc_.mDesc.get_lengths()[1];
|
||||
for(int n_ = 0; n_ < N; ++n_)
|
||||
{
|
||||
o_unquant_(m_, n_) = ck_tile::type_convert<OutDataType>(acc_(m_, n_));
|
||||
}
|
||||
|
||||
dquant_functor(m_, o_, acc_);
|
||||
};
|
||||
|
||||
if constexpr(SaveUnquant)
|
||||
{
|
||||
ck_tile::reference_rmsnorm2d_fwd<XDataType,
|
||||
GammaDataType,
|
||||
ComputeDataType,
|
||||
YDataType,
|
||||
InvRmsDataType,
|
||||
UnquantYDataType>(x_host,
|
||||
gamma_host,
|
||||
y_host_ref,
|
||||
invRms_host_ref,
|
||||
unquant_y_host_ref,
|
||||
epsilon,
|
||||
default_and_dquant_functor);
|
||||
}
|
||||
else
|
||||
{
|
||||
ck_tile::reference_rmsnorm2d_fwd<XDataType,
|
||||
GammaDataType,
|
||||
ComputeDataType,
|
||||
YDataType,
|
||||
InvRmsDataType,
|
||||
UnquantYDataType>(x_host,
|
||||
gamma_host,
|
||||
y_host_ref,
|
||||
invRms_host_ref,
|
||||
unquant_y_host_ref,
|
||||
epsilon,
|
||||
dquant_functor);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
assert(SaveUnquant == false);
|
||||
ck_tile::reference_rmsnorm2d_fwd<XDataType,
|
||||
GammaDataType,
|
||||
ComputeDataType,
|
||||
YDataType,
|
||||
InvRmsDataType>(
|
||||
x_host, gamma_host, y_host_ref, invRms_host_ref, epsilon);
|
||||
InvRmsDataType,
|
||||
ck_tile::null_type>(
|
||||
x_host, gamma_host, y_host_ref, invRms_host_ref, unquant_y_null, epsilon);
|
||||
}
|
||||
|
||||
y_buf.FromDevice(y_host_dev.data());
|
||||
@@ -293,6 +349,15 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
pass = ck_tile::check_err(
|
||||
y_host_dev, y_host_ref, std::string("\nOUT Error: Incorrect results!"), rtol, atol);
|
||||
|
||||
if constexpr(SaveUnquant)
|
||||
{
|
||||
pass &= ck_tile::check_err(unquant_y_host_dev,
|
||||
unquant_y_host_ref,
|
||||
std::string("\n OUT ERROR: Incorrect unquant results!"),
|
||||
rtol,
|
||||
atol);
|
||||
}
|
||||
|
||||
if(fused_add == 1)
|
||||
{
|
||||
pass &= ck_tile::check_err(y_residual_host_dev,
|
||||
@@ -331,6 +396,23 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
rtol,
|
||||
atol);
|
||||
}
|
||||
|
||||
if constexpr(SaveUnquant)
|
||||
{
|
||||
std::vector<UnquantYDataType> unquant_y_host_dev_row(
|
||||
unquant_y_host_dev.begin() + i_r * y_stride,
|
||||
unquant_y_host_dev.begin() + i_r * y_stride + n);
|
||||
std::vector<UnquantYDataType> unquant_y_host_ref_row(
|
||||
unquant_y_host_ref.begin() + i_r * y_stride,
|
||||
unquant_y_host_ref.begin() + i_r * y_stride + n);
|
||||
pass &=
|
||||
ck_tile::check_err(unquant_y_host_dev_row,
|
||||
unquant_y_host_ref_row,
|
||||
std::string("\nOUT[") + std::to_string(i_r) +
|
||||
std::string("] Error: Incorrect unquant y results!"),
|
||||
rtol,
|
||||
atol);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -350,6 +432,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
return pass;
|
||||
}
|
||||
|
||||
bool is_quant_data_type(const std::string& prec) { return (prec == "int8") || (prec == "fp8"); }
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
auto [result, arg_parser] = create_args(argc, argv);
|
||||
@@ -373,48 +457,79 @@ int main(int argc, char* argv[])
|
||||
prec_sy = "fp32";
|
||||
}
|
||||
|
||||
int save_rms = arg_parser.get_int("save_rms");
|
||||
int save_rms = arg_parser.get_int("save_rms");
|
||||
int fused_quant = arg_parser.get_int("fquant");
|
||||
int save_unquant =
|
||||
arg_parser.get_int("save_unquant") && is_quant_data_type(prec_o) && (fused_quant != 0);
|
||||
|
||||
if(prec_i == "fp16" && prec_o == "fp16" && prec_sm == "fp32" && prec_sy == "fp32" && save_rms)
|
||||
{
|
||||
return run<ck_tile::half_t, ck_tile::half_t, float, float, true>(arg_parser) ? 0 : -2;
|
||||
return run<ck_tile::half_t, ck_tile::half_t, float, float, true, false>(arg_parser) ? 0
|
||||
: -2;
|
||||
}
|
||||
else if(prec_i == "fp16" && prec_o == "fp16" && prec_sm == "fp32" && prec_sy == "fp32" &&
|
||||
!save_rms)
|
||||
{
|
||||
return run<ck_tile::half_t, ck_tile::half_t, float, float, false>(arg_parser) ? 0 : -2;
|
||||
return run<ck_tile::half_t, ck_tile::half_t, float, float, false, false>(arg_parser) ? 0
|
||||
: -2;
|
||||
}
|
||||
else if(prec_i == "bf16" && prec_o == "bf16" && prec_sm == "fp32" && prec_sy == "fp32" &&
|
||||
save_rms)
|
||||
{
|
||||
return run<ck_tile::bf16_t, ck_tile::bf16_t, float, float, true>(arg_parser) ? 0 : -2;
|
||||
return run<ck_tile::bf16_t, ck_tile::bf16_t, float, float, true, false>(arg_parser) ? 0
|
||||
: -2;
|
||||
}
|
||||
else if(prec_i == "bf16" && prec_o == "bf16" && prec_sm == "fp32" && prec_sy == "fp32" &&
|
||||
!save_rms)
|
||||
{
|
||||
return run<ck_tile::bf16_t, ck_tile::bf16_t, float, float, true>(arg_parser) ? 0 : -2;
|
||||
return run<ck_tile::bf16_t, ck_tile::bf16_t, float, float, false, false>(arg_parser) ? 0
|
||||
: -2;
|
||||
}
|
||||
|
||||
// dynamic quant case, only in inference
|
||||
else if(prec_i == "fp16" && prec_o == "int8" && prec_sm == "fp32" && prec_sy == "fp32" &&
|
||||
!save_rms)
|
||||
!save_rms && !save_unquant)
|
||||
{
|
||||
return run<ck_tile::half_t, ck_tile::int8_t, float, float, true>(arg_parser) ? 0 : -2;
|
||||
return run<ck_tile::half_t, ck_tile::int8_t, float, float, true, false>(arg_parser) ? 0
|
||||
: -2;
|
||||
}
|
||||
else if(prec_i == "bf16" && prec_o == "int8" && prec_sm == "fp32" && prec_sy == "fp32" &&
|
||||
!save_rms)
|
||||
!save_rms && !save_unquant)
|
||||
{
|
||||
return run<ck_tile::bf16_t, ck_tile::int8_t, float, float, true>(arg_parser) ? 0 : -2;
|
||||
return run<ck_tile::bf16_t, ck_tile::int8_t, float, float, true, false>(arg_parser) ? 0
|
||||
: -2;
|
||||
}
|
||||
else if(prec_i == "fp16" && prec_o == "fp8" && prec_sm == "fp32" && prec_sy == "fp32" &&
|
||||
!save_rms)
|
||||
!save_rms && !save_unquant)
|
||||
{
|
||||
return run<ck_tile::half_t, ck_tile::fp8_t, float, float, false>(arg_parser) ? 0 : -2;
|
||||
return run<ck_tile::half_t, ck_tile::fp8_t, float, float, false, false>(arg_parser) ? 0
|
||||
: -2;
|
||||
}
|
||||
else if(prec_i == "bf16" && prec_o == "fp8" && prec_sm == "fp32" && prec_sy == "fp32" &&
|
||||
!save_rms)
|
||||
!save_rms && !save_unquant)
|
||||
{
|
||||
return run<ck_tile::bf16_t, ck_tile::fp8_t, float, float, false>(arg_parser) ? 0 : -2;
|
||||
return run<ck_tile::bf16_t, ck_tile::fp8_t, float, float, false, false>(arg_parser) ? 0
|
||||
: -2;
|
||||
}
|
||||
else if(prec_i == "fp16" && prec_o == "int8" && prec_sm == "fp32" && prec_sy == "fp32" &&
|
||||
!save_rms && save_unquant)
|
||||
{
|
||||
return run<ck_tile::half_t, ck_tile::int8_t, float, float, true, true>(arg_parser) ? 0 : -2;
|
||||
}
|
||||
else if(prec_i == "bf16" && prec_o == "int8" && prec_sm == "fp32" && prec_sy == "fp32" &&
|
||||
!save_rms && save_unquant)
|
||||
{
|
||||
return run<ck_tile::bf16_t, ck_tile::int8_t, float, float, true, true>(arg_parser) ? 0 : -2;
|
||||
}
|
||||
else if(prec_i == "fp16" && prec_o == "fp8" && prec_sm == "fp32" && prec_sy == "fp32" &&
|
||||
!save_rms && save_unquant)
|
||||
{
|
||||
return run<ck_tile::half_t, ck_tile::fp8_t, float, float, false, true>(arg_parser) ? 0 : -2;
|
||||
}
|
||||
else if(prec_i == "bf16" && prec_o == "fp8" && prec_sm == "fp32" && prec_sy == "fp32" &&
|
||||
!save_rms && save_unquant)
|
||||
{
|
||||
return run<ck_tile::bf16_t, ck_tile::fp8_t, float, float, false, true>(arg_parser) ? 0 : -2;
|
||||
}
|
||||
|
||||
return -3;
|
||||
|
||||
@@ -21,6 +21,7 @@ struct RmsnormTypeConfig<ck_tile::half_t, OutType, SmoothScaleDataType_, YScaleD
|
||||
using YDataType = OutType;
|
||||
using GammaDataType = ck_tile::half_t;
|
||||
using InvRmsDataType = ck_tile::half_t;
|
||||
using UnquantYDataType = ck_tile::half_t;
|
||||
using ComputeDataType = float;
|
||||
using SmoothScaleDataType = SmoothScaleDataType_;
|
||||
using YScaleDataType = YScaleDataType_;
|
||||
@@ -33,6 +34,7 @@ struct RmsnormTypeConfig<ck_tile::bf16_t, OutType, SmoothScaleDataType_, YScaleD
|
||||
using YDataType = OutType;
|
||||
using GammaDataType = ck_tile::bf16_t;
|
||||
using InvRmsDataType = ck_tile::bf16_t;
|
||||
using UnquantYDataType = ck_tile::bf16_t;
|
||||
using ComputeDataType = float;
|
||||
using SmoothScaleDataType = SmoothScaleDataType_;
|
||||
using YScaleDataType = YScaleDataType_;
|
||||
@@ -59,6 +61,7 @@ struct rmsnorm2d_fwd_traits
|
||||
std::string prec_sy; // y-scale, used for [M*1] output for next layer
|
||||
|
||||
bool save_rms;
|
||||
bool save_unquant;
|
||||
int fused_add; // 0:no-add, 1:pre-add-store, 2:pre-add
|
||||
int fused_quant; // 0:no-sweep, 1:smooth-dynamic-quant, 2:dynamic-quant
|
||||
};
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
#!/bin/sh
|
||||
EXE="$(find . -name tile_rmsnorm2d_fwd -type f | head -n 1)"
|
||||
|
||||
for fquant in "" "-fquant=1 -prec_o=int8" "-fquant=2 -prec_o=int8" "-fquant=1 -prec_o=fp8" "-fquant=2 -prec_o=fp8"; do
|
||||
for fquant in "" "-fquant=1 -prec_o=int8" "-fquant=2 -prec_o=int8" "-fquant=1 -prec_o=fp8" "-fquant=2 -prec_o=fp8"\
|
||||
"-fquant=1 -prec_o=int8 -save_unquant=1" "-fquant=2 -prec_o=int8 -save_unquant=1" "-fquant=1 -prec_o=fp8 -save_unquant=1" "-fquant=2 -prec_o=fp8 -save_unquant=1"; do
|
||||
for pr_i in "fp16" "bf16" ; do
|
||||
for fadd in "0" "1"; do
|
||||
$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=99 -n=13
|
||||
@@ -27,6 +28,14 @@ $EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=7 -n=2734
|
||||
$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=1 -n=3182
|
||||
$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=9 -n=4096
|
||||
$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=3 -n=8192
|
||||
done
|
||||
done
|
||||
done
|
||||
|
||||
# The following cases uses two pass pipeline which doesn't support quant epilogue.
|
||||
for fquant in ""
|
||||
for pr_i in "fp16" "bf16" ; do
|
||||
for fadd in "0" "1"; do
|
||||
$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=1 -n=10547
|
||||
#$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=3 -n=17134
|
||||
done
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -35,11 +35,13 @@ template <typename XDataType,
|
||||
typename ComputeDataType,
|
||||
typename YDataType,
|
||||
typename InvRmsDataType,
|
||||
typename UnquantYDataType,
|
||||
typename Epilogue = reference_rmsnorm2d_default_epilogue>
|
||||
void reference_rmsnorm2d_fwd(const HostTensor<XDataType>& x_m_n,
|
||||
const HostTensor<GammaDataType>& gamma_n,
|
||||
HostTensor<YDataType>& y_m_n,
|
||||
HostTensor<InvRmsDataType>& invRms_m,
|
||||
HostTensor<UnquantYDataType>& unquant_y_m_n,
|
||||
ComputeDataType epsilon,
|
||||
Epilogue epilogue_functor = {})
|
||||
{
|
||||
@@ -69,7 +71,14 @@ void reference_rmsnorm2d_fwd(const HostTensor<XDataType>& x_m_n,
|
||||
acc(m, n) = x * divisor * gamma;
|
||||
}
|
||||
|
||||
epilogue_functor(m, y_m_n, acc);
|
||||
if constexpr(!std::is_same_v<UnquantYDataType, ck_tile::null_type>)
|
||||
{
|
||||
epilogue_functor(m, unquant_y_m_n, y_m_n, acc);
|
||||
}
|
||||
else
|
||||
{
|
||||
epilogue_functor(m, y_m_n, acc);
|
||||
}
|
||||
};
|
||||
|
||||
make_ParallelTensorFunctor(rmsnorm2d_fwd_func, invRms_m.mDesc.get_lengths()[0])(
|
||||
|
||||
@@ -6,6 +6,7 @@
|
||||
#include "ck_tile/ops/epilogue/cshuffle_epilogue.hpp"
|
||||
#include "ck_tile/ops/epilogue/default_2d_epilogue.hpp"
|
||||
#include "ck_tile/ops/epilogue/dynamic_quant_epilogue.hpp"
|
||||
#include "ck_tile/ops/epilogue/default_2d_and_dynamic_quant_epilogue.hpp"
|
||||
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
|
||||
#include "ck_tile/ops/common/tensor_layout.hpp"
|
||||
#include "ck_tile/ops/common/utils.hpp"
|
||||
|
||||
@@ -0,0 +1,91 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "default_2d_epilogue.hpp"
|
||||
#include "dynamic_quant_epilogue.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// User can reuse DynamicQuantEpilogueTraits with this epilogue
|
||||
template <bool kPadM_,
|
||||
bool kPadN_,
|
||||
bool UseSmoothInputScale_,
|
||||
bool UseRawStore_ = true,
|
||||
bool UseMax3_ = false>
|
||||
using Default2DAndDynamicQuantEpilogueTraits =
|
||||
DynamicQuantEpilogueTraits<kPadM_, kPadN_, UseSmoothInputScale_, UseRawStore_, UseMax3_>;
|
||||
|
||||
// This epilogue just store out a M*N matrix, row major
|
||||
template <typename AccDataType_,
|
||||
typename SmoothScaleDataType_,
|
||||
typename YScaleDataType_,
|
||||
typename ODataType_,
|
||||
typename UnquantYDataType_,
|
||||
typename BlockShape_,
|
||||
typename Traits_>
|
||||
struct Default2DAndDynamicQuantEpilogueProblem
|
||||
{
|
||||
using AccDataType = remove_cvref_t<AccDataType_>;
|
||||
using SmoothScaleDataType = remove_cvref_t<SmoothScaleDataType_>;
|
||||
using YScaleDataType = remove_cvref_t<YScaleDataType_>;
|
||||
using ODataType = remove_cvref_t<ODataType_>;
|
||||
using UnquantYDataType = remove_cvref_t<UnquantYDataType_>;
|
||||
using BlockShape = remove_cvref_t<BlockShape_>; // can consum generic 2d shape
|
||||
using Traits = remove_cvref_t<Traits_>;
|
||||
};
|
||||
|
||||
template <typename Problem_, typename Policy_ = void>
|
||||
struct Default2DAndDynamicQuantEpilogue
|
||||
{
|
||||
using Problem = remove_cvref_t<Problem_>;
|
||||
using AccDataType = remove_cvref_t<typename Problem::AccDataType>;
|
||||
using UnquantYDataType = remove_cvref_t<typename Problem::UnquantYDataType>;
|
||||
|
||||
static constexpr bool kPadM = Problem::Traits::kPadM;
|
||||
static constexpr bool kPadN = Problem::Traits::kPadN;
|
||||
static constexpr bool UseRawStore = Problem::Traits::UseRawStore;
|
||||
|
||||
using Default2DProblem =
|
||||
Default2DEpilogueProblem<AccDataType, UnquantYDataType, kPadM, kPadN, UseRawStore>;
|
||||
using Default2D = Default2DEpilogue<Default2DProblem>;
|
||||
using DynamicQuant = DynamicQuantEpilogue<Problem>;
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
|
||||
{
|
||||
return max(Default2D::GetSmemSize(), DynamicQuant::GetSmemSize());
|
||||
}
|
||||
|
||||
template <typename ODramWindowTmpD,
|
||||
typename ODramWindowTmpQ,
|
||||
typename SmoothScaleWindow,
|
||||
typename YScaleWindow,
|
||||
typename OAccTile>
|
||||
CK_TILE_DEVICE auto operator()(ODramWindowTmpD& o_direct_dram_window_tmp,
|
||||
ODramWindowTmpQ& o_quant_dram_window_tmp,
|
||||
const SmoothScaleWindow& sm_scale_window_,
|
||||
YScaleWindow& y_scale_window,
|
||||
const OAccTile& o_acc_tile,
|
||||
void* smem)
|
||||
{
|
||||
Default2D{}(o_direct_dram_window_tmp, o_acc_tile, smem);
|
||||
DynamicQuant{}(o_quant_dram_window_tmp, sm_scale_window_, y_scale_window, o_acc_tile, smem);
|
||||
}
|
||||
|
||||
template <typename ODramWindowTmpD,
|
||||
typename ODramWindowTmpQ,
|
||||
typename YScaleWindow,
|
||||
typename OAccTile>
|
||||
CK_TILE_DEVICE auto operator()(ODramWindowTmpD& o_direct_dram_window_tmp,
|
||||
ODramWindowTmpQ& o_quant_dram_window_tmp,
|
||||
YScaleWindow& y_scale_window,
|
||||
const OAccTile& o_acc_tile,
|
||||
void* smem)
|
||||
{
|
||||
Default2D{}(o_direct_dram_window_tmp, o_acc_tile, smem);
|
||||
DynamicQuant{}(o_quant_dram_window_tmp, y_scale_window, o_acc_tile, smem);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -182,9 +182,16 @@ struct Layernorm2dFwdPipelineTwoPass
|
||||
ck_tile::index_t stride_to_right_most_window =
|
||||
row_size % Block_N == 0 ? row_size - Block_N : row_size - row_size % Block_N;
|
||||
|
||||
move_tile_window(x_window, {0, -Block_N});
|
||||
move_tile_window(x_residual_window, {0, -Block_N});
|
||||
move_tile_window(x_bias_window, {-Block_N});
|
||||
if constexpr(kFusedAdd == Layernorm2dFusedAddEnum::PRE_ADD_STORE)
|
||||
{
|
||||
move_tile_window(y_residual_window, {0, -Block_N});
|
||||
}
|
||||
else
|
||||
{
|
||||
move_tile_window(x_window, {0, -Block_N});
|
||||
move_tile_window(x_residual_window, {0, -Block_N});
|
||||
move_tile_window(x_bias_window, {-Block_N});
|
||||
}
|
||||
move_tile_window(gamma_window, {stride_to_right_most_window});
|
||||
move_tile_window(beta_window, {stride_to_right_most_window});
|
||||
move_tile_window(y_window, {0, stride_to_right_most_window});
|
||||
@@ -192,28 +199,43 @@ struct Layernorm2dFwdPipelineTwoPass
|
||||
// layernorm computation
|
||||
for(int iN = __builtin_amdgcn_readfirstlane(0); iN < num_n_tile_iteration; ++iN)
|
||||
{
|
||||
auto x = load_tile(x_window);
|
||||
auto x_resi = load_tile(x_residual_window);
|
||||
const auto x_bias = load_tile(x_bias_window);
|
||||
auto acc = cast_tile<ComputeDataType>(x);
|
||||
auto acc = make_static_distributed_tensor<ComputeDataType>(
|
||||
decltype(load_tile(x_window))::get_tile_distribution());
|
||||
|
||||
if constexpr(kXbias == Layernorm2dXBiasEnum::ADD_BIAS)
|
||||
if constexpr(kFusedAdd == Layernorm2dFusedAddEnum::PRE_ADD_STORE)
|
||||
{
|
||||
sweep_tile(x, [&](auto idx) {
|
||||
// compute x = bias + x
|
||||
constexpr auto j_idx = make_tuple(idx[number<1>{}]);
|
||||
acc(idx) = type_convert<ComputeDataType>(x_bias[j_idx]) + acc(idx);
|
||||
});
|
||||
acc = cast_tile<ComputeDataType>(load_tile(y_residual_window));
|
||||
move_tile_window(y_residual_window, {0, -Block_N});
|
||||
}
|
||||
else
|
||||
{
|
||||
acc = cast_tile<ComputeDataType>(load_tile(x_window));
|
||||
move_tile_window(x_window, {0, -Block_N});
|
||||
|
||||
if constexpr(kXbias == Layernorm2dXBiasEnum::ADD_BIAS)
|
||||
{
|
||||
const auto x_bias = load_tile(x_bias_window);
|
||||
move_tile_window(x_bias_window, {-Block_N});
|
||||
|
||||
sweep_tile(acc, [&](auto idx) {
|
||||
// compute x = bias + x
|
||||
constexpr auto j_idx = make_tuple(idx[number<1>{}]);
|
||||
acc(idx) = type_convert<ComputeDataType>(x_bias[j_idx]) + acc(idx);
|
||||
});
|
||||
}
|
||||
|
||||
if constexpr(kFusedAdd == Layernorm2dFusedAddEnum::PRE_ADD)
|
||||
{
|
||||
auto x_resi = load_tile(x_residual_window);
|
||||
move_tile_window(x_residual_window, {0, -Block_N});
|
||||
|
||||
sweep_tile(x_resi, [&](auto idx) {
|
||||
// compute x = x_resi + x
|
||||
acc(idx) = type_convert<ComputeDataType>(x_resi(idx)) + acc(idx);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr(kFusedAdd == Layernorm2dFusedAddEnum::PRE_ADD_STORE ||
|
||||
kFusedAdd == Layernorm2dFusedAddEnum::PRE_ADD)
|
||||
{
|
||||
sweep_tile(x_resi, [&](auto idx) {
|
||||
// compute x = x_resi + x
|
||||
acc(idx) = type_convert<ComputeDataType>(x_resi(idx)) + acc(idx);
|
||||
});
|
||||
}
|
||||
// load gamma/beta (TODO: support no gamma/beta?)
|
||||
const auto gamma = load_tile(gamma_window);
|
||||
const auto beta = load_tile(beta_window);
|
||||
@@ -235,9 +257,6 @@ struct Layernorm2dFwdPipelineTwoPass
|
||||
static_assert(kFusedQuant != Layernorm2dFusedQuantEnum::DYNAMIC_QUANT);
|
||||
Epilogue{}(y_window, ln);
|
||||
|
||||
move_tile_window(x_window, {0, -Block_N});
|
||||
move_tile_window(x_residual_window, {0, -Block_N});
|
||||
move_tile_window(x_bias_window, {-Block_N});
|
||||
move_tile_window(gamma_window, {-Block_N});
|
||||
move_tile_window(beta_window, {-Block_N});
|
||||
move_tile_window(y_window, {0, -Block_N});
|
||||
|
||||
@@ -21,6 +21,7 @@ struct Rmsnorm2dFwdHostArgs
|
||||
void* p_y_residual; // [m, n], shortcut output, prec same as input, nullptr if not used
|
||||
void* p_y_scale; // [m, 1], output a dynamic quant per row, nullptr if not used
|
||||
void* p_invRms; // [m, 1], output inv-rms, prec same as input, nullptr if not used
|
||||
void* p_y_unquant; // [m, n], output result before quant, nullptr if not used
|
||||
|
||||
float epsilon;
|
||||
|
||||
@@ -47,13 +48,15 @@ struct Rmsnorm2dFwd
|
||||
using InvRmsDataType = remove_cvref_t<typename Problem::InvRmsDataType>;
|
||||
using SmoothScaleDataType = remove_cvref_t<typename Problem::SmoothScaleDataType>;
|
||||
using YScaleDataType = remove_cvref_t<typename Problem::YScaleDataType>;
|
||||
using UnquantYDataType = remove_cvref_t<typename Problem::UnquantYDataType>;
|
||||
|
||||
// for simplicity, shortcut input/output type is same as X
|
||||
using XResidualDataType = XDataType;
|
||||
using YResidualDataType = XDataType;
|
||||
|
||||
static constexpr bool kHasGamma = !std::is_same_v<GammaDataType, null_type>;
|
||||
static constexpr bool kSaveInvRms = Problem::Traits::kSaveInvRms;
|
||||
static constexpr bool kHasGamma = !std::is_same_v<GammaDataType, null_type>;
|
||||
static constexpr bool kSaveInvRms = Problem::Traits::kSaveInvRms;
|
||||
static constexpr bool kSaveUnquant = Problem::Traits::kSaveUnquant;
|
||||
|
||||
static constexpr index_t Block_M = Problem::BlockShape::Block_M;
|
||||
static constexpr index_t Block_N = Problem::BlockShape::Block_N;
|
||||
@@ -81,6 +84,7 @@ struct Rmsnorm2dFwd
|
||||
void* p_y_residual;
|
||||
void* p_y_scale;
|
||||
void* p_invRms;
|
||||
void* p_y_unquant;
|
||||
|
||||
float epsilon;
|
||||
|
||||
@@ -103,6 +107,7 @@ struct Rmsnorm2dFwd
|
||||
hargs.p_y_residual,
|
||||
hargs.p_y_scale,
|
||||
hargs.p_invRms,
|
||||
hargs.p_y_unquant,
|
||||
hargs.epsilon,
|
||||
hargs.m,
|
||||
hargs.n,
|
||||
@@ -323,6 +328,30 @@ struct Rmsnorm2dFwd
|
||||
}
|
||||
}();
|
||||
|
||||
auto unquant_y_window = [&]() {
|
||||
if constexpr((kFusedQuant == Rmsnorm2dFusedQuantEnum::SMOOTH_DYNAMIC_QUANT ||
|
||||
kFusedQuant == Rmsnorm2dFusedQuantEnum::DYNAMIC_QUANT) &&
|
||||
kSaveUnquant)
|
||||
{
|
||||
auto tmp_ = make_naive_tensor_view<address_space_enum::global>(
|
||||
static_cast<UnquantYDataType*>(kargs.p_y_unquant),
|
||||
make_tuple(kargs.m, kargs.n),
|
||||
make_tuple(kargs.y_stride, 1),
|
||||
number<Vector_N>{},
|
||||
number<1>{});
|
||||
|
||||
auto tmp2_ = pad_tensor_view(tmp_,
|
||||
make_tuple(number<Block_M>{}, number<Block_N>{}),
|
||||
sequence<kPadM, kPadN>{});
|
||||
return make_tile_window(
|
||||
tmp2_, make_tuple(number<Block_M>{}, number<Block_N>{}), {iM, 0});
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_null_tile_window(make_tuple(number<Block_M>{}, number<Block_N>{}));
|
||||
}
|
||||
}();
|
||||
|
||||
__shared__ char smem[GetSmemSize()];
|
||||
|
||||
Pipeline{}(x_window,
|
||||
@@ -333,6 +362,7 @@ struct Rmsnorm2dFwd
|
||||
inv_rms_window,
|
||||
sm_scale_window,
|
||||
y_scale_window,
|
||||
unquant_y_window,
|
||||
static_cast<const ComputeDataType>(kargs.epsilon),
|
||||
kargs.n,
|
||||
smem,
|
||||
|
||||
@@ -25,8 +25,9 @@ struct Rmsnorm2dFwdPipelineOnePass
|
||||
using XResidualDataType = XDataType;
|
||||
using YResidualDataType = XDataType;
|
||||
|
||||
static constexpr bool kHasGamma = !std::is_same_v<GammaDataType, ck_tile::null_type>;
|
||||
static constexpr bool kSaveInvRms = Problem::Traits::kSaveInvRms;
|
||||
static constexpr bool kHasGamma = !std::is_same_v<GammaDataType, ck_tile::null_type>;
|
||||
static constexpr bool kSaveInvRms = Problem::Traits::kSaveInvRms;
|
||||
static constexpr bool kSaveUnquant = Problem::Traits::kSaveUnquant;
|
||||
|
||||
static constexpr bool kNeedCrossWarpSync = Problem::kNeedCrossWarpSync;
|
||||
static constexpr bool kPadM = false; // TODO - BlockRmsnorm2dFwdProblem::kPadM
|
||||
@@ -54,6 +55,7 @@ struct Rmsnorm2dFwdPipelineOnePass
|
||||
typename InvRmsWindow,
|
||||
typename SmoothScaleWindow,
|
||||
typename YScaleWindow,
|
||||
typename UnquantYWindow,
|
||||
typename Epilogue>
|
||||
CK_TILE_DEVICE auto operator()(const XWindow& x_window_,
|
||||
const XResidualWindow& x_residual_window_,
|
||||
@@ -63,6 +65,7 @@ struct Rmsnorm2dFwdPipelineOnePass
|
||||
InvRmsWindow& inv_rms_window,
|
||||
const SmoothScaleWindow& sm_scale_window_,
|
||||
YScaleWindow& y_scale_window_,
|
||||
UnquantYWindow& unquant_y_window,
|
||||
ComputeDataType epsilon,
|
||||
ck_tile::index_t row_size,
|
||||
void* smem,
|
||||
@@ -137,11 +140,26 @@ struct Rmsnorm2dFwdPipelineOnePass
|
||||
|
||||
if constexpr(kFusedQuant == Rmsnorm2dFusedQuantEnum::SMOOTH_DYNAMIC_QUANT)
|
||||
{
|
||||
Epilogue{}(y_window_, sm_scale_window_, y_scale_window_, rmsn, smem);
|
||||
if constexpr(kSaveUnquant)
|
||||
{
|
||||
Epilogue{}(
|
||||
unquant_y_window, y_window_, sm_scale_window_, y_scale_window_, rmsn, smem);
|
||||
}
|
||||
else
|
||||
{
|
||||
Epilogue{}(y_window_, sm_scale_window_, y_scale_window_, rmsn, smem);
|
||||
}
|
||||
}
|
||||
else if constexpr(kFusedQuant == Rmsnorm2dFusedQuantEnum::DYNAMIC_QUANT)
|
||||
{
|
||||
Epilogue{}(y_window_, y_scale_window_, rmsn, smem);
|
||||
if constexpr(kSaveUnquant)
|
||||
{
|
||||
Epilogue{}(unquant_y_window, y_window_, y_scale_window_, rmsn, smem);
|
||||
}
|
||||
else
|
||||
{
|
||||
Epilogue{}(y_window_, y_scale_window_, rmsn, smem);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
|
||||
@@ -12,6 +12,7 @@ template <typename XDataType_,
|
||||
typename ComputeDataType_,
|
||||
typename YDataType_,
|
||||
typename InvRmsDataType_,
|
||||
typename UnquantYDataType_,
|
||||
typename SmoothScaleDataType_,
|
||||
typename YScaleDataType_,
|
||||
typename BlockShape_,
|
||||
@@ -23,6 +24,7 @@ struct Rmsnorm2dFwdPipelineProblem
|
||||
using ComputeDataType = remove_cvref_t<ComputeDataType_>;
|
||||
using YDataType = remove_cvref_t<YDataType_>;
|
||||
using InvRmsDataType = remove_cvref_t<InvRmsDataType_>;
|
||||
using UnquantYDataType = remove_cvref_t<UnquantYDataType_>;
|
||||
using SmoothScaleDataType = remove_cvref_t<SmoothScaleDataType_>;
|
||||
using YScaleDataType = remove_cvref_t<YScaleDataType_>;
|
||||
using BlockShape = remove_cvref_t<BlockShape_>;
|
||||
|
||||
@@ -54,6 +54,7 @@ struct Rmsnorm2dFwdPipelineTwoPass
|
||||
typename InvRmsWindow,
|
||||
typename SmoothScaleWindow,
|
||||
typename YScaleWindow,
|
||||
typename UnquantYWindow,
|
||||
typename Epilogue>
|
||||
CK_TILE_DEVICE auto operator()(const XWindow& x_window_,
|
||||
const XResidualWindow& x_residual_window_,
|
||||
@@ -63,6 +64,7 @@ struct Rmsnorm2dFwdPipelineTwoPass
|
||||
InvRmsWindow& inv_rms_window,
|
||||
const SmoothScaleWindow& /*sm_scale_window_*/,
|
||||
YScaleWindow& /*y_scale_window*/,
|
||||
UnquantYWindow& /*unquant_y_window*/,
|
||||
ComputeDataType epsilon,
|
||||
ck_tile::index_t row_size,
|
||||
void* smem,
|
||||
@@ -136,32 +138,51 @@ struct Rmsnorm2dFwdPipelineTwoPass
|
||||
ck_tile::index_t stride_to_right_most_window =
|
||||
row_size % Block_N == 0 ? row_size - Block_N : row_size - row_size % Block_N;
|
||||
|
||||
move_tile_window(x_window, {0, -Block_N});
|
||||
move_tile_window(x_residual_window, {0, -Block_N});
|
||||
if constexpr(kFusedAdd == Rmsnorm2dFusedAddEnum::PRE_ADD_STORE)
|
||||
{
|
||||
move_tile_window(y_residual_window, {0, -Block_N});
|
||||
}
|
||||
else
|
||||
{
|
||||
move_tile_window(x_window, {0, -Block_N});
|
||||
move_tile_window(x_residual_window, {0, -Block_N});
|
||||
}
|
||||
move_tile_window(gamma_window, {stride_to_right_most_window});
|
||||
move_tile_window(y_window, {0, stride_to_right_most_window});
|
||||
|
||||
// rmsnorm computation
|
||||
for(int iN = __builtin_amdgcn_readfirstlane(0); iN < num_n_tile_iteration; ++iN)
|
||||
{
|
||||
auto x = load_tile(x_window);
|
||||
auto x_resi = load_tile(x_residual_window);
|
||||
auto acc = cast_tile<ComputeDataType>(x);
|
||||
auto acc = make_static_distributed_tensor<ComputeDataType>(
|
||||
decltype(load_tile(x_window))::get_tile_distribution());
|
||||
|
||||
if constexpr(kFusedAdd == Rmsnorm2dFusedAddEnum::PRE_ADD_STORE ||
|
||||
kFusedAdd == Rmsnorm2dFusedAddEnum::PRE_ADD)
|
||||
if constexpr(kFusedAdd == Rmsnorm2dFusedAddEnum::PRE_ADD_STORE)
|
||||
{
|
||||
sweep_tile(x_resi, [&](auto idx) {
|
||||
// compute x = x_resi + x
|
||||
acc(idx) = type_convert<ComputeDataType>(x_resi(idx)) + acc(idx);
|
||||
});
|
||||
acc = cast_tile<ComputeDataType>(load_tile(y_residual_window));
|
||||
move_tile_window(y_residual_window, {0, -Block_N});
|
||||
}
|
||||
else
|
||||
{
|
||||
acc = cast_tile<ComputeDataType>(load_tile(x_window));
|
||||
move_tile_window(x_window, {0, -Block_N});
|
||||
|
||||
if constexpr(kFusedAdd == Rmsnorm2dFusedAddEnum::PRE_ADD)
|
||||
{
|
||||
auto x_resi = load_tile(x_residual_window);
|
||||
sweep_tile(x_resi, [&](auto idx) {
|
||||
// compute x = x_resi + x
|
||||
acc(idx) = type_convert<ComputeDataType>(x_resi(idx)) + acc(idx);
|
||||
});
|
||||
move_tile_window(x_residual_window, {0, -Block_N});
|
||||
}
|
||||
}
|
||||
|
||||
// load gamma (TODO: support no gamma?)
|
||||
const auto gamma = load_tile(gamma_window);
|
||||
|
||||
// rmsnorm computation
|
||||
auto rmsn = make_static_distributed_tensor<ComputeDataType>(x.get_tile_distribution());
|
||||
auto rmsn = make_static_distributed_tensor<ComputeDataType>(
|
||||
decltype(load_tile(x_window))::get_tile_distribution());
|
||||
sweep_tile(rmsn, [&, inv_rms_ = inv_rms](auto idx) {
|
||||
constexpr auto i_idx = make_tuple(idx[number<0>{}]);
|
||||
constexpr auto j_idx = make_tuple(idx[number<1>{}]);
|
||||
@@ -176,8 +197,6 @@ struct Rmsnorm2dFwdPipelineTwoPass
|
||||
static_assert(kFusedQuant == Rmsnorm2dFusedQuantEnum::NO_SWEEP);
|
||||
Epilogue{}(y_window, rmsn);
|
||||
|
||||
move_tile_window(x_window, {0, -Block_N});
|
||||
move_tile_window(x_residual_window, {0, -Block_N});
|
||||
move_tile_window(gamma_window, {-Block_N});
|
||||
move_tile_window(y_window, {0, -Block_N});
|
||||
}
|
||||
|
||||
@@ -39,6 +39,7 @@ template<> struct Rmsnorm2dFusedQuantEnumName<Rmsnorm2dFusedQuantEnum::SMOOTH_DY
|
||||
|
||||
template <bool kPadN_,
|
||||
bool kSaveInvRms_,
|
||||
bool kSaveUnquant_,
|
||||
bool kTwoPass_,
|
||||
Rmsnorm2dFusedAddEnum kFusedAdd_,
|
||||
Rmsnorm2dFusedQuantEnum kFusedQuant_>
|
||||
@@ -46,6 +47,7 @@ struct Rmsnorm2dFwdTraits
|
||||
{
|
||||
static constexpr bool kPadN = kPadN_;
|
||||
static constexpr bool kSaveInvRms = kSaveInvRms_;
|
||||
static constexpr bool kSaveUnquant = kSaveUnquant_;
|
||||
static constexpr bool kTwoPass = kTwoPass_;
|
||||
static constexpr Rmsnorm2dFusedAddEnum kFusedAdd = kFusedAdd_;
|
||||
static constexpr Rmsnorm2dFusedQuantEnum kFusedQuant = kFusedQuant_;
|
||||
|
||||
Reference in New Issue
Block a user