mixed-prec flatmm pipeline improve

This commit is contained in:
yadaish
2025-11-18 10:12:11 +00:00
parent bb3d2f5be6
commit b99c48da2e

View File

@@ -71,7 +71,11 @@ struct F16xMXF4FlatmmPipelineAGmemBGmemCRegV1
using WG = remove_cvref_t<decltype(config.template at<0>())>;
static constexpr index_t DsWritePreIssue = 3; // default 2, ds write at MIter - 2
#if CKTILE_FLATMM_USE_BUFFER_LOAD_LDS
static constexpr index_t DsReadPreload = 16; // default 8, if using lds, register pressure is alleviated, improve preload
#else
static constexpr index_t DsReadPreload = 2; // default 2, preload 2 ds read
#endif
static constexpr index_t BlockSize = Problem::kBlockSize;
static constexpr index_t WaveSize = get_warp_size();
@@ -186,11 +190,6 @@ struct F16xMXF4FlatmmPipelineAGmemBGmemCRegV1
CK_TILE_HOST_DEVICE static constexpr auto
SchedulerPerM(index_t dsread_perM, index_t dswrite_perM, index_t load_perM)
{
#if CKTILE_FLATMM_USE_BUFFER_LOAD_LDS
// GFX950 use BUFFER_LOAD_LDS to fill lds_buffer_A.
// There is no separate DS_WRITE instruction at all.
dswrite_perM = 0;
#endif
// Init inst order
index_t max_data_inst = dsread_perM > load_perM
? (dsread_perM > dswrite_perM ? dsread_perM : dswrite_perM)
@@ -360,7 +359,36 @@ struct F16xMXF4FlatmmPipelineAGmemBGmemCRegV1
// Calculate ds_read number per M
dsread_perM = dsread_per_wg;
#if CKTILE_FLATMM_USE_BUFFER_LOAD_LDS == 0
// Calculate ds_write number per M
if(mIter == 0)
{
dswrite_perM =
(dswrite_num_perK - (MIterPerWarp - DsWritePreIssue) * dswrite_rep) > 0
? dswrite_num_perK - (MIterPerWarp - DsWritePreIssue) * dswrite_rep
: 0;
}
else if(mIter >= MIterPerWarp - DsWritePreIssue + 1)
{
dswrite_perM = 0;
}
else
{
dswrite_perM = (dswrite_num_perK -
(MIterPerWarp - DsWritePreIssue - mIter) * dswrite_rep) > 0
? dswrite_rep
: 0;
}
// Add ds write when ds write data > needed
if(dswrite_num_perK == 0 && kIter == (KIterPerWarp - 1 - dswrite_kIter))
{
if(mIter == MIterPerWarp - 1 - dswrite_mIter)
dswrite_perM = 1;
}
#endif
// Calculate buffer_load number per M
#if CKTILE_FLATMM_USE_BUFFER_LOAD_LDS == 0
if(mIter < HalfMIter)
{
load_perM =
@@ -375,10 +403,17 @@ struct F16xMXF4FlatmmPipelineAGmemBGmemCRegV1
? Aload_rep
: 0;
}
if((kIter % KPerScaleLoad == 0) && (mIter == 0))
{
load_perM = load_perM + 1;
#else
if ((kIter * MIterPerWarp + mIter) >=
(KIterPerWarp * MIterPerWarp - m_preload)) {
load_perM = 1;
}
#endif
// if((kIter % KPerScaleLoad == 0) && (mIter == 0))
// {
// load_perM = load_perM + 1;
// }
// SchedulerPerM(dsread_perM, dswrite_perM, load_perM);
SchedulerPerM(dsread_perM, dswrite_perM, load_perM);
}
}
@@ -821,6 +856,8 @@ struct F16xMXF4FlatmmPipelineAGmemBGmemCRegV1
});
});
__builtin_amdgcn_sched_barrier(0);
// Prefill A(2i+1)
prefill_lds_a_stage2(a_copy_lds_window_pong);
@@ -866,12 +903,23 @@ struct F16xMXF4FlatmmPipelineAGmemBGmemCRegV1
load_tile(a_warp_windows_ping(number<AmIter>{})(number<AkIter>{}));
}
// yadai comments out the following
/*
// barrier
if constexpr((kIter == KIterPerWarp - 1) && (mIter == MIter_2nd_last))
{
__builtin_amdgcn_s_waitcnt(Bload_total_num);
block_sync_lds();
}
*/
// sync shouble made as early as possible
if constexpr((kIter * MIterPerWarp + mIter) ==
(KIterPerWarp * MIterPerWarp - m_preload))
{
__builtin_amdgcn_s_waitcnt(Bload_total_num);
block_sync_lds();
}
});
});
prefill_lds_a_stage1(
@@ -928,6 +976,8 @@ struct F16xMXF4FlatmmPipelineAGmemBGmemCRegV1
});
});
__builtin_amdgcn_sched_barrier(0);
// Prefill A(2i+2)
prefill_lds_a_stage2(a_copy_lds_window_ping);
@@ -973,12 +1023,23 @@ struct F16xMXF4FlatmmPipelineAGmemBGmemCRegV1
load_tile(a_warp_windows_pong(number<AmIter>{})(number<AkIter>{}));
}
// yadai comments out the following
/*
// barrier
if constexpr((kIter == KIterPerWarp - 1) && (mIter == MIter_2nd_last))
{
__builtin_amdgcn_s_waitcnt(Bload_total_num);
block_sync_lds();
}
*/
// sync shouble made as early as possible
if constexpr((kIter * MIterPerWarp + mIter) ==
(KIterPerWarp * MIterPerWarp - m_preload))
{
__builtin_amdgcn_s_waitcnt(Bload_total_num);
block_sync_lds();
}
});
});
prefill_lds_a_stage1(