mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-07 08:15:04 +00:00
change warp_welford.hpp
This commit is contained in:
@@ -44,9 +44,9 @@ struct WarpMergeWelford
|
||||
|
||||
constexpr index_t idim_p_lane = NDimP - 1;
|
||||
|
||||
const auto ps_idx = make_array<index_t>(get_warp_id(), get_lane_id());
|
||||
const auto rs_idx =
|
||||
mean_tensor.get_tile_distribution().calculate_rs_index_from_ps_index(ps_idx);
|
||||
// const auto ps_idx = make_array<index_t>(get_warp_id(), get_lane_id());
|
||||
// const auto rs_idx =
|
||||
// mean_tensor.get_tile_distribution().calculate_rs_index_from_ps_index(ps_idx);
|
||||
|
||||
constexpr index_t thread_buf_size = MeanDistributedTensor_::get_thread_buffer_size();
|
||||
static_assert(thread_buf_size == VarDistributedTensor_::get_thread_buffer_size());
|
||||
@@ -78,13 +78,15 @@ struct WarpMergeWelford
|
||||
|
||||
// reduction sweep forward
|
||||
static_for<0, nstage, 1>{}([&](auto istage) {
|
||||
constexpr index_t lid_delta =
|
||||
lid_over_rid_derivative * (1 << (nstage - istage - 1));
|
||||
// xor
|
||||
index_t src_lane =
|
||||
(__lane_id()) ^
|
||||
(number<lid_over_rid_derivative << istage.value>{}.value);
|
||||
|
||||
// pull data from remote lane
|
||||
const auto v_remote_mean = warp_shuffle_down(v_local_mean, lid_delta);
|
||||
const auto v_remote_var = warp_shuffle_down(v_local_var, lid_delta);
|
||||
const auto v_remote_count = warp_shuffle_down(v_local_count, lid_delta);
|
||||
const auto v_remote_mean = warp_shuffle(v_local_mean, src_lane);
|
||||
const auto v_remote_var = warp_shuffle(v_local_var, src_lane);
|
||||
const auto v_remote_count = warp_shuffle(v_local_count, src_lane);
|
||||
|
||||
// welford merge
|
||||
Merge(v_local_mean,
|
||||
@@ -97,48 +99,6 @@ struct WarpMergeWelford
|
||||
}
|
||||
});
|
||||
|
||||
// cross-lane broadcast for replication
|
||||
// only broadcast on R dimension correspond to lane
|
||||
// (lane id maps to this R dimension)
|
||||
if constexpr(BroadcastLane)
|
||||
{
|
||||
static_for<0, NDimR, 1>{}([&](auto idim_r) {
|
||||
// FIXME: nasty to use does_p_own_r_
|
||||
if constexpr(DstrEncodeDetail::does_p_own_r_[idim_p_lane][idim_r])
|
||||
{
|
||||
const index_t r_id = rs_idx[idim_r];
|
||||
|
||||
constexpr index_t r_length = DstrEncode::rs_lengths_[idim_r];
|
||||
|
||||
constexpr index_t lid_over_rid_derivative =
|
||||
DstrEncodeDetail::ps_over_rs_derivative_[NDimP - 1][idim_r];
|
||||
|
||||
static_assert(is_power_of_two_integer(r_length),
|
||||
"wrong! only support power of 2 reduction");
|
||||
|
||||
constexpr index_t nstage = integer_log2_floor(r_length);
|
||||
|
||||
// broadcast sweep backward
|
||||
static_for<0, nstage, 1>{}([&](auto istage) {
|
||||
// do I hold reduced data?
|
||||
const bool do_i_hold_reduced_data = r_id < (1 << istage);
|
||||
|
||||
constexpr index_t lid_delta = lid_over_rid_derivative * (1 << istage);
|
||||
|
||||
// pull data from remote lane
|
||||
const auto v_remote_mean = warp_shuffle_up(v_local_mean, lid_delta);
|
||||
const auto v_remote_var = warp_shuffle_up(v_local_var, lid_delta);
|
||||
const auto v_remote_count = warp_shuffle_up(v_local_count, lid_delta);
|
||||
|
||||
// decide whether to update local data with remote data
|
||||
v_local_mean = do_i_hold_reduced_data ? v_local_mean : v_remote_mean;
|
||||
v_local_var = do_i_hold_reduced_data ? v_local_var : v_remote_var;
|
||||
v_local_count = do_i_hold_reduced_data ? v_local_count : v_remote_count;
|
||||
});
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
mean_tensor.get_thread_buffer()(i) = v_local_mean;
|
||||
|
||||
if constexpr(GetActualVariance)
|
||||
|
||||
Reference in New Issue
Block a user