From 561c2213429635ec998eae7827642ed9bf804a9b Mon Sep 17 00:00:00 2001 From: dummycoderfe Date: Fri, 8 Nov 2024 12:28:23 +0800 Subject: [PATCH] [Ck tile] layernorm2d fwd optimize (#1637) * optimze small N case using vec io and using rcp div * [Ck_tile] layernorm, add param to control fastdiv; change generate codes and test pass * [Ck_tile] fix blockSize compute in Generic2dBlockShape * [Ck_tile]fix kfastfdiv template style * [Ck_tile] layernorm, fix stype in review --------- Co-authored-by: dummycoderfe [ROCm/composable_kernel commit: 686a58a912f6884a9b66841cf04b4b81ba35aa7f] --- example/ck_tile/02_layernorm2d/generate.py | 105 ++++++++++-------- .../ops/common/generic_2d_block_shape.hpp | 12 +- ...ayernorm2d_fwd_pipeline_default_policy.hpp | 12 +- .../layernorm2d_fwd_pipeline_one_pass.hpp | 11 +- .../pipeline/layernorm2d_fwd_traits.hpp | 2 + .../ops/welford/block/block_welford.hpp | 34 ++++-- .../welford/block/block_welford_problem.hpp | 9 +- .../ops/welford/thread/thread_welford.hpp | 43 +++++-- 8 files changed, 144 insertions(+), 84 deletions(-) diff --git a/example/ck_tile/02_layernorm2d/generate.py b/example/ck_tile/02_layernorm2d/generate.py index 09aa6b65f8..ca9e432a4f 100644 --- a/example/ck_tile/02_layernorm2d/generate.py +++ b/example/ck_tile/02_layernorm2d/generate.py @@ -57,6 +57,7 @@ template @@ -118,6 +119,7 @@ struct layernorm2d_fwd_traits_ static constexpr bool kPadN = kPadN_; static constexpr bool kSaveMeanInvStd = kSaveMeanInvStd_; + static constexpr bool kFastFDiv = kFastFDiv_; static constexpr bool kTwoPass = kTwoPass_; static constexpr ck_tile::index_t kFusedAdd = kFusedAdd_; static constexpr ck_tile::index_t kFusedQuant = kFusedQuant_; @@ -134,6 +136,7 @@ template @@ -148,6 +151,7 @@ using traits_ = layernorm2d_fwd_traits_; @@ -179,6 +183,7 @@ float layernorm2d_fwd_(const S& s, A a) using PipelineTraits = ck_tile::Layernorm2dFwdTraits(Traits_::kFusedAdd), static_cast(Traits_::kFusedQuant)>; @@ -269,7 +274,7 @@ float layernorm2d_fwd(layernorm2d_fwd_traits t, #include "layernorm2d_fwd_api_common.hpp" // clang-format off -// prec_i prec_o prec_sy rm rn tm tn vn pd mv 2p add sweep +// prec_i prec_o prec_sy rm rn tm tn vn pd mv rpcf 2p add sweep {F_instance_def} // clang-format on @@ -356,6 +361,7 @@ float layernorm2d_fwd(layernorm2d_fwd_traits t, F_Vector_N : int F_kPadN : bool F_kSaveMeanInvStd_ : bool + F_kFastFDiv_ : bool F_kTwoPass_ : bool F_kFusedAdd : int F_kFusedQuant : int @@ -363,7 +369,7 @@ float layernorm2d_fwd(layernorm2d_fwd_traits t, @property def trait_name(self) ->str: t_ = f'{DATA_TYPE_MAP[self.F_XDataType]}, {DATA_TYPE_MAP[self.F_YDataType]}, {DATA_TYPE_MAP[self.F_XScaleDataType]}, {DATA_TYPE_MAP[self.F_YScaleDataType]}, {self.F_Repeat_M:2}, {self.F_Repeat_N:2}, {self.F_ThreadPerBlock_M:2}, {self.F_ThreadPerBlock_N:4}' - t_ += f', {self.F_Vector_N:2}, {BOOL_MAP(self.F_kPadN):5}, {BOOL_MAP(self.F_kSaveMeanInvStd_):5}' + t_ += f', {self.F_Vector_N:2}, {BOOL_MAP(self.F_kPadN):5}, {BOOL_MAP(self.F_kSaveMeanInvStd_):5}, {BOOL_MAP(self.F_kFastFDiv_):5}' t_ += f', {BOOL_MAP(self.F_kTwoPass_):5}, {self.F_kFusedAdd:4}, {self.F_kFusedQuant:4}' return t_ @@ -483,52 +489,55 @@ float layernorm2d_fwd(layernorm2d_fwd_traits t, fused_add_list = [0, 1] fused_sweep_list = [0, 1] # NOTE: only single pass can use fused dynamic quant - # rm rn tm tn vn pd mv 2p add sweep - h_trait_dict = {'64' : [ h_traits('x', 'y', 'xs', 'ys', 1, 1, 4, 64, 1, True, False, False, 0, 0)], - '128' : [ h_traits('x', 'y', 'xs', 'ys', 1, 1, 4, 64, 2, True, False, False, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 2, 4, 64, 1, True, False, False, 0, 0)], - '256' : [ h_traits('x', 'y', 'xs', 'ys', 1, 1, 4, 64, 4, True, False, False, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 2, 4, 64, 2, True, False, False, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 4, 4, 64, 1, True, False, False, 0, 0)], - '512' : [ h_traits('x', 'y', 'xs', 'ys', 1, 1, 4, 64, 8, True, False, False, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 2, 4, 64, 4, True, False, False, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 4, 4, 64, 2, True, False, False, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 8, 4, 64, 1, True, False, False, 0, 0)], - '768' : [ h_traits('x', 'y', 'xs', 'ys', 1, 3, 4, 64, 4, True, False, False, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 6, 4, 64, 2, True, False, False, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 12, 4, 64, 1, True, False, False, 0, 0)], - '1024' :[ h_traits('x', 'y', 'xs', 'ys', 1, 1, 2, 128, 8, True, False, False, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 2, 2, 128, 4, True, False, False, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 4, 2, 128, 2, True, False, False, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 256, 1, True, False, False, 0, 0)], - '1536' :[ h_traits('x', 'y', 'xs', 'ys', 1, 3, 4, 64, 8, True, False, False, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 3, 2, 128, 4, True, False, False, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 3, 1, 256, 2, True, False, False, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 6, 1, 256, 1, True, False, False, 0, 0)], - '2048' :[ h_traits('x', 'y', 'xs', 'ys', 1, 1, 1, 256, 8, True, False, False, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 2, 1, 256, 4, True, False, False, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 256, 2, True, False, False, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 8, 1, 256, 1, True, False, False, 0, 0)], - '3072' :[ h_traits('x', 'y', 'xs', 'ys', 1, 3, 1, 128, 8, True, False, False, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 3, 1, 256, 4, True, False, False, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 6, 1, 256, 2, True, False, False, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 3, 1,1024, 1, True, False, False, 0, 0)], - '4096' :[ h_traits('x', 'y', 'xs', 'ys', 1, 2, 1, 256, 8, True, False, False, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 256, 4, True, False, False, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 2, 1,1024, 2, True, False, False, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 4, 1,1024, 1, True, False, False, 0, 0)], - '6144' :[ h_traits('x', 'y', 'xs', 'ys', 1, 3, 1, 256, 8, True, False, False, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 3, 1, 512, 4, True, False, False, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 3, 1,1024, 2, True, False, False, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 6, 1,1024, 1, True, False, False, 0, 0)], - '8192' :[ h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 256, 8, True, False, False, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 512, 4, True, False, False, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 4, 1,1024, 2, True, False, False, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 8, 1,1024, 1, True, False, False, 0, 0)], - 'big' :[ h_traits('x', 'y', 'xs', 'ys', 1, 2, 1, 256, 8, True, False, True, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 256, 4, True, False, True, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 2, 1,1024, 2, True, False, True, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 4, 1,1024, 1, True, False, True, 0, 0)]} + # rm rn tm tn vn pd mv fdiv 2p add sweep + h_trait_dict = {'64' : [ h_traits('x', 'y', 'xs', 'ys', 1, 1, 8, 8, 8, True, False, True, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 1, 4, 16, 4, True, False, True, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 1, 4, 64, 1, True, False, True, False, 0, 0)], + '128' : [ h_traits('x', 'y', 'xs', 'ys', 1, 1, 4, 16, 8, True, False, True, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 1, 4, 64, 2, True, False, True, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 2, 4, 64, 1, True, False, True, False, 0, 0)], + '256' : [ h_traits('x', 'y', 'xs', 'ys', 1, 1, 4, 64, 4, True, False, True, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 2, 4, 64, 2, True, False, True, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 4, 4, 64, 1, True, False, True, False, 0, 0)], + '512' : [ h_traits('x', 'y', 'xs', 'ys', 1, 1, 4, 64, 8, True, False, True, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 2, 4, 64, 4, True, False, True, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 4, 4, 64, 2, True, False, True, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 8, 4, 64, 1, True, False, True, False, 0, 0)], + '768' : [ h_traits('x', 'y', 'xs', 'ys', 1, 3, 4, 64, 4, True, False, True, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 6, 4, 64, 2, True, False, True, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 12, 4, 64, 1, True, False, True, False, 0, 0)], + '1024' :[ h_traits('x', 'y', 'xs', 'ys', 1, 1, 2, 128, 8, True, False, True, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 2, 2, 128, 4, True, False, True, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 4, 2, 128, 2, True, False, True, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 256, 1, True, False, True, False, 0, 0)], + '1536' :[ h_traits('x', 'y', 'xs', 'ys', 1, 3, 4, 64, 8, True, False, True, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 3, 2, 128, 4, True, False, True, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 3, 1, 256, 2, True, False, True, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 6, 1, 256, 1, True, False, True, False, 0, 0)], + '2048' :[ h_traits('x', 'y', 'xs', 'ys', 1, 1, 1, 256, 8, True, False, True, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 2, 1, 256, 4, True, False, True, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 256, 2, True, False, True, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 8, 1, 256, 1, True, False, True, False, 0, 0)], + '3072' :[ h_traits('x', 'y', 'xs', 'ys', 1, 3, 1, 128, 8, True, False, True, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 3, 1, 256, 4, True, False, True, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 6, 1, 256, 2, True, False, True, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 3, 1,1024, 1, True, False, True, False, 0, 0)], + '4096' :[ h_traits('x', 'y', 'xs', 'ys', 1, 2, 1, 256, 8, True, False, True, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 256, 4, True, False, True, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 2, 1,1024, 2, True, False, True, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 4, 1,1024, 1, True, False, True, False, 0, 0)], + '6144' :[ h_traits('x', 'y', 'xs', 'ys', 1, 3, 1, 256, 8, True, False, True, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 3, 1, 512, 4, True, False, True, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 3, 1,1024, 2, True, False, True, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 6, 1,1024, 1, True, False, True, False, 0, 0)], + '8192' :[ h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 256, 8, True, False, True, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 512, 4, True, False, True, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 4, 1,1024, 2, True, False, True, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 8, 1,1024, 1, True, False, True, False, 0, 0)], + 'big' :[ h_traits('x', 'y', 'xs', 'ys', 1, 2, 1, 256, 8, True, False, True, True, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 256, 4, True, False, True, True, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 2, 1,1024, 2, True, False, True, True, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 4, 1,1024, 1, True, False, True, True, 0, 0)]} total_blob = list() for hs_key in h_trait_dict: hs = h_trait_dict[hs_key] diff --git a/include/ck_tile/ops/common/generic_2d_block_shape.hpp b/include/ck_tile/ops/common/generic_2d_block_shape.hpp index 64ad20c3be..c0bfd93198 100644 --- a/include/ck_tile/ops/common/generic_2d_block_shape.hpp +++ b/include/ck_tile/ops/common/generic_2d_block_shape.hpp @@ -38,9 +38,7 @@ namespace ck_tile { template typename WarpPerBlock_, // num warps along seq typename WarpTile_, // warp size, seq - typename Vector_, // contiguous pixels(vector size) along seq - index_t BlockSize_ = - warpSize* reduce_on_sequence(WarpPerBlock_{}, multiplies{}, number<1>{})> + typename Vector_> // contiguous pixels(vector size) along seq)> struct Generic2dBlockShape { // block size @@ -68,10 +66,12 @@ struct Generic2dBlockShape static_assert(Warp_M % Vector_M == 0); static_assert(Warp_N % Vector_N == 0); // num of threads along seq, within each warp - static constexpr index_t ThreadPerWarp_M = Warp_M / Vector_M; - static constexpr index_t ThreadPerWarp_N = Warp_N / Vector_N; + static constexpr index_t ThreadPerWarp_M = Warp_M / Vector_M; + static constexpr index_t ThreadPerWarp_N = Warp_N / Vector_N; + static constexpr index_t ThreadPerBlock_M = Block_M / Repeat_M / Vector_M; + static constexpr index_t ThreadPerBlock_N = Block_N / Repeat_N / Vector_N; - static constexpr index_t BlockSize = BlockSize_; + static constexpr index_t BlockSize = ThreadPerBlock_M * ThreadPerBlock_N; }; } // namespace ck_tile diff --git a/include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_default_policy.hpp b/include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_default_policy.hpp index 1de230c144..724f6261d5 100644 --- a/include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_default_policy.hpp +++ b/include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_default_policy.hpp @@ -47,7 +47,8 @@ struct Layernorm2dFwdPipelineDefaultPolicy { using P_ = BlockWelfordProblem; + typename Problem::BlockShape, + Problem::Traits::kFastFDiv>; return BlockWelford{}; } @@ -57,7 +58,8 @@ struct Layernorm2dFwdPipelineDefaultPolicy { using P_ = BlockWelfordProblem; + typename Problem::BlockShape, + Problem::Traits::kFastFDiv>; return BlockWelfordSync{}; } @@ -67,7 +69,8 @@ struct Layernorm2dFwdPipelineDefaultPolicy { using P_ = BlockWelfordProblem; + typename Problem::BlockShape, + Problem::Traits::kFastFDiv>; return BlockWelfordCrossWarpSync{}; } @@ -79,7 +82,8 @@ struct Layernorm2dFwdPipelineDefaultPolicy { using P_ = BlockWelfordProblem; + typename Problem::BlockShape, + Problem::Traits::kFastFDiv>; using block_welford = BlockWelford; using x_block_tile = diff --git a/include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_one_pass.hpp b/include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_one_pass.hpp index 83cdab428e..4b83ed4fbf 100644 --- a/include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_one_pass.hpp +++ b/include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_one_pass.hpp @@ -36,6 +36,7 @@ struct Layernorm2dFwdPipelineOnePass static constexpr bool kNeedCrossWarpSync = Problem::kNeedCrossWarpSync; static constexpr bool kPadM = false; // TODO - BlockLayernorm2dFwdProblem::kPadM static constexpr bool kPadN = Problem::Traits::kPadN; + static constexpr bool kFastFDiv = Problem::Traits::kFastFDiv; static constexpr auto kFusedAdd = Problem::Traits::kFusedAdd; static constexpr auto kFusedQuant = Problem::Traits::kFusedQuant; @@ -125,7 +126,15 @@ struct Layernorm2dFwdPipelineOnePass // compute inv-std auto inv_std = tile_elementwise_in( [&](const auto& v_) { - return type_convert(1.0f) / (sqrt(v_ + epsilon)); + if(kFastFDiv && std::is_same_v) + { + return type_convert(1.0f) * + __builtin_amdgcn_rcpf(sqrt(v_ + epsilon)); + } + else + { + return type_convert(1.0f) / sqrt(v_ + epsilon); + } }, var); diff --git a/include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_traits.hpp b/include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_traits.hpp index ed9e18be30..e8c22f8ab5 100644 --- a/include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_traits.hpp +++ b/include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_traits.hpp @@ -39,6 +39,7 @@ template<> struct Layernorm2dFusedQuantEnumName @@ -46,6 +47,7 @@ struct Layernorm2dFwdTraits { static constexpr bool kPadN = kPadN_; static constexpr bool kSaveMeanInvStd = kSaveMeanInvStd_; + static constexpr bool kFastFDiv = kFastFDiv_; static constexpr bool kTwoPass = kTwoPass_; static constexpr Layernorm2dFusedAddEnum kFusedAdd = kFusedAdd_; static constexpr Layernorm2dFusedQuantEnum kFusedQuant = kFusedQuant_; diff --git a/include/ck_tile/ops/welford/block/block_welford.hpp b/include/ck_tile/ops/welford/block/block_welford.hpp index ce73c183e1..968895e38e 100644 --- a/include/ck_tile/ops/welford/block/block_welford.hpp +++ b/include/ck_tile/ops/welford/block/block_welford.hpp @@ -11,9 +11,10 @@ namespace ck_tile { template struct BlockWelford { - using Problem = remove_cvref_t; - using XDataType = typename Problem::XDataType; - using ComputeDataType = typename Problem::ComputeDataType; + using Problem = remove_cvref_t; + using XDataType = typename Problem::XDataType; + using ComputeDataType = typename Problem::ComputeDataType; + static constexpr bool kFastFDiv = Problem::kFastFDiv; CK_TILE_DEVICE constexpr BlockWelford() {} @@ -89,7 +90,8 @@ struct BlockWelford template struct BlockWelfordSync { - using Problem = remove_cvref_t; + using Problem = remove_cvref_t; + static constexpr bool kFastFDiv = Problem::kFastFDiv; template CK_TILE_DEVICE void @@ -173,8 +175,9 @@ struct BlockWelfordSync template struct BlockWelfordCrossWarpSync { - using Problem = remove_cvref_t; - using BlockShape = typename Problem::BlockShape; + using Problem = remove_cvref_t; + using BlockShape = typename Problem::BlockShape; + static constexpr bool kFastFDiv = Problem::kFastFDiv; template CK_TILE_DEVICE static constexpr index_t GetReduceWarps() @@ -351,12 +354,23 @@ CK_TILE_DEVICE constexpr index_t block_tile_welford_calculate_max_count(int row_ } // Note: this function must be called after all the computation -template +template CK_TILE_DEVICE constexpr void block_tile_welford_post_scale_var(VarDistributedTensor_& var_tensor, - int count) + int count, + bool_constant = {}) { using DataType = typename VarDistributedTensor_::DataType; - tile_elementwise_inout([&count](auto& x) { x = x / type_convert(count); }, - var_tensor); + tile_elementwise_inout( + [&count](auto& x) { + if(FastFdiv_ && std::is_same_v) + { + x = x * __builtin_amdgcn_rcpf(type_convert(count)); + } + else + { + x = x / type_convert(count); + } + }, + var_tensor); } } // namespace ck_tile diff --git a/include/ck_tile/ops/welford/block/block_welford_problem.hpp b/include/ck_tile/ops/welford/block/block_welford_problem.hpp index dcae1ef2ee..bcbfb7d76e 100644 --- a/include/ck_tile/ops/welford/block/block_welford_problem.hpp +++ b/include/ck_tile/ops/welford/block/block_welford_problem.hpp @@ -7,12 +7,13 @@ namespace ck_tile { -template +template struct BlockWelfordProblem { - using XDataType = remove_cvref_t; - using ComputeDataType = remove_cvref_t; - using BlockShape = remove_cvref_t; + using XDataType = remove_cvref_t; + using ComputeDataType = remove_cvref_t; + using BlockShape = remove_cvref_t; + static constexpr bool kFastFDiv = kFastFDiv_; }; } // namespace ck_tile diff --git a/include/ck_tile/ops/welford/thread/thread_welford.hpp b/include/ck_tile/ops/welford/thread/thread_welford.hpp index 4c61cdcf4b..52b253e5f7 100644 --- a/include/ck_tile/ops/welford/thread/thread_welford.hpp +++ b/include/ck_tile/ops/welford/thread/thread_welford.hpp @@ -7,25 +7,46 @@ namespace ck_tile { -template -CK_TILE_DEVICE void welford_update(T& mean, T& var, T x, int count) +template +CK_TILE_DEVICE void welford_update(T& mean, T& var, T x, int count, bool_constant = {}) { // TODO: check nan? maybe no T delta = x - mean; - mean += delta / count; + if(kFastFDiv && std::is_same_v) + { + mean += delta * __builtin_amdgcn_rcpf(count); + } + else + { + mean += delta / count; + } T delta2 = x - mean; var += delta * delta2; } -template -CK_TILE_DEVICE static void -welford_merge(T& mean_a, T& var_a, int& count_a, T mean_b, T var_b, int count_b) +template +CK_TILE_DEVICE static void welford_merge(T& mean_a, + T& var_a, + int& count_a, + T mean_b, + T var_b, + int count_b, + bool_constant = {}) { - int count = count_a + count_b; - T count_ = type_convert(count); - T count_a_ = type_convert(count_a); - T count_b_ = type_convert(count_b); - T count_b_over_count = count == 0 ? type_convert(0) : count_b_ / count_; + int count = count_a + count_b; + T count_ = type_convert(count); + T count_a_ = type_convert(count_a); + T count_b_ = type_convert(count_b); + T count_b_over_count; + if(kFastFDiv && std::is_same_v) + { + count_b_over_count = + count == 0 ? type_convert(0) : count_b_ * __builtin_amdgcn_rcpf(count_); + } + else + { + count_b_over_count = count == 0 ? type_convert(0) : count_b_ / count_; + } T delta = mean_b - mean_a; mean_a += delta * count_b_over_count;