diff --git a/example/ck_tile/02_layernorm2d/layernorm2d_fwd.cpp b/example/ck_tile/02_layernorm2d/layernorm2d_fwd.cpp index 9cbd286104..35f291e060 100644 --- a/example/ck_tile/02_layernorm2d/layernorm2d_fwd.cpp +++ b/example/ck_tile/02_layernorm2d/layernorm2d_fwd.cpp @@ -35,7 +35,9 @@ float layernorm2d_fwd(layernorm2d_fwd_traits t, YDataType, MeanDataType, InvStdDataType, - Shape>; + Shape, + true, + true>; using Kernel = ck_tile::Layernorm2dFwd; diff --git a/include/ck_tile/ops/layernorm2d/kernel/layernorm2d_fwd_kernel.hpp b/include/ck_tile/ops/layernorm2d/kernel/layernorm2d_fwd_kernel.hpp index 4be3e56874..468df793da 100644 --- a/include/ck_tile/ops/layernorm2d/kernel/layernorm2d_fwd_kernel.hpp +++ b/include/ck_tile/ops/layernorm2d/kernel/layernorm2d_fwd_kernel.hpp @@ -31,8 +31,14 @@ struct Layernorm2dFwd static constexpr ck_tile::index_t kMPerBlock = Problem::BlockShape::kMPerBlock; static constexpr ck_tile::index_t kNPerBlock = Problem::BlockShape::kNPerBlock; + static constexpr bool kPadM = Problem::kPadM; + static constexpr bool kPadN = Problem::kPadN; static constexpr ck_tile::index_t kNThreadPerWarp = Problem::BlockShape::kNThreadPerWarp; + static constexpr ck_tile::index_t kNPerThread = Problem::BlockShape::kNPerThread; + + static constexpr auto I0 = number<0>{}; + static constexpr auto I1 = number<1>{}; struct Kargs { @@ -96,19 +102,25 @@ struct Layernorm2dFwd sequence<2>>{}); } - template - CK_TILE_DEVICE static constexpr auto GetNPerThread(Dstr) + CK_TILE_DEVICE static int GetWelfordMaxCount(int N) { - constexpr auto nDstrSpan = Dstr::get_distributed_spans().template at<1>(); + constexpr ck_tile::index_t kNThreadPerBlock = kNPerBlock / kNPerThread; - using Lengths = decltype(nDstrSpan.impl_); + int thread_id_n = get_thread_id() % kNThreadPerBlock; + int max_count = + __builtin_amdgcn_readfirstlane(N < kNPerBlock ? 0 : kNPerThread * (N / kNPerBlock)); + int n_per_block_tail_loop = + __builtin_amdgcn_readfirstlane(N - max_count * kNThreadPerBlock); - ck_tile::index_t ret = 1; + if(n_per_block_tail_loop > 0) + { + int thread_max_n = (thread_id_n + 1) * kNPerThread; + int delta = thread_max_n - n_per_block_tail_loop; + delta = clamp(thread_max_n - n_per_block_tail_loop, 0, kNPerThread); + max_count += kNPerThread - delta; + } - ck_tile::static_for<0, Lengths::size(), 1>{}( - [&](auto idx) { ret *= Lengths::template at(idx); }); - - return ret; + return max_count; } template @@ -129,42 +141,29 @@ struct Layernorm2dFwd return out_dstr_tensor; } - template - CK_TILE_DEVICE std::enable_if_t TwoPassLayernorm2dFwd(const XDataType* p_x, - const GammaDataType* p_gamma, - const BetaDataType* p_beta, - YDataType* p_y, - MeanDataType* p_mean, - InvStdDataType* p_invStd, - const ComputeDataType epsilon, - ck_tile::index_t M, - ck_tile::index_t N) const + template + CK_TILE_DEVICE std::enable_if_t + TwoPassLayernorm2dFwd(XBlockWindow& x_block_window, + GammaBlockWindow& gamma_block_window, + BetaBlockWindow& beta_block_window, + YBlockWindow& y_block_window, + MeanBlockWindow& mean_block_window, + InvStdBlockWindow& inv_std_block_window, + ComputeDataType epsilon, + ck_tile::index_t N) const { - constexpr auto I0 = number<0>{}; - constexpr auto I1 = number<1>{}; + // TODO - Optimize tail loop to reduce move_tile_window() + index_t num_n_tile_iteration = + __builtin_amdgcn_readfirstlane(integer_divide_ceil(N, kNPerBlock)); - const auto x_m_n = make_naive_tensor_view( - p_x, make_tuple(M, N), make_tuple(N, 1), number<32>{}, number<1>{}); - - const auto gamma_n = make_naive_tensor_view( - p_gamma, make_tuple(N), make_tuple(1), number<32>{}, number<1>{}); - - const auto beta_n = make_naive_tensor_view( - p_beta, make_tuple(N), make_tuple(1), number<32>{}, number<1>{}); - - const auto iM = get_block_id() * kMPerBlock; - - constexpr auto xDstr = MakeXBlockTileDistribution(); - - auto x_block_window = make_tile_window( - x_m_n, make_tuple(number{}, number{}), {iM, 0}, xDstr); - - index_t num_n_tile_iteration = __builtin_amdgcn_readfirstlane(N / kNPerBlock); - - // TODO: padding - handle max_count if N % kNPerBlock != 0 - constexpr auto NPerThread = GetNPerThread(xDstr); - ThreadWelford thread_welford{ - type_convert(NPerThread * N / kNPerBlock)}; + int welford_max_count = GetWelfordMaxCount(N); + ThreadWelford thread_welford{welford_max_count}; using XTensorType = decltype(load_tile(x_block_window)); auto mean_compute_block_tensor = @@ -190,44 +189,14 @@ struct Layernorm2dFwd auto inv_std_compute_block_tensor = InvSqrt(var_compute_block_tensor, epsilon); if constexpr(kSaveMean) - { - const auto mean_m = make_naive_tensor_view_packed( - p_mean, make_tuple(M), number<32>{}); - - auto mean_block_window = - make_tile_window(mean_m, make_tuple(number{}), {iM}); - store_tile(mean_block_window, cast_tile(mean_compute_block_tensor)); - } if constexpr(kSaveInvStd) - { - const auto inv_std_m = make_naive_tensor_view_packed( - p_invStd, make_tuple(M), number<32>{}); - - auto inv_std_block_window = - make_tile_window(inv_std_m, make_tuple(number{}), {iM}); - - store_tile(inv_std_block_window, cast_tile(inv_std_compute_block_tensor)); - } - - // TODO: Extract normalize pipeline - const auto y_m_n = make_naive_tensor_view( - p_y, make_tuple(M, N), make_tuple(N, 1), number<32>{}, number<1>{}); - - auto y_block_window = make_tile_window( - y_m_n, make_tuple(number{}, number{}), {iM, 0}); - - constexpr auto gammaDstr = MakeGammaBetaBlockTileDistribution(); - constexpr auto betaDstr = gammaDstr; - - auto gamma_block_window = - make_tile_window(gamma_n, make_tuple(number{}), {0}, gammaDstr); - - auto beta_block_window = make_tile_window( - beta_n, make_tuple(number{}, number{}), {0}, betaDstr); + store_tile(inv_std_block_window, + cast_tile(inv_std_compute_block_tensor)); // reverse read x to reuse cache - ck_tile::index_t stride_to_right_most_window = N - kNPerBlock; + ck_tile::index_t stride_to_right_most_window = + N % kNPerBlock == 0 ? N - kNPerBlock : N - N % kNPerBlock; move_tile_window(x_block_window, {0, -kNPerBlock}); move_tile_window(gamma_block_window, {stride_to_right_most_window}); @@ -274,17 +243,209 @@ struct Layernorm2dFwd } } + template + CK_TILE_DEVICE std::enable_if_t + OnePassLayernorm2dFwd(XBlockWindow& x_block_window, + GammaBlockWindow& gamma_block_window, + BetaBlockWindow& beta_block_window, + YBlockWindow& y_block_window, + MeanBlockWindow& mean_block_window, + InvStdBlockWindow& inv_std_block_window, + ComputeDataType epsilon, + ck_tile::index_t N) const + { + int welford_max_count = GetWelfordMaxCount(N); + ThreadWelford thread_welford{welford_max_count}; + + using XTensorType = decltype(load_tile(x_block_window)); + auto mean_compute_block_tensor = + thread_welford.template MakeInitialMeanVarDistributedTensor(); + auto var_compute_block_tensor = + thread_welford.template MakeInitialMeanVarDistributedTensor(); + + clear_tile(mean_compute_block_tensor); + clear_tile(var_compute_block_tensor); + + const auto x_block_tensor = load_tile(x_block_window); + thread_welford(x_block_tensor, mean_compute_block_tensor, var_compute_block_tensor); + // TODO: support cross warp Welford + WarpMergeWelford{}( + mean_compute_block_tensor, var_compute_block_tensor, thread_welford.cur_count_); + + auto inv_std_compute_block_tensor = InvSqrt(var_compute_block_tensor, epsilon); + + if constexpr(kSaveMean) + store_tile(mean_block_window, cast_tile(mean_compute_block_tensor)); + if constexpr(kSaveInvStd) + store_tile(inv_std_block_window, + cast_tile(inv_std_compute_block_tensor)); + + // normalize + const auto gamma_block_tensor = load_tile(gamma_block_window); + const auto beta_block_tensor = load_tile(beta_block_window); + + constexpr auto x_spans = decltype(x_block_tensor)::get_distributed_spans(); + + auto y_block_tensor = + make_static_distributed_tensor(x_block_tensor.get_tile_distribution()); + + sweep_tile_span(x_spans[I1], [&](auto idx1) { + constexpr auto j_idx = make_tuple(idx1); + const auto gamma = type_convert(gamma_block_tensor[j_idx]); + const auto beta = type_convert(beta_block_tensor[j_idx]); + + sweep_tile_span(x_spans[I0], [&](auto idx0) { + constexpr auto i_idx = make_tuple(idx0); + constexpr auto i_j_idx = make_tuple(idx0, idx1); + + const auto mean = mean_compute_block_tensor[i_idx]; + const auto inv_std = inv_std_compute_block_tensor[i_idx]; + + const auto x = type_convert(x_block_tensor[i_j_idx]); + auto y = (x - mean) * inv_std * gamma + beta; + + y_block_tensor(i_j_idx) = type_convert(y); + }); + }); + + store_tile(y_block_window, y_block_tensor); + } + CK_TILE_DEVICE void operator()(Kargs kargs) const { - TwoPassLayernorm2dFwd(static_cast(kargs.p_x), - static_cast(kargs.p_gamma), - static_cast(kargs.p_beta), - static_cast(kargs.p_y), - static_cast(kargs.p_mean), - static_cast(kargs.p_invStd), - static_cast(kargs.epsilon), - kargs.M, - kargs.N); + const auto x_m_n = [&]() { + const auto x_dram_naive = make_naive_tensor_view( + static_cast(kargs.p_x), + make_tuple(kargs.M, kargs.N), + make_tuple(kargs.N, 1), + number{}, + number<1>{}); + + return pad_tensor_view(x_dram_naive, + make_tuple(number{}, number{}), + sequence{}); + }(); + + const auto gamma_n = [&]() { + const auto gamma_dram_naive = make_naive_tensor_view( + static_cast(kargs.p_gamma), + make_tuple(kargs.N), + make_tuple(1), + number{}, + number<1>{}); + + return pad_tensor_view( + gamma_dram_naive, make_tuple(number{}), sequence{}); + }(); + + const auto beta_n = [&]() { + const auto gamma_dram_naive = make_naive_tensor_view( + static_cast(kargs.p_beta), + make_tuple(kargs.N), + make_tuple(1), + number{}, + number<1>{}); + + return pad_tensor_view( + gamma_dram_naive, make_tuple(number{}), sequence{}); + }(); + + const auto iM = get_block_id() * kMPerBlock; + + constexpr auto xDstr = MakeXBlockTileDistribution(); + + auto x_block_window = make_tile_window( + x_m_n, make_tuple(number{}, number{}), {iM, 0}, xDstr); + + const auto y_m_n = [&]() { + const auto y_dram_naive = make_naive_tensor_view( + static_cast(kargs.p_y), + make_tuple(kargs.M, kargs.N), + make_tuple(kargs.N, 1), + number{}, + number<1>{}); + + return pad_tensor_view(y_dram_naive, + make_tuple(number{}, number{}), + sequence{}); + }(); + + auto y_block_window = make_tile_window( + y_m_n, make_tuple(number{}, number{}), {iM, 0}); + + constexpr auto gammaDstr = MakeGammaBetaBlockTileDistribution(); + constexpr auto betaDstr = gammaDstr; + + auto gamma_block_window = + make_tile_window(gamma_n, make_tuple(number{}), {0}, gammaDstr); + + auto beta_block_window = make_tile_window( + beta_n, make_tuple(number{}, number{}), {0}, betaDstr); + + auto mean_block_window = [&]() { + if constexpr(kSaveMean) + { + const auto mean_m = [&]() { + const auto mean_dram_naive = + make_naive_tensor_view_packed( + static_cast(kargs.p_mean), + make_tuple(kargs.M), + number<1>{}); + + return pad_tensor_view( + mean_dram_naive, make_tuple(number{}), sequence{}); + }(); + + return make_tile_window(mean_m, make_tuple(number{}), {iM}); + } + else + return make_null_tile_window(make_tuple(number{})); + }(); + + auto inv_std_block_window = [&]() { + if constexpr(kSaveInvStd) + { + const auto inv_std_m = [&]() { + const auto inv_std_dram_naive = + make_naive_tensor_view_packed( + static_cast(kargs.p_invStd), + make_tuple(kargs.M), + number<1>{}); + + return pad_tensor_view( + inv_std_dram_naive, make_tuple(number{}), sequence{}); + }(); + + return make_tile_window(inv_std_m, make_tuple(number{}), {iM}); + } + else + return make_null_tile_window(make_tuple(number{})); + }(); + + if(kargs.N <= kNPerBlock) + OnePassLayernorm2dFwd(x_block_window, + gamma_block_window, + beta_block_window, + y_block_window, + mean_block_window, + inv_std_block_window, + static_cast(kargs.epsilon), + kargs.N); + else + TwoPassLayernorm2dFwd(x_block_window, + gamma_block_window, + beta_block_window, + y_block_window, + mean_block_window, + inv_std_block_window, + static_cast(kargs.epsilon), + kargs.N); } }; diff --git a/include/ck_tile/ops/layernorm2d/pipeline/block_layernorm2d_fwd_problem.hpp b/include/ck_tile/ops/layernorm2d/pipeline/block_layernorm2d_fwd_problem.hpp index 5206d36d7d..707a38f621 100644 --- a/include/ck_tile/ops/layernorm2d/pipeline/block_layernorm2d_fwd_problem.hpp +++ b/include/ck_tile/ops/layernorm2d/pipeline/block_layernorm2d_fwd_problem.hpp @@ -14,17 +14,21 @@ template + typename BlockShape_, + bool kPadM_, + bool kPadN_> struct BlockLayernorm2dFwdProblem { - using XDataType = remove_cvref_t; - using GammaDataType = remove_cvref_t; - using BetaDataType = remove_cvref_t; - using ComputeDataType = remove_cvref_t; - using YDataType = remove_cvref_t; - using MeanDataType = remove_cvref_t; - using InvStdDataType = remove_cvref_t; - using BlockShape = remove_cvref_t; + using XDataType = remove_cvref_t; + using GammaDataType = remove_cvref_t; + using BetaDataType = remove_cvref_t; + using ComputeDataType = remove_cvref_t; + using YDataType = remove_cvref_t; + using MeanDataType = remove_cvref_t; + using InvStdDataType = remove_cvref_t; + using BlockShape = remove_cvref_t; + static constexpr bool kPadM = kPadM_; + static constexpr bool kPadN = kPadN_; }; } // namespace ck_tile