Rename two pass to three pass

This commit is contained in:
rocking
2024-10-26 20:29:55 +00:00
parent 697558d856
commit 0f9969a894
5 changed files with 14 additions and 14 deletions

View File

@@ -97,7 +97,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
b_buf.ToDevice(b_host.data());
gamma_buf.ToDevice(gamma_host.data());
constexpr bool kTwoPass = false;
constexpr bool kThreePass = true;
using BlockWarps = ck_tile::sequence<2, 2>;
using BlockTile = ck_tile::sequence<2, 128>;
@@ -115,12 +115,12 @@ bool run(const ck_tile::ArgParser& arg_parser)
Shape,
true, // kPadN
true, // kSaveX
kTwoPass>;
kThreePass>;
using OnePassPipeline = ck_tile::AddRmsnorm2dRdquantFwdPipelineOnePass<Problem>;
using TwoPassPipeline = ck_tile::AddRmsnorm2dRdquantFwdPipelineTwoPass<Problem>;
using Pipeline = std::conditional_t<kTwoPass, TwoPassPipeline, OnePassPipeline>;
using Kernel = ck_tile::AddRmsnorm2dRdquantFwd<Pipeline>;
using OnePassPipeline = ck_tile::AddRmsnorm2dRdquantFwdPipelineOnePass<Problem>;
using ThreePassPipeline = ck_tile::AddRmsnorm2dRdquantFwdPipelineThreePass<Problem>;
using Pipeline = std::conditional_t<kThreePass, ThreePassPipeline, OnePassPipeline>;
using Kernel = ck_tile::AddRmsnorm2dRdquantFwd<Pipeline>;
ck_tile::AddRmsnorm2dRdquantFwdHostArgs args{a_buf.GetDeviceBuffer(),
b_buf.GetDeviceBuffer(),