From ce1d20c2c63ce480ebaa3682714b81be85f57fa6 Mon Sep 17 00:00:00 2001 From: ruanjm Date: Tue, 25 Mar 2025 20:09:45 +0800 Subject: [PATCH] [CK_TILE] Improve RMS/Layer Normalization 2 Pass Pipeline Performance (#1861) * 50ms -> 28ms * Fix bug in non fuse_add_store cases * Fine tuned setting for 2 pass pipeline * adjust workload * remove unnecessary change * add layernorm * Adding output quant and unquant results at the same time. * fix test * fix format * tune for cases 128x640 and 128x1024 * bug ifx [ROCm/composable_kernel commit: d49abdaa8711cc3f690f8ffe00f7393b2708a28f] --- example/ck_tile/02_layernorm2d/generate.py | 4 +- .../10_rmsnorm2d/example_rmsnorm2d_fwd.cpp | 11 +- example/ck_tile/10_rmsnorm2d/generate.py | 154 ++++++++++------- .../ck_tile/10_rmsnorm2d/rmsnorm2d_fwd.cpp | 161 +++++++++++++++--- .../ck_tile/10_rmsnorm2d/rmsnorm2d_fwd.hpp | 3 + .../ck_tile/10_rmsnorm2d/script/smoke_test.sh | 11 +- .../reference/reference_rmsnorm2d_fwd.hpp | 13 +- include/ck_tile/ops/epilogue.hpp | 1 + .../default_2d_and_dynamic_quant_epilogue.hpp | 91 ++++++++++ .../layernorm2d_fwd_pipeline_two_pass.hpp | 67 +++++--- .../rmsnorm2d/kernel/rmsnorm2d_fwd_kernel.hpp | 34 +++- .../rmsnorm2d_fwd_pipeline_one_pass.hpp | 26 ++- .../rmsnorm2d_fwd_pipeline_problem.hpp | 2 + .../rmsnorm2d_fwd_pipeline_two_pass.hpp | 47 +++-- .../pipeline/rmsnorm2d_fwd_traits.hpp | 2 + 15 files changed, 492 insertions(+), 135 deletions(-) create mode 100644 include/ck_tile/ops/epilogue/default_2d_and_dynamic_quant_epilogue.hpp diff --git a/example/ck_tile/02_layernorm2d/generate.py b/example/ck_tile/02_layernorm2d/generate.py index 700b007fad..0238a125dc 100644 --- a/example/ck_tile/02_layernorm2d/generate.py +++ b/example/ck_tile/02_layernorm2d/generate.py @@ -564,9 +564,9 @@ float layernorm2d_fwd(layernorm2d_fwd_traits t, h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 512, 4, True, False, True, True, False, 0, 0, 0), h_traits('x', 'y', 'xs', 'ys', 1, 4, 1,1024, 2, True, False, True, True, False, 0, 0, 0), h_traits('x', 'y', 'xs', 'ys', 1, 8, 1,1024, 1, True, False, True, True, False, 0, 0, 0)], - 'big' :[ h_traits('x', 'y', 'xs', 'ys', 1, 2, 1, 256, 8, True, False, True, True, True, 0, 0, 0), + 'big' :[ h_traits('x', 'y', 'xs', 'ys', 1, 1, 1,1024, 8, True, False, True, True, True, 0, 0, 0), h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 256, 4, True, False, True, True, True, 0, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 2, 1,1024, 2, True, False, True, True, True, 0, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 12, 1, 256, 2, True, False, True, True, True, 0, 0, 0), h_traits('x', 'y', 'xs', 'ys', 1, 4, 1,1024, 1, True, False, True, True, True, 0, 0, 0)]} total_blob = list() for hs_key in h_trait_dict: diff --git a/example/ck_tile/10_rmsnorm2d/example_rmsnorm2d_fwd.cpp b/example/ck_tile/10_rmsnorm2d/example_rmsnorm2d_fwd.cpp index 48c150009e..25598282e3 100644 --- a/example/ck_tile/10_rmsnorm2d/example_rmsnorm2d_fwd.cpp +++ b/example/ck_tile/10_rmsnorm2d/example_rmsnorm2d_fwd.cpp @@ -41,6 +41,7 @@ bool run(const ck_tile::ArgParser& arg_parser) using YDataType = DataType; using GammaDataType = DataType; using InvRmsDataType = ck_tile::null_type; + using UnquantYDataType = ck_tile::null_type; using SmoothScaleDataType = ck_tile::null_type; using YScaleDataType = ck_tile::null_type; @@ -55,6 +56,8 @@ bool run(const ck_tile::ArgParser& arg_parser) ck_tile::HostTensor invRms_host_ref({m}); + ck_tile::HostTensor unquant_y_host_ref({m, n}, {stride, 1}); + ck_tile::FillUniformDistribution{-.5f, .5f}(x_host); ck_tile::FillUniformDistribution{-.5f, .5f}(gamma_host); @@ -76,6 +79,7 @@ bool run(const ck_tile::ArgParser& arg_parser) using PipelineTraits = ck_tile::Rmsnorm2dFwdTraits; // fuse quant @@ -85,6 +89,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ComputeDataType, YDataType, InvRmsDataType, + UnquantYDataType, SmoothScaleDataType, YScaleDataType, Shape, @@ -108,6 +113,7 @@ bool run(const ck_tile::ArgParser& arg_parser) nullptr, nullptr, nullptr, + nullptr, epsilon, m, n, @@ -135,8 +141,9 @@ bool run(const ck_tile::ArgParser& arg_parser) GammaDataType, ComputeDataType, YDataType, - InvRmsDataType>( - x_host, gamma_host, y_host_ref, invRms_host_ref, epsilon); + InvRmsDataType, + UnquantYDataType>( + x_host, gamma_host, y_host_ref, invRms_host_ref, unquant_y_host_ref, epsilon); y_buf.FromDevice(y_host_dev.data()); diff --git a/example/ck_tile/10_rmsnorm2d/generate.py b/example/ck_tile/10_rmsnorm2d/generate.py index dadb2268b2..39d42e5ff1 100644 --- a/example/ck_tile/10_rmsnorm2d/generate.py +++ b/example/ck_tile/10_rmsnorm2d/generate.py @@ -54,6 +54,7 @@ template @@ -70,6 +72,7 @@ struct rmsnorm2d_fwd_traits_ using YDataType = ck_tile::remove_cvref_t; using SmoothScaleDataType = ck_tile::remove_cvref_t; using YScaleDataType = ck_tile::remove_cvref_t; + using UnquantYDataType = ck_tile::remove_cvref_t; static constexpr bool is_warp_per_row = ThreadPerBlock_N_ <= warpSize; static_assert((ThreadPerBlock_M_ * ThreadPerBlock_N_) % warpSize == 0); @@ -120,9 +123,10 @@ struct rmsnorm2d_fwd_traits_ using Shape = ck_tile::Generic2dBlockShape; - static constexpr bool kPadN = kPadN_; - static constexpr bool kSaveInvRms = kSaveInvRms_; - static constexpr bool kTwoPass = kTwoPass_; + static constexpr bool kPadN = kPadN_; + static constexpr bool kSaveInvRms = kSaveInvRms_; + static constexpr bool kSaveUnquant = kSaveUnquant_; + static constexpr bool kTwoPass = kTwoPass_; static constexpr ck_tile::index_t kFusedAdd = kFusedAdd_; static constexpr ck_tile::index_t kFusedQuant = kFusedQuant_; }; @@ -131,6 +135,7 @@ template @@ -145,6 +151,7 @@ using traits_ = rmsnorm2d_fwd_traits_; @@ -180,11 +188,13 @@ float rmsnorm2d_fwd_(const S& s, A a) using YDataType = typename Traits_::YDataType; using SmoothScaleDataType = typename Traits_::SmoothScaleDataType; using YScaleDataType = typename Traits_::YScaleDataType; + using UnquantYDataType = typename Traits_::UnquantYDataType; using ComputeDataType = typename RmsnormTypeConfig::ComputeDataType; using PipelineTraits = ck_tile::Rmsnorm2dFwdTraits(Traits_::kFusedAdd), static_cast(Traits_::kFusedQuant)>; @@ -195,6 +205,7 @@ float rmsnorm2d_fwd_(const S& s, A a) typename RmsnormTypeConfig::ComputeDataType, typename RmsnormTypeConfig::YDataType, typename RmsnormTypeConfig::InvRmsDataType, + typename RmsnormTypeConfig::UnquantYDataType, typename RmsnormTypeConfig::SmoothScaleDataType, typename RmsnormTypeConfig::YScaleDataType, typename Traits_::Shape, @@ -213,7 +224,16 @@ float rmsnorm2d_fwd_(const S& s, A a) using DynamicQuantEpilogue = ck_tile::DynamicQuantEpilogue; - using Epilogue = std::conditional_t; + using Default2DAndDynamicQuantEpilogueProblem = ck_tile::Default2DAndDynamicQuantEpilogueProblem< + ComputeDataType, SmoothScaleDataType, YScaleDataType, YDataType, UnquantYDataType, typename Traits_::Shape, + ck_tile::Default2DAndDynamicQuantEpilogueTraits>; + using Default2DAndDynamicQuantEpilogue = ck_tile::Default2DAndDynamicQuantEpilogue; + + using Epilogue = std::conditional_t, + Default2DEpilogue>; using Kernel = ck_tile::Rmsnorm2dFwd; @@ -355,6 +375,7 @@ float rmsnorm2d_fwd(rmsnorm2d_fwd_traits t, F_YDataType : str F_SmoothScaleDataType : str F_YScaleDataType : str + F_UnquantYDataType : str F_Repeat_M : int F_Repeat_N : int F_ThreadPerBlock_M : int @@ -362,14 +383,15 @@ float rmsnorm2d_fwd(rmsnorm2d_fwd_traits t, F_Vector_N : int F_kPadN : bool F_kSaveInvRms : bool + F_kSaveUnquant: bool F_kTwoPass : bool F_kFusedAdd : int F_kFusedQuant : int @property def trait_name(self) ->str: - t_ = f'{DATA_TYPE_MAP[self.F_XDataType]}, {DATA_TYPE_MAP[self.F_YDataType]}, {DATA_TYPE_MAP[self.F_SmoothScaleDataType]}, {DATA_TYPE_MAP[self.F_YScaleDataType]}, {self.F_Repeat_M:2}, {self.F_Repeat_N:2}, {self.F_ThreadPerBlock_M:2}, {self.F_ThreadPerBlock_N:4}' - t_ += f', {self.F_Vector_N:2}, {BOOL_MAP(self.F_kPadN):5}, {BOOL_MAP(self.F_kSaveInvRms):5}' + t_ = f'{DATA_TYPE_MAP[self.F_XDataType]}, {DATA_TYPE_MAP[self.F_YDataType]}, {DATA_TYPE_MAP[self.F_SmoothScaleDataType]}, {DATA_TYPE_MAP[self.F_YScaleDataType]}, {DATA_TYPE_MAP[self.F_UnquantYDataType]}, {self.F_Repeat_M:2}, {self.F_Repeat_N:2}, {self.F_ThreadPerBlock_M:2}, {self.F_ThreadPerBlock_N:4}' + t_ += f', {self.F_Vector_N:2}, {BOOL_MAP(self.F_kPadN):5}, {BOOL_MAP(self.F_kSaveInvRms):5}, {BOOL_MAP(self.F_kSaveUnquant):5}' t_ += f', {BOOL_MAP(self.F_kTwoPass):5}, {self.F_kFusedAdd:4}, {self.F_kFusedQuant:4}' return t_ @@ -390,6 +412,7 @@ float rmsnorm2d_fwd(rmsnorm2d_fwd_traits t, F_N : str F_add : int F_sweep : int + F_saveunquant : bool instance_list : List[Any] # List[h_traits] @property @@ -401,6 +424,8 @@ float rmsnorm2d_fwd(rmsnorm2d_fwd_traits t, nnn = nnn + '_' + FUSED_ADD_ENUM_STR_MAP[self.F_add] if self.F_sweep != 0: nnn = nnn + '_' + FUSED_FUSED_SWEEP_STR_MAP[self.F_sweep] + if self.F_saveunquant: + nnn = nnn + '_saveunquant' return nnn @property @@ -451,11 +476,11 @@ float rmsnorm2d_fwd(rmsnorm2d_fwd_traits t, if ins.F_kFusedQuant == 0: _sweep_cond = 't.fused_quant == {f_fused_sweep}'.format(f_fused_sweep = ins.F_kFusedQuant) elif ins.F_kFusedQuant == 1: - _sweep_cond = 't.fused_quant == {f_fused_sweep} && (t.prec_sm == \"{f_sx_type}\" && t.prec_sy == \"{f_sy_type}\")'.format( - f_fused_sweep = ins.F_kFusedQuant, f_sx_type=ins.F_SmoothScaleDataType, f_sy_type=ins.F_YScaleDataType) + _sweep_cond = 't.fused_quant == {f_fused_sweep} && (t.prec_sm == \"{f_sx_type}\" && t.prec_sy == \"{f_sy_type}\" && t.save_unquant == {f_suq})'.format( + f_fused_sweep = ins.F_kFusedQuant, f_sx_type=ins.F_SmoothScaleDataType, f_sy_type=ins.F_YScaleDataType, f_suq=BOOL_MAP(ins.F_kSaveUnquant)) elif ins.F_kFusedQuant == 2: - _sweep_cond = 't.fused_quant == {f_fused_sweep} && (t.prec_sy == \"{f_sy_type}\")'.format( - f_fused_sweep = ins.F_kFusedQuant, f_sy_type=ins.F_YScaleDataType) + _sweep_cond = 't.fused_quant == {f_fused_sweep} && (t.prec_sy == \"{f_sy_type}\" && t.save_unquant == {f_suq})'.format( + f_fused_sweep = ins.F_kFusedQuant, f_sy_type=ins.F_YScaleDataType, f_suq=BOOL_MAP(ins.F_kSaveUnquant)) _cond = '((a.n % {f_vec_n} == 0) && (t.fused_add == {f_fused_add}) && ({f_sweep_cond}))'.format( f_vec_n = ins.F_Vector_N, f_fused_add = ins.F_kFusedAdd, f_sweep_cond = _sweep_cond) @@ -489,67 +514,72 @@ float rmsnorm2d_fwd(rmsnorm2d_fwd_traits t, #fused_sweep_list = [0, 1, 2] # NOTE: only single pass can use fused (smooth) dynamic quant fused_add_list = [0, 1] fused_sweep_list = [0, 1, 2] # NOTE: only single pass can use fused (smooth) dynamic quant + bool_list = [False, True] - # rm rn tm tn vn pd mv 2p add sweep - h_trait_dict = {'64' : [ h_traits('x', 'y', 'xs', 'ys', 1, 1, 8, 8, 8, True, False, False, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 1, 4, 16, 4, True, False, False, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 1, 4, 64, 1, True, False, False, 0, 0)], - '128' : [ h_traits('x', 'y', 'xs', 'ys', 1, 1, 4, 16, 8, True, False, False, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 1, 4, 64, 2, True, False, False, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 2, 4, 64, 1, True, False, False, 0, 0)], - '256' : [ h_traits('x', 'y', 'xs', 'ys', 1, 1, 4, 64, 4, True, False, False, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 2, 4, 64, 2, True, False, False, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 4, 4, 64, 1, True, False, False, 0, 0)], - '512' : [ h_traits('x', 'y', 'xs', 'ys', 1, 1, 4, 64, 8, True, False, False, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 2, 4, 64, 4, True, False, False, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 4, 4, 64, 2, True, False, False, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 8, 4, 64, 1, True, False, False, 0, 0)], - '768' : [ h_traits('x', 'y', 'xs', 'ys', 1, 3, 4, 64, 4, True, False, False, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 6, 4, 64, 2, True, False, False, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 12, 4, 64, 1, True, False, False, 0, 0)], - '1024' :[ h_traits('x', 'y', 'xs', 'ys', 1, 1, 2, 128, 8, True, False, False, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 2, 2, 128, 4, True, False, False, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 4, 2, 128, 2, True, False, False, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 256, 1, True, False, False, 0, 0)], - '1536' :[ h_traits('x', 'y', 'xs', 'ys', 1, 3, 4, 64, 8, True, False, False, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 3, 2, 128, 4, True, False, False, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 3, 1, 256, 2, True, False, False, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 6, 1, 256, 1, True, False, False, 0, 0)], - '2048' :[ h_traits('x', 'y', 'xs', 'ys', 1, 1, 1, 256, 8, True, False, False, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 2, 1, 256, 4, True, False, False, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 256, 2, True, False, False, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 8, 1, 256, 1, True, False, False, 0, 0)], - '3072' :[ h_traits('x', 'y', 'xs', 'ys', 1, 3, 1, 128, 8, True, False, False, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 3, 1, 256, 4, True, False, False, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 6, 1, 256, 2, True, False, False, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 3, 1,1024, 1, True, False, False, 0, 0)], - '4096' :[ h_traits('x', 'y', 'xs', 'ys', 1, 2, 1, 256, 8, True, False, False, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 256, 4, True, False, False, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 2, 1,1024, 2, True, False, False, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 4, 1,1024, 1, True, False, False, 0, 0)], - '6144' :[ h_traits('x', 'y', 'xs', 'ys', 1, 3, 1, 256, 8, True, False, False, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 3, 1, 512, 4, True, False, False, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 3, 1,1024, 2, True, False, False, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 6, 1,1024, 1, True, False, False, 0, 0)], - '8192' :[ h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 256, 8, True, False, False, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 512, 4, True, False, False, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 4, 1,1024, 2, True, False, False, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 8, 1,1024, 1, True, False, False, 0, 0)], - 'big' :[ h_traits('x', 'y', 'xs', 'ys', 1, 2, 1, 256, 8, True, False, True, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 256, 4, True, False, True, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 2, 1,1024, 2, True, False, True, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 4, 1,1024, 1, True, False, True, 0, 0)]} + # rm rn tm tn vn pd mv unquant 2p add sweep + h_trait_dict = {'64' : [ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 1, 8, 8, 8, True, False, False, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 1, 4, 16, 4, True, False, False, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 1, 4, 64, 1, True, False, False, False, 0, 0)], + '128' : [ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 1, 4, 16, 8, True, False, False, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 1, 4, 64, 2, True, False, False, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 2, 4, 64, 1, True, False, False, False, 0, 0)], + '256' : [ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 1, 4, 64, 4, True, False, False, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 2, 4, 64, 2, True, False, False, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 4, 64, 1, True, False, False, False, 0, 0)], + '512' : [ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 1, 4, 64, 8, True, False, False, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 2, 4, 64, 4, True, False, False, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 4, 64, 2, True, False, False, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 8, 4, 64, 1, True, False, False, False, 0, 0)], + '640' : [ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 5, 4, 64, 2, True, False, False, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 5, 4, 128, 1, True, False, False, False, 0, 0)], + '768' : [ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 3, 4, 64, 4, True, False, False, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 6, 4, 64, 2, True, False, False, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 12, 4, 64, 1, True, False, False, False, 0, 0)], + '1024' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 2, 2, 64, 8, True, False, False, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 2, 64, 4, True, False, False, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 8, 2, 64, 2, True, False, False, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 1, 256, 1, True, False, False, False, 0, 0)], + '1536' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 3, 4, 64, 8, True, False, False, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 3, 2, 128, 4, True, False, False, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 3, 1, 256, 2, True, False, False, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 6, 1, 256, 1, True, False, False, False, 0, 0)], + '2048' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 1, 1, 256, 8, True, False, False, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 2, 1, 256, 4, True, False, False, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 1, 256, 2, True, False, False, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 8, 1, 256, 1, True, False, False, False, 0, 0)], + '3072' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 3, 1, 128, 8, True, False, False, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 3, 1, 256, 4, True, False, False, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 6, 1, 256, 2, True, False, False, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 3, 1,1024, 1, True, False, False, False, 0, 0)], + '4096' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 2, 1, 256, 8, True, False, False, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 1, 256, 4, True, False, False, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 2, 1,1024, 2, True, False, False, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 1,1024, 1, True, False, False, False, 0, 0)], + '6144' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 3, 1, 256, 8, True, False, False, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 3, 1, 512, 4, True, False, False, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 3, 1,1024, 2, True, False, False, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 6, 1,1024, 1, True, False, False, False, 0, 0)], + '8192' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 1, 256, 8, True, False, False, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 1, 512, 4, True, False, False, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 1,1024, 2, True, False, False, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 8, 1,1024, 1, True, False, False, False, 0, 0)], + 'big' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 1, 1,1024, 8, True, False, False, True, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 1, 256, 4, True, False, False, True, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 12, 1, 256, 2, True, False, False, True, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 1,1024, 1, True, False, False, True, 0, 0)]} total_blob = list() for hs_key in h_trait_dict: hs = h_trait_dict[hs_key] current_n = hs[0].F_Repeat_N * hs[0].F_ThreadPerBlock_N * hs[0].F_Vector_N - for dtype, scale_type, fused_add, fused_quant in itertools.product(dtype_list, scale_list, fused_add_list, fused_sweep_list): + for dtype, scale_type, fused_add, fused_quant, save_unquant in itertools.product(dtype_list, scale_list, fused_add_list, fused_sweep_list, bool_list): prec_i, prec_o = dtype.split(',') scale_sm, scale_y = scale_type.split(',') if prec_o in dynamic_quant_out_dtype and fused_quant != 1 and fused_quant != 2: continue # skip non dynamic quant case if (fused_quant == 1 or fused_quant == 2) and hs_key == 'big': continue + if (fused_quant == 0 and save_unquant == True): + continue # save_unquant should always be false when there is no quant enabled current_hs = list() for chs_ in hs: h_ = copy.copy(chs_) # copy the base instance out @@ -557,12 +587,14 @@ float rmsnorm2d_fwd(rmsnorm2d_fwd_traits t, h_.F_YDataType = prec_o h_.F_SmoothScaleDataType = scale_sm h_.F_YScaleDataType = scale_y + h_.F_UnquantYDataType = prec_i h_.F_kFusedAdd = fused_add h_.F_kFusedQuant = fused_quant + h_.F_kSaveUnquant = save_unquant current_hs.append(h_) # + "\n" #f.write(str(f.parent / GEN_DIR / (blobs.api_common_header_ current_n_str = 'big' if hs_key == 'big' else current_n - total_blob.append(h_instance(dtype, current_n_str, fused_add, fused_quant, current_hs)) + total_blob.append(h_instance(dtype, current_n_str, fused_add, fused_quant, save_unquant, current_hs)) return total_blob def list_blobs(self) -> None: diff --git a/example/ck_tile/10_rmsnorm2d/rmsnorm2d_fwd.cpp b/example/ck_tile/10_rmsnorm2d/rmsnorm2d_fwd.cpp index cdee6dfb80..d5be4384ab 100644 --- a/example/ck_tile/10_rmsnorm2d/rmsnorm2d_fwd.cpp +++ b/example/ck_tile/10_rmsnorm2d/rmsnorm2d_fwd.cpp @@ -38,6 +38,7 @@ auto create_args(int argc, char* argv[]) .insert("yr_stride", "-1", "y residule row_stride, if -1 then equal to n") .insert("e", "1e-5", "epsilon") .insert("save_rms", "0", "save rms(invrms) or not. set to 1 in training case") + .insert("save_unquant", "0", "save result before quant") .insert("v", "1", "cpu validation or not") .insert("kname", "1", "print kernel name or not") .insert("prec_i", "fp16", "input precision") @@ -61,7 +62,8 @@ template + bool SaveRms, + bool SaveUnquant> bool run(const ck_tile::ArgParser& arg_parser) { ck_tile::index_t m = arg_parser.get_int("m"); @@ -113,6 +115,14 @@ bool run(const ck_tile::ArgParser& arg_parser) return false; } + if((fused_quant == 0) && SaveUnquant) + { + std::cout + << "save_unquant should be 0 if quant output is not enabled because it is meaningless. " + << "Output Y is what wanted." << std::endl; + return false; + } + using TypeConfig = RmsnormTypeConfig; @@ -124,6 +134,8 @@ bool run(const ck_tile::ArgParser& arg_parser) using InvRmsDataType = std::conditional_t; + using UnquantYDataType = + std::conditional_t; using ComputeDataType = typename TypeConfig::ComputeDataType; @@ -143,6 +155,10 @@ bool run(const ck_tile::ArgParser& arg_parser) ck_tile::HostTensor invRms_host_ref({m}); + ck_tile::HostTensor unquant_y_host_ref({m, n}, {y_stride, 1}); + ck_tile::HostTensor unquant_y_host_dev({m, n}, {y_stride, 1}); + ck_tile::HostTensor unquant_y_null({1}); + ck_tile::FillUniformDistribution{-.5f, .5f}(x_host); ck_tile::FillUniformDistribution{-.5f, .5f}(x_residual_host); ck_tile::FillUniformDistribution{-1.f, 1.f}(sm_scale_host); @@ -155,6 +171,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ck_tile::DeviceMem sm_scale_buf(sm_scale_host_dev.get_element_space_size_in_bytes()); ck_tile::DeviceMem x_residual_buf(x_residual_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem y_residual_buf(y_residual_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem unquant_y_buf(unquant_y_host_dev.get_element_space_size_in_bytes()); x_buf.ToDevice(x_host.data()); gamma_buf.ToDevice(gamma_host.data()); @@ -179,7 +196,8 @@ bool run(const ck_tile::ArgParser& arg_parser) << ", xr_stride:" << xr_stride << ", y_stride:" << y_stride << ", yr_stride:" << yr_stride << std::flush; - rmsnorm2d_fwd_traits traits{prec_i, prec_o, prec_sm, prec_sy, SaveRms, fused_add, fused_quant}; + rmsnorm2d_fwd_traits traits{ + prec_i, prec_o, prec_sm, prec_sy, SaveRms, SaveUnquant, fused_add, fused_quant}; rmsnorm2d_fwd_args args{x_buf.GetDeviceBuffer(), fused_add != 0 ? x_residual_buf.GetDeviceBuffer() : nullptr, @@ -189,6 +207,7 @@ bool run(const ck_tile::ArgParser& arg_parser) fused_add == 1 ? y_residual_buf.GetDeviceBuffer() : nullptr, fused_quant != 0 ? y_scale_buf.GetDeviceBuffer() : nullptr, nullptr, // p_invRms, unsupported yet + SaveUnquant ? unquant_y_buf.GetDeviceBuffer() : nullptr, epsilon, m, n, @@ -203,6 +222,7 @@ bool run(const ck_tile::ArgParser& arg_parser) std::size_t num_byte = sizeof(XDataType) * m * n + sizeof(GammaDataType) * n + sizeof(YDataType) * m * n; num_byte += SaveRms ? sizeof(InvRmsDataType) * m * n : 0; + num_byte += SaveUnquant ? sizeof(UnquantYDataType) * m * n : 0; num_byte += fused_add ? sizeof(XResidualDataType) * m * n : 0; num_byte += ((fused_quant == 1) || (fused_quant == 2)) ? sizeof(YScaleDataType) * m : 0; num_byte += (fused_quant == 1) ? sizeof(SmoothScaleDataType) * n : 0; @@ -262,21 +282,57 @@ bool run(const ck_tile::ArgParser& arg_parser) } }; - ck_tile::reference_rmsnorm2d_fwd( - x_host, gamma_host, y_host_ref, invRms_host_ref, epsilon, dquant_functor); + auto default_and_dquant_functor = [&](int m_, auto& o_unquant_, auto& o_, auto& acc_) { + const int N = acc_.mDesc.get_lengths()[1]; + for(int n_ = 0; n_ < N; ++n_) + { + o_unquant_(m_, n_) = ck_tile::type_convert(acc_(m_, n_)); + } + + dquant_functor(m_, o_, acc_); + }; + + if constexpr(SaveUnquant) + { + ck_tile::reference_rmsnorm2d_fwd(x_host, + gamma_host, + y_host_ref, + invRms_host_ref, + unquant_y_host_ref, + epsilon, + default_and_dquant_functor); + } + else + { + ck_tile::reference_rmsnorm2d_fwd(x_host, + gamma_host, + y_host_ref, + invRms_host_ref, + unquant_y_host_ref, + epsilon, + dquant_functor); + } } else { + assert(SaveUnquant == false); ck_tile::reference_rmsnorm2d_fwd( - x_host, gamma_host, y_host_ref, invRms_host_ref, epsilon); + InvRmsDataType, + ck_tile::null_type>( + x_host, gamma_host, y_host_ref, invRms_host_ref, unquant_y_null, epsilon); } y_buf.FromDevice(y_host_dev.data()); @@ -293,6 +349,15 @@ bool run(const ck_tile::ArgParser& arg_parser) pass = ck_tile::check_err( y_host_dev, y_host_ref, std::string("\nOUT Error: Incorrect results!"), rtol, atol); + if constexpr(SaveUnquant) + { + pass &= ck_tile::check_err(unquant_y_host_dev, + unquant_y_host_ref, + std::string("\n OUT ERROR: Incorrect unquant results!"), + rtol, + atol); + } + if(fused_add == 1) { pass &= ck_tile::check_err(y_residual_host_dev, @@ -331,6 +396,23 @@ bool run(const ck_tile::ArgParser& arg_parser) rtol, atol); } + + if constexpr(SaveUnquant) + { + std::vector unquant_y_host_dev_row( + unquant_y_host_dev.begin() + i_r * y_stride, + unquant_y_host_dev.begin() + i_r * y_stride + n); + std::vector unquant_y_host_ref_row( + unquant_y_host_ref.begin() + i_r * y_stride, + unquant_y_host_ref.begin() + i_r * y_stride + n); + pass &= + ck_tile::check_err(unquant_y_host_dev_row, + unquant_y_host_ref_row, + std::string("\nOUT[") + std::to_string(i_r) + + std::string("] Error: Incorrect unquant y results!"), + rtol, + atol); + } } } @@ -350,6 +432,8 @@ bool run(const ck_tile::ArgParser& arg_parser) return pass; } +bool is_quant_data_type(const std::string& prec) { return (prec == "int8") || (prec == "fp8"); } + int main(int argc, char* argv[]) { auto [result, arg_parser] = create_args(argc, argv); @@ -373,48 +457,79 @@ int main(int argc, char* argv[]) prec_sy = "fp32"; } - int save_rms = arg_parser.get_int("save_rms"); + int save_rms = arg_parser.get_int("save_rms"); + int fused_quant = arg_parser.get_int("fquant"); + int save_unquant = + arg_parser.get_int("save_unquant") && is_quant_data_type(prec_o) && (fused_quant != 0); if(prec_i == "fp16" && prec_o == "fp16" && prec_sm == "fp32" && prec_sy == "fp32" && save_rms) { - return run(arg_parser) ? 0 : -2; + return run(arg_parser) ? 0 + : -2; } else if(prec_i == "fp16" && prec_o == "fp16" && prec_sm == "fp32" && prec_sy == "fp32" && !save_rms) { - return run(arg_parser) ? 0 : -2; + return run(arg_parser) ? 0 + : -2; } else if(prec_i == "bf16" && prec_o == "bf16" && prec_sm == "fp32" && prec_sy == "fp32" && save_rms) { - return run(arg_parser) ? 0 : -2; + return run(arg_parser) ? 0 + : -2; } else if(prec_i == "bf16" && prec_o == "bf16" && prec_sm == "fp32" && prec_sy == "fp32" && !save_rms) { - return run(arg_parser) ? 0 : -2; + return run(arg_parser) ? 0 + : -2; } // dynamic quant case, only in inference else if(prec_i == "fp16" && prec_o == "int8" && prec_sm == "fp32" && prec_sy == "fp32" && - !save_rms) + !save_rms && !save_unquant) { - return run(arg_parser) ? 0 : -2; + return run(arg_parser) ? 0 + : -2; } else if(prec_i == "bf16" && prec_o == "int8" && prec_sm == "fp32" && prec_sy == "fp32" && - !save_rms) + !save_rms && !save_unquant) { - return run(arg_parser) ? 0 : -2; + return run(arg_parser) ? 0 + : -2; } else if(prec_i == "fp16" && prec_o == "fp8" && prec_sm == "fp32" && prec_sy == "fp32" && - !save_rms) + !save_rms && !save_unquant) { - return run(arg_parser) ? 0 : -2; + return run(arg_parser) ? 0 + : -2; } else if(prec_i == "bf16" && prec_o == "fp8" && prec_sm == "fp32" && prec_sy == "fp32" && - !save_rms) + !save_rms && !save_unquant) { - return run(arg_parser) ? 0 : -2; + return run(arg_parser) ? 0 + : -2; + } + else if(prec_i == "fp16" && prec_o == "int8" && prec_sm == "fp32" && prec_sy == "fp32" && + !save_rms && save_unquant) + { + return run(arg_parser) ? 0 : -2; + } + else if(prec_i == "bf16" && prec_o == "int8" && prec_sm == "fp32" && prec_sy == "fp32" && + !save_rms && save_unquant) + { + return run(arg_parser) ? 0 : -2; + } + else if(prec_i == "fp16" && prec_o == "fp8" && prec_sm == "fp32" && prec_sy == "fp32" && + !save_rms && save_unquant) + { + return run(arg_parser) ? 0 : -2; + } + else if(prec_i == "bf16" && prec_o == "fp8" && prec_sm == "fp32" && prec_sy == "fp32" && + !save_rms && save_unquant) + { + return run(arg_parser) ? 0 : -2; } return -3; diff --git a/example/ck_tile/10_rmsnorm2d/rmsnorm2d_fwd.hpp b/example/ck_tile/10_rmsnorm2d/rmsnorm2d_fwd.hpp index 566b94442d..bb4a2f5ef4 100644 --- a/example/ck_tile/10_rmsnorm2d/rmsnorm2d_fwd.hpp +++ b/example/ck_tile/10_rmsnorm2d/rmsnorm2d_fwd.hpp @@ -21,6 +21,7 @@ struct RmsnormTypeConfig void reference_rmsnorm2d_fwd(const HostTensor& x_m_n, const HostTensor& gamma_n, HostTensor& y_m_n, HostTensor& invRms_m, + HostTensor& unquant_y_m_n, ComputeDataType epsilon, Epilogue epilogue_functor = {}) { @@ -69,7 +71,14 @@ void reference_rmsnorm2d_fwd(const HostTensor& x_m_n, acc(m, n) = x * divisor * gamma; } - epilogue_functor(m, y_m_n, acc); + if constexpr(!std::is_same_v) + { + epilogue_functor(m, unquant_y_m_n, y_m_n, acc); + } + else + { + epilogue_functor(m, y_m_n, acc); + } }; make_ParallelTensorFunctor(rmsnorm2d_fwd_func, invRms_m.mDesc.get_lengths()[0])( diff --git a/include/ck_tile/ops/epilogue.hpp b/include/ck_tile/ops/epilogue.hpp index 9d2ed407c9..12e53e13e6 100644 --- a/include/ck_tile/ops/epilogue.hpp +++ b/include/ck_tile/ops/epilogue.hpp @@ -6,6 +6,7 @@ #include "ck_tile/ops/epilogue/cshuffle_epilogue.hpp" #include "ck_tile/ops/epilogue/default_2d_epilogue.hpp" #include "ck_tile/ops/epilogue/dynamic_quant_epilogue.hpp" +#include "ck_tile/ops/epilogue/default_2d_and_dynamic_quant_epilogue.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/common/utils.hpp" diff --git a/include/ck_tile/ops/epilogue/default_2d_and_dynamic_quant_epilogue.hpp b/include/ck_tile/ops/epilogue/default_2d_and_dynamic_quant_epilogue.hpp new file mode 100644 index 0000000000..6c5a2ac149 --- /dev/null +++ b/include/ck_tile/ops/epilogue/default_2d_and_dynamic_quant_epilogue.hpp @@ -0,0 +1,91 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "default_2d_epilogue.hpp" +#include "dynamic_quant_epilogue.hpp" + +namespace ck_tile { + +// User can reuse DynamicQuantEpilogueTraits with this epilogue +template +using Default2DAndDynamicQuantEpilogueTraits = + DynamicQuantEpilogueTraits; + +// This epilogue just store out a M*N matrix, row major +template +struct Default2DAndDynamicQuantEpilogueProblem +{ + using AccDataType = remove_cvref_t; + using SmoothScaleDataType = remove_cvref_t; + using YScaleDataType = remove_cvref_t; + using ODataType = remove_cvref_t; + using UnquantYDataType = remove_cvref_t; + using BlockShape = remove_cvref_t; // can consum generic 2d shape + using Traits = remove_cvref_t; +}; + +template +struct Default2DAndDynamicQuantEpilogue +{ + using Problem = remove_cvref_t; + using AccDataType = remove_cvref_t; + using UnquantYDataType = remove_cvref_t; + + static constexpr bool kPadM = Problem::Traits::kPadM; + static constexpr bool kPadN = Problem::Traits::kPadN; + static constexpr bool UseRawStore = Problem::Traits::UseRawStore; + + using Default2DProblem = + Default2DEpilogueProblem; + using Default2D = Default2DEpilogue; + using DynamicQuant = DynamicQuantEpilogue; + + CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() + { + return max(Default2D::GetSmemSize(), DynamicQuant::GetSmemSize()); + } + + template + CK_TILE_DEVICE auto operator()(ODramWindowTmpD& o_direct_dram_window_tmp, + ODramWindowTmpQ& o_quant_dram_window_tmp, + const SmoothScaleWindow& sm_scale_window_, + YScaleWindow& y_scale_window, + const OAccTile& o_acc_tile, + void* smem) + { + Default2D{}(o_direct_dram_window_tmp, o_acc_tile, smem); + DynamicQuant{}(o_quant_dram_window_tmp, sm_scale_window_, y_scale_window, o_acc_tile, smem); + } + + template + CK_TILE_DEVICE auto operator()(ODramWindowTmpD& o_direct_dram_window_tmp, + ODramWindowTmpQ& o_quant_dram_window_tmp, + YScaleWindow& y_scale_window, + const OAccTile& o_acc_tile, + void* smem) + { + Default2D{}(o_direct_dram_window_tmp, o_acc_tile, smem); + DynamicQuant{}(o_quant_dram_window_tmp, y_scale_window, o_acc_tile, smem); + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_two_pass.hpp b/include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_two_pass.hpp index b0b0c194ad..73cdd084c6 100644 --- a/include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_two_pass.hpp +++ b/include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_two_pass.hpp @@ -182,9 +182,16 @@ struct Layernorm2dFwdPipelineTwoPass ck_tile::index_t stride_to_right_most_window = row_size % Block_N == 0 ? row_size - Block_N : row_size - row_size % Block_N; - move_tile_window(x_window, {0, -Block_N}); - move_tile_window(x_residual_window, {0, -Block_N}); - move_tile_window(x_bias_window, {-Block_N}); + if constexpr(kFusedAdd == Layernorm2dFusedAddEnum::PRE_ADD_STORE) + { + move_tile_window(y_residual_window, {0, -Block_N}); + } + else + { + move_tile_window(x_window, {0, -Block_N}); + move_tile_window(x_residual_window, {0, -Block_N}); + move_tile_window(x_bias_window, {-Block_N}); + } move_tile_window(gamma_window, {stride_to_right_most_window}); move_tile_window(beta_window, {stride_to_right_most_window}); move_tile_window(y_window, {0, stride_to_right_most_window}); @@ -192,28 +199,43 @@ struct Layernorm2dFwdPipelineTwoPass // layernorm computation for(int iN = __builtin_amdgcn_readfirstlane(0); iN < num_n_tile_iteration; ++iN) { - auto x = load_tile(x_window); - auto x_resi = load_tile(x_residual_window); - const auto x_bias = load_tile(x_bias_window); - auto acc = cast_tile(x); + auto acc = make_static_distributed_tensor( + decltype(load_tile(x_window))::get_tile_distribution()); - if constexpr(kXbias == Layernorm2dXBiasEnum::ADD_BIAS) + if constexpr(kFusedAdd == Layernorm2dFusedAddEnum::PRE_ADD_STORE) { - sweep_tile(x, [&](auto idx) { - // compute x = bias + x - constexpr auto j_idx = make_tuple(idx[number<1>{}]); - acc(idx) = type_convert(x_bias[j_idx]) + acc(idx); - }); + acc = cast_tile(load_tile(y_residual_window)); + move_tile_window(y_residual_window, {0, -Block_N}); + } + else + { + acc = cast_tile(load_tile(x_window)); + move_tile_window(x_window, {0, -Block_N}); + + if constexpr(kXbias == Layernorm2dXBiasEnum::ADD_BIAS) + { + const auto x_bias = load_tile(x_bias_window); + move_tile_window(x_bias_window, {-Block_N}); + + sweep_tile(acc, [&](auto idx) { + // compute x = bias + x + constexpr auto j_idx = make_tuple(idx[number<1>{}]); + acc(idx) = type_convert(x_bias[j_idx]) + acc(idx); + }); + } + + if constexpr(kFusedAdd == Layernorm2dFusedAddEnum::PRE_ADD) + { + auto x_resi = load_tile(x_residual_window); + move_tile_window(x_residual_window, {0, -Block_N}); + + sweep_tile(x_resi, [&](auto idx) { + // compute x = x_resi + x + acc(idx) = type_convert(x_resi(idx)) + acc(idx); + }); + } } - if constexpr(kFusedAdd == Layernorm2dFusedAddEnum::PRE_ADD_STORE || - kFusedAdd == Layernorm2dFusedAddEnum::PRE_ADD) - { - sweep_tile(x_resi, [&](auto idx) { - // compute x = x_resi + x - acc(idx) = type_convert(x_resi(idx)) + acc(idx); - }); - } // load gamma/beta (TODO: support no gamma/beta?) const auto gamma = load_tile(gamma_window); const auto beta = load_tile(beta_window); @@ -235,9 +257,6 @@ struct Layernorm2dFwdPipelineTwoPass static_assert(kFusedQuant != Layernorm2dFusedQuantEnum::DYNAMIC_QUANT); Epilogue{}(y_window, ln); - move_tile_window(x_window, {0, -Block_N}); - move_tile_window(x_residual_window, {0, -Block_N}); - move_tile_window(x_bias_window, {-Block_N}); move_tile_window(gamma_window, {-Block_N}); move_tile_window(beta_window, {-Block_N}); move_tile_window(y_window, {0, -Block_N}); diff --git a/include/ck_tile/ops/rmsnorm2d/kernel/rmsnorm2d_fwd_kernel.hpp b/include/ck_tile/ops/rmsnorm2d/kernel/rmsnorm2d_fwd_kernel.hpp index 88c8084de6..f0251177d4 100644 --- a/include/ck_tile/ops/rmsnorm2d/kernel/rmsnorm2d_fwd_kernel.hpp +++ b/include/ck_tile/ops/rmsnorm2d/kernel/rmsnorm2d_fwd_kernel.hpp @@ -21,6 +21,7 @@ struct Rmsnorm2dFwdHostArgs void* p_y_residual; // [m, n], shortcut output, prec same as input, nullptr if not used void* p_y_scale; // [m, 1], output a dynamic quant per row, nullptr if not used void* p_invRms; // [m, 1], output inv-rms, prec same as input, nullptr if not used + void* p_y_unquant; // [m, n], output result before quant, nullptr if not used float epsilon; @@ -47,13 +48,15 @@ struct Rmsnorm2dFwd using InvRmsDataType = remove_cvref_t; using SmoothScaleDataType = remove_cvref_t; using YScaleDataType = remove_cvref_t; + using UnquantYDataType = remove_cvref_t; // for simplicity, shortcut input/output type is same as X using XResidualDataType = XDataType; using YResidualDataType = XDataType; - static constexpr bool kHasGamma = !std::is_same_v; - static constexpr bool kSaveInvRms = Problem::Traits::kSaveInvRms; + static constexpr bool kHasGamma = !std::is_same_v; + static constexpr bool kSaveInvRms = Problem::Traits::kSaveInvRms; + static constexpr bool kSaveUnquant = Problem::Traits::kSaveUnquant; static constexpr index_t Block_M = Problem::BlockShape::Block_M; static constexpr index_t Block_N = Problem::BlockShape::Block_N; @@ -81,6 +84,7 @@ struct Rmsnorm2dFwd void* p_y_residual; void* p_y_scale; void* p_invRms; + void* p_y_unquant; float epsilon; @@ -103,6 +107,7 @@ struct Rmsnorm2dFwd hargs.p_y_residual, hargs.p_y_scale, hargs.p_invRms, + hargs.p_y_unquant, hargs.epsilon, hargs.m, hargs.n, @@ -323,6 +328,30 @@ struct Rmsnorm2dFwd } }(); + auto unquant_y_window = [&]() { + if constexpr((kFusedQuant == Rmsnorm2dFusedQuantEnum::SMOOTH_DYNAMIC_QUANT || + kFusedQuant == Rmsnorm2dFusedQuantEnum::DYNAMIC_QUANT) && + kSaveUnquant) + { + auto tmp_ = make_naive_tensor_view( + static_cast(kargs.p_y_unquant), + make_tuple(kargs.m, kargs.n), + make_tuple(kargs.y_stride, 1), + number{}, + number<1>{}); + + auto tmp2_ = pad_tensor_view(tmp_, + make_tuple(number{}, number{}), + sequence{}); + return make_tile_window( + tmp2_, make_tuple(number{}, number{}), {iM, 0}); + } + else + { + return make_null_tile_window(make_tuple(number{}, number{})); + } + }(); + __shared__ char smem[GetSmemSize()]; Pipeline{}(x_window, @@ -333,6 +362,7 @@ struct Rmsnorm2dFwd inv_rms_window, sm_scale_window, y_scale_window, + unquant_y_window, static_cast(kargs.epsilon), kargs.n, smem, diff --git a/include/ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_one_pass.hpp b/include/ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_one_pass.hpp index 93c2833be4..58159142d0 100644 --- a/include/ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_one_pass.hpp +++ b/include/ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_one_pass.hpp @@ -25,8 +25,9 @@ struct Rmsnorm2dFwdPipelineOnePass using XResidualDataType = XDataType; using YResidualDataType = XDataType; - static constexpr bool kHasGamma = !std::is_same_v; - static constexpr bool kSaveInvRms = Problem::Traits::kSaveInvRms; + static constexpr bool kHasGamma = !std::is_same_v; + static constexpr bool kSaveInvRms = Problem::Traits::kSaveInvRms; + static constexpr bool kSaveUnquant = Problem::Traits::kSaveUnquant; static constexpr bool kNeedCrossWarpSync = Problem::kNeedCrossWarpSync; static constexpr bool kPadM = false; // TODO - BlockRmsnorm2dFwdProblem::kPadM @@ -54,6 +55,7 @@ struct Rmsnorm2dFwdPipelineOnePass typename InvRmsWindow, typename SmoothScaleWindow, typename YScaleWindow, + typename UnquantYWindow, typename Epilogue> CK_TILE_DEVICE auto operator()(const XWindow& x_window_, const XResidualWindow& x_residual_window_, @@ -63,6 +65,7 @@ struct Rmsnorm2dFwdPipelineOnePass InvRmsWindow& inv_rms_window, const SmoothScaleWindow& sm_scale_window_, YScaleWindow& y_scale_window_, + UnquantYWindow& unquant_y_window, ComputeDataType epsilon, ck_tile::index_t row_size, void* smem, @@ -137,11 +140,26 @@ struct Rmsnorm2dFwdPipelineOnePass if constexpr(kFusedQuant == Rmsnorm2dFusedQuantEnum::SMOOTH_DYNAMIC_QUANT) { - Epilogue{}(y_window_, sm_scale_window_, y_scale_window_, rmsn, smem); + if constexpr(kSaveUnquant) + { + Epilogue{}( + unquant_y_window, y_window_, sm_scale_window_, y_scale_window_, rmsn, smem); + } + else + { + Epilogue{}(y_window_, sm_scale_window_, y_scale_window_, rmsn, smem); + } } else if constexpr(kFusedQuant == Rmsnorm2dFusedQuantEnum::DYNAMIC_QUANT) { - Epilogue{}(y_window_, y_scale_window_, rmsn, smem); + if constexpr(kSaveUnquant) + { + Epilogue{}(unquant_y_window, y_window_, y_scale_window_, rmsn, smem); + } + else + { + Epilogue{}(y_window_, y_scale_window_, rmsn, smem); + } } else { diff --git a/include/ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_problem.hpp b/include/ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_problem.hpp index baf56246f3..773df4f0f4 100644 --- a/include/ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_problem.hpp +++ b/include/ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_problem.hpp @@ -12,6 +12,7 @@ template ; using YDataType = remove_cvref_t; using InvRmsDataType = remove_cvref_t; + using UnquantYDataType = remove_cvref_t; using SmoothScaleDataType = remove_cvref_t; using YScaleDataType = remove_cvref_t; using BlockShape = remove_cvref_t; diff --git a/include/ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_two_pass.hpp b/include/ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_two_pass.hpp index c29a6cb07d..4ca1dbc5da 100644 --- a/include/ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_two_pass.hpp +++ b/include/ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_two_pass.hpp @@ -54,6 +54,7 @@ struct Rmsnorm2dFwdPipelineTwoPass typename InvRmsWindow, typename SmoothScaleWindow, typename YScaleWindow, + typename UnquantYWindow, typename Epilogue> CK_TILE_DEVICE auto operator()(const XWindow& x_window_, const XResidualWindow& x_residual_window_, @@ -63,6 +64,7 @@ struct Rmsnorm2dFwdPipelineTwoPass InvRmsWindow& inv_rms_window, const SmoothScaleWindow& /*sm_scale_window_*/, YScaleWindow& /*y_scale_window*/, + UnquantYWindow& /*unquant_y_window*/, ComputeDataType epsilon, ck_tile::index_t row_size, void* smem, @@ -136,32 +138,51 @@ struct Rmsnorm2dFwdPipelineTwoPass ck_tile::index_t stride_to_right_most_window = row_size % Block_N == 0 ? row_size - Block_N : row_size - row_size % Block_N; - move_tile_window(x_window, {0, -Block_N}); - move_tile_window(x_residual_window, {0, -Block_N}); + if constexpr(kFusedAdd == Rmsnorm2dFusedAddEnum::PRE_ADD_STORE) + { + move_tile_window(y_residual_window, {0, -Block_N}); + } + else + { + move_tile_window(x_window, {0, -Block_N}); + move_tile_window(x_residual_window, {0, -Block_N}); + } move_tile_window(gamma_window, {stride_to_right_most_window}); move_tile_window(y_window, {0, stride_to_right_most_window}); // rmsnorm computation for(int iN = __builtin_amdgcn_readfirstlane(0); iN < num_n_tile_iteration; ++iN) { - auto x = load_tile(x_window); - auto x_resi = load_tile(x_residual_window); - auto acc = cast_tile(x); + auto acc = make_static_distributed_tensor( + decltype(load_tile(x_window))::get_tile_distribution()); - if constexpr(kFusedAdd == Rmsnorm2dFusedAddEnum::PRE_ADD_STORE || - kFusedAdd == Rmsnorm2dFusedAddEnum::PRE_ADD) + if constexpr(kFusedAdd == Rmsnorm2dFusedAddEnum::PRE_ADD_STORE) { - sweep_tile(x_resi, [&](auto idx) { - // compute x = x_resi + x - acc(idx) = type_convert(x_resi(idx)) + acc(idx); - }); + acc = cast_tile(load_tile(y_residual_window)); + move_tile_window(y_residual_window, {0, -Block_N}); + } + else + { + acc = cast_tile(load_tile(x_window)); + move_tile_window(x_window, {0, -Block_N}); + + if constexpr(kFusedAdd == Rmsnorm2dFusedAddEnum::PRE_ADD) + { + auto x_resi = load_tile(x_residual_window); + sweep_tile(x_resi, [&](auto idx) { + // compute x = x_resi + x + acc(idx) = type_convert(x_resi(idx)) + acc(idx); + }); + move_tile_window(x_residual_window, {0, -Block_N}); + } } // load gamma (TODO: support no gamma?) const auto gamma = load_tile(gamma_window); // rmsnorm computation - auto rmsn = make_static_distributed_tensor(x.get_tile_distribution()); + auto rmsn = make_static_distributed_tensor( + decltype(load_tile(x_window))::get_tile_distribution()); sweep_tile(rmsn, [&, inv_rms_ = inv_rms](auto idx) { constexpr auto i_idx = make_tuple(idx[number<0>{}]); constexpr auto j_idx = make_tuple(idx[number<1>{}]); @@ -176,8 +197,6 @@ struct Rmsnorm2dFwdPipelineTwoPass static_assert(kFusedQuant == Rmsnorm2dFusedQuantEnum::NO_SWEEP); Epilogue{}(y_window, rmsn); - move_tile_window(x_window, {0, -Block_N}); - move_tile_window(x_residual_window, {0, -Block_N}); move_tile_window(gamma_window, {-Block_N}); move_tile_window(y_window, {0, -Block_N}); } diff --git a/include/ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_traits.hpp b/include/ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_traits.hpp index cb7beba291..152da60c01 100644 --- a/include/ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_traits.hpp +++ b/include/ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_traits.hpp @@ -39,6 +39,7 @@ template<> struct Rmsnorm2dFusedQuantEnumName @@ -46,6 +47,7 @@ struct Rmsnorm2dFwdTraits { static constexpr bool kPadN = kPadN_; static constexpr bool kSaveInvRms = kSaveInvRms_; + static constexpr bool kSaveUnquant = kSaveUnquant_; static constexpr bool kTwoPass = kTwoPass_; static constexpr Rmsnorm2dFusedAddEnum kFusedAdd = kFusedAdd_; static constexpr Rmsnorm2dFusedQuantEnum kFusedQuant = kFusedQuant_;