mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-18 03:49:41 +00:00
Using in-place version of block_tile_reduce() so that using of m_local is avoided
This commit is contained in:
@@ -454,14 +454,10 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0x00000001);
|
||||
|
||||
auto m_local = block_tile_reduce<CompDataType>(
|
||||
pcomp_tile, sequence<1>{}, f_max, -numeric<CompDataType>::infinity());
|
||||
block_tile_reduce_sync(m_local, f_max, bool_constant<false>{});
|
||||
|
||||
const auto m_old = m;
|
||||
|
||||
tile_elementwise_inout(
|
||||
[](auto& e0, auto e1, auto e2) { e0 = max(e1, e2); }, m, m_old, m_local);
|
||||
block_tile_reduce(m, pcomp_tile, sequence<1>{}, f_max);
|
||||
block_tile_reduce_sync(m, f_max, bool_constant<false>{});
|
||||
|
||||
constexpr auto p_spans = decltype(pcomp_tile)::get_distributed_spans();
|
||||
sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) {
|
||||
|
||||
@@ -451,14 +451,10 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0x00000001);
|
||||
|
||||
auto m_local = block_tile_reduce<CompDataType>(
|
||||
pcomp_tile, sequence<1>{}, f_max, -numeric<CompDataType>::infinity());
|
||||
block_tile_reduce_sync(m_local, f_max, bool_constant<false>{});
|
||||
|
||||
const auto m_old = m;
|
||||
|
||||
tile_elementwise_inout(
|
||||
[](auto& e0, auto e1, auto e2) { e0 = max(e1, e2); }, m, m_old, m_local);
|
||||
block_tile_reduce(m, pcomp_tile, sequence<1>{}, f_max);
|
||||
block_tile_reduce_sync(m, f_max, bool_constant<false>{});
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0x00000001);
|
||||
|
||||
|
||||
Reference in New Issue
Block a user