mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-30 03:37:38 +00:00
Pipeline minor fixes
This commit is contained in:
@@ -410,7 +410,7 @@ struct UnifiedAttentionPipeline
|
||||
HEAD_SIZE_PADDED == VDramBlockWindowTmp{}.get_window_lengths()[number<1>{}],
|
||||
"wrong!");
|
||||
|
||||
static_assert(sizeof(SaccDataType) * BLOCK_SIZE <= GetSmemSize());
|
||||
static_assert(sizeof(SaccDataType) * BLOCK_SIZE * BLOCK_M <= GetSmemSize());
|
||||
auto s_lds = make_tensor_view<address_space_enum::lds>(
|
||||
reinterpret_cast<SaccDataType*>(static_cast<char*>(smem_ptr)),
|
||||
MakeSimpleLdsDesc<BLOCK_M, BLOCK_SIZE>());
|
||||
@@ -426,7 +426,7 @@ struct UnifiedAttentionPipeline
|
||||
|
||||
auto o_lds = make_tensor_view<address_space_enum::lds>(
|
||||
reinterpret_cast<PDataType*>(static_cast<char*>(smem_ptr)),
|
||||
MakeSimpleLdsDesc<BLOCK_M, BLOCK_SIZE>());
|
||||
MakeSimpleLdsDesc<BLOCK_M, HEAD_SIZE_PADDED>());
|
||||
[[maybe_unused]] auto o_lds_window = make_tile_window(
|
||||
o_lds, make_tuple(number<BLOCK_M>{}, number<HEAD_SIZE_PADDED>{}), {0, 0});
|
||||
|
||||
@@ -542,16 +542,9 @@ struct UnifiedAttentionPipeline
|
||||
clear_tile(l);
|
||||
|
||||
const auto q_origin = q_dram_window.get_window_origin();
|
||||
// const auto [seqlen_k_start, seqlen_k_end] =
|
||||
// mask.GetTileRangeAlongX(q_origin.at(number<0>{}), number<BLOCK_M>{},
|
||||
// number<BLOCK_SIZE>{});
|
||||
|
||||
// const auto num_total_loop = integer_divide_ceil(seqlen_k_end - seqlen_k_start,
|
||||
// BLOCK_SIZE);
|
||||
const auto num_total_loop = num_blocks;
|
||||
// index_t kv_token_start = seqlen_k_start;
|
||||
|
||||
// TODO check is paddings kPadSeqLenK
|
||||
// check early exit if no work to do
|
||||
if constexpr(FmhaMask::IsMasking)
|
||||
{
|
||||
@@ -567,20 +560,19 @@ struct UnifiedAttentionPipeline
|
||||
index_t i_total_loops = num_blocks_start;
|
||||
const ck_tile::index_t* block_tables_ptr_ =
|
||||
reinterpret_cast<const ck_tile::index_t*>(block_tables_ptr);
|
||||
index_t kv_blk_idx = block_tables_ptr_[block_table_offset + i_total_loops];
|
||||
index_t kv_blk_idx_prev = 0;
|
||||
index_t kv_blk_idx_intial = block_tables_ptr_[block_table_offset + i_total_loops];
|
||||
|
||||
auto k_dram_window =
|
||||
make_tile_window(k_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
k_dram_block_window_tmp.get_window_lengths(),
|
||||
{(kv_blk_idx - kv_blk_idx_prev) * BLOCK_SIZE, 0},
|
||||
{kv_blk_idx_intial * BLOCK_SIZE, 0},
|
||||
Policy::template MakeKDramTileDistribution<Problem>());
|
||||
k_dram_window.init_raw();
|
||||
|
||||
auto v_dram_window =
|
||||
make_tile_window(v_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
v_dram_block_window_tmp.get_window_lengths(),
|
||||
{(kv_blk_idx - kv_blk_idx_prev) * BLOCK_SIZE, 0}, // TODO: hdim split?
|
||||
{kv_blk_idx_intial * BLOCK_SIZE, 0},
|
||||
Policy::template MakeVDramTileDistribution<Problem>());
|
||||
v_dram_window.init_raw();
|
||||
|
||||
@@ -676,6 +668,7 @@ struct UnifiedAttentionPipeline
|
||||
async_load_tile_raw(k_lds_window_store(k_lds_write_idx), k_dram_window);
|
||||
// TODO maybe needs i_total_loops as argument. Or maybe needs to use the k_lds_write_idx
|
||||
// as the index
|
||||
index_t kv_blk_idx = block_tables_ptr_[block_table_offset + i_total_loops];
|
||||
/// FIXME: use the future-predicting method to move the window
|
||||
k_dram_window.set_window_origin({kv_blk_idx * BLOCK_SIZE, 0});
|
||||
};
|
||||
@@ -686,7 +679,7 @@ struct UnifiedAttentionPipeline
|
||||
|
||||
auto V_mem_load = [&](auto v_lds_write_idx) {
|
||||
async_load_tile_raw(v_lds_window_store(v_lds_write_idx), v_dram_window);
|
||||
// kv_blk_idx = block_tables_ptr_[block_table_offset + i_total_loops];
|
||||
index_t kv_blk_idx = block_tables_ptr_[block_table_offset + i_total_loops];
|
||||
/// FIXME: use the future-predicting method to move the window
|
||||
v_dram_window.set_window_origin({kv_blk_idx * BLOCK_SIZE, 0});
|
||||
};
|
||||
@@ -985,7 +978,6 @@ struct UnifiedAttentionPipeline
|
||||
__builtin_amdgcn_s_barrier();
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
cl_load(memK, K_w0_lds_wr_idx, V_w0_lds_rd_idx);
|
||||
// TODO what is this???
|
||||
Scheduler::schedule(cl_p, number<1>{});
|
||||
fmha_mask(xdl_SP_p01_reg_idx);
|
||||
|
||||
|
||||
Reference in New Issue
Block a user