diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_with_softmax_fwd_pipeline.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_with_softmax_fwd_pipeline.hpp index c7d2d76f5e..38a6319fbe 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_with_softmax_fwd_pipeline.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_with_softmax_fwd_pipeline.hpp @@ -342,14 +342,11 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS } else { - // We assume NumPrefetchV >= NumPrefetchK - if constexpr(i_n0 - (n0_loops - NumPrefetchK) < NumPrefetchK) - { - // load v_tiles used in current iteration - v_tiles[number{}] = - load_tile(v_dram_window); - move_tile_window(v_dram_window, {0, kK1}); - } + // Since NumPrefetchV >= NumPrefetchK, we are able to have NumPrefetchK + // prefetchings of v_tile arranged in n0_loops + + v_tiles[number{}] = load_tile(v_dram_window); + move_tile_window(v_dram_window, {0, kK1}); }; __builtin_amdgcn_sched_barrier(0x00000001); @@ -449,7 +446,6 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS __builtin_amdgcn_sched_barrier(0x00000001); static_for{}([&](auto i_k1) { - // load v_tiles used in current iteration v_tiles[i_k1] = load_tile(v_dram_window); move_tile_window(v_dram_window, {0, kK1}); }); diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_with_softmax_fwd_trload_pipeline.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_with_softmax_fwd_trload_pipeline.hpp index 3230ad96df..c151c15d0e 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_with_softmax_fwd_trload_pipeline.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_with_softmax_fwd_trload_pipeline.hpp @@ -342,14 +342,11 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad } else { - // We assume NumPrefetchV >= NumPrefetchK - if constexpr(i_n0 - (n0_loops - NumPrefetchK) < NumPrefetchK) - { - // load v_tiles used in current iteration - v_tiles[number{}] = - load_tile(v_dram_window); - move_tile_window(v_dram_window, {kK1, 0}); - } + // Since NumPrefetchV >= NumPrefetchK, we are able to have NumPrefetchK + // prefetchings of v_tile arranged in n0_loops + + v_tiles[number{}] = load_tile(v_dram_window); + move_tile_window(v_dram_window, {kK1, 0}); }; __builtin_amdgcn_sched_barrier(0x00000001);