Revert "Optimize fmha fwd decode & prefill for gfx950 (#2641)" (#2670)

This reverts commit b7322a521a.
This commit is contained in:
asleepzzz
2025-08-12 20:27:10 +08:00
committed by GitHub
parent b7322a521a
commit 5b39de4bb6
31 changed files with 639 additions and 3545 deletions

View File

@@ -39,12 +39,6 @@ enum struct TailNumber
Full,
};
enum struct GemmLoopOrder
{
KMN,
MNK,
};
} // namespace ck_tile
inline std::ostream& operator<<(std::ostream& os, const ck_tile::GemmPipelineScheduler& s)

View File

@@ -14,11 +14,10 @@ template <typename ADataType_,
typename CDataType_,
typename BlockGemmShape_,
typename Traits_,
typename ComputeDataType_ = ADataType_,
bool FixedVectorSize_ = false,
index_t VectorSizeA_ = 1,
index_t VectorSizeB_ = 1,
GemmLoopOrder BlockGemmLoopOrder_ = GemmLoopOrder::KMN>
typename ComputeDataType_ = ADataType_,
bool FixedVectorSize_ = false,
index_t VectorSizeA_ = 1,
index_t VectorSizeB_ = 1>
struct GemmPipelineProblemBase
{
using Traits = remove_cvref_t<Traits_>;
@@ -46,10 +45,9 @@ struct GemmPipelineProblemBase
static constexpr bool kPadN = Traits::kPadN;
static constexpr bool kPadK = Traits::kPadK;
static constexpr bool DoubleSmemBuffer = Traits::DoubleSmemBuffer;
static constexpr auto Scheduler = GemmPipelineScheduler::Default;
static constexpr index_t VectorLoadSize = Traits::_VectorSize;
static constexpr GemmLoopOrder BlockGemmLoopOrder = BlockGemmLoopOrder_;
static constexpr bool DoubleSmemBuffer = Traits::DoubleSmemBuffer;
static constexpr auto Scheduler = GemmPipelineScheduler::Default;
static constexpr index_t VectorLoadSize = Traits::_VectorSize;
// In the base situation, the Preshuffle setting should be false.
static constexpr bool Preshuffle = false;
@@ -169,11 +167,10 @@ template <typename ADataType_,
typename CDataType_,
typename BlockGemmShape_,
typename Traits_,
typename ComputeDataType_ = ADataType_,
bool FixedVectorSize_ = false,
index_t VectorSizeA_ = 1,
index_t VectorSizeB_ = 1,
GemmLoopOrder BlockGemmLoopOrder_ = GemmLoopOrder::KMN>
typename ComputeDataType_ = ADataType_,
bool FixedVectorSize_ = false,
index_t VectorSizeA_ = 1,
index_t VectorSizeB_ = 1>
using GemmPipelineProblem = GemmPipelineProblemBase<ADataType_,
BDataType_,
CDataType_,
@@ -182,22 +179,20 @@ using GemmPipelineProblem = GemmPipelineProblemBase<ADataType_,
ComputeDataType_,
FixedVectorSize_,
VectorSizeA_,
VectorSizeB_,
BlockGemmLoopOrder_>;
VectorSizeB_>;
template <typename ADataType_,
typename BDataType_,
typename CDataType_,
typename BlockGemmShape_,
typename Traits_,
GemmPipelineScheduler Scheduler_ = GemmPipelineScheduler::Intrawave,
bool HasHotLoop_ = true,
TailNumber TailNum_ = TailNumber::Full,
typename ComputeDataType_ = ADataType_,
bool FixedVectorSize_ = false,
index_t VectorSizeA_ = 1,
index_t VectorSizeB_ = 1,
GemmLoopOrder BlockGemmLoopOrder_ = GemmLoopOrder::KMN>
GemmPipelineScheduler Scheduler_ = GemmPipelineScheduler::Intrawave,
bool HasHotLoop_ = true,
TailNumber TailNum_ = TailNumber::Full,
typename ComputeDataType_ = ADataType_,
bool FixedVectorSize_ = false,
index_t VectorSizeA_ = 1,
index_t VectorSizeB_ = 1>
struct UniversalGemmPipelineProblem
{
using Traits = remove_cvref_t<Traits_>;
@@ -229,9 +224,8 @@ struct UniversalGemmPipelineProblem
static constexpr auto Scheduler = Scheduler_;
static constexpr bool Preshuffle = Traits::Preshuffle;
static constexpr index_t VectorSizeA = VectorSizeA_;
static constexpr index_t VectorSizeB = VectorSizeB_;
static constexpr GemmLoopOrder BlockGemmLoopOrder = BlockGemmLoopOrder_;
static constexpr index_t VectorSizeA = VectorSizeA_;
static constexpr index_t VectorSizeB = VectorSizeB_;
static constexpr auto HasHotLoop = HasHotLoop_;
static constexpr auto TailNum = TailNum_;