From 0765bd5201d0f87e1ee01470010c30d5a55b1e54 Mon Sep 17 00:00:00 2001 From: feli Date: Thu, 14 Nov 2024 14:06:36 +0800 Subject: [PATCH] [Ck_tile] hot fix, fix rpcf param setting err (#1657) Co-authored-by: dummycoderfe [ROCm/composable_kernel commit: c1f8d53ce83c6ca6d15fec8d987974bc05008c16] --- .../pipeline/layernorm2d_fwd_pipeline_one_pass.hpp | 2 +- .../pipeline/layernorm2d_fwd_pipeline_two_pass.hpp | 14 +++++++++++--- .../ck_tile/ops/welford/block/block_welford.hpp | 13 +++++++++---- 3 files changed, 21 insertions(+), 8 deletions(-) diff --git a/include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_one_pass.hpp b/include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_one_pass.hpp index 4b83ed4fbf..eefdaf9176 100644 --- a/include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_one_pass.hpp +++ b/include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_one_pass.hpp @@ -121,7 +121,7 @@ struct Layernorm2dFwdPipelineOnePass auto [mean, var] = block_welford(acc, cur_count, max_count); block_welford_sync(mean, var, cur_count); block_welford_cross_warp_sync(mean, var, cur_count, smem); - block_tile_welford_post_scale_var(var, cur_count); + block_tile_welford_post_scale_var(var, cur_count, constant{}); // compute inv-std auto inv_std = tile_elementwise_in( diff --git a/include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_two_pass.hpp b/include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_two_pass.hpp index fadf56dfd3..6a86cc43c9 100644 --- a/include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_two_pass.hpp +++ b/include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_two_pass.hpp @@ -35,6 +35,7 @@ struct Layernorm2dFwdPipelineTwoPass static constexpr bool kNeedCrossWarpSync = Problem::kNeedCrossWarpSync; static constexpr bool kPadM = false; // TODO - BlockLayernorm2dFwdProblem::kPadM static constexpr bool kPadN = Problem::Traits::kPadN; + static constexpr bool kFastFDiv = Problem::Traits::kFastFDiv; static constexpr auto kFusedAdd = Problem::Traits::kFusedAdd; static constexpr auto kFusedQuant = Problem::Traits::kFusedQuant; @@ -137,15 +138,22 @@ struct Layernorm2dFwdPipelineTwoPass block_welford_sync(mean, var, cur_count); block_welford_cross_warp_sync(mean, var, cur_count, smem); - block_tile_welford_post_scale_var(var, cur_count); + block_tile_welford_post_scale_var(var, cur_count, constant{}); // compute inv-std auto inv_std = tile_elementwise_in( [&](const auto& v_) { - return type_convert(1.0f) / (sqrt(v_ + epsilon)); + if(kFastFDiv && std::is_same_v) + { + return type_convert(1.0f) * + __builtin_amdgcn_rcpf(sqrt(v_ + epsilon)); + } + else + { + return type_convert(1.0f) / sqrt(v_ + epsilon); + } }, var); - if constexpr(kSaveMean) store_tile(mean_window, cast_tile(mean)); if constexpr(kSaveInvStd) diff --git a/include/ck_tile/ops/welford/block/block_welford.hpp b/include/ck_tile/ops/welford/block/block_welford.hpp index 968895e38e..56ca86d9df 100644 --- a/include/ck_tile/ops/welford/block/block_welford.hpp +++ b/include/ck_tile/ops/welford/block/block_welford.hpp @@ -47,8 +47,11 @@ struct BlockWelford auto x = ck_tile::type_convert(x_tensor[in_dstr_idx]); - welford_update( - mean_tensor(out_dstr_idx), var_tensor(out_dstr_idx), x, cur_count_); + welford_update(mean_tensor(out_dstr_idx), + var_tensor(out_dstr_idx), + x, + cur_count_, + constant{}); }); } }); @@ -159,7 +162,8 @@ struct BlockWelfordSync v_local_count, v_remote_mean, v_remote_var, - v_remote_count); + v_remote_count, + constant{}); }); } }); @@ -307,7 +311,8 @@ struct BlockWelfordCrossWarpSync v_local_count, v_remote_mean, v_remote_var, - v_remote_count); + v_remote_count, + constant{}); }); mean_tensor.get_thread_buffer()(i_0) = v_local_mean;