mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-30 19:57:40 +00:00
mixed-prec flatmm pipeline improve
This commit is contained in:
@@ -71,7 +71,11 @@ struct F16xMXF4FlatmmPipelineAGmemBGmemCRegV1
|
||||
using WG = remove_cvref_t<decltype(config.template at<0>())>;
|
||||
|
||||
static constexpr index_t DsWritePreIssue = 3; // default 2, ds write at MIter - 2
|
||||
#if CKTILE_FLATMM_USE_BUFFER_LOAD_LDS
|
||||
static constexpr index_t DsReadPreload = 16; // default 8, if using lds, register pressure is alleviated, improve preload
|
||||
#else
|
||||
static constexpr index_t DsReadPreload = 2; // default 2, preload 2 ds read
|
||||
#endif
|
||||
|
||||
static constexpr index_t BlockSize = Problem::kBlockSize;
|
||||
static constexpr index_t WaveSize = get_warp_size();
|
||||
@@ -186,11 +190,6 @@ struct F16xMXF4FlatmmPipelineAGmemBGmemCRegV1
|
||||
CK_TILE_HOST_DEVICE static constexpr auto
|
||||
SchedulerPerM(index_t dsread_perM, index_t dswrite_perM, index_t load_perM)
|
||||
{
|
||||
#if CKTILE_FLATMM_USE_BUFFER_LOAD_LDS
|
||||
// GFX950 use BUFFER_LOAD_LDS to fill lds_buffer_A.
|
||||
// There is no separate DS_WRITE instruction at all.
|
||||
dswrite_perM = 0;
|
||||
#endif
|
||||
// Init inst order
|
||||
index_t max_data_inst = dsread_perM > load_perM
|
||||
? (dsread_perM > dswrite_perM ? dsread_perM : dswrite_perM)
|
||||
@@ -360,7 +359,36 @@ struct F16xMXF4FlatmmPipelineAGmemBGmemCRegV1
|
||||
// Calculate ds_read number per M
|
||||
dsread_perM = dsread_per_wg;
|
||||
|
||||
#if CKTILE_FLATMM_USE_BUFFER_LOAD_LDS == 0
|
||||
// Calculate ds_write number per M
|
||||
if(mIter == 0)
|
||||
{
|
||||
dswrite_perM =
|
||||
(dswrite_num_perK - (MIterPerWarp - DsWritePreIssue) * dswrite_rep) > 0
|
||||
? dswrite_num_perK - (MIterPerWarp - DsWritePreIssue) * dswrite_rep
|
||||
: 0;
|
||||
}
|
||||
else if(mIter >= MIterPerWarp - DsWritePreIssue + 1)
|
||||
{
|
||||
dswrite_perM = 0;
|
||||
}
|
||||
else
|
||||
{
|
||||
dswrite_perM = (dswrite_num_perK -
|
||||
(MIterPerWarp - DsWritePreIssue - mIter) * dswrite_rep) > 0
|
||||
? dswrite_rep
|
||||
: 0;
|
||||
}
|
||||
// Add ds write when ds write data > needed
|
||||
if(dswrite_num_perK == 0 && kIter == (KIterPerWarp - 1 - dswrite_kIter))
|
||||
{
|
||||
if(mIter == MIterPerWarp - 1 - dswrite_mIter)
|
||||
dswrite_perM = 1;
|
||||
}
|
||||
#endif
|
||||
|
||||
// Calculate buffer_load number per M
|
||||
#if CKTILE_FLATMM_USE_BUFFER_LOAD_LDS == 0
|
||||
if(mIter < HalfMIter)
|
||||
{
|
||||
load_perM =
|
||||
@@ -375,10 +403,17 @@ struct F16xMXF4FlatmmPipelineAGmemBGmemCRegV1
|
||||
? Aload_rep
|
||||
: 0;
|
||||
}
|
||||
if((kIter % KPerScaleLoad == 0) && (mIter == 0))
|
||||
{
|
||||
load_perM = load_perM + 1;
|
||||
#else
|
||||
if ((kIter * MIterPerWarp + mIter) >=
|
||||
(KIterPerWarp * MIterPerWarp - m_preload)) {
|
||||
load_perM = 1;
|
||||
}
|
||||
#endif
|
||||
// if((kIter % KPerScaleLoad == 0) && (mIter == 0))
|
||||
// {
|
||||
// load_perM = load_perM + 1;
|
||||
// }
|
||||
// SchedulerPerM(dsread_perM, dswrite_perM, load_perM);
|
||||
SchedulerPerM(dsread_perM, dswrite_perM, load_perM);
|
||||
}
|
||||
}
|
||||
@@ -821,6 +856,8 @@ struct F16xMXF4FlatmmPipelineAGmemBGmemCRegV1
|
||||
});
|
||||
});
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
// Prefill A(2i+1)
|
||||
prefill_lds_a_stage2(a_copy_lds_window_pong);
|
||||
|
||||
@@ -866,12 +903,23 @@ struct F16xMXF4FlatmmPipelineAGmemBGmemCRegV1
|
||||
load_tile(a_warp_windows_ping(number<AmIter>{})(number<AkIter>{}));
|
||||
}
|
||||
|
||||
// yadai comments out the following
|
||||
/*
|
||||
// barrier
|
||||
if constexpr((kIter == KIterPerWarp - 1) && (mIter == MIter_2nd_last))
|
||||
{
|
||||
__builtin_amdgcn_s_waitcnt(Bload_total_num);
|
||||
block_sync_lds();
|
||||
}
|
||||
*/
|
||||
|
||||
// sync shouble made as early as possible
|
||||
if constexpr((kIter * MIterPerWarp + mIter) ==
|
||||
(KIterPerWarp * MIterPerWarp - m_preload))
|
||||
{
|
||||
__builtin_amdgcn_s_waitcnt(Bload_total_num);
|
||||
block_sync_lds();
|
||||
}
|
||||
});
|
||||
});
|
||||
prefill_lds_a_stage1(
|
||||
@@ -928,6 +976,8 @@ struct F16xMXF4FlatmmPipelineAGmemBGmemCRegV1
|
||||
});
|
||||
});
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
// Prefill A(2i+2)
|
||||
prefill_lds_a_stage2(a_copy_lds_window_ping);
|
||||
|
||||
@@ -973,12 +1023,23 @@ struct F16xMXF4FlatmmPipelineAGmemBGmemCRegV1
|
||||
load_tile(a_warp_windows_pong(number<AmIter>{})(number<AkIter>{}));
|
||||
}
|
||||
|
||||
// yadai comments out the following
|
||||
/*
|
||||
// barrier
|
||||
if constexpr((kIter == KIterPerWarp - 1) && (mIter == MIter_2nd_last))
|
||||
{
|
||||
__builtin_amdgcn_s_waitcnt(Bload_total_num);
|
||||
block_sync_lds();
|
||||
}
|
||||
*/
|
||||
|
||||
// sync shouble made as early as possible
|
||||
if constexpr((kIter * MIterPerWarp + mIter) ==
|
||||
(KIterPerWarp * MIterPerWarp - m_preload))
|
||||
{
|
||||
__builtin_amdgcn_s_waitcnt(Bload_total_num);
|
||||
block_sync_lds();
|
||||
}
|
||||
});
|
||||
});
|
||||
prefill_lds_a_stage1(
|
||||
|
||||
Reference in New Issue
Block a user