[CK_TILE] add atomic IGLP scheduler for wp gemm (#2739)

* add atomic IGLP scheduler

* clang format

---------

Co-authored-by: ThomasNing <thomas.ning@amd.com>

[ROCm/composable_kernel commit: e4a7728903]
This commit is contained in:
lalala-sh
2025-09-09 05:57:14 +08:00
committed by GitHub
parent 44bd944957
commit b2f2800468
2 changed files with 320 additions and 390 deletions

View File

@@ -91,7 +91,7 @@ int main(int argc, char* argv[])
try try
{ {
return !run_gemm_example<GemmConfigPreshuffleDecode>(arg_parser); return !run_gemm_example<GemmConfigPreshufflePrefill>(arg_parser);
} }
catch(const std::runtime_error& e) catch(const std::runtime_error& e)
{ {

View File

@@ -43,6 +43,7 @@ struct BaseWeightPreshufflePipelineAGmemBGmemCRegV2
return run_func(bool_constant<true>{}, return run_func(bool_constant<true>{},
integral_constant<TailNumber, TailNumber::Even>{}); integral_constant<TailNumber, TailNumber::Even>{});
} }
return run_func(bool_constant<true>{}, integral_constant<TailNumber, TailNumber::Empty>{});
} }
}; };
@@ -69,7 +70,11 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2
using WG = remove_cvref_t<decltype(config.template at<0>())>; using WG = remove_cvref_t<decltype(config.template at<0>())>;
static constexpr index_t DsWritePreIssue = 3; // default 2, ds write at MIter - 2
static constexpr index_t DsReadPreload = 2; // default 2, preload 2 ds read
static constexpr index_t BlockSize = Problem::kBlockSize; static constexpr index_t BlockSize = Problem::kBlockSize;
static constexpr index_t WaveSize = get_warp_size();
static constexpr index_t kMPerBlock = BlockGemmShape::kM; static constexpr index_t kMPerBlock = BlockGemmShape::kM;
static constexpr index_t kNPerBlock = BlockGemmShape::kN; static constexpr index_t kNPerBlock = BlockGemmShape::kN;
@@ -129,13 +134,33 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2
static constexpr index_t MPerBlockPerIter = kMPerBlock / MIterPerWarp; static constexpr index_t MPerBlockPerIter = kMPerBlock / MIterPerWarp;
static constexpr index_t KPerBlockPerIter = kKPerBlock / KIterPerWarp; static constexpr index_t KPerBlockPerIter = kKPerBlock / KIterPerWarp;
static constexpr index_t K1 = Problem::VectorLoadSize / sizeof(ADataType); static constexpr index_t K1 = Problem::VectorLoadSize / sizeof(ADataType);
static constexpr index_t ACopyLoadNum = kMPerBlock * kKPerBlock / BlockSize / K1; static constexpr index_t m_preload = (MIterPerWarp * KIterPerWarp >= DsReadPreload)
static constexpr auto TailNum = Problem::TailNum; ? DsReadPreload
: MIterPerWarp * KIterPerWarp;
static constexpr auto TailNum = Problem::TailNum;
static constexpr auto warp_m = WarpTile::at(idxM); #ifdef __gfx942__
static constexpr auto warp_n = WarpTile::at(idxN); static constexpr index_t mfma_per_wg = 2;
static constexpr auto warp_k = WarpTile::at(idxK); #else
static constexpr index_t mfma_per_wg = 1;
#endif
static constexpr index_t dsread_per_wg =
WG::kM * WG::kK * sizeof(ADataType) / WaveSize / Problem::VectorLoadSize;
static_assert((WG::kM * WG::kK * sizeof(ADataType) / WaveSize) % Problem::VectorLoadSize == 0);
static constexpr index_t dsread_num_perK = dsread_per_wg * MIterPerWarp;
static constexpr index_t dswrite_num_perK = dsread_num_perK / (MWarp * NWarp);
static constexpr index_t dswrite_rep = (dswrite_num_perK + MIterPerWarp - 1) / MIterPerWarp;
static constexpr index_t Aload_num_perK = dswrite_num_perK;
static constexpr index_t Aload_rep = dswrite_rep;
static constexpr index_t Bload_num_perK = kNPerBlock * WG::kK / NWarp / K1 / WaveSize;
static constexpr index_t HalfMIter = (MIterPerWarp + 1) / 2;
static constexpr index_t Bload_rep = (Bload_num_perK + HalfMIter - 1) / HalfMIter;
static constexpr index_t mfma_perM_perK = NIterPerWarp * mfma_per_wg;
static constexpr index_t dswrite_mIter = (DsWritePreIssue - 1) % MIterPerWarp;
static constexpr index_t dswrite_kIter = (DsWritePreIssue - 1) / MIterPerWarp;
[[nodiscard]] CK_TILE_HOST static const std::string GetName() [[nodiscard]] CK_TILE_HOST static const std::string GetName()
{ {
@@ -160,411 +185,314 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2
return PipelinePolicy::template GetSmemSize<Problem>(); return PipelinePolicy::template GetSmemSize<Problem>();
} }
CK_TILE_HOST_DEVICE static constexpr auto HotLoopScheduler() // dsread_perM: how many LDS reads want to issue in this M-iter
// dswrite_perM: how many LDS writes you want to do this M-iter
// load_perM: how many global loads VMEM want to do in this M-iter
CK_TILE_HOST_DEVICE static constexpr auto
SchedulerPerM(index_t dsread_perM, index_t dswrite_perM, index_t load_perM)
{ {
constexpr index_t KPerLoad = Problem::VectorLoadSize / sizeof(ADataType); // Init inst order
constexpr index_t A_Buffer_Load_Inst_Num = kMPerBlock * kKPerBlock / BlockSize / KPerLoad; index_t max_data_inst = dsread_perM > load_perM
constexpr index_t A_LDS_Read_Inst_Num = MIterPerWarp * KIterPerWarp; ? (dsread_perM > dswrite_perM ? dsread_perM : dswrite_perM)
constexpr index_t B_Buffer_Load_Inst_Num = NIterPerWarp * KIterPerWarp; : (load_perM > dswrite_perM ? load_perM : dswrite_perM);
index_t sum_data_inst = dsread_perM + load_perM + dswrite_perM;
index_t round_data_inst = ck_tile::integer_divide_ceil(sum_data_inst, mfma_perM_perK);
constexpr int kOrderCap = NIterPerWarp * 10;
index_t inst_order[kOrderCap] = {};
index_t index = 0;
#pragma unroll
// round-robin
// Index: 0 1 2 3 4 5 ...
// Value: 1 2 3 1 2 3 ...
for(int j = 0; j < max_data_inst; j++)
{
if(dswrite_perM > j)
{
inst_order[index] = 1;
index++;
}
if(load_perM > j)
{
inst_order[index] = 2;
index++;
}
if(dsread_perM > j)
{
inst_order[index] = 3;
index++;
}
}
// Schedule IGLP
#pragma unroll
for(int j = 0; j < mfma_perM_perK; j++)
{
index_t inst_idx = 0;
if(j == 0)
;
else if(j == 1)
inst_idx = mfma_perM_perK == 2 ? 1 : mfma_perM_perK - 2;
else if(j == 2)
inst_idx = mfma_perM_perK - 1;
else
inst_idx = mfma_perM_perK - j;
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
#pragma unroll
for(int r = 0; r < round_data_inst; r++)
{
if(r % 2 == 0)
{
if(inst_order[inst_idx + r * mfma_perM_perK] == 1)
{
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
}
if(inst_order[inst_idx + r * mfma_perM_perK] == 2)
{
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
}
if(inst_order[inst_idx + r * mfma_perM_perK] == 3)
{
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
}
}
else
{
if(inst_order[(r + 1) * mfma_perM_perK - 1 - inst_idx] == 1)
{
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
}
if(inst_order[(r + 1) * mfma_perM_perK - 1 - inst_idx] == 2)
{
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
}
if(inst_order[(r + 1) * mfma_perM_perK - 1 - inst_idx] == 3)
{
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
}
}
}
}
}
CK_TILE_HOST_DEVICE static constexpr auto HotLoopScheduler()
{
// Keypoint of pipeline optimize is workload balance in time // Keypoint of pipeline optimize is workload balance in time
// instruction schedule example(128X256X256, 1X4, 16X16X128): // instruction schedule example(128X256X256, 1X4, 16X16X128):
// Iter MNK MFMA ds_read ds_write A_load b_load // Iter MNK MFMA ds_read ds_write A_load b_load
// -1 M6N0: 57 - 8 - -
// -1 M6N1: 58 1 - - -
// -1 M6N2: 59 - - 7 -
// -1 M6N3: 60 2 - - - // -1 M6N3: 60 2 - - -
// -1 M7N0: 61 - - - - // -1 M7N0: 61 - - - -
// -1 M7N1: 62 - - - - // -1 M7N1: 62 3 - - -
// -1 M7N2: 63 - - - - // -1 M7N2: 63 - - 8 -
// -1 M7N3: 64 4 - - - // -1 M7N3: 64 4 - - -
// 0 M0N0K0: 1 - - - - // 0 M0N0K0: 1 - - - 1
// 0 M0N1: 2 - - - 2 // 0 M0N1: 2 5 - - -
// 0 M0N2: 3 - - - - // 0 M0N2: 3 - - - 2
// 0 M0N3: 4 6 - - - // 0 M0N3: 4 6 - - -
// 0 M1N0: 5 - - - - // 0 M1N0: 5 - - - 3
// 0 M1N1: 6 - - - 4 // 0 M1N1: 6 7 - - -
// 0 M1N2: 7 - - - - // 0 M1N2: 7 - - - 4
// 0 M1N3: 8 8 - - - // 0 M1N3: 8 8 - - -
// 0 M2N0: 9 - - - - // 0 M2N0: 9 - - - 5
// 0 M2N1: 10 - - - 6 // 0 M2N1: 10 9 - - -
// 0 M2N2: 11 - - - - // 0 M2N2: 11 - - - 6
// 0 M2N3: 12 10 - - - // 0 M2N3: 12 10 - - -
// 0 M3N0: 13 - 1 - - // 0 M3N0: 13 - 1 - 7
// 0 M3N1: 14 - - - 8 // 0 M3N1: 14 11 - - -
// 0 M3N2: 15 - - - - // 0 M3N2: 15 - - - 8
// 0 M3N3: 16 12 - - - // 0 M3N3: 16 12 - - -
// 0 M4N0: 17 - 2 - - // 0 M4N0: 17 - 2 - -
// 0 M4N1: 18 - - - - // 0 M4N1: 18 13 - - -
// 0 M4N2: 19 - - 1 - // 0 M4N2: 19 - - 1 -
// 0 M4N3: 20 14 - - - // 0 M4N3: 20 14 - - -
// 0 M5N0: 21 - 3 - - // 0 M5N0: 21 - 3 - -
// 0 M5N1: 22 - - - - // 0 M5N1: 22 15 - - -
// 0 M5N2: 23 - - 2 - // 0 M5N2: 23 - - 2 -
// 0 M5N3: 24 16 - - - // 0 M5N3: 24 16 - - -
// 0 M6N0: 25 - 4 - - // 0 M6N0: 25 - 4 - -
// 0 M6N1: 26 - - - - // 0 M6N1: 26 17 - - -
// 0 M6N2: 27 - - 3 - // 0 M6N2: 27 - - 3 -
// 0 M6N3: 28 17 - - - // 0 M6N3: 28 18 - - -
// 0 M7N0: 29 - - - - // 0 M7N0: 29 - - - -
// 0 M7N1: 30 - - - - // 0 M7N1: 30 19 - - -
// 0 M7N2: 31 - - 4 - // 0 M7N2: 31 - - 4 -
// 0 M7N3: 32 18 - - - // 0 M7N3: 32 20 - - -
// 0 M0N0K1: 33 - - - - // 0 M0N0K1: 33 - - - 9
// 0 M0N1: 34 - - - 10 // 0 M0N1: 34 21 - - -
// 0 M0N2: 35 - - - - // 0 M0N2: 35 - - - 10
// 0 M0N3: 36 20 - - - // 0 M0N3: 36 22 - - -
// 0 M1N0: 37 - - - - // 0 M1N0: 37 - - - 11
// 0 M1N1: 38 - - - 12 // 0 M1N1: 38 23 - - -
// 0 M1N2: 39 - - - - // 0 M1N2: 39 - - - 12
// 0 M1N3: 40 22 - - - // 0 M1N3: 40 24 - - -
// 0 M2N0: 41 - - - - // 0 M2N0: 41 - - - 13
// 0 M2N1: 42 - - - 14 // 0 M2N1: 42 25 - - -
// 0 M2N2: 43 - - - - // 0 M2N2: 43 - - - 14
// 0 M2N3: 44 24 - - - // 0 M2N3: 44 26 - - -
// 0 M3N0: 45 - 5 - - // 0 M3N0: 45 - 5 - 15
// 0 M3N1: 46 - - - 16 // 0 M3N1: 46 27 - - -
// 0 M3N2: 47 - - - - // 0 M3N2: 47 - - - 16
// 0 M3N3: 48 26 - - - // 0 M3N3: 48 28 - - -
// 0 M4N0: 49 - 6 - - // 0 M4N0: 49 - 6 - -
// 0 M4N1: 50 - - - - // 0 M4N1: 50 29 - - -
// 0 M4N2: 51 - - 5 - // 0 M4N2: 51 - - 5 -
// 0 M4N3: 52 28 - - - // 0 M4N3: 52 30 - - -
// 0 M5N0: 53 - 7 - - // 0 M5N0: 53 - 7 - -
// 0 M5N1: 54 - - - - // 0 M5N1: 54 31 - - -
// 0 M5N2: 55 - - 6 - // 0 M5N2: 55 - - 6 -
// 0 M5N3: 56 30 - - - // 0 M5N3: 56 32 - - -
// 0 M6N0: 57 - 8 - - // 0 M6N0: 57 - 8 - -
// 0 M6N1: 58 - - - - // 0 M6N1: 58 1 - - -
// 0 M6N2: 59 - - 7 - // 0 M6N2: 59 - - 7 -
// 0 M6N3: 60 2 - - - // 0 M6N3: 60 2 - - -
// 0 M7N0: 61 - - - - // 0 M7N0: 61 - - - -
// 0 M7N1: 62 - - - - // 0 M7N1: 62 3 - - -
// 0 M7N2: 63 - - 8 - // 0 M7N2: 63 - - 8 -
// 0 M7N3: 64 4 - - - // 0 M7N3: 64 4 - - -
if constexpr(warp_m == 16 && warp_n == 16) #pragma unroll
for(int kIter = 0; kIter < KIterPerWarp; kIter++)
{ {
// MFMA -> VMEM READ -> MFMA -> DS Read -> MFMA #pragma unroll
// hiding the glbal memory VMEM latency for(int mIter = 0; mIter < MIterPerWarp; mIter++)
#if defined(__gfx950__)
if constexpr(kMPerBlock == 128 && kNPerBlock == 256 && kKPerBlock == 256)
{ {
static_for<0, 2, 1>{}([&](auto j) { index_t dsread_perM = 0;
ignore = j; index_t dswrite_perM = 0;
static_for<0, 3, 1>{}([&](auto i) { index_t load_perM = 0;
ignore = i;
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
});
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
static_for<0, 3, 1>{}([&](auto i) { // Calculate ds_read number per M
ignore = i; dsread_perM = dsread_per_wg;
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
});
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
});
__builtin_amdgcn_sched_barrier(0); // Calculate ds_write number per M
if(mIter == 0)
{
dswrite_perM =
(dswrite_num_perK - (MIterPerWarp - DsWritePreIssue) * dswrite_rep) > 0
? dswrite_num_perK - (MIterPerWarp - DsWritePreIssue) * dswrite_rep
: 0;
}
else if(mIter >= MIterPerWarp - DsWritePreIssue + 1)
{
dswrite_perM = 0;
}
else
{
dswrite_perM = (dswrite_num_perK -
(MIterPerWarp - DsWritePreIssue - mIter) * dswrite_rep) > 0
? dswrite_rep
: 0;
}
// Add ds write when ds write data > needed
if(dswrite_num_perK == 0 && kIter == (KIterPerWarp - 1 - dswrite_kIter))
{
if(mIter == MIterPerWarp - 1 - dswrite_mIter)
dswrite_perM = 1;
}
// Calculate buffer_load number per M
if(mIter < HalfMIter)
{
load_perM =
((Aload_num_perK - (MIterPerWarp - 1 - mIter) * Aload_rep) > 0 ? Aload_rep
: 0) +
((Bload_num_perK - (HalfMIter - 1 - mIter) * Bload_rep) > 0 ? Bload_rep
: 0);
}
else
{
load_perM = (Aload_num_perK - (MIterPerWarp - 1 - mIter) * Aload_rep) > 0
? Aload_rep
: 0;
}
SchedulerPerM(dsread_perM, dswrite_perM, load_perM);
} }
else
{
static_for<0, 2, 1>{}([&](auto j) {
ignore = j;
static_for<0, 3, 1>{}([&](auto i) {
ignore = i;
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
});
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
static_for<0, 3, 1>{}([&](auto i) {
ignore = i;
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
});
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
});
__builtin_amdgcn_sched_barrier(0);
}
// MFMA → MFMA → MFMA → MFMA → DS Read
// For other device engine we need more aggressive MFMA with DS writes interleaved
#else
if constexpr(kMPerBlock == 128 && kNPerBlock == 256 && kKPerBlock == 256)
{
static_for<0, 2, 1>{}([&](auto j) {
ignore = j;
// Uses loops to amortize scheduling overhead
static_for<0, 4, 1>{}([&](auto i) {
ignore = i;
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
});
static_for<0, 1, 1>{}([&](auto i) {
ignore = i;
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
});
static_for<0, 1, 1>{}([&](auto i) {
ignore = i;
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
});
static_for<0, 1, 1>{}([&](auto i) {
ignore = i;
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
});
static_for<0, 1, 1>{}([&](auto i) {
ignore = i;
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
});
});
__builtin_amdgcn_sched_barrier(0);
}
else if constexpr(kMPerBlock == 16 && kNPerBlock == 64 && kKPerBlock == 256)
{
static_for<0, 1, 1>{}([&](auto i) {
ignore = i;
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
});
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
__builtin_amdgcn_sched_barrier(0);
}
else if constexpr(kMPerBlock == 128 && kNPerBlock == 128 && kKPerBlock == 128)
{
// prioritize MFMA to avoid LDS write conflicts
static_for<0, 2, 1>{}([&](auto j) {
ignore = j;
static_for<0, 2, 1>{}([&](auto i) {
ignore = i;
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
});
static_for<0, 2, 1>{}([&](auto i) {
ignore = i;
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
});
static_for<0, 1, 1>{}([&](auto i) {
ignore = i;
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
});
static_for<0, 1, 1>{}([&](auto i) {
ignore = i;
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
});
static_for<0, 1, 1>{}([&](auto i) {
ignore = i;
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
});
static_for<0, 1, 1>{}([&](auto i) {
ignore = i;
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
});
});
__builtin_amdgcn_sched_barrier(0);
}
else
{
static_for<0, A_Buffer_Load_Inst_Num, 1>{}([&](auto i) {
ignore = i;
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
});
static_for<0, A_LDS_Read_Inst_Num - A_Buffer_Load_Inst_Num, 1>{}([&](auto i) {
ignore = i;
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
__builtin_amdgcn_sched_group_barrier(0x008, 3, 0); // MFMA
});
static_for<0, B_Buffer_Load_Inst_Num, 1>{}([&](auto i) {
ignore = i;
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
__builtin_amdgcn_sched_group_barrier(0x008, 2, 0); // MFMA
});
static_for<0, A_Buffer_Load_Inst_Num, 1>{}([&](auto i) {
ignore = i;
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
__builtin_amdgcn_sched_group_barrier(0x008, 4, 0); // MFMA
});
}
#endif
} }
else // Add Aload when Aload data > needed
if(Aload_num_perK == 0)
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
__builtin_amdgcn_sched_barrier(0);
}
CK_TILE_HOST_DEVICE static constexpr auto Last2ndHotLoopScheduler()
{
#pragma unroll
for(int kIter = 0; kIter < KIterPerWarp; kIter++)
{ {
if constexpr((A_LDS_Read_Inst_Num / 2 > #pragma unroll
A_Buffer_Load_Inst_Num + B_Buffer_Load_Inst_Num)) for(int mIter = 0; mIter < MIterPerWarp; mIter++)
{ {
static_for<0, index_t dsread_perM = 0;
A_LDS_Read_Inst_Num / 2 - A_Buffer_Load_Inst_Num - index_t dswrite_perM = 0;
B_Buffer_Load_Inst_Num, index_t load_perM = 0;
1>{}([&](auto i) {
ignore = i; // Calculate ds_read number per M
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read dsread_perM = dsread_per_wg;
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
}); // Calculate ds_write number per M
if(mIter == 0)
{
dswrite_perM =
(dswrite_num_perK - (MIterPerWarp - DsWritePreIssue) * dswrite_rep) > 0
? dswrite_num_perK - (MIterPerWarp - DsWritePreIssue) * dswrite_rep
: 0;
}
else if(mIter >= MIterPerWarp - DsWritePreIssue + 1)
{
dswrite_perM = 0;
}
else
{
dswrite_perM = (dswrite_num_perK -
(MIterPerWarp - DsWritePreIssue - mIter) * dswrite_rep) > 0
? dswrite_rep
: 0;
}
// Add ds write when ds write data > needed
if(dswrite_num_perK == 0 && kIter == (KIterPerWarp - 1 - dswrite_kIter))
{
if(mIter == MIterPerWarp - 1 - dswrite_mIter)
dswrite_perM = 1;
}
// Calculate buffer_load number per M
if(mIter < HalfMIter)
{
load_perM =
((Bload_num_perK - (HalfMIter - 1 - mIter) * Bload_rep) > 0 ? Bload_rep
: 0);
}
SchedulerPerM(dsread_perM, dswrite_perM, load_perM);
} }
static_for<0, A_Buffer_Load_Inst_Num, 1>{}([&](auto i) {
ignore = i;
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
});
static_for<0, A_LDS_Read_Inst_Num / 2, 1>{}([&](auto i) {
ignore = i;
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
});
static_for<0, B_Buffer_Load_Inst_Num, 1>{}([&](auto i) {
ignore = i;
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
});
static_for<0, A_Buffer_Load_Inst_Num, 1>{}([&](auto i) {
ignore = i;
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
__builtin_amdgcn_sched_group_barrier(0x008, 3, 0); // MFMA
});
__builtin_amdgcn_sched_group_barrier(0x008, 4, 0); // MFMA
} }
__builtin_amdgcn_sched_barrier(0);
}
CK_TILE_HOST_DEVICE static constexpr auto LastHotLoopScheduler()
{
#pragma unroll
for(int kIter = 0; kIter < KIterPerWarp; kIter++)
{
#pragma unroll
for(int mIter = 0; mIter < MIterPerWarp; mIter++)
{
index_t dsread_perM = 0;
index_t dswrite_perM = 0;
index_t load_perM = 0;
// Calculate ds_read number per M
if((kIter * MIterPerWarp + mIter) < (KIterPerWarp * MIterPerWarp - m_preload))
dsread_perM = dsread_per_wg;
SchedulerPerM(dsread_perM, dswrite_perM, load_perM);
}
}
// __builtin_amdgcn_sched_barrier(0);
} }
template <TailNumber TailNum, template <TailNumber TailNum,
@@ -738,22 +666,19 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2
block_sync_lds(); block_sync_lds();
// preload A00,A10 from lds // preload A00,A10 from lds
constexpr auto m_preload = (MIterPerWarp * KIterPerWarp >= 2) ? 2 : 1;
statically_indexed_array<decltype(load_tile(a_warp_windows_ping(number<0>{})(number<0>{}))), statically_indexed_array<decltype(load_tile(a_warp_windows_ping(number<0>{})(number<0>{}))),
m_preload> m_preload>
a_warp_tensor_ping; a_warp_tensor;
statically_indexed_array<decltype(load_tile(a_warp_windows_pong(number<0>{})(number<0>{}))),
m_preload>
a_warp_tensor_pong;
static_for<0, m_preload, 1>{}([&](auto loadIter) { static_for<0, m_preload, 1>{}([&](auto loadIter) {
constexpr auto mIter = loadIter % MIterPerWarp; constexpr auto mIter = loadIter % MIterPerWarp;
constexpr auto kIter = loadIter / MIterPerWarp; constexpr auto kIter = loadIter / MIterPerWarp;
a_warp_tensor_ping(loadIter) = a_warp_tensor(loadIter) =
load_tile(a_warp_windows_ping(number<mIter>{})(number<kIter>{})); load_tile(a_warp_windows_ping(number<mIter>{})(number<kIter>{}));
}); });
__builtin_amdgcn_sched_barrier(0); __builtin_amdgcn_sched_barrier(0);
// MAIN LOOP
index_t iCounter = (num_loop - 1) / 2; index_t iCounter = (num_loop - 1) / 2;
while(iCounter > 0) while(iCounter > 0)
{ {
@@ -792,7 +717,7 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2
// warp GEMM // warp GEMM
WG{}(c_warp_tensor, WG{}(c_warp_tensor,
a_warp_tensor_ping(number<AwarpIter>{}), a_warp_tensor(number<AwarpIter>{}),
b_warp_tensor_ping(nIter)(kIter)); b_warp_tensor_ping(nIter)(kIter));
// write C warp tensor into C block tensor // write C warp tensor into C block tensor
@@ -809,7 +734,7 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2
{ {
constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp; constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp;
constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp); constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp);
a_warp_tensor_ping(number<AwarpIter>{}) = a_warp_tensor(number<AwarpIter>{}) =
load_tile(a_warp_windows_ping(number<AmIter>{})(number<AkIter>{})); load_tile(a_warp_windows_ping(number<AmIter>{})(number<AkIter>{}));
} }
@@ -826,7 +751,7 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2
static_for<0, m_preload, 1>{}([&](auto loadIter) { static_for<0, m_preload, 1>{}([&](auto loadIter) {
constexpr auto mIter = loadIter % MIterPerWarp; constexpr auto mIter = loadIter % MIterPerWarp;
constexpr auto kIter = loadIter / MIterPerWarp; constexpr auto kIter = loadIter / MIterPerWarp;
a_warp_tensor_pong(loadIter) = a_warp_tensor(loadIter) =
load_tile(a_warp_windows_pong(number<mIter>{})(number<kIter>{})); load_tile(a_warp_windows_pong(number<mIter>{})(number<kIter>{}));
}); });
HotLoopScheduler(); HotLoopScheduler();
@@ -867,7 +792,7 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2
// warp GEMM // warp GEMM
WG{}(c_warp_tensor, WG{}(c_warp_tensor,
a_warp_tensor_pong(number<AwarpIter>{}), a_warp_tensor(number<AwarpIter>{}),
b_warp_tensor_pong(nIter)(kIter)); b_warp_tensor_pong(nIter)(kIter));
// write C warp tensor into C block tensor // write C warp tensor into C block tensor
@@ -884,7 +809,7 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2
{ {
constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp; constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp;
constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp); constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp);
a_warp_tensor_pong(number<AwarpIter>{}) = a_warp_tensor(number<AwarpIter>{}) =
load_tile(a_warp_windows_pong(number<AmIter>{})(number<AkIter>{})); load_tile(a_warp_windows_pong(number<AmIter>{})(number<AkIter>{}));
} }
@@ -901,7 +826,7 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2
static_for<0, m_preload, 1>{}([&](auto loadIter) { static_for<0, m_preload, 1>{}([&](auto loadIter) {
constexpr auto mIter = loadIter % MIterPerWarp; constexpr auto mIter = loadIter % MIterPerWarp;
constexpr auto kIter = loadIter / MIterPerWarp; constexpr auto kIter = loadIter / MIterPerWarp;
a_warp_tensor_ping(loadIter) = a_warp_tensor(loadIter) =
load_tile(a_warp_windows_ping(number<mIter>{})(number<kIter>{})); load_tile(a_warp_windows_ping(number<mIter>{})(number<kIter>{}));
}); });
HotLoopScheduler(); HotLoopScheduler();
@@ -943,7 +868,7 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2
// warp GEMM // warp GEMM
WG{}(c_warp_tensor, WG{}(c_warp_tensor,
a_warp_tensor_ping(number<AwarpIter>{}), a_warp_tensor(number<AwarpIter>{}),
b_warp_tensor_ping(nIter)(kIter)); b_warp_tensor_ping(nIter)(kIter));
// write C warp tensor into C block tensor // write C warp tensor into C block tensor
@@ -960,7 +885,7 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2
{ {
constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp; constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp;
constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp); constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp);
a_warp_tensor_ping(number<AwarpIter>{}) = a_warp_tensor(number<AwarpIter>{}) =
load_tile(a_warp_windows_ping(number<AmIter>{})(number<AkIter>{})); load_tile(a_warp_windows_ping(number<AmIter>{})(number<AkIter>{}));
} }
@@ -976,11 +901,11 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2
static_for<0, m_preload, 1>{}([&](auto loadIter) { static_for<0, m_preload, 1>{}([&](auto loadIter) {
constexpr auto mIter = loadIter % MIterPerWarp; constexpr auto mIter = loadIter % MIterPerWarp;
constexpr auto kIter = loadIter / MIterPerWarp; constexpr auto kIter = loadIter / MIterPerWarp;
a_warp_tensor_pong(loadIter) = a_warp_tensor(loadIter) =
load_tile(a_warp_windows_pong(number<mIter>{})(number<kIter>{})); load_tile(a_warp_windows_pong(number<mIter>{})(number<kIter>{}));
}); });
// __builtin_amdgcn_sched_barrier(0); Last2ndHotLoopScheduler();
// GEMM loopK // GEMM loopK
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
@@ -996,7 +921,7 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2
// warp GEMM // warp GEMM
WG{}(c_warp_tensor, WG{}(c_warp_tensor,
a_warp_tensor_pong(number<AwarpIter>{}), a_warp_tensor(number<AwarpIter>{}),
b_warp_tensor_pong(nIter)(kIter)); b_warp_tensor_pong(nIter)(kIter));
// write C warp tensor into C block tensor // write C warp tensor into C block tensor
@@ -1004,19 +929,23 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros), merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
c_warp_tensor.get_thread_buffer()); c_warp_tensor.get_thread_buffer());
__builtin_amdgcn_sched_barrier(0x7F6);
}); });
if constexpr((kIter * MIterPerWarp + mIter) < if constexpr((kIter * MIterPerWarp + mIter) <
(KIterPerWarp * MIterPerWarp - m_preload)) (KIterPerWarp * MIterPerWarp - m_preload))
{ {
constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp; constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp;
constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp); constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp);
a_warp_tensor_pong(number<AwarpIter>{}) = a_warp_tensor(number<AwarpIter>{}) =
load_tile(a_warp_windows_pong(number<AmIter>{})(number<AkIter>{})); load_tile(a_warp_windows_pong(number<AmIter>{})(number<AkIter>{}));
} }
// barrier
if constexpr((kIter == KIterPerWarp - 1) && (mIter == MIter_2nd_last))
{
block_sync_lds();
}
}); });
}); });
// TailHotLoopScheduler(); LastHotLoopScheduler();
} }
else if constexpr(TailNum == TailNumber::Odd) else if constexpr(TailNum == TailNumber::Odd)
{ {
@@ -1034,7 +963,7 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2
// warp GEMM // warp GEMM
WG{}(c_warp_tensor, WG{}(c_warp_tensor,
a_warp_tensor_ping(number<AwarpIter>{}), a_warp_tensor(number<AwarpIter>{}),
b_warp_tensor_ping(nIter)(kIter)); b_warp_tensor_ping(nIter)(kIter));
// write C warp tensor into C block tensor // write C warp tensor into C block tensor
@@ -1051,7 +980,7 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2
{ {
constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp; constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp;
constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp); constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp);
a_warp_tensor_ping(number<AwarpIter>{}) = a_warp_tensor(number<AwarpIter>{}) =
load_tile(a_warp_windows_ping(number<AmIter>{})(number<AkIter>{})); load_tile(a_warp_windows_ping(number<AmIter>{})(number<AkIter>{}));
} }
@@ -1062,6 +991,7 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2
} }
}); });
}); });
LastHotLoopScheduler();
} }
return c_block_tile; return c_block_tile;