diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline.hpp index a0ce38c92d..790f6e2b90 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline.hpp @@ -356,7 +356,7 @@ struct HstuAttentionFwdPipelineQRKSVS set_tile_if( sacc_tiles[i_k1], type_convert(0), [&](auto tile_idx) { const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{}); - const auto col = seqlen_k_curr + i_k1 * kK1 + tile_idx.at(number<1>{}); + const auto col = seqlen_k_curr + tile_idx.at(number<1>{}); return !mask.IsTokenPairInsideMask(row, col); }); } @@ -368,14 +368,15 @@ struct HstuAttentionFwdPipelineQRKSVS sacc_tiles[i_k1], type_convert(0), [&](auto tile_idx) { const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{}); - const auto col = - seqlen_k_curr + i_k1 * kK1 + tile_idx.at(number<1>{}); + const auto col = seqlen_k_curr + tile_idx.at(number<1>{}); return !mask.IsTokenPairInsideMask(row, col); }); } } pcomp_tiles[i_k1] = cast_tile(sacc_tiles[i_k1]); + + seqlen_k_curr += kK1; }); // load one k_tile for next iteration @@ -415,11 +416,9 @@ struct HstuAttentionFwdPipelineQRKSVS reinterpret_cast(smem_ptr) + Policy::template GetSmemSizeKV(); dropout.template Run( - randval_lds_ptr, seqlen_k_curr, pcomp_tiles[I0], null_randval_window); + randval_lds_ptr, seqlen_k_curr - kN0, pcomp_tiles[I0], null_randval_window); } - seqlen_k_curr += kK1; - auto p = [&]() { if constexpr(std::is_same_v) return impl::cast_tile_pk_fp16_fp32( @@ -467,13 +466,11 @@ struct HstuAttentionFwdPipelineQRKSVS dropout.template Run( randval_lds_ptr, - seqlen_k_curr, + seqlen_k_curr - kN0 + (i_k1 + 1) * kK1, pcomp_tiles[number{}], null_randval_window); } - seqlen_k_curr += kK1; - p = [&]() { if constexpr(std::is_same_v) return impl::cast_tile_pk_fp16_fp32(tile_elementwise_in(