diff --git a/example/ck_tile/05_reduce/reduce.cpp b/example/ck_tile/05_reduce/reduce.cpp index a97f9fe7d4..005541dc62 100644 --- a/example/ck_tile/05_reduce/reduce.cpp +++ b/example/ck_tile/05_reduce/reduce.cpp @@ -40,6 +40,7 @@ bool run(const ck_tile::ArgParser& arg_parser) x_buf.ToDevice(x_host.data()); + using ReduceOp = ck_tile::ReduceOp::Add; using BlockWarps = ck_tile::sequence<4, 1>; using BlockTile = ck_tile::sequence<128, 128>; using WarpTile = ck_tile::sequence<32, 128>; @@ -56,8 +57,9 @@ bool run(const ck_tile::ArgParser& arg_parser) ck_tile::index_t kGridSize = (m / BlockTile::at(ck_tile::number<0>{})); std::cout << "grid size " << kGridSize << std::endl; - using Shape = ck_tile::Reduce2dShape; - using Porblem = ck_tile::Reduce2dProblem; + using Shape = ck_tile::Reduce2dShape; + using Porblem = + ck_tile::Reduce2dProblem; using Kernel = ck_tile::Reduce; @@ -83,7 +85,8 @@ bool run(const ck_tile::ArgParser& arg_parser) if(do_validation) { // reference - ck_tile::reference_reduce(x_host, y_host_ref); + ck_tile::reference_reduce( + x_host, y_host_ref, ReduceOp{}); y_buf.FromDevice(y_host_dev.mData.data()); pass = ck_tile::check_err(y_host_dev, y_host_ref); diff --git a/example/ck_tile/05_reduce/reduce.hpp b/example/ck_tile/05_reduce/reduce.hpp index 4d09d5414a..856b065318 100644 --- a/example/ck_tile/05_reduce/reduce.hpp +++ b/example/ck_tile/05_reduce/reduce.hpp @@ -38,13 +38,18 @@ struct Reduce2dShape warpSize * reduce_on_sequence(BlockWarps{}, multiplies{}, number<1>{}); }; -template +template struct Reduce2dProblem { using XDataType = remove_cvref_t; using ComputeDataType = remove_cvref_t; using YDataType = remove_cvref_t; using BlockShape = remove_cvref_t; + using ReduceOp = ReduceOp_; static constexpr bool kNeedCrossLaneSync = BlockShape::ThreadPerWarp_N > 1; static constexpr bool kNeedCrossWarpSync = BlockShape::WarpPerBlock_N > 1; @@ -131,7 +136,7 @@ struct Reduce index_t num_n_tile_iteration = __builtin_amdgcn_readfirstlane(integer_divide_ceil(N, S::Block_N)); - auto reduce_func = [](const auto& v0, const auto& v1) { return v0 + v1; }; + auto reduce_func = typename Problem::ReduceOp{}; auto block_reduce2d = Policy::template GetBlockReduce2d(); auto block_reduce2d_sync = Policy::template GetBlockReduce2dSync(); auto block_reduce2d_cross_warp_sync = @@ -139,7 +144,7 @@ struct Reduce using XTensorType = decltype(load_tile(x_window)); auto y_compute = block_reduce2d.template MakeYBlockTile(); - set_tile(y_compute, 0); + set_tile(y_compute, reduce_func.template GetIdentityValue()); for(int iN = __builtin_amdgcn_readfirstlane(0); iN < num_n_tile_iteration; ++iN) { diff --git a/include/ck_tile/core.hpp b/include/ck_tile/core.hpp index d96f14710b..cc3d0f1a70 100644 --- a/include/ck_tile/core.hpp +++ b/include/ck_tile/core.hpp @@ -57,6 +57,7 @@ #include "ck_tile/core/utility/magic_div.hpp" #include "ck_tile/core/utility/philox_rand.hpp" #include "ck_tile/core/utility/random.hpp" +#include "ck_tile/core/utility/reduce_operator.hpp" #include "ck_tile/core/utility/to_sequence.hpp" #include "ck_tile/core/utility/transpose_vectors.hpp" #include "ck_tile/core/utility/type_traits.hpp" diff --git a/include/ck_tile/core/utility/reduce_operator.hpp b/include/ck_tile/core/utility/reduce_operator.hpp new file mode 100644 index 0000000000..8b15d187fe --- /dev/null +++ b/include/ck_tile/core/utility/reduce_operator.hpp @@ -0,0 +1,95 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core/config.hpp" + +namespace ck_tile { + +namespace ReduceOp { +// y = ReduceOp(y, x); +struct Add +{ + template + CK_TILE_HOST_DEVICE static constexpr T GetIdentityValue() + { + return type_convert(0.0f); + }; + + template || std::is_same_v || + std::is_same_v || std::is_same_v>> + CK_TILE_HOST_DEVICE constexpr T operator()(const T& y, const T x) const + { + return y + x; + } + + template || std::is_same_v>> + CK_TILE_HOST_DEVICE constexpr T operator()(T& y, T x) const + { + float y_ = type_convert(y); + float x_ = type_convert(x); + + return type_convert(y_ + x_); + } +}; + +struct SquareAdd +{ + template + CK_TILE_HOST_DEVICE static constexpr T GetIdentityValue() + { + return type_convert(0.0f); + }; + + template || std::is_same_v || + std::is_same_v || std::is_same_v>> + CK_TILE_HOST_DEVICE constexpr T operator()(const T& y, const T x) const + { + return y + (x * x); + } +}; + +struct Max +{ + template || std::is_same_v || + std::is_same_v || std::is_same_v>> + CK_TILE_HOST_DEVICE static constexpr T GetIdentityValue() + { + return numeric::min(); + }; + + template || std::is_same_v || + std::is_same_v || std::is_same_v>> + CK_TILE_HOST_DEVICE constexpr T operator()(const T& y, const T x) const + { + return max(y, x); + } +}; + +struct AbsMax +{ + template || std::is_same_v || + std::is_same_v || std::is_same_v>> + CK_TILE_HOST_DEVICE static constexpr T GetIdentityValue() + { + return numeric::min(); + }; + + template || std::is_same_v || + std::is_same_v || std::is_same_v>> + CK_TILE_HOST_DEVICE constexpr T operator()(const T& y, const T x) const + { + return max(y, abs(x)); + } +}; + +} // namespace ReduceOp +} // namespace ck_tile diff --git a/include/ck_tile/host/reference/reference_reduce.hpp b/include/ck_tile/host/reference/reference_reduce.hpp index b16cee3f94..8f8aa23670 100644 --- a/include/ck_tile/host/reference/reference_reduce.hpp +++ b/include/ck_tile/host/reference/reference_reduce.hpp @@ -9,24 +9,25 @@ namespace ck_tile { -template -CK_TILE_HOST void reference_reduce(const HostTensor& a_m_n, HostTensor& b_m) +template +CK_TILE_HOST void +reference_reduce(const HostTensor& x_m_n, HostTensor& y_m, ReduceOp reduce_op) { auto f = [&](auto m) { - const int N = a_m_n.mDesc.get_lengths()[1]; + const int N = x_m_n.mDesc.get_lengths()[1]; - AccDataType v_acc = 0; + ComputeDataType v_acc = reduce_op.template GetIdentityValue(); for(int n = 0; n < N; ++n) { - const ADataType v_a = a_m_n(m, n); + const ComputeDataType v_a = type_convert(x_m_n(m, n)); - v_acc += v_a; + v_acc = reduce_op(v_acc, v_a); } - b_m(m) = ck_tile::type_convert(v_acc); + y_m(m) = ck_tile::type_convert(v_acc); }; - make_ParallelTensorFunctor(f, b_m.mDesc.get_lengths()[0])(std::thread::hardware_concurrency()); + make_ParallelTensorFunctor(f, y_m.mDesc.get_lengths()[0])(std::thread::hardware_concurrency()); } } // namespace ck_tile diff --git a/include/ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_one_pass.hpp b/include/ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_one_pass.hpp index f37f5e6475..c8f836beea 100644 --- a/include/ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_one_pass.hpp +++ b/include/ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_one_pass.hpp @@ -66,12 +66,12 @@ struct AddRmsnorm2dRdquantFwdPipelineOnePass const auto gamma_window = make_tile_window( gamma_window_, Policy::template MakeGammaBlockTileDistribution()); - auto reduce_square_sum_func = [](const auto& v0, const auto& v1) { return v0 + v1 * v1; }; - auto reduce_sum_func = [](const auto& v0, const auto& v1) { return v0 + v1; }; - auto reduce_absmax_func = [](const auto& v0, const auto& v1) { return max(v0, abs(v1)); }; - auto reduce_max_func = [](const auto& v0, const auto& v1) { return max(v0, v1); }; - auto block_reduce2d = Policy::template GetBlockReduce2d(); - auto block_reduce2d_sync = Policy::template GetBlockReduce2dSync(); + auto reduce_square_sum_func = ReduceOp::SquareAdd{}; + auto reduce_sum_func = ReduceOp::Add{}; + auto reduce_absmax_func = ReduceOp::AbsMax{}; + auto reduce_max_func = ReduceOp::Max{}; + auto block_reduce2d = Policy::template GetBlockReduce2d(); + auto block_reduce2d_sync = Policy::template GetBlockReduce2dSync(); auto block_reduce2d_cross_warp_sync = Policy::template GetBlockReduce2dCrossWarpSync(); @@ -90,7 +90,8 @@ struct AddRmsnorm2dRdquantFwdPipelineOnePass store_tile(x_window, cast_tile(x)); // compute mean square, each-thread->cross-lane->cross-warp - auto square_sum = block_reduce2d(x, 0, reduce_square_sum_func); + auto square_sum = block_reduce2d( + x, reduce_square_sum_func.GetIdentityValue(), reduce_square_sum_func); block_reduce2d_sync(square_sum, reduce_sum_func); block_reduce2d_cross_warp_sync(square_sum, smem, reduce_sum_func); @@ -115,7 +116,8 @@ struct AddRmsnorm2dRdquantFwdPipelineOnePass }); // compute absmax, each-thread->cross-lane->cross-warp - auto absmax = block_reduce2d(x, numeric::min(), reduce_absmax_func); + auto absmax = block_reduce2d( + x, reduce_absmax_func.GetIdentityValue(), reduce_absmax_func); block_reduce2d_sync(absmax, reduce_max_func); block_reduce2d_cross_warp_sync(absmax, smem, reduce_max_func); @@ -130,7 +132,6 @@ struct AddRmsnorm2dRdquantFwdPipelineOnePass auto qy = make_static_distributed_tensor(y.get_tile_distribution()); sweep_tile(qy, [&, yscale_ = yscale](auto idx) { constexpr auto i_idx = make_tuple(idx[number<0>{}]); - constexpr auto j_idx = make_tuple(idx[number<1>{}]); auto qy_ = y[idx] / yscale_[i_idx]; qy(idx) = saturates{}(qy_); }); diff --git a/include/ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_two_pass.hpp b/include/ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_two_pass.hpp index 0c8c402d84..6660defb47 100644 --- a/include/ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_two_pass.hpp +++ b/include/ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_two_pass.hpp @@ -63,8 +63,86 @@ struct AddRmsnorm2dRdquantFwdPipelineTwoPass make_tile_window(a_window_, Policy::template MakeABXBlockTileDistribution()); auto b_window = make_tile_window(b_window_, Policy::template MakeABXBlockTileDistribution()); - const auto gamma_window = make_tile_window( + auto gamma_window = make_tile_window( gamma_window_, Policy::template MakeGammaBlockTileDistribution()); + + auto reduce_square_sum_func = [](const auto& v0, const auto& v1) { return v0 + v1 * v1; }; + auto reduce_sum_func = [](const auto& v0, const auto& v1) { return v0 + v1; }; + auto reduce_absmax_func = [](const auto& v0, const auto& v1) { return max(v0, abs(v1)); }; + auto reduce_max_func = [](const auto& v0, const auto& v1) { return max(v0, v1); }; + auto block_reduce2d = Policy::template GetBlockReduce2d(); + auto block_reduce2d_sync = Policy::template GetBlockReduce2dSync(); + auto block_reduce2d_cross_warp_sync = + Policy::template GetBlockReduce2dCrossWarpSync(); + + static constexpr index_t Block_N = Problem::BlockShape::Block_N; + index_t num_n_tile_iteration = + __builtin_amdgcn_readfirstlane(integer_divide_ceil(row_size, Block_N)); + + using XTensorType = decltype(cast_tile(load_tile(a_window))); + auto square_sum = block_reduce2d.template MakeYBlockTile(); + set_tile(square_sum, 0); + + for(int iN = __builtin_amdgcn_readfirstlane(0); iN < num_n_tile_iteration; ++iN) + { + const auto a = load_tile(a_window); + const auto b = load_tile(b_window); + + auto x = tile_elementwise_in( + [&](const auto& a_, const auto& b_) { + return type_convert(a_) + type_convert(b_); + }, + a, + b); + + if constexpr(kSaveX) + store_tile(x_window, cast_tile(x)); + + block_reduce2d(x, square_sum, reduce_square_sum_func); + move_tile_window(x_window, {0, Block_N}); + } + + block_reduce2d_sync(square_sum, reduce_sum_func); + block_reduce2d_cross_warp_sync(square_sum, smem, reduce_sum_func); + + auto inv_rms = tile_elementwise_in( + [&](const auto& v_) { + return type_convert(1.0f) / (sqrt(v_ / row_size + epsilon)); + }, + square_sum); + + // reverse read x to reuse cache + ck_tile::index_t stride_to_right_most_window = + row_size % Block_N == 0 ? row_size - Block_N : row_size - row_size % Block_N; + + move_tile_window(x_window, {0, -Block_N}); + move_tile_window(gamma_window, {stride_to_right_most_window}); + move_tile_window(y_window, {0, stride_to_right_most_window}); + + // rmsnorm computation + absmax + quantization + for(int iN = __builtin_amdgcn_readfirstlane(0); iN < num_n_tile_iteration; ++iN) + { + const auto x = load_tile(x_window); + const auto gamma = load_tile(gamma_window); + + auto y = make_static_distributed_tensor(x.get_tile_distribution()); + + sweep_tile(y, [&, inv_rms_ = inv_rms](auto idx) { + constexpr auto i_idx = make_tuple(idx[number<0>{}]); + constexpr auto j_idx = make_tuple(idx[number<1>{}]); + + const auto gamma_ = type_convert(gamma[j_idx]); + + const auto x_ = type_convert(x[idx]); + auto y_ = x_ * inv_rms_[i_idx] * gamma_; + + y(idx) = type_convert(y_); + }); + + move_tile_window(x_window, {0, -Block_N}); + move_tile_window(gamma_window, {-Block_N}); + move_tile_window(y_window, {0, -Block_N}); + } } }; } // namespace ck_tile diff --git a/include/ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_one_pass.hpp b/include/ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_one_pass.hpp index 8559485038..68cfe4282b 100644 --- a/include/ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_one_pass.hpp +++ b/include/ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_one_pass.hpp @@ -55,8 +55,8 @@ struct Rmsnorm2dFwdPipelineOnePass const auto gamma_window = make_tile_window( gamma_window_, Policy::template MakeGammaBlockTileDistribution()); - auto reduce_square_sum_func = [](const auto& v0, const auto& v1) { return v0 + v1 * v1; }; - auto reduce_sum_func = [](const auto& v0, const auto& v1) { return v0 + v1; }; + auto reduce_square_sum_func = ReduceOp::SquareAdd{}; + auto reduce_sum_func = ReduceOp::Add{}; auto block_reduce2d = Policy::template GetBlockReduce2d(); auto block_reduce2d_sync = Policy::template GetBlockReduce2dSync(); auto block_reduce2d_cross_warp_sync = @@ -67,7 +67,8 @@ struct Rmsnorm2dFwdPipelineOnePass const auto gamma = load_tile(gamma_window); // compute mean square each-thread->cross-lane->cross-warp - auto square_sum = block_reduce2d(x, 0, reduce_square_sum_func); + auto square_sum = block_reduce2d( + x, reduce_square_sum_func.GetIdentityValue(), reduce_square_sum_func); block_reduce2d_sync(square_sum, reduce_sum_func); block_reduce2d_cross_warp_sync(square_sum, smem, reduce_sum_func); diff --git a/include/ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_two_pass.hpp b/include/ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_two_pass.hpp index 28e02fe651..a892df6bdb 100644 --- a/include/ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_two_pass.hpp +++ b/include/ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_two_pass.hpp @@ -60,8 +60,8 @@ struct Rmsnorm2dFwdPipelineTwoPass index_t num_n_tile_iteration = __builtin_amdgcn_readfirstlane(integer_divide_ceil(row_size, Block_N)); - auto reduce_square_sum_func = [](const auto& v0, const auto& v1) { return v0 + v1 * v1; }; - auto reduce_sum_func = [](const auto& v0, const auto& v1) { return v0 + v1; }; + auto reduce_square_sum_func = ReduceOp::SquareAdd{}; + auto reduce_sum_func = ReduceOp::Add{}; auto block_reduce2d = Policy::template GetBlockReduce2d(); auto block_reduce2d_sync = Policy::template GetBlockReduce2dSync(); auto block_reduce2d_cross_warp_sync = @@ -69,7 +69,7 @@ struct Rmsnorm2dFwdPipelineTwoPass using XTensorType = decltype(load_tile(x_window)); auto square_sum = block_reduce2d.template MakeYBlockTile(); - set_tile(square_sum, 0); + set_tile(square_sum, reduce_square_sum_func.GetIdentityValue()); for(int iN = __builtin_amdgcn_readfirstlane(0); iN < num_n_tile_iteration; ++iN) {