diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_appendkv_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_appendkv_kernel.hpp index ee272cd494..74b42bc403 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_appendkv_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_appendkv_kernel.hpp @@ -255,7 +255,9 @@ struct FmhaFwdAppendKVKernel ck_tile::index_t nhead_stride_knew, ck_tile::index_t nhead_stride_v, ck_tile::index_t nhead_stride_vnew, + ck_tile::index_t batch_stride_k, ck_tile::index_t batch_stride_knew, + ck_tile::index_t batch_stride_v, ck_tile::index_t batch_stride_vnew) { Kargs kargs{{q_ptr, @@ -288,7 +290,9 @@ struct FmhaFwdAppendKVKernel {}, // placeholder for rope reinterpret_cast(seqstart_q_ptr), reinterpret_cast(seqstart_k_ptr), - reinterpret_cast(seqlen_k_ptr)}; + reinterpret_cast(seqlen_k_ptr), + batch_stride_k, + batch_stride_v}; if constexpr(kApplyRoPE) { @@ -371,22 +375,22 @@ struct FmhaFwdAppendKVKernel reinterpret_cast(kargs.block_table_ptr) + i_batch_ * kargs.batch_stride_block_table; const index_t num_blocks = - integer_divide_ceil(kargs.seqlen_k, kargs.page_block_size); + integer_divide_ceil(kargs.seqlen_k + kargs.seqlen_knew, kargs.page_block_size); const long_index_t fixed_offset = static_cast(i_nhead_ / kargs.nhead_ratio_qk) * kargs.nhead_stride_k; - return PagedTileWindowNavigator(kargs.k_ptr, - kargs.batch_stride_k, - fixed_offset, - block_indices, - num_blocks, - kargs.page_block_size); + return PagedTileWindowNavigator(kargs.k_ptr, + kargs.batch_stride_k, + fixed_offset, + block_indices, + num_blocks, + kargs.page_block_size); } else { - return SimpleTileWindowNavigator(); + return SimpleTileWindowNavigator(); } }(); @@ -397,22 +401,22 @@ struct FmhaFwdAppendKVKernel reinterpret_cast(kargs.block_table_ptr) + i_batch_ * kargs.batch_stride_block_table; const index_t num_blocks = - integer_divide_ceil(kargs.seqlen_k, kargs.page_block_size); + integer_divide_ceil(kargs.seqlen_k + kargs.seqlen_knew, kargs.page_block_size); const long_index_t fixed_offset = static_cast(i_nhead_ / kargs.nhead_ratio_qk) * kargs.nhead_stride_v; - return PagedTileWindowNavigator(kargs.v_ptr, - kargs.batch_stride_v, - fixed_offset, - block_indices, - num_blocks, - kargs.page_block_size); + return PagedTileWindowNavigator(kargs.v_ptr, + kargs.batch_stride_v, + fixed_offset, + block_indices, + num_blocks, + kargs.page_block_size); } else { - return SimpleTileWindowNavigator(); + return SimpleTileWindowNavigator(); } }(); @@ -464,7 +468,7 @@ struct FmhaFwdAppendKVKernel }(); const auto k_dram_naive = make_naive_tensor_view( - k_ptr, + k_ptr, // will update this pointer if using paged-kvcache lengths, make_tuple(kargs.stride_k, 1), number{}, @@ -503,7 +507,7 @@ struct FmhaFwdAppendKVKernel }(); const auto v_dram_naive = make_naive_tensor_view( - v_ptr, + v_ptr, // will update this pointer if using paged-kvcache lengths, make_tuple(kargs.stride_v, 1), number{}, @@ -511,8 +515,8 @@ struct FmhaFwdAppendKVKernel const auto v_dram_transposed = transform_tensor_view( v_dram_naive, - make_tuple(make_pass_through_transform(kargs.hdim_v), - make_pass_through_transform(kargs.seqlen_k + kargs.seqlen_knew)), + make_tuple(make_pass_through_transform(lengths.at(number<1>{})), + make_pass_through_transform(lengths.at(number<0>{}))), make_tuple(sequence<1>{}, sequence<0>{}), make_tuple(sequence<0>{}, sequence<1>{})); @@ -535,7 +539,7 @@ struct FmhaFwdAppendKVKernel }(); const auto v_dram_naive = make_naive_tensor_view( - v_ptr, + v_ptr, // will update this pointer if using paged-kvcache lengths, make_tuple(kargs.stride_v, 1), number{}, @@ -700,21 +704,44 @@ struct FmhaFwdAppendKVKernel make_tuple(number{}, number{}), {i_m0, 0}); + /// FIXME: create tile window directly via TileWindowNavigator + const bool skip_append_kv = kargs.seqlen_knew <= i_n0; auto k_dram_window = make_tile_window(k_dram, make_tuple(number{}, number{}), - {kargs.seqlen_k + i_n0, 0}); + {skip_append_kv ? 0 : kargs.seqlen_k + i_n0, 0}); + auto [i_block0, k_dram_window_tmp] = k_tile_navigator.make_tile_window( + k_dram_window, {skip_append_kv ? 0 : kargs.seqlen_k + i_n0, 0}); + DEVICE_DEBUG_STMTS + { + printf("[DEVICE] i_block0: %d\n", i_block0); + auto local_origin = k_dram_window_tmp.get_window_origin(); + printf("[DEVICE] origin: (%d, %d)\n", + local_origin.at(number<0>{}), + local_origin.at(number<1>{})); + } auto knew_dram_window = make_tile_window(knew_dram, make_tuple(number{}, number{}), {i_n0, 0}); + /// FIXME: create tile window directly via TileWindowNavigator auto v_dram_window = make_tile_window(v_dram, make_tuple(number{}, number{}), - {0, kargs.seqlen_k + i_n0}); + {0, skip_append_kv ? 0 : kargs.seqlen_k + i_n0}); + auto [i_block1, v_dram_window_tmp] = v_tile_navigator.make_tile_window( + v_dram_window, {0, skip_append_kv ? 0 : kargs.seqlen_k + i_n0}); + DEVICE_DEBUG_STMTS + { + printf("[DEVICE] i_block1: %d\n", i_block1); + auto local_origin = v_dram_window_tmp.get_window_origin(); + printf("[DEVICE] origin: (%d, %d)\n", + local_origin.at(number<0>{}), + local_origin.at(number<1>{})); + } auto vnew_dram_window = make_tile_window(vnew_dram, make_tuple(number{}, number{}), @@ -723,9 +750,9 @@ struct FmhaFwdAppendKVKernel if constexpr(kApplyRoPE) { FmhaPipeline{}(q_dram_window, - k_dram_window, + k_dram_window_tmp, knew_dram_window, - v_dram_window, + v_dram_window_tmp, vnew_dram_window, q_rotary_cos_dram_window, q_rotary_sin_dram_window, @@ -740,9 +767,9 @@ struct FmhaFwdAppendKVKernel else { FmhaPipeline{}(q_dram_window, - k_dram_window, + k_dram_window_tmp, knew_dram_window, - v_dram_window, + v_dram_window_tmp, vnew_dram_window, q_rotary_cos_dram_window, q_rotary_sin_dram_window,