[CK_TILE] add atomic IGLP scheduler for wp gemm (#2739)

* add atomic IGLP scheduler

* clang format

---------

Co-authored-by: ThomasNing <thomas.ning@amd.com>
This commit is contained in:
lalala-sh
2025-09-09 05:57:14 +08:00
committed by GitHub
parent e11f694eda
commit e4a7728903
2 changed files with 320 additions and 390 deletions

View File

@@ -91,7 +91,7 @@ int main(int argc, char* argv[])
try
{
return !run_gemm_example<GemmConfigPreshuffleDecode>(arg_parser);
return !run_gemm_example<GemmConfigPreshufflePrefill>(arg_parser);
}
catch(const std::runtime_error& e)
{

View File

@@ -43,6 +43,7 @@ struct BaseWeightPreshufflePipelineAGmemBGmemCRegV2
return run_func(bool_constant<true>{},
integral_constant<TailNumber, TailNumber::Even>{});
}
return run_func(bool_constant<true>{}, integral_constant<TailNumber, TailNumber::Empty>{});
}
};
@@ -69,7 +70,11 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2
using WG = remove_cvref_t<decltype(config.template at<0>())>;
static constexpr index_t DsWritePreIssue = 3; // default 2, ds write at MIter - 2
static constexpr index_t DsReadPreload = 2; // default 2, preload 2 ds read
static constexpr index_t BlockSize = Problem::kBlockSize;
static constexpr index_t WaveSize = get_warp_size();
static constexpr index_t kMPerBlock = BlockGemmShape::kM;
static constexpr index_t kNPerBlock = BlockGemmShape::kN;
@@ -129,13 +134,33 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2
static constexpr index_t MPerBlockPerIter = kMPerBlock / MIterPerWarp;
static constexpr index_t KPerBlockPerIter = kKPerBlock / KIterPerWarp;
static constexpr index_t K1 = Problem::VectorLoadSize / sizeof(ADataType);
static constexpr index_t ACopyLoadNum = kMPerBlock * kKPerBlock / BlockSize / K1;
static constexpr auto TailNum = Problem::TailNum;
static constexpr index_t K1 = Problem::VectorLoadSize / sizeof(ADataType);
static constexpr index_t m_preload = (MIterPerWarp * KIterPerWarp >= DsReadPreload)
? DsReadPreload
: MIterPerWarp * KIterPerWarp;
static constexpr auto TailNum = Problem::TailNum;
static constexpr auto warp_m = WarpTile::at(idxM);
static constexpr auto warp_n = WarpTile::at(idxN);
static constexpr auto warp_k = WarpTile::at(idxK);
#ifdef __gfx942__
static constexpr index_t mfma_per_wg = 2;
#else
static constexpr index_t mfma_per_wg = 1;
#endif
static constexpr index_t dsread_per_wg =
WG::kM * WG::kK * sizeof(ADataType) / WaveSize / Problem::VectorLoadSize;
static_assert((WG::kM * WG::kK * sizeof(ADataType) / WaveSize) % Problem::VectorLoadSize == 0);
static constexpr index_t dsread_num_perK = dsread_per_wg * MIterPerWarp;
static constexpr index_t dswrite_num_perK = dsread_num_perK / (MWarp * NWarp);
static constexpr index_t dswrite_rep = (dswrite_num_perK + MIterPerWarp - 1) / MIterPerWarp;
static constexpr index_t Aload_num_perK = dswrite_num_perK;
static constexpr index_t Aload_rep = dswrite_rep;
static constexpr index_t Bload_num_perK = kNPerBlock * WG::kK / NWarp / K1 / WaveSize;
static constexpr index_t HalfMIter = (MIterPerWarp + 1) / 2;
static constexpr index_t Bload_rep = (Bload_num_perK + HalfMIter - 1) / HalfMIter;
static constexpr index_t mfma_perM_perK = NIterPerWarp * mfma_per_wg;
static constexpr index_t dswrite_mIter = (DsWritePreIssue - 1) % MIterPerWarp;
static constexpr index_t dswrite_kIter = (DsWritePreIssue - 1) / MIterPerWarp;
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
{
@@ -160,411 +185,314 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2
return PipelinePolicy::template GetSmemSize<Problem>();
}
CK_TILE_HOST_DEVICE static constexpr auto HotLoopScheduler()
// dsread_perM: how many LDS reads want to issue in this M-iter
// dswrite_perM: how many LDS writes you want to do this M-iter
// load_perM: how many global loads VMEM want to do in this M-iter
CK_TILE_HOST_DEVICE static constexpr auto
SchedulerPerM(index_t dsread_perM, index_t dswrite_perM, index_t load_perM)
{
constexpr index_t KPerLoad = Problem::VectorLoadSize / sizeof(ADataType);
constexpr index_t A_Buffer_Load_Inst_Num = kMPerBlock * kKPerBlock / BlockSize / KPerLoad;
constexpr index_t A_LDS_Read_Inst_Num = MIterPerWarp * KIterPerWarp;
constexpr index_t B_Buffer_Load_Inst_Num = NIterPerWarp * KIterPerWarp;
// Init inst order
index_t max_data_inst = dsread_perM > load_perM
? (dsread_perM > dswrite_perM ? dsread_perM : dswrite_perM)
: (load_perM > dswrite_perM ? load_perM : dswrite_perM);
index_t sum_data_inst = dsread_perM + load_perM + dswrite_perM;
index_t round_data_inst = ck_tile::integer_divide_ceil(sum_data_inst, mfma_perM_perK);
constexpr int kOrderCap = NIterPerWarp * 10;
index_t inst_order[kOrderCap] = {};
index_t index = 0;
#pragma unroll
// round-robin
// Index: 0 1 2 3 4 5 ...
// Value: 1 2 3 1 2 3 ...
for(int j = 0; j < max_data_inst; j++)
{
if(dswrite_perM > j)
{
inst_order[index] = 1;
index++;
}
if(load_perM > j)
{
inst_order[index] = 2;
index++;
}
if(dsread_perM > j)
{
inst_order[index] = 3;
index++;
}
}
// Schedule IGLP
#pragma unroll
for(int j = 0; j < mfma_perM_perK; j++)
{
index_t inst_idx = 0;
if(j == 0)
;
else if(j == 1)
inst_idx = mfma_perM_perK == 2 ? 1 : mfma_perM_perK - 2;
else if(j == 2)
inst_idx = mfma_perM_perK - 1;
else
inst_idx = mfma_perM_perK - j;
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
#pragma unroll
for(int r = 0; r < round_data_inst; r++)
{
if(r % 2 == 0)
{
if(inst_order[inst_idx + r * mfma_perM_perK] == 1)
{
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
}
if(inst_order[inst_idx + r * mfma_perM_perK] == 2)
{
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
}
if(inst_order[inst_idx + r * mfma_perM_perK] == 3)
{
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
}
}
else
{
if(inst_order[(r + 1) * mfma_perM_perK - 1 - inst_idx] == 1)
{
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
}
if(inst_order[(r + 1) * mfma_perM_perK - 1 - inst_idx] == 2)
{
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
}
if(inst_order[(r + 1) * mfma_perM_perK - 1 - inst_idx] == 3)
{
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
}
}
}
}
}
CK_TILE_HOST_DEVICE static constexpr auto HotLoopScheduler()
{
// Keypoint of pipeline optimize is workload balance in time
// instruction schedule example(128X256X256, 1X4, 16X16X128):
// Iter MNK MFMA ds_read ds_write A_load b_load
// -1 M6N0: 57 - 8 - -
// -1 M6N1: 58 1 - - -
// -1 M6N2: 59 - - 7 -
// -1 M6N3: 60 2 - - -
// -1 M7N0: 61 - - - -
// -1 M7N1: 62 - - - -
// -1 M7N2: 63 - - - -
// -1 M7N1: 62 3 - - -
// -1 M7N2: 63 - - 8 -
// -1 M7N3: 64 4 - - -
// 0 M0N0K0: 1 - - - -
// 0 M0N1: 2 - - - 2
// 0 M0N2: 3 - - - -
// 0 M0N0K0: 1 - - - 1
// 0 M0N1: 2 5 - - -
// 0 M0N2: 3 - - - 2
// 0 M0N3: 4 6 - - -
// 0 M1N0: 5 - - - -
// 0 M1N1: 6 - - - 4
// 0 M1N2: 7 - - - -
// 0 M1N0: 5 - - - 3
// 0 M1N1: 6 7 - - -
// 0 M1N2: 7 - - - 4
// 0 M1N3: 8 8 - - -
// 0 M2N0: 9 - - - -
// 0 M2N1: 10 - - - 6
// 0 M2N2: 11 - - - -
// 0 M2N0: 9 - - - 5
// 0 M2N1: 10 9 - - -
// 0 M2N2: 11 - - - 6
// 0 M2N3: 12 10 - - -
// 0 M3N0: 13 - 1 - -
// 0 M3N1: 14 - - - 8
// 0 M3N2: 15 - - - -
// 0 M3N0: 13 - 1 - 7
// 0 M3N1: 14 11 - - -
// 0 M3N2: 15 - - - 8
// 0 M3N3: 16 12 - - -
// 0 M4N0: 17 - 2 - -
// 0 M4N1: 18 - - - -
// 0 M4N1: 18 13 - - -
// 0 M4N2: 19 - - 1 -
// 0 M4N3: 20 14 - - -
// 0 M5N0: 21 - 3 - -
// 0 M5N1: 22 - - - -
// 0 M5N1: 22 15 - - -
// 0 M5N2: 23 - - 2 -
// 0 M5N3: 24 16 - - -
// 0 M6N0: 25 - 4 - -
// 0 M6N1: 26 - - - -
// 0 M6N1: 26 17 - - -
// 0 M6N2: 27 - - 3 -
// 0 M6N3: 28 17 - - -
// 0 M6N3: 28 18 - - -
// 0 M7N0: 29 - - - -
// 0 M7N1: 30 - - - -
// 0 M7N1: 30 19 - - -
// 0 M7N2: 31 - - 4 -
// 0 M7N3: 32 18 - - -
// 0 M0N0K1: 33 - - - -
// 0 M0N1: 34 - - - 10
// 0 M0N2: 35 - - - -
// 0 M0N3: 36 20 - - -
// 0 M1N0: 37 - - - -
// 0 M1N1: 38 - - - 12
// 0 M1N2: 39 - - - -
// 0 M1N3: 40 22 - - -
// 0 M2N0: 41 - - - -
// 0 M2N1: 42 - - - 14
// 0 M2N2: 43 - - - -
// 0 M2N3: 44 24 - - -
// 0 M3N0: 45 - 5 - -
// 0 M3N1: 46 - - - 16
// 0 M3N2: 47 - - - -
// 0 M3N3: 48 26 - - -
// 0 M7N3: 32 20 - - -
// 0 M0N0K1: 33 - - - 9
// 0 M0N1: 34 21 - - -
// 0 M0N2: 35 - - - 10
// 0 M0N3: 36 22 - - -
// 0 M1N0: 37 - - - 11
// 0 M1N1: 38 23 - - -
// 0 M1N2: 39 - - - 12
// 0 M1N3: 40 24 - - -
// 0 M2N0: 41 - - - 13
// 0 M2N1: 42 25 - - -
// 0 M2N2: 43 - - - 14
// 0 M2N3: 44 26 - - -
// 0 M3N0: 45 - 5 - 15
// 0 M3N1: 46 27 - - -
// 0 M3N2: 47 - - - 16
// 0 M3N3: 48 28 - - -
// 0 M4N0: 49 - 6 - -
// 0 M4N1: 50 - - - -
// 0 M4N1: 50 29 - - -
// 0 M4N2: 51 - - 5 -
// 0 M4N3: 52 28 - - -
// 0 M4N3: 52 30 - - -
// 0 M5N0: 53 - 7 - -
// 0 M5N1: 54 - - - -
// 0 M5N1: 54 31 - - -
// 0 M5N2: 55 - - 6 -
// 0 M5N3: 56 30 - - -
// 0 M5N3: 56 32 - - -
// 0 M6N0: 57 - 8 - -
// 0 M6N1: 58 - - - -
// 0 M6N1: 58 1 - - -
// 0 M6N2: 59 - - 7 -
// 0 M6N3: 60 2 - - -
// 0 M7N0: 61 - - - -
// 0 M7N1: 62 - - - -
// 0 M7N1: 62 3 - - -
// 0 M7N2: 63 - - 8 -
// 0 M7N3: 64 4 - - -
if constexpr(warp_m == 16 && warp_n == 16)
#pragma unroll
for(int kIter = 0; kIter < KIterPerWarp; kIter++)
{
// MFMA -> VMEM READ -> MFMA -> DS Read -> MFMA
// hiding the glbal memory VMEM latency
#if defined(__gfx950__)
if constexpr(kMPerBlock == 128 && kNPerBlock == 256 && kKPerBlock == 256)
#pragma unroll
for(int mIter = 0; mIter < MIterPerWarp; mIter++)
{
static_for<0, 2, 1>{}([&](auto j) {
ignore = j;
static_for<0, 3, 1>{}([&](auto i) {
ignore = i;
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
});
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
index_t dsread_perM = 0;
index_t dswrite_perM = 0;
index_t load_perM = 0;
static_for<0, 3, 1>{}([&](auto i) {
ignore = i;
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
});
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
});
// Calculate ds_read number per M
dsread_perM = dsread_per_wg;
__builtin_amdgcn_sched_barrier(0);
// Calculate ds_write number per M
if(mIter == 0)
{
dswrite_perM =
(dswrite_num_perK - (MIterPerWarp - DsWritePreIssue) * dswrite_rep) > 0
? dswrite_num_perK - (MIterPerWarp - DsWritePreIssue) * dswrite_rep
: 0;
}
else if(mIter >= MIterPerWarp - DsWritePreIssue + 1)
{
dswrite_perM = 0;
}
else
{
dswrite_perM = (dswrite_num_perK -
(MIterPerWarp - DsWritePreIssue - mIter) * dswrite_rep) > 0
? dswrite_rep
: 0;
}
// Add ds write when ds write data > needed
if(dswrite_num_perK == 0 && kIter == (KIterPerWarp - 1 - dswrite_kIter))
{
if(mIter == MIterPerWarp - 1 - dswrite_mIter)
dswrite_perM = 1;
}
// Calculate buffer_load number per M
if(mIter < HalfMIter)
{
load_perM =
((Aload_num_perK - (MIterPerWarp - 1 - mIter) * Aload_rep) > 0 ? Aload_rep
: 0) +
((Bload_num_perK - (HalfMIter - 1 - mIter) * Bload_rep) > 0 ? Bload_rep
: 0);
}
else
{
load_perM = (Aload_num_perK - (MIterPerWarp - 1 - mIter) * Aload_rep) > 0
? Aload_rep
: 0;
}
SchedulerPerM(dsread_perM, dswrite_perM, load_perM);
}
else
{
static_for<0, 2, 1>{}([&](auto j) {
ignore = j;
static_for<0, 3, 1>{}([&](auto i) {
ignore = i;
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
});
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
static_for<0, 3, 1>{}([&](auto i) {
ignore = i;
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
});
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
});
__builtin_amdgcn_sched_barrier(0);
}
// MFMA → MFMA → MFMA → MFMA → DS Read
// For other device engine we need more aggressive MFMA with DS writes interleaved
#else
if constexpr(kMPerBlock == 128 && kNPerBlock == 256 && kKPerBlock == 256)
{
static_for<0, 2, 1>{}([&](auto j) {
ignore = j;
// Uses loops to amortize scheduling overhead
static_for<0, 4, 1>{}([&](auto i) {
ignore = i;
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
});
static_for<0, 1, 1>{}([&](auto i) {
ignore = i;
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
});
static_for<0, 1, 1>{}([&](auto i) {
ignore = i;
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
});
static_for<0, 1, 1>{}([&](auto i) {
ignore = i;
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
});
static_for<0, 1, 1>{}([&](auto i) {
ignore = i;
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
});
});
__builtin_amdgcn_sched_barrier(0);
}
else if constexpr(kMPerBlock == 16 && kNPerBlock == 64 && kKPerBlock == 256)
{
static_for<0, 1, 1>{}([&](auto i) {
ignore = i;
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
});
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
__builtin_amdgcn_sched_barrier(0);
}
else if constexpr(kMPerBlock == 128 && kNPerBlock == 128 && kKPerBlock == 128)
{
// prioritize MFMA to avoid LDS write conflicts
static_for<0, 2, 1>{}([&](auto j) {
ignore = j;
static_for<0, 2, 1>{}([&](auto i) {
ignore = i;
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
});
static_for<0, 2, 1>{}([&](auto i) {
ignore = i;
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
});
static_for<0, 1, 1>{}([&](auto i) {
ignore = i;
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
});
static_for<0, 1, 1>{}([&](auto i) {
ignore = i;
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
});
static_for<0, 1, 1>{}([&](auto i) {
ignore = i;
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
});
static_for<0, 1, 1>{}([&](auto i) {
ignore = i;
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
});
});
__builtin_amdgcn_sched_barrier(0);
}
else
{
static_for<0, A_Buffer_Load_Inst_Num, 1>{}([&](auto i) {
ignore = i;
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
});
static_for<0, A_LDS_Read_Inst_Num - A_Buffer_Load_Inst_Num, 1>{}([&](auto i) {
ignore = i;
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
__builtin_amdgcn_sched_group_barrier(0x008, 3, 0); // MFMA
});
static_for<0, B_Buffer_Load_Inst_Num, 1>{}([&](auto i) {
ignore = i;
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
__builtin_amdgcn_sched_group_barrier(0x008, 2, 0); // MFMA
});
static_for<0, A_Buffer_Load_Inst_Num, 1>{}([&](auto i) {
ignore = i;
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
__builtin_amdgcn_sched_group_barrier(0x008, 4, 0); // MFMA
});
}
#endif
}
else
// Add Aload when Aload data > needed
if(Aload_num_perK == 0)
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
__builtin_amdgcn_sched_barrier(0);
}
CK_TILE_HOST_DEVICE static constexpr auto Last2ndHotLoopScheduler()
{
#pragma unroll
for(int kIter = 0; kIter < KIterPerWarp; kIter++)
{
if constexpr((A_LDS_Read_Inst_Num / 2 >
A_Buffer_Load_Inst_Num + B_Buffer_Load_Inst_Num))
#pragma unroll
for(int mIter = 0; mIter < MIterPerWarp; mIter++)
{
static_for<0,
A_LDS_Read_Inst_Num / 2 - A_Buffer_Load_Inst_Num -
B_Buffer_Load_Inst_Num,
1>{}([&](auto i) {
ignore = i;
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
});
index_t dsread_perM = 0;
index_t dswrite_perM = 0;
index_t load_perM = 0;
// Calculate ds_read number per M
dsread_perM = dsread_per_wg;
// Calculate ds_write number per M
if(mIter == 0)
{
dswrite_perM =
(dswrite_num_perK - (MIterPerWarp - DsWritePreIssue) * dswrite_rep) > 0
? dswrite_num_perK - (MIterPerWarp - DsWritePreIssue) * dswrite_rep
: 0;
}
else if(mIter >= MIterPerWarp - DsWritePreIssue + 1)
{
dswrite_perM = 0;
}
else
{
dswrite_perM = (dswrite_num_perK -
(MIterPerWarp - DsWritePreIssue - mIter) * dswrite_rep) > 0
? dswrite_rep
: 0;
}
// Add ds write when ds write data > needed
if(dswrite_num_perK == 0 && kIter == (KIterPerWarp - 1 - dswrite_kIter))
{
if(mIter == MIterPerWarp - 1 - dswrite_mIter)
dswrite_perM = 1;
}
// Calculate buffer_load number per M
if(mIter < HalfMIter)
{
load_perM =
((Bload_num_perK - (HalfMIter - 1 - mIter) * Bload_rep) > 0 ? Bload_rep
: 0);
}
SchedulerPerM(dsread_perM, dswrite_perM, load_perM);
}
static_for<0, A_Buffer_Load_Inst_Num, 1>{}([&](auto i) {
ignore = i;
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
});
static_for<0, A_LDS_Read_Inst_Num / 2, 1>{}([&](auto i) {
ignore = i;
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
});
static_for<0, B_Buffer_Load_Inst_Num, 1>{}([&](auto i) {
ignore = i;
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
});
static_for<0, A_Buffer_Load_Inst_Num, 1>{}([&](auto i) {
ignore = i;
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
__builtin_amdgcn_sched_group_barrier(0x008, 3, 0); // MFMA
});
__builtin_amdgcn_sched_group_barrier(0x008, 4, 0); // MFMA
}
__builtin_amdgcn_sched_barrier(0);
}
CK_TILE_HOST_DEVICE static constexpr auto LastHotLoopScheduler()
{
#pragma unroll
for(int kIter = 0; kIter < KIterPerWarp; kIter++)
{
#pragma unroll
for(int mIter = 0; mIter < MIterPerWarp; mIter++)
{
index_t dsread_perM = 0;
index_t dswrite_perM = 0;
index_t load_perM = 0;
// Calculate ds_read number per M
if((kIter * MIterPerWarp + mIter) < (KIterPerWarp * MIterPerWarp - m_preload))
dsread_perM = dsread_per_wg;
SchedulerPerM(dsread_perM, dswrite_perM, load_perM);
}
}
// __builtin_amdgcn_sched_barrier(0);
}
template <TailNumber TailNum,
@@ -738,22 +666,19 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2
block_sync_lds();
// preload A00,A10 from lds
constexpr auto m_preload = (MIterPerWarp * KIterPerWarp >= 2) ? 2 : 1;
statically_indexed_array<decltype(load_tile(a_warp_windows_ping(number<0>{})(number<0>{}))),
m_preload>
a_warp_tensor_ping;
statically_indexed_array<decltype(load_tile(a_warp_windows_pong(number<0>{})(number<0>{}))),
m_preload>
a_warp_tensor_pong;
a_warp_tensor;
static_for<0, m_preload, 1>{}([&](auto loadIter) {
constexpr auto mIter = loadIter % MIterPerWarp;
constexpr auto kIter = loadIter / MIterPerWarp;
a_warp_tensor_ping(loadIter) =
a_warp_tensor(loadIter) =
load_tile(a_warp_windows_ping(number<mIter>{})(number<kIter>{}));
});
__builtin_amdgcn_sched_barrier(0);
// MAIN LOOP
index_t iCounter = (num_loop - 1) / 2;
while(iCounter > 0)
{
@@ -792,7 +717,7 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2
// warp GEMM
WG{}(c_warp_tensor,
a_warp_tensor_ping(number<AwarpIter>{}),
a_warp_tensor(number<AwarpIter>{}),
b_warp_tensor_ping(nIter)(kIter));
// write C warp tensor into C block tensor
@@ -809,7 +734,7 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2
{
constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp;
constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp);
a_warp_tensor_ping(number<AwarpIter>{}) =
a_warp_tensor(number<AwarpIter>{}) =
load_tile(a_warp_windows_ping(number<AmIter>{})(number<AkIter>{}));
}
@@ -826,7 +751,7 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2
static_for<0, m_preload, 1>{}([&](auto loadIter) {
constexpr auto mIter = loadIter % MIterPerWarp;
constexpr auto kIter = loadIter / MIterPerWarp;
a_warp_tensor_pong(loadIter) =
a_warp_tensor(loadIter) =
load_tile(a_warp_windows_pong(number<mIter>{})(number<kIter>{}));
});
HotLoopScheduler();
@@ -867,7 +792,7 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2
// warp GEMM
WG{}(c_warp_tensor,
a_warp_tensor_pong(number<AwarpIter>{}),
a_warp_tensor(number<AwarpIter>{}),
b_warp_tensor_pong(nIter)(kIter));
// write C warp tensor into C block tensor
@@ -884,7 +809,7 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2
{
constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp;
constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp);
a_warp_tensor_pong(number<AwarpIter>{}) =
a_warp_tensor(number<AwarpIter>{}) =
load_tile(a_warp_windows_pong(number<AmIter>{})(number<AkIter>{}));
}
@@ -901,7 +826,7 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2
static_for<0, m_preload, 1>{}([&](auto loadIter) {
constexpr auto mIter = loadIter % MIterPerWarp;
constexpr auto kIter = loadIter / MIterPerWarp;
a_warp_tensor_ping(loadIter) =
a_warp_tensor(loadIter) =
load_tile(a_warp_windows_ping(number<mIter>{})(number<kIter>{}));
});
HotLoopScheduler();
@@ -943,7 +868,7 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2
// warp GEMM
WG{}(c_warp_tensor,
a_warp_tensor_ping(number<AwarpIter>{}),
a_warp_tensor(number<AwarpIter>{}),
b_warp_tensor_ping(nIter)(kIter));
// write C warp tensor into C block tensor
@@ -960,7 +885,7 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2
{
constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp;
constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp);
a_warp_tensor_ping(number<AwarpIter>{}) =
a_warp_tensor(number<AwarpIter>{}) =
load_tile(a_warp_windows_ping(number<AmIter>{})(number<AkIter>{}));
}
@@ -976,11 +901,11 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2
static_for<0, m_preload, 1>{}([&](auto loadIter) {
constexpr auto mIter = loadIter % MIterPerWarp;
constexpr auto kIter = loadIter / MIterPerWarp;
a_warp_tensor_pong(loadIter) =
a_warp_tensor(loadIter) =
load_tile(a_warp_windows_pong(number<mIter>{})(number<kIter>{}));
});
// __builtin_amdgcn_sched_barrier(0);
Last2ndHotLoopScheduler();
// GEMM loopK
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
@@ -996,7 +921,7 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2
// warp GEMM
WG{}(c_warp_tensor,
a_warp_tensor_pong(number<AwarpIter>{}),
a_warp_tensor(number<AwarpIter>{}),
b_warp_tensor_pong(nIter)(kIter));
// write C warp tensor into C block tensor
@@ -1004,19 +929,23 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
c_warp_tensor.get_thread_buffer());
__builtin_amdgcn_sched_barrier(0x7F6);
});
if constexpr((kIter * MIterPerWarp + mIter) <
(KIterPerWarp * MIterPerWarp - m_preload))
{
constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp;
constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp);
a_warp_tensor_pong(number<AwarpIter>{}) =
a_warp_tensor(number<AwarpIter>{}) =
load_tile(a_warp_windows_pong(number<AmIter>{})(number<AkIter>{}));
}
// barrier
if constexpr((kIter == KIterPerWarp - 1) && (mIter == MIter_2nd_last))
{
block_sync_lds();
}
});
});
// TailHotLoopScheduler();
LastHotLoopScheduler();
}
else if constexpr(TailNum == TailNumber::Odd)
{
@@ -1034,7 +963,7 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2
// warp GEMM
WG{}(c_warp_tensor,
a_warp_tensor_ping(number<AwarpIter>{}),
a_warp_tensor(number<AwarpIter>{}),
b_warp_tensor_ping(nIter)(kIter));
// write C warp tensor into C block tensor
@@ -1051,7 +980,7 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2
{
constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp;
constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp);
a_warp_tensor_ping(number<AwarpIter>{}) =
a_warp_tensor(number<AwarpIter>{}) =
load_tile(a_warp_windows_ping(number<AmIter>{})(number<AkIter>{}));
}
@@ -1062,6 +991,7 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2
}
});
});
LastHotLoopScheduler();
}
return c_block_tile;