From 0f9969a8941507a5072bb7d03f3ff3ca45cba458 Mon Sep 17 00:00:00 2001 From: rocking Date: Sat, 26 Oct 2024 20:29:55 +0000 Subject: [PATCH] Rename two pass to three pass --- .../example_add_rmsnorm2d_rdquant_fwd.cpp | 12 ++++++------ include/ck_tile/ops/add_rmsnorm2d_rdquant.hpp | 2 +- .../kernel/add_rmsnorm2d_rdquant_fwd_kernel.hpp | 4 ++-- .../add_rmsnorm2d_rdquant_fwd_pipeline_problem.hpp | 8 ++++---- ...dd_rmsnorm2d_rdquant_fwd_pipeline_three_pass.hpp} | 2 +- 5 files changed, 14 insertions(+), 14 deletions(-) rename include/ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/{add_rmsnorm2d_rdquant_fwd_pipeline_two_pass.hpp => add_rmsnorm2d_rdquant_fwd_pipeline_three_pass.hpp} (99%) diff --git a/example/ck_tile/07_add_rmsnorm2d_rdquant/example_add_rmsnorm2d_rdquant_fwd.cpp b/example/ck_tile/07_add_rmsnorm2d_rdquant/example_add_rmsnorm2d_rdquant_fwd.cpp index ca84f4c5e8..c964b1a18c 100644 --- a/example/ck_tile/07_add_rmsnorm2d_rdquant/example_add_rmsnorm2d_rdquant_fwd.cpp +++ b/example/ck_tile/07_add_rmsnorm2d_rdquant/example_add_rmsnorm2d_rdquant_fwd.cpp @@ -97,7 +97,7 @@ bool run(const ck_tile::ArgParser& arg_parser) b_buf.ToDevice(b_host.data()); gamma_buf.ToDevice(gamma_host.data()); - constexpr bool kTwoPass = false; + constexpr bool kThreePass = true; using BlockWarps = ck_tile::sequence<2, 2>; using BlockTile = ck_tile::sequence<2, 128>; @@ -115,12 +115,12 @@ bool run(const ck_tile::ArgParser& arg_parser) Shape, true, // kPadN true, // kSaveX - kTwoPass>; + kThreePass>; - using OnePassPipeline = ck_tile::AddRmsnorm2dRdquantFwdPipelineOnePass; - using TwoPassPipeline = ck_tile::AddRmsnorm2dRdquantFwdPipelineTwoPass; - using Pipeline = std::conditional_t; - using Kernel = ck_tile::AddRmsnorm2dRdquantFwd; + using OnePassPipeline = ck_tile::AddRmsnorm2dRdquantFwdPipelineOnePass; + using ThreePassPipeline = ck_tile::AddRmsnorm2dRdquantFwdPipelineThreePass; + using Pipeline = std::conditional_t; + using Kernel = ck_tile::AddRmsnorm2dRdquantFwd; ck_tile::AddRmsnorm2dRdquantFwdHostArgs args{a_buf.GetDeviceBuffer(), b_buf.GetDeviceBuffer(), diff --git a/include/ck_tile/ops/add_rmsnorm2d_rdquant.hpp b/include/ck_tile/ops/add_rmsnorm2d_rdquant.hpp index 21c77f2c9c..eb06fea2dd 100644 --- a/include/ck_tile/ops/add_rmsnorm2d_rdquant.hpp +++ b/include/ck_tile/ops/add_rmsnorm2d_rdquant.hpp @@ -8,5 +8,5 @@ #include "ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_default_policy.hpp" #include "ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_one_pass.hpp" #include "ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_problem.hpp" -#include "ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_two_pass.hpp" +#include "ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_three_pass.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" diff --git a/include/ck_tile/ops/add_rmsnorm2d_rdquant/kernel/add_rmsnorm2d_rdquant_fwd_kernel.hpp b/include/ck_tile/ops/add_rmsnorm2d_rdquant/kernel/add_rmsnorm2d_rdquant_fwd_kernel.hpp index 29ba6080da..4a0e290352 100644 --- a/include/ck_tile/ops/add_rmsnorm2d_rdquant/kernel/add_rmsnorm2d_rdquant_fwd_kernel.hpp +++ b/include/ck_tile/ops/add_rmsnorm2d_rdquant/kernel/add_rmsnorm2d_rdquant_fwd_kernel.hpp @@ -47,7 +47,7 @@ struct AddRmsnorm2dRdquantFwd static constexpr index_t Block_N = Problem::BlockShape::Block_N; static constexpr bool kPadM = false; // always no need to pad along M static constexpr bool kPadN = Problem::kPadN; - static constexpr bool kTwoPass = Problem::kTwoPass; + static constexpr bool kThreePass = Problem::kThreePass; static constexpr index_t ThreadPerWarp_N = Problem::BlockShape::ThreadPerWarp_N; static constexpr index_t Vector_N = Problem::BlockShape::Vector_N; @@ -115,7 +115,7 @@ struct AddRmsnorm2dRdquantFwd std::string n; if (kPadN) n += "_pn"; if (kSaveX) n += "_x"; - if (kTwoPass) n += "_2p"; + if (kThreePass) n += "_2p"; return n; }(); #define _SS_ std::string diff --git a/include/ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_problem.hpp b/include/ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_problem.hpp index 51cba63a08..8939e36bae 100644 --- a/include/ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_problem.hpp +++ b/include/ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_problem.hpp @@ -18,7 +18,7 @@ template + bool kThreePass_> struct AddRmsnorm2dRdquantFwdPipelineProblem { using ADataType = remove_cvref_t; @@ -33,9 +33,9 @@ struct AddRmsnorm2dRdquantFwdPipelineProblem static constexpr bool kNeedCrossLaneSync = BlockShape::ThreadPerWarp_N > 1; static constexpr bool kNeedCrossWarpSync = BlockShape::WarpPerBlock_N > 1; - static constexpr bool kPadN = kPadN_; - static constexpr bool kSaveX = kSaveX_; - static constexpr bool kTwoPass = kTwoPass_; + static constexpr bool kPadN = kPadN_; + static constexpr bool kSaveX = kSaveX_; + static constexpr bool kThreePass = kThreePass_; }; } // namespace ck_tile diff --git a/include/ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_two_pass.hpp b/include/ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_three_pass.hpp similarity index 99% rename from include/ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_two_pass.hpp rename to include/ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_three_pass.hpp index e1d8bfe1e1..0d436b143f 100644 --- a/include/ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_two_pass.hpp +++ b/include/ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_three_pass.hpp @@ -11,7 +11,7 @@ namespace ck_tile { template -struct AddRmsnorm2dRdquantFwdPipelineTwoPass +struct AddRmsnorm2dRdquantFwdPipelineThreePass { using Problem = ck_tile::remove_cvref_t; using Policy = ck_tile::remove_cvref_t;