diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp index 24bd66a59e..7a9d7ab1e0 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp @@ -32,6 +32,15 @@ struct GemmPipelineAgBgCrImplBase move_tile_window(dram_tile_window, dram_tile_window_step); } + template + CK_TILE_DEVICE void GlobalPrefetchAsync(DstBlockTile& dst_block_tile, + SrcTileWindow& dram_tile_window, + const DramTileWindowStep& dram_tile_window_step) const + { + async_load_tile(dst_block_tile, dram_tile_window); + move_tile_window(dram_tile_window, dram_tile_window_step); + } + template CK_TILE_DEVICE void LocalPrefill(DstTileWindow& lds_tile_window, const SrcBlockTile& src_block_tile, diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp index c198c9443a..f96c42e01a 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp @@ -408,8 +408,8 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3 // prefetch // global read 0 - Base::GlobalPrefetch(a_block_tile, a_copy_dram_window, a_dram_tile_window_step); - Base::GlobalPrefetch(b_block_tile, b_copy_dram_window, b_dram_tile_window_step); + Base::GlobalPrefetchAsync(a_block_tile, a_copy_dram_window, a_dram_tile_window_step); + Base::GlobalPrefetchAsync(b_block_tile, b_copy_dram_window, b_dram_tile_window_step); // initialize C tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile); @@ -438,8 +438,8 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3 Base::LocalPrefill(b_copy_lds_window, b_block_tile, b_element_func); } - Base::GlobalPrefetch(a_block_tile, a_copy_dram_window, a_dram_tile_window_step); - Base::GlobalPrefetch(b_block_tile, b_copy_dram_window, b_dram_tile_window_step); + Base::GlobalPrefetchAsync(a_block_tile, a_copy_dram_window, a_dram_tile_window_step); + Base::GlobalPrefetchAsync(b_block_tile, b_copy_dram_window, b_dram_tile_window_step); block_sync_lds(); block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window); @@ -477,8 +477,10 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3 Base::LocalPrefill(b_copy_lds_window, b_block_tile, b_element_func); } - Base::GlobalPrefetch(a_block_tile, a_copy_dram_window, a_dram_tile_window_step); - Base::GlobalPrefetch(b_block_tile, b_copy_dram_window, b_dram_tile_window_step); + Base::GlobalPrefetchAsync( + a_block_tile, a_copy_dram_window, a_dram_tile_window_step); + Base::GlobalPrefetchAsync( + b_block_tile, b_copy_dram_window, b_dram_tile_window_step); block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window);