mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-03 13:11:25 +00:00
* Pooling 2D/3D with refernce * Tests & cleanup - added test for ppoling - cleanup - removed 2d example * Comment resolution - README added - example target name rectified - appropriate arg description and comments added * clang-format * appropriate blocksize calc * modifications for future indexing addition - instead of transforming views we now transform the descriptors, so that the same descriptor can be re-used for index tensor in the future * some basic fixes * comment resolutions * comment resolutions --------- Co-authored-by: Illia Silin <98187287+illsilin@users.noreply.github.com>
58 lines
2.8 KiB
C++
58 lines
2.8 KiB
C++
// SPDX-License-Identifier: MIT
|
|
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
|
|
|
#pragma once
|
|
|
|
#include "ck_tile/core.hpp"
|
|
|
|
namespace ck_tile {
|
|
|
|
template <typename BlockWarps, // num warps along seq<M, N>
|
|
typename BlockTile, // block size, seq<M, N>
|
|
typename WarpTile, // warp size, seq<M, N>
|
|
typename ThreadTile> // contiguous pixels(vector size) along seq<M, N>
|
|
struct PoolShape
|
|
{
|
|
static constexpr index_t Block_M = BlockTile::at(number<0>{});
|
|
static constexpr index_t Block_N = BlockTile::at(number<1>{});
|
|
|
|
static constexpr index_t Warp_M = WarpTile::at(number<0>{});
|
|
static constexpr index_t Warp_N = WarpTile::at(number<1>{});
|
|
|
|
static constexpr index_t ThreadTile_M = ThreadTile::at(number<0>{});
|
|
static constexpr index_t ThreadTile_N = ThreadTile::at(number<1>{});
|
|
|
|
static constexpr index_t WarpPerBlock_M = BlockWarps::at(number<0>{});
|
|
static constexpr index_t WarpPerBlock_N = BlockWarps::at(number<1>{});
|
|
|
|
static_assert(Warp_M % ThreadTile_M == 0, "Warp_M must be divisible by ThreadTile_M");
|
|
static_assert(Warp_N % ThreadTile_N == 0, "Warp_N must be divisible by ThreadTile_N");
|
|
static_assert((Warp_M * Warp_N / ThreadTile_M / ThreadTile_N) % ck_tile::get_warp_size() == 0,
|
|
"Warp_M * Warp_N / ThreadTile_M / ThreadTile_N must be a multiple of warp size");
|
|
|
|
// Scale factor to account for warp size
|
|
// WarpSizeScaleFactor = warp tile/ thread tile / warp size
|
|
static constexpr index_t WarpSizeScaleFactor =
|
|
Warp_M * Warp_N / ThreadTile_M / ThreadTile_N / ck_tile::get_warp_size();
|
|
|
|
static constexpr index_t WarpSizeScaleFactor_M =
|
|
(Warp_M / ThreadTile_M > Warp_N / ThreadTile_N) ? WarpSizeScaleFactor : 1;
|
|
static constexpr index_t WarpSizeScaleFactor_N =
|
|
(Warp_M / ThreadTile_M > Warp_N / ThreadTile_N) ? 1 : WarpSizeScaleFactor;
|
|
|
|
static constexpr index_t ThreadPerWarp_M = Warp_M / ThreadTile_M / WarpSizeScaleFactor_M;
|
|
static constexpr index_t ThreadPerWarp_N = Warp_N / ThreadTile_N / WarpSizeScaleFactor_N;
|
|
|
|
static_assert((Block_M * WarpSizeScaleFactor_M) % (WarpPerBlock_M * Warp_M) == 0,
|
|
"Block_M * WarpSizeScaleFactor_M must be divisible by WarpPerBlock_M * Warp_M");
|
|
static_assert((Block_N * WarpSizeScaleFactor_N) % (WarpPerBlock_N * Warp_N) == 0,
|
|
"Block_N * WarpSizeScaleFactor_N must be divisible by WarpPerBlock_N * Warp_N");
|
|
|
|
static constexpr index_t Repeat_M = Block_M * WarpSizeScaleFactor_M / (WarpPerBlock_M * Warp_M);
|
|
static constexpr index_t Repeat_N = Block_N * WarpSizeScaleFactor_N / (WarpPerBlock_N * Warp_N);
|
|
|
|
static constexpr index_t BlockSize =
|
|
ck_tile::get_warp_size() * reduce_on_sequence(BlockWarps{}, multiplies{}, number<1>{});
|
|
};
|
|
} // namespace ck_tile
|