mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-07-02 04:37:02 +00:00
Fix o_spans
This commit is contained in:
@@ -295,14 +295,15 @@ struct FlashAttentionFwdImpl
|
||||
block_tile_reduce_sync(rowsum_p, f_sum);
|
||||
|
||||
// l{j}, Oacc{j}
|
||||
sweep_tile_span(p_spans[I0], [&](auto idx0) {
|
||||
constexpr auto o_spans = decltype(o_acc)::get_distributed_spans();
|
||||
sweep_tile_span(o_spans[I0], [&](auto idx0) {
|
||||
constexpr auto i_idx = make_tuple(idx0);
|
||||
|
||||
const auto tmp = exp(m_old[i_idx] - m[i_idx]);
|
||||
|
||||
l(i_idx) = tmp * l[i_idx] + rowsum_p[i_idx];
|
||||
|
||||
sweep_tile_span(p_spans[I1], [&](auto idx1) {
|
||||
sweep_tile_span(o_spans[I1], [&](auto idx1) {
|
||||
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
|
||||
o_acc(i_j_idx) *= tmp;
|
||||
|
||||
@@ -209,6 +209,7 @@ struct FlashAttentionFwdImpl
|
||||
auto v_lds_window = make_tile_window(
|
||||
v_lds, make_tuple(number<kN1PerBlock>{}, number<kK1PerBlock>{}), {0, 0});
|
||||
#endif
|
||||
|
||||
// reduction function for softmax
|
||||
const auto f_max = [](auto e0, auto e1) { return max(e0, e1); };
|
||||
const auto f_sum = [](auto e0, auto e1) { return e0 + e1; };
|
||||
@@ -294,14 +295,15 @@ struct FlashAttentionFwdImpl
|
||||
block_tile_reduce_sync(rowsum_p, f_sum);
|
||||
|
||||
// l{j}, Oacc{j}
|
||||
sweep_tile_span(p_spans[I0], [&](auto idx0) {
|
||||
constexpr auto o_spans = decltype(o_acc)::get_distributed_spans();
|
||||
sweep_tile_span(o_spans[I0], [&](auto idx0) {
|
||||
constexpr auto i_idx = make_tuple(idx0);
|
||||
|
||||
const auto tmp = exp(m_old[i_idx] - m[i_idx]);
|
||||
|
||||
l(i_idx) = tmp * l[i_idx] + rowsum_p[i_idx];
|
||||
|
||||
sweep_tile_span(p_spans[I1], [&](auto idx1) {
|
||||
sweep_tile_span(o_spans[I1], [&](auto idx1) {
|
||||
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
|
||||
o_acc(i_j_idx) *= tmp;
|
||||
|
||||
Reference in New Issue
Block a user