Fix o_spans

This commit is contained in:
MHYang
2025-04-29 10:58:08 +00:00
parent 89bf0765fb
commit c204cdc382
2 changed files with 7 additions and 4 deletions

View File

@@ -295,14 +295,15 @@ struct FlashAttentionFwdImpl
block_tile_reduce_sync(rowsum_p, f_sum);
// l{j}, Oacc{j}
sweep_tile_span(p_spans[I0], [&](auto idx0) {
constexpr auto o_spans = decltype(o_acc)::get_distributed_spans();
sweep_tile_span(o_spans[I0], [&](auto idx0) {
constexpr auto i_idx = make_tuple(idx0);
const auto tmp = exp(m_old[i_idx] - m[i_idx]);
l(i_idx) = tmp * l[i_idx] + rowsum_p[i_idx];
sweep_tile_span(p_spans[I1], [&](auto idx1) {
sweep_tile_span(o_spans[I1], [&](auto idx1) {
constexpr auto i_j_idx = make_tuple(idx0, idx1);
o_acc(i_j_idx) *= tmp;

View File

@@ -209,6 +209,7 @@ struct FlashAttentionFwdImpl
auto v_lds_window = make_tile_window(
v_lds, make_tuple(number<kN1PerBlock>{}, number<kK1PerBlock>{}), {0, 0});
#endif
// reduction function for softmax
const auto f_max = [](auto e0, auto e1) { return max(e0, e1); };
const auto f_sum = [](auto e0, auto e1) { return e0 + e1; };
@@ -294,14 +295,15 @@ struct FlashAttentionFwdImpl
block_tile_reduce_sync(rowsum_p, f_sum);
// l{j}, Oacc{j}
sweep_tile_span(p_spans[I0], [&](auto idx0) {
constexpr auto o_spans = decltype(o_acc)::get_distributed_spans();
sweep_tile_span(o_spans[I0], [&](auto idx0) {
constexpr auto i_idx = make_tuple(idx0);
const auto tmp = exp(m_old[i_idx] - m[i_idx]);
l(i_idx) = tmp * l[i_idx] + rowsum_p[i_idx];
sweep_tile_span(p_spans[I1], [&](auto idx1) {
sweep_tile_span(o_spans[I1], [&](auto idx1) {
constexpr auto i_j_idx = make_tuple(idx0, idx1);
o_acc(i_j_idx) *= tmp;