From 022365802b43a398deee2bc672785fa31a89297d Mon Sep 17 00:00:00 2001 From: rocking Date: Mon, 14 Oct 2024 12:40:30 +0000 Subject: [PATCH] refine welford max count calculation --- .../kernel/layernorm2d_fwd_kernel.hpp | 73 +++++++------------ 1 file changed, 27 insertions(+), 46 deletions(-) diff --git a/include/ck_tile/ops/layernorm2d/kernel/layernorm2d_fwd_kernel.hpp b/include/ck_tile/ops/layernorm2d/kernel/layernorm2d_fwd_kernel.hpp index 503939626b..4572b51b06 100644 --- a/include/ck_tile/ops/layernorm2d/kernel/layernorm2d_fwd_kernel.hpp +++ b/include/ck_tile/ops/layernorm2d/kernel/layernorm2d_fwd_kernel.hpp @@ -31,9 +31,9 @@ struct Layernorm2dFwd static constexpr ck_tile::index_t kMPerBlock = Problem::BlockShape::kMPerBlock; static constexpr ck_tile::index_t kNPerBlock = Problem::BlockShape::kNPerBlock; - static constexpr bool kPadM = false; // TODO - Problem::kPadM - static constexpr bool kPadN = Problem::kPadN; - static constexpr bool kTwoPass = Problem::kTwoPass; + static constexpr bool kPadM = false; // TODO - BlockLayernorm2dFwdProblem::kPadM + static constexpr bool kPadN = Problem::kPadN; + static constexpr bool kTwoPass = Problem::kTwoPass; static constexpr ck_tile::index_t kNThreadPerWarp = Problem::BlockShape::kNThreadPerWarp; static constexpr ck_tile::index_t kNPerThread = Problem::BlockShape::kNPerThread; @@ -106,21 +106,6 @@ struct Layernorm2dFwd sequence<0, 3>>{}); } - template - CK_TILE_DEVICE static constexpr auto GetNPerThread(Dstr) - { - constexpr auto nDstrSpan = Dstr::get_distributed_spans().template at<1>(); - - using Lengths = decltype(nDstrSpan.impl_); - - ck_tile::index_t ret = 1; - - ck_tile::static_for<0, Lengths::size(), 1>{}( - [&](auto idx) { ret *= Lengths::template at(idx); }); - - return ret; - } - template CK_TILE_DEVICE static auto InvSqrt(const DistributedTensor& in_dstr_tensor, const ComputeDataType epsilon) @@ -139,20 +124,25 @@ struct Layernorm2dFwd return out_dstr_tensor; } - CK_TILE_HOST_DEVICE static constexpr auto - GetLastloopLayerNormIntraLaneReduceCount(index_t NLength) + CK_TILE_DEVICE static int GetWelfordMaxCount(int N) { - using S = typename Problem::BlockShape; - // S::kNWarpPerBlock, S::kNThreadPerWarp, S::kNPerThread - auto LastloopN = NLength % kNPerBlock == 0 ? kNPerBlock : NLength % kNPerBlock; - constexpr auto NThread = S::kNWarpPerBlock * S::kNThreadPerWarp; - auto iNLane = get_thread_local_1d_id() % NThread; - auto iN0 = LastloopN / (S::kNPerThread * S::kNThreadPerWarp); - auto iN1 = (LastloopN % (S::kNPerThread * S::kNThreadPerWarp)) / S::kNPerThread; - auto N2 = (LastloopN % (S::kNPerThread * S::kNThreadPerWarp)) % S::kNPerThread; - auto iN3 = iNLane < iN1 ? S::kNPerThread : iNLane == iN1 ? N2 : 0; + constexpr ck_tile::index_t kNThreadPerBlock = kNPerBlock / kNPerThread; - return iN0 * S::kNPerThread + iN3; + int thread_id_n = get_thread_id() % kNThreadPerBlock; + int max_count = + __builtin_amdgcn_readfirstlane(N < kNPerBlock ? 0 : kNPerThread * (N / kNPerBlock)); + int n_per_block_tail_loop = + __builtin_amdgcn_readfirstlane(N - max_count * kNThreadPerBlock); + + if(n_per_block_tail_loop > 0) + { + int thread_max_n = (thread_id_n + 1) * kNPerThread; + int delta = thread_max_n - n_per_block_tail_loop; + delta = clamp(thread_max_n - n_per_block_tail_loop, 0, kNPerThread); + max_count += kNPerThread - delta; + } + + return max_count; } template thread_welford{intra_thread_count_last}; + int welford_max_count = GetWelfordMaxCount(N); + ThreadWelford thread_welford{welford_max_count}; using XTensorType = decltype(load_tile(x_block_window)); auto mean_compute_block_tensor = @@ -246,15 +236,11 @@ struct Layernorm2dFwd ComputeDataType epsilon, ck_tile::index_t N) const { - using S = typename Problem::BlockShape; index_t num_n_tile_iteration = - __builtin_amdgcn_readfirstlane((N + kNPerBlock - 1) / kNPerBlock); + __builtin_amdgcn_readfirstlane(integer_divide_ceil(N, kNPerBlock)); - auto intra_thread_count = S::kNRepeat * S::kNPerThread * (num_n_tile_iteration - 1); - auto intra_thread_count_last = GetLastloopLayerNormIntraLaneReduceCount(N); - - ThreadWelford thread_welford{intra_thread_count}; - ThreadWelford thread_welford_last{intra_thread_count_last}; + int welford_max_count = GetWelfordMaxCount(N); + ThreadWelford thread_welford{welford_max_count}; using XTensorType = decltype(load_tile(x_block_window)); auto mean_compute_block_tensor = @@ -265,19 +251,13 @@ struct Layernorm2dFwd clear_tile(mean_compute_block_tensor); clear_tile(var_compute_block_tensor); - for(int iN = __builtin_amdgcn_readfirstlane(0); iN < num_n_tile_iteration - 1; ++iN) + for(int iN = __builtin_amdgcn_readfirstlane(0); iN < num_n_tile_iteration; ++iN) { const auto x_block_tensor = load_tile(x_block_window); thread_welford(x_block_tensor, mean_compute_block_tensor, var_compute_block_tensor); move_tile_window(x_block_window, {0, kNPerBlock}); } - const auto x_block_tensor_ = load_tile(x_block_window); - - thread_welford_last.cur_count_ += intra_thread_count; - thread_welford_last.max_count_ += intra_thread_count; - thread_welford_last(x_block_tensor_, mean_compute_block_tensor, var_compute_block_tensor); - thread_welford.cur_count_ += intra_thread_count_last; // TODO: support cross warp Welford WarpMergeWelford{}( @@ -295,6 +275,7 @@ struct Layernorm2dFwd ck_tile::index_t stride_to_right_most_window = N % kNPerBlock == 0 ? N - kNPerBlock : N - N % kNPerBlock; + move_tile_window(x_block_window, {0, -kNPerBlock}); move_tile_window(gamma_block_window, {stride_to_right_most_window}); move_tile_window(beta_block_window, {stride_to_right_most_window}); move_tile_window(y_block_window, {0, stride_to_right_most_window});