mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-03-21 23:57:39 +00:00
[CK_TILE] add atomic IGLP scheduler for wp gemm (#2739)
* add atomic IGLP scheduler * clang format --------- Co-authored-by: ThomasNing <thomas.ning@amd.com>
This commit is contained in:
@@ -91,7 +91,7 @@ int main(int argc, char* argv[])
|
||||
|
||||
try
|
||||
{
|
||||
return !run_gemm_example<GemmConfigPreshuffleDecode>(arg_parser);
|
||||
return !run_gemm_example<GemmConfigPreshufflePrefill>(arg_parser);
|
||||
}
|
||||
catch(const std::runtime_error& e)
|
||||
{
|
||||
|
||||
@@ -43,6 +43,7 @@ struct BaseWeightPreshufflePipelineAGmemBGmemCRegV2
|
||||
return run_func(bool_constant<true>{},
|
||||
integral_constant<TailNumber, TailNumber::Even>{});
|
||||
}
|
||||
return run_func(bool_constant<true>{}, integral_constant<TailNumber, TailNumber::Empty>{});
|
||||
}
|
||||
};
|
||||
|
||||
@@ -69,7 +70,11 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2
|
||||
|
||||
using WG = remove_cvref_t<decltype(config.template at<0>())>;
|
||||
|
||||
static constexpr index_t DsWritePreIssue = 3; // default 2, ds write at MIter - 2
|
||||
static constexpr index_t DsReadPreload = 2; // default 2, preload 2 ds read
|
||||
|
||||
static constexpr index_t BlockSize = Problem::kBlockSize;
|
||||
static constexpr index_t WaveSize = get_warp_size();
|
||||
|
||||
static constexpr index_t kMPerBlock = BlockGemmShape::kM;
|
||||
static constexpr index_t kNPerBlock = BlockGemmShape::kN;
|
||||
@@ -129,13 +134,33 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2
|
||||
static constexpr index_t MPerBlockPerIter = kMPerBlock / MIterPerWarp;
|
||||
static constexpr index_t KPerBlockPerIter = kKPerBlock / KIterPerWarp;
|
||||
|
||||
static constexpr index_t K1 = Problem::VectorLoadSize / sizeof(ADataType);
|
||||
static constexpr index_t ACopyLoadNum = kMPerBlock * kKPerBlock / BlockSize / K1;
|
||||
static constexpr auto TailNum = Problem::TailNum;
|
||||
static constexpr index_t K1 = Problem::VectorLoadSize / sizeof(ADataType);
|
||||
static constexpr index_t m_preload = (MIterPerWarp * KIterPerWarp >= DsReadPreload)
|
||||
? DsReadPreload
|
||||
: MIterPerWarp * KIterPerWarp;
|
||||
static constexpr auto TailNum = Problem::TailNum;
|
||||
|
||||
static constexpr auto warp_m = WarpTile::at(idxM);
|
||||
static constexpr auto warp_n = WarpTile::at(idxN);
|
||||
static constexpr auto warp_k = WarpTile::at(idxK);
|
||||
#ifdef __gfx942__
|
||||
static constexpr index_t mfma_per_wg = 2;
|
||||
#else
|
||||
static constexpr index_t mfma_per_wg = 1;
|
||||
#endif
|
||||
static constexpr index_t dsread_per_wg =
|
||||
WG::kM * WG::kK * sizeof(ADataType) / WaveSize / Problem::VectorLoadSize;
|
||||
static_assert((WG::kM * WG::kK * sizeof(ADataType) / WaveSize) % Problem::VectorLoadSize == 0);
|
||||
|
||||
static constexpr index_t dsread_num_perK = dsread_per_wg * MIterPerWarp;
|
||||
static constexpr index_t dswrite_num_perK = dsread_num_perK / (MWarp * NWarp);
|
||||
static constexpr index_t dswrite_rep = (dswrite_num_perK + MIterPerWarp - 1) / MIterPerWarp;
|
||||
static constexpr index_t Aload_num_perK = dswrite_num_perK;
|
||||
static constexpr index_t Aload_rep = dswrite_rep;
|
||||
static constexpr index_t Bload_num_perK = kNPerBlock * WG::kK / NWarp / K1 / WaveSize;
|
||||
static constexpr index_t HalfMIter = (MIterPerWarp + 1) / 2;
|
||||
static constexpr index_t Bload_rep = (Bload_num_perK + HalfMIter - 1) / HalfMIter;
|
||||
|
||||
static constexpr index_t mfma_perM_perK = NIterPerWarp * mfma_per_wg;
|
||||
static constexpr index_t dswrite_mIter = (DsWritePreIssue - 1) % MIterPerWarp;
|
||||
static constexpr index_t dswrite_kIter = (DsWritePreIssue - 1) / MIterPerWarp;
|
||||
|
||||
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
|
||||
{
|
||||
@@ -160,411 +185,314 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2
|
||||
return PipelinePolicy::template GetSmemSize<Problem>();
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto HotLoopScheduler()
|
||||
// dsread_perM: how many LDS reads want to issue in this M-iter
|
||||
// dswrite_perM: how many LDS writes you want to do this M-iter
|
||||
// load_perM: how many global loads VMEM want to do in this M-iter
|
||||
CK_TILE_HOST_DEVICE static constexpr auto
|
||||
SchedulerPerM(index_t dsread_perM, index_t dswrite_perM, index_t load_perM)
|
||||
{
|
||||
|
||||
constexpr index_t KPerLoad = Problem::VectorLoadSize / sizeof(ADataType);
|
||||
constexpr index_t A_Buffer_Load_Inst_Num = kMPerBlock * kKPerBlock / BlockSize / KPerLoad;
|
||||
constexpr index_t A_LDS_Read_Inst_Num = MIterPerWarp * KIterPerWarp;
|
||||
constexpr index_t B_Buffer_Load_Inst_Num = NIterPerWarp * KIterPerWarp;
|
||||
// Init inst order
|
||||
index_t max_data_inst = dsread_perM > load_perM
|
||||
? (dsread_perM > dswrite_perM ? dsread_perM : dswrite_perM)
|
||||
: (load_perM > dswrite_perM ? load_perM : dswrite_perM);
|
||||
index_t sum_data_inst = dsread_perM + load_perM + dswrite_perM;
|
||||
index_t round_data_inst = ck_tile::integer_divide_ceil(sum_data_inst, mfma_perM_perK);
|
||||
|
||||
constexpr int kOrderCap = NIterPerWarp * 10;
|
||||
index_t inst_order[kOrderCap] = {};
|
||||
index_t index = 0;
|
||||
#pragma unroll
|
||||
// round-robin
|
||||
// Index: 0 1 2 3 4 5 ...
|
||||
// Value: 1 2 3 1 2 3 ...
|
||||
for(int j = 0; j < max_data_inst; j++)
|
||||
{
|
||||
if(dswrite_perM > j)
|
||||
{
|
||||
inst_order[index] = 1;
|
||||
index++;
|
||||
}
|
||||
if(load_perM > j)
|
||||
{
|
||||
inst_order[index] = 2;
|
||||
index++;
|
||||
}
|
||||
if(dsread_perM > j)
|
||||
{
|
||||
inst_order[index] = 3;
|
||||
index++;
|
||||
}
|
||||
}
|
||||
|
||||
// Schedule IGLP
|
||||
#pragma unroll
|
||||
for(int j = 0; j < mfma_perM_perK; j++)
|
||||
{
|
||||
index_t inst_idx = 0;
|
||||
if(j == 0)
|
||||
;
|
||||
else if(j == 1)
|
||||
inst_idx = mfma_perM_perK == 2 ? 1 : mfma_perM_perK - 2;
|
||||
else if(j == 2)
|
||||
inst_idx = mfma_perM_perK - 1;
|
||||
else
|
||||
inst_idx = mfma_perM_perK - j;
|
||||
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
|
||||
#pragma unroll
|
||||
for(int r = 0; r < round_data_inst; r++)
|
||||
{
|
||||
if(r % 2 == 0)
|
||||
{
|
||||
if(inst_order[inst_idx + r * mfma_perM_perK] == 1)
|
||||
{
|
||||
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
|
||||
}
|
||||
if(inst_order[inst_idx + r * mfma_perM_perK] == 2)
|
||||
{
|
||||
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
|
||||
}
|
||||
if(inst_order[inst_idx + r * mfma_perM_perK] == 3)
|
||||
{
|
||||
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if(inst_order[(r + 1) * mfma_perM_perK - 1 - inst_idx] == 1)
|
||||
{
|
||||
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
|
||||
}
|
||||
if(inst_order[(r + 1) * mfma_perM_perK - 1 - inst_idx] == 2)
|
||||
{
|
||||
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
|
||||
}
|
||||
if(inst_order[(r + 1) * mfma_perM_perK - 1 - inst_idx] == 3)
|
||||
{
|
||||
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto HotLoopScheduler()
|
||||
{
|
||||
// Keypoint of pipeline optimize is workload balance in time
|
||||
// instruction schedule example(128X256X256, 1X4, 16X16X128):
|
||||
// Iter MNK MFMA ds_read ds_write A_load b_load
|
||||
// -1 M6N0: 57 - 8 - -
|
||||
// -1 M6N1: 58 1 - - -
|
||||
// -1 M6N2: 59 - - 7 -
|
||||
// -1 M6N3: 60 2 - - -
|
||||
// -1 M7N0: 61 - - - -
|
||||
// -1 M7N1: 62 - - - -
|
||||
// -1 M7N2: 63 - - - -
|
||||
// -1 M7N1: 62 3 - - -
|
||||
// -1 M7N2: 63 - - 8 -
|
||||
// -1 M7N3: 64 4 - - -
|
||||
// 0 M0N0K0: 1 - - - -
|
||||
// 0 M0N1: 2 - - - 2
|
||||
// 0 M0N2: 3 - - - -
|
||||
// 0 M0N0K0: 1 - - - 1
|
||||
// 0 M0N1: 2 5 - - -
|
||||
// 0 M0N2: 3 - - - 2
|
||||
// 0 M0N3: 4 6 - - -
|
||||
// 0 M1N0: 5 - - - -
|
||||
// 0 M1N1: 6 - - - 4
|
||||
// 0 M1N2: 7 - - - -
|
||||
// 0 M1N0: 5 - - - 3
|
||||
// 0 M1N1: 6 7 - - -
|
||||
// 0 M1N2: 7 - - - 4
|
||||
// 0 M1N3: 8 8 - - -
|
||||
// 0 M2N0: 9 - - - -
|
||||
// 0 M2N1: 10 - - - 6
|
||||
// 0 M2N2: 11 - - - -
|
||||
// 0 M2N0: 9 - - - 5
|
||||
// 0 M2N1: 10 9 - - -
|
||||
// 0 M2N2: 11 - - - 6
|
||||
// 0 M2N3: 12 10 - - -
|
||||
// 0 M3N0: 13 - 1 - -
|
||||
// 0 M3N1: 14 - - - 8
|
||||
// 0 M3N2: 15 - - - -
|
||||
// 0 M3N0: 13 - 1 - 7
|
||||
// 0 M3N1: 14 11 - - -
|
||||
// 0 M3N2: 15 - - - 8
|
||||
// 0 M3N3: 16 12 - - -
|
||||
// 0 M4N0: 17 - 2 - -
|
||||
// 0 M4N1: 18 - - - -
|
||||
// 0 M4N1: 18 13 - - -
|
||||
// 0 M4N2: 19 - - 1 -
|
||||
// 0 M4N3: 20 14 - - -
|
||||
// 0 M5N0: 21 - 3 - -
|
||||
// 0 M5N1: 22 - - - -
|
||||
// 0 M5N1: 22 15 - - -
|
||||
// 0 M5N2: 23 - - 2 -
|
||||
// 0 M5N3: 24 16 - - -
|
||||
// 0 M6N0: 25 - 4 - -
|
||||
// 0 M6N1: 26 - - - -
|
||||
// 0 M6N1: 26 17 - - -
|
||||
// 0 M6N2: 27 - - 3 -
|
||||
// 0 M6N3: 28 17 - - -
|
||||
// 0 M6N3: 28 18 - - -
|
||||
// 0 M7N0: 29 - - - -
|
||||
// 0 M7N1: 30 - - - -
|
||||
// 0 M7N1: 30 19 - - -
|
||||
// 0 M7N2: 31 - - 4 -
|
||||
// 0 M7N3: 32 18 - - -
|
||||
// 0 M0N0K1: 33 - - - -
|
||||
// 0 M0N1: 34 - - - 10
|
||||
// 0 M0N2: 35 - - - -
|
||||
// 0 M0N3: 36 20 - - -
|
||||
// 0 M1N0: 37 - - - -
|
||||
// 0 M1N1: 38 - - - 12
|
||||
// 0 M1N2: 39 - - - -
|
||||
// 0 M1N3: 40 22 - - -
|
||||
// 0 M2N0: 41 - - - -
|
||||
// 0 M2N1: 42 - - - 14
|
||||
// 0 M2N2: 43 - - - -
|
||||
// 0 M2N3: 44 24 - - -
|
||||
// 0 M3N0: 45 - 5 - -
|
||||
// 0 M3N1: 46 - - - 16
|
||||
// 0 M3N2: 47 - - - -
|
||||
// 0 M3N3: 48 26 - - -
|
||||
// 0 M7N3: 32 20 - - -
|
||||
// 0 M0N0K1: 33 - - - 9
|
||||
// 0 M0N1: 34 21 - - -
|
||||
// 0 M0N2: 35 - - - 10
|
||||
// 0 M0N3: 36 22 - - -
|
||||
// 0 M1N0: 37 - - - 11
|
||||
// 0 M1N1: 38 23 - - -
|
||||
// 0 M1N2: 39 - - - 12
|
||||
// 0 M1N3: 40 24 - - -
|
||||
// 0 M2N0: 41 - - - 13
|
||||
// 0 M2N1: 42 25 - - -
|
||||
// 0 M2N2: 43 - - - 14
|
||||
// 0 M2N3: 44 26 - - -
|
||||
// 0 M3N0: 45 - 5 - 15
|
||||
// 0 M3N1: 46 27 - - -
|
||||
// 0 M3N2: 47 - - - 16
|
||||
// 0 M3N3: 48 28 - - -
|
||||
// 0 M4N0: 49 - 6 - -
|
||||
// 0 M4N1: 50 - - - -
|
||||
// 0 M4N1: 50 29 - - -
|
||||
// 0 M4N2: 51 - - 5 -
|
||||
// 0 M4N3: 52 28 - - -
|
||||
// 0 M4N3: 52 30 - - -
|
||||
// 0 M5N0: 53 - 7 - -
|
||||
// 0 M5N1: 54 - - - -
|
||||
// 0 M5N1: 54 31 - - -
|
||||
// 0 M5N2: 55 - - 6 -
|
||||
// 0 M5N3: 56 30 - - -
|
||||
// 0 M5N3: 56 32 - - -
|
||||
// 0 M6N0: 57 - 8 - -
|
||||
// 0 M6N1: 58 - - - -
|
||||
// 0 M6N1: 58 1 - - -
|
||||
// 0 M6N2: 59 - - 7 -
|
||||
// 0 M6N3: 60 2 - - -
|
||||
// 0 M7N0: 61 - - - -
|
||||
// 0 M7N1: 62 - - - -
|
||||
// 0 M7N1: 62 3 - - -
|
||||
// 0 M7N2: 63 - - 8 -
|
||||
// 0 M7N3: 64 4 - - -
|
||||
|
||||
if constexpr(warp_m == 16 && warp_n == 16)
|
||||
#pragma unroll
|
||||
for(int kIter = 0; kIter < KIterPerWarp; kIter++)
|
||||
{
|
||||
// MFMA -> VMEM READ -> MFMA -> DS Read -> MFMA
|
||||
// hiding the glbal memory VMEM latency
|
||||
#if defined(__gfx950__)
|
||||
if constexpr(kMPerBlock == 128 && kNPerBlock == 256 && kKPerBlock == 256)
|
||||
#pragma unroll
|
||||
for(int mIter = 0; mIter < MIterPerWarp; mIter++)
|
||||
{
|
||||
static_for<0, 2, 1>{}([&](auto j) {
|
||||
ignore = j;
|
||||
static_for<0, 3, 1>{}([&](auto i) {
|
||||
ignore = i;
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
|
||||
});
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
|
||||
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
|
||||
index_t dsread_perM = 0;
|
||||
index_t dswrite_perM = 0;
|
||||
index_t load_perM = 0;
|
||||
|
||||
static_for<0, 3, 1>{}([&](auto i) {
|
||||
ignore = i;
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
|
||||
});
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
|
||||
});
|
||||
// Calculate ds_read number per M
|
||||
dsread_perM = dsread_per_wg;
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
// Calculate ds_write number per M
|
||||
if(mIter == 0)
|
||||
{
|
||||
dswrite_perM =
|
||||
(dswrite_num_perK - (MIterPerWarp - DsWritePreIssue) * dswrite_rep) > 0
|
||||
? dswrite_num_perK - (MIterPerWarp - DsWritePreIssue) * dswrite_rep
|
||||
: 0;
|
||||
}
|
||||
else if(mIter >= MIterPerWarp - DsWritePreIssue + 1)
|
||||
{
|
||||
dswrite_perM = 0;
|
||||
}
|
||||
else
|
||||
{
|
||||
dswrite_perM = (dswrite_num_perK -
|
||||
(MIterPerWarp - DsWritePreIssue - mIter) * dswrite_rep) > 0
|
||||
? dswrite_rep
|
||||
: 0;
|
||||
}
|
||||
// Add ds write when ds write data > needed
|
||||
if(dswrite_num_perK == 0 && kIter == (KIterPerWarp - 1 - dswrite_kIter))
|
||||
{
|
||||
if(mIter == MIterPerWarp - 1 - dswrite_mIter)
|
||||
dswrite_perM = 1;
|
||||
}
|
||||
|
||||
// Calculate buffer_load number per M
|
||||
if(mIter < HalfMIter)
|
||||
{
|
||||
load_perM =
|
||||
((Aload_num_perK - (MIterPerWarp - 1 - mIter) * Aload_rep) > 0 ? Aload_rep
|
||||
: 0) +
|
||||
((Bload_num_perK - (HalfMIter - 1 - mIter) * Bload_rep) > 0 ? Bload_rep
|
||||
: 0);
|
||||
}
|
||||
else
|
||||
{
|
||||
load_perM = (Aload_num_perK - (MIterPerWarp - 1 - mIter) * Aload_rep) > 0
|
||||
? Aload_rep
|
||||
: 0;
|
||||
}
|
||||
SchedulerPerM(dsread_perM, dswrite_perM, load_perM);
|
||||
}
|
||||
else
|
||||
{
|
||||
static_for<0, 2, 1>{}([&](auto j) {
|
||||
ignore = j;
|
||||
static_for<0, 3, 1>{}([&](auto i) {
|
||||
ignore = i;
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
|
||||
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
|
||||
});
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
|
||||
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
|
||||
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
|
||||
|
||||
static_for<0, 3, 1>{}([&](auto i) {
|
||||
ignore = i;
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
|
||||
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
|
||||
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
|
||||
});
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
|
||||
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
|
||||
});
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
}
|
||||
// MFMA → MFMA → MFMA → MFMA → DS Read
|
||||
// For other device engine we need more aggressive MFMA with DS writes interleaved
|
||||
#else
|
||||
if constexpr(kMPerBlock == 128 && kNPerBlock == 256 && kKPerBlock == 256)
|
||||
{
|
||||
static_for<0, 2, 1>{}([&](auto j) {
|
||||
ignore = j;
|
||||
// Uses loops to amortize scheduling overhead
|
||||
static_for<0, 4, 1>{}([&](auto i) {
|
||||
ignore = i;
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
|
||||
});
|
||||
static_for<0, 1, 1>{}([&](auto i) {
|
||||
ignore = i;
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
|
||||
});
|
||||
|
||||
static_for<0, 1, 1>{}([&](auto i) {
|
||||
ignore = i;
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
|
||||
});
|
||||
static_for<0, 1, 1>{}([&](auto i) {
|
||||
ignore = i;
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
|
||||
});
|
||||
static_for<0, 1, 1>{}([&](auto i) {
|
||||
ignore = i;
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
|
||||
});
|
||||
});
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
}
|
||||
else if constexpr(kMPerBlock == 16 && kNPerBlock == 64 && kKPerBlock == 256)
|
||||
{
|
||||
static_for<0, 1, 1>{}([&](auto i) {
|
||||
ignore = i;
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
|
||||
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
|
||||
});
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
|
||||
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
|
||||
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
|
||||
|
||||
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
}
|
||||
else if constexpr(kMPerBlock == 128 && kNPerBlock == 128 && kKPerBlock == 128)
|
||||
{
|
||||
// prioritize MFMA to avoid LDS write conflicts
|
||||
static_for<0, 2, 1>{}([&](auto j) {
|
||||
ignore = j;
|
||||
static_for<0, 2, 1>{}([&](auto i) {
|
||||
ignore = i;
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
|
||||
});
|
||||
static_for<0, 2, 1>{}([&](auto i) {
|
||||
ignore = i;
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
|
||||
});
|
||||
static_for<0, 1, 1>{}([&](auto i) {
|
||||
ignore = i;
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
|
||||
});
|
||||
static_for<0, 1, 1>{}([&](auto i) {
|
||||
ignore = i;
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
|
||||
});
|
||||
static_for<0, 1, 1>{}([&](auto i) {
|
||||
ignore = i;
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
|
||||
});
|
||||
static_for<0, 1, 1>{}([&](auto i) {
|
||||
ignore = i;
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
|
||||
});
|
||||
});
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
}
|
||||
else
|
||||
{
|
||||
static_for<0, A_Buffer_Load_Inst_Num, 1>{}([&](auto i) {
|
||||
ignore = i;
|
||||
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
|
||||
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
});
|
||||
static_for<0, A_LDS_Read_Inst_Num - A_Buffer_Load_Inst_Num, 1>{}([&](auto i) {
|
||||
ignore = i;
|
||||
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 3, 0); // MFMA
|
||||
});
|
||||
static_for<0, B_Buffer_Load_Inst_Num, 1>{}([&](auto i) {
|
||||
ignore = i;
|
||||
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 2, 0); // MFMA
|
||||
});
|
||||
static_for<0, A_Buffer_Load_Inst_Num, 1>{}([&](auto i) {
|
||||
ignore = i;
|
||||
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 4, 0); // MFMA
|
||||
});
|
||||
}
|
||||
|
||||
#endif
|
||||
}
|
||||
else
|
||||
// Add Aload when Aload data > needed
|
||||
if(Aload_num_perK == 0)
|
||||
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto Last2ndHotLoopScheduler()
|
||||
{
|
||||
#pragma unroll
|
||||
for(int kIter = 0; kIter < KIterPerWarp; kIter++)
|
||||
{
|
||||
if constexpr((A_LDS_Read_Inst_Num / 2 >
|
||||
A_Buffer_Load_Inst_Num + B_Buffer_Load_Inst_Num))
|
||||
#pragma unroll
|
||||
for(int mIter = 0; mIter < MIterPerWarp; mIter++)
|
||||
{
|
||||
static_for<0,
|
||||
A_LDS_Read_Inst_Num / 2 - A_Buffer_Load_Inst_Num -
|
||||
B_Buffer_Load_Inst_Num,
|
||||
1>{}([&](auto i) {
|
||||
ignore = i;
|
||||
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
});
|
||||
index_t dsread_perM = 0;
|
||||
index_t dswrite_perM = 0;
|
||||
index_t load_perM = 0;
|
||||
|
||||
// Calculate ds_read number per M
|
||||
dsread_perM = dsread_per_wg;
|
||||
|
||||
// Calculate ds_write number per M
|
||||
if(mIter == 0)
|
||||
{
|
||||
dswrite_perM =
|
||||
(dswrite_num_perK - (MIterPerWarp - DsWritePreIssue) * dswrite_rep) > 0
|
||||
? dswrite_num_perK - (MIterPerWarp - DsWritePreIssue) * dswrite_rep
|
||||
: 0;
|
||||
}
|
||||
else if(mIter >= MIterPerWarp - DsWritePreIssue + 1)
|
||||
{
|
||||
dswrite_perM = 0;
|
||||
}
|
||||
else
|
||||
{
|
||||
dswrite_perM = (dswrite_num_perK -
|
||||
(MIterPerWarp - DsWritePreIssue - mIter) * dswrite_rep) > 0
|
||||
? dswrite_rep
|
||||
: 0;
|
||||
}
|
||||
// Add ds write when ds write data > needed
|
||||
if(dswrite_num_perK == 0 && kIter == (KIterPerWarp - 1 - dswrite_kIter))
|
||||
{
|
||||
if(mIter == MIterPerWarp - 1 - dswrite_mIter)
|
||||
dswrite_perM = 1;
|
||||
}
|
||||
|
||||
// Calculate buffer_load number per M
|
||||
if(mIter < HalfMIter)
|
||||
{
|
||||
load_perM =
|
||||
((Bload_num_perK - (HalfMIter - 1 - mIter) * Bload_rep) > 0 ? Bload_rep
|
||||
: 0);
|
||||
}
|
||||
SchedulerPerM(dsread_perM, dswrite_perM, load_perM);
|
||||
}
|
||||
static_for<0, A_Buffer_Load_Inst_Num, 1>{}([&](auto i) {
|
||||
ignore = i;
|
||||
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
});
|
||||
static_for<0, A_LDS_Read_Inst_Num / 2, 1>{}([&](auto i) {
|
||||
ignore = i;
|
||||
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
});
|
||||
static_for<0, B_Buffer_Load_Inst_Num, 1>{}([&](auto i) {
|
||||
ignore = i;
|
||||
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
});
|
||||
static_for<0, A_Buffer_Load_Inst_Num, 1>{}([&](auto i) {
|
||||
ignore = i;
|
||||
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 3, 0); // MFMA
|
||||
});
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 4, 0); // MFMA
|
||||
}
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto LastHotLoopScheduler()
|
||||
{
|
||||
#pragma unroll
|
||||
for(int kIter = 0; kIter < KIterPerWarp; kIter++)
|
||||
{
|
||||
#pragma unroll
|
||||
for(int mIter = 0; mIter < MIterPerWarp; mIter++)
|
||||
{
|
||||
index_t dsread_perM = 0;
|
||||
index_t dswrite_perM = 0;
|
||||
index_t load_perM = 0;
|
||||
|
||||
// Calculate ds_read number per M
|
||||
if((kIter * MIterPerWarp + mIter) < (KIterPerWarp * MIterPerWarp - m_preload))
|
||||
dsread_perM = dsread_per_wg;
|
||||
|
||||
SchedulerPerM(dsread_perM, dswrite_perM, load_perM);
|
||||
}
|
||||
}
|
||||
// __builtin_amdgcn_sched_barrier(0);
|
||||
}
|
||||
|
||||
template <TailNumber TailNum,
|
||||
@@ -738,22 +666,19 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2
|
||||
block_sync_lds();
|
||||
|
||||
// preload A00,A10 from lds
|
||||
constexpr auto m_preload = (MIterPerWarp * KIterPerWarp >= 2) ? 2 : 1;
|
||||
statically_indexed_array<decltype(load_tile(a_warp_windows_ping(number<0>{})(number<0>{}))),
|
||||
m_preload>
|
||||
a_warp_tensor_ping;
|
||||
statically_indexed_array<decltype(load_tile(a_warp_windows_pong(number<0>{})(number<0>{}))),
|
||||
m_preload>
|
||||
a_warp_tensor_pong;
|
||||
a_warp_tensor;
|
||||
|
||||
static_for<0, m_preload, 1>{}([&](auto loadIter) {
|
||||
constexpr auto mIter = loadIter % MIterPerWarp;
|
||||
constexpr auto kIter = loadIter / MIterPerWarp;
|
||||
a_warp_tensor_ping(loadIter) =
|
||||
a_warp_tensor(loadIter) =
|
||||
load_tile(a_warp_windows_ping(number<mIter>{})(number<kIter>{}));
|
||||
});
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
// MAIN LOOP
|
||||
index_t iCounter = (num_loop - 1) / 2;
|
||||
while(iCounter > 0)
|
||||
{
|
||||
@@ -792,7 +717,7 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2
|
||||
|
||||
// warp GEMM
|
||||
WG{}(c_warp_tensor,
|
||||
a_warp_tensor_ping(number<AwarpIter>{}),
|
||||
a_warp_tensor(number<AwarpIter>{}),
|
||||
b_warp_tensor_ping(nIter)(kIter));
|
||||
|
||||
// write C warp tensor into C block tensor
|
||||
@@ -809,7 +734,7 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2
|
||||
{
|
||||
constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp;
|
||||
constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp);
|
||||
a_warp_tensor_ping(number<AwarpIter>{}) =
|
||||
a_warp_tensor(number<AwarpIter>{}) =
|
||||
load_tile(a_warp_windows_ping(number<AmIter>{})(number<AkIter>{}));
|
||||
}
|
||||
|
||||
@@ -826,7 +751,7 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2
|
||||
static_for<0, m_preload, 1>{}([&](auto loadIter) {
|
||||
constexpr auto mIter = loadIter % MIterPerWarp;
|
||||
constexpr auto kIter = loadIter / MIterPerWarp;
|
||||
a_warp_tensor_pong(loadIter) =
|
||||
a_warp_tensor(loadIter) =
|
||||
load_tile(a_warp_windows_pong(number<mIter>{})(number<kIter>{}));
|
||||
});
|
||||
HotLoopScheduler();
|
||||
@@ -867,7 +792,7 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2
|
||||
|
||||
// warp GEMM
|
||||
WG{}(c_warp_tensor,
|
||||
a_warp_tensor_pong(number<AwarpIter>{}),
|
||||
a_warp_tensor(number<AwarpIter>{}),
|
||||
b_warp_tensor_pong(nIter)(kIter));
|
||||
|
||||
// write C warp tensor into C block tensor
|
||||
@@ -884,7 +809,7 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2
|
||||
{
|
||||
constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp;
|
||||
constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp);
|
||||
a_warp_tensor_pong(number<AwarpIter>{}) =
|
||||
a_warp_tensor(number<AwarpIter>{}) =
|
||||
load_tile(a_warp_windows_pong(number<AmIter>{})(number<AkIter>{}));
|
||||
}
|
||||
|
||||
@@ -901,7 +826,7 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2
|
||||
static_for<0, m_preload, 1>{}([&](auto loadIter) {
|
||||
constexpr auto mIter = loadIter % MIterPerWarp;
|
||||
constexpr auto kIter = loadIter / MIterPerWarp;
|
||||
a_warp_tensor_ping(loadIter) =
|
||||
a_warp_tensor(loadIter) =
|
||||
load_tile(a_warp_windows_ping(number<mIter>{})(number<kIter>{}));
|
||||
});
|
||||
HotLoopScheduler();
|
||||
@@ -943,7 +868,7 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2
|
||||
|
||||
// warp GEMM
|
||||
WG{}(c_warp_tensor,
|
||||
a_warp_tensor_ping(number<AwarpIter>{}),
|
||||
a_warp_tensor(number<AwarpIter>{}),
|
||||
b_warp_tensor_ping(nIter)(kIter));
|
||||
|
||||
// write C warp tensor into C block tensor
|
||||
@@ -960,7 +885,7 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2
|
||||
{
|
||||
constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp;
|
||||
constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp);
|
||||
a_warp_tensor_ping(number<AwarpIter>{}) =
|
||||
a_warp_tensor(number<AwarpIter>{}) =
|
||||
load_tile(a_warp_windows_ping(number<AmIter>{})(number<AkIter>{}));
|
||||
}
|
||||
|
||||
@@ -976,11 +901,11 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2
|
||||
static_for<0, m_preload, 1>{}([&](auto loadIter) {
|
||||
constexpr auto mIter = loadIter % MIterPerWarp;
|
||||
constexpr auto kIter = loadIter / MIterPerWarp;
|
||||
a_warp_tensor_pong(loadIter) =
|
||||
a_warp_tensor(loadIter) =
|
||||
load_tile(a_warp_windows_pong(number<mIter>{})(number<kIter>{}));
|
||||
});
|
||||
|
||||
// __builtin_amdgcn_sched_barrier(0);
|
||||
Last2ndHotLoopScheduler();
|
||||
|
||||
// GEMM loopK
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
@@ -996,7 +921,7 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2
|
||||
|
||||
// warp GEMM
|
||||
WG{}(c_warp_tensor,
|
||||
a_warp_tensor_pong(number<AwarpIter>{}),
|
||||
a_warp_tensor(number<AwarpIter>{}),
|
||||
b_warp_tensor_pong(nIter)(kIter));
|
||||
|
||||
// write C warp tensor into C block tensor
|
||||
@@ -1004,19 +929,23 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
|
||||
c_warp_tensor.get_thread_buffer());
|
||||
__builtin_amdgcn_sched_barrier(0x7F6);
|
||||
});
|
||||
if constexpr((kIter * MIterPerWarp + mIter) <
|
||||
(KIterPerWarp * MIterPerWarp - m_preload))
|
||||
{
|
||||
constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp;
|
||||
constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp);
|
||||
a_warp_tensor_pong(number<AwarpIter>{}) =
|
||||
a_warp_tensor(number<AwarpIter>{}) =
|
||||
load_tile(a_warp_windows_pong(number<AmIter>{})(number<AkIter>{}));
|
||||
}
|
||||
// barrier
|
||||
if constexpr((kIter == KIterPerWarp - 1) && (mIter == MIter_2nd_last))
|
||||
{
|
||||
block_sync_lds();
|
||||
}
|
||||
});
|
||||
});
|
||||
// TailHotLoopScheduler();
|
||||
LastHotLoopScheduler();
|
||||
}
|
||||
else if constexpr(TailNum == TailNumber::Odd)
|
||||
{
|
||||
@@ -1034,7 +963,7 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2
|
||||
|
||||
// warp GEMM
|
||||
WG{}(c_warp_tensor,
|
||||
a_warp_tensor_ping(number<AwarpIter>{}),
|
||||
a_warp_tensor(number<AwarpIter>{}),
|
||||
b_warp_tensor_ping(nIter)(kIter));
|
||||
|
||||
// write C warp tensor into C block tensor
|
||||
@@ -1051,7 +980,7 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2
|
||||
{
|
||||
constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp;
|
||||
constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp);
|
||||
a_warp_tensor_ping(number<AwarpIter>{}) =
|
||||
a_warp_tensor(number<AwarpIter>{}) =
|
||||
load_tile(a_warp_windows_ping(number<AmIter>{})(number<AkIter>{}));
|
||||
}
|
||||
|
||||
@@ -1062,6 +991,7 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2
|
||||
}
|
||||
});
|
||||
});
|
||||
LastHotLoopScheduler();
|
||||
}
|
||||
|
||||
return c_block_tile;
|
||||
|
||||
Reference in New Issue
Block a user