diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_decode.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_decode.py index 4c6933eebb..4a23d1dd10 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_decode.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_decode.py @@ -38,8 +38,8 @@ K0_MAX_SUBMAX_MAP = { } SEQLENQ_MAP = { - # "16" : "16", - # "32" : "32", + "16" : "16", + "32" : "32", # "64" : "64" "128" : "128", } @@ -132,18 +132,18 @@ using trait_{F_idx} = fmha_fwd_decode_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_ namespace {{ template void run_instance(const ck_tile::stream_config& s, fmha_fwd_decode_args a) {{ - //if constexpr ({F_hdim} == 128 && {F_bias} == ck_tile::BlockAttentionBiasEnum::NO_BIAS - // && (std::is_same_v<{F_mask}, ck_tile::SimplifiedGenericAttentionMask> - // || std::is_same_v<{F_mask}, FmhaMasks::NoMask>)) {{ - // if (a.max_seqlen_q == 1 && a.nhead_k < a.nhead_q) {{ - // instance::run(s, a); - // }} else {{ - // instance::run(s, a); - // }} - //}} else {{ - // instance::run(s, a); - //}} - instance::run(s, a); + if constexpr ({F_hdim} == 128 && {F_bias} == ck_tile::BlockAttentionBiasEnum::NO_BIAS + && (std::is_same_v<{F_mask}, ck_tile::SimplifiedGenericAttentionMask> + || std::is_same_v<{F_mask}, FmhaMasks::NoMask>)) {{ + if (a.max_seqlen_q == 1 && a.nhead_k < a.nhead_q) {{ + instance::run(s, a); + }} else {{ + instance::run(s, a); + }} + }} else {{ + instance::run(s, a); + }} + //instance::run(s, a); }} }} // anonymous namespace @@ -152,20 +152,20 @@ void run_instance(const ck_tile::stream_config& s, fmha_fwd_decode_args a) {{ template<> void fmha_fwd_decode_oneshot_(const ck_tile::stream_config& s, fmha_fwd_decode_args a) {{ - //if constexpr({F_mode} == false) {{ // batch mode - // // we don't check every seqlen_k values for kvcache - // if (a.seqlen_k_ptr != nullptr) {{ - // run_instance(s, a); - // // make sure F_bn0 is divisible by F_bk1 - // }} else if (a.seqlen_k % (a.num_splits * {F_bn0}) == 0) {{ - // run_instance(s, a); - // }} else {{ - // run_instance(s, a); - // }} - //}} else {{ - // run_instance(s, a); - //}} - run_instance(s, a); + if constexpr({F_mode} == false) {{ // batch mode + // we don't check every seqlen_k values for kvcache + if (a.seqlen_k_ptr != nullptr) {{ + run_instance(s, a); + // make sure F_bn0 is divisible by F_bk1 + }} else if (a.seqlen_k % (a.num_splits * {F_bn0}) == 0) {{ + run_instance(s, a); + }} else {{ + run_instance(s, a); + }} + }} else {{ + run_instance(s, a); + }} + //run_instance(s, a); }} template<> @@ -658,15 +658,15 @@ class FmhaFwdSplitKVCombineKernel: def get_fmha_fwd_tile_dict_from_dtype(dtype : str) -> Optional[dict]: if dtype == 'fp16' or dtype == 'bf16': return { - # '64': { - # # Specialize for different SeqQ - # '16': FmhaFwdTileSize(16, 32, 64, 64, 32, 64, 1, 1, 1, 1, 1, 1, 16, 16, 32, 16, 16, 32, -1), - # '32': FmhaFwdTileSize(32, 32, 64, 64, 32, 64, 1, 1, 1, 1, 1, 1, 32, 32, 16, 32, 32, 16, -1), - # '128': FmhaFwdTileSize(128, 64, 64, 64, 64, 64, 4, 1, 1, 4, 1, 1, 16, 16, 32, 16, 16, 32, -1), - # }, + '64': { + # Specialize for different SeqQ + '16': FmhaFwdTileSize(16, 32, 64, 64, 32, 64, 1, 1, 1, 1, 1, 1, 16, 16, 32, 16, 16, 32, -1), + '32': FmhaFwdTileSize(32, 32, 64, 64, 32, 64, 1, 1, 1, 1, 1, 1, 32, 32, 16, 32, 32, 16, -1), + '128': FmhaFwdTileSize(128, 64, 64, 64, 64, 64, 4, 1, 1, 4, 1, 1, 16, 16, 32, 16, 16, 32, -1), + }, '128': { - # '16': FmhaFwdTileSize(16, 32, 64, 128, 32, 128, 1, 1, 1, 1, 1, 1, 16, 16, 32, 16, 16, 32, -1), - # '32': FmhaFwdTileSize(32, 32, 128, 128, 32, 128, 1, 1, 1, 1, 1, 1, 32, 32, 16, 32, 32, 16, -1), + '16': FmhaFwdTileSize(16, 32, 64, 128, 32, 128, 1, 1, 1, 1, 1, 1, 16, 16, 32, 16, 16, 32, -1), + '32': FmhaFwdTileSize(32, 32, 128, 128, 32, 128, 1, 1, 1, 1, 1, 1, 32, 32, 16, 32, 32, 16, -1), '128': FmhaFwdTileSize(128, 64, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1), }, } diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_decode_pipeline_qr_ks_vs.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_decode_pipeline_qr_ks_vs.hpp index e10cb23f93..b5ed6a483d 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_decode_pipeline_qr_ks_vs.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_decode_pipeline_qr_ks_vs.hpp @@ -857,47 +857,53 @@ struct BlockFmhaFwdDecodePipelineQRKSVS static_assert(1 <= k0_loops); static_assert(1 <= k1_loops); - - // block_sync_lds(); async_load_tile(k_lds_write_windows.at(I0), k_dram_window); async_load_tile(v_lds_write_windows.at(I0), v_dram_window); + + move_tile_window(k_dram_window, {kN0, 0}); + async_load_tile(k_lds_write_windows.at(I1), k_dram_window); constexpr index_t k_vmem_insts = k_dram_window.get_num_of_access(); constexpr index_t v_vmem_insts = v_dram_window.get_num_of_access(); + block_sync_lds_direct_load(); + auto k_tile = load_tile(k_lds_read_windows.at(I0)); + + __builtin_amdgcn_sched_barrier(0); + auto mainloop = [&](index_t cur_loop) { - auto k_lds_write_window = (cur_loop%2 == 0)? k_lds_write_windows.at(I1) : k_lds_write_windows.at(I0); - auto k_lds_read_window = (cur_loop%2 == 0)? k_lds_read_windows.at(I0) : k_lds_read_windows.at(I1); + auto k_lds_write_window = (cur_loop%2 == 0)? k_lds_write_windows.at(I0) : k_lds_write_windows.at(I1); + auto k_lds_read_window_cur = (cur_loop%2 == 0)? k_lds_read_windows.at(I0) : k_lds_read_windows.at(I1); + auto k_lds_read_window_next = (cur_loop%2 == 0)? k_lds_read_windows.at(I1) : k_lds_read_windows.at(I0); auto v_lds_write_window = (cur_loop%2 == 0)? v_lds_write_windows.at(I1) : v_lds_write_windows.at(I0); auto v_lds_read_window = (cur_loop%2 == 0)? v_lds_read_windows.at(I0) : v_lds_read_windows.at(I1); + // move V tile windows block_sync_lds(); - // move K tile windows - move_tile_window(k_dram_window, {kN0, 0}); - async_load_tile(k_lds_write_window, k_dram_window); + move_tile_window(v_dram_window, {kN0, 0}); + async_load_tile(v_lds_write_window, v_dram_window); // STAGE 1, QK gemm clear_tile(s_acc); // initialize C - block_sync_lds_direct_load(); - auto k_tile = load_tile(k_lds_read_window); - if constexpr(1 < k0_loops) { static_for<0, k0_loops - 1, 1>{}([&](auto i_k0) { + // loop over along the [K]ey head dimension + move_tile_window(k_lds_read_window_cur, {0, kK0}); + auto k_tile_switch = load_tile(k_lds_read_window_cur); + gemm_0(s_acc, get_slice_tile(q_tile, sequence<0, i_k0 * kK0>{}, sequence{}), k_tile); - - // loop over along the [K]ey head dimension - move_tile_window(k_lds_read_window, {0, kK0}); - k_tile = load_tile(k_lds_read_window); + + k_tile = k_tile_switch; }); // move back to the origin - move_tile_window(k_lds_read_window, {0, -kK0 * (k0_loops - 1)}); + move_tile_window(k_lds_read_window_cur, {0, -kK0 * (k0_loops - 1)}); } gemm_0(s_acc, @@ -905,6 +911,21 @@ struct BlockFmhaFwdDecodePipelineQRKSVS sequence<0, (k0_loops - 1) * kK0>{}, sequence{}), k_tile); + + block_sync_lds_direct_load(); + auto v_tile = load_tile_transpose(v_lds_read_window); + + static_for<0, 14, 1>{}([&](auto i) { + ignore = i; + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS_READ + }); + + static_for<0, 2, 1>{}([&](auto i) { + ignore = i; + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x100, 2, 0); // DS_READ + }); if constexpr(kHasUnevenSplits) { @@ -1058,25 +1079,24 @@ struct BlockFmhaFwdDecodePipelineQRKSVS }); }); - move_tile_window(v_dram_window, {kN0, 0}); - async_load_tile(v_lds_write_window, v_dram_window); - - block_sync_lds_direct_load(); - // Will insert unexpected vmcnt(0) here, probably the aliasing issue. - auto v_tile = load_tile_transpose(v_lds_read_window); + block_sync_lds(); + move_tile_window(k_dram_window, {kN0, 0}); + async_load_tile(k_lds_write_window, k_dram_window); if constexpr(1 < k1_loops) { static_for<0, k1_loops - 1, 1>{}([&](auto i_k1) { + // loop over along the [V]alue Sequence length + move_tile_window(v_lds_read_window, {kK1, 0}); + auto v_tile_switch = load_tile_transpose(v_lds_read_window); + gemm_1(o_acc, get_slice_tile(p_tile, sequence<0, i_k1 * kK1>{}, sequence{}), v_tile); - // loop over along the [V]alue Sequence length - move_tile_window(v_lds_read_window, {kK1, 0}); - v_tile = load_tile_transpose(v_lds_read_window); + v_tile = v_tile_switch; }); // move back to the origin move_tile_window(v_lds_read_window, {-kK1 * (k1_loops - 1), 0}); @@ -1087,20 +1107,26 @@ struct BlockFmhaFwdDecodePipelineQRKSVS sequence<0, (k1_loops - 1) * kK1>{}, sequence{}), v_tile); + + k_tile = load_tile(k_lds_read_window_next); + + static_for<0, 14, 1>{}([&](auto i) { + ignore = i; + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x100, 2, 0); // DS_READ + }); + + static_for<0, 2, 1>{}([&](auto i) { + ignore = i; + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS_READ + }); }; do { mainloop(i_total_loops); i_total_loops++; - // mainloop(I1, I0); - // i_total_loops++; - // if(i_total_loops == (num_total_loop)) - // { - // continue; - // } - // mainloop(I0, I1); - // i_total_loops++; } while(i_total_loops < num_total_loop); if constexpr(kStoreLSE) diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_decode_pipeline_qr_ks_vs_policy.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_decode_pipeline_qr_ks_vs_policy.hpp index 93837af5f5..dc9be6caff 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_decode_pipeline_qr_ks_vs_policy.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_decode_pipeline_qr_ks_vs_policy.hpp @@ -127,7 +127,7 @@ struct BlockFmhaFwdDecodePipelineQRKSVSDefaultPolicy tuple, sequence>, tuple>, tuple>, - sequence<1, 2>, + sequence<2, 1>, sequence<0, 0>>{}; constexpr auto q_block_dstr_encode = detail::make_embed_tile_distribution_encoding( @@ -183,12 +183,14 @@ struct BlockFmhaFwdDecodePipelineQRKSVSDefaultPolicy constexpr index_t MIterPerWarp = kMPerBlock / (MWarp * WarpGemm::kM); constexpr index_t KIterPerWarp = kKPerBlock / WarpGemm::kK; + // Read M first, then K + // This is the same data consume order as BlockGEMM constexpr auto q_block_outer_dstr_encoding = tile_distribution_encoding, tuple, sequence>, tuple>, tuple>, - sequence<1, 2>, + sequence<2, 1>, sequence<0, 0>>{}; constexpr auto q_block_dstr_encode = detail::make_embed_tile_distribution_encoding( @@ -428,12 +430,14 @@ struct BlockFmhaFwdDecodePipelineQRKSVSDefaultPolicy constexpr index_t NIterPerWarp = kNPerBlock / (NWarp * WarpGemm::kN); constexpr index_t KIterPerWarp = kKPerBlock / WarpGemm::kK; + // Read N first, then K + // This is the same data consume order as BlockGEMM constexpr auto k_block_outer_dstr_encoding = tile_distribution_encoding, tuple, sequence>, tuple>, tuple>, - sequence<1, 2>, + sequence<2, 1>, sequence<0, 0>>{}; constexpr auto k_block_dstr_encode = detail::make_embed_tile_distribution_encoding( @@ -489,12 +493,14 @@ struct BlockFmhaFwdDecodePipelineQRKSVSDefaultPolicy constexpr index_t MIterPerWarp = kMPerBlock / (MWarp * WarpGemm::kM); constexpr index_t KIterPerWarp = kKPerBlock / WarpGemm::kK; + // Read M first, then K + // This is the same data consume order as BlockGEMM constexpr auto p_block_outer_dstr_encoding = tile_distribution_encoding, tuple, sequence>, tuple>, tuple>, - sequence<1, 2>, + sequence<2, 1>, sequence<0, 0>>{}; constexpr auto p_block_dstr_encode = detail::make_embed_tile_distribution_encoding( @@ -521,12 +527,14 @@ struct BlockFmhaFwdDecodePipelineQRKSVSDefaultPolicy constexpr index_t NIterPerWarp = kNPerBlock / (NWarp * WarpGemm::kN); constexpr index_t KIterPerWarp = kKPerBlock / WarpGemm::kK; + // Read N first, then K + // This is the same data consume order as BlockGEMM constexpr auto v_block_outer_dstr_encoding = tile_distribution_encoding, tuple, sequence>, tuple>, tuple>, - sequence<1, 2>, + sequence<2, 1>, sequence<0, 0>>{}; constexpr auto v_block_dstr_encode = detail::make_embed_tile_distribution_encoding( diff --git a/include/ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1.hpp b/include/ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1.hpp index 28d8b3eead..a8f42d50f7 100644 --- a/include/ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1.hpp +++ b/include/ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1.hpp @@ -88,7 +88,7 @@ struct BlockGemmARegBRegCRegV1 tuple, sequence>, tuple>, tuple>, - sequence<1, 2>, + sequence<2, 1>, sequence<0, 0>>{}; constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding( a_block_outer_dstr_encoding, typename WarpGemm::AWarpDstrEncoding{}); @@ -120,7 +120,7 @@ struct BlockGemmARegBRegCRegV1 tuple, sequence>, tuple>, tuple>, - sequence<1, 2>, + sequence<2, 1>, sequence<0, 0>>{}; constexpr auto b_block_dstr_encode = detail::make_embed_tile_distribution_encoding( b_block_outer_dstr_encoding, typename WarpGemm::BWarpDstrEncoding{}); @@ -221,7 +221,7 @@ struct BlockGemmARegBRegCRegV1 AWarpTensor a_warp_tensor; a_warp_tensor.get_thread_buffer() = a_block_tensor.get_y_sliced_thread_data( - merge_sequences(sequence{}, a_warp_y_index_zeros), + merge_sequences(sequence{}, a_warp_y_index_zeros), merge_sequences(sequence<1, 1>{}, a_warp_y_lengths)); static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { @@ -229,7 +229,7 @@ struct BlockGemmARegBRegCRegV1 BWarpTensor b_warp_tensor; b_warp_tensor.get_thread_buffer() = b_block_tensor.get_y_sliced_thread_data( - merge_sequences(sequence{}, b_warp_y_index_zeros), + merge_sequences(sequence{}, b_warp_y_index_zeros), merge_sequences(sequence<1, 1>{}, b_warp_y_lengths)); // read C warp tensor from C block tensor