From 3499fe67ff24f9e3610b208d14589fac645e0ea7 Mon Sep 17 00:00:00 2001 From: MHYangAMD Date: Wed, 16 Jul 2025 14:05:26 +0800 Subject: [PATCH 1/5] [CK_TILE] Enhance RMSNorm Accuracy: New Pipeline Pass for Selectable Implementation (#2409) * Add Rmsnorm2dFwdPipelineModelSensitiveT5Pass * Update rmsnorm2d_fwd_pipeline_model_sensitive_pass 1. Add BlockReduce2dTreeCrossWarpSync * Add Rmsnorm2dFusedModelSensitiveEnum * Update patch 1. Reverse generate.py 2. Remove comment in generate.py 3. Update tree cross warp reduce * Refactor RMSNorm model enum and introduce T5-like option * Update the n stage for cross warp reduce * Add new cmdline option in RMSNorm for new pipeline testing --------- Co-authored-by: Clement Lin Co-authored-by: ClementLinCF <162283536+ClementLinCF@users.noreply.github.com> --- .../10_rmsnorm2d/example_rmsnorm2d_fwd.cpp | 39 ++- example/ck_tile/10_rmsnorm2d/generate.py | 257 ++++++++++++------ .../ck_tile/10_rmsnorm2d/rmsnorm2d_fwd.cpp | 35 ++- .../ck_tile/10_rmsnorm2d/rmsnorm2d_fwd.hpp | 2 + .../ck_tile/10_rmsnorm2d/script/perf_test.sh | 103 ++++--- .../ck_tile/10_rmsnorm2d/script/smoke_test.sh | 54 ++-- .../ops/reduce/block/block_reduce2d.hpp | 133 +++++++++ include/ck_tile/ops/rmsnorm2d.hpp | 1 + .../rmsnorm2d/kernel/rmsnorm2d_fwd_kernel.hpp | 17 +- .../rmsnorm2d_fwd_pipeline_default_policy.hpp | 9 + ...rm2d_fwd_pipeline_model_sensitive_pass.hpp | 228 ++++++++++++++++ .../rmsnorm2d_fwd_pipeline_one_pass.hpp | 5 +- .../pipeline/rmsnorm2d_fwd_traits.hpp | 31 ++- 13 files changed, 730 insertions(+), 184 deletions(-) create mode 100644 include/ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_model_sensitive_pass.hpp diff --git a/example/ck_tile/10_rmsnorm2d/example_rmsnorm2d_fwd.cpp b/example/ck_tile/10_rmsnorm2d/example_rmsnorm2d_fwd.cpp index 25598282e3..13924f5fe9 100644 --- a/example/ck_tile/10_rmsnorm2d/example_rmsnorm2d_fwd.cpp +++ b/example/ck_tile/10_rmsnorm2d/example_rmsnorm2d_fwd.cpp @@ -15,13 +15,14 @@ auto create_args(int argc, char* argv[]) .insert("v", "1", "cpu validation or not") .insert("prec", "fp16", "precision") .insert("warmup", "0", "cold iter") - .insert("repeat", "1", "hot iter"); + .insert("repeat", "1", "hot iter") + .insert("s", "0", "sensitive model mode, 0: for no specific model, 1: for T5-like model"); bool result = arg_parser.parse(argc, argv); return std::make_tuple(result, arg_parser); } -template +template bool run(const ck_tile::ArgParser& arg_parser) { ck_tile::index_t m = arg_parser.get_int("m"); @@ -81,8 +82,10 @@ bool run(const ck_tile::ArgParser& arg_parser) false, // kSaveInvRms false, // kSaveUnquant kTwoPass, - ck_tile::Rmsnorm2dFusedAddEnum::NO_ADD, // fuse add - ck_tile::Rmsnorm2dFusedQuantEnum::NO_SWEEP>; // fuse quant + ck_tile::Rmsnorm2dFusedAddEnum::NO_ADD, // fuse add + ck_tile::Rmsnorm2dFusedQuantEnum::NO_SWEEP, // fuse quant + static_cast( + USEModelSensitive)>; using Problem = ck_tile::Rmsnorm2dFwdPipelineProblem; using TwoPassPipeline = ck_tile::Rmsnorm2dFwdPipelineTwoPass; - using Pipeline = std::conditional_t; + using T5PassPipeline = ck_tile::Rmsnorm2dFwdPipelineModelSensitiveT5Pass; + + using Pipeline = + std::conditional_t<(PipelineTraits::kUseModelSensitiveRMSNorm == + ck_tile::Rmsnorm2dSensitiveEnum::NO_SPECIFIC_MODEL || + PipelineTraits::kTwoPass), // TODO: consider TwoPass for T5PassPipeline + std::conditional_t, // kUseModelSensitiveRMSNorm + // == 0 + T5PassPipeline>; using Default2DEpilogueProblem = ck_tile:: Default2DEpilogueProblem; @@ -172,7 +185,8 @@ bool run(const ck_tile::ArgParser& arg_parser) std::cout << "[" << data_type << "]" << " m:" << m << ", n:" << n << ", stride:" << stride - << ", valid:" << (pass ? "y" : "n") << std::flush << std::endl; + << ", s:" << USEModelSensitive << ", valid:" << (pass ? "y" : "n") << std::flush + << std::endl; } return pass; @@ -184,10 +198,19 @@ int main(int argc, char* argv[]) if(!result) return -1; - const std::string data_type = arg_parser.get_str("prec"); + const std::string data_type = arg_parser.get_str("prec"); + const int use_model_sensitive_rmsnorm = arg_parser.get_int("s"); + if(data_type == "fp16") { - return run(arg_parser) ? 0 : -2; + if(use_model_sensitive_rmsnorm == 0) // 0: for no specific RMSNorm + { + return run(arg_parser) ? 0 : -2; + } + else if(use_model_sensitive_rmsnorm == 1) // 1: for T5-like RMSNorm + { + return run(arg_parser) ? 0 : -2; + } } return -3; diff --git a/example/ck_tile/10_rmsnorm2d/generate.py b/example/ck_tile/10_rmsnorm2d/generate.py index 4296b7373e..b0ba400af1 100644 --- a/example/ck_tile/10_rmsnorm2d/generate.py +++ b/example/ck_tile/10_rmsnorm2d/generate.py @@ -65,7 +65,8 @@ template + ck_tile::index_t kFusedQuant_ = 0, + ck_tile::index_t kUseModelSensitiveRMSNorm_ = 0> struct rmsnorm2d_fwd_traits_ { using XDataType = ck_tile::remove_cvref_t; @@ -127,8 +128,9 @@ struct rmsnorm2d_fwd_traits_ static constexpr bool kSaveInvRms = kSaveInvRms_; static constexpr bool kSaveUnquant = kSaveUnquant_; static constexpr bool kTwoPass = kTwoPass_; - static constexpr ck_tile::index_t kFusedAdd = kFusedAdd_; - static constexpr ck_tile::index_t kFusedQuant = kFusedQuant_; + static constexpr ck_tile::index_t kFusedAdd = kFusedAdd_; + static constexpr ck_tile::index_t kFusedQuant = kFusedQuant_; + static constexpr ck_tile::index_t kUseModelSensitiveRMSNorm = kUseModelSensitiveRMSNorm_; }; template + int kFusedQuant_, + int kUseModelSensitiveRMSNorm_> using traits_ = rmsnorm2d_fwd_traits_; + kFusedQuant_, + kUseModelSensitiveRMSNorm_>; """ API_COMMON_HEADER = """ @@ -197,7 +201,8 @@ float rmsnorm2d_fwd_(const S& s, A a) Traits_::kSaveUnquant, Traits_::kTwoPass, static_cast(Traits_::kFusedAdd), - static_cast(Traits_::kFusedQuant)>; + static_cast(Traits_::kFusedQuant), + static_cast(Traits_::kUseModelSensitiveRMSNorm)>; using PipelineProblem = ck_tile::Rmsnorm2dFwdPipelineProblem::XDataType, @@ -213,7 +218,13 @@ float rmsnorm2d_fwd_(const S& s, A a) using OnePassPipeline = ck_tile::Rmsnorm2dFwdPipelineOnePass; using TwoPassPipeline = ck_tile::Rmsnorm2dFwdPipelineTwoPass; - using Pipeline = std::conditional_t; + using T5PassPipeline = ck_tile::Rmsnorm2dFwdPipelineModelSensitiveT5Pass; + + using Pipeline = std::conditional_t< + (Traits_::kUseModelSensitiveRMSNorm == 0 || Traits_::kTwoPass), // TODO: consider TwoPass for T5PassPipeline + std::conditional_t, // kUseModelSensitiveRMSNorm == 0 + T5PassPipeline + >; using Default2DEpilogueProblem = ck_tile::Default2DEpilogueProblem; using Default2DEpilogue = ck_tile::Default2DEpilogue; @@ -387,12 +398,13 @@ float rmsnorm2d_fwd(rmsnorm2d_fwd_traits t, F_kTwoPass : bool F_kFusedAdd : int F_kFusedQuant : int + F_use_model_sensitive_rmsnorm : int @property def trait_name(self) ->str: t_ = f'{DATA_TYPE_MAP[self.F_XDataType]}, {DATA_TYPE_MAP[self.F_YDataType]}, {DATA_TYPE_MAP[self.F_SmoothScaleDataType]}, {DATA_TYPE_MAP[self.F_YScaleDataType]}, {DATA_TYPE_MAP[self.F_UnquantYDataType]}, {self.F_Repeat_M:2}, {self.F_Repeat_N:2}, {self.F_ThreadPerBlock_M:2}, {self.F_ThreadPerBlock_N:4}' t_ += f', {self.F_Vector_N:2}, {BOOL_MAP(self.F_kPadN):5}, {BOOL_MAP(self.F_kSaveInvRms):5}, {BOOL_MAP(self.F_kSaveUnquant):5}' - t_ += f', {BOOL_MAP(self.F_kTwoPass):5}, {self.F_kFusedAdd:4}, {self.F_kFusedQuant:4}' + t_ += f', {BOOL_MAP(self.F_kTwoPass):5}, {self.F_kFusedAdd:4}, {self.F_kFusedQuant:4}, {self.F_use_model_sensitive_rmsnorm:4}' return t_ # string when calling this kernel @@ -413,6 +425,7 @@ float rmsnorm2d_fwd(rmsnorm2d_fwd_traits t, F_add : int F_sweep : int F_saveunquant : bool + F_use_model_sensitive_rmsnorm : int instance_list : List[Any] # List[h_traits] @property @@ -426,6 +439,10 @@ float rmsnorm2d_fwd(rmsnorm2d_fwd_traits t, nnn = nnn + '_' + FUSED_FUSED_SWEEP_STR_MAP[self.F_sweep] if self.F_saveunquant: nnn = nnn + '_saveunquant' + if self.F_use_model_sensitive_rmsnorm == 0: + nnn = nnn + '_nsm' + elif self.F_use_model_sensitive_rmsnorm == 1: + nnn = nnn + '_t5ml' return nnn @property @@ -481,9 +498,9 @@ float rmsnorm2d_fwd(rmsnorm2d_fwd_traits t, elif ins.F_kFusedQuant == 2: _sweep_cond = 't.fused_quant == {f_fused_sweep} && (t.prec_sy == \"{f_sy_type}\" && t.save_unquant == {f_suq})'.format( f_fused_sweep = ins.F_kFusedQuant, f_sy_type=ins.F_YScaleDataType, f_suq=BOOL_MAP(ins.F_kSaveUnquant)) - _cond = '((a.n % {f_vec_n} == 0) && (t.fused_add == {f_fused_add}) && ({f_sweep_cond}))'.format( + _cond = '((a.n % {f_vec_n} == 0) && (t.fused_add == {f_fused_add}) && ({f_sweep_cond}) && (t.use_model_sensitive_rmsnorm == {f_use_model_sensitive_rmsnorm}) )'.format( f_vec_n = ins.F_Vector_N, f_fused_add = ins.F_kFusedAdd, - f_sweep_cond = _sweep_cond) + f_sweep_cond = _sweep_cond, f_use_model_sensitive_rmsnorm = ins.F_use_model_sensitive_rmsnorm) inner_str += self.API_INNER_CASE.format(F_if = get_if_str(idx_in_n, len_in_n, False), F_VEC_COND = _cond, F_instance_func=ins.call_name) #inner_str = inner_str + vec_str @@ -516,85 +533,149 @@ float rmsnorm2d_fwd(rmsnorm2d_fwd_traits t, fused_sweep_list = [0, 1, 2] # NOTE: only single pass can use fused (smooth) dynamic quant bool_list = [False, True] - # rm rn tm tn vn pd mv unquant 2p add sweep - h_trait_dict = {'64' : [ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 1, 8, 8, 8, True, False, False, False, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 1, 4, 16, 4, True, False, False, False, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 1, 4, 64, 1, True, False, False, False, 0, 0)], - '128' : [ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 1, 4, 16, 8, True, False, False, False, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 1, 4, 64, 2, True, False, False, False, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 2, 4, 64, 1, True, False, False, False, 0, 0)], - '256' : [ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 1, 4, 64, 4, True, False, False, False, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 2, 4, 64, 2, True, False, False, False, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 4, 64, 1, True, False, False, False, 0, 0)], - '512' : [ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 1, 4, 64, 8, True, False, False, False, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 2, 4, 64, 4, True, False, False, False, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 4, 64, 2, True, False, False, False, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 8, 4, 64, 1, True, False, False, False, 0, 0)], - '640' : [ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 5, 4, 64, 2, True, False, False, False, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 5, 4, 128, 1, True, False, False, False, 0, 0)], - '768' : [ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 3, 4, 64, 4, True, False, False, False, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 6, 4, 64, 2, True, False, False, False, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 12, 4, 64, 1, True, False, False, False, 0, 0)], - '1024' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 2, 2, 64, 8, True, False, False, False, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 2, 64, 4, True, False, False, False, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 8, 2, 64, 2, True, False, False, False, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 1, 256, 1, True, False, False, False, 0, 0)], - '1536' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 3, 4, 64, 8, True, False, False, False, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 3, 2, 128, 4, True, False, False, False, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 3, 1, 256, 2, True, False, False, False, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 6, 1, 256, 1, True, False, False, False, 0, 0)], - '2048' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 1, 1, 256, 8, True, False, False, False, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 2, 1, 256, 4, True, False, False, False, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 1, 256, 2, True, False, False, False, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 8, 1, 256, 1, True, False, False, False, 0, 0)], - '3072' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 3, 1, 128, 8, True, False, False, False, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 3, 1, 256, 4, True, False, False, False, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 6, 1, 256, 2, True, False, False, False, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 3, 1,1024, 1, True, False, False, False, 0, 0)], - '4096' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 2, 1, 256, 8, True, False, False, False, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 1, 256, 4, True, False, False, False, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 2, 1,1024, 2, True, False, False, False, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 1,1024, 1, True, False, False, False, 0, 0)], - '6144' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 3, 1, 256, 8, True, False, False, False, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 3, 1, 512, 4, True, False, False, False, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 3, 1,1024, 2, True, False, False, False, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 6, 1,1024, 1, True, False, False, False, 0, 0)], - '8192' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 1, 256, 8, True, False, False, False, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 1, 512, 4, True, False, False, False, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 1,1024, 2, True, False, False, False, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 8, 1,1024, 1, True, False, False, False, 0, 0)], - 'big' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 1, 1,1024, 8, True, False, False, True, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 1, 256, 4, True, False, False, True, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 12, 1, 256, 2, True, False, False, True, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 1,1024, 1, True, False, False, True, 0, 0)]} + h_trait_dicts = { + 0: { + # rm rn tm tn vn pd mv unquant 2p add sweep srm + '64' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 1, 8, 8, 8, True, False, False, False, 0, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 1, 4, 16, 4, True, False, False, False, 0, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 1, 4, 64, 1, True, False, False, False, 0, 0, 0)], + '128' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 1, 4, 16, 8, True, False, False, False, 0, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 1, 4, 64, 2, True, False, False, False, 0, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 2, 4, 64, 1, True, False, False, False, 0, 0, 0)], + '256' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 1, 4, 64, 4, True, False, False, False, 0, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 2, 4, 64, 2, True, False, False, False, 0, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 4, 64, 1, True, False, False, False, 0, 0, 0)], + '512' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 1, 4, 64, 8, True, False, False, False, 0, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 2, 4, 64, 4, True, False, False, False, 0, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 4, 64, 2, True, False, False, False, 0, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 8, 4, 64, 1, True, False, False, False, 0, 0, 0)], + '640' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 5, 4, 64, 2, True, False, False, False, 0, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 5, 4, 128, 1, True, False, False, False, 0, 0, 0)], + '768' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 3, 4, 64, 4, True, False, False, False, 0, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 6, 4, 64, 2, True, False, False, False, 0, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 12, 4, 64, 1, True, False, False, False, 0, 0, 0)], + '1024' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 2, 2, 64, 8, True, False, False, False, 0, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 2, 64, 4, True, False, False, False, 0, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 8, 2, 64, 2, True, False, False, False, 0, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 1, 256, 1, True, False, False, False, 0, 0, 0)], + '1536' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 3, 4, 64, 8, True, False, False, False, 0, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 3, 2, 128, 4, True, False, False, False, 0, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 3, 1, 256, 2, True, False, False, False, 0, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 6, 1, 256, 1, True, False, False, False, 0, 0, 0)], + '2048' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 1, 1, 256, 8, True, False, False, False, 0, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 2, 1, 256, 4, True, False, False, False, 0, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 1, 256, 2, True, False, False, False, 0, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 8, 1, 256, 1, True, False, False, False, 0, 0, 0)], + '3072' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 3, 1, 128, 8, True, False, False, False, 0, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 3, 1, 256, 4, True, False, False, False, 0, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 6, 1, 256, 2, True, False, False, False, 0, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 3, 1,1024, 1, True, False, False, False, 0, 0, 0)], + '4096' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 2, 1, 256, 8, True, False, False, False, 0, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 1, 256, 4, True, False, False, False, 0, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 2, 1,1024, 2, True, False, False, False, 0, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 1,1024, 1, True, False, False, False, 0, 0, 0)], + '6144' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 3, 1, 256, 8, True, False, False, False, 0, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 3, 1, 512, 4, True, False, False, False, 0, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 3, 1,1024, 2, True, False, False, False, 0, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 6, 1,1024, 1, True, False, False, False, 0, 0, 0)], + '8192' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 1, 256, 8, True, False, False, False, 0, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 1, 512, 4, True, False, False, False, 0, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 1,1024, 2, True, False, False, False, 0, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 8, 1,1024, 1, True, False, False, False, 0, 0, 0)], + 'big' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 1, 1,1024, 8, True, False, False, True, 0, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 1, 256, 4, True, False, False, True, 0, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 12, 1, 256, 2, True, False, False, True, 0, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 1,1024, 1, True, False, False, True, 0, 0, 0)] + }, + 1: { + # rm rn tm tn vn pd mv unquant 2p add sweep srm + '64' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 1, 8, 8, 8, True, False, False, False, 0, 0, 1), + h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 1, 4, 16, 4, True, False, False, False, 0, 0, 1), + h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 1, 4, 64, 1, True, False, False, False, 0, 0, 1)], + '128' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 1, 4, 16, 8, True, False, False, False, 0, 0, 1), + h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 1, 4, 64, 2, True, False, False, False, 0, 0, 1), + h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 2, 4, 64, 1, True, False, False, False, 0, 0, 1)], + '256' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 1, 8, 32, 8, True, False, False, False, 0, 0, 1), + h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 1, 4, 64, 4, True, False, False, False, 0, 0, 1), + h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 2, 4, 64, 2, True, False, False, False, 0, 0, 1), + h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 4, 64, 1, True, False, False, False, 0, 0, 1)], + '512' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 1, 4, 64, 8, True, False, False, False, 0, 0, 1), + h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 2, 4, 64, 4, True, False, False, False, 0, 0, 1), + h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 4, 64, 2, True, False, False, False, 0, 0, 1), + h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 8, 4, 64, 1, True, False, False, False, 0, 0, 1)], + '640' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 1, 2, 128, 8, True, False, False, False, 0, 0, 1), + h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 5, 4, 64, 2, True, False, False, False, 0, 0, 1), + h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 5, 4, 128, 1, True, False, False, False, 0, 0, 1)], + '768' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 1, 2, 128, 8, True, False, False, False, 0, 0, 1), + h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 3, 4, 64, 4, True, False, False, False, 0, 0, 1), + h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 6, 4, 64, 2, True, False, False, False, 0, 0, 1), + h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 12, 4, 64, 1, True, False, False, False, 0, 0, 1)], + '1024' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 1, 2, 128, 8, True, False, False, False, 0, 0, 1), + h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 2, 64, 4, True, False, False, False, 0, 0, 1), + h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 8, 2, 64, 2, True, False, False, False, 0, 0, 1), + h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 1, 256, 1, True, False, False, False, 0, 0, 1)], + '1536' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 1, 1, 256, 8, True, False, False, False, 0, 0, 1), + h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 3, 2, 128, 4, True, False, False, False, 0, 0, 1), + h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 3, 1, 256, 2, True, False, False, False, 0, 0, 1), + h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 6, 1, 256, 1, True, False, False, False, 0, 0, 1)], + '2048' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 1, 1, 256, 8, True, False, False, False, 0, 0, 1), + h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 2, 1, 256, 4, True, False, False, False, 0, 0, 1), + h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 1, 256, 2, True, False, False, False, 0, 0, 1), + h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 8, 1, 256, 1, True, False, False, False, 0, 0, 1)], + '3072' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 2, 1, 256, 8, True, False, False, False, 0, 0, 1), + h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 3, 1, 256, 4, True, False, False, False, 0, 0, 1), + h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 6, 1, 256, 2, True, False, False, False, 0, 0, 1), + h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 3, 1,1024, 1, True, False, False, False, 0, 0, 1)], + '4096' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 2, 1, 256, 8, True, False, False, False, 0, 0, 1), + h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 1, 256, 4, True, False, False, False, 0, 0, 1), + h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 2, 1,1024, 2, True, False, False, False, 0, 0, 1), + h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 1,1024, 1, True, False, False, False, 0, 0, 1)], + '6144' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 3, 1, 256, 8, True, False, False, False, 0, 0, 1), + h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 3, 1, 512, 4, True, False, False, False, 0, 0, 1), + h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 3, 1,1024, 2, True, False, False, False, 0, 0, 1), + h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 6, 1,1024, 1, True, False, False, False, 0, 0, 1)], + '8192' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 1, 256, 8, True, False, False, False, 0, 0, 1), + h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 1, 512, 4, True, False, False, False, 0, 0, 1), + h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 1,1024, 2, True, False, False, False, 0, 0, 1), + h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 8, 1,1024, 1, True, False, False, False, 0, 0, 1)], + 'big' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 1, 1,1024, 8, True, False, False, True, 0, 0, 1), + h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 1, 256, 4, True, False, False, True, 0, 0, 1), + h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 12, 1, 256, 2, True, False, False, True, 0, 0, 1), + h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 1,1024, 1, True, False, False, True, 0, 0, 1)] + } + } + total_blob = list() - for hs_key in h_trait_dict: - hs = h_trait_dict[hs_key] - current_n = hs[0].F_Repeat_N * hs[0].F_ThreadPerBlock_N * hs[0].F_Vector_N - for dtype, scale_type, fused_add, fused_quant, save_unquant in itertools.product(dtype_list, scale_list, fused_add_list, fused_sweep_list, bool_list): - prec_i, prec_o = dtype.split(',') - scale_sm, scale_y = scale_type.split(',') - if prec_o in dynamic_quant_out_dtype and fused_quant != 1 and fused_quant != 2: - continue # skip non dynamic quant case - if (fused_quant == 1 or fused_quant == 2) and hs_key == 'big': - continue - if (fused_quant == 0 and save_unquant == True): - continue # save_unquant should always be false when there is no quant enabled - current_hs = list() - for chs_ in hs: - h_ = copy.copy(chs_) # copy the base instance out - h_.F_XDataType = prec_i - h_.F_YDataType = prec_o - h_.F_SmoothScaleDataType = scale_sm - h_.F_YScaleDataType = scale_y - h_.F_UnquantYDataType = prec_i - h_.F_kFusedAdd = fused_add - h_.F_kFusedQuant = fused_quant - h_.F_kSaveUnquant = save_unquant - current_hs.append(h_) # + "\n" - #f.write(str(f.parent / GEN_DIR / (blobs.api_common_header_ - current_n_str = 'big' if hs_key == 'big' else current_n - total_blob.append(h_instance(dtype, current_n_str, fused_add, fused_quant, save_unquant, current_hs)) + + for model_sensitive_flag in [0, 1]: # 0: default; 1: model sensitive + current_trait_dict = h_trait_dicts[model_sensitive_flag] + for hs_key in current_trait_dict: + hs = current_trait_dict[hs_key] + current_n = hs_key + for dtype, scale_type, fused_add, fused_quant, save_unquant in itertools.product(dtype_list, scale_list, fused_add_list, fused_sweep_list, bool_list): + prec_i, prec_o = dtype.split(',') + scale_sm, scale_y = scale_type.split(',') + if prec_o in dynamic_quant_out_dtype and fused_quant != 1 and fused_quant != 2: + continue # skip non dynamic quant case + if (fused_quant == 1 or fused_quant == 2) and hs_key == 'big': + continue + if (fused_quant == 0 and save_unquant == True): + continue # save_unquant should always be false when there is no quant enabled + current_hs = list() + for chs_ in hs: + h_ = copy.copy(chs_) # copy the base instance out + h_.F_XDataType = prec_i + h_.F_YDataType = prec_o + h_.F_SmoothScaleDataType = scale_sm + h_.F_YScaleDataType = scale_y + h_.F_UnquantYDataType = prec_i + h_.F_kFusedAdd = fused_add + h_.F_kFusedQuant = fused_quant + h_.F_kSaveUnquant = save_unquant + current_hs.append(h_) # + "\n" + #f.write(str(f.parent / GEN_DIR / (blobs.api_common_header_ + current_n_str = 'big' if hs_key == 'big' else current_n + total_blob.append(h_instance(dtype, current_n_str, fused_add, fused_quant, save_unquant, h_.F_use_model_sensitive_rmsnorm, current_hs)) return total_blob def list_blobs(self) -> None: diff --git a/example/ck_tile/10_rmsnorm2d/rmsnorm2d_fwd.cpp b/example/ck_tile/10_rmsnorm2d/rmsnorm2d_fwd.cpp index d5be4384ab..049a0cad41 100644 --- a/example/ck_tile/10_rmsnorm2d/rmsnorm2d_fwd.cpp +++ b/example/ck_tile/10_rmsnorm2d/rmsnorm2d_fwd.cpp @@ -52,7 +52,8 @@ auto create_args(int argc, char* argv[]) .insert("fadd", "0", "fused-add, 0:no fused add, 1:preadd+store, 2:preadd only") .insert("fquant", "0", "fused-quant, 0:no, 1:smooth-dynamic-quant, 2:dynamic-quant") .insert("warmup", "5", "cold iter") - .insert("repeat", "20", "hot iter"); + .insert("repeat", "20", "hot iter") + .insert("s", "0", "sensitive model mode, 0: for no specific model, 1: for T5-like model"); bool result = arg_parser.parse(argc, argv); return std::make_tuple(result, arg_parser); @@ -66,15 +67,16 @@ template bool run(const ck_tile::ArgParser& arg_parser) { - ck_tile::index_t m = arg_parser.get_int("m"); - ck_tile::index_t n = arg_parser.get_int("n"); - float epsilon = arg_parser.get_float("e"); - int kname = arg_parser.get_int("kname"); - int do_validation = arg_parser.get_int("v"); - int fused_add = arg_parser.get_int("fadd"); - int fused_quant = arg_parser.get_int("fquant"); - int warmup = arg_parser.get_int("warmup"); - int repeat = arg_parser.get_int("repeat"); + ck_tile::index_t m = arg_parser.get_int("m"); + ck_tile::index_t n = arg_parser.get_int("n"); + float epsilon = arg_parser.get_float("e"); + int kname = arg_parser.get_int("kname"); + int do_validation = arg_parser.get_int("v"); + int fused_add = arg_parser.get_int("fadd"); + int fused_quant = arg_parser.get_int("fquant"); + int warmup = arg_parser.get_int("warmup"); + int repeat = arg_parser.get_int("repeat"); + const int use_model_sensitive_rmsnorm = arg_parser.get_int("s"); ck_tile::index_t x_stride = arg_parser.get_int("x_stride"); if(x_stride < 0) @@ -194,10 +196,17 @@ bool run(const ck_tile::ArgParser& arg_parser) std::cout << "[" << prec_str << "]" << " m:" << m << ", n:" << n << ", x_stride:" << x_stride << ", xr_stride:" << xr_stride << ", y_stride:" << y_stride - << ", yr_stride:" << yr_stride << std::flush; + << ", yr_stride:" << yr_stride << ", s:" << use_model_sensitive_rmsnorm << std::flush; - rmsnorm2d_fwd_traits traits{ - prec_i, prec_o, prec_sm, prec_sy, SaveRms, SaveUnquant, fused_add, fused_quant}; + rmsnorm2d_fwd_traits traits{prec_i, + prec_o, + prec_sm, + prec_sy, + SaveRms, + SaveUnquant, + fused_add, + fused_quant, + use_model_sensitive_rmsnorm}; rmsnorm2d_fwd_args args{x_buf.GetDeviceBuffer(), fused_add != 0 ? x_residual_buf.GetDeviceBuffer() : nullptr, diff --git a/example/ck_tile/10_rmsnorm2d/rmsnorm2d_fwd.hpp b/example/ck_tile/10_rmsnorm2d/rmsnorm2d_fwd.hpp index bb4a2f5ef4..c1090ed28b 100644 --- a/example/ck_tile/10_rmsnorm2d/rmsnorm2d_fwd.hpp +++ b/example/ck_tile/10_rmsnorm2d/rmsnorm2d_fwd.hpp @@ -64,6 +64,8 @@ struct rmsnorm2d_fwd_traits bool save_unquant; int fused_add; // 0:no-add, 1:pre-add-store, 2:pre-add int fused_quant; // 0:no-sweep, 1:smooth-dynamic-quant, 2:dynamic-quant + + int use_model_sensitive_rmsnorm = 0; // 0: Use default RMSNorm; 1: Use T5-like implementation }; float rmsnorm2d_fwd(rmsnorm2d_fwd_traits, rmsnorm2d_fwd_args, const ck_tile::stream_config&); diff --git a/example/ck_tile/10_rmsnorm2d/script/perf_test.sh b/example/ck_tile/10_rmsnorm2d/script/perf_test.sh index 7b9d0820fd..bc4362c105 100755 --- a/example/ck_tile/10_rmsnorm2d/script/perf_test.sh +++ b/example/ck_tile/10_rmsnorm2d/script/perf_test.sh @@ -1,37 +1,74 @@ #!/bin/sh EXE="$(find . -name tile_rmsnorm2d_fwd -type f | head -n 1)" -$EXE -m=1 -n=1 -e=1e-12 -v=1 -prec=bf16 -repeat=1000 -$EXE -m=700 -n=80 -e=1e-12 -v=1 -prec=bf16 -repeat=1000 -$EXE -m=700 -n=128 -e=1e-12 -v=1 -prec=bf16 -repeat=1000 -$EXE -m=700 -n=144 -e=1e-12 -v=1 -prec=bf16 -repeat=1000 -$EXE -m=700 -n=168 -e=1e-12 -v=1 -prec=bf16 -repeat=1000 -$EXE -m=700 -n=184 -e=1e-12 -v=1 -prec=bf16 -repeat=1000 -$EXE -m=700 -n=256 -e=1e-12 -v=1 -prec=bf16 -repeat=1000 -$EXE -m=700 -n=288 -e=1e-12 -v=1 -prec=bf16 -repeat=1000 -$EXE -m=700 -n=344 -e=1e-12 -v=1 -prec=bf16 -repeat=1000 -$EXE -m=700 -n=376 -e=1e-12 -v=1 -prec=bf16 -repeat=1000 -$EXE -m=700 -n=448 -e=1e-12 -v=1 -prec=bf16 -repeat=1000 -$EXE -m=700 -n=512 -e=1e-12 -v=1 -prec=bf16 -repeat=1000 -$EXE -m=700 -n=924 -e=1e-12 -v=1 -prec=bf16 -repeat=1000 -$EXE -m=700 -n=1024 -e=1e-12 -v=1 -prec=bf16 -repeat=1000 -$EXE -m=700 -n=1078 -e=1e-12 -v=1 -prec=bf16 -repeat=1000 -$EXE -m=700 -n=1996 -e=1e-12 -v=1 -prec=bf16 -repeat=1000 -$EXE -m=700 -n=4080 -e=1e-12 -v=1 -prec=bf16 -repeat=1000 +# 0: for no specific RMSNorm +$EXE -m=1 -n=1 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000 -s=0 +$EXE -m=700 -n=80 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000 -s=0 +$EXE -m=700 -n=128 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000 -s=0 +$EXE -m=700 -n=144 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000 -s=0 +$EXE -m=700 -n=168 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000 -s=0 +$EXE -m=700 -n=184 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000 -s=0 +$EXE -m=700 -n=256 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000 -s=0 +$EXE -m=700 -n=288 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000 -s=0 +$EXE -m=700 -n=344 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000 -s=0 +$EXE -m=700 -n=376 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000 -s=0 +$EXE -m=700 -n=448 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000 -s=0 +$EXE -m=700 -n=512 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000 -s=0 +$EXE -m=700 -n=924 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000 -s=0 +$EXE -m=700 -n=1024 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000 -s=0 +$EXE -m=700 -n=1078 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000 -s=0 +$EXE -m=700 -n=1996 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000 -s=0 +$EXE -m=700 -n=4080 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000 -s=0 -$EXE -m=700 -n=80 -e=1e-12 -v=1 -prec=fp16 -repeat=1000 -$EXE -m=700 -n=128 -e=1e-12 -v=1 -prec=fp16 -repeat=1000 -$EXE -m=700 -n=144 -e=1e-12 -v=1 -prec=fp16 -repeat=1000 -$EXE -m=700 -n=168 -e=1e-12 -v=1 -prec=fp16 -repeat=1000 -$EXE -m=700 -n=184 -e=1e-12 -v=1 -prec=fp16 -repeat=1000 -$EXE -m=700 -n=256 -e=1e-12 -v=1 -prec=fp16 -repeat=1000 -$EXE -m=700 -n=288 -e=1e-12 -v=1 -prec=fp16 -repeat=1000 -$EXE -m=700 -n=344 -e=1e-12 -v=1 -prec=fp16 -repeat=1000 -$EXE -m=700 -n=376 -e=1e-12 -v=1 -prec=fp16 -repeat=1000 -$EXE -m=700 -n=448 -e=1e-12 -v=1 -prec=fp16 -repeat=1000 -$EXE -m=700 -n=512 -e=1e-12 -v=1 -prec=fp16 -repeat=1000 -$EXE -m=700 -n=924 -e=1e-12 -v=1 -prec=fp16 -repeat=1000 -$EXE -m=700 -n=1024 -e=1e-12 -v=1 -prec=fp16 -repeat=1000 -$EXE -m=700 -n=1078 -e=1e-12 -v=1 -prec=fp16 -repeat=1000 -$EXE -m=700 -n=1996 -e=1e-12 -v=1 -prec=fp16 -repeat=1000 -$EXE -m=700 -n=4080 -e=1e-12 -v=1 -prec=fp16 -repeat=1000 \ No newline at end of file +$EXE -m=700 -n=80 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000 -s=0 +$EXE -m=700 -n=128 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000 -s=0 +$EXE -m=700 -n=144 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000 -s=0 +$EXE -m=700 -n=168 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000 -s=0 +$EXE -m=700 -n=184 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000 -s=0 +$EXE -m=700 -n=256 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000 -s=0 +$EXE -m=700 -n=288 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000 -s=0 +$EXE -m=700 -n=344 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000 -s=0 +$EXE -m=700 -n=376 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000 -s=0 +$EXE -m=700 -n=448 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000 -s=0 +$EXE -m=700 -n=512 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000 -s=0 +$EXE -m=700 -n=924 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000 -s=0 +$EXE -m=700 -n=1024 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000 -s=0 +$EXE -m=700 -n=1078 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000 -s=0 +$EXE -m=700 -n=1996 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000 -s=0 +$EXE -m=700 -n=4080 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000 -s=0 + +# 1: for T5-like RMSNorm +$EXE -m=1 -n=1 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000 -s=1 +$EXE -m=700 -n=80 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000 -s=1 +$EXE -m=700 -n=128 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000 -s=1 +$EXE -m=700 -n=144 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000 -s=1 +$EXE -m=700 -n=168 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000 -s=1 +$EXE -m=700 -n=184 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000 -s=1 +$EXE -m=700 -n=256 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000 -s=1 +$EXE -m=700 -n=288 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000 -s=1 +$EXE -m=700 -n=344 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000 -s=1 +$EXE -m=700 -n=376 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000 -s=1 +$EXE -m=700 -n=448 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000 -s=1 +$EXE -m=700 -n=512 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000 -s=1 +$EXE -m=700 -n=924 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000 -s=1 +$EXE -m=700 -n=1024 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000 -s=1 +$EXE -m=700 -n=1078 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000 -s=1 +$EXE -m=700 -n=1996 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000 -s=1 +$EXE -m=700 -n=4080 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000 -s=1 + +$EXE -m=700 -n=80 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000 -s=1 +$EXE -m=700 -n=128 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000 -s=1 +$EXE -m=700 -n=144 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000 -s=1 +$EXE -m=700 -n=168 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000 -s=1 +$EXE -m=700 -n=184 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000 -s=1 +$EXE -m=700 -n=256 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000 -s=1 +$EXE -m=700 -n=288 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000 -s=1 +$EXE -m=700 -n=344 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000 -s=1 +$EXE -m=700 -n=376 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000 -s=1 +$EXE -m=700 -n=448 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000 -s=1 +$EXE -m=700 -n=512 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000 -s=1 +$EXE -m=700 -n=924 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000 -s=1 +$EXE -m=700 -n=1024 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000 -s=1 +$EXE -m=700 -n=1078 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000 -s=1 +$EXE -m=700 -n=1996 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000 -s=1 +$EXE -m=700 -n=4080 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000 -s=1 \ No newline at end of file diff --git a/example/ck_tile/10_rmsnorm2d/script/smoke_test.sh b/example/ck_tile/10_rmsnorm2d/script/smoke_test.sh index 2bad7a00ea..1c79dafadd 100755 --- a/example/ck_tile/10_rmsnorm2d/script/smoke_test.sh +++ b/example/ck_tile/10_rmsnorm2d/script/smoke_test.sh @@ -5,29 +5,32 @@ for fquant in "" "-fquant=1 -prec_o=int8" "-fquant=2 -prec_o=int8" "-fquant=1 -p "-fquant=1 -prec_o=int8 -save_unquant=1" "-fquant=2 -prec_o=int8 -save_unquant=1" "-fquant=1 -prec_o=fp8 -save_unquant=1" "-fquant=2 -prec_o=fp8 -save_unquant=1"; do for pr_i in "fp16" "bf16" ; do for fadd in "0" "1"; do -$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=99 -n=13 -$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=17 -n=16 -$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=1 -n=100 -$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=4 -n=128 -$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=80 -n=127 -$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=22 -n=255 -stride=256 -$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=7 -n=599 -$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=19 -n=512 -$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=33 -n=313 -stride=1000 -$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=11 -n=510 -$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=171 -n=676 -stride=818 -$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=91 -n=636 -$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=12 -n=768 -stride=800 -$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=100 -n=766 -stride=812 -$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=31 -n=1024 -$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=64 -n=1000 -stride=1004 -$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=8 -n=1501 -$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=3 -n=1826 -$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=5 -n=2040 -$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=7 -n=2734 -$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=1 -n=3182 -$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=9 -n=4096 -$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=3 -n=8192 +# 0: for no specific RMSNorm; 1: for T-5 like RMSNorm +for s in "0" "1"; do +$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=99 -n=13 +$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=17 -n=16 +$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=1 -n=100 +$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=4 -n=128 +$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=80 -n=127 +# $EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=22 -n=255 -stride=256 +$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=7 -n=599 +$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=19 -n=512 +# $EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=33 -n=313 -stride=1000 +$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=11 -n=510 +# $EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=171 -n=676 -stride=818 +$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=91 -n=636 +# $EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=12 -n=768 -stride=800 +# $EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=100 -n=766 -stride=812 +$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=31 -n=1024 +# $EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=64 -n=1000 -stride=1004 +$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=8 -n=1501 +$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=3 -n=1826 +$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=5 -n=2040 +$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=7 -n=2734 +$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=1 -n=3182 +$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=9 -n=4096 +$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=3 -n=8192 +done done done done @@ -36,8 +39,11 @@ done for fquant in "" for pr_i in "fp16" "bf16" ; do for fadd in "0" "1"; do -$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=1 -n=10547 +# 0: for no specific RMSNorm; 1: for T-5 like RMSNorm +for s in "0" "1"; do +$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=1 -n=10547 #$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=3 -n=17134 done done done +done diff --git a/include/ck_tile/ops/reduce/block/block_reduce2d.hpp b/include/ck_tile/ops/reduce/block/block_reduce2d.hpp index 6a1f926a9a..62c9944bd2 100644 --- a/include/ck_tile/ops/reduce/block/block_reduce2d.hpp +++ b/include/ck_tile/ops/reduce/block/block_reduce2d.hpp @@ -272,4 +272,137 @@ struct BlockReduce2dCrossWarpSync } }; +template +struct BlockReduce2dTreeCrossWarpSync +{ + using Problem = remove_cvref_t; + using BlockShape = typename Problem::BlockShape; + + template + CK_TILE_DEVICE static constexpr index_t GetReduceWarps() + { + constexpr index_t num_reduce_warps = [&]() { + using Dstr = typename YDistributedTensor_::StaticTileDistribution; + using DstrEncode = typename Dstr::DstrEncode; + using DstrEncodeDetail = typename DstrEncode::detail; + + constexpr index_t NDimR = Dstr::get_num_of_dimension_r(); + + constexpr index_t idim_p_warp = 0; + + index_t len_ = 1; + static_for<0, NDimR, 1>{}([&](auto idim_r) { + if constexpr(DstrEncodeDetail::does_p_own_r_[idim_p_warp][idim_r]) + { + constexpr index_t r_length = DstrEncode::rs_lengths_[idim_r]; + len_ *= r_length; + } + }); + return len_; + }(); + return num_reduce_warps; + } + + // return in byte + template + CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() + { + using DataType = typename YDistributedTensor_::DataType; + constexpr index_t thread_buf_size = YDistributedTensor_::get_thread_buffer_size(); + + // we need to store all data from every wave into smem + // e.g. 2x2 reduce along N + // -------------> reduce N + // | w0 | w1 | ___> | w01 | + // | w2 | w3 | | w23 | + // + // -> store data from every wave into LDS + // + // + // -------------> reduce N + // | w0 | w1 | w2 | w3 | -----> | w0123 | + // + // -> also store data from every wave into LDS + constexpr index_t num_warps = BlockShape::BlockSize / warpSize; + return num_warps * thread_buf_size * sizeof(DataType); + } + + template + CK_TILE_DEVICE void + operator()(YDistributedTensor_& y_tensor, void* smem, const ReduceFunc& reduce_func) + { + using Dstr = typename YDistributedTensor_::StaticTileDistribution; + using DstrEncode = typename Dstr::DstrEncode; + using DstrEncodeDetail = typename DstrEncode::detail; + using DataType = typename YDistributedTensor_::DataType; + + constexpr index_t NDimP = Dstr::get_num_of_dimension_p(); + constexpr index_t NDimR = Dstr::get_num_of_dimension_r(); + + constexpr index_t idim_p_lane = NDimP - 1; + constexpr index_t thread_buf_size = YDistributedTensor_::get_thread_buffer_size(); + + DataType* smem_ptr = reinterpret_cast(smem); + const index_t lane_id = get_lane_id(); + const index_t warp_id = get_warp_id(); + + constexpr index_t num_warps = BlockShape::BlockSize / get_warp_size(); + constexpr index_t num_reduce_warps = GetReduceWarps(); + + if constexpr(num_reduce_warps == 1) + return; + + // Each warp's lane 0 writes its partial results to shared memory + const index_t smem_offset = warp_id; + if(lane_id == 0) + { + static_for<0, thread_buf_size, 1>{}([&](auto i) { + // Store the i-th element of this warp's thread_buffer into SMEM + smem_ptr[smem_offset + i * num_warps] = y_tensor.get_thread_buffer()[i]; + }); + } + block_sync_lds(); + + // We let each warp holds a duplication to do reduction. + static_for<0, thread_buf_size, 1>{}([&](auto i) { + DataType v = 0; + if(lane_id < num_reduce_warps) + { + v = smem_ptr[lane_id + i * num_warps]; + } + + // cross-lane reduce for replication + // only reduce on R dimension correspond to lane + // (lane id maps to this R dimension) + static_for<0, NDimR, 1>{}([&](auto idim_r) { + // FIXME: nasty to use does_p_own_r_ + if constexpr(DstrEncodeDetail::does_p_own_r_[idim_p_lane][idim_r]) + { + constexpr index_t r_length = DstrEncode::rs_lengths_[idim_r]; + + constexpr index_t lid_over_rid_derivative = + DstrEncodeDetail::ps_over_rs_derivative_[idim_p_lane][idim_r]; + + static_assert(is_power_of_two_integer(r_length), + "wrong! only support power of 2 reduction"); + + constexpr index_t nstage = integer_log2_floor(r_length); + + // reduction sweep forward + static_for<0, nstage, 1>{}([&](auto istage) { + // pull data from remote lane + const auto o = + __shfl_xor(v, number{}.value); + + // reduce + v = reduce_func(v, o); + }); + } + }); + + y_tensor.get_thread_buffer()(i) = v; + }); + } +}; + } // namespace ck_tile diff --git a/include/ck_tile/ops/rmsnorm2d.hpp b/include/ck_tile/ops/rmsnorm2d.hpp index 3eec2a1ab6..610541b2e4 100644 --- a/include/ck_tile/ops/rmsnorm2d.hpp +++ b/include/ck_tile/ops/rmsnorm2d.hpp @@ -5,6 +5,7 @@ #include "ck_tile/ops/rmsnorm2d/kernel/rmsnorm2d_fwd_kernel.hpp" #include "ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_default_policy.hpp" +#include "ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_model_sensitive_pass.hpp" #include "ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_one_pass.hpp" #include "ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_problem.hpp" #include "ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_two_pass.hpp" diff --git a/include/ck_tile/ops/rmsnorm2d/kernel/rmsnorm2d_fwd_kernel.hpp b/include/ck_tile/ops/rmsnorm2d/kernel/rmsnorm2d_fwd_kernel.hpp index f0251177d4..6cb81b8856 100644 --- a/include/ck_tile/ops/rmsnorm2d/kernel/rmsnorm2d_fwd_kernel.hpp +++ b/include/ck_tile/ops/rmsnorm2d/kernel/rmsnorm2d_fwd_kernel.hpp @@ -58,13 +58,14 @@ struct Rmsnorm2dFwd static constexpr bool kSaveInvRms = Problem::Traits::kSaveInvRms; static constexpr bool kSaveUnquant = Problem::Traits::kSaveUnquant; - static constexpr index_t Block_M = Problem::BlockShape::Block_M; - static constexpr index_t Block_N = Problem::BlockShape::Block_N; - static constexpr bool kPadM = false; // always no need to pad along M - static constexpr bool kPadN = Problem::Traits::kPadN; - static constexpr bool kTwoPass = Problem::Traits::kTwoPass; - static constexpr auto kFusedAdd = Problem::Traits::kFusedAdd; - static constexpr auto kFusedQuant = Problem::Traits::kFusedQuant; + static constexpr index_t Block_M = Problem::BlockShape::Block_M; + static constexpr index_t Block_N = Problem::BlockShape::Block_N; + static constexpr bool kPadM = false; // always no need to pad along M + static constexpr bool kPadN = Problem::Traits::kPadN; + static constexpr bool kTwoPass = Problem::Traits::kTwoPass; + static constexpr auto kFusedAdd = Problem::Traits::kFusedAdd; + static constexpr auto kFusedQuant = Problem::Traits::kFusedQuant; + static constexpr auto kUseModelSensitiveRMSNorm = Problem::Traits::kUseModelSensitiveRMSNorm; static constexpr index_t ThreadPerWarp_N = Problem::BlockShape::ThreadPerWarp_N; static constexpr index_t Vector_N = Problem::BlockShape::Vector_N; @@ -150,6 +151,8 @@ struct Rmsnorm2dFwd if (kPadN) n += "_pn"; if (kSaveInvRms) n += "_rms"; if (kTwoPass) n += "_2p"; + if (kUseModelSensitiveRMSNorm == Rmsnorm2dSensitiveEnum::NO_SPECIFIC_MODEL) n += "_nsm"; + else if (kUseModelSensitiveRMSNorm == Rmsnorm2dSensitiveEnum::T5_MODEL_LIKE) n += "_t5ml"; return n; }(); auto prec_str = [&] () { diff --git a/include/ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_default_policy.hpp b/include/ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_default_policy.hpp index 356a2e12ca..df689c6b46 100644 --- a/include/ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_default_policy.hpp +++ b/include/ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_default_policy.hpp @@ -69,6 +69,15 @@ struct Rmsnorm2dFwdPipelineDefaultPolicy return BlockReduce2dCrossWarpSync{}; } + template + CK_TILE_HOST_DEVICE static constexpr auto GetBlockReduce2dTreeCrossWarpSync() + { + using P_ = BlockReduce2dProblem; + return BlockReduce2dTreeCrossWarpSync{}; + } + template CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() { diff --git a/include/ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_model_sensitive_pass.hpp b/include/ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_model_sensitive_pass.hpp new file mode 100644 index 0000000000..810c3c5243 --- /dev/null +++ b/include/ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_model_sensitive_pass.hpp @@ -0,0 +1,228 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_default_policy.hpp" +#include +#include + +namespace ck_tile { + +/** + * @brief This T5Pass implements the RMSNorm2d forward pipeline as a variant + * based on Rmsnorm2dFwdPipelineOnePass and Rmsnorm2dFwdPipelineTwoPass using a T5 model-like + * method. + * + * The T5 model, developed by Google, is a transformer-based architecture designed to perform + * a variety of NLP tasks. The T5-like approach employed here is characterized by how RMS + * normalization is handled, particularly where intermediate values are cast to BF16. This aims to + * achieve a similar value distribution to that produced by the VLLM hip implementation, thereby + * enhancing model accuracy. + * + * Note: While this implementation improves precision and can reduce discrepancies with VLLM, it is + * not guaranteed to eliminate all differences or ensure uniform outcomes across every use case. + * + * This implementation is a variant based on the original one-pass and two-pass approaches, + * allowing for both fused and non-fused add operations. + */ + +template +struct Rmsnorm2dFwdPipelineModelSensitiveT5Pass +{ + using Problem = ck_tile::remove_cvref_t; + using Policy = ck_tile::remove_cvref_t; + + using XDataType = ck_tile::remove_cvref_t; + using GammaDataType = ck_tile::remove_cvref_t; + using ComputeDataType = ck_tile::remove_cvref_t; + using YDataType = ck_tile::remove_cvref_t; + using InvRmsDataType = ck_tile::remove_cvref_t; + + using XResidualDataType = XDataType; + using YResidualDataType = XDataType; + + static constexpr bool kHasGamma = !std::is_same_v; + static constexpr bool kSaveInvRms = Problem::Traits::kSaveInvRms; + static constexpr bool kSaveUnquant = Problem::Traits::kSaveUnquant; + + static constexpr bool kNeedCrossWarpSync = Problem::kNeedCrossWarpSync; + static constexpr bool kPadM = false; // TODO - BlockRmsnorm2dFwdProblem::kPadM + static constexpr bool kPadN = Problem::Traits::kPadN; + static constexpr auto kFusedAdd = Problem::Traits::kFusedAdd; + static constexpr auto kFusedQuant = Problem::Traits::kFusedQuant; + + static constexpr const char* name = []() { + if constexpr(kNeedCrossWarpSync) + return "bpr_op"; // block per row + else + return "wpr_op"; // warp per row + }(); + + CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() + { + return Policy::template GetSmemSize(); + } + + template + CK_TILE_DEVICE auto operator()(const XWindow& x_window_, + const XResidualWindow& x_residual_window_, + const GammaWindow& gamma_window_, + YWindow& y_window_, + const YResidualWindow& y_residual_window_, + InvRmsWindow& inv_rms_window, + const SmoothScaleWindow& sm_scale_window_, + YScaleWindow& y_scale_window_, + UnquantYWindow& unquant_y_window, + ComputeDataType epsilon, + ck_tile::index_t row_size, + void* smem, + Epilogue) const + { + const auto x_window = + make_tile_window(x_window_, Policy::template MakeXBlockTileDistribution()); + const auto gamma_window = make_tile_window( + gamma_window_, Policy::template MakeGammaBlockTileDistribution()); + const auto x_residual_window = make_tile_window( + x_residual_window_, Policy::template MakeXBlockTileDistribution()); + auto y_residual_window = make_tile_window( + y_residual_window_, Policy::template MakeXBlockTileDistribution()); + + auto reduce_square_sum_func = ReduceOp::SquareAdd{}; + auto reduce_sum_func = ReduceOp::Add{}; + auto block_reduce2d = Policy::template GetBlockReduce2d(); + auto block_reduce2d_sync = Policy::template GetBlockReduce2dSync(); + auto block_reduce2d_tree_cross_warp_sync = + Policy::template GetBlockReduce2dTreeCrossWarpSync(); + + auto x = load_tile(x_window); + auto x_resi = load_tile(x_residual_window); + + // load gamma (TODO: support no gamma?) + const auto gamma = load_tile(gamma_window); + + auto acc = cast_tile(x); + + if constexpr(kFusedAdd == Rmsnorm2dFusedAddEnum::PRE_ADD || + kFusedAdd == Rmsnorm2dFusedAddEnum::PRE_ADD_STORE) + { + [[maybe_unused]] auto pre_out = + make_static_distributed_tensor(x.get_tile_distribution()); + + sweep_tile(x_resi, [&](auto idx) { + // compute x = x_resi + x + acc(idx) = type_convert(x_resi(idx)) + acc(idx); + + // To make norm input align with residual output + if constexpr(kFusedAdd == Rmsnorm2dFusedAddEnum::PRE_ADD_STORE) + { + if constexpr(std::is_same_v) + { + pre_out(idx) = float_to_bf16(acc(idx)); + } + else + { + pre_out(idx) = type_convert(acc(idx)); + } + acc(idx) = type_convert(pre_out(idx)); + } + }); + if constexpr(kFusedAdd == Rmsnorm2dFusedAddEnum::PRE_ADD_STORE) + { + store_tile(y_residual_window, pre_out); + } + } + + // compute mean square each-thread->cross-lane->cross-warp + auto square_sum = block_reduce2d.template MakeYBlockTile(); + set_tile(square_sum, 0); + if constexpr(Problem::BlockShape::Vector_N % 2 == 0) + { + sweep_tile( + acc, + [&](auto idx_0, auto idx_1) { + square_sum(idx_0) += acc[idx_0] * acc[idx_0] + acc[idx_1] * acc[idx_1]; + }, + sequence<1, 2>{}); + } + else + { + square_sum = block_reduce2d(acc, + reduce_square_sum_func.GetIdentityValue(), + reduce_square_sum_func); + } + block_reduce2d_sync(square_sum, reduce_sum_func); + block_reduce2d_tree_cross_warp_sync(square_sum, smem, reduce_sum_func); + + // compute inv-rms + auto inv_rms = tile_elementwise_in( + [&](const auto& v_) { return rsqrtf(v_ / row_size + epsilon); }, square_sum); + + if constexpr(kSaveInvRms) + store_tile(inv_rms_window, cast_tile(inv_rms)); + + // rmsnorm computation + auto rmsn = make_static_distributed_tensor(x.get_tile_distribution()); + sweep_tile(rmsn, [&, inv_rms_ = inv_rms](auto idx) { + constexpr auto i_idx = make_tuple(idx[number<0>{}]); + constexpr auto j_idx = make_tuple(idx[number<1>{}]); + + const auto gamma_ = type_convert(gamma[j_idx]); + + if constexpr(std::is_same_v) + { + const auto tmp0 = + float_to_bf16(acc[idx] * inv_rms_[i_idx]); + const auto tmp1 = float_to_bf16( + type_convert(tmp0) * gamma_); + const auto rmsn_ = type_convert(tmp1); + rmsn(idx) = rmsn_; + } + else + { + const auto tmp = type_convert(acc[idx] * inv_rms_[i_idx]); + const auto rmsn_ = type_convert(tmp) * gamma_; + rmsn(idx) = rmsn_; + } + }); + + if constexpr(kFusedQuant == Rmsnorm2dFusedQuantEnum::SMOOTH_DYNAMIC_QUANT) + { + if constexpr(kSaveUnquant) + { + Epilogue{}( + unquant_y_window, y_window_, sm_scale_window_, y_scale_window_, rmsn, smem); + } + else + { + Epilogue{}(y_window_, sm_scale_window_, y_scale_window_, rmsn, smem); + } + } + else if constexpr(kFusedQuant == Rmsnorm2dFusedQuantEnum::DYNAMIC_QUANT) + { + if constexpr(kSaveUnquant) + { + Epilogue{}(unquant_y_window, y_window_, y_scale_window_, rmsn, smem); + } + else + { + Epilogue{}(y_window_, y_scale_window_, rmsn, smem); + } + } + else + { + Epilogue{}(y_window_, rmsn); + } + } +}; +} // namespace ck_tile diff --git a/include/ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_one_pass.hpp b/include/ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_one_pass.hpp index 58159142d0..c77d61872e 100644 --- a/include/ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_one_pass.hpp +++ b/include/ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_one_pass.hpp @@ -117,10 +117,7 @@ struct Rmsnorm2dFwdPipelineOnePass // compute inv-rms auto inv_rms = tile_elementwise_in( - [&](const auto& v_) { - return type_convert(1.0f) / (sqrt(v_ / row_size + epsilon)); - }, - square_sum); + [&](const auto& v_) { return rsqrtf(v_ / row_size + epsilon); }, square_sum); if constexpr(kSaveInvRms) store_tile(inv_rms_window, cast_tile(inv_rms)); diff --git a/include/ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_traits.hpp b/include/ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_traits.hpp index 152da60c01..b91f17ffdd 100644 --- a/include/ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_traits.hpp +++ b/include/ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_traits.hpp @@ -37,20 +37,37 @@ template<> struct Rmsnorm2dFusedQuantEnumName struct Rmsnorm2dFusedQuantEnumName { static constexpr const char * name = "smdqt"; }; // clang-format on +enum class Rmsnorm2dSensitiveEnum +{ + NO_SPECIFIC_MODEL = 0, + // T5-like model for RMSNorm. The T5 model, developed by Google, is a transformer-based + // architecture designed for a variety of NLP tasks. This option mimics T5's approach to + // RMSNorm, aiming to ensure similar value distributions and enhance accuracy. + T5_MODEL_LIKE = 1, +}; + +// clang-format off +template struct Rmsnorm2dSensitiveEnumName; +template<> struct Rmsnorm2dSensitiveEnumName { static constexpr const char * name = "nsm"; }; +template<> struct Rmsnorm2dSensitiveEnumName { static constexpr const char * name = "t5ml"; }; +// clang-format on + template + Rmsnorm2dFusedQuantEnum kFusedQuant_, + Rmsnorm2dSensitiveEnum kUseModelSensitiveRMSNorm_> struct Rmsnorm2dFwdTraits { - static constexpr bool kPadN = kPadN_; - static constexpr bool kSaveInvRms = kSaveInvRms_; - static constexpr bool kSaveUnquant = kSaveUnquant_; - static constexpr bool kTwoPass = kTwoPass_; - static constexpr Rmsnorm2dFusedAddEnum kFusedAdd = kFusedAdd_; - static constexpr Rmsnorm2dFusedQuantEnum kFusedQuant = kFusedQuant_; + static constexpr bool kPadN = kPadN_; + static constexpr bool kSaveInvRms = kSaveInvRms_; + static constexpr bool kSaveUnquant = kSaveUnquant_; + static constexpr bool kTwoPass = kTwoPass_; + static constexpr Rmsnorm2dFusedAddEnum kFusedAdd = kFusedAdd_; + static constexpr Rmsnorm2dFusedQuantEnum kFusedQuant = kFusedQuant_; + static constexpr Rmsnorm2dSensitiveEnum kUseModelSensitiveRMSNorm = kUseModelSensitiveRMSNorm_; }; } // namespace ck_tile From c1badfd30c1679f4c8e176c8f0608db2c6ac6505 Mon Sep 17 00:00:00 2001 From: huaiguxu <145733371+huaiguxu@users.noreply.github.com> Date: Wed, 16 Jul 2025 15:44:34 +0800 Subject: [PATCH 2/5] Handle moe_fp8 no-mainloop cases. Supprese no-mainloop check (#2438) Co-authored-by: felix --- .../gpu/device/impl/device_moe_gemm.hpp | 50 ++++++++++++++++--- .../gpu/grid/gridwise_moe_gemm.hpp | 2 +- 2 files changed, 45 insertions(+), 7 deletions(-) diff --git a/include/ck/tensor_operation/gpu/device/impl/device_moe_gemm.hpp b/include/ck/tensor_operation/gpu/device/impl/device_moe_gemm.hpp index 08d177035e..27d3c378ac 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_moe_gemm.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_moe_gemm.hpp @@ -325,12 +325,50 @@ struct DeviceMoeGemm : public DeviceGemmMultipleDSplitKBPreShuffle; - RunKernel(kernel); + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) + { + const auto kernel = kernel_moe_gemm; + RunKernel(kernel); + } + else + { + const auto kernel = kernel_moe_gemm; + RunKernel(kernel); + } + } + else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v2 || + BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) + { + const auto kernel = kernel_moe_gemm_2lds; + RunKernel(kernel); + } + else + { + const auto kernel = kernel_moe_gemm_2lds; + RunKernel(kernel); + } + } + else + { + throw std::runtime_error("todo: only v1 & v2 support now"); } } #endif diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm.hpp index 36f8fd7cc1..3d5066d52d 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm.hpp @@ -1112,7 +1112,7 @@ struct GridwiseMoeGemm } // check gridwise gemm pipeline -#if 1 +#if 0 const auto num_k_loop = karg.AK0 / (KPerBlock / AK1Value); if(num_k_loop <= BlockwiseGemmPipe::PrefetchStages) From f5d1e3fa4878fcfa380082e357e89152756327ce Mon Sep 17 00:00:00 2001 From: Illia Silin <98187287+illsilin@users.noreply.github.com> Date: Wed, 16 Jul 2025 07:37:53 -0700 Subject: [PATCH 3/5] Use a clang20 compiler for gfx950 builds. (#2504) * update docker tag for gfx950 ci build * update compiler path for gfx950 ci build * suppress compiler path override for gfx950 * clean up --- Jenkinsfile | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/Jenkinsfile b/Jenkinsfile index 50c15701a7..a7dc8360ee 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -234,11 +234,6 @@ def cmake_build(Map conf=[:]){ def build_type_debug = (conf.get("build_type",'release') == 'debug') - // use special compiler for gfx950 - if ( check_arch() == 7){ - compiler = "/llvm-project/build/bin/clang++" - } - //cmake_env can overwrite default CXX variables. def cmake_envs = "CXX=${compiler} CXXFLAGS='-Werror' " + conf.get("cmake_ex_env","") @@ -1352,12 +1347,12 @@ pipeline { execute_args = """ cd ../client_example && rm -rf build && mkdir build && cd build && \ cmake -DCMAKE_PREFIX_PATH="${env.WORKSPACE}/install;/opt/rocm" \ -DGPU_TARGETS="gfx950" \ - -DCMAKE_CXX_COMPILER=/llvm-project/build/bin/clang++ \ + -DCMAKE_CXX_COMPILER=/opt/rocm/llvm/bin/clang++ \ -DCMAKE_C_COMPILER=/opt/rocm/llvm/bin/clang \ -DCMAKE_CXX_FLAGS=" -O3 " .. && make -j """ } steps{ - Build_CK_and_Reboot(setup_args: setup_args, docker_name: "${env.CK_DOCKERHUB_PRIVATE}:ck_ub22.04_rocm7.0", config_targets: "install", no_reboot:true, build_type: 'Release', execute_cmd: execute_args, prefixpath: '/usr/local') + Build_CK_and_Reboot(setup_args: setup_args, docker_name: "${env.CK_DOCKERHUB_PRIVATE}:ck_ub24.04_rocm7.0", config_targets: "install", no_reboot:true, build_type: 'Release', execute_cmd: execute_args, prefixpath: '/usr/local') cleanWs() } } From a4bf78ac0ec5882692423bd5b58d84feb3488629 Mon Sep 17 00:00:00 2001 From: Illia Silin <98187287+illsilin@users.noreply.github.com> Date: Wed, 16 Jul 2025 07:39:15 -0700 Subject: [PATCH 4/5] replace obsolete warpSize system variable with the new one (#2496) --- .../gpu/grid/gridwise_moe_mx_gemm_bpreshuffle.hpp | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_moe_mx_gemm_bpreshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_moe_mx_gemm_bpreshuffle.hpp index 156db6e636..be85528f28 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_moe_mx_gemm_bpreshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_moe_mx_gemm_bpreshuffle.hpp @@ -467,7 +467,7 @@ struct GridwiseMoeGemmMX_BPreshuffle __host__ __device__ static auto MakeBGridDescriptor_Preshuffled(index_t N0, index_t K0) { - constexpr index_t NkSwizzleNumber = Number{}; + constexpr index_t NkSwizzleNumber = Number{}; return make_naive_tensor_descriptor_packed( make_tuple(N0 / NWave / NXdlPack, NWave, NXdlPack, K0, NkSwizzleNumber)); } @@ -1474,7 +1474,7 @@ struct GridwiseMoeGemmMX_BPreshuffle make_multi_index(n_block_data_idx_on_grid, get_warp_local_1d_id() % NWave, 0, - KPack / KGroup * (get_thread_local_1d_id() % warpSize))); + KPack / KGroup * (get_thread_local_1d_id() % WarpSize))); // LDS allocation for A and B: be careful of alignment // Cast after lds @@ -1567,7 +1567,7 @@ struct GridwiseMoeGemmMX_BPreshuffle make_multi_index(n_block_data_idx_on_grid, get_warp_local_1d_id() % NWave, 0, - KPack / KGroup * (get_thread_local_1d_id() % warpSize))); + KPack / KGroup * (get_thread_local_1d_id() % WarpSize))); const BScaleDataType* p_b_scale_grid_up = p_b_scale_grid + expert_scale_stride / 2; const auto b_scale_grid_buf_up = make_dynamic_buffer( p_b_scale_grid_up + expert_id * expert_scale_stride, @@ -2185,7 +2185,7 @@ struct GridwiseMoeGemmMX_BPreshuffle get_warp_local_1d_id() % NWave, 0, 0, - KPack * (get_thread_local_1d_id() % warpSize))); + KPack * (get_thread_local_1d_id() % WarpSize))); // LDS allocation for A and B: be careful of alignment // Cast after lds @@ -2289,7 +2289,7 @@ struct GridwiseMoeGemmMX_BPreshuffle get_warp_local_1d_id() % NWave, 0, 0, - KPack * (get_thread_local_1d_id() % warpSize))); + KPack * (get_thread_local_1d_id() % WarpSize))); const BScaleDataType* p_b_scale_grid_up = p_b_scale_grid + expert_scale_stride / 2 / sizeof(BScaleDataType); const auto b_scale_grid_buf_up = make_dynamic_buffer( From 6e76b82059eceb1a1614f4a335c70faa2d122c97 Mon Sep 17 00:00:00 2001 From: linqunAMD Date: Wed, 16 Jul 2025 22:58:23 +0800 Subject: [PATCH 5/5] Fix build errors on windows (#2456) * Fix build errors on windows * correct clang format --------- Co-authored-by: Lin, Qun --- cmake/gtest.cmake | 3 ++ .../34_batchnorm/batchnorm_backward_nhwc.cpp | 4 +- .../batchnorm_forward_inferring_nhwc.cpp | 5 +-- .../batchnorm_forward_training_nhwc.cpp | 7 ++-- ...tchnorm_forward_training_nhwc_obsolete.cpp | 7 ++-- example/CMakeLists.txt | 1 + include/ck/utility/amd_xdlops.hpp | 32 +++++++------- include/ck/utility/env.hpp | 1 + include/ck/utility/synchronization.hpp | 2 +- .../ops/gemm/kernel/batched_gemm_kernel.hpp | 2 +- .../ck_tile/ops/gemm/kernel/gemm_kernel.hpp | 2 +- .../ops/gemm/kernel/grouped_gemm_kernel.hpp | 4 +- .../warp/warp_gemm_attribute_mfma_impl.hpp | 42 +++++++++---------- .../include/profiler/profile_gemm_impl.hpp | 4 ++ profiler/src/profile_batched_gemm_b_scale.cpp | 3 +- profiler/src/profile_gemm_b_scale.cpp | 3 +- test/scatter_gather/scatter_gather.cpp | 4 +- 17 files changed, 67 insertions(+), 59 deletions(-) diff --git a/cmake/gtest.cmake b/cmake/gtest.cmake index 0915f53411..6587f4c4be 100644 --- a/cmake/gtest.cmake +++ b/cmake/gtest.cmake @@ -68,3 +68,6 @@ endif() target_compile_options(gtest PRIVATE ${GTEST_CXX_FLAGS}) target_compile_options(gtest_main PRIVATE ${GTEST_CXX_FLAGS}) +target_compile_definitions(gtest PRIVATE GTEST_HAS_SEH=0) +target_compile_definitions(gtest_main PRIVATE GTEST_HAS_SEH=0) + diff --git a/example/34_batchnorm/batchnorm_backward_nhwc.cpp b/example/34_batchnorm/batchnorm_backward_nhwc.cpp index 3756310fd7..9737b0d99b 100644 --- a/example/34_batchnorm/batchnorm_backward_nhwc.cpp +++ b/example/34_batchnorm/batchnorm_backward_nhwc.cpp @@ -403,10 +403,10 @@ bool bnorm_bwd_nhwc_test(bool do_verification, return (pass); }; -static const double epsilon = std::numeric_limits::epsilon(); - int main(int argc, char* argv[]) { + static const double epsilon = std::numeric_limits::epsilon(); + bool pass = true; if(argc > 1) diff --git a/example/34_batchnorm/batchnorm_forward_inferring_nhwc.cpp b/example/34_batchnorm/batchnorm_forward_inferring_nhwc.cpp index 6a8002025a..1ffbabd04b 100644 --- a/example/34_batchnorm/batchnorm_forward_inferring_nhwc.cpp +++ b/example/34_batchnorm/batchnorm_forward_inferring_nhwc.cpp @@ -314,11 +314,10 @@ bool bnorm_infer_nhwc_test(bool do_verification, return (pass); }; -static const double epsilon = std::numeric_limits::epsilon(); - int main(int argc, char* argv[]) { - bool pass = true; + static const double epsilon = std::numeric_limits::epsilon(); + bool pass = true; if(argc > 1) { diff --git a/example/34_batchnorm/batchnorm_forward_training_nhwc.cpp b/example/34_batchnorm/batchnorm_forward_training_nhwc.cpp index b27358fd9d..06441be860 100644 --- a/example/34_batchnorm/batchnorm_forward_training_nhwc.cpp +++ b/example/34_batchnorm/batchnorm_forward_training_nhwc.cpp @@ -453,12 +453,11 @@ bool bnorm_fwd_nhwc_test(bool do_verification, return (pass); }; -const double epsilon = std::numeric_limits::epsilon(); -static const double averageFactor = 0.1; - int main(int argc, char* argv[]) { - bool pass = true; + const double epsilon = std::numeric_limits::epsilon(); + static const double averageFactor = 0.1; + bool pass = true; if(argc > 1) { diff --git a/example/34_batchnorm/batchnorm_forward_training_nhwc_obsolete.cpp b/example/34_batchnorm/batchnorm_forward_training_nhwc_obsolete.cpp index ffb9f4b584..8f2b7613b5 100644 --- a/example/34_batchnorm/batchnorm_forward_training_nhwc_obsolete.cpp +++ b/example/34_batchnorm/batchnorm_forward_training_nhwc_obsolete.cpp @@ -453,12 +453,11 @@ bool bnorm_fwd_nhwc_test(bool do_verification, return (pass); }; -const double epsilon = std::numeric_limits::epsilon(); -static const double averageFactor = 0.1; - int main(int argc, char* argv[]) { - bool pass = true; + const double epsilon = std::numeric_limits::epsilon(); + static const double averageFactor = 0.1; + bool pass = true; if(argc > 1) { diff --git a/example/CMakeLists.txt b/example/CMakeLists.txt index 56d709f41b..3c67e9214f 100644 --- a/example/CMakeLists.txt +++ b/example/CMakeLists.txt @@ -128,6 +128,7 @@ function(add_example_executable EXAMPLE_NAME FILE_NAME) set_source_files_properties(${FILE_NAME} PROPERTIES LANGUAGE HIP) add_executable(${EXAMPLE_NAME} ${FILE_NAME}) target_link_libraries(${EXAMPLE_NAME} PRIVATE utility) + target_link_libraries(${EXAMPLE_NAME} PRIVATE getopt::getopt) add_test(NAME ${EXAMPLE_NAME} COMMAND $ ${ARGN}) set_property(TARGET ${EXAMPLE_NAME} PROPERTY HIP_ARCHITECTURES ${EX_TARGETS} ) add_dependencies(examples ${EXAMPLE_NAME}) diff --git a/include/ck/utility/amd_xdlops.hpp b/include/ck/utility/amd_xdlops.hpp index 8646b8393b..02a7a72b8c 100644 --- a/include/ck/utility/amd_xdlops.hpp +++ b/include/ck/utility/amd_xdlops.hpp @@ -1396,8 +1396,8 @@ struct intrin_mfma_f32_32x32x16f8f8<32, 32> #if defined(__gfx94__) reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x16_fp8_fp8( - bit_cast(reg_a), - bit_cast(reg_b), + bit_cast(reg_a), + bit_cast(reg_b), reg_c.template AsType()[Number<0>{}], 0, 0, @@ -1427,8 +1427,8 @@ struct intrin_mfma_f32_16x16x32f8f8<16, 16> { #if defined(__gfx94__) reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x32_fp8_fp8( - bit_cast(reg_a), - bit_cast(reg_b), + bit_cast(reg_a), + bit_cast(reg_b), reg_c.template AsType()[Number<0>{}], 0, 0, @@ -1459,8 +1459,8 @@ struct intrin_mfma_f32_32x32x16bf8bf8<32, 32> #if defined(__gfx94__) reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x16_bf8_bf8( - bit_cast(reg_a), - bit_cast(reg_b), + bit_cast(reg_a), + bit_cast(reg_b), reg_c.template AsType()[Number<0>{}], 0, 0, @@ -1490,8 +1490,8 @@ struct intrin_mfma_f32_16x16x32bf8bf8<16, 16> { #if defined(__gfx94__) reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x32_bf8_bf8( - bit_cast(reg_a), - bit_cast(reg_b), + bit_cast(reg_a), + bit_cast(reg_b), reg_c.template AsType()[Number<0>{}], 0, 0, @@ -1522,8 +1522,8 @@ struct intrin_mfma_f32_32x32x16f8bf8<32, 32> #if defined(__gfx94__) reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x16_fp8_bf8( - bit_cast(reg_a), - bit_cast(reg_b), + bit_cast(reg_a), + bit_cast(reg_b), reg_c.template AsType()[Number<0>{}], 0, 0, @@ -1553,8 +1553,8 @@ struct intrin_mfma_f32_16x16x32f8bf8<16, 16> { #if defined(__gfx94__) reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x32_fp8_bf8( - bit_cast(reg_a), - bit_cast(reg_b), + bit_cast(reg_a), + bit_cast(reg_b), reg_c.template AsType()[Number<0>{}], 0, 0, @@ -1585,8 +1585,8 @@ struct intrin_mfma_f32_32x32x16bf8f8<32, 32> #if defined(__gfx94__) reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x16_bf8_fp8( - bit_cast(reg_a), - bit_cast(reg_b), + bit_cast(reg_a), + bit_cast(reg_b), reg_c.template AsType()[Number<0>{}], 0, 0, @@ -1616,8 +1616,8 @@ struct intrin_mfma_f32_16x16x32bf8f8<16, 16> { #if defined(__gfx94__) reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x32_bf8_fp8( - bit_cast(reg_a), - bit_cast(reg_b), + bit_cast(reg_a), + bit_cast(reg_b), reg_c.template AsType()[Number<0>{}], 0, 0, diff --git a/include/ck/utility/env.hpp b/include/ck/utility/env.hpp index 46ba32bb87..2f5b804d16 100644 --- a/include/ck/utility/env.hpp +++ b/include/ck/utility/env.hpp @@ -8,6 +8,7 @@ #include #include #include +#include namespace ck { namespace internal { diff --git a/include/ck/utility/synchronization.hpp b/include/ck/utility/synchronization.hpp index d6b6eac26c..7652e73809 100644 --- a/include/ck/utility/synchronization.hpp +++ b/include/ck/utility/synchronization.hpp @@ -33,7 +33,7 @@ __device__ void block_sync_lds_direct_load() { #ifdef __gfx12__ asm volatile("\ - s_wait_vmcnt 0x0 \n \ + s_wait_loadcnt 0x0 \n \ s_wait_dscnt 0x0 \n \ s_barrier_signal -1 \n \ s_barrier_wait -1 \ diff --git a/include/ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp b/include/ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp index 09c7d58558..fc72138abf 100644 --- a/include/ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp +++ b/include/ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp @@ -74,7 +74,7 @@ struct BatchedGemmKernel : public GemmKernel, + return concat('_', "gemm_batched", gemm_prec_str(), concat('x', P_::MPerBlock, P_::NPerBlock, P_::KPerBlock), concat('x', P_::GetVectorSizeA(), P_::GetVectorSizeB(), P_::GetVectorSizeC()), concat('x', P_::kPadM, P_::kPadN, P_::kPadK)); diff --git a/include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp b/include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp index 516d4298ef..53c21b49f5 100755 --- a/include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp +++ b/include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp @@ -196,7 +196,7 @@ struct GemmKernel [[nodiscard]] CK_TILE_HOST static const std::string GetName() { // clang-format off - return concat('_', "gemm", gemm_prec_str, GemmPipeline::GetName()); + return concat('_', "gemm", gemm_prec_str(), GemmPipeline::GetName()); // clang-format on } diff --git a/include/ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp b/include/ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp index 533cabb736..2605b1afbc 100644 --- a/include/ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp +++ b/include/ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp @@ -57,7 +57,7 @@ struct GroupedGemmKernel : public GemmKernel, + return concat('_', "gemm_grouped", gemm_prec_str(), concat('x', P_::MPerBlock, P_::NPerBlock, P_::KPerBlock), concat('x', P_::GetVectorSizeA(), P_::GetVectorSizeB(), P_::GetVectorSizeC()), concat('x', P_::kPadM, P_::kPadN, P_::kPadK), @@ -95,7 +95,7 @@ struct GroupedGemmKernel : public GemmKernel>& gemm_descs) { index_t grid_size = 0; diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp index 80f38f263b..0831cf85c4 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp @@ -1095,16 +1095,16 @@ struct WarpGemmAttributeMfmaImpl_f32_16x16x32_f8_base #if defined(__gfx94__) or defined(__gfx95__) if constexpr(std::is_same_v && std::is_same_v) c_vec = __builtin_amdgcn_mfma_f32_16x16x32_fp8_fp8( - bit_cast(a_vec), bit_cast(b_vec), c_vec, 0, 0, 0); + bit_cast(a_vec), bit_cast(b_vec), c_vec, 0, 0, 0); else if constexpr(std::is_same_v && std::is_same_v) c_vec = __builtin_amdgcn_mfma_f32_16x16x32_fp8_bf8( - bit_cast(a_vec), bit_cast(b_vec), c_vec, 0, 0, 0); + bit_cast(a_vec), bit_cast(b_vec), c_vec, 0, 0, 0); else if constexpr(std::is_same_v && std::is_same_v) c_vec = __builtin_amdgcn_mfma_f32_16x16x32_bf8_fp8( - bit_cast(a_vec), bit_cast(b_vec), c_vec, 0, 0, 0); + bit_cast(a_vec), bit_cast(b_vec), c_vec, 0, 0, 0); else if constexpr(std::is_same_v && std::is_same_v) c_vec = __builtin_amdgcn_mfma_f32_16x16x32_bf8_bf8( - bit_cast(a_vec), bit_cast(b_vec), c_vec, 0, 0, 0); + bit_cast(a_vec), bit_cast(b_vec), c_vec, 0, 0, 0); #else ck_tile::ignore = c_vec; ck_tile::ignore = a_vec; @@ -1119,16 +1119,16 @@ struct WarpGemmAttributeMfmaImpl_f32_16x16x32_f8_base #if defined(__gfx94__) or defined(__gfx95__) if constexpr(std::is_same_v && std::is_same_v) return bit_cast(__builtin_amdgcn_mfma_f32_16x16x32_fp8_fp8( - bit_cast(a_vec), bit_cast(b_vec), CVecType{0.f}, 0, 0, 0)); + bit_cast(a_vec), bit_cast(b_vec), CVecType{0.f}, 0, 0, 0)); else if constexpr(std::is_same_v && std::is_same_v) return bit_cast(__builtin_amdgcn_mfma_f32_16x16x32_fp8_bf8( - bit_cast(a_vec), bit_cast(b_vec), CVecType{0.f}, 0, 0, 0)); + bit_cast(a_vec), bit_cast(b_vec), CVecType{0.f}, 0, 0, 0)); else if constexpr(std::is_same_v && std::is_same_v) return bit_cast(__builtin_amdgcn_mfma_f32_16x16x32_bf8_fp8( - bit_cast(a_vec), bit_cast(b_vec), CVecType{0.f}, 0, 0, 0)); + bit_cast(a_vec), bit_cast(b_vec), CVecType{0.f}, 0, 0, 0)); else if constexpr(std::is_same_v && std::is_same_v) return bit_cast(__builtin_amdgcn_mfma_f32_16x16x32_bf8_bf8( - bit_cast(a_vec), bit_cast(b_vec), CVecType{0.f}, 0, 0, 0)); + bit_cast(a_vec), bit_cast(b_vec), CVecType{0.f}, 0, 0, 0)); #else ck_tile::ignore = a_vec; ck_tile::ignore = b_vec; @@ -1254,16 +1254,16 @@ struct WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base #if defined(__gfx94__) or defined(__gfx95__) if constexpr(std::is_same_v && std::is_same_v) c_vec = __builtin_amdgcn_mfma_f32_32x32x16_fp8_fp8( - bit_cast(a_vec), bit_cast(b_vec), c_vec, 0, 0, 0); + bit_cast(a_vec), bit_cast(b_vec), c_vec, 0, 0, 0); else if constexpr(std::is_same_v && std::is_same_v) c_vec = __builtin_amdgcn_mfma_f32_32x32x16_fp8_bf8( - bit_cast(a_vec), bit_cast(b_vec), c_vec, 0, 0, 0); + bit_cast(a_vec), bit_cast(b_vec), c_vec, 0, 0, 0); else if constexpr(std::is_same_v && std::is_same_v) c_vec = __builtin_amdgcn_mfma_f32_32x32x16_bf8_fp8( - bit_cast(a_vec), bit_cast(b_vec), c_vec, 0, 0, 0); + bit_cast(a_vec), bit_cast(b_vec), c_vec, 0, 0, 0); else if constexpr(std::is_same_v && std::is_same_v) c_vec = __builtin_amdgcn_mfma_f32_32x32x16_bf8_bf8( - bit_cast(a_vec), bit_cast(b_vec), c_vec, 0, 0, 0); + bit_cast(a_vec), bit_cast(b_vec), c_vec, 0, 0, 0); #elif defined(__gfx908__) || defined(__gfx90a__) static_for<0, 8, 1>{}([&](auto k) { float a_f32 = @@ -1289,16 +1289,16 @@ struct WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base #if defined(__gfx94__) or defined(__gfx95__) if constexpr(std::is_same_v && std::is_same_v) return bit_cast(__builtin_amdgcn_mfma_f32_32x32x16_fp8_fp8( - bit_cast(a_vec), bit_cast(b_vec), CVecType{0.f}, 0, 0, 0)); + bit_cast(a_vec), bit_cast(b_vec), CVecType{0.f}, 0, 0, 0)); else if constexpr(std::is_same_v && std::is_same_v) return bit_cast(__builtin_amdgcn_mfma_f32_32x32x16_fp8_bf8( - bit_cast(a_vec), bit_cast(b_vec), CVecType{0.f}, 0, 0, 0)); + bit_cast(a_vec), bit_cast(b_vec), CVecType{0.f}, 0, 0, 0)); else if constexpr(std::is_same_v && std::is_same_v) return bit_cast(__builtin_amdgcn_mfma_f32_32x32x16_bf8_fp8( - bit_cast(a_vec), bit_cast(b_vec), CVecType{0.f}, 0, 0, 0)); + bit_cast(a_vec), bit_cast(b_vec), CVecType{0.f}, 0, 0, 0)); else if constexpr(std::is_same_v && std::is_same_v) return bit_cast(__builtin_amdgcn_mfma_f32_32x32x16_bf8_bf8( - bit_cast(a_vec), bit_cast(b_vec), CVecType{0.f}, 0, 0, 0)); + bit_cast(a_vec), bit_cast(b_vec), CVecType{0.f}, 0, 0, 0)); #elif defined(__gfx908__) || defined(__gfx90a__) CVecType c_vec{0.f}; static_for<0, 8, 1>{}([&](auto k) { @@ -1580,7 +1580,7 @@ struct WarpGemmAttributeMfmaImpl_i32_32x32x16_i8 { #if defined(__gfx94__) or defined(__gfx95__) c_vec = __builtin_amdgcn_mfma_i32_32x32x16_i8( - bit_cast(a_vec), bit_cast(b_vec), c_vec, 0, 0, 0); + bit_cast(a_vec), bit_cast(b_vec), c_vec, 0, 0, 0); #elif defined(__gfx908__) || defined(__gfx90a__) static_for<0, 8, 1>{}([&](auto k) { float a_f32 = @@ -1650,7 +1650,7 @@ struct WarpGemmAttributeMfmaImpl_i32_16x16x32_i8 { #if defined(__gfx94__) or defined(__gfx95__) c_vec = __builtin_amdgcn_mfma_i32_16x16x32_i8( - bit_cast(a_vec), bit_cast(b_vec), c_vec, 0, 0, 0); + bit_cast(a_vec), bit_cast(b_vec), c_vec, 0, 0, 0); #else ck_tile::ignore = c_vec; ck_tile::ignore = a_vec; @@ -1709,7 +1709,7 @@ struct WarpGemmAttributeMfmaImpl_i32_16x16x64_i8 { #if defined(__gfx95__) c_vec = __builtin_amdgcn_mfma_i32_16x16x64_i8( - bit_cast(a_vec), bit_cast(b_vec), c_vec, 0, 0, 0); + bit_cast(a_vec), bit_cast(b_vec), c_vec, 0, 0, 0); #else ck_tile::ignore = c_vec; ck_tile::ignore = a_vec; @@ -1767,8 +1767,8 @@ struct WarpGemmAttributeMfmaImpl_i32_32x32x32_i8 else { #if defined(__gfx95__) - c_vec = - __builtin_amdgcn_mfma_i32_32x32x32_i8(a_vec, bit_cast(b_vec), c_vec, 0, 0, 0); + c_vec = __builtin_amdgcn_mfma_i32_32x32x32_i8( + a_vec, bit_cast(b_vec), c_vec, 0, 0, 0); #else ck_tile::ignore = c_vec; ck_tile::ignore = a_vec; diff --git a/profiler/include/profiler/profile_gemm_impl.hpp b/profiler/include/profiler/profile_gemm_impl.hpp index 1373dbc497..d2a38b2a81 100644 --- a/profiler/include/profiler/profile_gemm_impl.hpp +++ b/profiler/include/profiler/profile_gemm_impl.hpp @@ -6,7 +6,9 @@ #include #include #include +#if defined(__unix__) #include +#endif #include "ck/ck.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" @@ -213,7 +215,9 @@ int profile_gemm_impl(int do_verification, instance_id++; } +#if defined(__unix__) sleep(2); +#endif // Run the best instance again { diff --git a/profiler/src/profile_batched_gemm_b_scale.cpp b/profiler/src/profile_batched_gemm_b_scale.cpp index f768a17570..5fe6f490be 100644 --- a/profiler/src/profile_batched_gemm_b_scale.cpp +++ b/profiler/src/profile_batched_gemm_b_scale.cpp @@ -5,6 +5,7 @@ #include #include #include +#include #include "profiler/profile_batched_gemm_b_scale_impl.hpp" #include "profiler_operation_registry.hpp" @@ -114,7 +115,7 @@ int profile_batched_gemm_b_scale(int argc, char* argv[]) n_iter = std::stoi(argv[18]); rotating = std::stoull(argv[19]) * 1024 * 1024; - printf("n_warmup:%d, n_iter:%d, rotating:%lu\n", n_warmup, n_iter, rotating); + printf("n_warmup:%d, n_iter:%d, rotating:%" PRIu64 "\n", n_warmup, n_iter, rotating); } using F32 = float; diff --git a/profiler/src/profile_gemm_b_scale.cpp b/profiler/src/profile_gemm_b_scale.cpp index 443ebff834..7bcc96a434 100644 --- a/profiler/src/profile_gemm_b_scale.cpp +++ b/profiler/src/profile_gemm_b_scale.cpp @@ -5,6 +5,7 @@ #include #include #include +#include #include "profiler/profile_gemm_b_scale_impl.hpp" #include "profiler_operation_registry.hpp" @@ -100,7 +101,7 @@ int profile_gemm_b_scale(int argc, char* argv[]) n_iter = std::stoi(argv[17]); rotating = std::stoull(argv[18]) * 1024 * 1024; - printf("n_warmup:%d, n_iter:%d, rotating:%lu\n", n_warmup, n_iter, rotating); + printf("n_warmup:%d, n_iter:%d, rotating:%" PRIu64 "\n", n_warmup, n_iter, rotating); } using F32 = float; diff --git a/test/scatter_gather/scatter_gather.cpp b/test/scatter_gather/scatter_gather.cpp index 81765b43e5..874c4d86c0 100644 --- a/test/scatter_gather/scatter_gather.cpp +++ b/test/scatter_gather/scatter_gather.cpp @@ -140,8 +140,8 @@ union pixel { struct __attribute__((packed)) { - unsigned int r : 6; - unsigned int c : 10; + ushort r : 6; + ushort c : 10; }; ushort data; };