mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-07 00:04:37 +00:00
refine welford max count calculation
This commit is contained in:
@@ -31,9 +31,9 @@ struct Layernorm2dFwd
|
||||
|
||||
static constexpr ck_tile::index_t kMPerBlock = Problem::BlockShape::kMPerBlock;
|
||||
static constexpr ck_tile::index_t kNPerBlock = Problem::BlockShape::kNPerBlock;
|
||||
static constexpr bool kPadM = false; // TODO - Problem::kPadM
|
||||
static constexpr bool kPadN = Problem::kPadN;
|
||||
static constexpr bool kTwoPass = Problem::kTwoPass;
|
||||
static constexpr bool kPadM = false; // TODO - BlockLayernorm2dFwdProblem::kPadM
|
||||
static constexpr bool kPadN = Problem::kPadN;
|
||||
static constexpr bool kTwoPass = Problem::kTwoPass;
|
||||
|
||||
static constexpr ck_tile::index_t kNThreadPerWarp = Problem::BlockShape::kNThreadPerWarp;
|
||||
static constexpr ck_tile::index_t kNPerThread = Problem::BlockShape::kNPerThread;
|
||||
@@ -106,21 +106,6 @@ struct Layernorm2dFwd
|
||||
sequence<0, 3>>{});
|
||||
}
|
||||
|
||||
template <typename Dstr>
|
||||
CK_TILE_DEVICE static constexpr auto GetNPerThread(Dstr)
|
||||
{
|
||||
constexpr auto nDstrSpan = Dstr::get_distributed_spans().template at<1>();
|
||||
|
||||
using Lengths = decltype(nDstrSpan.impl_);
|
||||
|
||||
ck_tile::index_t ret = 1;
|
||||
|
||||
ck_tile::static_for<0, Lengths::size(), 1>{}(
|
||||
[&](auto idx) { ret *= Lengths::template at(idx); });
|
||||
|
||||
return ret;
|
||||
}
|
||||
|
||||
template <typename DistributedTensor>
|
||||
CK_TILE_DEVICE static auto InvSqrt(const DistributedTensor& in_dstr_tensor,
|
||||
const ComputeDataType epsilon)
|
||||
@@ -139,20 +124,25 @@ struct Layernorm2dFwd
|
||||
return out_dstr_tensor;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto
|
||||
GetLastloopLayerNormIntraLaneReduceCount(index_t NLength)
|
||||
CK_TILE_DEVICE static int GetWelfordMaxCount(int N)
|
||||
{
|
||||
using S = typename Problem::BlockShape;
|
||||
// S::kNWarpPerBlock, S::kNThreadPerWarp, S::kNPerThread
|
||||
auto LastloopN = NLength % kNPerBlock == 0 ? kNPerBlock : NLength % kNPerBlock;
|
||||
constexpr auto NThread = S::kNWarpPerBlock * S::kNThreadPerWarp;
|
||||
auto iNLane = get_thread_local_1d_id() % NThread;
|
||||
auto iN0 = LastloopN / (S::kNPerThread * S::kNThreadPerWarp);
|
||||
auto iN1 = (LastloopN % (S::kNPerThread * S::kNThreadPerWarp)) / S::kNPerThread;
|
||||
auto N2 = (LastloopN % (S::kNPerThread * S::kNThreadPerWarp)) % S::kNPerThread;
|
||||
auto iN3 = iNLane < iN1 ? S::kNPerThread : iNLane == iN1 ? N2 : 0;
|
||||
constexpr ck_tile::index_t kNThreadPerBlock = kNPerBlock / kNPerThread;
|
||||
|
||||
return iN0 * S::kNPerThread + iN3;
|
||||
int thread_id_n = get_thread_id() % kNThreadPerBlock;
|
||||
int max_count =
|
||||
__builtin_amdgcn_readfirstlane(N < kNPerBlock ? 0 : kNPerThread * (N / kNPerBlock));
|
||||
int n_per_block_tail_loop =
|
||||
__builtin_amdgcn_readfirstlane(N - max_count * kNThreadPerBlock);
|
||||
|
||||
if(n_per_block_tail_loop > 0)
|
||||
{
|
||||
int thread_max_n = (thread_id_n + 1) * kNPerThread;
|
||||
int delta = thread_max_n - n_per_block_tail_loop;
|
||||
delta = clamp(thread_max_n - n_per_block_tail_loop, 0, kNPerThread);
|
||||
max_count += kNPerThread - delta;
|
||||
}
|
||||
|
||||
return max_count;
|
||||
}
|
||||
|
||||
template <typename XBlockWindow,
|
||||
@@ -172,8 +162,8 @@ struct Layernorm2dFwd
|
||||
ComputeDataType epsilon,
|
||||
ck_tile::index_t N) const
|
||||
{
|
||||
auto intra_thread_count_last = GetLastloopLayerNormIntraLaneReduceCount(N);
|
||||
ThreadWelford<ComputeDataType, XDataType> thread_welford{intra_thread_count_last};
|
||||
int welford_max_count = GetWelfordMaxCount(N);
|
||||
ThreadWelford<ComputeDataType, XDataType> thread_welford{welford_max_count};
|
||||
|
||||
using XTensorType = decltype(load_tile(x_block_window));
|
||||
auto mean_compute_block_tensor =
|
||||
@@ -246,15 +236,11 @@ struct Layernorm2dFwd
|
||||
ComputeDataType epsilon,
|
||||
ck_tile::index_t N) const
|
||||
{
|
||||
using S = typename Problem::BlockShape;
|
||||
index_t num_n_tile_iteration =
|
||||
__builtin_amdgcn_readfirstlane((N + kNPerBlock - 1) / kNPerBlock);
|
||||
__builtin_amdgcn_readfirstlane(integer_divide_ceil(N, kNPerBlock));
|
||||
|
||||
auto intra_thread_count = S::kNRepeat * S::kNPerThread * (num_n_tile_iteration - 1);
|
||||
auto intra_thread_count_last = GetLastloopLayerNormIntraLaneReduceCount(N);
|
||||
|
||||
ThreadWelford<ComputeDataType, XDataType> thread_welford{intra_thread_count};
|
||||
ThreadWelford<ComputeDataType, XDataType> thread_welford_last{intra_thread_count_last};
|
||||
int welford_max_count = GetWelfordMaxCount(N);
|
||||
ThreadWelford<ComputeDataType, XDataType> thread_welford{welford_max_count};
|
||||
|
||||
using XTensorType = decltype(load_tile(x_block_window));
|
||||
auto mean_compute_block_tensor =
|
||||
@@ -265,19 +251,13 @@ struct Layernorm2dFwd
|
||||
clear_tile(mean_compute_block_tensor);
|
||||
clear_tile(var_compute_block_tensor);
|
||||
|
||||
for(int iN = __builtin_amdgcn_readfirstlane(0); iN < num_n_tile_iteration - 1; ++iN)
|
||||
for(int iN = __builtin_amdgcn_readfirstlane(0); iN < num_n_tile_iteration; ++iN)
|
||||
{
|
||||
const auto x_block_tensor = load_tile(x_block_window);
|
||||
|
||||
thread_welford(x_block_tensor, mean_compute_block_tensor, var_compute_block_tensor);
|
||||
move_tile_window(x_block_window, {0, kNPerBlock});
|
||||
}
|
||||
const auto x_block_tensor_ = load_tile(x_block_window);
|
||||
|
||||
thread_welford_last.cur_count_ += intra_thread_count;
|
||||
thread_welford_last.max_count_ += intra_thread_count;
|
||||
thread_welford_last(x_block_tensor_, mean_compute_block_tensor, var_compute_block_tensor);
|
||||
thread_welford.cur_count_ += intra_thread_count_last;
|
||||
|
||||
// TODO: support cross warp Welford
|
||||
WarpMergeWelford<ComputeDataType, true>{}(
|
||||
@@ -295,6 +275,7 @@ struct Layernorm2dFwd
|
||||
ck_tile::index_t stride_to_right_most_window =
|
||||
N % kNPerBlock == 0 ? N - kNPerBlock : N - N % kNPerBlock;
|
||||
|
||||
move_tile_window(x_block_window, {0, -kNPerBlock});
|
||||
move_tile_window(gamma_block_window, {stride_to_right_most_window});
|
||||
move_tile_window(beta_block_window, {stride_to_right_most_window});
|
||||
move_tile_window(y_block_window, {0, stride_to_right_most_window});
|
||||
|
||||
Reference in New Issue
Block a user