diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_decode.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_decode.py index 4a23d1dd10..3355a0a620 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_decode.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_decode.py @@ -38,8 +38,8 @@ K0_MAX_SUBMAX_MAP = { } SEQLENQ_MAP = { - "16" : "16", - "32" : "32", + # "16" : "16", + # "32" : "32", # "64" : "64" "128" : "128", } @@ -132,18 +132,18 @@ using trait_{F_idx} = fmha_fwd_decode_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_ namespace {{ template void run_instance(const ck_tile::stream_config& s, fmha_fwd_decode_args a) {{ - if constexpr ({F_hdim} == 128 && {F_bias} == ck_tile::BlockAttentionBiasEnum::NO_BIAS - && (std::is_same_v<{F_mask}, ck_tile::SimplifiedGenericAttentionMask> - || std::is_same_v<{F_mask}, FmhaMasks::NoMask>)) {{ - if (a.max_seqlen_q == 1 && a.nhead_k < a.nhead_q) {{ - instance::run(s, a); - }} else {{ - instance::run(s, a); - }} - }} else {{ - instance::run(s, a); - }} - //instance::run(s, a); + //if constexpr ({F_hdim} == 128 && {F_bias} == ck_tile::BlockAttentionBiasEnum::NO_BIAS + // && (std::is_same_v<{F_mask}, ck_tile::SimplifiedGenericAttentionMask> + // || std::is_same_v<{F_mask}, FmhaMasks::NoMask>)) {{ + // if (a.max_seqlen_q == 1 && a.nhead_k < a.nhead_q) {{ + // instance::run(s, a); + // }} else {{ + // instance::run(s, a); + // }} + //}} else {{ + // instance::run(s, a); + //}} + instance::run(s, a); }} }} // anonymous namespace @@ -152,20 +152,20 @@ void run_instance(const ck_tile::stream_config& s, fmha_fwd_decode_args a) {{ template<> void fmha_fwd_decode_oneshot_(const ck_tile::stream_config& s, fmha_fwd_decode_args a) {{ - if constexpr({F_mode} == false) {{ // batch mode - // we don't check every seqlen_k values for kvcache - if (a.seqlen_k_ptr != nullptr) {{ - run_instance(s, a); - // make sure F_bn0 is divisible by F_bk1 - }} else if (a.seqlen_k % (a.num_splits * {F_bn0}) == 0) {{ - run_instance(s, a); - }} else {{ - run_instance(s, a); - }} - }} else {{ - run_instance(s, a); - }} - //run_instance(s, a); + //if constexpr({F_mode} == false) {{ // batch mode + // // we don't check every seqlen_k values for kvcache + // if (a.seqlen_k_ptr != nullptr) {{ + // run_instance(s, a); + // // make sure F_bn0 is divisible by F_bk1 + // }} else if (a.seqlen_k % (a.num_splits * {F_bn0}) == 0) {{ + // run_instance(s, a); + // }} else {{ + // run_instance(s, a); + // }} + //}} else {{ + // run_instance(s, a); + //}} + run_instance(s, a); }} template<> @@ -658,16 +658,16 @@ class FmhaFwdSplitKVCombineKernel: def get_fmha_fwd_tile_dict_from_dtype(dtype : str) -> Optional[dict]: if dtype == 'fp16' or dtype == 'bf16': return { - '64': { + # '64': { # Specialize for different SeqQ - '16': FmhaFwdTileSize(16, 32, 64, 64, 32, 64, 1, 1, 1, 1, 1, 1, 16, 16, 32, 16, 16, 32, -1), - '32': FmhaFwdTileSize(32, 32, 64, 64, 32, 64, 1, 1, 1, 1, 1, 1, 32, 32, 16, 32, 32, 16, -1), - '128': FmhaFwdTileSize(128, 64, 64, 64, 64, 64, 4, 1, 1, 4, 1, 1, 16, 16, 32, 16, 16, 32, -1), - }, + # '16': FmhaFwdTileSize(16, 32, 64, 64, 32, 64, 1, 1, 1, 1, 1, 1, 16, 16, 32, 16, 16, 32, -1), + # '32': FmhaFwdTileSize(32, 32, 64, 64, 32, 64, 1, 1, 1, 1, 1, 1, 32, 32, 16, 32, 32, 16, -1), + # '128': FmhaFwdTileSize(128, 64, 64, 64, 64, 64, 4, 1, 1, 4, 1, 1, 16, 16, 32, 16, 16, 32, -1), + # }, '128': { - '16': FmhaFwdTileSize(16, 32, 64, 128, 32, 128, 1, 1, 1, 1, 1, 1, 16, 16, 32, 16, 16, 32, -1), - '32': FmhaFwdTileSize(32, 32, 128, 128, 32, 128, 1, 1, 1, 1, 1, 1, 32, 32, 16, 32, 32, 16, -1), - '128': FmhaFwdTileSize(128, 64, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1), + # '16': FmhaFwdTileSize(16, 32, 64, 128, 32, 128, 1, 1, 1, 1, 1, 1, 16, 16, 32, 16, 16, 32, -1), + # '32': FmhaFwdTileSize(32, 32, 128, 128, 32, 128, 1, 1, 1, 1, 1, 1, 32, 32, 16, 32, 32, 16, -1), + '128': FmhaFwdTileSize(128, 64, 32, 128, 16, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1), }, } else: diff --git a/include/ck_tile/core/arch/arch.hpp b/include/ck_tile/core/arch/arch.hpp index 9be466a2d1..dfe9e1aa86 100644 --- a/include/ck_tile/core/arch/arch.hpp +++ b/include/ck_tile/core/arch/arch.hpp @@ -92,19 +92,11 @@ CK_TILE_DEVICE index_t get_thread_id() { return threadIdx.x; } CK_TILE_DEVICE index_t get_block_id() { return blockIdx.x; } +template CK_TILE_DEVICE void block_sync_lds() { -#if CK_TILE_EXPERIMENTAL_BLOCK_SYNC_LDS_WITHOUT_SYNC_VMEM - // asm volatile("\ - // s_waitcnt lgkmcnt(0) \n \ - // s_barrier \ - // " ::); - - __builtin_amdgcn_s_waitcnt(0xc07f); + __builtin_amdgcn_s_waitcnt(CK_TILE_S_CNT_MAX & CK_TILE_LGKMCNT(lgkmcnt)); __builtin_amdgcn_s_barrier(); -#else - __syncthreads(); -#endif } CK_TILE_DEVICE void block_sync_load_raw(index_t cnt = 0) diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_decode_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_decode_kernel.hpp index 55584c6118..ca42d77cc5 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_decode_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_decode_kernel.hpp @@ -831,7 +831,8 @@ struct FmhaFwdDecodeKernel // TODO: Add kVHeadDim // TrLoad Performed in 16x4/16x8/16x16 unit, the fast dimension is 16 elements - constexpr auto XorGroupSize = FmhaPipeline::Problem::BlockFmhaShape::Gemm1WarpTile::at(number<0>{}); + constexpr auto XorGroupSize = + FmhaPipeline::Problem::BlockFmhaShape::Gemm1WarpTile::at(number<0>{}); const auto v_dram_unmerged = transform_tensor_view( v_dram_pad, diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_decode_pipeline_qr_ks_vs.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_decode_pipeline_qr_ks_vs.hpp index b5ed6a483d..2d9833eef8 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_decode_pipeline_qr_ks_vs.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_decode_pipeline_qr_ks_vs.hpp @@ -758,41 +758,24 @@ struct BlockFmhaFwdDecodePipelineQRKSVS auto k_dram_window = make_tile_window( k_dram_block_window_tmp, Policy::template MakeKDramTileDistribution()); - auto k_lds_write_view = - make_tuple(make_tensor_view( - static_cast(smem_ptrk0), - Policy::template MakeKLdsBlockDescriptor()), - make_tensor_view( - static_cast(smem_ptrk1), - Policy::template MakeKLdsBlockDescriptor())); + auto k_lds_write_view = make_tensor_view( + static_cast(smem_ptrk0), + Policy::template MakeKLdsBlockDescriptor()); - auto k_lds_read_view = - make_tuple(make_tensor_view( - static_cast(smem_ptrk0), - Policy::template MakeKLdsBlockDescriptor()), - make_tensor_view( - static_cast(smem_ptrk1), - Policy::template MakeKLdsBlockDescriptor())); + auto k_lds_read_view = make_tensor_view( + static_cast(smem_ptrk0), + Policy::template MakeKLdsBlockDescriptor()); - auto k_lds_write_windows = - make_tuple(make_tile_window( - k_lds_write_view.at(I0), - Policy::template MakeKLdsBlockDescriptor().get_lengths(), - {0, 0}), - make_tile_window( - k_lds_write_view.at(I1), - Policy::template MakeKLdsBlockDescriptor().get_lengths(), - {0, 0})); + auto k_lds_write_window = + make_tile_window(k_lds_write_view, + Policy::template MakeKLdsBlockDescriptor().get_lengths(), + {0, 0}); - auto k_lds_read_windows = - make_tuple(make_tile_window(k_lds_read_view.at(I0), - make_tuple(number{}, number{}), - {0, 0}, - Policy::template MakeKRegTileDistribution()), - make_tile_window(k_lds_read_view.at(I1), - make_tuple(number{}, number{}), - {0, 0}, - Policy::template MakeKRegTileDistribution())); + auto k_lds_read_window = + make_tile_window(k_lds_read_view, + make_tuple(number{}, number{}), + {0, 0}, + Policy::template MakeKRegTileDistribution()); // S tile in LDS auto s_lds = make_tensor_view( @@ -811,39 +794,24 @@ struct BlockFmhaFwdDecodePipelineQRKSVS auto v_dram_window = make_tile_window( v_dram_block_window_tmp, Policy::template MakeVDramTileDistribution()); - auto v_lds_write_view = make_tuple( - make_tensor_view( - reinterpret_cast(static_cast(smem_ptrv0)), - Policy::template MakeVLdsBlockDescriptor()), - make_tensor_view( - reinterpret_cast(static_cast(smem_ptrv1)), - Policy::template MakeVLdsBlockDescriptor())); + auto v_lds_write_view = make_tensor_view( + reinterpret_cast(static_cast(smem_ptrv0)), + Policy::template MakeVLdsBlockDescriptor()); - auto v_lds_read_view = make_tuple( - make_tensor_view( - reinterpret_cast(static_cast(smem_ptrv0)), - Policy::template MakeVLdsBlockDescriptor()), - make_tensor_view( - reinterpret_cast(static_cast(smem_ptrv1)), - Policy::template MakeVLdsBlockDescriptor())); + auto v_lds_read_view = make_tensor_view( + reinterpret_cast(static_cast(smem_ptrv0)), + Policy::template MakeVLdsBlockDescriptor()); - auto v_lds_write_windows = make_tuple( - make_tile_window(v_lds_write_view.at(I0), + auto v_lds_write_window = + make_tile_window(v_lds_write_view, Policy::template MakeVLdsBlockDescriptor().get_lengths(), - {0, 0}), - make_tile_window(v_lds_write_view.at(I1), - Policy::template MakeVLdsBlockDescriptor().get_lengths(), - {0, 0})); + {0, 0}); - auto v_lds_read_windows = - make_tuple(make_tile_window(v_lds_read_view.at(I0), - make_tuple(number{}, number{}), - {0, 0}, - Policy::template MakeVRegTileDistribution()), - make_tile_window(v_lds_read_view.at(I1), - make_tuple(number{}, number{}), - {0, 0}, - Policy::template MakeVRegTileDistribution())); + auto v_lds_read_window = + make_tile_window(v_lds_read_view, + make_tuple(number{}, number{}), + {0, 0}, + Policy::template MakeVRegTileDistribution()); // block_sync_lds_direct_load<0>(); // auto q_tile = load_tile(q_lds_read_window); @@ -857,31 +825,41 @@ struct BlockFmhaFwdDecodePipelineQRKSVS static_assert(1 <= k0_loops); static_assert(1 <= k1_loops); - async_load_tile(k_lds_write_windows.at(I0), k_dram_window); - async_load_tile(v_lds_write_windows.at(I0), v_dram_window); - + async_load_tile(k_lds_write_window, k_dram_window); + async_load_tile(v_lds_write_window, v_dram_window); + move_tile_window(k_dram_window, {kN0, 0}); - async_load_tile(k_lds_write_windows.at(I1), k_dram_window); + k_lds_write_window.set_bottom_tensor_view_data_ptr( + static_cast(smem_ptrk1)); + async_load_tile(k_lds_write_window, k_dram_window); constexpr index_t k_vmem_insts = k_dram_window.get_num_of_access(); constexpr index_t v_vmem_insts = v_dram_window.get_num_of_access(); + constexpr index_t k_lds_insts = k_lds_read_window.get_num_of_access(); + constexpr index_t v_lds_insts = v_lds_read_window.get_num_of_access(); + block_sync_lds_direct_load(); - auto k_tile = load_tile(k_lds_read_windows.at(I0)); + auto k_tile = load_tile(k_lds_read_window); __builtin_amdgcn_sched_barrier(0); auto mainloop = [&](index_t cur_loop) { - - auto k_lds_write_window = (cur_loop%2 == 0)? k_lds_write_windows.at(I0) : k_lds_write_windows.at(I1); - auto k_lds_read_window_cur = (cur_loop%2 == 0)? k_lds_read_windows.at(I0) : k_lds_read_windows.at(I1); - auto k_lds_read_window_next = (cur_loop%2 == 0)? k_lds_read_windows.at(I1) : k_lds_read_windows.at(I0); - auto v_lds_write_window = (cur_loop%2 == 0)? v_lds_write_windows.at(I1) : v_lds_write_windows.at(I0); - auto v_lds_read_window = (cur_loop%2 == 0)? v_lds_read_windows.at(I0) : v_lds_read_windows.at(I1); + const bool is_even_loop = (cur_loop % 2 == 0); + + auto k_lds_write_ptr = is_even_loop ? static_cast(smem_ptrk0) + : static_cast(smem_ptrk1); + auto k_lds_read_ptr = is_even_loop ? static_cast(smem_ptrk1) + : static_cast(smem_ptrk0); + auto v_lds_write_ptr = is_even_loop ? static_cast(smem_ptrv1) + : static_cast(smem_ptrv0); + auto v_lds_read_ptr = is_even_loop ? static_cast(smem_ptrv0) + : static_cast(smem_ptrv1); // move V tile windows - block_sync_lds(); + block_sync_lds(); move_tile_window(v_dram_window, {kN0, 0}); + v_lds_write_window.set_bottom_tensor_view_data_ptr(v_lds_write_ptr); async_load_tile(v_lds_write_window, v_dram_window); // STAGE 1, QK gemm @@ -891,19 +869,19 @@ struct BlockFmhaFwdDecodePipelineQRKSVS { static_for<0, k0_loops - 1, 1>{}([&](auto i_k0) { // loop over along the [K]ey head dimension - move_tile_window(k_lds_read_window_cur, {0, kK0}); - auto k_tile_switch = load_tile(k_lds_read_window_cur); + move_tile_window(k_lds_read_window, {0, kK0}); + auto k_tile_switch = load_tile(k_lds_read_window); gemm_0(s_acc, get_slice_tile(q_tile, sequence<0, i_k0 * kK0>{}, sequence{}), k_tile); - + k_tile = k_tile_switch; }); // move back to the origin - move_tile_window(k_lds_read_window_cur, {0, -kK0 * (k0_loops - 1)}); + move_tile_window(k_lds_read_window, {0, -kK0 * (k0_loops - 1)}); } gemm_0(s_acc, @@ -911,21 +889,10 @@ struct BlockFmhaFwdDecodePipelineQRKSVS sequence<0, (k0_loops - 1) * kK0>{}, sequence{}), k_tile); - - block_sync_lds_direct_load(); - auto v_tile = load_tile_transpose(v_lds_read_window); - - static_for<0, 14, 1>{}([&](auto i) { - ignore = i; - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS_READ - }); - static_for<0, 2, 1>{}([&](auto i) { - ignore = i; - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x100, 2, 0); // DS_READ - }); + block_sync_lds_direct_load(); + v_lds_read_window.set_bottom_tensor_view_data_ptr(v_lds_read_ptr); + auto v_tile = load_tile_transpose(v_lds_read_window); if constexpr(kHasUnevenSplits) { @@ -991,6 +958,18 @@ struct BlockFmhaFwdDecodePipelineQRKSVS -numeric::infinity()); // m_local = rowmax(S{j}) block_tile_reduce_sync(m_local, f_max, bool_constant{}); + static_for<0, 12, 1>{}([&](auto i) { + ignore = i; + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS_READ + }); + + static_for<0, 4, 1>{}([&](auto i) { + ignore = i; + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x100, 2, 0); // DS_READ + }); + const auto m_old = m; // m{j-1} tile_elementwise_inout( [](auto& e0, auto e1, auto e2) { e0 = max(e1, e2); }, m, m_old, m_local); // m{j} @@ -1079,8 +1058,9 @@ struct BlockFmhaFwdDecodePipelineQRKSVS }); }); - block_sync_lds(); + block_sync_lds(); move_tile_window(k_dram_window, {kN0, 0}); + k_lds_write_window.set_bottom_tensor_view_data_ptr(k_lds_write_ptr); async_load_tile(k_lds_write_window, k_dram_window); if constexpr(1 < k1_loops) @@ -1108,15 +1088,16 @@ struct BlockFmhaFwdDecodePipelineQRKSVS sequence{}), v_tile); - k_tile = load_tile(k_lds_read_window_next); - - static_for<0, 14, 1>{}([&](auto i) { + k_lds_read_window.set_bottom_tensor_view_data_ptr(k_lds_read_ptr); + k_tile = load_tile(k_lds_read_window); + + static_for<0, 12, 1>{}([&](auto i) { ignore = i; __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA __builtin_amdgcn_sched_group_barrier(0x100, 2, 0); // DS_READ }); - static_for<0, 2, 1>{}([&](auto i) { + static_for<0, 4, 1>{}([&](auto i) { ignore = i; __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS_READ diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_decode_pipeline_qr_ks_vs_policy.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_decode_pipeline_qr_ks_vs_policy.hpp index dc9be6caff..21b35475fe 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_decode_pipeline_qr_ks_vs_policy.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_decode_pipeline_qr_ks_vs_policy.hpp @@ -292,7 +292,8 @@ struct BlockFmhaFwdDecodePipelineQRKSVSDefaultPolicy constexpr auto v_lds_block_desc = [&]() { if constexpr(Xor) { - constexpr auto XorGroupSize = Problem::BlockFmhaShape::Gemm1WarpTile::at(number<0>{}); + constexpr auto XorGroupSize = + Problem::BlockFmhaShape::Gemm1WarpTile::at(number<0>{}); constexpr auto v_lds_block_desc_naive = make_naive_tensor_descriptor( make_tuple(number{}, number{}), @@ -303,27 +304,24 @@ struct BlockFmhaFwdDecodePipelineQRKSVSDefaultPolicy const auto v_lds_block_desc_unmerged = transform_tensor_descriptor( v_lds_block_desc_naive, make_tuple(make_pass_through_transform(number{}), - make_unmerge_transform( - make_tuple(number{}, - number{}))), + make_unmerge_transform(make_tuple( + number{}, number{}))), make_tuple(sequence<0>{}, sequence<1>{}), make_tuple(sequence<0>{}, sequence<1, 2>{})); const auto v_lds_block_desc_permuted = transform_tensor_descriptor( v_lds_block_desc_unmerged, - make_tuple( - make_xor_transform(make_tuple(number{}, - number{})), - make_pass_through_transform(number{})), + make_tuple(make_xor_transform(make_tuple(number{}, + number{})), + make_pass_through_transform(number{})), make_tuple(sequence<0, 1>{}, sequence<2>{}), make_tuple(sequence<0, 1>{}, sequence<2>{})); return transform_tensor_descriptor( v_lds_block_desc_permuted, make_tuple(make_pass_through_transform(number{}), - make_merge_transform_v3_division_mod( - make_tuple(number{}, - number{}))), + make_merge_transform_v3_division_mod(make_tuple( + number{}, number{}))), make_tuple(sequence<0>{}, sequence<1, 2>{}), make_tuple(sequence<0>{}, sequence<1>{})); }