diff --git a/include/ck_tile/ops/flatmm/pipeline/mixed_prec_flatmm_pipeline_agmem_bgmem_creg_v1.hpp b/include/ck_tile/ops/flatmm/pipeline/mixed_prec_flatmm_pipeline_agmem_bgmem_creg_v1.hpp index 17c88e4f08..89c40a0d69 100644 --- a/include/ck_tile/ops/flatmm/pipeline/mixed_prec_flatmm_pipeline_agmem_bgmem_creg_v1.hpp +++ b/include/ck_tile/ops/flatmm/pipeline/mixed_prec_flatmm_pipeline_agmem_bgmem_creg_v1.hpp @@ -71,7 +71,11 @@ struct F16xMXF4FlatmmPipelineAGmemBGmemCRegV1 using WG = remove_cvref_t())>; static constexpr index_t DsWritePreIssue = 3; // default 2, ds write at MIter - 2 +#if CKTILE_FLATMM_USE_BUFFER_LOAD_LDS + static constexpr index_t DsReadPreload = 16; // default 8, if using lds, register pressure is alleviated, improve preload +#else static constexpr index_t DsReadPreload = 2; // default 2, preload 2 ds read +#endif static constexpr index_t BlockSize = Problem::kBlockSize; static constexpr index_t WaveSize = get_warp_size(); @@ -186,11 +190,6 @@ struct F16xMXF4FlatmmPipelineAGmemBGmemCRegV1 CK_TILE_HOST_DEVICE static constexpr auto SchedulerPerM(index_t dsread_perM, index_t dswrite_perM, index_t load_perM) { -#if CKTILE_FLATMM_USE_BUFFER_LOAD_LDS - // GFX950 use BUFFER_LOAD_LDS to fill lds_buffer_A. - // There is no separate DS_WRITE instruction at all. - dswrite_perM = 0; -#endif // Init inst order index_t max_data_inst = dsread_perM > load_perM ? (dsread_perM > dswrite_perM ? dsread_perM : dswrite_perM) @@ -360,7 +359,36 @@ struct F16xMXF4FlatmmPipelineAGmemBGmemCRegV1 // Calculate ds_read number per M dsread_perM = dsread_per_wg; +#if CKTILE_FLATMM_USE_BUFFER_LOAD_LDS == 0 + // Calculate ds_write number per M + if(mIter == 0) + { + dswrite_perM = + (dswrite_num_perK - (MIterPerWarp - DsWritePreIssue) * dswrite_rep) > 0 + ? dswrite_num_perK - (MIterPerWarp - DsWritePreIssue) * dswrite_rep + : 0; + } + else if(mIter >= MIterPerWarp - DsWritePreIssue + 1) + { + dswrite_perM = 0; + } + else + { + dswrite_perM = (dswrite_num_perK - + (MIterPerWarp - DsWritePreIssue - mIter) * dswrite_rep) > 0 + ? dswrite_rep + : 0; + } + // Add ds write when ds write data > needed + if(dswrite_num_perK == 0 && kIter == (KIterPerWarp - 1 - dswrite_kIter)) + { + if(mIter == MIterPerWarp - 1 - dswrite_mIter) + dswrite_perM = 1; + } +#endif + // Calculate buffer_load number per M +#if CKTILE_FLATMM_USE_BUFFER_LOAD_LDS == 0 if(mIter < HalfMIter) { load_perM = @@ -375,10 +403,17 @@ struct F16xMXF4FlatmmPipelineAGmemBGmemCRegV1 ? Aload_rep : 0; } - if((kIter % KPerScaleLoad == 0) && (mIter == 0)) - { - load_perM = load_perM + 1; +#else + if ((kIter * MIterPerWarp + mIter) >= + (KIterPerWarp * MIterPerWarp - m_preload)) { + load_perM = 1; } +#endif + // if((kIter % KPerScaleLoad == 0) && (mIter == 0)) + // { + // load_perM = load_perM + 1; + // } + // SchedulerPerM(dsread_perM, dswrite_perM, load_perM); SchedulerPerM(dsread_perM, dswrite_perM, load_perM); } } @@ -821,6 +856,8 @@ struct F16xMXF4FlatmmPipelineAGmemBGmemCRegV1 }); }); + __builtin_amdgcn_sched_barrier(0); + // Prefill A(2i+1) prefill_lds_a_stage2(a_copy_lds_window_pong); @@ -866,12 +903,23 @@ struct F16xMXF4FlatmmPipelineAGmemBGmemCRegV1 load_tile(a_warp_windows_ping(number{})(number{})); } + // yadai comments out the following + /* // barrier if constexpr((kIter == KIterPerWarp - 1) && (mIter == MIter_2nd_last)) { __builtin_amdgcn_s_waitcnt(Bload_total_num); block_sync_lds(); } + */ + + // sync shouble made as early as possible + if constexpr((kIter * MIterPerWarp + mIter) == + (KIterPerWarp * MIterPerWarp - m_preload)) + { + __builtin_amdgcn_s_waitcnt(Bload_total_num); + block_sync_lds(); + } }); }); prefill_lds_a_stage1( @@ -928,6 +976,8 @@ struct F16xMXF4FlatmmPipelineAGmemBGmemCRegV1 }); }); + __builtin_amdgcn_sched_barrier(0); + // Prefill A(2i+2) prefill_lds_a_stage2(a_copy_lds_window_ping); @@ -973,12 +1023,23 @@ struct F16xMXF4FlatmmPipelineAGmemBGmemCRegV1 load_tile(a_warp_windows_pong(number{})(number{})); } + // yadai comments out the following + /* // barrier if constexpr((kIter == KIterPerWarp - 1) && (mIter == MIter_2nd_last)) { __builtin_amdgcn_s_waitcnt(Bload_total_num); block_sync_lds(); } + */ + + // sync shouble made as early as possible + if constexpr((kIter * MIterPerWarp + mIter) == + (KIterPerWarp * MIterPerWarp - m_preload)) + { + __builtin_amdgcn_s_waitcnt(Bload_total_num); + block_sync_lds(); + } }); }); prefill_lds_a_stage1(