Rename two pass to three pass

This commit is contained in:
rocking
2024-10-26 20:29:55 +00:00
parent 697558d856
commit 0f9969a894
5 changed files with 14 additions and 14 deletions

View File

@@ -8,5 +8,5 @@
#include "ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_default_policy.hpp"
#include "ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_one_pass.hpp"
#include "ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_problem.hpp"
#include "ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_two_pass.hpp"
#include "ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_three_pass.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"

View File

@@ -47,7 +47,7 @@ struct AddRmsnorm2dRdquantFwd
static constexpr index_t Block_N = Problem::BlockShape::Block_N;
static constexpr bool kPadM = false; // always no need to pad along M
static constexpr bool kPadN = Problem::kPadN;
static constexpr bool kTwoPass = Problem::kTwoPass;
static constexpr bool kThreePass = Problem::kThreePass;
static constexpr index_t ThreadPerWarp_N = Problem::BlockShape::ThreadPerWarp_N;
static constexpr index_t Vector_N = Problem::BlockShape::Vector_N;
@@ -115,7 +115,7 @@ struct AddRmsnorm2dRdquantFwd
std::string n;
if (kPadN) n += "_pn";
if (kSaveX) n += "_x";
if (kTwoPass) n += "_2p";
if (kThreePass) n += "_2p";
return n; }();
#define _SS_ std::string

View File

@@ -18,7 +18,7 @@ template <typename ADataType_,
typename BlockShape_,
bool kPadN_,
bool kSaveX_,
bool kTwoPass_>
bool kThreePass_>
struct AddRmsnorm2dRdquantFwdPipelineProblem
{
using ADataType = remove_cvref_t<ADataType_>;
@@ -33,9 +33,9 @@ struct AddRmsnorm2dRdquantFwdPipelineProblem
static constexpr bool kNeedCrossLaneSync = BlockShape::ThreadPerWarp_N > 1;
static constexpr bool kNeedCrossWarpSync = BlockShape::WarpPerBlock_N > 1;
static constexpr bool kPadN = kPadN_;
static constexpr bool kSaveX = kSaveX_;
static constexpr bool kTwoPass = kTwoPass_;
static constexpr bool kPadN = kPadN_;
static constexpr bool kSaveX = kSaveX_;
static constexpr bool kThreePass = kThreePass_;
};
} // namespace ck_tile

View File

@@ -11,7 +11,7 @@
namespace ck_tile {
template <typename Problem_, typename Policy_ = AddRmsnorm2dRdquantFwdPipelineDefaultPolicy>
struct AddRmsnorm2dRdquantFwdPipelineTwoPass
struct AddRmsnorm2dRdquantFwdPipelineThreePass
{
using Problem = ck_tile::remove_cvref_t<Problem_>;
using Policy = ck_tile::remove_cvref_t<Policy_>;