mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-06 07:51:52 +00:00
Rename two pass to three pass
This commit is contained in:
@@ -97,7 +97,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
b_buf.ToDevice(b_host.data());
|
||||
gamma_buf.ToDevice(gamma_host.data());
|
||||
|
||||
constexpr bool kTwoPass = false;
|
||||
constexpr bool kThreePass = true;
|
||||
|
||||
using BlockWarps = ck_tile::sequence<2, 2>;
|
||||
using BlockTile = ck_tile::sequence<2, 128>;
|
||||
@@ -115,12 +115,12 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
Shape,
|
||||
true, // kPadN
|
||||
true, // kSaveX
|
||||
kTwoPass>;
|
||||
kThreePass>;
|
||||
|
||||
using OnePassPipeline = ck_tile::AddRmsnorm2dRdquantFwdPipelineOnePass<Problem>;
|
||||
using TwoPassPipeline = ck_tile::AddRmsnorm2dRdquantFwdPipelineTwoPass<Problem>;
|
||||
using Pipeline = std::conditional_t<kTwoPass, TwoPassPipeline, OnePassPipeline>;
|
||||
using Kernel = ck_tile::AddRmsnorm2dRdquantFwd<Pipeline>;
|
||||
using OnePassPipeline = ck_tile::AddRmsnorm2dRdquantFwdPipelineOnePass<Problem>;
|
||||
using ThreePassPipeline = ck_tile::AddRmsnorm2dRdquantFwdPipelineThreePass<Problem>;
|
||||
using Pipeline = std::conditional_t<kThreePass, ThreePassPipeline, OnePassPipeline>;
|
||||
using Kernel = ck_tile::AddRmsnorm2dRdquantFwd<Pipeline>;
|
||||
|
||||
ck_tile::AddRmsnorm2dRdquantFwdHostArgs args{a_buf.GetDeviceBuffer(),
|
||||
b_buf.GetDeviceBuffer(),
|
||||
|
||||
@@ -8,5 +8,5 @@
|
||||
#include "ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_default_policy.hpp"
|
||||
#include "ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_one_pass.hpp"
|
||||
#include "ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_problem.hpp"
|
||||
#include "ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_two_pass.hpp"
|
||||
#include "ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_three_pass.hpp"
|
||||
#include "ck_tile/ops/common/tensor_layout.hpp"
|
||||
|
||||
@@ -47,7 +47,7 @@ struct AddRmsnorm2dRdquantFwd
|
||||
static constexpr index_t Block_N = Problem::BlockShape::Block_N;
|
||||
static constexpr bool kPadM = false; // always no need to pad along M
|
||||
static constexpr bool kPadN = Problem::kPadN;
|
||||
static constexpr bool kTwoPass = Problem::kTwoPass;
|
||||
static constexpr bool kThreePass = Problem::kThreePass;
|
||||
|
||||
static constexpr index_t ThreadPerWarp_N = Problem::BlockShape::ThreadPerWarp_N;
|
||||
static constexpr index_t Vector_N = Problem::BlockShape::Vector_N;
|
||||
@@ -115,7 +115,7 @@ struct AddRmsnorm2dRdquantFwd
|
||||
std::string n;
|
||||
if (kPadN) n += "_pn";
|
||||
if (kSaveX) n += "_x";
|
||||
if (kTwoPass) n += "_2p";
|
||||
if (kThreePass) n += "_2p";
|
||||
return n; }();
|
||||
|
||||
#define _SS_ std::string
|
||||
|
||||
@@ -18,7 +18,7 @@ template <typename ADataType_,
|
||||
typename BlockShape_,
|
||||
bool kPadN_,
|
||||
bool kSaveX_,
|
||||
bool kTwoPass_>
|
||||
bool kThreePass_>
|
||||
struct AddRmsnorm2dRdquantFwdPipelineProblem
|
||||
{
|
||||
using ADataType = remove_cvref_t<ADataType_>;
|
||||
@@ -33,9 +33,9 @@ struct AddRmsnorm2dRdquantFwdPipelineProblem
|
||||
static constexpr bool kNeedCrossLaneSync = BlockShape::ThreadPerWarp_N > 1;
|
||||
static constexpr bool kNeedCrossWarpSync = BlockShape::WarpPerBlock_N > 1;
|
||||
|
||||
static constexpr bool kPadN = kPadN_;
|
||||
static constexpr bool kSaveX = kSaveX_;
|
||||
static constexpr bool kTwoPass = kTwoPass_;
|
||||
static constexpr bool kPadN = kPadN_;
|
||||
static constexpr bool kSaveX = kSaveX_;
|
||||
static constexpr bool kThreePass = kThreePass_;
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -11,7 +11,7 @@
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename Problem_, typename Policy_ = AddRmsnorm2dRdquantFwdPipelineDefaultPolicy>
|
||||
struct AddRmsnorm2dRdquantFwdPipelineTwoPass
|
||||
struct AddRmsnorm2dRdquantFwdPipelineThreePass
|
||||
{
|
||||
using Problem = ck_tile::remove_cvref_t<Problem_>;
|
||||
using Policy = ck_tile::remove_cvref_t<Policy_>;
|
||||
Reference in New Issue
Block a user