diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline.hpp index e0c21d26c5..cd16c9a0b3 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline.hpp @@ -254,15 +254,6 @@ struct HstuAttentionFwdPipelineQRKSVS const auto num_loops = integer_divide_ceil(seqlen_k_end - seqlen_k_start, kN0); - // check early exit if no work to do - if constexpr(HstuMask::IsMasking || kPadSeqLenK) - { - if(num_loops <= 0) - { - return o_acc; - } - } - const auto bias_origin = bias_dram_block_window_tmp.get_window_origin(); auto bias_dram_window = make_tile_window(bias_dram_block_window_tmp.get_bottom_tensor_view(), @@ -299,7 +290,7 @@ struct HstuAttentionFwdPipelineQRKSVS index_t i_loop = 0; - do + while(i_loop < num_loops) { static_for<0, k1_loops, 1>{}([&](auto i_k1) { // load v_tile for current unroll @@ -462,7 +453,9 @@ struct HstuAttentionFwdPipelineQRKSVS // this does not occur when k1_loops == 2 and NumKVLdsBuffers == 3 if constexpr(Policy::template IsFirstKLdsBufferOverlapLastVLdsBuffer()) __builtin_amdgcn_s_barrier(); - } while(++i_loop < num_loops); + + i_loop++; + }; tile_elementwise_inout( [&](auto& x) {