Pipeline minor fixes

This commit is contained in:
Tianxing Wu
2025-11-24 10:26:26 +00:00
parent f2fbc44b7b
commit 76d1866537

View File

@@ -410,7 +410,7 @@ struct UnifiedAttentionPipeline
HEAD_SIZE_PADDED == VDramBlockWindowTmp{}.get_window_lengths()[number<1>{}],
"wrong!");
static_assert(sizeof(SaccDataType) * BLOCK_SIZE <= GetSmemSize());
static_assert(sizeof(SaccDataType) * BLOCK_SIZE * BLOCK_M <= GetSmemSize());
auto s_lds = make_tensor_view<address_space_enum::lds>(
reinterpret_cast<SaccDataType*>(static_cast<char*>(smem_ptr)),
MakeSimpleLdsDesc<BLOCK_M, BLOCK_SIZE>());
@@ -426,7 +426,7 @@ struct UnifiedAttentionPipeline
auto o_lds = make_tensor_view<address_space_enum::lds>(
reinterpret_cast<PDataType*>(static_cast<char*>(smem_ptr)),
MakeSimpleLdsDesc<BLOCK_M, BLOCK_SIZE>());
MakeSimpleLdsDesc<BLOCK_M, HEAD_SIZE_PADDED>());
[[maybe_unused]] auto o_lds_window = make_tile_window(
o_lds, make_tuple(number<BLOCK_M>{}, number<HEAD_SIZE_PADDED>{}), {0, 0});
@@ -542,16 +542,9 @@ struct UnifiedAttentionPipeline
clear_tile(l);
const auto q_origin = q_dram_window.get_window_origin();
// const auto [seqlen_k_start, seqlen_k_end] =
// mask.GetTileRangeAlongX(q_origin.at(number<0>{}), number<BLOCK_M>{},
// number<BLOCK_SIZE>{});
// const auto num_total_loop = integer_divide_ceil(seqlen_k_end - seqlen_k_start,
// BLOCK_SIZE);
const auto num_total_loop = num_blocks;
// index_t kv_token_start = seqlen_k_start;
// TODO check is paddings kPadSeqLenK
// check early exit if no work to do
if constexpr(FmhaMask::IsMasking)
{
@@ -567,20 +560,19 @@ struct UnifiedAttentionPipeline
index_t i_total_loops = num_blocks_start;
const ck_tile::index_t* block_tables_ptr_ =
reinterpret_cast<const ck_tile::index_t*>(block_tables_ptr);
index_t kv_blk_idx = block_tables_ptr_[block_table_offset + i_total_loops];
index_t kv_blk_idx_prev = 0;
index_t kv_blk_idx_intial = block_tables_ptr_[block_table_offset + i_total_loops];
auto k_dram_window =
make_tile_window(k_dram_block_window_tmp.get_bottom_tensor_view(),
k_dram_block_window_tmp.get_window_lengths(),
{(kv_blk_idx - kv_blk_idx_prev) * BLOCK_SIZE, 0},
{kv_blk_idx_intial * BLOCK_SIZE, 0},
Policy::template MakeKDramTileDistribution<Problem>());
k_dram_window.init_raw();
auto v_dram_window =
make_tile_window(v_dram_block_window_tmp.get_bottom_tensor_view(),
v_dram_block_window_tmp.get_window_lengths(),
{(kv_blk_idx - kv_blk_idx_prev) * BLOCK_SIZE, 0}, // TODO: hdim split?
{kv_blk_idx_intial * BLOCK_SIZE, 0},
Policy::template MakeVDramTileDistribution<Problem>());
v_dram_window.init_raw();
@@ -676,6 +668,7 @@ struct UnifiedAttentionPipeline
async_load_tile_raw(k_lds_window_store(k_lds_write_idx), k_dram_window);
// TODO maybe needs i_total_loops as argument. Or maybe needs to use the k_lds_write_idx
// as the index
index_t kv_blk_idx = block_tables_ptr_[block_table_offset + i_total_loops];
/// FIXME: use the future-predicting method to move the window
k_dram_window.set_window_origin({kv_blk_idx * BLOCK_SIZE, 0});
};
@@ -686,7 +679,7 @@ struct UnifiedAttentionPipeline
auto V_mem_load = [&](auto v_lds_write_idx) {
async_load_tile_raw(v_lds_window_store(v_lds_write_idx), v_dram_window);
// kv_blk_idx = block_tables_ptr_[block_table_offset + i_total_loops];
index_t kv_blk_idx = block_tables_ptr_[block_table_offset + i_total_loops];
/// FIXME: use the future-predicting method to move the window
v_dram_window.set_window_origin({kv_blk_idx * BLOCK_SIZE, 0});
};
@@ -985,7 +978,6 @@ struct UnifiedAttentionPipeline
__builtin_amdgcn_s_barrier();
__builtin_amdgcn_sched_barrier(0);
cl_load(memK, K_w0_lds_wr_idx, V_w0_lds_rd_idx);
// TODO what is this???
Scheduler::schedule(cl_p, number<1>{});
fmha_mask(xdl_SP_p01_reg_idx);