From 4ee40bcc6fe2cf4599e9f9d7ec99e2cd15db7038 Mon Sep 17 00:00:00 2001 From: letaoqin Date: Sat, 12 Oct 2024 22:45:03 +0800 Subject: [PATCH] change warp_welford.hpp --- .../ck_tile/ops/welford/warp/warp_welford.hpp | 60 ++++--------------- 1 file changed, 10 insertions(+), 50 deletions(-) diff --git a/include/ck_tile/ops/welford/warp/warp_welford.hpp b/include/ck_tile/ops/welford/warp/warp_welford.hpp index 687b61f430..a828e8eb8e 100644 --- a/include/ck_tile/ops/welford/warp/warp_welford.hpp +++ b/include/ck_tile/ops/welford/warp/warp_welford.hpp @@ -44,9 +44,9 @@ struct WarpMergeWelford constexpr index_t idim_p_lane = NDimP - 1; - const auto ps_idx = make_array(get_warp_id(), get_lane_id()); - const auto rs_idx = - mean_tensor.get_tile_distribution().calculate_rs_index_from_ps_index(ps_idx); + // const auto ps_idx = make_array(get_warp_id(), get_lane_id()); + // const auto rs_idx = + // mean_tensor.get_tile_distribution().calculate_rs_index_from_ps_index(ps_idx); constexpr index_t thread_buf_size = MeanDistributedTensor_::get_thread_buffer_size(); static_assert(thread_buf_size == VarDistributedTensor_::get_thread_buffer_size()); @@ -78,13 +78,15 @@ struct WarpMergeWelford // reduction sweep forward static_for<0, nstage, 1>{}([&](auto istage) { - constexpr index_t lid_delta = - lid_over_rid_derivative * (1 << (nstage - istage - 1)); + // xor + index_t src_lane = + (__lane_id()) ^ + (number{}.value); // pull data from remote lane - const auto v_remote_mean = warp_shuffle_down(v_local_mean, lid_delta); - const auto v_remote_var = warp_shuffle_down(v_local_var, lid_delta); - const auto v_remote_count = warp_shuffle_down(v_local_count, lid_delta); + const auto v_remote_mean = warp_shuffle(v_local_mean, src_lane); + const auto v_remote_var = warp_shuffle(v_local_var, src_lane); + const auto v_remote_count = warp_shuffle(v_local_count, src_lane); // welford merge Merge(v_local_mean, @@ -97,48 +99,6 @@ struct WarpMergeWelford } }); - // cross-lane broadcast for replication - // only broadcast on R dimension correspond to lane - // (lane id maps to this R dimension) - if constexpr(BroadcastLane) - { - static_for<0, NDimR, 1>{}([&](auto idim_r) { - // FIXME: nasty to use does_p_own_r_ - if constexpr(DstrEncodeDetail::does_p_own_r_[idim_p_lane][idim_r]) - { - const index_t r_id = rs_idx[idim_r]; - - constexpr index_t r_length = DstrEncode::rs_lengths_[idim_r]; - - constexpr index_t lid_over_rid_derivative = - DstrEncodeDetail::ps_over_rs_derivative_[NDimP - 1][idim_r]; - - static_assert(is_power_of_two_integer(r_length), - "wrong! only support power of 2 reduction"); - - constexpr index_t nstage = integer_log2_floor(r_length); - - // broadcast sweep backward - static_for<0, nstage, 1>{}([&](auto istage) { - // do I hold reduced data? - const bool do_i_hold_reduced_data = r_id < (1 << istage); - - constexpr index_t lid_delta = lid_over_rid_derivative * (1 << istage); - - // pull data from remote lane - const auto v_remote_mean = warp_shuffle_up(v_local_mean, lid_delta); - const auto v_remote_var = warp_shuffle_up(v_local_var, lid_delta); - const auto v_remote_count = warp_shuffle_up(v_local_count, lid_delta); - - // decide whether to update local data with remote data - v_local_mean = do_i_hold_reduced_data ? v_local_mean : v_remote_mean; - v_local_var = do_i_hold_reduced_data ? v_local_var : v_remote_var; - v_local_count = do_i_hold_reduced_data ? v_local_count : v_remote_count; - }); - } - }); - } - mean_tensor.get_thread_buffer()(i) = v_local_mean; if constexpr(GetActualVariance)