mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 06:49:15 +00:00
[CK_TILE] Enhance RMSNorm Accuracy: New Pipeline Pass for Selectable Implementation (#2409)
* Add Rmsnorm2dFwdPipelineModelSensitiveT5Pass * Update rmsnorm2d_fwd_pipeline_model_sensitive_pass 1. Add BlockReduce2dTreeCrossWarpSync * Add Rmsnorm2dFusedModelSensitiveEnum * Update patch 1. Reverse generate.py 2. Remove comment in generate.py 3. Update tree cross warp reduce * Refactor RMSNorm model enum and introduce T5-like option * Update the n stage for cross warp reduce * Add new cmdline option in RMSNorm for new pipeline testing --------- Co-authored-by: Clement Lin <clement.lin@amd.com> Co-authored-by: ClementLinCF <162283536+ClementLinCF@users.noreply.github.com>
This commit is contained in:
@@ -15,13 +15,14 @@ auto create_args(int argc, char* argv[])
|
||||
.insert("v", "1", "cpu validation or not")
|
||||
.insert("prec", "fp16", "precision")
|
||||
.insert("warmup", "0", "cold iter")
|
||||
.insert("repeat", "1", "hot iter");
|
||||
.insert("repeat", "1", "hot iter")
|
||||
.insert("s", "0", "sensitive model mode, 0: for no specific model, 1: for T5-like model");
|
||||
|
||||
bool result = arg_parser.parse(argc, argv);
|
||||
return std::make_tuple(result, arg_parser);
|
||||
}
|
||||
|
||||
template <typename DataType>
|
||||
template <typename DataType, int USEModelSensitive>
|
||||
bool run(const ck_tile::ArgParser& arg_parser)
|
||||
{
|
||||
ck_tile::index_t m = arg_parser.get_int("m");
|
||||
@@ -81,8 +82,10 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
false, // kSaveInvRms
|
||||
false, // kSaveUnquant
|
||||
kTwoPass,
|
||||
ck_tile::Rmsnorm2dFusedAddEnum::NO_ADD, // fuse add
|
||||
ck_tile::Rmsnorm2dFusedQuantEnum::NO_SWEEP>; // fuse quant
|
||||
ck_tile::Rmsnorm2dFusedAddEnum::NO_ADD, // fuse add
|
||||
ck_tile::Rmsnorm2dFusedQuantEnum::NO_SWEEP, // fuse quant
|
||||
static_cast<ck_tile::Rmsnorm2dSensitiveEnum>(
|
||||
USEModelSensitive)>;
|
||||
|
||||
using Problem = ck_tile::Rmsnorm2dFwdPipelineProblem<XDataType,
|
||||
GammaDataType,
|
||||
@@ -97,7 +100,17 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
|
||||
using OnePassPipeline = ck_tile::Rmsnorm2dFwdPipelineOnePass<Problem>;
|
||||
using TwoPassPipeline = ck_tile::Rmsnorm2dFwdPipelineTwoPass<Problem>;
|
||||
using Pipeline = std::conditional_t<kTwoPass, TwoPassPipeline, OnePassPipeline>;
|
||||
using T5PassPipeline = ck_tile::Rmsnorm2dFwdPipelineModelSensitiveT5Pass<Problem>;
|
||||
|
||||
using Pipeline =
|
||||
std::conditional_t<(PipelineTraits::kUseModelSensitiveRMSNorm ==
|
||||
ck_tile::Rmsnorm2dSensitiveEnum::NO_SPECIFIC_MODEL ||
|
||||
PipelineTraits::kTwoPass), // TODO: consider TwoPass for T5PassPipeline
|
||||
std::conditional_t<PipelineTraits::kTwoPass,
|
||||
TwoPassPipeline,
|
||||
OnePassPipeline>, // kUseModelSensitiveRMSNorm
|
||||
// == 0
|
||||
T5PassPipeline>;
|
||||
|
||||
using Default2DEpilogueProblem = ck_tile::
|
||||
Default2DEpilogueProblem<ComputeDataType, YDataType, false, PipelineTraits::kPadN, false>;
|
||||
@@ -172,7 +185,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
|
||||
std::cout << "[" << data_type << "]"
|
||||
<< " m:" << m << ", n:" << n << ", stride:" << stride
|
||||
<< ", valid:" << (pass ? "y" : "n") << std::flush << std::endl;
|
||||
<< ", s:" << USEModelSensitive << ", valid:" << (pass ? "y" : "n") << std::flush
|
||||
<< std::endl;
|
||||
}
|
||||
|
||||
return pass;
|
||||
@@ -184,10 +198,19 @@ int main(int argc, char* argv[])
|
||||
if(!result)
|
||||
return -1;
|
||||
|
||||
const std::string data_type = arg_parser.get_str("prec");
|
||||
const std::string data_type = arg_parser.get_str("prec");
|
||||
const int use_model_sensitive_rmsnorm = arg_parser.get_int("s");
|
||||
|
||||
if(data_type == "fp16")
|
||||
{
|
||||
return run<ck_tile::half_t>(arg_parser) ? 0 : -2;
|
||||
if(use_model_sensitive_rmsnorm == 0) // 0: for no specific RMSNorm
|
||||
{
|
||||
return run<ck_tile::half_t, 0>(arg_parser) ? 0 : -2;
|
||||
}
|
||||
else if(use_model_sensitive_rmsnorm == 1) // 1: for T5-like RMSNorm
|
||||
{
|
||||
return run<ck_tile::half_t, 1>(arg_parser) ? 0 : -2;
|
||||
}
|
||||
}
|
||||
|
||||
return -3;
|
||||
|
||||
@@ -65,7 +65,8 @@ template <typename XDataType_,
|
||||
bool kSaveUnquant_,
|
||||
bool kTwoPass_,
|
||||
ck_tile::index_t kFusedAdd_ = 0,
|
||||
ck_tile::index_t kFusedQuant_ = 0>
|
||||
ck_tile::index_t kFusedQuant_ = 0,
|
||||
ck_tile::index_t kUseModelSensitiveRMSNorm_ = 0>
|
||||
struct rmsnorm2d_fwd_traits_
|
||||
{
|
||||
using XDataType = ck_tile::remove_cvref_t<XDataType_>;
|
||||
@@ -127,8 +128,9 @@ struct rmsnorm2d_fwd_traits_
|
||||
static constexpr bool kSaveInvRms = kSaveInvRms_;
|
||||
static constexpr bool kSaveUnquant = kSaveUnquant_;
|
||||
static constexpr bool kTwoPass = kTwoPass_;
|
||||
static constexpr ck_tile::index_t kFusedAdd = kFusedAdd_;
|
||||
static constexpr ck_tile::index_t kFusedQuant = kFusedQuant_;
|
||||
static constexpr ck_tile::index_t kFusedAdd = kFusedAdd_;
|
||||
static constexpr ck_tile::index_t kFusedQuant = kFusedQuant_;
|
||||
static constexpr ck_tile::index_t kUseModelSensitiveRMSNorm = kUseModelSensitiveRMSNorm_;
|
||||
};
|
||||
|
||||
template <typename XDataType_,
|
||||
@@ -146,7 +148,8 @@ template <typename XDataType_,
|
||||
bool kSaveUnquant_,
|
||||
bool kTwoPass_,
|
||||
int kFusedAdd_,
|
||||
int kFusedQuant_>
|
||||
int kFusedQuant_,
|
||||
int kUseModelSensitiveRMSNorm_>
|
||||
using traits_ = rmsnorm2d_fwd_traits_<XDataType_,
|
||||
YDataType_,
|
||||
SmoothScaleDataType_,
|
||||
@@ -162,7 +165,8 @@ using traits_ = rmsnorm2d_fwd_traits_<XDataType_,
|
||||
kSaveUnquant_,
|
||||
kTwoPass_,
|
||||
kFusedAdd_,
|
||||
kFusedQuant_>;
|
||||
kFusedQuant_,
|
||||
kUseModelSensitiveRMSNorm_>;
|
||||
"""
|
||||
|
||||
API_COMMON_HEADER = """
|
||||
@@ -197,7 +201,8 @@ float rmsnorm2d_fwd_(const S& s, A a)
|
||||
Traits_::kSaveUnquant,
|
||||
Traits_::kTwoPass,
|
||||
static_cast<ck_tile::Rmsnorm2dFusedAddEnum>(Traits_::kFusedAdd),
|
||||
static_cast<ck_tile::Rmsnorm2dFusedQuantEnum>(Traits_::kFusedQuant)>;
|
||||
static_cast<ck_tile::Rmsnorm2dFusedQuantEnum>(Traits_::kFusedQuant),
|
||||
static_cast<ck_tile::Rmsnorm2dSensitiveEnum>(Traits_::kUseModelSensitiveRMSNorm)>;
|
||||
|
||||
using PipelineProblem =
|
||||
ck_tile::Rmsnorm2dFwdPipelineProblem<typename RmsnormTypeConfig<XDataType, YDataType, SmoothScaleDataType, YScaleDataType>::XDataType,
|
||||
@@ -213,7 +218,13 @@ float rmsnorm2d_fwd_(const S& s, A a)
|
||||
|
||||
using OnePassPipeline = ck_tile::Rmsnorm2dFwdPipelineOnePass<PipelineProblem>;
|
||||
using TwoPassPipeline = ck_tile::Rmsnorm2dFwdPipelineTwoPass<PipelineProblem>;
|
||||
using Pipeline = std::conditional_t<Traits_::kTwoPass, TwoPassPipeline, OnePassPipeline>;
|
||||
using T5PassPipeline = ck_tile::Rmsnorm2dFwdPipelineModelSensitiveT5Pass<PipelineProblem>;
|
||||
|
||||
using Pipeline = std::conditional_t<
|
||||
(Traits_::kUseModelSensitiveRMSNorm == 0 || Traits_::kTwoPass), // TODO: consider TwoPass for T5PassPipeline
|
||||
std::conditional_t<Traits_::kTwoPass, TwoPassPipeline, OnePassPipeline>, // kUseModelSensitiveRMSNorm == 0
|
||||
T5PassPipeline
|
||||
>;
|
||||
|
||||
using Default2DEpilogueProblem = ck_tile::Default2DEpilogueProblem<ComputeDataType, YDataType, false, Traits_::kPadN, false>;
|
||||
using Default2DEpilogue = ck_tile::Default2DEpilogue<Default2DEpilogueProblem>;
|
||||
@@ -387,12 +398,13 @@ float rmsnorm2d_fwd(rmsnorm2d_fwd_traits t,
|
||||
F_kTwoPass : bool
|
||||
F_kFusedAdd : int
|
||||
F_kFusedQuant : int
|
||||
F_use_model_sensitive_rmsnorm : int
|
||||
|
||||
@property
|
||||
def trait_name(self) ->str:
|
||||
t_ = f'{DATA_TYPE_MAP[self.F_XDataType]}, {DATA_TYPE_MAP[self.F_YDataType]}, {DATA_TYPE_MAP[self.F_SmoothScaleDataType]}, {DATA_TYPE_MAP[self.F_YScaleDataType]}, {DATA_TYPE_MAP[self.F_UnquantYDataType]}, {self.F_Repeat_M:2}, {self.F_Repeat_N:2}, {self.F_ThreadPerBlock_M:2}, {self.F_ThreadPerBlock_N:4}'
|
||||
t_ += f', {self.F_Vector_N:2}, {BOOL_MAP(self.F_kPadN):5}, {BOOL_MAP(self.F_kSaveInvRms):5}, {BOOL_MAP(self.F_kSaveUnquant):5}'
|
||||
t_ += f', {BOOL_MAP(self.F_kTwoPass):5}, {self.F_kFusedAdd:4}, {self.F_kFusedQuant:4}'
|
||||
t_ += f', {BOOL_MAP(self.F_kTwoPass):5}, {self.F_kFusedAdd:4}, {self.F_kFusedQuant:4}, {self.F_use_model_sensitive_rmsnorm:4}'
|
||||
return t_
|
||||
|
||||
# string when calling this kernel
|
||||
@@ -413,6 +425,7 @@ float rmsnorm2d_fwd(rmsnorm2d_fwd_traits t,
|
||||
F_add : int
|
||||
F_sweep : int
|
||||
F_saveunquant : bool
|
||||
F_use_model_sensitive_rmsnorm : int
|
||||
instance_list : List[Any] # List[h_traits]
|
||||
|
||||
@property
|
||||
@@ -426,6 +439,10 @@ float rmsnorm2d_fwd(rmsnorm2d_fwd_traits t,
|
||||
nnn = nnn + '_' + FUSED_FUSED_SWEEP_STR_MAP[self.F_sweep]
|
||||
if self.F_saveunquant:
|
||||
nnn = nnn + '_saveunquant'
|
||||
if self.F_use_model_sensitive_rmsnorm == 0:
|
||||
nnn = nnn + '_nsm'
|
||||
elif self.F_use_model_sensitive_rmsnorm == 1:
|
||||
nnn = nnn + '_t5ml'
|
||||
return nnn
|
||||
|
||||
@property
|
||||
@@ -481,9 +498,9 @@ float rmsnorm2d_fwd(rmsnorm2d_fwd_traits t,
|
||||
elif ins.F_kFusedQuant == 2:
|
||||
_sweep_cond = 't.fused_quant == {f_fused_sweep} && (t.prec_sy == \"{f_sy_type}\" && t.save_unquant == {f_suq})'.format(
|
||||
f_fused_sweep = ins.F_kFusedQuant, f_sy_type=ins.F_YScaleDataType, f_suq=BOOL_MAP(ins.F_kSaveUnquant))
|
||||
_cond = '((a.n % {f_vec_n} == 0) && (t.fused_add == {f_fused_add}) && ({f_sweep_cond}))'.format(
|
||||
_cond = '((a.n % {f_vec_n} == 0) && (t.fused_add == {f_fused_add}) && ({f_sweep_cond}) && (t.use_model_sensitive_rmsnorm == {f_use_model_sensitive_rmsnorm}) )'.format(
|
||||
f_vec_n = ins.F_Vector_N, f_fused_add = ins.F_kFusedAdd,
|
||||
f_sweep_cond = _sweep_cond)
|
||||
f_sweep_cond = _sweep_cond, f_use_model_sensitive_rmsnorm = ins.F_use_model_sensitive_rmsnorm)
|
||||
inner_str += self.API_INNER_CASE.format(F_if = get_if_str(idx_in_n, len_in_n, False),
|
||||
F_VEC_COND = _cond, F_instance_func=ins.call_name)
|
||||
#inner_str = inner_str + vec_str
|
||||
@@ -516,85 +533,149 @@ float rmsnorm2d_fwd(rmsnorm2d_fwd_traits t,
|
||||
fused_sweep_list = [0, 1, 2] # NOTE: only single pass can use fused (smooth) dynamic quant
|
||||
bool_list = [False, True]
|
||||
|
||||
# rm rn tm tn vn pd mv unquant 2p add sweep
|
||||
h_trait_dict = {'64' : [ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 1, 8, 8, 8, True, False, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 1, 4, 16, 4, True, False, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 1, 4, 64, 1, True, False, False, False, 0, 0)],
|
||||
'128' : [ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 1, 4, 16, 8, True, False, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 1, 4, 64, 2, True, False, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 2, 4, 64, 1, True, False, False, False, 0, 0)],
|
||||
'256' : [ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 1, 4, 64, 4, True, False, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 2, 4, 64, 2, True, False, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 4, 64, 1, True, False, False, False, 0, 0)],
|
||||
'512' : [ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 1, 4, 64, 8, True, False, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 2, 4, 64, 4, True, False, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 4, 64, 2, True, False, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 8, 4, 64, 1, True, False, False, False, 0, 0)],
|
||||
'640' : [ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 5, 4, 64, 2, True, False, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 5, 4, 128, 1, True, False, False, False, 0, 0)],
|
||||
'768' : [ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 3, 4, 64, 4, True, False, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 6, 4, 64, 2, True, False, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 12, 4, 64, 1, True, False, False, False, 0, 0)],
|
||||
'1024' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 2, 2, 64, 8, True, False, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 2, 64, 4, True, False, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 8, 2, 64, 2, True, False, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 1, 256, 1, True, False, False, False, 0, 0)],
|
||||
'1536' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 3, 4, 64, 8, True, False, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 3, 2, 128, 4, True, False, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 3, 1, 256, 2, True, False, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 6, 1, 256, 1, True, False, False, False, 0, 0)],
|
||||
'2048' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 1, 1, 256, 8, True, False, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 2, 1, 256, 4, True, False, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 1, 256, 2, True, False, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 8, 1, 256, 1, True, False, False, False, 0, 0)],
|
||||
'3072' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 3, 1, 128, 8, True, False, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 3, 1, 256, 4, True, False, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 6, 1, 256, 2, True, False, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 3, 1,1024, 1, True, False, False, False, 0, 0)],
|
||||
'4096' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 2, 1, 256, 8, True, False, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 1, 256, 4, True, False, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 2, 1,1024, 2, True, False, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 1,1024, 1, True, False, False, False, 0, 0)],
|
||||
'6144' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 3, 1, 256, 8, True, False, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 3, 1, 512, 4, True, False, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 3, 1,1024, 2, True, False, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 6, 1,1024, 1, True, False, False, False, 0, 0)],
|
||||
'8192' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 1, 256, 8, True, False, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 1, 512, 4, True, False, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 1,1024, 2, True, False, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 8, 1,1024, 1, True, False, False, False, 0, 0)],
|
||||
'big' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 1, 1,1024, 8, True, False, False, True, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 1, 256, 4, True, False, False, True, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 12, 1, 256, 2, True, False, False, True, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 1,1024, 1, True, False, False, True, 0, 0)]}
|
||||
h_trait_dicts = {
|
||||
0: {
|
||||
# rm rn tm tn vn pd mv unquant 2p add sweep srm
|
||||
'64' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 1, 8, 8, 8, True, False, False, False, 0, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 1, 4, 16, 4, True, False, False, False, 0, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 1, 4, 64, 1, True, False, False, False, 0, 0, 0)],
|
||||
'128' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 1, 4, 16, 8, True, False, False, False, 0, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 1, 4, 64, 2, True, False, False, False, 0, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 2, 4, 64, 1, True, False, False, False, 0, 0, 0)],
|
||||
'256' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 1, 4, 64, 4, True, False, False, False, 0, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 2, 4, 64, 2, True, False, False, False, 0, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 4, 64, 1, True, False, False, False, 0, 0, 0)],
|
||||
'512' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 1, 4, 64, 8, True, False, False, False, 0, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 2, 4, 64, 4, True, False, False, False, 0, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 4, 64, 2, True, False, False, False, 0, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 8, 4, 64, 1, True, False, False, False, 0, 0, 0)],
|
||||
'640' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 5, 4, 64, 2, True, False, False, False, 0, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 5, 4, 128, 1, True, False, False, False, 0, 0, 0)],
|
||||
'768' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 3, 4, 64, 4, True, False, False, False, 0, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 6, 4, 64, 2, True, False, False, False, 0, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 12, 4, 64, 1, True, False, False, False, 0, 0, 0)],
|
||||
'1024' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 2, 2, 64, 8, True, False, False, False, 0, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 2, 64, 4, True, False, False, False, 0, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 8, 2, 64, 2, True, False, False, False, 0, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 1, 256, 1, True, False, False, False, 0, 0, 0)],
|
||||
'1536' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 3, 4, 64, 8, True, False, False, False, 0, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 3, 2, 128, 4, True, False, False, False, 0, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 3, 1, 256, 2, True, False, False, False, 0, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 6, 1, 256, 1, True, False, False, False, 0, 0, 0)],
|
||||
'2048' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 1, 1, 256, 8, True, False, False, False, 0, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 2, 1, 256, 4, True, False, False, False, 0, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 1, 256, 2, True, False, False, False, 0, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 8, 1, 256, 1, True, False, False, False, 0, 0, 0)],
|
||||
'3072' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 3, 1, 128, 8, True, False, False, False, 0, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 3, 1, 256, 4, True, False, False, False, 0, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 6, 1, 256, 2, True, False, False, False, 0, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 3, 1,1024, 1, True, False, False, False, 0, 0, 0)],
|
||||
'4096' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 2, 1, 256, 8, True, False, False, False, 0, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 1, 256, 4, True, False, False, False, 0, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 2, 1,1024, 2, True, False, False, False, 0, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 1,1024, 1, True, False, False, False, 0, 0, 0)],
|
||||
'6144' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 3, 1, 256, 8, True, False, False, False, 0, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 3, 1, 512, 4, True, False, False, False, 0, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 3, 1,1024, 2, True, False, False, False, 0, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 6, 1,1024, 1, True, False, False, False, 0, 0, 0)],
|
||||
'8192' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 1, 256, 8, True, False, False, False, 0, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 1, 512, 4, True, False, False, False, 0, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 1,1024, 2, True, False, False, False, 0, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 8, 1,1024, 1, True, False, False, False, 0, 0, 0)],
|
||||
'big' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 1, 1,1024, 8, True, False, False, True, 0, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 1, 256, 4, True, False, False, True, 0, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 12, 1, 256, 2, True, False, False, True, 0, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 1,1024, 1, True, False, False, True, 0, 0, 0)]
|
||||
},
|
||||
1: {
|
||||
# rm rn tm tn vn pd mv unquant 2p add sweep srm
|
||||
'64' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 1, 8, 8, 8, True, False, False, False, 0, 0, 1),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 1, 4, 16, 4, True, False, False, False, 0, 0, 1),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 1, 4, 64, 1, True, False, False, False, 0, 0, 1)],
|
||||
'128' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 1, 4, 16, 8, True, False, False, False, 0, 0, 1),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 1, 4, 64, 2, True, False, False, False, 0, 0, 1),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 2, 4, 64, 1, True, False, False, False, 0, 0, 1)],
|
||||
'256' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 1, 8, 32, 8, True, False, False, False, 0, 0, 1),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 1, 4, 64, 4, True, False, False, False, 0, 0, 1),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 2, 4, 64, 2, True, False, False, False, 0, 0, 1),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 4, 64, 1, True, False, False, False, 0, 0, 1)],
|
||||
'512' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 1, 4, 64, 8, True, False, False, False, 0, 0, 1),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 2, 4, 64, 4, True, False, False, False, 0, 0, 1),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 4, 64, 2, True, False, False, False, 0, 0, 1),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 8, 4, 64, 1, True, False, False, False, 0, 0, 1)],
|
||||
'640' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 1, 2, 128, 8, True, False, False, False, 0, 0, 1),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 5, 4, 64, 2, True, False, False, False, 0, 0, 1),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 5, 4, 128, 1, True, False, False, False, 0, 0, 1)],
|
||||
'768' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 1, 2, 128, 8, True, False, False, False, 0, 0, 1),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 3, 4, 64, 4, True, False, False, False, 0, 0, 1),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 6, 4, 64, 2, True, False, False, False, 0, 0, 1),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 12, 4, 64, 1, True, False, False, False, 0, 0, 1)],
|
||||
'1024' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 1, 2, 128, 8, True, False, False, False, 0, 0, 1),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 2, 64, 4, True, False, False, False, 0, 0, 1),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 8, 2, 64, 2, True, False, False, False, 0, 0, 1),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 1, 256, 1, True, False, False, False, 0, 0, 1)],
|
||||
'1536' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 1, 1, 256, 8, True, False, False, False, 0, 0, 1),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 3, 2, 128, 4, True, False, False, False, 0, 0, 1),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 3, 1, 256, 2, True, False, False, False, 0, 0, 1),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 6, 1, 256, 1, True, False, False, False, 0, 0, 1)],
|
||||
'2048' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 1, 1, 256, 8, True, False, False, False, 0, 0, 1),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 2, 1, 256, 4, True, False, False, False, 0, 0, 1),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 1, 256, 2, True, False, False, False, 0, 0, 1),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 8, 1, 256, 1, True, False, False, False, 0, 0, 1)],
|
||||
'3072' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 2, 1, 256, 8, True, False, False, False, 0, 0, 1),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 3, 1, 256, 4, True, False, False, False, 0, 0, 1),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 6, 1, 256, 2, True, False, False, False, 0, 0, 1),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 3, 1,1024, 1, True, False, False, False, 0, 0, 1)],
|
||||
'4096' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 2, 1, 256, 8, True, False, False, False, 0, 0, 1),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 1, 256, 4, True, False, False, False, 0, 0, 1),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 2, 1,1024, 2, True, False, False, False, 0, 0, 1),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 1,1024, 1, True, False, False, False, 0, 0, 1)],
|
||||
'6144' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 3, 1, 256, 8, True, False, False, False, 0, 0, 1),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 3, 1, 512, 4, True, False, False, False, 0, 0, 1),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 3, 1,1024, 2, True, False, False, False, 0, 0, 1),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 6, 1,1024, 1, True, False, False, False, 0, 0, 1)],
|
||||
'8192' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 1, 256, 8, True, False, False, False, 0, 0, 1),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 1, 512, 4, True, False, False, False, 0, 0, 1),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 1,1024, 2, True, False, False, False, 0, 0, 1),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 8, 1,1024, 1, True, False, False, False, 0, 0, 1)],
|
||||
'big' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 1, 1,1024, 8, True, False, False, True, 0, 0, 1),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 1, 256, 4, True, False, False, True, 0, 0, 1),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 12, 1, 256, 2, True, False, False, True, 0, 0, 1),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 1,1024, 1, True, False, False, True, 0, 0, 1)]
|
||||
}
|
||||
}
|
||||
|
||||
total_blob = list()
|
||||
for hs_key in h_trait_dict:
|
||||
hs = h_trait_dict[hs_key]
|
||||
current_n = hs[0].F_Repeat_N * hs[0].F_ThreadPerBlock_N * hs[0].F_Vector_N
|
||||
for dtype, scale_type, fused_add, fused_quant, save_unquant in itertools.product(dtype_list, scale_list, fused_add_list, fused_sweep_list, bool_list):
|
||||
prec_i, prec_o = dtype.split(',')
|
||||
scale_sm, scale_y = scale_type.split(',')
|
||||
if prec_o in dynamic_quant_out_dtype and fused_quant != 1 and fused_quant != 2:
|
||||
continue # skip non dynamic quant case
|
||||
if (fused_quant == 1 or fused_quant == 2) and hs_key == 'big':
|
||||
continue
|
||||
if (fused_quant == 0 and save_unquant == True):
|
||||
continue # save_unquant should always be false when there is no quant enabled
|
||||
current_hs = list()
|
||||
for chs_ in hs:
|
||||
h_ = copy.copy(chs_) # copy the base instance out
|
||||
h_.F_XDataType = prec_i
|
||||
h_.F_YDataType = prec_o
|
||||
h_.F_SmoothScaleDataType = scale_sm
|
||||
h_.F_YScaleDataType = scale_y
|
||||
h_.F_UnquantYDataType = prec_i
|
||||
h_.F_kFusedAdd = fused_add
|
||||
h_.F_kFusedQuant = fused_quant
|
||||
h_.F_kSaveUnquant = save_unquant
|
||||
current_hs.append(h_) # + "\n"
|
||||
#f.write(str(f.parent / GEN_DIR / (blobs.api_common_header_
|
||||
current_n_str = 'big' if hs_key == 'big' else current_n
|
||||
total_blob.append(h_instance(dtype, current_n_str, fused_add, fused_quant, save_unquant, current_hs))
|
||||
|
||||
for model_sensitive_flag in [0, 1]: # 0: default; 1: model sensitive
|
||||
current_trait_dict = h_trait_dicts[model_sensitive_flag]
|
||||
for hs_key in current_trait_dict:
|
||||
hs = current_trait_dict[hs_key]
|
||||
current_n = hs_key
|
||||
for dtype, scale_type, fused_add, fused_quant, save_unquant in itertools.product(dtype_list, scale_list, fused_add_list, fused_sweep_list, bool_list):
|
||||
prec_i, prec_o = dtype.split(',')
|
||||
scale_sm, scale_y = scale_type.split(',')
|
||||
if prec_o in dynamic_quant_out_dtype and fused_quant != 1 and fused_quant != 2:
|
||||
continue # skip non dynamic quant case
|
||||
if (fused_quant == 1 or fused_quant == 2) and hs_key == 'big':
|
||||
continue
|
||||
if (fused_quant == 0 and save_unquant == True):
|
||||
continue # save_unquant should always be false when there is no quant enabled
|
||||
current_hs = list()
|
||||
for chs_ in hs:
|
||||
h_ = copy.copy(chs_) # copy the base instance out
|
||||
h_.F_XDataType = prec_i
|
||||
h_.F_YDataType = prec_o
|
||||
h_.F_SmoothScaleDataType = scale_sm
|
||||
h_.F_YScaleDataType = scale_y
|
||||
h_.F_UnquantYDataType = prec_i
|
||||
h_.F_kFusedAdd = fused_add
|
||||
h_.F_kFusedQuant = fused_quant
|
||||
h_.F_kSaveUnquant = save_unquant
|
||||
current_hs.append(h_) # + "\n"
|
||||
#f.write(str(f.parent / GEN_DIR / (blobs.api_common_header_
|
||||
current_n_str = 'big' if hs_key == 'big' else current_n
|
||||
total_blob.append(h_instance(dtype, current_n_str, fused_add, fused_quant, save_unquant, h_.F_use_model_sensitive_rmsnorm, current_hs))
|
||||
return total_blob
|
||||
|
||||
def list_blobs(self) -> None:
|
||||
|
||||
@@ -52,7 +52,8 @@ auto create_args(int argc, char* argv[])
|
||||
.insert("fadd", "0", "fused-add, 0:no fused add, 1:preadd+store, 2:preadd only")
|
||||
.insert("fquant", "0", "fused-quant, 0:no, 1:smooth-dynamic-quant, 2:dynamic-quant")
|
||||
.insert("warmup", "5", "cold iter")
|
||||
.insert("repeat", "20", "hot iter");
|
||||
.insert("repeat", "20", "hot iter")
|
||||
.insert("s", "0", "sensitive model mode, 0: for no specific model, 1: for T5-like model");
|
||||
|
||||
bool result = arg_parser.parse(argc, argv);
|
||||
return std::make_tuple(result, arg_parser);
|
||||
@@ -66,15 +67,16 @@ template <typename InDataType,
|
||||
bool SaveUnquant>
|
||||
bool run(const ck_tile::ArgParser& arg_parser)
|
||||
{
|
||||
ck_tile::index_t m = arg_parser.get_int("m");
|
||||
ck_tile::index_t n = arg_parser.get_int("n");
|
||||
float epsilon = arg_parser.get_float("e");
|
||||
int kname = arg_parser.get_int("kname");
|
||||
int do_validation = arg_parser.get_int("v");
|
||||
int fused_add = arg_parser.get_int("fadd");
|
||||
int fused_quant = arg_parser.get_int("fquant");
|
||||
int warmup = arg_parser.get_int("warmup");
|
||||
int repeat = arg_parser.get_int("repeat");
|
||||
ck_tile::index_t m = arg_parser.get_int("m");
|
||||
ck_tile::index_t n = arg_parser.get_int("n");
|
||||
float epsilon = arg_parser.get_float("e");
|
||||
int kname = arg_parser.get_int("kname");
|
||||
int do_validation = arg_parser.get_int("v");
|
||||
int fused_add = arg_parser.get_int("fadd");
|
||||
int fused_quant = arg_parser.get_int("fquant");
|
||||
int warmup = arg_parser.get_int("warmup");
|
||||
int repeat = arg_parser.get_int("repeat");
|
||||
const int use_model_sensitive_rmsnorm = arg_parser.get_int("s");
|
||||
|
||||
ck_tile::index_t x_stride = arg_parser.get_int("x_stride");
|
||||
if(x_stride < 0)
|
||||
@@ -194,10 +196,17 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
std::cout << "[" << prec_str << "]"
|
||||
<< " m:" << m << ", n:" << n << ", x_stride:" << x_stride
|
||||
<< ", xr_stride:" << xr_stride << ", y_stride:" << y_stride
|
||||
<< ", yr_stride:" << yr_stride << std::flush;
|
||||
<< ", yr_stride:" << yr_stride << ", s:" << use_model_sensitive_rmsnorm << std::flush;
|
||||
|
||||
rmsnorm2d_fwd_traits traits{
|
||||
prec_i, prec_o, prec_sm, prec_sy, SaveRms, SaveUnquant, fused_add, fused_quant};
|
||||
rmsnorm2d_fwd_traits traits{prec_i,
|
||||
prec_o,
|
||||
prec_sm,
|
||||
prec_sy,
|
||||
SaveRms,
|
||||
SaveUnquant,
|
||||
fused_add,
|
||||
fused_quant,
|
||||
use_model_sensitive_rmsnorm};
|
||||
|
||||
rmsnorm2d_fwd_args args{x_buf.GetDeviceBuffer(),
|
||||
fused_add != 0 ? x_residual_buf.GetDeviceBuffer() : nullptr,
|
||||
|
||||
@@ -64,6 +64,8 @@ struct rmsnorm2d_fwd_traits
|
||||
bool save_unquant;
|
||||
int fused_add; // 0:no-add, 1:pre-add-store, 2:pre-add
|
||||
int fused_quant; // 0:no-sweep, 1:smooth-dynamic-quant, 2:dynamic-quant
|
||||
|
||||
int use_model_sensitive_rmsnorm = 0; // 0: Use default RMSNorm; 1: Use T5-like implementation
|
||||
};
|
||||
|
||||
float rmsnorm2d_fwd(rmsnorm2d_fwd_traits, rmsnorm2d_fwd_args, const ck_tile::stream_config&);
|
||||
|
||||
@@ -1,37 +1,74 @@
|
||||
#!/bin/sh
|
||||
EXE="$(find . -name tile_rmsnorm2d_fwd -type f | head -n 1)"
|
||||
|
||||
$EXE -m=1 -n=1 -e=1e-12 -v=1 -prec=bf16 -repeat=1000
|
||||
$EXE -m=700 -n=80 -e=1e-12 -v=1 -prec=bf16 -repeat=1000
|
||||
$EXE -m=700 -n=128 -e=1e-12 -v=1 -prec=bf16 -repeat=1000
|
||||
$EXE -m=700 -n=144 -e=1e-12 -v=1 -prec=bf16 -repeat=1000
|
||||
$EXE -m=700 -n=168 -e=1e-12 -v=1 -prec=bf16 -repeat=1000
|
||||
$EXE -m=700 -n=184 -e=1e-12 -v=1 -prec=bf16 -repeat=1000
|
||||
$EXE -m=700 -n=256 -e=1e-12 -v=1 -prec=bf16 -repeat=1000
|
||||
$EXE -m=700 -n=288 -e=1e-12 -v=1 -prec=bf16 -repeat=1000
|
||||
$EXE -m=700 -n=344 -e=1e-12 -v=1 -prec=bf16 -repeat=1000
|
||||
$EXE -m=700 -n=376 -e=1e-12 -v=1 -prec=bf16 -repeat=1000
|
||||
$EXE -m=700 -n=448 -e=1e-12 -v=1 -prec=bf16 -repeat=1000
|
||||
$EXE -m=700 -n=512 -e=1e-12 -v=1 -prec=bf16 -repeat=1000
|
||||
$EXE -m=700 -n=924 -e=1e-12 -v=1 -prec=bf16 -repeat=1000
|
||||
$EXE -m=700 -n=1024 -e=1e-12 -v=1 -prec=bf16 -repeat=1000
|
||||
$EXE -m=700 -n=1078 -e=1e-12 -v=1 -prec=bf16 -repeat=1000
|
||||
$EXE -m=700 -n=1996 -e=1e-12 -v=1 -prec=bf16 -repeat=1000
|
||||
$EXE -m=700 -n=4080 -e=1e-12 -v=1 -prec=bf16 -repeat=1000
|
||||
# 0: for no specific RMSNorm
|
||||
$EXE -m=1 -n=1 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000 -s=0
|
||||
$EXE -m=700 -n=80 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000 -s=0
|
||||
$EXE -m=700 -n=128 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000 -s=0
|
||||
$EXE -m=700 -n=144 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000 -s=0
|
||||
$EXE -m=700 -n=168 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000 -s=0
|
||||
$EXE -m=700 -n=184 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000 -s=0
|
||||
$EXE -m=700 -n=256 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000 -s=0
|
||||
$EXE -m=700 -n=288 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000 -s=0
|
||||
$EXE -m=700 -n=344 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000 -s=0
|
||||
$EXE -m=700 -n=376 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000 -s=0
|
||||
$EXE -m=700 -n=448 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000 -s=0
|
||||
$EXE -m=700 -n=512 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000 -s=0
|
||||
$EXE -m=700 -n=924 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000 -s=0
|
||||
$EXE -m=700 -n=1024 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000 -s=0
|
||||
$EXE -m=700 -n=1078 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000 -s=0
|
||||
$EXE -m=700 -n=1996 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000 -s=0
|
||||
$EXE -m=700 -n=4080 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000 -s=0
|
||||
|
||||
$EXE -m=700 -n=80 -e=1e-12 -v=1 -prec=fp16 -repeat=1000
|
||||
$EXE -m=700 -n=128 -e=1e-12 -v=1 -prec=fp16 -repeat=1000
|
||||
$EXE -m=700 -n=144 -e=1e-12 -v=1 -prec=fp16 -repeat=1000
|
||||
$EXE -m=700 -n=168 -e=1e-12 -v=1 -prec=fp16 -repeat=1000
|
||||
$EXE -m=700 -n=184 -e=1e-12 -v=1 -prec=fp16 -repeat=1000
|
||||
$EXE -m=700 -n=256 -e=1e-12 -v=1 -prec=fp16 -repeat=1000
|
||||
$EXE -m=700 -n=288 -e=1e-12 -v=1 -prec=fp16 -repeat=1000
|
||||
$EXE -m=700 -n=344 -e=1e-12 -v=1 -prec=fp16 -repeat=1000
|
||||
$EXE -m=700 -n=376 -e=1e-12 -v=1 -prec=fp16 -repeat=1000
|
||||
$EXE -m=700 -n=448 -e=1e-12 -v=1 -prec=fp16 -repeat=1000
|
||||
$EXE -m=700 -n=512 -e=1e-12 -v=1 -prec=fp16 -repeat=1000
|
||||
$EXE -m=700 -n=924 -e=1e-12 -v=1 -prec=fp16 -repeat=1000
|
||||
$EXE -m=700 -n=1024 -e=1e-12 -v=1 -prec=fp16 -repeat=1000
|
||||
$EXE -m=700 -n=1078 -e=1e-12 -v=1 -prec=fp16 -repeat=1000
|
||||
$EXE -m=700 -n=1996 -e=1e-12 -v=1 -prec=fp16 -repeat=1000
|
||||
$EXE -m=700 -n=4080 -e=1e-12 -v=1 -prec=fp16 -repeat=1000
|
||||
$EXE -m=700 -n=80 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000 -s=0
|
||||
$EXE -m=700 -n=128 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000 -s=0
|
||||
$EXE -m=700 -n=144 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000 -s=0
|
||||
$EXE -m=700 -n=168 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000 -s=0
|
||||
$EXE -m=700 -n=184 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000 -s=0
|
||||
$EXE -m=700 -n=256 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000 -s=0
|
||||
$EXE -m=700 -n=288 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000 -s=0
|
||||
$EXE -m=700 -n=344 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000 -s=0
|
||||
$EXE -m=700 -n=376 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000 -s=0
|
||||
$EXE -m=700 -n=448 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000 -s=0
|
||||
$EXE -m=700 -n=512 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000 -s=0
|
||||
$EXE -m=700 -n=924 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000 -s=0
|
||||
$EXE -m=700 -n=1024 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000 -s=0
|
||||
$EXE -m=700 -n=1078 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000 -s=0
|
||||
$EXE -m=700 -n=1996 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000 -s=0
|
||||
$EXE -m=700 -n=4080 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000 -s=0
|
||||
|
||||
# 1: for T5-like RMSNorm
|
||||
$EXE -m=1 -n=1 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000 -s=1
|
||||
$EXE -m=700 -n=80 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000 -s=1
|
||||
$EXE -m=700 -n=128 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000 -s=1
|
||||
$EXE -m=700 -n=144 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000 -s=1
|
||||
$EXE -m=700 -n=168 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000 -s=1
|
||||
$EXE -m=700 -n=184 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000 -s=1
|
||||
$EXE -m=700 -n=256 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000 -s=1
|
||||
$EXE -m=700 -n=288 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000 -s=1
|
||||
$EXE -m=700 -n=344 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000 -s=1
|
||||
$EXE -m=700 -n=376 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000 -s=1
|
||||
$EXE -m=700 -n=448 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000 -s=1
|
||||
$EXE -m=700 -n=512 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000 -s=1
|
||||
$EXE -m=700 -n=924 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000 -s=1
|
||||
$EXE -m=700 -n=1024 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000 -s=1
|
||||
$EXE -m=700 -n=1078 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000 -s=1
|
||||
$EXE -m=700 -n=1996 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000 -s=1
|
||||
$EXE -m=700 -n=4080 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000 -s=1
|
||||
|
||||
$EXE -m=700 -n=80 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000 -s=1
|
||||
$EXE -m=700 -n=128 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000 -s=1
|
||||
$EXE -m=700 -n=144 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000 -s=1
|
||||
$EXE -m=700 -n=168 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000 -s=1
|
||||
$EXE -m=700 -n=184 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000 -s=1
|
||||
$EXE -m=700 -n=256 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000 -s=1
|
||||
$EXE -m=700 -n=288 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000 -s=1
|
||||
$EXE -m=700 -n=344 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000 -s=1
|
||||
$EXE -m=700 -n=376 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000 -s=1
|
||||
$EXE -m=700 -n=448 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000 -s=1
|
||||
$EXE -m=700 -n=512 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000 -s=1
|
||||
$EXE -m=700 -n=924 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000 -s=1
|
||||
$EXE -m=700 -n=1024 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000 -s=1
|
||||
$EXE -m=700 -n=1078 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000 -s=1
|
||||
$EXE -m=700 -n=1996 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000 -s=1
|
||||
$EXE -m=700 -n=4080 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000 -s=1
|
||||
@@ -5,29 +5,32 @@ for fquant in "" "-fquant=1 -prec_o=int8" "-fquant=2 -prec_o=int8" "-fquant=1 -p
|
||||
"-fquant=1 -prec_o=int8 -save_unquant=1" "-fquant=2 -prec_o=int8 -save_unquant=1" "-fquant=1 -prec_o=fp8 -save_unquant=1" "-fquant=2 -prec_o=fp8 -save_unquant=1"; do
|
||||
for pr_i in "fp16" "bf16" ; do
|
||||
for fadd in "0" "1"; do
|
||||
$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=99 -n=13
|
||||
$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=17 -n=16
|
||||
$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=1 -n=100
|
||||
$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=4 -n=128
|
||||
$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=80 -n=127
|
||||
$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=22 -n=255 -stride=256
|
||||
$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=7 -n=599
|
||||
$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=19 -n=512
|
||||
$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=33 -n=313 -stride=1000
|
||||
$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=11 -n=510
|
||||
$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=171 -n=676 -stride=818
|
||||
$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=91 -n=636
|
||||
$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=12 -n=768 -stride=800
|
||||
$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=100 -n=766 -stride=812
|
||||
$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=31 -n=1024
|
||||
$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=64 -n=1000 -stride=1004
|
||||
$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=8 -n=1501
|
||||
$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=3 -n=1826
|
||||
$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=5 -n=2040
|
||||
$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=7 -n=2734
|
||||
$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=1 -n=3182
|
||||
$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=9 -n=4096
|
||||
$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=3 -n=8192
|
||||
# 0: for no specific RMSNorm; 1: for T-5 like RMSNorm
|
||||
for s in "0" "1"; do
|
||||
$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=99 -n=13
|
||||
$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=17 -n=16
|
||||
$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=1 -n=100
|
||||
$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=4 -n=128
|
||||
$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=80 -n=127
|
||||
# $EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=22 -n=255 -stride=256
|
||||
$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=7 -n=599
|
||||
$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=19 -n=512
|
||||
# $EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=33 -n=313 -stride=1000
|
||||
$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=11 -n=510
|
||||
# $EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=171 -n=676 -stride=818
|
||||
$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=91 -n=636
|
||||
# $EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=12 -n=768 -stride=800
|
||||
# $EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=100 -n=766 -stride=812
|
||||
$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=31 -n=1024
|
||||
# $EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=64 -n=1000 -stride=1004
|
||||
$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=8 -n=1501
|
||||
$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=3 -n=1826
|
||||
$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=5 -n=2040
|
||||
$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=7 -n=2734
|
||||
$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=1 -n=3182
|
||||
$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=9 -n=4096
|
||||
$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=3 -n=8192
|
||||
done
|
||||
done
|
||||
done
|
||||
done
|
||||
@@ -36,8 +39,11 @@ done
|
||||
for fquant in ""
|
||||
for pr_i in "fp16" "bf16" ; do
|
||||
for fadd in "0" "1"; do
|
||||
$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=1 -n=10547
|
||||
# 0: for no specific RMSNorm; 1: for T-5 like RMSNorm
|
||||
for s in "0" "1"; do
|
||||
$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=1 -n=10547
|
||||
#$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=3 -n=17134
|
||||
done
|
||||
done
|
||||
done
|
||||
done
|
||||
|
||||
Reference in New Issue
Block a user