mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-07 08:15:04 +00:00
Add reduce op
This commit is contained in:
@@ -40,6 +40,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
|
||||
x_buf.ToDevice(x_host.data());
|
||||
|
||||
using ReduceOp = ck_tile::ReduceOp::Add;
|
||||
using BlockWarps = ck_tile::sequence<4, 1>;
|
||||
using BlockTile = ck_tile::sequence<128, 128>;
|
||||
using WarpTile = ck_tile::sequence<32, 128>;
|
||||
@@ -56,8 +57,9 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
ck_tile::index_t kGridSize = (m / BlockTile::at(ck_tile::number<0>{}));
|
||||
std::cout << "grid size " << kGridSize << std::endl;
|
||||
|
||||
using Shape = ck_tile::Reduce2dShape<BlockWarps, BlockTile, WarpTile, Vector>;
|
||||
using Porblem = ck_tile::Reduce2dProblem<XDataType, ComputeDataType, YDataType, Shape>;
|
||||
using Shape = ck_tile::Reduce2dShape<BlockWarps, BlockTile, WarpTile, Vector>;
|
||||
using Porblem =
|
||||
ck_tile::Reduce2dProblem<XDataType, ComputeDataType, YDataType, Shape, ReduceOp>;
|
||||
|
||||
using Kernel = ck_tile::Reduce<Porblem>;
|
||||
|
||||
@@ -83,7 +85,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
if(do_validation)
|
||||
{
|
||||
// reference
|
||||
ck_tile::reference_reduce<XDataType, ComputeDataType, YDataType>(x_host, y_host_ref);
|
||||
ck_tile::reference_reduce<XDataType, ComputeDataType, YDataType>(
|
||||
x_host, y_host_ref, ReduceOp{});
|
||||
y_buf.FromDevice(y_host_dev.mData.data());
|
||||
pass = ck_tile::check_err(y_host_dev, y_host_ref);
|
||||
|
||||
|
||||
@@ -38,13 +38,18 @@ struct Reduce2dShape
|
||||
warpSize * reduce_on_sequence(BlockWarps{}, multiplies{}, number<1>{});
|
||||
};
|
||||
|
||||
template <typename XDataType_, typename ComputeDataType_, typename YDataType_, typename BlockShape_>
|
||||
template <typename XDataType_,
|
||||
typename ComputeDataType_,
|
||||
typename YDataType_,
|
||||
typename BlockShape_,
|
||||
typename ReduceOp_>
|
||||
struct Reduce2dProblem
|
||||
{
|
||||
using XDataType = remove_cvref_t<XDataType_>;
|
||||
using ComputeDataType = remove_cvref_t<ComputeDataType_>;
|
||||
using YDataType = remove_cvref_t<YDataType_>;
|
||||
using BlockShape = remove_cvref_t<BlockShape_>;
|
||||
using ReduceOp = ReduceOp_;
|
||||
|
||||
static constexpr bool kNeedCrossLaneSync = BlockShape::ThreadPerWarp_N > 1;
|
||||
static constexpr bool kNeedCrossWarpSync = BlockShape::WarpPerBlock_N > 1;
|
||||
@@ -131,7 +136,7 @@ struct Reduce
|
||||
index_t num_n_tile_iteration =
|
||||
__builtin_amdgcn_readfirstlane(integer_divide_ceil(N, S::Block_N));
|
||||
|
||||
auto reduce_func = [](const auto& v0, const auto& v1) { return v0 + v1; };
|
||||
auto reduce_func = typename Problem::ReduceOp{};
|
||||
auto block_reduce2d = Policy::template GetBlockReduce2d<Problem>();
|
||||
auto block_reduce2d_sync = Policy::template GetBlockReduce2dSync<Problem>();
|
||||
auto block_reduce2d_cross_warp_sync =
|
||||
@@ -139,7 +144,7 @@ struct Reduce
|
||||
|
||||
using XTensorType = decltype(load_tile(x_window));
|
||||
auto y_compute = block_reduce2d.template MakeYBlockTile<XTensorType>();
|
||||
set_tile(y_compute, 0);
|
||||
set_tile(y_compute, reduce_func.template GetIdentityValue<ComputeDataType>());
|
||||
|
||||
for(int iN = __builtin_amdgcn_readfirstlane(0); iN < num_n_tile_iteration; ++iN)
|
||||
{
|
||||
|
||||
@@ -57,6 +57,7 @@
|
||||
#include "ck_tile/core/utility/magic_div.hpp"
|
||||
#include "ck_tile/core/utility/philox_rand.hpp"
|
||||
#include "ck_tile/core/utility/random.hpp"
|
||||
#include "ck_tile/core/utility/reduce_operator.hpp"
|
||||
#include "ck_tile/core/utility/to_sequence.hpp"
|
||||
#include "ck_tile/core/utility/transpose_vectors.hpp"
|
||||
#include "ck_tile/core/utility/type_traits.hpp"
|
||||
|
||||
95
include/ck_tile/core/utility/reduce_operator.hpp
Normal file
95
include/ck_tile/core/utility/reduce_operator.hpp
Normal file
@@ -0,0 +1,95 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
namespace ReduceOp {
|
||||
// y = ReduceOp(y, x);
|
||||
struct Add
|
||||
{
|
||||
template <typename T>
|
||||
CK_TILE_HOST_DEVICE static constexpr T GetIdentityValue()
|
||||
{
|
||||
return type_convert<T>(0.0f);
|
||||
};
|
||||
|
||||
template <typename T,
|
||||
typename = std::enable_if_t<std::is_same_v<T, float> || std::is_same_v<T, double> ||
|
||||
std::is_same_v<T, int32_t> || std::is_same_v<T, int8_t>>>
|
||||
CK_TILE_HOST_DEVICE constexpr T operator()(const T& y, const T x) const
|
||||
{
|
||||
return y + x;
|
||||
}
|
||||
|
||||
template <typename T,
|
||||
typename = std::enable_if_t<std::is_same_v<T, half_t> || std::is_same_v<T, bf16_t>>>
|
||||
CK_TILE_HOST_DEVICE constexpr T operator()(T& y, T x) const
|
||||
{
|
||||
float y_ = type_convert<float>(y);
|
||||
float x_ = type_convert<float>(x);
|
||||
|
||||
return type_convert<T>(y_ + x_);
|
||||
}
|
||||
};
|
||||
|
||||
struct SquareAdd
|
||||
{
|
||||
template <typename T>
|
||||
CK_TILE_HOST_DEVICE static constexpr T GetIdentityValue()
|
||||
{
|
||||
return type_convert<T>(0.0f);
|
||||
};
|
||||
|
||||
template <typename T,
|
||||
typename = std::enable_if_t<std::is_same_v<T, float> || std::is_same_v<T, double> ||
|
||||
std::is_same_v<T, int32_t> || std::is_same_v<T, int8_t>>>
|
||||
CK_TILE_HOST_DEVICE constexpr T operator()(const T& y, const T x) const
|
||||
{
|
||||
return y + (x * x);
|
||||
}
|
||||
};
|
||||
|
||||
struct Max
|
||||
{
|
||||
template <typename T,
|
||||
typename = std::enable_if_t<std::is_same_v<T, float> || std::is_same_v<T, double> ||
|
||||
std::is_same_v<T, int32_t> || std::is_same_v<T, int8_t>>>
|
||||
CK_TILE_HOST_DEVICE static constexpr T GetIdentityValue()
|
||||
{
|
||||
return numeric<T>::min();
|
||||
};
|
||||
|
||||
template <typename T,
|
||||
typename = std::enable_if_t<std::is_same_v<T, float> || std::is_same_v<T, double> ||
|
||||
std::is_same_v<T, int32_t> || std::is_same_v<T, int8_t>>>
|
||||
CK_TILE_HOST_DEVICE constexpr T operator()(const T& y, const T x) const
|
||||
{
|
||||
return max(y, x);
|
||||
}
|
||||
};
|
||||
|
||||
struct AbsMax
|
||||
{
|
||||
template <typename T,
|
||||
typename = std::enable_if_t<std::is_same_v<T, float> || std::is_same_v<T, double> ||
|
||||
std::is_same_v<T, int32_t> || std::is_same_v<T, int8_t>>>
|
||||
CK_TILE_HOST_DEVICE static constexpr T GetIdentityValue()
|
||||
{
|
||||
return numeric<T>::min();
|
||||
};
|
||||
|
||||
template <typename T,
|
||||
typename = std::enable_if_t<std::is_same_v<T, float> || std::is_same_v<T, double> ||
|
||||
std::is_same_v<T, int32_t> || std::is_same_v<T, int8_t>>>
|
||||
CK_TILE_HOST_DEVICE constexpr T operator()(const T& y, const T x) const
|
||||
{
|
||||
return max(y, abs(x));
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ReduceOp
|
||||
} // namespace ck_tile
|
||||
@@ -9,24 +9,25 @@
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename ADataType, typename AccDataType, typename BDataType>
|
||||
CK_TILE_HOST void reference_reduce(const HostTensor<ADataType>& a_m_n, HostTensor<BDataType>& b_m)
|
||||
template <typename XDataType, typename ComputeDataType, typename YDataType, typename ReduceOp>
|
||||
CK_TILE_HOST void
|
||||
reference_reduce(const HostTensor<XDataType>& x_m_n, HostTensor<YDataType>& y_m, ReduceOp reduce_op)
|
||||
{
|
||||
auto f = [&](auto m) {
|
||||
const int N = a_m_n.mDesc.get_lengths()[1];
|
||||
const int N = x_m_n.mDesc.get_lengths()[1];
|
||||
|
||||
AccDataType v_acc = 0;
|
||||
ComputeDataType v_acc = reduce_op.template GetIdentityValue<ComputeDataType>();
|
||||
|
||||
for(int n = 0; n < N; ++n)
|
||||
{
|
||||
const ADataType v_a = a_m_n(m, n);
|
||||
const ComputeDataType v_a = type_convert<ComputeDataType>(x_m_n(m, n));
|
||||
|
||||
v_acc += v_a;
|
||||
v_acc = reduce_op(v_acc, v_a);
|
||||
}
|
||||
|
||||
b_m(m) = ck_tile::type_convert<BDataType>(v_acc);
|
||||
y_m(m) = ck_tile::type_convert<YDataType>(v_acc);
|
||||
};
|
||||
|
||||
make_ParallelTensorFunctor(f, b_m.mDesc.get_lengths()[0])(std::thread::hardware_concurrency());
|
||||
make_ParallelTensorFunctor(f, y_m.mDesc.get_lengths()[0])(std::thread::hardware_concurrency());
|
||||
}
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -66,12 +66,12 @@ struct AddRmsnorm2dRdquantFwdPipelineOnePass
|
||||
const auto gamma_window = make_tile_window(
|
||||
gamma_window_, Policy::template MakeGammaBlockTileDistribution<Problem>());
|
||||
|
||||
auto reduce_square_sum_func = [](const auto& v0, const auto& v1) { return v0 + v1 * v1; };
|
||||
auto reduce_sum_func = [](const auto& v0, const auto& v1) { return v0 + v1; };
|
||||
auto reduce_absmax_func = [](const auto& v0, const auto& v1) { return max(v0, abs(v1)); };
|
||||
auto reduce_max_func = [](const auto& v0, const auto& v1) { return max(v0, v1); };
|
||||
auto block_reduce2d = Policy::template GetBlockReduce2d<Problem>();
|
||||
auto block_reduce2d_sync = Policy::template GetBlockReduce2dSync<Problem>();
|
||||
auto reduce_square_sum_func = ReduceOp::SquareAdd{};
|
||||
auto reduce_sum_func = ReduceOp::Add{};
|
||||
auto reduce_absmax_func = ReduceOp::AbsMax{};
|
||||
auto reduce_max_func = ReduceOp::Max{};
|
||||
auto block_reduce2d = Policy::template GetBlockReduce2d<Problem>();
|
||||
auto block_reduce2d_sync = Policy::template GetBlockReduce2dSync<Problem>();
|
||||
auto block_reduce2d_cross_warp_sync =
|
||||
Policy::template GetBlockReduce2dCrossWarpSync<Problem>();
|
||||
|
||||
@@ -90,7 +90,8 @@ struct AddRmsnorm2dRdquantFwdPipelineOnePass
|
||||
store_tile(x_window, cast_tile<XDataType>(x));
|
||||
|
||||
// compute mean square, each-thread->cross-lane->cross-warp
|
||||
auto square_sum = block_reduce2d(x, 0, reduce_square_sum_func);
|
||||
auto square_sum = block_reduce2d(
|
||||
x, reduce_square_sum_func.GetIdentityValue<ComputeDataType>(), reduce_square_sum_func);
|
||||
block_reduce2d_sync(square_sum, reduce_sum_func);
|
||||
block_reduce2d_cross_warp_sync(square_sum, smem, reduce_sum_func);
|
||||
|
||||
@@ -115,7 +116,8 @@ struct AddRmsnorm2dRdquantFwdPipelineOnePass
|
||||
});
|
||||
|
||||
// compute absmax, each-thread->cross-lane->cross-warp
|
||||
auto absmax = block_reduce2d(x, numeric<YScaleDataType>::min(), reduce_absmax_func);
|
||||
auto absmax = block_reduce2d(
|
||||
x, reduce_absmax_func.GetIdentityValue<ComputeDataType>(), reduce_absmax_func);
|
||||
block_reduce2d_sync(absmax, reduce_max_func);
|
||||
block_reduce2d_cross_warp_sync(absmax, smem, reduce_max_func);
|
||||
|
||||
@@ -130,7 +132,6 @@ struct AddRmsnorm2dRdquantFwdPipelineOnePass
|
||||
auto qy = make_static_distributed_tensor<QYDataType>(y.get_tile_distribution());
|
||||
sweep_tile(qy, [&, yscale_ = yscale](auto idx) {
|
||||
constexpr auto i_idx = make_tuple(idx[number<0>{}]);
|
||||
constexpr auto j_idx = make_tuple(idx[number<1>{}]);
|
||||
auto qy_ = y[idx] / yscale_[i_idx];
|
||||
qy(idx) = saturates<QYDataType>{}(qy_);
|
||||
});
|
||||
|
||||
@@ -63,8 +63,86 @@ struct AddRmsnorm2dRdquantFwdPipelineTwoPass
|
||||
make_tile_window(a_window_, Policy::template MakeABXBlockTileDistribution<Problem>());
|
||||
auto b_window =
|
||||
make_tile_window(b_window_, Policy::template MakeABXBlockTileDistribution<Problem>());
|
||||
const auto gamma_window = make_tile_window(
|
||||
auto gamma_window = make_tile_window(
|
||||
gamma_window_, Policy::template MakeGammaBlockTileDistribution<Problem>());
|
||||
|
||||
auto reduce_square_sum_func = [](const auto& v0, const auto& v1) { return v0 + v1 * v1; };
|
||||
auto reduce_sum_func = [](const auto& v0, const auto& v1) { return v0 + v1; };
|
||||
auto reduce_absmax_func = [](const auto& v0, const auto& v1) { return max(v0, abs(v1)); };
|
||||
auto reduce_max_func = [](const auto& v0, const auto& v1) { return max(v0, v1); };
|
||||
auto block_reduce2d = Policy::template GetBlockReduce2d<Problem>();
|
||||
auto block_reduce2d_sync = Policy::template GetBlockReduce2dSync<Problem>();
|
||||
auto block_reduce2d_cross_warp_sync =
|
||||
Policy::template GetBlockReduce2dCrossWarpSync<Problem>();
|
||||
|
||||
static constexpr index_t Block_N = Problem::BlockShape::Block_N;
|
||||
index_t num_n_tile_iteration =
|
||||
__builtin_amdgcn_readfirstlane(integer_divide_ceil(row_size, Block_N));
|
||||
|
||||
using XTensorType = decltype(cast_tile<ComputeDataType>(load_tile(a_window)));
|
||||
auto square_sum = block_reduce2d.template MakeYBlockTile<XTensorType>();
|
||||
set_tile(square_sum, 0);
|
||||
|
||||
for(int iN = __builtin_amdgcn_readfirstlane(0); iN < num_n_tile_iteration; ++iN)
|
||||
{
|
||||
const auto a = load_tile(a_window);
|
||||
const auto b = load_tile(b_window);
|
||||
|
||||
auto x = tile_elementwise_in(
|
||||
[&](const auto& a_, const auto& b_) {
|
||||
return type_convert<ComputeDataType>(a_) + type_convert<ComputeDataType>(b_);
|
||||
},
|
||||
a,
|
||||
b);
|
||||
|
||||
if constexpr(kSaveX)
|
||||
store_tile(x_window, cast_tile<XDataType>(x));
|
||||
|
||||
block_reduce2d(x, square_sum, reduce_square_sum_func);
|
||||
move_tile_window(x_window, {0, Block_N});
|
||||
}
|
||||
|
||||
block_reduce2d_sync(square_sum, reduce_sum_func);
|
||||
block_reduce2d_cross_warp_sync(square_sum, smem, reduce_sum_func);
|
||||
|
||||
auto inv_rms = tile_elementwise_in(
|
||||
[&](const auto& v_) {
|
||||
return type_convert<ComputeDataType>(1.0f) / (sqrt(v_ / row_size + epsilon));
|
||||
},
|
||||
square_sum);
|
||||
|
||||
// reverse read x to reuse cache
|
||||
ck_tile::index_t stride_to_right_most_window =
|
||||
row_size % Block_N == 0 ? row_size - Block_N : row_size - row_size % Block_N;
|
||||
|
||||
move_tile_window(x_window, {0, -Block_N});
|
||||
move_tile_window(gamma_window, {stride_to_right_most_window});
|
||||
move_tile_window(y_window, {0, stride_to_right_most_window});
|
||||
|
||||
// rmsnorm computation + absmax + quantization
|
||||
for(int iN = __builtin_amdgcn_readfirstlane(0); iN < num_n_tile_iteration; ++iN)
|
||||
{
|
||||
const auto x = load_tile(x_window);
|
||||
const auto gamma = load_tile(gamma_window);
|
||||
|
||||
auto y = make_static_distributed_tensor<ComputeDataType>(x.get_tile_distribution());
|
||||
|
||||
sweep_tile(y, [&, inv_rms_ = inv_rms](auto idx) {
|
||||
constexpr auto i_idx = make_tuple(idx[number<0>{}]);
|
||||
constexpr auto j_idx = make_tuple(idx[number<1>{}]);
|
||||
|
||||
const auto gamma_ = type_convert<ComputeDataType>(gamma[j_idx]);
|
||||
|
||||
const auto x_ = type_convert<ComputeDataType>(x[idx]);
|
||||
auto y_ = x_ * inv_rms_[i_idx] * gamma_;
|
||||
|
||||
y(idx) = type_convert<ComputeDataType>(y_);
|
||||
});
|
||||
|
||||
move_tile_window(x_window, {0, -Block_N});
|
||||
move_tile_window(gamma_window, {-Block_N});
|
||||
move_tile_window(y_window, {0, -Block_N});
|
||||
}
|
||||
}
|
||||
};
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -55,8 +55,8 @@ struct Rmsnorm2dFwdPipelineOnePass
|
||||
const auto gamma_window = make_tile_window(
|
||||
gamma_window_, Policy::template MakeGammaBlockTileDistribution<Problem>());
|
||||
|
||||
auto reduce_square_sum_func = [](const auto& v0, const auto& v1) { return v0 + v1 * v1; };
|
||||
auto reduce_sum_func = [](const auto& v0, const auto& v1) { return v0 + v1; };
|
||||
auto reduce_square_sum_func = ReduceOp::SquareAdd{};
|
||||
auto reduce_sum_func = ReduceOp::Add{};
|
||||
auto block_reduce2d = Policy::template GetBlockReduce2d<Problem>();
|
||||
auto block_reduce2d_sync = Policy::template GetBlockReduce2dSync<Problem>();
|
||||
auto block_reduce2d_cross_warp_sync =
|
||||
@@ -67,7 +67,8 @@ struct Rmsnorm2dFwdPipelineOnePass
|
||||
const auto gamma = load_tile(gamma_window);
|
||||
|
||||
// compute mean square each-thread->cross-lane->cross-warp
|
||||
auto square_sum = block_reduce2d(x, 0, reduce_square_sum_func);
|
||||
auto square_sum = block_reduce2d(
|
||||
x, reduce_square_sum_func.GetIdentityValue<ComputeDataType>(), reduce_square_sum_func);
|
||||
block_reduce2d_sync(square_sum, reduce_sum_func);
|
||||
block_reduce2d_cross_warp_sync(square_sum, smem, reduce_sum_func);
|
||||
|
||||
|
||||
@@ -60,8 +60,8 @@ struct Rmsnorm2dFwdPipelineTwoPass
|
||||
index_t num_n_tile_iteration =
|
||||
__builtin_amdgcn_readfirstlane(integer_divide_ceil(row_size, Block_N));
|
||||
|
||||
auto reduce_square_sum_func = [](const auto& v0, const auto& v1) { return v0 + v1 * v1; };
|
||||
auto reduce_sum_func = [](const auto& v0, const auto& v1) { return v0 + v1; };
|
||||
auto reduce_square_sum_func = ReduceOp::SquareAdd{};
|
||||
auto reduce_sum_func = ReduceOp::Add{};
|
||||
auto block_reduce2d = Policy::template GetBlockReduce2d<Problem>();
|
||||
auto block_reduce2d_sync = Policy::template GetBlockReduce2dSync<Problem>();
|
||||
auto block_reduce2d_cross_warp_sync =
|
||||
@@ -69,7 +69,7 @@ struct Rmsnorm2dFwdPipelineTwoPass
|
||||
|
||||
using XTensorType = decltype(load_tile(x_window));
|
||||
auto square_sum = block_reduce2d.template MakeYBlockTile<XTensorType>();
|
||||
set_tile(square_sum, 0);
|
||||
set_tile(square_sum, reduce_square_sum_func.GetIdentityValue<ComputeDataType>());
|
||||
|
||||
for(int iN = __builtin_amdgcn_readfirstlane(0); iN < num_n_tile_iteration; ++iN)
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user