diff --git a/include/ck_tile/ops/flatmm/pipeline/mixed_prec_flatmm_pipeline_agmem_bgmem_creg_v1.hpp b/include/ck_tile/ops/flatmm/pipeline/mixed_prec_flatmm_pipeline_agmem_bgmem_creg_v1.hpp index 51019b8e47..2d80d6d620 100644 --- a/include/ck_tile/ops/flatmm/pipeline/mixed_prec_flatmm_pipeline_agmem_bgmem_creg_v1.hpp +++ b/include/ck_tile/ops/flatmm/pipeline/mixed_prec_flatmm_pipeline_agmem_bgmem_creg_v1.hpp @@ -184,7 +184,7 @@ struct F16xMXF4FlatmmPipelineAGmemBGmemCRegV1 CK_TILE_HOST_DEVICE static constexpr auto SchedulerPerM(index_t dsread_perM, index_t dswrite_perM, index_t load_perM) { -#if defined(__gfx950__) +#if CKTILE_FLATMM_USE_BUFFER_LOAD_LDS_AS_POSSIBLE && defined(__gfx950__) // GFX950 use BUFFER_LOAD_LDS to fill lds_buffer_A. // There is no separate DS_WRITE instruction at all. dswrite_perM = 0; @@ -658,10 +658,17 @@ struct F16xMXF4FlatmmPipelineAGmemBGmemCRegV1 using ABlockTile = decltype(load_tile(a_copy_dram_window)); ABlockTile a_block_tile; -#if defined(__gfx950__) - auto prefill_lds_a_stage1 = [&](auto lds_tile_a, auto dram_tile_a) { + enum + { + PrefillBeforeGemm = 1, + PrefillAfterGemm = 2, + PrefillAlways = PrefillBeforeGemm | PrefillAfterGemm, + }; +#if CKTILE_FLATMM_USE_BUFFER_LOAD_LDS_AS_POSSIBLE && defined(__gfx950__) + auto prefill_lds_a_stage1 = [&](auto lds_tile_a, auto dram_tile_a, auto prefill_location) { // global -> lds - async_load_tile(lds_tile_a, dram_tile_a); + if constexpr(prefill_location & PrefillAfterGemm) + async_load_tile(lds_tile_a, dram_tile_a); }; auto prefill_lds_a_stage2 = [&](auto lds_tile_a) { // data has been stored in lds, no need more operation. @@ -669,9 +676,10 @@ struct F16xMXF4FlatmmPipelineAGmemBGmemCRegV1 "buffer_load_lds don't support element func fot A before mfma"); }; #else - auto prefill_lds_a_stage1 = [&](auto lds_tile_a, auto dram_tile_a) { + auto prefill_lds_a_stage1 = [&](auto lds_tile_a, auto dram_tile_a, auto prefill_location) { // global -> vgpr - a_block_tile = load_tile(dram_tile_a); + if constexpr(prefill_location & PrefillBeforeGemm) + a_block_tile = load_tile(dram_tile_a); }; auto prefill_lds_a_stage2 = [&](auto lds_tile_a) { // vgpr -> lds @@ -682,7 +690,7 @@ struct F16xMXF4FlatmmPipelineAGmemBGmemCRegV1 // HEAD // Prefetch A0 - prefill_lds_a_stage1(a_copy_lds_window_ping, a_copy_dram_window); + prefill_lds_a_stage1(a_copy_lds_window_ping, a_copy_dram_window, number{}); // move A window to next k move_tile_window(a_copy_dram_window, {0, kKPerBlock}); @@ -725,7 +733,7 @@ struct F16xMXF4FlatmmPipelineAGmemBGmemCRegV1 __builtin_amdgcn_sched_barrier(0); // Prefetch A1 - prefill_lds_a_stage1(a_copy_lds_window_pong, a_copy_dram_window); + prefill_lds_a_stage1(a_copy_lds_window_pong, a_copy_dram_window, number{}); // move A window to next k move_tile_window(a_copy_dram_window, {0, kKPerBlock}); @@ -858,6 +866,12 @@ struct F16xMXF4FlatmmPipelineAGmemBGmemCRegV1 }); }); + // Prefill A(2i+1) + prefill_lds_a_stage2(a_copy_lds_window_pong); + + // Prefetch A(2i+2) + prefill_lds_a_stage1( + a_copy_lds_window_ping, a_copy_dram_window, number{}); // GEMM 2i static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { @@ -904,19 +918,16 @@ struct F16xMXF4FlatmmPipelineAGmemBGmemCRegV1 } }); }); + prefill_lds_a_stage1( + a_copy_lds_window_ping, a_copy_dram_window, number{}); + + // move A window to next k + move_tile_window(a_copy_dram_window, {0, kKPerBlock}); // move B window to next flat K move_tile_window(b_flat_dram_window, {0, MXFP4KPerWarp * KFlatPerBlockPerIter}); move_tile_window(scale_b_flat_dram_window, {0, ScaleKPerWarp * ScaleKFlatPerWarp}); - // Prefill A(2i+1) - prefill_lds_a_stage2(a_copy_lds_window_pong); - - // Prefetch A(2i+2) - prefill_lds_a_stage1(a_copy_lds_window_ping, a_copy_dram_window); - // move A window to next k - move_tile_window(a_copy_dram_window, {0, kKPerBlock}); - static_for<0, m_preload, 1>{}([&](auto loadIter) { constexpr auto mIter = loadIter % MIterPerWarp; constexpr auto kIter = loadIter / MIterPerWarp; @@ -960,6 +971,14 @@ struct F16xMXF4FlatmmPipelineAGmemBGmemCRegV1 b_warp_tensor_ping(nIter)(kIter) = ub.u; }); }); + + // Prefill A(2i+2) + prefill_lds_a_stage2(a_copy_lds_window_ping); + + // Prefetch A(2i+3) + prefill_lds_a_stage1( + a_copy_lds_window_pong, a_copy_dram_window, number{}); + // GEMM 2i+1 static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { @@ -1005,15 +1024,11 @@ struct F16xMXF4FlatmmPipelineAGmemBGmemCRegV1 } }); }); + prefill_lds_a_stage1( + a_copy_lds_window_pong, a_copy_dram_window, number{}); - // Prefill A(2i+2) - prefill_lds_a_stage2(a_copy_lds_window_ping); - - // Prefetch A(2i+3) - prefill_lds_a_stage1(a_copy_lds_window_pong, a_copy_dram_window); // move A window to next k move_tile_window(a_copy_dram_window, {0, kKPerBlock}); - // move B window to next flat K move_tile_window(b_flat_dram_window, {0, MXFP4KPerWarp * KFlatPerBlockPerIter}); move_tile_window(scale_b_flat_dram_window, {0, ScaleKPerWarp * ScaleKFlatPerWarp}); @@ -1075,6 +1090,9 @@ struct F16xMXF4FlatmmPipelineAGmemBGmemCRegV1 } }); + // Prefill A(loopK) + prefill_lds_a_stage2(a_copy_lds_window_pong); + // GEMM loopK-1 static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { @@ -1122,9 +1140,6 @@ struct F16xMXF4FlatmmPipelineAGmemBGmemCRegV1 }); }); - // Prefill A(loopK) - prefill_lds_a_stage2(a_copy_lds_window_pong); - static_for<0, m_preload, 1>{}([&](auto loadIter) { constexpr auto mIter = loadIter % MIterPerWarp; constexpr auto kIter = loadIter / MIterPerWarp; diff --git a/include/ck_tile/ops/flatmm/pipeline/mixed_prec_flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp b/include/ck_tile/ops/flatmm/pipeline/mixed_prec_flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp index df60e0da00..c23bb98bd9 100644 --- a/include/ck_tile/ops/flatmm/pipeline/mixed_prec_flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp +++ b/include/ck_tile/ops/flatmm/pipeline/mixed_prec_flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp @@ -7,6 +7,8 @@ namespace ck_tile { +#define CKTILE_FLATMM_USE_BUFFER_LOAD_LDS_AS_POSSIBLE 1 + struct F16xMXF4FlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy { static constexpr auto I0 = number<0>{}; @@ -21,7 +23,7 @@ struct F16xMXF4FlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy CK_TILE_HOST_DEVICE static constexpr auto TransformF16xF4_ATensorView(const NativeADramTensorView& a_dram_view) { -#if defined(__gfx950__) //|| defined(__gfx942__) +#if CKTILE_FLATMM_USE_BUFFER_LOAD_LDS_AS_POSSIBLE && defined(__gfx950__) constexpr int DynamicTileOffsetFlag = 0; constexpr index_t MPerXdl = Problem::BlockGemmShape::WarpTile::at(I0); @@ -118,7 +120,7 @@ struct F16xMXF4FlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy template CK_TILE_HOST_DEVICE static constexpr auto MakeF16xF4_WriteALdsBlockDescriptor() { -#if defined(__gfx950__) //|| defined(__gfx942__) +#if CKTILE_FLATMM_USE_BUFFER_LOAD_LDS_AS_POSSIBLE && defined(__gfx950__) constexpr index_t MPerBlock = Problem::BlockGemmShape::kM; constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; constexpr index_t KPack = GetSmemPackA();