From c204cdc382fcd695131056547215384b53b04495 Mon Sep 17 00:00:00 2001 From: MHYang Date: Tue, 29 Apr 2025 10:58:08 +0000 Subject: [PATCH] Fix o_spans --- .../03_flash_attention_fwd/flash_attention_fwd_impl.hpp | 5 +++-- .../flash_attention_fwd_impl.hpp | 6 ++++-- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/example/ck_tile/99_toy_example/03_flash_attention_fwd/flash_attention_fwd_impl.hpp b/example/ck_tile/99_toy_example/03_flash_attention_fwd/flash_attention_fwd_impl.hpp index fbca3a95ac..5a9041973a 100644 --- a/example/ck_tile/99_toy_example/03_flash_attention_fwd/flash_attention_fwd_impl.hpp +++ b/example/ck_tile/99_toy_example/03_flash_attention_fwd/flash_attention_fwd_impl.hpp @@ -295,14 +295,15 @@ struct FlashAttentionFwdImpl block_tile_reduce_sync(rowsum_p, f_sum); // l{j}, Oacc{j} - sweep_tile_span(p_spans[I0], [&](auto idx0) { + constexpr auto o_spans = decltype(o_acc)::get_distributed_spans(); + sweep_tile_span(o_spans[I0], [&](auto idx0) { constexpr auto i_idx = make_tuple(idx0); const auto tmp = exp(m_old[i_idx] - m[i_idx]); l(i_idx) = tmp * l[i_idx] + rowsum_p[i_idx]; - sweep_tile_span(p_spans[I1], [&](auto idx1) { + sweep_tile_span(o_spans[I1], [&](auto idx1) { constexpr auto i_j_idx = make_tuple(idx0, idx1); o_acc(i_j_idx) *= tmp; diff --git a/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/flash_attention_fwd_impl.hpp b/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/flash_attention_fwd_impl.hpp index bd8c209383..5a9041973a 100644 --- a/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/flash_attention_fwd_impl.hpp +++ b/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/flash_attention_fwd_impl.hpp @@ -209,6 +209,7 @@ struct FlashAttentionFwdImpl auto v_lds_window = make_tile_window( v_lds, make_tuple(number{}, number{}), {0, 0}); #endif + // reduction function for softmax const auto f_max = [](auto e0, auto e1) { return max(e0, e1); }; const auto f_sum = [](auto e0, auto e1) { return e0 + e1; }; @@ -294,14 +295,15 @@ struct FlashAttentionFwdImpl block_tile_reduce_sync(rowsum_p, f_sum); // l{j}, Oacc{j} - sweep_tile_span(p_spans[I0], [&](auto idx0) { + constexpr auto o_spans = decltype(o_acc)::get_distributed_spans(); + sweep_tile_span(o_spans[I0], [&](auto idx0) { constexpr auto i_idx = make_tuple(idx0); const auto tmp = exp(m_old[i_idx] - m[i_idx]); l(i_idx) = tmp * l[i_idx] + rowsum_p[i_idx]; - sweep_tile_span(p_spans[I1], [&](auto idx1) { + sweep_tile_span(o_spans[I1], [&](auto idx1) { constexpr auto i_j_idx = make_tuple(idx0, idx1); o_acc(i_j_idx) *= tmp;