diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_appendkv_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_appendkv_kernel.hpp index 74b42bc403..73d622ec9c 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_appendkv_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_appendkv_kernel.hpp @@ -403,6 +403,16 @@ struct FmhaFwdAppendKVKernel const index_t num_blocks = integer_divide_ceil(kargs.seqlen_k + kargs.seqlen_knew, kargs.page_block_size); + DEVICE_DEBUG_STMTS + { + printf("[DEVICE] block_indics: "); + for(index_t i_block = 0; i_block < num_blocks; ++i_block) + { + printf("(%d, %d) ", i_block, block_indices[i_block]); + } + printf("\n"); + } + const long_index_t fixed_offset = static_cast(i_nhead_ / kargs.nhead_ratio_qk) * kargs.nhead_stride_v; @@ -734,25 +744,45 @@ struct FmhaFwdAppendKVKernel auto [i_block1, v_dram_window_tmp] = v_tile_navigator.make_tile_window( v_dram_window, {0, skip_append_kv ? 0 : kargs.seqlen_k + i_n0}); - DEVICE_DEBUG_STMTS + if constexpr(kIsPagedKV) { - printf("[DEVICE] i_block1: %d\n", i_block1); - auto local_origin = v_dram_window_tmp.get_window_origin(); - printf("[DEVICE] origin: (%d, %d)\n", - local_origin.at(number<0>{}), - local_origin.at(number<1>{})); + DEVICE_DEBUG_STMTS + { + printf("[DEVICE] i_block1: %d\n", i_block1); + auto local_origin = v_dram_window_tmp.get_window_origin(); + printf("[DEVICE] origin: (%d, %d)\n", + local_origin.at(number<0>{}), + local_origin.at(number<1>{})); + + printf("[DEVICE] psychical block_ptr 0: %p\n", + static_cast(v_tile_navigator.physical_blocks + + 0 * v_tile_navigator.block_stride)); + printf("[DEVICE] psychical block_ptr 1: %p\n", + static_cast(v_tile_navigator.physical_blocks + + 1 * v_tile_navigator.block_stride)); + + printf("[DEVICE] tile window data ptr: %p\n", + static_cast(v_dram_window_tmp.get_bottom_tensor_view().buf_.p_data_)); + } } auto vnew_dram_window = make_tile_window(vnew_dram, make_tuple(number{}, number{}), {0, i_n0}); - + DEVICE_DEBUG_STMTS + { + printf("[DEVICE] skip_transform_q: %d, skip_appendkv: %d\n", + kargs.seqlen_q <= i_m0, + kargs.seqlen_knew <= i_n0); + } if constexpr(kApplyRoPE) { FmhaPipeline{}(q_dram_window, k_dram_window_tmp, + i_block0, knew_dram_window, v_dram_window_tmp, + i_block1, vnew_dram_window, q_rotary_cos_dram_window, q_rotary_sin_dram_window, @@ -768,8 +798,10 @@ struct FmhaFwdAppendKVKernel { FmhaPipeline{}(q_dram_window, k_dram_window_tmp, + i_block0, knew_dram_window, v_dram_window_tmp, + i_block1, vnew_dram_window, q_rotary_cos_dram_window, q_rotary_sin_dram_window, diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_appendkv_pipeline.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_appendkv_pipeline.hpp index 53467c5e93..c9d8ef9a8f 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_appendkv_pipeline.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_appendkv_pipeline.hpp @@ -89,10 +89,12 @@ struct BlockFmhaFwdAppendKVPipeline CK_TILE_HOST_DEVICE auto operator()(QDramBlockWindow& q_dram_block_window, // M0*K0 tile const QElementFunction& q_element_func, - KDramBlockWindow& k_dram_block_window, // N0*K0 tile + KDramBlockWindow& k_dram_block_window, // N0*K0 tile + index_t i_block0, const KnewDramBlockWindow& knew_dram_block_window, // N0*K0 tile const KnewElementFunction& knew_element_func, - VDramBlockWindow& v_dram_block_window, // N1*N0 tile + VDramBlockWindow& v_dram_block_window, // N1*N0 tile + index_t i_block1, const VnewDramBlockWindow& vnew_dram_block_window, // N1*N0 tile const VnewElementFunction& vnew_element_func, const QRotaryCosDramBlockWindow q_rotary_cos_dram_block_window, @@ -148,8 +150,14 @@ struct BlockFmhaFwdAppendKVPipeline if constexpr(kIsPagedKV) { - /// TODO: handle cross-page-block write store_tile(k_dram_block_window, knew_tile); + + // write tile to another block if nesscary + if(k_tile_navigator.is_closs_block(k_dram_block_window)) + { + k_tile_navigator.move_to_block(i_block0, k_dram_block_window, i_block0 + 1); + store_tile(k_dram_block_window, knew_tile); + } } else { @@ -167,8 +175,14 @@ struct BlockFmhaFwdAppendKVPipeline if constexpr(kIsPagedKV) { - /// TODO: handle cross-page-block write store_tile(v_dram_block_window, vnew_tile); + + // write tile to another block if nesscary + if(v_tile_navigator.is_closs_block(v_dram_block_window)) + { + v_tile_navigator.move_to_block(i_block1, v_dram_block_window, i_block1 + 1); + store_tile(v_dram_block_window, vnew_tile); + } } else { @@ -229,8 +243,10 @@ struct BlockFmhaFwdAppendKVPipeline CK_TILE_HOST_DEVICE auto operator()(QDramBlockWindow& q_dram_block_window, KDramBlockWindow& k_dram_block_window, + index_t i_block0, const KnewDramBlockWindow& knew_dram_block_window, VDramBlockWindow& v_dram_block_window, + index_t i_block1, const VnewDramBlockWindow& vnew_dram_block_window, const QRotaryCosDramBlockWindow& q_rotary_cos_dram_block_window, const QRotarySinDramBlockWindow& q_rotary_sin_dram_block_window, @@ -245,9 +261,11 @@ struct BlockFmhaFwdAppendKVPipeline return operator()(q_dram_block_window, identity{}, k_dram_block_window, + i_block0, knew_dram_block_window, identity{}, v_dram_block_window, + i_block1, vnew_dram_block_window, identity{}, q_rotary_cos_dram_block_window,