From 73d6e0eb67b53e3f97abd5ed848e80cabd37486e Mon Sep 17 00:00:00 2001 From: Qianfeng Zhang Date: Thu, 5 Mar 2026 16:27:41 +0000 Subject: [PATCH] Using in-place version of block_tile_reduce() so that using of m_local is avoided --- .../hstu_attention_with_softmax_fwd_pipeline.hpp | 8 ++------ .../hstu_attention_with_softmax_fwd_trload_pipeline.hpp | 8 ++------ 2 files changed, 4 insertions(+), 12 deletions(-) diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_with_softmax_fwd_pipeline.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_with_softmax_fwd_pipeline.hpp index c2e6907e3d..c77948e54c 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_with_softmax_fwd_pipeline.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_with_softmax_fwd_pipeline.hpp @@ -454,14 +454,10 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS __builtin_amdgcn_sched_barrier(0x00000001); - auto m_local = block_tile_reduce( - pcomp_tile, sequence<1>{}, f_max, -numeric::infinity()); - block_tile_reduce_sync(m_local, f_max, bool_constant{}); - const auto m_old = m; - tile_elementwise_inout( - [](auto& e0, auto e1, auto e2) { e0 = max(e1, e2); }, m, m_old, m_local); + block_tile_reduce(m, pcomp_tile, sequence<1>{}, f_max); + block_tile_reduce_sync(m, f_max, bool_constant{}); constexpr auto p_spans = decltype(pcomp_tile)::get_distributed_spans(); sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) { diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_with_softmax_fwd_trload_pipeline.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_with_softmax_fwd_trload_pipeline.hpp index 5df7dbabea..1b3ed7ad71 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_with_softmax_fwd_trload_pipeline.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_with_softmax_fwd_trload_pipeline.hpp @@ -451,14 +451,10 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad __builtin_amdgcn_sched_barrier(0x00000001); - auto m_local = block_tile_reduce( - pcomp_tile, sequence<1>{}, f_max, -numeric::infinity()); - block_tile_reduce_sync(m_local, f_max, bool_constant{}); - const auto m_old = m; - tile_elementwise_inout( - [](auto& e0, auto e1, auto e2) { e0 = max(e1, e2); }, m, m_old, m_local); + block_tile_reduce(m, pcomp_tile, sequence<1>{}, f_max); + block_tile_reduce_sync(m, f_max, bool_constant{}); __builtin_amdgcn_sched_barrier(0x00000001);