[CK_TILE MOE] add NT & preshuffle permute to cktile MOE (#3377)

* update coherence
---------

Co-authored-by: Zzz9990 <Zzz9990>
This commit is contained in:
Zzz9990
2025-12-10 10:03:28 +08:00
committed by GitHub
parent 934ba1208a
commit 1aa93ef551
8 changed files with 88 additions and 29 deletions

View File

@@ -24,6 +24,18 @@ struct BaseFlatmmPipelineAGmemBGmemCRegV1
return num_loop % 2 == 0 ? TailNumber::Even : TailNumber::Odd;
}
CK_TILE_HOST static constexpr amd_buffer_coherence_enum
GetBMemNTType(index_t M, index_t N, index_t K)
{
ck_tile::ignore = N;
ck_tile::ignore = K;
if(M <= 416)
{
return ck_tile::amd_buffer_coherence_enum::WAVE_NT1;
}
return ck_tile::amd_buffer_coherence_enum::coherence_default;
}
template <bool DispatchHotloop = false, TailNumber tail_num, typename RunFunction>
CK_TILE_HOST_DEVICE static auto TailHandler(const RunFunction& run_func, bool has_hot_loop)
{

View File

@@ -16,10 +16,12 @@ template <typename ADataType_,
typename CDataType_,
typename BlockGemmShape_,
typename Traits_,
GemmPipelineScheduler Scheduler_ = GemmPipelineScheduler::Intrawave,
bool HasHotLoop_ = true,
TailNumber TailNum_ = TailNumber::Full,
typename ComputeDataType_ = ADataType_>
GemmPipelineScheduler Scheduler_ = GemmPipelineScheduler::Intrawave,
bool HasHotLoop_ = true,
TailNumber TailNum_ = TailNumber::Full,
amd_buffer_coherence_enum BMemNTType_ = amd_buffer_coherence_enum::coherence_default,
bool BPreShufflePermute_ = false,
typename ComputeDataType_ = ADataType_>
struct F16xMXF4FlatmmPipelineProblem : FlatmmPipelineProblem<ADataType_,
ADataType_,
CDataType_,
@@ -28,6 +30,8 @@ struct F16xMXF4FlatmmPipelineProblem : FlatmmPipelineProblem<ADataType_,
Scheduler_,
HasHotLoop_,
TailNum_,
BMemNTType_,
BPreShufflePermute_,
ComputeDataType_>
{
using BlockGemmShape = BlockGemmShape_;
@@ -183,6 +187,9 @@ struct F16xMXF4FlatmmPipelineAGmemBGmemCRegV1
// For the basic gemm pipelien DoubleSmemBuffer set to be false naturally.
static constexpr bool DoubleSmemBuffer = false;
static constexpr auto BMemNTType = Problem::BMemNTType;
static constexpr bool BPreShufflePermute = Problem::BPreShufflePermute;
CK_TILE_HOST_DEVICE static constexpr auto
SchedulerPerM(index_t dsread_perM, index_t dswrite_perM, index_t load_perM)
{

View File

@@ -115,6 +115,9 @@ struct MoeFlatmmPipelineAGmemBGmemCRegV1
static constexpr index_t dswrite_mIter = (DsWritePreIssue - 1) % MIterPerWarp;
static constexpr index_t dswrite_kIter = (DsWritePreIssue - 1) / MIterPerWarp;
static constexpr auto BMemNTType = Problem::BMemNTType;
static constexpr bool BPreShufflePermute = Problem::BPreShufflePermute;
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
{
// clang-format off

View File

@@ -16,10 +16,12 @@ template <typename ADataType_,
typename CDataType_,
typename BlockGemmShape_,
typename Traits_,
GemmPipelineScheduler Scheduler_ = GemmPipelineScheduler::Intrawave,
bool HasHotLoop_ = true,
TailNumber TailNum_ = TailNumber::Full,
typename ComputeDataType_ = ADataType_>
GemmPipelineScheduler Scheduler_ = GemmPipelineScheduler::Intrawave,
bool HasHotLoop_ = true,
TailNumber TailNum_ = TailNumber::Full,
amd_buffer_coherence_enum BMemNTType_ = amd_buffer_coherence_enum::coherence_default,
bool BPreShufflePermute_ = false,
typename ComputeDataType_ = ADataType_>
struct MXFlatmmPipelineProblem : FlatmmPipelineProblem<ADataType_,
ADataType_,
CDataType_,
@@ -28,6 +30,8 @@ struct MXFlatmmPipelineProblem : FlatmmPipelineProblem<ADataType_,
Scheduler_,
HasHotLoop_,
TailNum_,
BMemNTType_,
BPreShufflePermute_,
ComputeDataType_>
{
using BlockGemmShape = BlockGemmShape_;