mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 02:02:46 +00:00
refactor to combine two kernel
This commit is contained in:
@@ -318,26 +318,26 @@ struct BlockFmhaPipelineQRKSVSAsyncJenga
|
||||
{
|
||||
if(!block_relation_onehot[i_total_loops])
|
||||
{
|
||||
i_total_loops++;
|
||||
if(i_total_loops < num_total_loop)
|
||||
{
|
||||
// move K tile windows
|
||||
move_tile_window(k_dram_block_window, {kN0, 0});
|
||||
k_dram_window.set_window_origin(k_dram_block_window.get_window_origin());
|
||||
|
||||
if(block_relation_onehot[i_total_loops])
|
||||
{
|
||||
async_load_tile_raw(k_lds_store(LdsSeq.at(number<0>{})),
|
||||
k_dram_window,
|
||||
number<-1>{},
|
||||
k_oob_ck,
|
||||
k_pre_np);
|
||||
}
|
||||
move_tile_window(k_dram_window, {0, kK0});
|
||||
move_tile_window(v_dram_window, {0, kN0});
|
||||
continue;
|
||||
}
|
||||
break;
|
||||
// scan-ahead: find the next active block in one shot
|
||||
index_t next = i_total_loops + 1;
|
||||
while(next < num_total_loop && !block_relation_onehot[next])
|
||||
next++;
|
||||
if(next >= num_total_loop)
|
||||
break;
|
||||
const index_t delta = next - i_total_loops;
|
||||
i_total_loops = next;
|
||||
// jump K/V windows to the next active block
|
||||
move_tile_window(k_dram_block_window, {kN0 * delta, 0});
|
||||
k_dram_window.set_window_origin(k_dram_block_window.get_window_origin());
|
||||
move_tile_window(v_dram_window, {0, kN0 * delta});
|
||||
// immediately prefetch the active K tile
|
||||
async_load_tile_raw(k_lds_store(LdsSeq.at(number<0>{})),
|
||||
k_dram_window,
|
||||
number<-1>{},
|
||||
k_oob_ck,
|
||||
k_pre_np);
|
||||
move_tile_window(k_dram_window, {0, kK0});
|
||||
continue;
|
||||
}
|
||||
|
||||
// STAGE 1, QK gemm
|
||||
|
||||
Reference in New Issue
Block a user