Rename two pass to three pass

This commit is contained in:
rocking
2024-10-26 20:29:55 +00:00
parent 697558d856
commit 0f9969a894
5 changed files with 14 additions and 14 deletions

View File

@@ -97,7 +97,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
b_buf.ToDevice(b_host.data());
gamma_buf.ToDevice(gamma_host.data());
constexpr bool kTwoPass = false;
constexpr bool kThreePass = true;
using BlockWarps = ck_tile::sequence<2, 2>;
using BlockTile = ck_tile::sequence<2, 128>;
@@ -115,12 +115,12 @@ bool run(const ck_tile::ArgParser& arg_parser)
Shape,
true, // kPadN
true, // kSaveX
kTwoPass>;
kThreePass>;
using OnePassPipeline = ck_tile::AddRmsnorm2dRdquantFwdPipelineOnePass<Problem>;
using TwoPassPipeline = ck_tile::AddRmsnorm2dRdquantFwdPipelineTwoPass<Problem>;
using Pipeline = std::conditional_t<kTwoPass, TwoPassPipeline, OnePassPipeline>;
using Kernel = ck_tile::AddRmsnorm2dRdquantFwd<Pipeline>;
using OnePassPipeline = ck_tile::AddRmsnorm2dRdquantFwdPipelineOnePass<Problem>;
using ThreePassPipeline = ck_tile::AddRmsnorm2dRdquantFwdPipelineThreePass<Problem>;
using Pipeline = std::conditional_t<kThreePass, ThreePassPipeline, OnePassPipeline>;
using Kernel = ck_tile::AddRmsnorm2dRdquantFwd<Pipeline>;
ck_tile::AddRmsnorm2dRdquantFwdHostArgs args{a_buf.GetDeviceBuffer(),
b_buf.GetDeviceBuffer(),

View File

@@ -8,5 +8,5 @@
#include "ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_default_policy.hpp"
#include "ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_one_pass.hpp"
#include "ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_problem.hpp"
#include "ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_two_pass.hpp"
#include "ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_three_pass.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"

View File

@@ -47,7 +47,7 @@ struct AddRmsnorm2dRdquantFwd
static constexpr index_t Block_N = Problem::BlockShape::Block_N;
static constexpr bool kPadM = false; // always no need to pad along M
static constexpr bool kPadN = Problem::kPadN;
static constexpr bool kTwoPass = Problem::kTwoPass;
static constexpr bool kThreePass = Problem::kThreePass;
static constexpr index_t ThreadPerWarp_N = Problem::BlockShape::ThreadPerWarp_N;
static constexpr index_t Vector_N = Problem::BlockShape::Vector_N;
@@ -115,7 +115,7 @@ struct AddRmsnorm2dRdquantFwd
std::string n;
if (kPadN) n += "_pn";
if (kSaveX) n += "_x";
if (kTwoPass) n += "_2p";
if (kThreePass) n += "_2p";
return n; }();
#define _SS_ std::string

View File

@@ -18,7 +18,7 @@ template <typename ADataType_,
typename BlockShape_,
bool kPadN_,
bool kSaveX_,
bool kTwoPass_>
bool kThreePass_>
struct AddRmsnorm2dRdquantFwdPipelineProblem
{
using ADataType = remove_cvref_t<ADataType_>;
@@ -33,9 +33,9 @@ struct AddRmsnorm2dRdquantFwdPipelineProblem
static constexpr bool kNeedCrossLaneSync = BlockShape::ThreadPerWarp_N > 1;
static constexpr bool kNeedCrossWarpSync = BlockShape::WarpPerBlock_N > 1;
static constexpr bool kPadN = kPadN_;
static constexpr bool kSaveX = kSaveX_;
static constexpr bool kTwoPass = kTwoPass_;
static constexpr bool kPadN = kPadN_;
static constexpr bool kSaveX = kSaveX_;
static constexpr bool kThreePass = kThreePass_;
};
} // namespace ck_tile

View File

@@ -11,7 +11,7 @@
namespace ck_tile {
template <typename Problem_, typename Policy_ = AddRmsnorm2dRdquantFwdPipelineDefaultPolicy>
struct AddRmsnorm2dRdquantFwdPipelineTwoPass
struct AddRmsnorm2dRdquantFwdPipelineThreePass
{
using Problem = ck_tile::remove_cvref_t<Problem_>;
using Policy = ck_tile::remove_cvref_t<Policy_>;