diff --git a/example/ck_tile/36_pooling/pool3d.cpp b/example/ck_tile/36_pooling/pool3d.cpp index bb76efbc03..092020c4ae 100644 --- a/example/ck_tile/36_pooling/pool3d.cpp +++ b/example/ck_tile/36_pooling/pool3d.cpp @@ -38,7 +38,10 @@ auto create_args(int argc, char* argv[]) return std::make_tuple(result, arg_parser); } -template +template bool run(const ck_tile::ArgParser& arg_parser) { @@ -84,6 +87,9 @@ bool run(const ck_tile::ArgParser& arg_parser) int warmup = arg_parser.get_int("warmup"); int repeat = arg_parser.get_int("repeat"); + constexpr bool OutputIndex = true; + constexpr bool PropagateNan = false; + // Shapes / strides / parameters (NDHWC) const auto input_shape = ck_tile::make_tuple(N, D, H, W, C); const auto output_shape = ck_tile::make_tuple(N, Do, Ho, Wo, C); @@ -100,11 +106,16 @@ bool run(const ck_tile::ArgParser& arg_parser) {Do * Ho * Wo * C, Ho * Wo * C, Wo * C, C, 1}); ck_tile::HostTensor out_ref({N, Do, Ho, Wo, C}, {Do * Ho * Wo * C, Ho * Wo * C, Wo * C, C, 1}); + ck_tile::HostTensor out_index({N, Do, Ho, Wo, C}, + {Do * Ho * Wo * C, Ho * Wo * C, Wo * C, C, 1}); + ck_tile::HostTensor out_ref_index({N, Do, Ho, Wo, C}, + {Do * Ho * Wo * C, Ho * Wo * C, Wo * C, C, 1}); ck_tile::FillUniformDistribution{-5.f, 5.f}(in); ck_tile::DeviceMem in_buf(in.get_element_space_size_in_bytes()); ck_tile::DeviceMem out_buf(out.get_element_space_size_in_bytes()); + ck_tile::DeviceMem out_index_buf(OutputIndex ? out_index.get_element_space_size_in_bytes() : 0); in_buf.ToDevice(in.data()); @@ -118,10 +129,10 @@ bool run(const ck_tile::ArgParser& arg_parser) using Problem = ck_tile::PoolProblem; using Kernel = ck_tile::PoolKernel; @@ -131,6 +142,7 @@ bool run(const ck_tile::ArgParser& arg_parser) auto host_args = ck_tile::PoolHostArgs{ static_cast(in_buf.GetDeviceBuffer()), static_cast(out_buf.GetDeviceBuffer()), + OutputIndex ? static_cast(out_index_buf.GetDeviceBuffer()) : nullptr, input_shape, output_shape, input_strides, @@ -167,12 +179,28 @@ bool run(const ck_tile::ArgParser& arg_parser) if(do_validation) { - ck_tile::reference_pool3d( - in, out_ref, kernel_args, ReduceOp{}); out_buf.FromDevice(out.mData.data()); - pass = ck_tile::check_err(out, out_ref); - std::cout << "valid:" << (pass ? "y" : "n") << std::flush << std::endl; + ck_tile::reference_pool3d(in, out_ref, out_ref_index, kernel_args, ReduceOp{}); + + if constexpr(OutputIndex) + { + out_index_buf.FromDevice(out_index.mData.data()); + pass = ck_tile::check_err(out, out_ref) && ck_tile::check_err(out_index, out_ref_index); + } + else + { + pass = ck_tile::check_err(out, out_ref); + } + + std::cout << "valid:" << (pass ? "y" : "n") << std::endl; } return pass; @@ -184,5 +212,5 @@ int main(int argc, char* argv[]) if(!result) return -1; - return run(arg_parser) ? 0 : -2; + return run(arg_parser) ? 0 : -2; } diff --git a/include/ck_tile/core.hpp b/include/ck_tile/core.hpp index 9f3c996873..92d19cf619 100644 --- a/include/ck_tile/core.hpp +++ b/include/ck_tile/core.hpp @@ -78,6 +78,7 @@ #include "ck_tile/core/utility/print.hpp" #include "ck_tile/core/utility/random.hpp" #include "ck_tile/core/utility/reduce_operator.hpp" +#include "ck_tile/core/utility/reduce_operator_accumulate.hpp" #include "ck_tile/core/utility/static_counter.hpp" #include "ck_tile/core/utility/to_sequence.hpp" #include "ck_tile/core/utility/transpose_vectors.hpp" diff --git a/include/ck_tile/core/utility/reduce_operator.hpp b/include/ck_tile/core/utility/reduce_operator.hpp index 218606f303..69449711e0 100644 --- a/include/ck_tile/core/utility/reduce_operator.hpp +++ b/include/ck_tile/core/utility/reduce_operator.hpp @@ -4,6 +4,7 @@ #pragma once #include "ck_tile/core/config.hpp" +#include "ck_tile/core/utility/type_traits.hpp" namespace ck_tile { @@ -18,16 +19,14 @@ struct Add }; template || std::is_same_v || - std::is_same_v || std::is_same_v>> + typename = std::enable_if_t::value>> CK_TILE_HOST_DEVICE constexpr T operator()(const T& y, const T x) const { return y + x; } template || std::is_same_v || - std::is_same_v || std::is_same_v>> + typename = std::enable_if_t::value>> CK_TILE_HOST_DEVICE constexpr T operator()(T& y, T x) const { float y_ = type_convert(y); @@ -46,16 +45,14 @@ struct SquareAdd }; template || std::is_same_v || - std::is_same_v || std::is_same_v>> + typename = std::enable_if_t::value>> CK_TILE_HOST_DEVICE constexpr T operator()(const T& y, const T x) const { return y + (x * x); } template || std::is_same_v || - std::is_same_v || std::is_same_v>> + typename = std::enable_if_t::value>> CK_TILE_HOST_DEVICE constexpr T operator()(T& y, T x) const { float y_ = type_convert(y); @@ -66,48 +63,74 @@ struct SquareAdd struct Max { - template || std::is_same_v || - std::is_same_v || std::is_same_v || - std::is_same_v || std::is_same_v || - std::is_same_v || std::is_same_v>> + template < + typename T, + typename = std::enable_if_t< + is_any_of::value>> CK_TILE_HOST_DEVICE static constexpr T GetIdentityValue() { return numeric::lowest(); }; - template || std::is_same_v || - std::is_same_v || std::is_same_v || - std::is_same_v || std::is_same_v || - std::is_same_v || std::is_same_v>> + template < + typename T, + typename = std::enable_if_t< + is_any_of::value>> CK_TILE_HOST_DEVICE constexpr T operator()(const T& y, const T x) const { return max(y, x); } + + // Overload with changed flag for index tracking + template < + typename T, + typename = std::enable_if_t< + is_any_of::value>> + CK_TILE_HOST_DEVICE constexpr T operator()(const T& y, const T x, bool& changed) const + { + T new_max = max(y, x); + if(x > y) + { + changed = true; + } + return new_max; + } }; struct AbsMax { - template || std::is_same_v || - std::is_same_v || std::is_same_v || - std::is_same_v || std::is_same_v || - std::is_same_v || std::is_same_v>> + template < + typename T, + typename = std::enable_if_t< + is_any_of::value>> CK_TILE_HOST_DEVICE static constexpr T GetIdentityValue() { return numeric::zero(); }; - template || std::is_same_v || - std::is_same_v || std::is_same_v || - std::is_same_v || std::is_same_v || - std::is_same_v || std::is_same_v>> + template < + typename T, + typename = std::enable_if_t< + is_any_of::value>> CK_TILE_HOST_DEVICE constexpr T operator()(const T& y, const T x) const { return max(y, abs(x)); } + + // Overload with changed flag for index tracking + template < + typename T, + typename = std::enable_if_t< + is_any_of::value>> + CK_TILE_HOST_DEVICE constexpr T operator()(const T& y, const T x, bool& changed) const + { + T new_max = max(y, abs(x)); + if(abs(x) > y) + { + changed = true; + } + return new_max; + } }; } // namespace ReduceOp diff --git a/include/ck_tile/core/utility/reduce_operator_accumulate.hpp b/include/ck_tile/core/utility/reduce_operator_accumulate.hpp new file mode 100644 index 0000000000..b49ff41ee0 --- /dev/null +++ b/include/ck_tile/core/utility/reduce_operator_accumulate.hpp @@ -0,0 +1,50 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core/config.hpp" + +namespace ck_tile { + +/// @brief Accumulate with index tracking reductions, provides deterministic first occurring index +struct AccumulateWithIndex +{ + template + CK_TILE_HOST_DEVICE void operator()(const ReduceOp& reduce_func, + T& current_value, + IndexType& current_index, + const T& new_value, + const IndexType& new_index) const + { + bool changed = false; + current_value = reduce_func(current_value, new_value, changed); + + if(changed) + { + current_index = new_index; + } + else if(new_index < current_index) + { + bool reverse_changed = false; + reduce_func(new_value, current_value, reverse_changed); + + if(!reverse_changed) + { + current_index = new_index; + } + } + } +}; + +struct Accumulate +{ + template + CK_TILE_HOST_DEVICE void + operator()(const ReduceOp& reduce_func, T& current_value, const T& new_value) const + { + current_value = reduce_func(current_value, new_value); + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/host/reference/reference_pool.hpp b/include/ck_tile/host/reference/reference_pool.hpp index 4fdb5fed78..7a2848def5 100644 --- a/include/ck_tile/host/reference/reference_pool.hpp +++ b/include/ck_tile/host/reference/reference_pool.hpp @@ -7,17 +7,21 @@ #include "ck_tile/host/host_tensor.hpp" #include "ck_tile/ops/pooling/kernel/pool_kernel.hpp" #include +#include namespace ck_tile { template + typename WindowShape, + bool OutputIndex = false> CK_TILE_HOST void reference_pool2d(const HostTensor& input, HostTensor& output, + HostTensor& output_index, PoolKernelArgs kargs, ReduceOp reduce_op) { @@ -45,6 +49,8 @@ CK_TILE_HOST void reference_pool2d(const HostTensor& input, auto f = [&](auto n, auto ho, auto wo, auto c) { ComputeDataType v_acc = reduce_op.template GetIdentityValue(); + IndexDataType current_index = 0; // Declare outside if constexpr for efficiency + for(ck_tile::index_t y = 0; y < Y; ++y) { // Calculate input height index with stride, dilation, and padding @@ -58,13 +64,32 @@ CK_TILE_HOST void reference_pool2d(const HostTensor& input, if(hi >= 0 && hi < H && wi >= 0 && wi < W) { const ComputeDataType v_in = type_convert(input(n, hi, wi, c)); - v_acc = reduce_op(v_acc, v_in); + + if constexpr(OutputIndex) + { + IndexDataType flat_index = input.GetOffsetFromMultiIndex(n, hi, wi, c); + bool changed = false; + v_acc = reduce_op(v_acc, v_in, changed); + if(changed) + { + current_index = flat_index; + } + } + else + { + v_acc = reduce_op(v_acc, v_in); + } } // For positions outside bounds, we implicitly use identity value } } output(n, ho, wo, c) = ck_tile::type_convert(v_acc); + + if constexpr(OutputIndex) + { + output_index(n, ho, wo, c) = current_index; + } }; // Parallelize over all output dimensions @@ -74,11 +99,14 @@ CK_TILE_HOST void reference_pool2d(const HostTensor& input, template + typename WindowShape, + bool OutputIndex = false> CK_TILE_HOST void reference_pool3d(const HostTensor& input, HostTensor& output, + HostTensor& output_index, PoolKernelArgs kargs, ReduceOp reduce_op) { @@ -112,6 +140,8 @@ CK_TILE_HOST void reference_pool3d(const HostTensor& input, auto f = [&](auto n, auto do_, auto ho, auto wo, auto c) { ComputeDataType v_acc = reduce_op.template GetIdentityValue(); + IndexDataType current_index = 0; // Declare outside if constexpr for efficiency + for(ck_tile::index_t z = 0; z < Z; ++z) { // Calculate input depth index with stride, dilation, and padding @@ -131,7 +161,22 @@ CK_TILE_HOST void reference_pool3d(const HostTensor& input, { const ComputeDataType v_in = type_convert(input(n, di, hi, wi, c)); - v_acc = reduce_op(v_acc, v_in); + + if constexpr(OutputIndex) + { + IndexDataType flat_index = + input.GetOffsetFromMultiIndex(n, di, hi, wi, c); + bool changed = false; + v_acc = reduce_op(v_acc, v_in, changed); + if(changed) + { + current_index = flat_index; + } + } + else + { + v_acc = reduce_op(v_acc, v_in); + } } // For positions outside bounds, we implicitly use identity value } @@ -139,10 +184,15 @@ CK_TILE_HOST void reference_pool3d(const HostTensor& input, } output(n, do_, ho, wo, c) = ck_tile::type_convert(v_acc); + + if constexpr(OutputIndex) + { + + output_index(n, do_, ho, wo, c) = current_index; + } }; // Parallelize over all output dimensions make_ParallelTensorFunctor(f, N, Do, Ho, Wo, C)(std::thread::hardware_concurrency()); } - } // namespace ck_tile diff --git a/include/ck_tile/ops/pooling/kernel/pool_kernel.hpp b/include/ck_tile/ops/pooling/kernel/pool_kernel.hpp index 93567e7161..b91fe514e8 100644 --- a/include/ck_tile/ops/pooling/kernel/pool_kernel.hpp +++ b/include/ck_tile/ops/pooling/kernel/pool_kernel.hpp @@ -17,6 +17,7 @@ struct PoolHostArgs CK_TILE_HOST PoolHostArgs(const void* input_ptr_, void* output_ptr_, + void* output_index_ptr_, TensorShape input_shape_, TensorShape output_shape_, TensorShape input_strides_, @@ -28,6 +29,7 @@ struct PoolHostArgs WindowShape input_right_pads_) : input_ptr(input_ptr_), output_ptr(output_ptr_), + output_index_ptr(output_index_ptr_), input_shape(input_shape_), output_shape(output_shape_), input_strides(input_strides_), @@ -42,6 +44,7 @@ struct PoolHostArgs const void* input_ptr; void* output_ptr; + void* output_index_ptr; TensorShape input_shape; TensorShape output_shape; @@ -60,6 +63,7 @@ struct PoolKernelArgs { const void* input_ptr; void* output_ptr; + void* output_index_ptr; TensorShape input_shape; TensorShape output_shape; TensorShape input_strides; @@ -80,6 +84,7 @@ struct PoolKernel using InDataType = ck_tile::remove_cvref_t; using ComputeDataType = ck_tile::remove_cvref_t; using OutDataType = ck_tile::remove_cvref_t; + using IndexDataType = ck_tile::remove_cvref_t; static constexpr index_t kBlockSize = Problem::BlockShape::BlockSize; @@ -205,7 +210,23 @@ struct PoolKernel tensor_view{out_buffer_view, out_desc_padded}; - return make_tuple(in_tensor_padded, out_tensor_padded); + if constexpr(Problem::kOutputIndex) + { + auto out_index_buffer_view = make_buffer_view( + static_cast(kargs.output_index_ptr), + out_desc.get_element_space_size(), + IndexDataType(-1)); + const auto out_index_tensor_padded = + tensor_view{ + out_index_buffer_view, out_desc_padded}; + + return make_tuple(in_tensor_padded, out_tensor_padded, out_index_tensor_padded); + } + else + { + // Return a dummy tensor for the third element when index output is not needed + return make_tuple(in_tensor_padded, out_tensor_padded, null_tensor{}); + } } template @@ -338,7 +359,23 @@ struct PoolKernel tensor_view{out_buffer_view, out_desc_padded}; - return make_tuple(in_tensor_padded, out_tensor_padded); + if constexpr(Problem::kOutputIndex) + { + auto out_index_buffer_view = make_buffer_view( + static_cast(kargs.output_index_ptr), + out_desc.get_element_space_size(), + IndexDataType(-1)); + const auto out_index_tensor_padded = + tensor_view{ + out_index_buffer_view, out_desc_padded}; + + return make_tuple(in_tensor_padded, out_tensor_padded, out_index_tensor_padded); + } + else + { + // Return a dummy tensor for the third element when index output is not needed + return make_tuple(in_tensor_padded, out_tensor_padded, null_tensor{}); + } } public: @@ -354,7 +391,7 @@ struct PoolKernel const auto iM = get_block_id() * S::Block_M; // Get tensors based on dimensionality - auto [in_tensor_padded, out_tensor_padded] = [&]() { + auto [in_tensor_padded, out_tensor_padded, out_index_tensor_padded] = [&]() { if constexpr(WindowShape::size() == 2) return MakeTensorView2D(kargs); else if constexpr(WindowShape::size() == 3) @@ -387,16 +424,57 @@ struct PoolKernel auto y_tile = block_reduce2d.template MakeYBlockTile(); set_tile(y_tile, reduce_op.template GetIdentityValue()); - for(int k_tile = __builtin_amdgcn_readfirstlane(0); k_tile < num_k_tiles; ++k_tile) + if constexpr(Problem::kOutputIndex) { - const auto x_tile = load_tile(x_window); - block_reduce2d(x_tile, y_tile, reduce_op); - move_tile_window(x_window, {0, S::Block_N}); - } + auto y_index_window = + make_tile_window(out_index_tensor_padded, make_tuple(number{}), {iM}); - block_reduce2d_sync(y_tile, reduce_op); - block_reduce2d_cross_warp(y_tile, smem, reduce_op); - store_tile(y_window, cast_tile(y_tile)); + auto y_index_tile = + block_reduce2d.template MakeYIndexBlockTile(); + set_tile(y_index_tile, IndexDataType(0)); + + // Main reduction loop - with index tracking + for(int k_tile = amd_wave_read_first_lane(0); k_tile < num_k_tiles; ++k_tile) + { + const auto x_tile = load_tile(x_window); + auto index_calculator = [&](const auto& x_indices) { + // Get global coordinates in the 2D matrix space (M, N) + const auto global_M = x_indices.at(number<0>{}) + iM; + const auto global_N = (k_tile * S::Block_N) + x_indices.at(number<1>{}); + return in_tensor_padded.get_tensor_descriptor().calculate_offset( + make_tuple(global_M, global_N)); + }; + + block_reduce2d(x_tile, y_tile, y_index_tile, reduce_op, index_calculator); + move_tile_window(x_window, {0, S::Block_N}); + } + + block_reduce2d_sync(y_tile, y_index_tile, reduce_op); + if constexpr(Problem::kNeedCrossWarpSync) + { + __shared__ char smem_indices[Policy::template GetIndicesSmemSize()]; + + block_reduce2d_cross_warp(y_tile, y_index_tile, smem, smem_indices, reduce_op); + } + + store_tile(y_window, cast_tile(y_tile)); + store_tile(y_index_window, cast_tile(y_index_tile)); + } + else + { + // Main reduction loop - without index tracking + for(int k_tile = __builtin_amdgcn_readfirstlane(0); k_tile < num_k_tiles; ++k_tile) + { + const auto x_tile = load_tile(x_window); + block_reduce2d(x_tile, y_tile, reduce_op); + move_tile_window(x_window, {0, S::Block_N}); + } + + block_reduce2d_sync(y_tile, reduce_op); + block_reduce2d_cross_warp(y_tile, smem, reduce_op); + + store_tile(y_window, cast_tile(y_tile)); + } } /// @brief Validates if the given arguments are supported by the pooling kernel. @@ -481,6 +559,7 @@ struct PoolKernel { return PoolKernelArgs{host_args.input_ptr, host_args.output_ptr, + host_args.output_index_ptr, host_args.input_shape, host_args.output_shape, host_args.input_strides, diff --git a/include/ck_tile/ops/pooling/pipeline/pool_default_policy.hpp b/include/ck_tile/ops/pooling/pipeline/pool_default_policy.hpp index a5b5fac63d..e08cc42e58 100644 --- a/include/ck_tile/ops/pooling/pipeline/pool_default_policy.hpp +++ b/include/ck_tile/ops/pooling/pipeline/pool_default_policy.hpp @@ -32,7 +32,8 @@ struct PoolDefaultPolicy { using P_ = BlockReduce2dProblem; + typename Problem::BlockShape, + Problem::kOutputIndex>; return BlockReduce2d{}; } @@ -41,7 +42,8 @@ struct PoolDefaultPolicy { using P_ = BlockReduce2dProblem; + typename Problem::BlockShape, + Problem::kOutputIndex>; return BlockReduce2dSync{}; } @@ -50,7 +52,8 @@ struct PoolDefaultPolicy { using P_ = BlockReduce2dProblem; + typename Problem::BlockShape, + Problem::kOutputIndex>; return BlockReduce2dCrossWarpSync{}; } @@ -61,7 +64,8 @@ struct PoolDefaultPolicy { using P_ = BlockReduce2dProblem; + typename Problem::BlockShape, + Problem::kOutputIndex>; using block_reduce2d = BlockReduce2d; using x_block_tile = @@ -76,5 +80,25 @@ struct PoolDefaultPolicy return 1; // zero size arrays are an extension } } + + template + CK_TILE_HOST_DEVICE static constexpr index_t GetIndicesSmemSize() + { + + using P_ = BlockReduce2dProblem; + + using block_reduce2d = BlockReduce2d; + using x_block_tile = decltype(make_static_distributed_tensor( + MakeXBlockTileDistribution())); + using y_index_block_tile = decltype(block_reduce2d::template MakeYIndexBlockTile< + x_block_tile, + typename Problem::IndexDataType>()); + + return GetBlockReduce2dCrossWarpSync() + .template GetIndicesSmemSize(); + } }; } // namespace ck_tile diff --git a/include/ck_tile/ops/pooling/pipeline/pool_problem.hpp b/include/ck_tile/ops/pooling/pipeline/pool_problem.hpp index 83a43318bc..53071b1772 100644 --- a/include/ck_tile/ops/pooling/pipeline/pool_problem.hpp +++ b/include/ck_tile/ops/pooling/pipeline/pool_problem.hpp @@ -26,6 +26,8 @@ struct PoolProblem using OutputIndex = bool_constant; using PropagateNan = bool_constant; + static constexpr bool kOutputIndex = OutputIndex_; + static constexpr bool kPropagateNan = PropagateNan_; static constexpr bool kNeedCrossLaneSync = BlockShape::ThreadPerWarp_N > 1; static constexpr bool kNeedCrossWarpSync = BlockShape::WarpPerBlock_N > 1; }; diff --git a/include/ck_tile/ops/reduce/block/block_reduce2d.hpp b/include/ck_tile/ops/reduce/block/block_reduce2d.hpp index 9cddb0abf2..c666608bfd 100644 --- a/include/ck_tile/ops/reduce/block/block_reduce2d.hpp +++ b/include/ck_tile/ops/reduce/block/block_reduce2d.hpp @@ -4,6 +4,7 @@ #pragma once #include "ck_tile/core.hpp" +#include "ck_tile/core/utility/reduce_operator_accumulate.hpp" namespace ck_tile { @@ -50,6 +51,53 @@ struct BlockReduce2d CK_TILE_DEVICE constexpr BlockReduce2d() {} + private: + template + CK_TILE_DEVICE void reduce_impl(const XDistributedTensor_& x_tensor, + YDistributedTensor_& y_tensor, + YIndexDistributedTensor_& y_index_tensor, + const ReduceFunc& reduce_func, + const IndexCalculatorFunc& index_calculator, + ReducePacksPerXDim) + { + sweep_tile( + [&](auto... idx_) { + constexpr auto idx_0 = make_tuple(make_tuple(idx_[number<0>{}]...)[number<0>{}]); + + (..., [&](auto idx) { + auto val = ck_tile::type_convert(x_tensor[idx]); + + if constexpr(kProcessIndex) + { + + const auto x_indices = get_x_indices_from_distributed_indices( + XDistributedTensor_::get_tile_distribution(), idx); + const auto new_idx = index_calculator(x_indices); + auto current_idx = y_index_tensor(idx_0); + + AccumulateWithIndex{}( + reduce_func, y_tensor(idx_0), current_idx, val, new_idx); + + y_index_tensor(idx_0) = + type_convert(current_idx); + } + else + { + Accumulate{}(reduce_func, y_tensor(idx_0), val); + } + }(idx_)); + }, + ReducePacksPerXDim{}); + } + + public: + // Overload for non-index tracking template < typename XDistributedTensor_, typename YDistributedTensor_, @@ -61,13 +109,36 @@ struct BlockReduce2d const ReduceFunc& reduce_func, ReducePacksPerXDim = {}) { - sweep_tile( - [&](auto... idx_) { - constexpr auto idx_0 = make_tuple(make_tuple(idx_[number<0>{}]...)[number<0>{}]); - y_tensor(idx_0) = reduce_func( - y_tensor(idx_0), ck_tile::type_convert(x_tensor[idx_])...); - }, + reduce_impl( + x_tensor, + y_tensor, + y_tensor, // dummy + reduce_func, + [](auto) { return 0; }, // dummy ReducePacksPerXDim{}); + } + + // Overload for index tracking + template > + CK_TILE_DEVICE void operator()(const XDistributedTensor_& x_tensor, + YDistributedTensor_& y_tensor, + YIndexDistributedTensor_& y_index_tensor, + const ReduceFunc& reduce_func, + const IndexCalculatorFunc& index_calculator, + ReducePacksPerXDim = {}) + { + reduce_impl(x_tensor, + y_tensor, + y_index_tensor, + reduce_func, + index_calculator, + ReducePacksPerXDim{}); + } #if 0 constexpr auto I0 = number<0>{}; @@ -90,7 +161,6 @@ struct BlockReduce2d y_tensor(y_dstr_idx) = y; }); #endif - } template CK_TILE_DEVICE static auto MakeYBlockTile() @@ -111,6 +181,25 @@ struct BlockReduce2d return tensor; } + template + CK_TILE_DEVICE static auto MakeYIndexBlockTile() + { + static_assert(std::is_same_v, "wrong!"); + + // FIXME: hard coded to reduce 2nd axis + constexpr auto reduce_dims = sequence<1>{}; + + constexpr auto dstr = + make_static_tile_distribution(detail::make_reduce_tile_distribution_encoding( + XDistributedTensor_::get_tile_distribution() + .get_static_tile_distribution_encoding(), + reduce_dims)); + + auto tensor = make_static_distributed_tensor(dstr); + + return tensor; + } + // uniform_sequence_gen_t generates sequence of NSize elements filled with Value // e.g., uniform_sequence_gen_t<2, 1> → {1, 1} and uniform_sequence_gen_t<3, 4> → {4, 4, 4} template ; - template - CK_TILE_DEVICE void operator()(YDistributedTensor_& y_tensor, const ReduceFunc& reduce_func) + private: + template + CK_TILE_DEVICE void reduce_impl(YDistributedTensor_& y_tensor, + YIndexDistributedTensor_& y_index_tensor, + const ReduceFunc& reduce_func) { using Dstr = typename YDistributedTensor_::StaticTileDistribution; using DstrEncode = typename Dstr::DstrEncode; @@ -157,6 +252,14 @@ struct BlockReduce2dSync static_for<0, thread_buf_size, 1>{}([&](auto i) { auto v_local = y_tensor.get_thread_buffer()[i]; + using IndexDataType = typename YIndexDistributedTensor_::DataType; + IndexDataType idx_local{}; + + if constexpr(kProcessIndex) + { + idx_local = y_index_tensor.get_thread_buffer()[i]; + } + // cross-lane reduce for replication // only reduce on R dimension correspond to lane // (lane id maps to this R dimension) @@ -183,15 +286,46 @@ struct BlockReduce2dSync // pull data from remote lane const auto v_remote = warp_shuffle(v_local, src_lane); - v_local = reduce_func(v_local, v_remote); + + if constexpr(kProcessIndex) + { + const auto idx_remote = warp_shuffle(idx_local, src_lane); + + AccumulateWithIndex{}( + reduce_func, v_local, idx_local, v_remote, idx_remote); + } + else + { + Accumulate{}(reduce_func, v_local, v_remote); + } }); } }); // TODO - Do we need to broadcast to other lane? y_tensor.get_thread_buffer()(i) = v_local; + + if constexpr(kProcessIndex) + { + y_index_tensor.get_thread_buffer()(i) = idx_local; + } }); } + + public: + template + CK_TILE_DEVICE void operator()(YDistributedTensor_& y_tensor, const ReduceFunc& reduce_func) + { + reduce_impl(y_tensor, y_tensor, reduce_func); + } + + template + CK_TILE_DEVICE void operator()(YDistributedTensor_& y_tensor, + YIndexDistributedTensor_& y_index_tensor, + const ReduceFunc& reduce_func) + { + reduce_impl(y_tensor, y_index_tensor, reduce_func); + } }; // BlockReduce2dCrossWarpSync: Cross-warp reduction (Stage 3) @@ -250,15 +384,39 @@ struct BlockReduce2dCrossWarpSync return num_warps * thread_buf_size * sizeof(DataType); } - template - CK_TILE_DEVICE void - operator()(YDistributedTensor_& y_tensor, void* smem, const ReduceFunc& reduce_func) + // return in byte - separate shared memory size calculation for indices + template + CK_TILE_HOST_DEVICE static constexpr index_t GetIndicesSmemSize() { - using DataType = typename YDistributedTensor_::DataType; + using IndexDataType = typename YIndexDistributedTensor_::DataType; + constexpr index_t thread_buf_size = YIndexDistributedTensor_::get_thread_buffer_size(); + constexpr index_t num_warps = BlockShape::BlockSize / get_warp_size(); + return num_warps * thread_buf_size * sizeof(IndexDataType); + } + + private: + template + CK_TILE_DEVICE void reduce_impl(YDistributedTensor_& y_tensor, + YIndexDistributedTensor_& y_index_tensor, + void* smem, + void* smem_indices_ptr, + const ReduceFunc& reduce_func) + { + using DataType = typename YDistributedTensor_::DataType; + using IndexDataType = typename YIndexDistributedTensor_::DataType; constexpr index_t thread_buf_size = YDistributedTensor_::get_thread_buffer_size(); - DataType* smem_ptr = reinterpret_cast(smem); + DataType* smem_ptr = reinterpret_cast(smem); + IndexDataType* smem_indices = nullptr; + if constexpr(kProcessIndex) + { + smem_indices = reinterpret_cast(smem_indices_ptr); + } + const index_t lane_id = get_lane_id(); const index_t warp_id = get_warp_id(); @@ -275,6 +433,11 @@ struct BlockReduce2dCrossWarpSync static_for<0, thread_buf_size, 1>{}([&](auto i) { // Store the i-th element of this warp's thread_buffer into SMEM smem_ptr[smem_offset + i * num_warps] = y_tensor.get_thread_buffer()[i]; + if constexpr(kProcessIndex) + { + smem_indices[smem_offset + i * num_warps] = + y_index_tensor.get_thread_buffer()[i]; + } }); } block_sync_lds(); @@ -282,10 +445,19 @@ struct BlockReduce2dCrossWarpSync // We let each warp holds a duplication to do reduction. const index_t local_warp_id = warp_id / num_reduce_warps; const index_t local_smem_os = local_warp_id * num_reduce_warps; + static_for<0, thread_buf_size, 1>{}([&](auto i) { DataType v[num_reduce_warps]; - static_for<0, num_reduce_warps, 1>{}( - [&](auto idx) { v[idx] = smem_ptr[i * num_warps + local_smem_os + idx]; }); + [[maybe_unused]] std:: + conditional_t idx_v; + + static_for<0, num_reduce_warps, 1>{}([&](auto idx) { + v[idx] = smem_ptr[i * num_warps + local_smem_os + idx]; + if constexpr(kProcessIndex) + { + idx_v[idx] = smem_indices[i * num_warps + local_smem_os + idx]; + } + }); static_assert(is_power_of_two_integer(num_reduce_warps), "wrong! only support power of 2 reduction"); @@ -299,14 +471,44 @@ struct BlockReduce2dCrossWarpSync constexpr index_t i1 = idx_ + stride; if constexpr(i1 < num_reduce_warps) { - v[i0] = reduce_func(v[i0], v[i1]); + if constexpr(kProcessIndex) + { + AccumulateWithIndex{}(reduce_func, v[i0], idx_v[i0], v[i1], idx_v[i1]); + } + else + { + Accumulate{}(reduce_func, v[i0], v[i1]); + } } }); }); y_tensor.get_thread_buffer()(i) = v[0]; + if constexpr(kProcessIndex) + { + y_index_tensor.get_thread_buffer()(i) = idx_v[0]; + } }); } + + public: + template + CK_TILE_DEVICE void + operator()(YDistributedTensor_& y_tensor, void* smem, const ReduceFunc& reduce_func) + { + reduce_impl(y_tensor, y_tensor, smem, nullptr, reduce_func); + } + + template + CK_TILE_DEVICE void operator()(YDistributedTensor_& y_tensor, + YIndexDistributedTensor_& y_index_tensor, + void* smem, + void* smem_indices, + const ReduceFunc& reduce_func) + { + reduce_impl( + y_tensor, y_index_tensor, smem, smem_indices, reduce_func); + } }; template @@ -364,15 +566,39 @@ struct BlockReduce2dLinearCrossWarpSync return num_warps * thread_buf_size * sizeof(DataType); } - template - CK_TILE_DEVICE void - operator()(YDistributedTensor_& y_tensor, void* smem, const ReduceFunc& reduce_func) + // return in byte - separate shared memory size calculation for indices + template + CK_TILE_HOST_DEVICE static constexpr index_t GetIndicesSmemSize() { - using DataType = typename YDistributedTensor_::DataType; + using IndexDataType = typename YIndexDistributedTensor_::DataType; + constexpr index_t thread_buf_size = YIndexDistributedTensor_::get_thread_buffer_size(); + constexpr index_t num_warps = BlockShape::BlockSize / get_warp_size(); + return num_warps * thread_buf_size * sizeof(IndexDataType); + } + + private: + template + CK_TILE_DEVICE void reduce_impl(YDistributedTensor_& y_tensor, + YIndexDistributedTensor_& y_index_tensor, + void* smem, + void* smem_indices_ptr, + const ReduceFunc& reduce_func) + { + using DataType = typename YDistributedTensor_::DataType; + using IndexDataType = typename YIndexDistributedTensor_::DataType; constexpr index_t thread_buf_size = YDistributedTensor_::get_thread_buffer_size(); - DataType* smem_ptr = reinterpret_cast(smem); + DataType* smem_ptr = reinterpret_cast(smem); + IndexDataType* smem_indices = nullptr; + if constexpr(kProcessIndex) + { + smem_indices = reinterpret_cast(smem_indices_ptr); + } + const index_t lane_id = get_lane_id(); const index_t warp_id = get_warp_id(); constexpr auto num_reduce_warps = GetReduceWarps(); @@ -388,6 +614,11 @@ struct BlockReduce2dLinearCrossWarpSync { static_for<0, thread_buf_size, 1>{}([&](auto i) { smem_ptr[smem_offset + i * num_warps] = y_tensor.get_thread_buffer()[i]; + if constexpr(kProcessIndex) + { + smem_indices[smem_offset + i * num_warps] = + y_index_tensor.get_thread_buffer()[i]; + } }); } block_sync_lds(); @@ -395,31 +626,86 @@ struct BlockReduce2dLinearCrossWarpSync // load from smem. here we let everythread to do compute :) index_t local_warp_id = warp_id / num_reduce_warps; index_t local_smem_os = local_warp_id * num_reduce_warps; + DataType all_scratch[thread_buf_size * num_reduce_warps]; + [[maybe_unused]] std::conditional_t all_indices; + + // Load data from shared memory static_for<0, thread_buf_size, 1>{}([&](auto i_0) { static_for<0, num_reduce_warps, 1>{}([&](auto i_1) { all_scratch[i_0 * num_reduce_warps + i_1] = smem_ptr[i_0 * num_warps + local_smem_os + i_1]; + + if constexpr(kProcessIndex) + { + all_indices[i_0 * num_reduce_warps + i_1] = + smem_indices[i_0 * num_warps + local_smem_os + i_1]; + } }); }); block_sync_lds(); // TODO: we don't need sync here + // Perform reduction static_for<0, thread_buf_size, 1>{}([&](auto i_0) { // TODO: use descriptor for this auto v_local = all_scratch[i_0 * num_reduce_warps]; + IndexDataType idx_local{}; + if constexpr(kProcessIndex) + { + idx_local = all_indices[i_0 * num_reduce_warps]; + } + // further reduce mean/var static_for<0, num_reduce_warps - 1, 1>{}([&](auto i_1_n1) { constexpr auto i_1 = number{}; const DataType v_remote = all_scratch[i_0 * num_reduce_warps + i_1]; - // reduce - v_local = reduce_func(v_local, v_remote); + if constexpr(kProcessIndex) + { + const IndexDataType idx_remote = all_indices[i_0 * num_reduce_warps + i_1]; + + bool changed = false; + v_local = reduce_func(v_local, v_remote, changed); + if(changed) + { + idx_local = idx_remote; + } + } + else + { + v_local = reduce_func(v_local, v_remote); + } }); y_tensor.get_thread_buffer()(i_0) = v_local; + if constexpr(kProcessIndex) + { + y_index_tensor.get_thread_buffer()(i_0) = idx_local; + } }); } + + public: + template + CK_TILE_DEVICE void + operator()(YDistributedTensor_& y_tensor, void* smem, const ReduceFunc& reduce_func) + { + reduce_impl(y_tensor, y_tensor, smem, nullptr, reduce_func); + } + + template + CK_TILE_DEVICE void operator()(YDistributedTensor_& y_tensor, + YIndexDistributedTensor_& y_index_tensor, + void* smem, + void* smem_indices, + const ReduceFunc& reduce_func) + { + reduce_impl( + y_tensor, y_index_tensor, smem, smem_indices, reduce_func); + } }; } // namespace ck_tile diff --git a/include/ck_tile/ops/reduce/block/block_reduce2d_problem.hpp b/include/ck_tile/ops/reduce/block/block_reduce2d_problem.hpp index b75f4f0767..33cc660541 100644 --- a/include/ck_tile/ops/reduce/block/block_reduce2d_problem.hpp +++ b/include/ck_tile/ops/reduce/block/block_reduce2d_problem.hpp @@ -7,12 +7,17 @@ namespace ck_tile { -template +template struct BlockReduce2dProblem { using XDataType = remove_cvref_t; using ComputeDataType = remove_cvref_t; using BlockShape = remove_cvref_t; + + static constexpr bool kOutputIndex = OutputIndex_; }; } // namespace ck_tile diff --git a/include/ck_tile/ops/reduce/pipeline/reduce2d_default_policy.hpp b/include/ck_tile/ops/reduce/pipeline/reduce2d_default_policy.hpp index 27bb4bcdcb..273a764f01 100644 --- a/include/ck_tile/ops/reduce/pipeline/reduce2d_default_policy.hpp +++ b/include/ck_tile/ops/reduce/pipeline/reduce2d_default_policy.hpp @@ -32,7 +32,8 @@ struct Reduce2dDefaultPolicy { using P_ = BlockReduce2dProblem; + typename Problem::BlockShape, + Problem::kOutputIndex>; return BlockReduce2d{}; } @@ -41,7 +42,8 @@ struct Reduce2dDefaultPolicy { using P_ = BlockReduce2dProblem; + typename Problem::BlockShape, + Problem::kOutputIndex>; return BlockReduce2dSync{}; } @@ -50,7 +52,8 @@ struct Reduce2dDefaultPolicy { using P_ = BlockReduce2dProblem; + typename Problem::BlockShape, + Problem::kOutputIndex>; return BlockReduce2dCrossWarpSync{}; } @@ -61,7 +64,8 @@ struct Reduce2dDefaultPolicy { using P_ = BlockReduce2dProblem; + typename Problem::BlockShape, + Problem::kOutputIndex>; using block_reduce2d = BlockReduce2d; using x_block_tile = @@ -76,5 +80,23 @@ struct Reduce2dDefaultPolicy return 1; // zero size arrays are an extension } } + + template + CK_TILE_HOST_DEVICE static constexpr index_t GetIndicesSmemSize() + { + using P_ = BlockReduce2dProblem; + + using block_reduce2d = BlockReduce2d; + using x_block_tile = decltype(make_static_distributed_tensor( + MakeXBlockTileDistribution())); + using y_index_block_tile = + decltype(block_reduce2d::template MakeYIndexBlockTile()); + + return GetBlockReduce2dCrossWarpSync() + .template GetIndicesSmemSize(); + } }; } // namespace ck_tile diff --git a/include/ck_tile/ops/reduce/pipeline/reduce2d_problem.hpp b/include/ck_tile/ops/reduce/pipeline/reduce2d_problem.hpp index 67fdec9286..1570b44271 100644 --- a/include/ck_tile/ops/reduce/pipeline/reduce2d_problem.hpp +++ b/include/ck_tile/ops/reduce/pipeline/reduce2d_problem.hpp @@ -11,7 +11,8 @@ template + typename ReduceOp_, + bool OutputIndex_ = false> struct Reduce2dProblem { using XDataType = remove_cvref_t; @@ -20,6 +21,7 @@ struct Reduce2dProblem using BlockShape = remove_cvref_t; using ReduceOp = ReduceOp_; + static constexpr bool kOutputIndex = OutputIndex_; static constexpr bool kNeedCrossLaneSync = BlockShape::ThreadPerWarp_N > 1; static constexpr bool kNeedCrossWarpSync = BlockShape::WarpPerBlock_N > 1; }; diff --git a/test/ck_tile/pooling/test_pooling.cpp b/test/ck_tile/pooling/test_pooling.cpp index fa98687bda..37f4b6cad6 100644 --- a/test/ck_tile/pooling/test_pooling.cpp +++ b/test/ck_tile/pooling/test_pooling.cpp @@ -28,7 +28,19 @@ class TestCkTilePooling : public ::testing::Test using TestPoolShape = ck_tile::PoolShape; - // 3D pooling configuration + // 2D pooling configuration (NHWC) + struct Config2D + { + ck_tile::index_t N, H, W, C; + ck_tile::index_t Y, X; + ck_tile::index_t Sy, Sx; + ck_tile::index_t Dy, Dx; + ck_tile::index_t LeftPy, LeftPx; + ck_tile::index_t RightPy, RightPx; + std::string name; + }; + + // 3D pooling configuration (NDHWC) struct Config3D { ck_tile::index_t N, D, H, W, C; @@ -40,6 +52,117 @@ class TestCkTilePooling : public ::testing::Test std::string name; }; + bool RunPool2D(const Config2D& config) + { + std::cout << "Testing 2D: " << config.name << " ... "; + + const ck_tile::index_t Ys = (config.Y - 1) * config.Dy + 1; + const ck_tile::index_t Xs = (config.X - 1) * config.Dx + 1; + const ck_tile::index_t Ho = + (config.H + config.LeftPy + config.RightPy - Ys) / config.Sy + 1; + const ck_tile::index_t Wo = + (config.W + config.LeftPx + config.RightPx - Xs) / config.Sx + 1; + + using IndexDataType = ck_tile::index_t; + + // Host tensors + ck_tile::HostTensor h_in({config.N, config.H, config.W, config.C}); + ck_tile::HostTensor h_out({config.N, Ho, Wo, config.C}); + ck_tile::HostTensor h_out_ref({config.N, Ho, Wo, config.C}); + ck_tile::HostTensor h_out_index({config.N, Ho, Wo, config.C}); + ck_tile::HostTensor h_out_ref_index({config.N, Ho, Wo, config.C}); + + // Initialize input with random data + ck_tile::FillUniformDistribution{-5.f, 5.f}(h_in); + + // Device memory + ck_tile::DeviceMem d_in_mem(h_in.get_element_space_size_in_bytes()); + ck_tile::DeviceMem d_out_mem(h_out.get_element_space_size_in_bytes()); + ck_tile::DeviceMem d_out_index_mem(h_out_index.get_element_space_size_in_bytes()); + + d_in_mem.ToDevice(h_in.data()); + d_out_mem.ToDevice(h_out.data()); + d_out_index_mem.ToDevice(h_out_index.data()); + + constexpr ck_tile::index_t kBlockPerCu = 1; + + using Problem = ck_tile::PoolProblem; + using Kernel = ck_tile::PoolKernel; + + const ck_tile::index_t kBlockSize = Kernel::BlockSize(); + + // Shapes and strides (NHWC) + const auto input_shape = ck_tile::make_tuple(config.N, config.H, config.W, config.C); + const auto output_shape = ck_tile::make_tuple(config.N, Ho, Wo, config.C); + const auto input_strides = + ck_tile::make_tuple(config.H * config.W * config.C, config.W * config.C, config.C, 1); + const auto output_strides = + ck_tile::make_tuple(Ho * Wo * config.C, Wo * config.C, config.C, 1); + const auto window_spatial_lengths = ck_tile::make_tuple(config.Y, config.X); + const auto window_strides = ck_tile::make_tuple(config.Sy, config.Sx); + const auto window_dilations = ck_tile::make_tuple(config.Dy, config.Dx); + const auto input_left_pads = ck_tile::make_tuple(config.LeftPy, config.LeftPx); + const auto input_right_pads = ck_tile::make_tuple(config.RightPy, config.RightPx); + + auto host_args = + ck_tile::PoolHostArgs{ + static_cast(d_in_mem.GetDeviceBuffer()), + static_cast(d_out_mem.GetDeviceBuffer()), + static_cast(d_out_index_mem.GetDeviceBuffer()), + input_shape, + output_shape, + input_strides, + output_strides, + window_spatial_lengths, + window_strides, + window_dilations, + input_left_pads, + input_right_pads}; + + auto kernel_args = Kernel::MakeKernelArgs(host_args); + const ck_tile::index_t kGridSize = Kernel::CalculateGridSize(kernel_args); + + if(!Kernel::IsSupportedArgument(kernel_args)) + { + return true; + } + + // Run kernel + ck_tile::launch_kernel( + ck_tile::stream_config{nullptr, false, 0}, + ck_tile::make_kernel(Kernel{}, kGridSize, kBlockSize, 0, kernel_args)); + + // Run reference + ck_tile::reference_pool2d( + h_in, h_out_ref, h_out_ref_index, kernel_args, ReduceOpType{}); + + d_out_mem.FromDevice(h_out.data()); + d_out_index_mem.FromDevice(h_out_index.data()); + + // Validate results + bool pass_value = + ck_tile::check_err(h_out, h_out_ref, "Error: Incorrect values!", 1e-5, 1e-5); + bool pass_index = ck_tile::check_err( + h_out_index, h_out_ref_index, "Error: Incorrect indices!", 1e-5, 1e-5); + + std::cout << (pass_value && pass_index ? "PASS" : "FAIL") << std::endl; + return pass_value && pass_index; + } + bool RunPool3D(const Config3D& config) { std::cout << "Testing 3D: " << config.name << " ... "; @@ -72,6 +195,8 @@ class TestCkTilePooling : public ::testing::Test const auto input_right_pads = ck_tile::make_tuple(config.RightPz, config.RightPy, config.RightPx); + using IndexDataType = ck_tile::index_t; + ck_tile::HostTensor h_in({config.N, config.D, config.H, config.W, config.C}, {config.D * config.H * config.W * config.C, config.H * config.W * config.C, @@ -84,6 +209,12 @@ class TestCkTilePooling : public ::testing::Test ck_tile::HostTensor h_out_ref( {config.N, Do, Ho, Wo, config.C}, {Do * Ho * Wo * config.C, Ho * Wo * config.C, Wo * config.C, config.C, 1}); + ck_tile::HostTensor h_out_index( + {config.N, Do, Ho, Wo, config.C}, + {Do * Ho * Wo * config.C, Ho * Wo * config.C, Wo * config.C, config.C, 1}); + ck_tile::HostTensor h_out_ref_index( + {config.N, Do, Ho, Wo, config.C}, + {Do * Ho * Wo * config.C, Ho * Wo * config.C, Wo * config.C, config.C, 1}); ck_tile::FillUniformDistribution{-5.f, 5.f}(h_in); h_out.SetZero(); @@ -91,17 +222,19 @@ class TestCkTilePooling : public ::testing::Test ck_tile::DeviceMem d_in_mem(h_in.get_element_space_size_in_bytes()); ck_tile::DeviceMem d_out_mem(h_out.get_element_space_size_in_bytes()); + ck_tile::DeviceMem d_out_index_mem(h_out_index.get_element_space_size_in_bytes()); d_in_mem.ToDevice(h_in.data()); d_out_mem.ToDevice(h_out.data()); + d_out_index_mem.ToDevice(h_out_index.data()); using Problem = ck_tile::PoolProblem; using Kernel = ck_tile::PoolKernel; @@ -112,6 +245,7 @@ class TestCkTilePooling : public ::testing::Test ck_tile::PoolHostArgs{ static_cast(d_in_mem.GetDeviceBuffer()), static_cast(d_out_mem.GetDeviceBuffer()), + static_cast(d_out_index_mem.GetDeviceBuffer()), input_shape, output_shape, input_strides, @@ -137,16 +271,27 @@ class TestCkTilePooling : public ::testing::Test ck_tile::make_kernel(Kernel{}, kGridSize, kBlockSize, 0, kernel_args)); // Run reference implementation - ck_tile::reference_pool3d( - h_in, h_out_ref, kernel_args, ReduceOpType{}); + ck_tile::reference_pool3d( + h_in, h_out_ref, h_out_ref_index, kernel_args, ReduceOpType{}); d_out_mem.FromDevice(h_out.data()); + d_out_index_mem.FromDevice(h_out_index.data()); // Validate results - bool pass = ck_tile::check_err(h_out, h_out_ref); - std::cout << (pass ? "PASS" : "FAIL") << std::endl; + bool pass_value = + ck_tile::check_err(h_out, h_out_ref, "Error: Incorrect values!", 1e-5, 1e-5); + bool pass_index = ck_tile::check_err( + h_out_index, h_out_ref_index, "Error: Incorrect indices!", 1e-5, 1e-5); - return pass; + std::cout << (pass_value && pass_index ? "PASS" : "FAIL") << std::endl; + return pass_value && pass_index; } }; @@ -194,6 +339,50 @@ using TestTypes = TYPED_TEST_SUITE(TestCkTilePooling, TestTypes); +// 2D Pooling Tests (NHWC) +TYPED_TEST(TestCkTilePooling, Pool2D_2x2) +{ + typename TestFixture::Config2D config = {1, // N - batch size + 8, // H - height dimension + 8, // W - width dimension + 32, // C - channel dimension + 2, // Y - pooling window height + 2, // X - pooling window width + 2, // Sy - window stride height + 2, // Sx - window stride width + 1, // Dy - window dilation height + 1, // Dx - window dilation width + 0, // LeftPy - left padding height + 0, // LeftPx - left padding width + 0, // RightPy - right padding height + 0, // RightPx - right padding width + "2x2 pooling NHWC"}; + bool pass = this->RunPool2D(config); + EXPECT_TRUE(pass); +} + +TYPED_TEST(TestCkTilePooling, Pool2D_3x3_WithPadding) +{ + typename TestFixture::Config2D config = {2, // N - batch size + 16, // H - height dimension + 16, // W - width dimension + 32, // C - channel dimension + 3, // Y - pooling window height + 3, // X - pooling window width + 2, // Sy - window stride height + 2, // Sx - window stride width + 1, // Dy - window dilation height + 1, // Dx - window dilation width + 1, // LeftPy - left padding height + 1, // LeftPx - left padding width + 1, // RightPy - right padding height + 1, // RightPx - right padding width + "3x3 pooling NHWC with padding"}; + bool pass = this->RunPool2D(config); + EXPECT_TRUE(pass); +} + +// 3D Pooling Tests (NDHWC) TYPED_TEST(TestCkTilePooling, Pool3D_2x2x2) { typename TestFixture::Config3D config = {1, // N - batch size @@ -216,7 +405,7 @@ TYPED_TEST(TestCkTilePooling, Pool3D_2x2x2) 0, // RightPz - right padding depth 0, // RightPy - right padding height 0, // RightPx - right padding width - "2x2x2 pooling"}; + "2x2x2 pooling NDHWC"}; bool pass = this->RunPool3D(config); EXPECT_TRUE(pass); } @@ -243,7 +432,7 @@ TYPED_TEST(TestCkTilePooling, Pool3D_3x3x3) 1, // RightPz - right padding depth 1, // RightPy - right padding height 1, // RightPx - right padding width - "3x3x3 pooling"}; + "3x3x3 pooling NDHWC with padding"}; bool pass = this->RunPool3D(config); EXPECT_TRUE(pass); }