diff --git a/example/ck_tile/99_toy_example/03_flash_attention_fwd/block_gemm_areg_bsmem_creg_v1.hpp b/example/ck_tile/99_toy_example/03_flash_attention_fwd/block_gemm_areg_bsmem_creg_v1.hpp index 740c540d6c..d502655210 100644 --- a/example/ck_tile/99_toy_example/03_flash_attention_fwd/block_gemm_areg_bsmem_creg_v1.hpp +++ b/example/ck_tile/99_toy_example/03_flash_attention_fwd/block_gemm_areg_bsmem_creg_v1.hpp @@ -26,6 +26,100 @@ struct BlockGemmARegBSmemCRegV1 static constexpr index_t kBlockSize = Problem::kBlockSize; + static constexpr index_t kPackedSize = + ck_tile::numeric_traits>::PackedSize; + + template + CK_TILE_DEVICE static constexpr auto HotLoopScheduler() + { + constexpr index_t MPerBlock = BlockGemmShape::kM; + constexpr index_t NPerBlock = BlockGemmShape::kN; + constexpr index_t KPerBlock = BlockGemmShape::kK; + + constexpr auto config = Policy::template GetWarpGemmMWarpNWarp(); + using WG = remove_cvref_t())>; + constexpr index_t MPerXDL = WG::kM; + constexpr index_t NPerXDL = WG::kN; + constexpr index_t KPerXDL = WG::WarpGemmAttribute::Impl::kK; + + constexpr index_t WaveSize = get_warp_size(); + constexpr index_t WaveNumM = config.template get<1>(); + + constexpr index_t B_LDS_RW_Width = SmemPack; + + constexpr index_t B_Buffer_Load_Inst_Num = + NPerBlock * KPerBlock / (kBlockSize * VectorSizeB); + + constexpr index_t B_LDS_Write_Inst_Num = + NPerBlock * KPerBlock / (kBlockSize * B_LDS_RW_Width); + + constexpr index_t B_LDS_Read_Inst_Num = + WaveNumM * NPerBlock * KPerBlock / + (kBlockSize * B_LDS_RW_Width); + + constexpr index_t C_MFMA_Inst_Num = + MPerBlock * NPerBlock * KPerBlock / (kBlockSize / WaveSize) / + (MPerXDL * NPerXDL * KPerXDL); + + // B split schedule + constexpr auto num_ds_read_inst_b = B_LDS_RW_Width * sizeof(BDataType) / kPackedSize == 16 + ? B_LDS_Read_Inst_Num + : B_LDS_Read_Inst_Num / 2; + + constexpr auto num_ds_write_inst_b = B_LDS_Write_Inst_Num; + + constexpr auto num_buffer_load_inst_b = B_Buffer_Load_Inst_Num; + + constexpr auto num_mfma_inst = C_MFMA_Inst_Num; + + constexpr auto mfma_cycle = NPerXDL == 16 ? 16 : 32; + constexpr auto ds_read_b_issue_cycle = + B_LDS_RW_Width * sizeof(BDataType) / kPackedSize == 16 ? 8 : 4; + constexpr auto ds_read_b_mfma_rate = + (mfma_cycle - 4 + 2 * ds_read_b_issue_cycle - 1) / (2 * ds_read_b_issue_cycle); + + constexpr auto num_dsread_b_mfma = + (num_ds_read_inst_b + ds_read_b_mfma_rate - 1) / ds_read_b_mfma_rate; + + // stage 1 + constexpr auto num_mfma_stage1 = num_mfma_inst - (num_dsread_b_mfma); + constexpr auto num_mfma_per_issue = num_mfma_stage1 / (num_buffer_load_inst_b); + constexpr auto num_dswrite_per_issue_b = num_ds_write_inst_b / num_buffer_load_inst_b; + constexpr auto num_mfma_per_dswrite_b = + (num_mfma_per_issue - num_dswrite_per_issue_b * 2 >= 1) ? 2 : 1; + + static_for<0, num_buffer_load_inst_b, 1>{}([&](auto i) { + ignore = i; + static_for<0, num_dswrite_per_issue_b, 1>{}([&](auto idswrite) { + ignore = idswrite; + __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write + __builtin_amdgcn_sched_group_barrier(0x008, num_mfma_per_dswrite_b, 0); // MFMA + }); + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + __builtin_amdgcn_sched_group_barrier(0x008, + num_mfma_per_issue - num_mfma_per_dswrite_b * + num_dswrite_per_issue_b, + 0); // MFMA + }); + + // stage 2 + static_for<0, num_dsread_b_mfma, 1>{}([&](auto i) { + if constexpr((num_ds_read_inst_b - (i + 1) * ds_read_b_mfma_rate) >= + ds_read_b_mfma_rate) + { + __builtin_amdgcn_sched_group_barrier(0x100, ds_read_b_mfma_rate, 0); // DS read + } + else + { + __builtin_amdgcn_sched_group_barrier(0x100, + num_ds_read_inst_b - (num_dsread_b_mfma - 1) * + ds_read_b_mfma_rate, + 0); // DS read + } + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + }); + } + // C += A * B template __device__ void operator()(CBlockTensor& c_block_tensor, diff --git a/example/ck_tile/99_toy_example/03_flash_attention_fwd/block_gemm_pipeline_agmem_bgmem_creg_v2_askiplds.hpp b/example/ck_tile/99_toy_example/03_flash_attention_fwd/block_gemm_pipeline_agmem_bgmem_creg_v2_askiplds.hpp index 4a8c2beeb7..d54b460b71 100644 --- a/example/ck_tile/99_toy_example/03_flash_attention_fwd/block_gemm_pipeline_agmem_bgmem_creg_v2_askiplds.hpp +++ b/example/ck_tile/99_toy_example/03_flash_attention_fwd/block_gemm_pipeline_agmem_bgmem_creg_v2_askiplds.hpp @@ -134,6 +134,8 @@ struct BlockGemmPipelineAGmemBGmemCReg< b_block_tile = load_tile(b_copy_dram_window); } + __builtin_amdgcn_sched_barrier(0); + if constexpr(k_loops > 2) { static_for<0, k_loops - 2, 1>{}([&](auto i_k0) { @@ -158,6 +160,9 @@ struct BlockGemmPipelineAGmemBGmemCReg< store_tile(b_copy_lds_window, b_block_tile); b_block_tile = load_tile(b_copy_dram_window); + + block_gemm.HotLoopScheduler(); + __builtin_amdgcn_sched_barrier(0); }); } @@ -276,6 +281,7 @@ struct BlockGemmPipelineAGmemBGmemCReg< b_block_tile = load_tile(b_copy_dram_window); } + __builtin_amdgcn_sched_barrier(0); if constexpr(k_loops > 2) { static_for<0, k_loops - 2, 1>{}([&](auto i_k0) { @@ -293,6 +299,9 @@ struct BlockGemmPipelineAGmemBGmemCReg< store_tile(b_copy_lds_window, b_block_tile); b_block_tile = load_tile(b_copy_dram_window); + + block_gemm.HotLoopScheduler(); + __builtin_amdgcn_sched_barrier(0); }); } diff --git a/example/ck_tile/99_toy_example/03_flash_attention_fwd/flash_attention_fwd_impl.hpp b/example/ck_tile/99_toy_example/03_flash_attention_fwd/flash_attention_fwd_impl.hpp index 94016025c4..0d734aca3d 100644 --- a/example/ck_tile/99_toy_example/03_flash_attention_fwd/flash_attention_fwd_impl.hpp +++ b/example/ck_tile/99_toy_example/03_flash_attention_fwd/flash_attention_fwd_impl.hpp @@ -310,6 +310,7 @@ struct FlashAttentionFwdImpl if constexpr(k1_loops > 1) { + __builtin_amdgcn_sched_barrier(0); static_for<0, k1_loops - 1, 1>{}([&](auto i_k1) { const auto v = load_tile(v_dram_window); // load next v block_sync_lds(); @@ -321,6 +322,9 @@ struct FlashAttentionFwdImpl block_sync_lds(); store_tile(v_lds_window, v); move_tile_window(v_dram_window, {0, kK1PerBlock}); + + gemm1.template HotLoopScheduler<8, 4>(); + __builtin_amdgcn_sched_barrier(0); }); } // tail