mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-03 13:11:25 +00:00
* General 2D Reduction Kernel * Move the reduction kernel from the example * Split the code and add the necessary policy, problem, shape files as per ck_tile convention * Add/modify the headers * Modified the example to work with the 'new' kernel * Added tests for the kernel * N-D refernce reduce * Added support for N-D input with transform to 2D * Added padding to support various input sized tensors * Bug fix in the thread buffer constructor * Some comments to explain the reduce2d block kernel * comments resolution * clang-format * comments resolution * clang-format * clang-format * comments resolution * clang-format
38 lines
1.4 KiB
C++
38 lines
1.4 KiB
C++
// SPDX-License-Identifier: MIT
|
|
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
|
|
|
#pragma once
|
|
|
|
#include "ck_tile/core.hpp"
|
|
|
|
namespace ck_tile {
|
|
|
|
template <typename BlockWarps, // num warps along seq<M, N>
|
|
typename BlockTile, // block size, seq<M, N>
|
|
typename WarpTile, // warp size, seq<M, N>
|
|
typename ThreadTile> // contiguous pixels(vector size) along seq<M, N>
|
|
struct Reduce2dShape
|
|
{
|
|
static constexpr index_t Block_M = BlockTile::at(number<0>{});
|
|
static constexpr index_t Block_N = BlockTile::at(number<1>{});
|
|
|
|
static constexpr index_t Warp_M = WarpTile::at(number<0>{});
|
|
static constexpr index_t Warp_N = WarpTile::at(number<1>{});
|
|
|
|
static constexpr index_t ThreadTile_M = ThreadTile::at(number<0>{});
|
|
static constexpr index_t ThreadTile_N = ThreadTile::at(number<1>{});
|
|
|
|
static constexpr index_t WarpPerBlock_M = BlockWarps::at(number<0>{});
|
|
static constexpr index_t WarpPerBlock_N = BlockWarps::at(number<1>{});
|
|
|
|
static constexpr index_t ThreadPerWarp_M = Warp_M / ThreadTile_M;
|
|
static constexpr index_t ThreadPerWarp_N = Warp_N / ThreadTile_N;
|
|
|
|
static constexpr index_t Repeat_M = Block_M / (WarpPerBlock_M * Warp_M);
|
|
static constexpr index_t Repeat_N = Block_N / (WarpPerBlock_N * Warp_N);
|
|
|
|
static constexpr index_t BlockSize =
|
|
ck_tile::get_warp_size() * reduce_on_sequence(BlockWarps{}, multiplies{}, number<1>{});
|
|
};
|
|
} // namespace ck_tile
|