diff --git a/include/ck_tile/ops/reduce/block/block_reduce2d.hpp b/include/ck_tile/ops/reduce/block/block_reduce2d.hpp index 5715d82390..beb8c718e3 100644 --- a/include/ck_tile/ops/reduce/block/block_reduce2d.hpp +++ b/include/ck_tile/ops/reduce/block/block_reduce2d.hpp @@ -218,8 +218,6 @@ struct BlockReduce2dCrossWarpSync if constexpr(num_reduce_warps == 1) return; - // CAUSION - Imitate BlockWelfordCrossWarpSync, but looks like there are some bugs. - // store into smem only for lane-0 within one warp if(lane_id == 0) { @@ -237,22 +235,18 @@ struct BlockReduce2dCrossWarpSync static_for<0, num_reduce_warps, 1>{}([&](auto i_1) { all_scratch[i_0 * num_reduce_warps + i_1] = smem_ptr[i_0 * num_warps + local_smem_os + i_1]; - // all_scratch[i_0 * num_warps + i_1] = - // smem_ptr[i_0 * num_reduce_warps + local_smem_os + i_1]; }); }); block_sync_lds(); // TODO: we don't need sync here static_for<0, thread_buf_size, 1>{}([&](auto i_0) { // TODO: use descriptor for this - // auto v_local = all_scratch[i_0 * num_warps]; auto v_local = all_scratch[i_0 * num_reduce_warps]; // further reduce mean/var static_for<0, num_reduce_warps - 1, 1>{}([&](auto i_1_n1) { constexpr auto i_1 = number{}; const DataType v_remote = all_scratch[i_0 * num_reduce_warps + i_1]; - // const DataType v_remote = all_scratch[i_0 * num_warps + i_1]; // reduce v_local = reduce_func(v_local, v_remote);