mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-07 08:15:04 +00:00
Rename two pass to three pass
This commit is contained in:
@@ -8,5 +8,5 @@
|
||||
#include "ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_default_policy.hpp"
|
||||
#include "ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_one_pass.hpp"
|
||||
#include "ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_problem.hpp"
|
||||
#include "ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_two_pass.hpp"
|
||||
#include "ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_three_pass.hpp"
|
||||
#include "ck_tile/ops/common/tensor_layout.hpp"
|
||||
|
||||
@@ -47,7 +47,7 @@ struct AddRmsnorm2dRdquantFwd
|
||||
static constexpr index_t Block_N = Problem::BlockShape::Block_N;
|
||||
static constexpr bool kPadM = false; // always no need to pad along M
|
||||
static constexpr bool kPadN = Problem::kPadN;
|
||||
static constexpr bool kTwoPass = Problem::kTwoPass;
|
||||
static constexpr bool kThreePass = Problem::kThreePass;
|
||||
|
||||
static constexpr index_t ThreadPerWarp_N = Problem::BlockShape::ThreadPerWarp_N;
|
||||
static constexpr index_t Vector_N = Problem::BlockShape::Vector_N;
|
||||
@@ -115,7 +115,7 @@ struct AddRmsnorm2dRdquantFwd
|
||||
std::string n;
|
||||
if (kPadN) n += "_pn";
|
||||
if (kSaveX) n += "_x";
|
||||
if (kTwoPass) n += "_2p";
|
||||
if (kThreePass) n += "_2p";
|
||||
return n; }();
|
||||
|
||||
#define _SS_ std::string
|
||||
|
||||
@@ -18,7 +18,7 @@ template <typename ADataType_,
|
||||
typename BlockShape_,
|
||||
bool kPadN_,
|
||||
bool kSaveX_,
|
||||
bool kTwoPass_>
|
||||
bool kThreePass_>
|
||||
struct AddRmsnorm2dRdquantFwdPipelineProblem
|
||||
{
|
||||
using ADataType = remove_cvref_t<ADataType_>;
|
||||
@@ -33,9 +33,9 @@ struct AddRmsnorm2dRdquantFwdPipelineProblem
|
||||
static constexpr bool kNeedCrossLaneSync = BlockShape::ThreadPerWarp_N > 1;
|
||||
static constexpr bool kNeedCrossWarpSync = BlockShape::WarpPerBlock_N > 1;
|
||||
|
||||
static constexpr bool kPadN = kPadN_;
|
||||
static constexpr bool kSaveX = kSaveX_;
|
||||
static constexpr bool kTwoPass = kTwoPass_;
|
||||
static constexpr bool kPadN = kPadN_;
|
||||
static constexpr bool kSaveX = kSaveX_;
|
||||
static constexpr bool kThreePass = kThreePass_;
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -11,7 +11,7 @@
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename Problem_, typename Policy_ = AddRmsnorm2dRdquantFwdPipelineDefaultPolicy>
|
||||
struct AddRmsnorm2dRdquantFwdPipelineTwoPass
|
||||
struct AddRmsnorm2dRdquantFwdPipelineThreePass
|
||||
{
|
||||
using Problem = ck_tile::remove_cvref_t<Problem_>;
|
||||
using Policy = ck_tile::remove_cvref_t<Policy_>;
|
||||
Reference in New Issue
Block a user