mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-07 00:04:37 +00:00
Rename two pass to three pass
This commit is contained in:
@@ -97,7 +97,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
b_buf.ToDevice(b_host.data());
|
||||
gamma_buf.ToDevice(gamma_host.data());
|
||||
|
||||
constexpr bool kTwoPass = false;
|
||||
constexpr bool kThreePass = true;
|
||||
|
||||
using BlockWarps = ck_tile::sequence<2, 2>;
|
||||
using BlockTile = ck_tile::sequence<2, 128>;
|
||||
@@ -115,12 +115,12 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
Shape,
|
||||
true, // kPadN
|
||||
true, // kSaveX
|
||||
kTwoPass>;
|
||||
kThreePass>;
|
||||
|
||||
using OnePassPipeline = ck_tile::AddRmsnorm2dRdquantFwdPipelineOnePass<Problem>;
|
||||
using TwoPassPipeline = ck_tile::AddRmsnorm2dRdquantFwdPipelineTwoPass<Problem>;
|
||||
using Pipeline = std::conditional_t<kTwoPass, TwoPassPipeline, OnePassPipeline>;
|
||||
using Kernel = ck_tile::AddRmsnorm2dRdquantFwd<Pipeline>;
|
||||
using OnePassPipeline = ck_tile::AddRmsnorm2dRdquantFwdPipelineOnePass<Problem>;
|
||||
using ThreePassPipeline = ck_tile::AddRmsnorm2dRdquantFwdPipelineThreePass<Problem>;
|
||||
using Pipeline = std::conditional_t<kThreePass, ThreePassPipeline, OnePassPipeline>;
|
||||
using Kernel = ck_tile::AddRmsnorm2dRdquantFwd<Pipeline>;
|
||||
|
||||
ck_tile::AddRmsnorm2dRdquantFwdHostArgs args{a_buf.GetDeviceBuffer(),
|
||||
b_buf.GetDeviceBuffer(),
|
||||
|
||||
Reference in New Issue
Block a user