From d2dadc22a71288b3e6671a6407914a92cf0c5c0c Mon Sep 17 00:00:00 2001 From: Qianfeng Zhang Date: Fri, 26 Dec 2025 15:13:28 +0000 Subject: [PATCH] Remove un-needed constexpr checking for loading v_tiles in Gemm0 loop --- .../hstu_attention_with_softmax_fwd_pipeline.hpp | 14 +++++--------- ..._attention_with_softmax_fwd_trload_pipeline.hpp | 13 +++++-------- 2 files changed, 10 insertions(+), 17 deletions(-) diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_with_softmax_fwd_pipeline.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_with_softmax_fwd_pipeline.hpp index c7d2d76f5e..38a6319fbe 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_with_softmax_fwd_pipeline.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_with_softmax_fwd_pipeline.hpp @@ -342,14 +342,11 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS } else { - // We assume NumPrefetchV >= NumPrefetchK - if constexpr(i_n0 - (n0_loops - NumPrefetchK) < NumPrefetchK) - { - // load v_tiles used in current iteration - v_tiles[number{}] = - load_tile(v_dram_window); - move_tile_window(v_dram_window, {0, kK1}); - } + // Since NumPrefetchV >= NumPrefetchK, we are able to have NumPrefetchK + // prefetchings of v_tile arranged in n0_loops + + v_tiles[number{}] = load_tile(v_dram_window); + move_tile_window(v_dram_window, {0, kK1}); }; __builtin_amdgcn_sched_barrier(0x00000001); @@ -449,7 +446,6 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS __builtin_amdgcn_sched_barrier(0x00000001); static_for{}([&](auto i_k1) { - // load v_tiles used in current iteration v_tiles[i_k1] = load_tile(v_dram_window); move_tile_window(v_dram_window, {0, kK1}); }); diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_with_softmax_fwd_trload_pipeline.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_with_softmax_fwd_trload_pipeline.hpp index 3230ad96df..c151c15d0e 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_with_softmax_fwd_trload_pipeline.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_with_softmax_fwd_trload_pipeline.hpp @@ -342,14 +342,11 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad } else { - // We assume NumPrefetchV >= NumPrefetchK - if constexpr(i_n0 - (n0_loops - NumPrefetchK) < NumPrefetchK) - { - // load v_tiles used in current iteration - v_tiles[number{}] = - load_tile(v_dram_window); - move_tile_window(v_dram_window, {kK1, 0}); - } + // Since NumPrefetchV >= NumPrefetchK, we are able to have NumPrefetchK + // prefetchings of v_tile arranged in n0_loops + + v_tiles[number{}] = load_tile(v_dram_window); + move_tile_window(v_dram_window, {kK1, 0}); }; __builtin_amdgcn_sched_barrier(0x00000001);