mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-01 20:21:23 +00:00
[CK_TILE] Pooling FWD (Lwpck 3683) (#2956)
* Pooling 2D/3D with refernce * Tests & cleanup - added test for ppoling - cleanup - removed 2d example * Comment resolution - README added - example target name rectified - appropriate arg description and comments added * clang-format * appropriate blocksize calc * modifications for future indexing addition - instead of transforming views we now transform the descriptors, so that the same descriptor can be re-used for index tensor in the future * some basic fixes * comment resolutions * comment resolutions --------- Co-authored-by: Illia Silin <98187287+illsilin@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
9d4bfe3932
commit
7b6451b68e
496
include/ck_tile/ops/pooling/kernel/pool_kernel.hpp
Normal file
496
include/ck_tile/ops/pooling/kernel/pool_kernel.hpp
Normal file
@@ -0,0 +1,496 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/pooling/pipeline/pool_default_policy.hpp"
|
||||
#include "ck_tile/ops/common.hpp"
|
||||
#include <type_traits>
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
/// @brief Host arguments for pooling operations
|
||||
template <typename TensorShape, typename WindowShape>
|
||||
struct PoolHostArgs
|
||||
{
|
||||
|
||||
CK_TILE_HOST PoolHostArgs(const void* input_ptr_,
|
||||
void* output_ptr_,
|
||||
TensorShape input_shape_,
|
||||
TensorShape output_shape_,
|
||||
TensorShape input_strides_,
|
||||
TensorShape output_strides_,
|
||||
WindowShape window_lengths_,
|
||||
WindowShape window_strides_,
|
||||
WindowShape window_dilations_,
|
||||
WindowShape input_left_pads_,
|
||||
WindowShape input_right_pads_)
|
||||
: input_ptr(input_ptr_),
|
||||
output_ptr(output_ptr_),
|
||||
input_shape(input_shape_),
|
||||
output_shape(output_shape_),
|
||||
input_strides(input_strides_),
|
||||
output_strides(output_strides_),
|
||||
window_lengths(window_lengths_),
|
||||
window_strides(window_strides_),
|
||||
window_dilations(window_dilations_),
|
||||
input_left_pads(input_left_pads_),
|
||||
input_right_pads(input_right_pads_)
|
||||
{
|
||||
}
|
||||
|
||||
const void* input_ptr;
|
||||
void* output_ptr;
|
||||
|
||||
TensorShape input_shape;
|
||||
TensorShape output_shape;
|
||||
TensorShape input_strides;
|
||||
TensorShape output_strides;
|
||||
WindowShape window_lengths;
|
||||
WindowShape window_strides;
|
||||
WindowShape window_dilations;
|
||||
WindowShape input_left_pads;
|
||||
WindowShape input_right_pads;
|
||||
};
|
||||
|
||||
/// @brief Kernel arguments for pooling operations
|
||||
template <typename TensorShape, typename WindowShape>
|
||||
struct PoolKernelArgs
|
||||
{
|
||||
const void* input_ptr;
|
||||
void* output_ptr;
|
||||
TensorShape input_shape;
|
||||
TensorShape output_shape;
|
||||
TensorShape input_strides;
|
||||
TensorShape output_strides;
|
||||
WindowShape window_lengths;
|
||||
WindowShape window_strides;
|
||||
WindowShape window_dilations;
|
||||
WindowShape input_left_pads;
|
||||
WindowShape input_right_pads;
|
||||
};
|
||||
|
||||
template <typename Problem_, typename Policy_ = PoolDefaultPolicy>
|
||||
struct PoolKernel
|
||||
{
|
||||
using Problem = ck_tile::remove_cvref_t<Problem_>;
|
||||
using Policy = ck_tile::remove_cvref_t<Policy_>;
|
||||
|
||||
using InDataType = ck_tile::remove_cvref_t<typename Problem::InDataType>;
|
||||
using ComputeDataType = ck_tile::remove_cvref_t<typename Problem::ComputeDataType>;
|
||||
using OutDataType = ck_tile::remove_cvref_t<typename Problem::OutDataType>;
|
||||
|
||||
static constexpr index_t kBlockSize = Problem::BlockShape::BlockSize;
|
||||
|
||||
CK_TILE_HOST static constexpr auto BlockSize()
|
||||
{
|
||||
return is_wave32() ? kBlockSize / 2 : kBlockSize;
|
||||
}
|
||||
|
||||
template <typename TensorShape, typename WindowShape>
|
||||
static CK_TILE_DEVICE auto MakeTensorView2D(PoolKernelArgs<TensorShape, WindowShape> kargs)
|
||||
{
|
||||
using S = typename Problem::BlockShape;
|
||||
|
||||
// Compile-time validation for 2D pooling
|
||||
static_assert(TensorShape::size() == 4, "2D pooling requires 4D input tensor (N,H,W,C)");
|
||||
static_assert(WindowShape::size() == 2, "2D pooling requires 2D window shape (Y,X)");
|
||||
|
||||
// Extract dimension values
|
||||
const index_t N = kargs.input_shape.at(number<0>{});
|
||||
const index_t H = kargs.input_shape.at(number<1>{});
|
||||
const index_t W = kargs.input_shape.at(number<2>{});
|
||||
const index_t C = kargs.input_shape.at(number<3>{});
|
||||
|
||||
const index_t No = kargs.output_shape.at(number<0>{});
|
||||
const index_t Ho = kargs.output_shape.at(number<1>{});
|
||||
const index_t Wo = kargs.output_shape.at(number<2>{});
|
||||
const index_t Co = kargs.output_shape.at(number<3>{});
|
||||
|
||||
const index_t Y = kargs.window_lengths.at(number<0>{});
|
||||
const index_t X = kargs.window_lengths.at(number<1>{});
|
||||
|
||||
const index_t WindowStrideH = kargs.window_strides.at(number<0>{});
|
||||
const index_t WindowStrideW = kargs.window_strides.at(number<1>{});
|
||||
|
||||
const index_t WindowDilationH = kargs.window_dilations.at(number<0>{});
|
||||
const index_t WindowDilationW = kargs.window_dilations.at(number<1>{});
|
||||
|
||||
const index_t InLeftPadH = kargs.input_left_pads.at(number<0>{});
|
||||
const index_t InLeftPadW = kargs.input_left_pads.at(number<1>{});
|
||||
|
||||
const index_t InRightPadH = kargs.input_right_pads.at(number<0>{});
|
||||
const index_t InRightPadW = kargs.input_right_pads.at(number<1>{});
|
||||
|
||||
const index_t MRaw = N * Ho * Wo * C;
|
||||
const index_t KRaw = Y * X;
|
||||
const index_t MPad = integer_least_multiple(MRaw, S::Block_M) - MRaw;
|
||||
const index_t KPad = integer_least_multiple(KRaw, S::Block_N) - KRaw;
|
||||
|
||||
auto reduce_op = typename Problem::ReduceOp{};
|
||||
|
||||
// Create input descriptor with all transformations
|
||||
auto in_desc = make_naive_tensor_descriptor(kargs.input_shape, kargs.input_strides);
|
||||
|
||||
// Apply spatial padding to input descriptor
|
||||
const auto padded_in_desc = transform_tensor_descriptor(
|
||||
in_desc,
|
||||
make_tuple(make_pass_through_transform(N),
|
||||
make_pad_transform(H, InLeftPadH, InRightPadH),
|
||||
make_pad_transform(W, InLeftPadW, InRightPadW),
|
||||
make_pass_through_transform(C)),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}));
|
||||
|
||||
// Create sliding windows by embedding pooling windows into descriptor
|
||||
const auto embed_in_desc = transform_tensor_descriptor(
|
||||
padded_in_desc,
|
||||
make_tuple(
|
||||
make_pass_through_transform(N),
|
||||
make_embed_transform(make_tuple(Y, Ho), make_tuple(WindowDilationH, WindowStrideH)),
|
||||
make_embed_transform(make_tuple(X, Wo), make_tuple(WindowDilationW, WindowStrideW)),
|
||||
make_pass_through_transform(C)),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1, 2>{}, sequence<3, 4>{}, sequence<5>{}));
|
||||
|
||||
// Reshape into 2D matrix: output positions (M) x pooling window elements (K)
|
||||
const auto merged_embed_in_desc =
|
||||
transform_tensor_descriptor(embed_in_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(N, Ho, Wo, C)),
|
||||
make_merge_transform(make_tuple(Y, X))),
|
||||
make_tuple(sequence<0, 2, 4, 5>{}, sequence<1, 3>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
|
||||
const auto in_desc_padded = transform_tensor_descriptor(
|
||||
merged_embed_in_desc,
|
||||
make_tuple(make_right_pad_transform(MRaw, MPad), make_right_pad_transform(KRaw, KPad)),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
|
||||
// Create output descriptor with transformations
|
||||
auto out_desc = make_naive_tensor_descriptor(kargs.output_shape, kargs.output_strides);
|
||||
|
||||
const auto merged_out_desc = transform_tensor_descriptor(
|
||||
out_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(No, Ho, Wo, Co))),
|
||||
make_tuple(sequence<0, 1, 2, 3>{}),
|
||||
make_tuple(sequence<0>{}));
|
||||
|
||||
const auto out_desc_padded =
|
||||
transform_tensor_descriptor(merged_out_desc,
|
||||
make_tuple(make_right_pad_transform(MRaw, MPad)),
|
||||
make_tuple(sequence<0>{}),
|
||||
make_tuple(sequence<0>{}));
|
||||
|
||||
// Now create buffer views and tensor views with the fully transformed descriptors
|
||||
const InDataType in_identity =
|
||||
type_convert<InDataType>(reduce_op.template GetIdentityValue<ComputeDataType>());
|
||||
const OutDataType out_identity =
|
||||
type_convert<OutDataType>(reduce_op.template GetIdentityValue<ComputeDataType>());
|
||||
|
||||
auto in_buffer_view = make_buffer_view<address_space_enum::global>(
|
||||
static_cast<const InDataType*>(kargs.input_ptr),
|
||||
in_desc.get_element_space_size(),
|
||||
in_identity);
|
||||
const auto in_tensor_padded =
|
||||
tensor_view<decltype(in_buffer_view), decltype(in_desc_padded)>{in_buffer_view,
|
||||
in_desc_padded};
|
||||
|
||||
auto out_buffer_view = make_buffer_view<address_space_enum::global>(
|
||||
static_cast<OutDataType*>(kargs.output_ptr),
|
||||
out_desc.get_element_space_size(),
|
||||
out_identity);
|
||||
const auto out_tensor_padded =
|
||||
tensor_view<decltype(out_buffer_view), decltype(out_desc_padded)>{out_buffer_view,
|
||||
out_desc_padded};
|
||||
|
||||
return make_tuple(in_tensor_padded, out_tensor_padded);
|
||||
}
|
||||
|
||||
template <typename TensorShape, typename WindowShape>
|
||||
static CK_TILE_DEVICE auto MakeTensorView3D(PoolKernelArgs<TensorShape, WindowShape> kargs)
|
||||
{
|
||||
using S = typename Problem::BlockShape;
|
||||
|
||||
// Compile-time validation for 3D pooling
|
||||
static_assert(TensorShape::size() == 5, "3D pooling requires 5D input tensor (N,D,H,W,C)");
|
||||
static_assert(WindowShape::size() == 3, "3D pooling requires 3D window shape (Z,Y,X)");
|
||||
|
||||
// Extract dimension values
|
||||
const index_t N = kargs.input_shape.at(number<0>{});
|
||||
const index_t D = kargs.input_shape.at(number<1>{});
|
||||
const index_t H = kargs.input_shape.at(number<2>{});
|
||||
const index_t W = kargs.input_shape.at(number<3>{});
|
||||
const index_t C = kargs.input_shape.at(number<4>{});
|
||||
|
||||
const index_t No = kargs.output_shape.at(number<0>{});
|
||||
const index_t Do = kargs.output_shape.at(number<1>{});
|
||||
const index_t Ho = kargs.output_shape.at(number<2>{});
|
||||
const index_t Wo = kargs.output_shape.at(number<3>{});
|
||||
const index_t Co = kargs.output_shape.at(number<4>{});
|
||||
|
||||
const index_t Z = kargs.window_lengths.at(number<0>{});
|
||||
const index_t Y = kargs.window_lengths.at(number<1>{});
|
||||
const index_t X = kargs.window_lengths.at(number<2>{});
|
||||
|
||||
const index_t WindowStrideD = kargs.window_strides.at(number<0>{});
|
||||
const index_t WindowStrideH = kargs.window_strides.at(number<1>{});
|
||||
const index_t WindowStrideW = kargs.window_strides.at(number<2>{});
|
||||
|
||||
const index_t WindowDilationD = kargs.window_dilations.at(number<0>{});
|
||||
const index_t WindowDilationH = kargs.window_dilations.at(number<1>{});
|
||||
const index_t WindowDilationW = kargs.window_dilations.at(number<2>{});
|
||||
|
||||
const index_t InLeftPadD = kargs.input_left_pads.at(number<0>{});
|
||||
const index_t InLeftPadH = kargs.input_left_pads.at(number<1>{});
|
||||
const index_t InLeftPadW = kargs.input_left_pads.at(number<2>{});
|
||||
|
||||
const index_t InRightPadD = kargs.input_right_pads.at(number<0>{});
|
||||
const index_t InRightPadH = kargs.input_right_pads.at(number<1>{});
|
||||
const index_t InRightPadW = kargs.input_right_pads.at(number<2>{});
|
||||
|
||||
const index_t MRaw = N * Do * Ho * Wo * C;
|
||||
const index_t KRaw = Z * Y * X;
|
||||
const index_t MPad = integer_least_multiple(MRaw, S::Block_M) - MRaw;
|
||||
const index_t KPad = integer_least_multiple(KRaw, S::Block_N) - KRaw;
|
||||
|
||||
auto reduce_op = typename Problem::ReduceOp{};
|
||||
|
||||
// Create input descriptor with all transformations
|
||||
auto in_desc = make_naive_tensor_descriptor(kargs.input_shape, kargs.input_strides);
|
||||
|
||||
// Apply spatial padding to input descriptor (all 3D dimensions)
|
||||
const auto padded_in_desc = transform_tensor_descriptor(
|
||||
in_desc,
|
||||
make_tuple(make_pass_through_transform(N),
|
||||
make_pad_transform(D, InLeftPadD, InRightPadD),
|
||||
make_pad_transform(H, InLeftPadH, InRightPadH),
|
||||
make_pad_transform(W, InLeftPadW, InRightPadW),
|
||||
make_pass_through_transform(C)),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}, sequence<4>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}, sequence<4>{}));
|
||||
|
||||
// Create 3D sliding windows by embedding pooling windows into descriptor
|
||||
const auto embed_in_desc = transform_tensor_descriptor(
|
||||
padded_in_desc,
|
||||
make_tuple(
|
||||
make_pass_through_transform(N),
|
||||
make_embed_transform(make_tuple(Z, Do), make_tuple(WindowDilationD, WindowStrideD)),
|
||||
make_embed_transform(make_tuple(Y, Ho), make_tuple(WindowDilationH, WindowStrideH)),
|
||||
make_embed_transform(make_tuple(X, Wo), make_tuple(WindowDilationW, WindowStrideW)),
|
||||
make_pass_through_transform(C)),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}, sequence<4>{}),
|
||||
make_tuple(sequence<0>{},
|
||||
sequence<1, 2>{},
|
||||
sequence<3, 4>{},
|
||||
sequence<5, 6>{},
|
||||
sequence<7>{}));
|
||||
|
||||
// Reshape into 2D matrix: output positions (M) x pooling window elements (K)
|
||||
const auto merged_embed_in_desc = transform_tensor_descriptor(
|
||||
embed_in_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(N, Do, Ho, Wo, C)),
|
||||
make_merge_transform(make_tuple(Z, Y, X))),
|
||||
make_tuple(sequence<0, 2, 4, 6, 7>{}, sequence<1, 3, 5>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
|
||||
const auto in_desc_padded = transform_tensor_descriptor(
|
||||
merged_embed_in_desc,
|
||||
make_tuple(make_right_pad_transform(MRaw, MPad), make_right_pad_transform(KRaw, KPad)),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
|
||||
// Create output descriptor with transformations
|
||||
auto out_desc = make_naive_tensor_descriptor(kargs.output_shape, kargs.output_strides);
|
||||
|
||||
const auto merged_out_desc = transform_tensor_descriptor(
|
||||
out_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(No, Do, Ho, Wo, Co))),
|
||||
make_tuple(sequence<0, 1, 2, 3, 4>{}),
|
||||
make_tuple(sequence<0>{}));
|
||||
|
||||
const auto out_desc_padded =
|
||||
transform_tensor_descriptor(merged_out_desc,
|
||||
make_tuple(make_right_pad_transform(MRaw, MPad)),
|
||||
make_tuple(sequence<0>{}),
|
||||
make_tuple(sequence<0>{}));
|
||||
|
||||
// Now create buffer views and tensor views with the fully transformed descriptors
|
||||
const InDataType in_identity =
|
||||
type_convert<InDataType>(reduce_op.template GetIdentityValue<ComputeDataType>());
|
||||
const OutDataType out_identity =
|
||||
type_convert<OutDataType>(reduce_op.template GetIdentityValue<ComputeDataType>());
|
||||
|
||||
auto in_buffer_view = make_buffer_view<address_space_enum::global>(
|
||||
static_cast<const InDataType*>(kargs.input_ptr),
|
||||
in_desc.get_element_space_size(),
|
||||
in_identity);
|
||||
const auto in_tensor_padded =
|
||||
tensor_view<decltype(in_buffer_view), decltype(in_desc_padded)>{in_buffer_view,
|
||||
in_desc_padded};
|
||||
|
||||
auto out_buffer_view = make_buffer_view<address_space_enum::global>(
|
||||
static_cast<OutDataType*>(kargs.output_ptr),
|
||||
out_desc.get_element_space_size(),
|
||||
out_identity);
|
||||
const auto out_tensor_padded =
|
||||
tensor_view<decltype(out_buffer_view), decltype(out_desc_padded)>{out_buffer_view,
|
||||
out_desc_padded};
|
||||
|
||||
return make_tuple(in_tensor_padded, out_tensor_padded);
|
||||
}
|
||||
|
||||
public:
|
||||
template <typename TensorShape, typename WindowShape>
|
||||
CK_TILE_DEVICE void operator()(PoolKernelArgs<TensorShape, WindowShape> kargs) const
|
||||
{
|
||||
using S = typename Problem::BlockShape;
|
||||
|
||||
// Compile-time validation for supported window dimensions
|
||||
static_assert(WindowShape::size() == 2 || WindowShape::size() == 3,
|
||||
"Only 2D and 3D pooling operations are supported");
|
||||
|
||||
const auto iM = get_block_id() * S::Block_M;
|
||||
|
||||
// Get tensors based on dimensionality
|
||||
auto [in_tensor_padded, out_tensor_padded] = [&]() {
|
||||
if constexpr(WindowShape::size() == 2)
|
||||
return MakeTensorView2D(kargs);
|
||||
else if constexpr(WindowShape::size() == 3)
|
||||
return MakeTensorView3D(kargs);
|
||||
else
|
||||
static_assert(WindowShape::size() == 2 || WindowShape::size() == 3,
|
||||
"Unsupported WindowShape rank: only 2D or 3D pooling is supported");
|
||||
}();
|
||||
|
||||
auto reduce_op = typename Problem::ReduceOp{};
|
||||
|
||||
auto x_window = make_tile_window(in_tensor_padded,
|
||||
make_tuple(number<S::Block_M>{}, number<S::Block_N>{}),
|
||||
{iM, 0},
|
||||
Policy::template MakeXBlockTileDistribution<Problem>());
|
||||
auto y_window = make_tile_window(out_tensor_padded, make_tuple(number<S::Block_M>{}), {iM});
|
||||
|
||||
__shared__ char smem[Policy::template GetSmemSize<Problem>()];
|
||||
|
||||
const auto reduce_len =
|
||||
in_tensor_padded.get_tensor_descriptor().get_lengths().at(number<1>{});
|
||||
index_t num_k_tiles =
|
||||
__builtin_amdgcn_readfirstlane(integer_divide_ceil(reduce_len, S::Block_N));
|
||||
|
||||
auto block_reduce2d = Policy::template GetBlockReduce2d<Problem>();
|
||||
auto block_reduce2d_sync = Policy::template GetBlockReduce2dSync<Problem>();
|
||||
auto block_reduce2d_cross_warp = Policy::template GetBlockReduce2dCrossWarpSync<Problem>();
|
||||
|
||||
using XTensorTile = decltype(load_tile(x_window));
|
||||
auto y_tile = block_reduce2d.template MakeYBlockTile<XTensorTile>();
|
||||
set_tile(y_tile, reduce_op.template GetIdentityValue<ComputeDataType>());
|
||||
|
||||
for(int k_tile = __builtin_amdgcn_readfirstlane(0); k_tile < num_k_tiles; ++k_tile)
|
||||
{
|
||||
const auto x_tile = load_tile(x_window);
|
||||
block_reduce2d(x_tile, y_tile, reduce_op);
|
||||
move_tile_window(x_window, {0, S::Block_N});
|
||||
}
|
||||
|
||||
block_reduce2d_sync(y_tile, reduce_op);
|
||||
block_reduce2d_cross_warp(y_tile, smem, reduce_op);
|
||||
store_tile(y_window, cast_tile<OutDataType>(y_tile));
|
||||
}
|
||||
|
||||
/// @brief Validates if the given arguments are supported by the pooling kernel.
|
||||
///
|
||||
/// @param kargs The pooling kernel arguments containing all necessary parameters.
|
||||
///
|
||||
/// @return true if the arguments are supported, false otherwise.
|
||||
///
|
||||
/// @note Requirements:
|
||||
/// - Last dimension (C) must be contiguous (stride = 1) for vectorized access
|
||||
/// - Window dimensions must be supported (2D or 3D)
|
||||
/// - All dimension sizes must be consistent between input and output
|
||||
template <typename TensorShape, typename WindowShape>
|
||||
CK_TILE_HOST static bool IsSupportedArgument(PoolKernelArgs<TensorShape, WindowShape> kargs)
|
||||
{
|
||||
constexpr index_t InputRank = TensorShape::size();
|
||||
constexpr index_t OutputRank = TensorShape::size(); // Same as input rank
|
||||
constexpr index_t WindowRank = WindowShape::size();
|
||||
|
||||
// Validate window dimensions (only 2D and 3D supported)
|
||||
if constexpr(WindowRank != 2 && WindowRank != 3)
|
||||
{
|
||||
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
|
||||
{
|
||||
CK_TILE_ERROR("Only 2D and 3D pooling are supported!");
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
// Validate that input rank matches expected rank for window dimensions
|
||||
if constexpr((WindowRank == 2 && InputRank != 4) || (WindowRank == 3 && InputRank != 5))
|
||||
{
|
||||
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
|
||||
{
|
||||
CK_TILE_ERROR("Input tensor rank doesn't match window dimensions!");
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
// Check that channel dimension (last dimension) is contiguous for both input and output
|
||||
if(kargs.input_strides.at(number<InputRank - 1>{}) != 1)
|
||||
{
|
||||
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
|
||||
{
|
||||
CK_TILE_ERROR("Input tensor's channel dimension must have stride 1!");
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
if(kargs.output_strides.at(number<OutputRank - 1>{}) != 1)
|
||||
{
|
||||
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
|
||||
{
|
||||
CK_TILE_ERROR("Output tensor's channel dimension must have stride 1!");
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
/// @param kargs The pooling kernel arguments
|
||||
/// @return The calculated grid size
|
||||
template <typename TensorShape, typename WindowShape>
|
||||
CK_TILE_HOST static constexpr index_t
|
||||
CalculateGridSize(PoolKernelArgs<TensorShape, WindowShape> kargs)
|
||||
{
|
||||
using S = typename Problem::BlockShape;
|
||||
|
||||
// Calculate total output elements (M dimension)
|
||||
index_t M = 1;
|
||||
static_for<0, TensorShape::size(), 1>{}([&](auto i) { M *= kargs.output_shape.at(i); });
|
||||
|
||||
// Calculate grid size: ceil(M / Block_M)
|
||||
return (M + S::Block_M - 1) / S::Block_M;
|
||||
}
|
||||
|
||||
/// @brief Create kernel arguments from host arguments
|
||||
template <typename TensorShape, typename WindowShape>
|
||||
CK_TILE_HOST static constexpr auto
|
||||
MakeKernelArgs(PoolHostArgs<TensorShape, WindowShape>& host_args)
|
||||
{
|
||||
return PoolKernelArgs<TensorShape, WindowShape>{host_args.input_ptr,
|
||||
host_args.output_ptr,
|
||||
host_args.input_shape,
|
||||
host_args.output_shape,
|
||||
host_args.input_strides,
|
||||
host_args.output_strides,
|
||||
host_args.window_lengths,
|
||||
host_args.window_strides,
|
||||
host_args.window_dilations,
|
||||
host_args.input_left_pads,
|
||||
host_args.input_right_pads};
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
Reference in New Issue
Block a user