Optimize fmha fwd decode & prefill for gfx950 (#2641)

* Fix for fwd/bwd kernel build filter

* fix bwd code

* save an example for __bf16 type

* temp save, waiting for debug

* tempsave, fmha_decode

* temp save, change all instance to 1wave

* fix async copytest bug

* Add block_sync_lds_direct_load utility

* fix the s_waitcnt_imm calculation

* Improve s_waitcnt_imm calculation

* fix vmcnt shift

* add input validation and bug fix

* remove unnecessary output

* move test_copy into test

* temp save

* tempsave

* compile pass

* tempsave, trload+asyncload done

* tempsave. asynccopy+trload sanity checked

* remove unnecessary features

* fix the lds alignment caused performance regression

* enable prefill overload operator().

* remove all lds bankconflict with xor layouts

* enable larger tile size; upgrade xor pattern

* upgrade prefill pipeline; simple iglp; consistent data produce and consume order

* small refactor

* Load Q through lds, implement xor;

* add vmcnt guard before load ktile

* Add v_permlaneb32 for block_reduce. Disable it as it will cause un-coexecutable packed math in FA

* Add XOR fold strategy for hdim<128, but perf dropped; disable it by default; wait further perf debug

* add __restrict__ to tr load

* merge fa_decode pipeline into fmha_fwd api

* remove unnecessary files; rename some files

* Remove unnecessary changes

* bug fix, clang format;

* remove non-necessary change

* fix clangformat with 18.1.3

* fix bugs

* fix bug

* fix bug on non-gfx950

* fix bugs in gemm

* fix bug in pki4

* tempsave, update the blocksync functions

* change the warp setting for hdim32 fmha fwd

* clang format

* fix conflict. disable all v-col instance for fmha fwd

* Fix the bug

* clang format

---------

Co-authored-by: Max Podkorytov <4273004+tenpercent@users.noreply.github.com>
This commit is contained in:
Haocong WANG
2025-08-12 19:43:14 +08:00
committed by GitHub
parent c0c2ded566
commit b7322a521a
31 changed files with 3533 additions and 627 deletions

View File

@@ -39,6 +39,12 @@ enum struct TailNumber
Full,
};
enum struct GemmLoopOrder
{
KMN,
MNK,
};
} // namespace ck_tile
inline std::ostream& operator<<(std::ostream& os, const ck_tile::GemmPipelineScheduler& s)

View File

@@ -14,10 +14,11 @@ template <typename ADataType_,
typename CDataType_,
typename BlockGemmShape_,
typename Traits_,
typename ComputeDataType_ = ADataType_,
bool FixedVectorSize_ = false,
index_t VectorSizeA_ = 1,
index_t VectorSizeB_ = 1>
typename ComputeDataType_ = ADataType_,
bool FixedVectorSize_ = false,
index_t VectorSizeA_ = 1,
index_t VectorSizeB_ = 1,
GemmLoopOrder BlockGemmLoopOrder_ = GemmLoopOrder::KMN>
struct GemmPipelineProblemBase
{
using Traits = remove_cvref_t<Traits_>;
@@ -45,9 +46,10 @@ struct GemmPipelineProblemBase
static constexpr bool kPadN = Traits::kPadN;
static constexpr bool kPadK = Traits::kPadK;
static constexpr bool DoubleSmemBuffer = Traits::DoubleSmemBuffer;
static constexpr auto Scheduler = GemmPipelineScheduler::Default;
static constexpr index_t VectorLoadSize = Traits::_VectorSize;
static constexpr bool DoubleSmemBuffer = Traits::DoubleSmemBuffer;
static constexpr auto Scheduler = GemmPipelineScheduler::Default;
static constexpr index_t VectorLoadSize = Traits::_VectorSize;
static constexpr GemmLoopOrder BlockGemmLoopOrder = BlockGemmLoopOrder_;
// In the base situation, the Preshuffle setting should be false.
static constexpr bool Preshuffle = false;
@@ -167,10 +169,11 @@ template <typename ADataType_,
typename CDataType_,
typename BlockGemmShape_,
typename Traits_,
typename ComputeDataType_ = ADataType_,
bool FixedVectorSize_ = false,
index_t VectorSizeA_ = 1,
index_t VectorSizeB_ = 1>
typename ComputeDataType_ = ADataType_,
bool FixedVectorSize_ = false,
index_t VectorSizeA_ = 1,
index_t VectorSizeB_ = 1,
GemmLoopOrder BlockGemmLoopOrder_ = GemmLoopOrder::KMN>
using GemmPipelineProblem = GemmPipelineProblemBase<ADataType_,
BDataType_,
CDataType_,
@@ -179,20 +182,22 @@ using GemmPipelineProblem = GemmPipelineProblemBase<ADataType_,
ComputeDataType_,
FixedVectorSize_,
VectorSizeA_,
VectorSizeB_>;
VectorSizeB_,
BlockGemmLoopOrder_>;
template <typename ADataType_,
typename BDataType_,
typename CDataType_,
typename BlockGemmShape_,
typename Traits_,
GemmPipelineScheduler Scheduler_ = GemmPipelineScheduler::Intrawave,
bool HasHotLoop_ = true,
TailNumber TailNum_ = TailNumber::Full,
typename ComputeDataType_ = ADataType_,
bool FixedVectorSize_ = false,
index_t VectorSizeA_ = 1,
index_t VectorSizeB_ = 1>
GemmPipelineScheduler Scheduler_ = GemmPipelineScheduler::Intrawave,
bool HasHotLoop_ = true,
TailNumber TailNum_ = TailNumber::Full,
typename ComputeDataType_ = ADataType_,
bool FixedVectorSize_ = false,
index_t VectorSizeA_ = 1,
index_t VectorSizeB_ = 1,
GemmLoopOrder BlockGemmLoopOrder_ = GemmLoopOrder::KMN>
struct UniversalGemmPipelineProblem
{
using Traits = remove_cvref_t<Traits_>;
@@ -224,8 +229,9 @@ struct UniversalGemmPipelineProblem
static constexpr auto Scheduler = Scheduler_;
static constexpr bool Preshuffle = Traits::Preshuffle;
static constexpr index_t VectorSizeA = VectorSizeA_;
static constexpr index_t VectorSizeB = VectorSizeB_;
static constexpr index_t VectorSizeA = VectorSizeA_;
static constexpr index_t VectorSizeB = VectorSizeB_;
static constexpr GemmLoopOrder BlockGemmLoopOrder = BlockGemmLoopOrder_;
static constexpr auto HasHotLoop = HasHotLoop_;
static constexpr auto TailNum = TailNum_;