diff --git a/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp b/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp index 3bb30149bf..5844285ffe 100644 --- a/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp +++ b/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp @@ -410,7 +410,7 @@ struct UnifiedAttentionPipeline HEAD_SIZE_PADDED == VDramBlockWindowTmp{}.get_window_lengths()[number<1>{}], "wrong!"); - static_assert(sizeof(SaccDataType) * BLOCK_SIZE <= GetSmemSize()); + static_assert(sizeof(SaccDataType) * BLOCK_SIZE * BLOCK_M <= GetSmemSize()); auto s_lds = make_tensor_view( reinterpret_cast(static_cast(smem_ptr)), MakeSimpleLdsDesc()); @@ -426,7 +426,7 @@ struct UnifiedAttentionPipeline auto o_lds = make_tensor_view( reinterpret_cast(static_cast(smem_ptr)), - MakeSimpleLdsDesc()); + MakeSimpleLdsDesc()); [[maybe_unused]] auto o_lds_window = make_tile_window( o_lds, make_tuple(number{}, number{}), {0, 0}); @@ -542,16 +542,9 @@ struct UnifiedAttentionPipeline clear_tile(l); const auto q_origin = q_dram_window.get_window_origin(); - // const auto [seqlen_k_start, seqlen_k_end] = - // mask.GetTileRangeAlongX(q_origin.at(number<0>{}), number{}, - // number{}); - // const auto num_total_loop = integer_divide_ceil(seqlen_k_end - seqlen_k_start, - // BLOCK_SIZE); const auto num_total_loop = num_blocks; - // index_t kv_token_start = seqlen_k_start; - // TODO check is paddings kPadSeqLenK // check early exit if no work to do if constexpr(FmhaMask::IsMasking) { @@ -567,20 +560,19 @@ struct UnifiedAttentionPipeline index_t i_total_loops = num_blocks_start; const ck_tile::index_t* block_tables_ptr_ = reinterpret_cast(block_tables_ptr); - index_t kv_blk_idx = block_tables_ptr_[block_table_offset + i_total_loops]; - index_t kv_blk_idx_prev = 0; + index_t kv_blk_idx_intial = block_tables_ptr_[block_table_offset + i_total_loops]; auto k_dram_window = make_tile_window(k_dram_block_window_tmp.get_bottom_tensor_view(), k_dram_block_window_tmp.get_window_lengths(), - {(kv_blk_idx - kv_blk_idx_prev) * BLOCK_SIZE, 0}, + {kv_blk_idx_intial * BLOCK_SIZE, 0}, Policy::template MakeKDramTileDistribution()); k_dram_window.init_raw(); auto v_dram_window = make_tile_window(v_dram_block_window_tmp.get_bottom_tensor_view(), v_dram_block_window_tmp.get_window_lengths(), - {(kv_blk_idx - kv_blk_idx_prev) * BLOCK_SIZE, 0}, // TODO: hdim split? + {kv_blk_idx_intial * BLOCK_SIZE, 0}, Policy::template MakeVDramTileDistribution()); v_dram_window.init_raw(); @@ -676,6 +668,7 @@ struct UnifiedAttentionPipeline async_load_tile_raw(k_lds_window_store(k_lds_write_idx), k_dram_window); // TODO maybe needs i_total_loops as argument. Or maybe needs to use the k_lds_write_idx // as the index + index_t kv_blk_idx = block_tables_ptr_[block_table_offset + i_total_loops]; /// FIXME: use the future-predicting method to move the window k_dram_window.set_window_origin({kv_blk_idx * BLOCK_SIZE, 0}); }; @@ -686,7 +679,7 @@ struct UnifiedAttentionPipeline auto V_mem_load = [&](auto v_lds_write_idx) { async_load_tile_raw(v_lds_window_store(v_lds_write_idx), v_dram_window); - // kv_blk_idx = block_tables_ptr_[block_table_offset + i_total_loops]; + index_t kv_blk_idx = block_tables_ptr_[block_table_offset + i_total_loops]; /// FIXME: use the future-predicting method to move the window v_dram_window.set_window_origin({kv_blk_idx * BLOCK_SIZE, 0}); }; @@ -985,7 +978,6 @@ struct UnifiedAttentionPipeline __builtin_amdgcn_s_barrier(); __builtin_amdgcn_sched_barrier(0); cl_load(memK, K_w0_lds_wr_idx, V_w0_lds_rd_idx); - // TODO what is this??? Scheduler::schedule(cl_p, number<1>{}); fmha_mask(xdl_SP_p01_reg_idx);