mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-05 06:01:23 +00:00
[CK_TILE MOE] add NT & preshuffle permute to cktile MOE (#3377)
* update coherence --------- Co-authored-by: Zzz9990 <Zzz9990>
This commit is contained in:
@@ -24,6 +24,18 @@ struct BaseFlatmmPipelineAGmemBGmemCRegV1
|
||||
return num_loop % 2 == 0 ? TailNumber::Even : TailNumber::Odd;
|
||||
}
|
||||
|
||||
CK_TILE_HOST static constexpr amd_buffer_coherence_enum
|
||||
GetBMemNTType(index_t M, index_t N, index_t K)
|
||||
{
|
||||
ck_tile::ignore = N;
|
||||
ck_tile::ignore = K;
|
||||
if(M <= 416)
|
||||
{
|
||||
return ck_tile::amd_buffer_coherence_enum::WAVE_NT1;
|
||||
}
|
||||
return ck_tile::amd_buffer_coherence_enum::coherence_default;
|
||||
}
|
||||
|
||||
template <bool DispatchHotloop = false, TailNumber tail_num, typename RunFunction>
|
||||
CK_TILE_HOST_DEVICE static auto TailHandler(const RunFunction& run_func, bool has_hot_loop)
|
||||
{
|
||||
|
||||
@@ -16,10 +16,12 @@ template <typename ADataType_,
|
||||
typename CDataType_,
|
||||
typename BlockGemmShape_,
|
||||
typename Traits_,
|
||||
GemmPipelineScheduler Scheduler_ = GemmPipelineScheduler::Intrawave,
|
||||
bool HasHotLoop_ = true,
|
||||
TailNumber TailNum_ = TailNumber::Full,
|
||||
typename ComputeDataType_ = ADataType_>
|
||||
GemmPipelineScheduler Scheduler_ = GemmPipelineScheduler::Intrawave,
|
||||
bool HasHotLoop_ = true,
|
||||
TailNumber TailNum_ = TailNumber::Full,
|
||||
amd_buffer_coherence_enum BMemNTType_ = amd_buffer_coherence_enum::coherence_default,
|
||||
bool BPreShufflePermute_ = false,
|
||||
typename ComputeDataType_ = ADataType_>
|
||||
struct F16xMXF4FlatmmPipelineProblem : FlatmmPipelineProblem<ADataType_,
|
||||
ADataType_,
|
||||
CDataType_,
|
||||
@@ -28,6 +30,8 @@ struct F16xMXF4FlatmmPipelineProblem : FlatmmPipelineProblem<ADataType_,
|
||||
Scheduler_,
|
||||
HasHotLoop_,
|
||||
TailNum_,
|
||||
BMemNTType_,
|
||||
BPreShufflePermute_,
|
||||
ComputeDataType_>
|
||||
{
|
||||
using BlockGemmShape = BlockGemmShape_;
|
||||
@@ -183,6 +187,9 @@ struct F16xMXF4FlatmmPipelineAGmemBGmemCRegV1
|
||||
// For the basic gemm pipelien DoubleSmemBuffer set to be false naturally.
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
|
||||
static constexpr auto BMemNTType = Problem::BMemNTType;
|
||||
static constexpr bool BPreShufflePermute = Problem::BPreShufflePermute;
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto
|
||||
SchedulerPerM(index_t dsread_perM, index_t dswrite_perM, index_t load_perM)
|
||||
{
|
||||
|
||||
@@ -115,6 +115,9 @@ struct MoeFlatmmPipelineAGmemBGmemCRegV1
|
||||
static constexpr index_t dswrite_mIter = (DsWritePreIssue - 1) % MIterPerWarp;
|
||||
static constexpr index_t dswrite_kIter = (DsWritePreIssue - 1) / MIterPerWarp;
|
||||
|
||||
static constexpr auto BMemNTType = Problem::BMemNTType;
|
||||
static constexpr bool BPreShufflePermute = Problem::BPreShufflePermute;
|
||||
|
||||
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
|
||||
{
|
||||
// clang-format off
|
||||
|
||||
@@ -16,10 +16,12 @@ template <typename ADataType_,
|
||||
typename CDataType_,
|
||||
typename BlockGemmShape_,
|
||||
typename Traits_,
|
||||
GemmPipelineScheduler Scheduler_ = GemmPipelineScheduler::Intrawave,
|
||||
bool HasHotLoop_ = true,
|
||||
TailNumber TailNum_ = TailNumber::Full,
|
||||
typename ComputeDataType_ = ADataType_>
|
||||
GemmPipelineScheduler Scheduler_ = GemmPipelineScheduler::Intrawave,
|
||||
bool HasHotLoop_ = true,
|
||||
TailNumber TailNum_ = TailNumber::Full,
|
||||
amd_buffer_coherence_enum BMemNTType_ = amd_buffer_coherence_enum::coherence_default,
|
||||
bool BPreShufflePermute_ = false,
|
||||
typename ComputeDataType_ = ADataType_>
|
||||
struct MXFlatmmPipelineProblem : FlatmmPipelineProblem<ADataType_,
|
||||
ADataType_,
|
||||
CDataType_,
|
||||
@@ -28,6 +30,8 @@ struct MXFlatmmPipelineProblem : FlatmmPipelineProblem<ADataType_,
|
||||
Scheduler_,
|
||||
HasHotLoop_,
|
||||
TailNum_,
|
||||
BMemNTType_,
|
||||
BPreShufflePermute_,
|
||||
ComputeDataType_>
|
||||
{
|
||||
using BlockGemmShape = BlockGemmShape_;
|
||||
|
||||
Reference in New Issue
Block a user