General 2D Reduction Kernel (#2535)

* General 2D Reduction Kernel

* Move the reduction kernel from the example
* Split the code and add the necessary policy, problem, shape files as
per ck_tile convention
* Add/modify the headers
* Modified the example to work with the 'new' kernel
* Added tests for the kernel
* N-D refernce reduce
* Added support for N-D input with transform to 2D
* Added padding to support various input sized tensors
* Bug fix in the thread buffer constructor
* Some comments to explain the reduce2d block kernel

* comments resolution

* clang-format

* comments resolution

* clang-format

* clang-format

* comments resolution

* clang-format
This commit is contained in:
Yashvardhan Agarwal
2025-08-06 16:36:59 +03:00
committed by GitHub
parent 2622ff06cb
commit 4750b293fe
14 changed files with 905 additions and 199 deletions

View File

@@ -1,16 +1,21 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck_tile/host.hpp"
#include "reduce.hpp"
#include "ck_tile/ops/reduce.hpp"
#include <cstring>
auto create_args(int argc, char* argv[])
{
ck_tile::ArgParser arg_parser;
arg_parser.insert("m", "3328", "m dimension")
.insert("n", "4096", "n dimension")
arg_parser.insert("n", "32", "n dimension")
.insert("h", "7", "h dimension")
.insert("w", "7", "w dimension")
.insert("c", "512", "c dimension")
.insert("v", "1", "cpu validation or not")
.insert("prec", "fp16", "precision")
.insert("warmup", "5", "cold iter")
.insert("repeat", "20", "hot iter");
.insert("warmup", "0", "cold iter")
.insert("repeat", "1", "hot iter");
bool result = arg_parser.parse(argc, argv);
return std::make_tuple(result, arg_parser);
@@ -23,15 +28,28 @@ bool run(const ck_tile::ArgParser& arg_parser)
using ComputeDataType = float;
using YDataType = DataType;
ck_tile::index_t m = arg_parser.get_int("m");
ck_tile::index_t n = arg_parser.get_int("n");
ck_tile::index_t N = arg_parser.get_int("n");
ck_tile::index_t H = arg_parser.get_int("h");
ck_tile::index_t W = arg_parser.get_int("w");
ck_tile::index_t C = arg_parser.get_int("c");
int do_validation = arg_parser.get_int("v");
int warmup = arg_parser.get_int("warmup");
int repeat = arg_parser.get_int("repeat");
ck_tile::HostTensor<XDataType> x_host({m, n});
ck_tile::HostTensor<YDataType> y_host_ref({m});
ck_tile::HostTensor<YDataType> y_host_dev({m});
std::vector<ck_tile::index_t> problem_shape = {N, H, W, C};
std::vector<ck_tile::index_t> strides(4);
strides[0] = H * W * C;
strides[1] = W * C;
strides[2] = C;
strides[3] = 1;
// Define reduction specification:
constexpr auto kept_dim = ck_tile::sequence<0, 3>{}; // Which dimension to keep
constexpr auto reduce_dims = ck_tile::sequence<1, 2>{}; // Which dimensions to reduce
ck_tile::HostTensor<XDataType> x_host(problem_shape, strides);
ck_tile::HostTensor<YDataType> y_host_ref({N, C}, {C, 1});
ck_tile::HostTensor<YDataType> y_host_dev({N, C}, {C, 1});
ck_tile::FillUniformDistribution<XDataType>{-5.f, 5.f}(x_host);
@@ -54,7 +72,9 @@ bool run(const ck_tile::ArgParser& arg_parser)
constexpr ck_tile::index_t kBlockSize = 256;
constexpr ck_tile::index_t kBlockPerCu = 1;
ck_tile::index_t kGridSize = (m / BlockTile::at(ck_tile::number<0>{}));
ck_tile::index_t kept_dim_len_prod = N * C;
ck_tile::index_t kGridSize = (kept_dim_len_prod + BlockTile::at(ck_tile::number<0>{}) - 1) /
BlockTile::at(ck_tile::number<0>{});
std::cout << "grid size " << kGridSize << std::endl;
using Shape = ck_tile::Reduce2dShape<BlockWarps, BlockTile, WarpTile, Vector>;
@@ -63,6 +83,17 @@ bool run(const ck_tile::ArgParser& arg_parser)
using Kernel = ck_tile::Reduce<Porblem>;
// Create input tensor shape and strides
auto input_shape =
ck_tile::make_tuple(problem_shape[0], problem_shape[1], problem_shape[2], problem_shape[3]);
auto input_strides = ck_tile::make_tuple(strides[0], strides[1], strides[2], strides[3]);
if(!Kernel::IsSupportedArgument(
C, input_strides)) // output tensor's continuous dimension and input strides
{
throw std::runtime_error("Wrong! Arguments not supported!\n");
}
float ave_time = launch_kernel(ck_tile::stream_config{nullptr, true, 0, warmup, repeat},
ck_tile::make_kernel<kBlockSize, kBlockPerCu>(
Kernel{},
@@ -71,10 +102,12 @@ bool run(const ck_tile::ArgParser& arg_parser)
0,
static_cast<XDataType*>(x_buf.GetDeviceBuffer()),
static_cast<YDataType*>(y_buf.GetDeviceBuffer()),
m,
n));
input_shape,
input_strides,
kept_dim,
reduce_dims));
std::size_t num_btype = sizeof(XDataType) * m * n + sizeof(YDataType) * m;
std::size_t num_btype = sizeof(XDataType) * N * C * H * W + sizeof(YDataType) * N * C;
float gb_per_sec = num_btype / 1.E6 / ave_time;
@@ -86,7 +119,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
{
// reference
ck_tile::reference_reduce<XDataType, ComputeDataType, YDataType>(
x_host, y_host_ref, ReduceOp{});
x_host, y_host_ref, ReduceOp{}, kept_dim, reduce_dims);
y_buf.FromDevice(y_host_dev.mData.data());
pass = ck_tile::check_err(y_host_dev, y_host_ref);

