diff --git a/include/ck_tile/ops/welford/block/block_welford.hpp b/include/ck_tile/ops/welford/block/block_welford.hpp index 55d55402d8..623e1e16d8 100644 --- a/include/ck_tile/ops/welford/block/block_welford.hpp +++ b/include/ck_tile/ops/welford/block/block_welford.hpp @@ -276,8 +276,8 @@ struct BlockWelfordCrossWarpSync fp32x4_t all_scratch[thread_buf_size * num_reduce_warps]; static_for<0, thread_buf_size, 1>{}([&](auto i_0) { static_for<0, num_reduce_warps, 1>{}([&](auto i_1) { - all_scratch[i_0 * num_warps + i_1] = - smem_ptr[i_0 * num_reduce_warps + local_smem_os + i_1]; + all_scratch[i_0 * num_reduce_warps + i_1] = + smem_ptr[i_0 * num_warps + local_smem_os + i_1]; }); }); block_sync_lds(); // TODO: we don't need sync here @@ -286,7 +286,7 @@ struct BlockWelfordCrossWarpSync static_for<0, thread_buf_size, 1>{}([&](auto i_0) { // TODO: use descriptor for this - auto v_local = all_scratch[i_0 * num_warps]; + auto v_local = all_scratch[i_0 * num_reduce_warps]; auto v_local_mean = bit_cast(v_local[0]); auto v_local_var = bit_cast(v_local[1]); auto v_local_count = bit_cast(v_local[2]); @@ -294,7 +294,7 @@ struct BlockWelfordCrossWarpSync // further reduce mean/var static_for<0, num_reduce_warps - 1, 1>{}([&](auto i_1_n1) { constexpr auto i_1 = number{}; - const fp32x4_t v_remote = all_scratch[i_0 * num_warps + i_1]; + const fp32x4_t v_remote = all_scratch[i_0 * num_reduce_warps + i_1]; const auto v_remote_mean = bit_cast(v_remote[0]); const auto v_remote_var = bit_cast(v_remote[1]); const auto v_remote_count = bit_cast(v_remote[2]);