View File

@@ -1,164 +0,0 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common.hpp"
#include "ck_tile/ops/reduce/block/block_reduce.hpp"
#include "ck_tile/ops/reduce/block/block_reduce2d_default_policy.hpp"
namespace ck_tile {
template <typename BlockWarps, // num warps along seq<M, N>
typename BlockTile, // block size, seq<M, N>
typename WarpTile, // warp size, seq<M, N>
typename Vector> // contiguous pixels(vector size) along seq<M, N>
struct Reduce2dShape
{
static constexpr index_t Block_M = BlockTile::at(number<0>{});
static constexpr index_t Block_N = BlockTile::at(number<1>{});
static constexpr index_t Warp_M = WarpTile::at(number<0>{});
static constexpr index_t Warp_N = WarpTile::at(number<1>{});
static constexpr index_t Vector_M = Vector::at(number<0>{});
static constexpr index_t Vector_N = Vector::at(number<1>{});
static constexpr index_t WarpPerBlock_M = BlockWarps::at(number<0>{});
static constexpr index_t WarpPerBlock_N = BlockWarps::at(number<1>{});
static constexpr index_t ThreadPerWarp_M = Warp_M / Vector_M;
static constexpr index_t ThreadPerWarp_N = Warp_N / Vector_N;
static constexpr index_t Repeat_M = Block_M / (WarpPerBlock_M * Warp_M);
static constexpr index_t Repeat_N = Block_N / (WarpPerBlock_N * Warp_N);
static constexpr index_t BlockSize =
ck_tile::get_warp_size() * reduce_on_sequence(BlockWarps{}, multiplies{}, number<1>{});
};
template <typename XDataType_,
typename ComputeDataType_,
typename YDataType_,
typename BlockShape_,
typename ReduceOp_>
struct Reduce2dProblem
{
using XDataType = remove_cvref_t<XDataType_>;
using ComputeDataType = remove_cvref_t<ComputeDataType_>;
using YDataType = remove_cvref_t<YDataType_>;
using BlockShape = remove_cvref_t<BlockShape_>;
using ReduceOp = ReduceOp_;
static constexpr bool kNeedCrossLaneSync = BlockShape::ThreadPerWarp_N > 1;
static constexpr bool kNeedCrossWarpSync = BlockShape::WarpPerBlock_N > 1;
};
template <typename Problem_, typename Policy_ = BlockReduce2dDefaultPolicy>
struct Reduce
{
using Problem = ck_tile::remove_cvref_t<Problem_>;
using Policy = ck_tile::remove_cvref_t<Policy_>;
using XDataType = ck_tile::remove_cvref_t<typename Problem::XDataType>;
using ComputeDataType = ck_tile::remove_cvref_t<typename Problem::ComputeDataType>;
using YDataType = ck_tile::remove_cvref_t<typename Problem::YDataType>;
#if 0
CK_TILE_DEVICE void operator()(const XDataType* p_x, YDataType* p_y, index_t M, index_t N)
const
{
using S = typename Problem::BlockShape;
const auto x_m_n = make_naive_tensor_view<address_space_enum::global>(
p_x, make_tuple(M, N), make_tuple(N, 1), number<S::Vector_N>{}, number<1>{});
const auto y_m = make_naive_tensor_view_packed<address_space_enum::global>(
p_y, make_tuple(M), number<1>{});
const auto iM = get_block_id() * S::Block_M;
auto x_window = make_tile_window(x_m_n,
make_tuple(number<S::Block_M>{}, number<S::Block_N>{}),
{iM, 0},
Policy::template MakeXBlockTileDistribution<Problem>());
auto y_window = make_tile_window(y_m, make_tuple(number<S::Block_M>{}), {iM});
const auto f_reduce = [](const auto& v0, const auto& v1) { return v0 + v1; };
const XDataType reduce_init_value = 0;
constexpr auto reduce_dims = sequence<1>{};
auto y_compute = decltype(block_tile_reduce<ComputeDataType>(
load_tile(x_window), reduce_dims, f_reduce, reduce_init_value)){};
set_tile(y_compute, reduce_init_value);
index_t num_n_tile_iteration =
__builtin_amdgcn_readfirstlane(integer_divide_ceil(N, S::Block_N));
for(int iN = __builtin_amdgcn_readfirstlane(0); iN < num_n_tile_iteration; ++iN)
{
const auto x = load_tile(x_window);
block_tile_reduce(y_compute, x, reduce_dims, f_reduce);
move_tile_window(x_window, {0, S::Block_N});
}
block_tile_reduce_sync(y_compute, f_reduce);
store_tile(y_window, cast_tile<YDataType>(y_compute));
}
#else
CK_TILE_DEVICE void operator()(const XDataType* p_x, YDataType* p_y, index_t M, index_t N) const
{
using S = typename Problem::BlockShape;
const auto x_m_n = make_naive_tensor_view<address_space_enum::global>(
p_x, make_tuple(M, N), make_tuple(N, 1), number<S::Vector_N>{}, number<1>{});
const auto y_m = make_naive_tensor_view_packed<address_space_enum::global>(
p_y, make_tuple(M), number<1>{});
const auto iM = get_block_id() * S::Block_M;
auto x_window = make_tile_window(x_m_n,
make_tuple(number<S::Block_M>{}, number<S::Block_N>{}),
{iM, 0},
Policy::template MakeXBlockTileDistribution<Problem>());
auto y_window = make_tile_window(y_m, make_tuple(number<S::Block_M>{}), {iM});
__shared__ char smem[Policy::template GetSmemSize<Problem>()];
index_t num_n_tile_iteration =
__builtin_amdgcn_readfirstlane(integer_divide_ceil(N, S::Block_N));
auto reduce_func = typename Problem::ReduceOp{};
auto block_reduce2d = Policy::template GetBlockReduce2d<Problem>();
auto block_reduce2d_sync = Policy::template GetBlockReduce2dSync<Problem>();
auto block_reduce2d_cross_warp_sync =
Policy::template GetBlockReduce2dCrossWarpSync<Problem>();
using XTensorType = decltype(load_tile(x_window));
auto y_compute = block_reduce2d.template MakeYBlockTile<XTensorType>();
set_tile(y_compute, reduce_func.template GetIdentityValue<ComputeDataType>());
for(int iN = __builtin_amdgcn_readfirstlane(0); iN < num_n_tile_iteration; ++iN)
{
const auto x = load_tile(x_window);
block_reduce2d(x, y_compute, reduce_func);
move_tile_window(x_window, {0, S::Block_N});
}
block_reduce2d_sync(y_compute, reduce_func);
block_reduce2d_cross_warp_sync(y_compute, smem, reduce_func);
store_tile(y_window, cast_tile<YDataType>(y_compute));
}
#endif
};
} // namespace ck_tile