mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 14:29:05 +00:00
[CK_TILE] Add indexing to pooling operator (Lwpck 3892) (#3013)
* Add indexing support to pooling operator - Add IndexDataType template parameter to pooling problem and kernel definitions - Enable pooling kernel to output indices of selected elements during max/absmax pooling - Add overloaded operators for Max and AbsMax that track when values change using bool changed parameter - Support optional index buffer allocation and management in device memory - Modify BlockReduce2d classes to handle index tensors alongside value tensors - Add separate shared memory allocation for index data in cross-warp reductions - Create validate_pool_indices function to verify index correctness - Modify pool3d.cpp example to demonstrate index output functionality - Add tests for index output * fixes * Refactor BlockReduce2D functions to get rid auxiliary private types. * comment resolutions and some changes to block_reduce2d - index reference implementation improved - reduce_operator.hpp cleanedup - updated the block_reduce2d.hpp to have index calculation for BlockReduce2dLinearCrossWarpSync as well * conditionally used variable declaration improvement - the conditionally used vairbales are used only when indexing is enabled. To inform the compiler that they may be unused and declare them with least size possible. This may allow it to be optimized compared to the previous declarations * comment resolutions * lexical ordering of the indicies - introduced accumulate methods that handle the intermediate steps if needed to order the indexes * add reduce_operator_accumulate.hpp to core.hpp --------- Co-authored-by: Adam Osewski <Adam.Osewski@amd.com>
This commit is contained in:
committed by
GitHub
parent
7c6430eca0
commit
3052d7c9e6
@@ -38,7 +38,10 @@ auto create_args(int argc, char* argv[])
|
||||
return std::make_tuple(result, arg_parser);
|
||||
}
|
||||
|
||||
template <typename InDataType, typename OutDataType, typename ComputeDataType>
|
||||
template <typename InDataType,
|
||||
typename OutDataType,
|
||||
typename ComputeDataType,
|
||||
typename IndexDataType>
|
||||
bool run(const ck_tile::ArgParser& arg_parser)
|
||||
{
|
||||
|
||||
@@ -84,6 +87,9 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
int warmup = arg_parser.get_int("warmup");
|
||||
int repeat = arg_parser.get_int("repeat");
|
||||
|
||||
constexpr bool OutputIndex = true;
|
||||
constexpr bool PropagateNan = false;
|
||||
|
||||
// Shapes / strides / parameters (NDHWC)
|
||||
const auto input_shape = ck_tile::make_tuple(N, D, H, W, C);
|
||||
const auto output_shape = ck_tile::make_tuple(N, Do, Ho, Wo, C);
|
||||
@@ -100,11 +106,16 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
{Do * Ho * Wo * C, Ho * Wo * C, Wo * C, C, 1});
|
||||
ck_tile::HostTensor<OutDataType> out_ref({N, Do, Ho, Wo, C},
|
||||
{Do * Ho * Wo * C, Ho * Wo * C, Wo * C, C, 1});
|
||||
ck_tile::HostTensor<IndexDataType> out_index({N, Do, Ho, Wo, C},
|
||||
{Do * Ho * Wo * C, Ho * Wo * C, Wo * C, C, 1});
|
||||
ck_tile::HostTensor<IndexDataType> out_ref_index({N, Do, Ho, Wo, C},
|
||||
{Do * Ho * Wo * C, Ho * Wo * C, Wo * C, C, 1});
|
||||
|
||||
ck_tile::FillUniformDistribution<InDataType>{-5.f, 5.f}(in);
|
||||
|
||||
ck_tile::DeviceMem in_buf(in.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem out_buf(out.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem out_index_buf(OutputIndex ? out_index.get_element_space_size_in_bytes() : 0);
|
||||
|
||||
in_buf.ToDevice(in.data());
|
||||
|
||||
@@ -118,10 +129,10 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
using Problem = ck_tile::PoolProblem<InDataType,
|
||||
OutDataType,
|
||||
ComputeDataType,
|
||||
OutDataType,
|
||||
IndexDataType,
|
||||
ReduceOp,
|
||||
false,
|
||||
false,
|
||||
OutputIndex,
|
||||
PropagateNan,
|
||||
Shape>;
|
||||
using Kernel = ck_tile::PoolKernel<Problem>;
|
||||
|
||||
@@ -131,6 +142,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
auto host_args = ck_tile::PoolHostArgs<decltype(input_shape), decltype(window_spatial_lengths)>{
|
||||
static_cast<InDataType*>(in_buf.GetDeviceBuffer()),
|
||||
static_cast<OutDataType*>(out_buf.GetDeviceBuffer()),
|
||||
OutputIndex ? static_cast<IndexDataType*>(out_index_buf.GetDeviceBuffer()) : nullptr,
|
||||
input_shape,
|
||||
output_shape,
|
||||
input_strides,
|
||||
@@ -167,12 +179,28 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
|
||||
if(do_validation)
|
||||
{
|
||||
ck_tile::reference_pool3d<InDataType, ComputeDataType, OutDataType>(
|
||||
in, out_ref, kernel_args, ReduceOp{});
|
||||
out_buf.FromDevice(out.mData.data());
|
||||
pass = ck_tile::check_err(out, out_ref);
|
||||
|
||||
std::cout << "valid:" << (pass ? "y" : "n") << std::flush << std::endl;
|
||||
ck_tile::reference_pool3d<InDataType,
|
||||
ComputeDataType,
|
||||
OutDataType,
|
||||
IndexDataType,
|
||||
ReduceOp,
|
||||
decltype(input_shape),
|
||||
decltype(window_spatial_lengths),
|
||||
OutputIndex>(in, out_ref, out_ref_index, kernel_args, ReduceOp{});
|
||||
|
||||
if constexpr(OutputIndex)
|
||||
{
|
||||
out_index_buf.FromDevice(out_index.mData.data());
|
||||
pass = ck_tile::check_err(out, out_ref) && ck_tile::check_err(out_index, out_ref_index);
|
||||
}
|
||||
else
|
||||
{
|
||||
pass = ck_tile::check_err(out, out_ref);
|
||||
}
|
||||
|
||||
std::cout << "valid:" << (pass ? "y" : "n") << std::endl;
|
||||
}
|
||||
|
||||
return pass;
|
||||
@@ -184,5 +212,5 @@ int main(int argc, char* argv[])
|
||||
if(!result)
|
||||
return -1;
|
||||
|
||||
return run<ck_tile::half_t, ck_tile::half_t, float>(arg_parser) ? 0 : -2;
|
||||
return run<ck_tile::half_t, ck_tile::half_t, float, ck_tile::index_t>(arg_parser) ? 0 : -2;
|
||||
}
|
||||
|
||||
@@ -78,6 +78,7 @@
|
||||
#include "ck_tile/core/utility/print.hpp"
|
||||
#include "ck_tile/core/utility/random.hpp"
|
||||
#include "ck_tile/core/utility/reduce_operator.hpp"
|
||||
#include "ck_tile/core/utility/reduce_operator_accumulate.hpp"
|
||||
#include "ck_tile/core/utility/static_counter.hpp"
|
||||
#include "ck_tile/core/utility/to_sequence.hpp"
|
||||
#include "ck_tile/core/utility/transpose_vectors.hpp"
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/utility/type_traits.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
@@ -18,16 +19,14 @@ struct Add
|
||||
};
|
||||
|
||||
template <typename T,
|
||||
typename = std::enable_if_t<std::is_same_v<T, float> || std::is_same_v<T, double> ||
|
||||
std::is_same_v<T, int32_t> || std::is_same_v<T, int8_t>>>
|
||||
typename = std::enable_if_t<is_any_of<T, float, double, int32_t, int8_t>::value>>
|
||||
CK_TILE_HOST_DEVICE constexpr T operator()(const T& y, const T x) const
|
||||
{
|
||||
return y + x;
|
||||
}
|
||||
|
||||
template <typename T,
|
||||
typename = std::enable_if_t<std::is_same_v<T, half_t> || std::is_same_v<T, bf16_t> ||
|
||||
std::is_same_v<T, fp8_t> || std::is_same_v<T, bf8_t>>>
|
||||
typename = std::enable_if_t<is_any_of<T, half_t, bf16_t, fp8_t, bf8_t>::value>>
|
||||
CK_TILE_HOST_DEVICE constexpr T operator()(T& y, T x) const
|
||||
{
|
||||
float y_ = type_convert<float>(y);
|
||||
@@ -46,16 +45,14 @@ struct SquareAdd
|
||||
};
|
||||
|
||||
template <typename T,
|
||||
typename = std::enable_if_t<std::is_same_v<T, float> || std::is_same_v<T, double> ||
|
||||
std::is_same_v<T, int32_t> || std::is_same_v<T, int8_t>>>
|
||||
typename = std::enable_if_t<is_any_of<T, float, double, int32_t, int8_t>::value>>
|
||||
CK_TILE_HOST_DEVICE constexpr T operator()(const T& y, const T x) const
|
||||
{
|
||||
return y + (x * x);
|
||||
}
|
||||
|
||||
template <typename T,
|
||||
typename = std::enable_if_t<std::is_same_v<T, half_t> || std::is_same_v<T, bf16_t> ||
|
||||
std::is_same_v<T, fp8_t> || std::is_same_v<T, bf8_t>>>
|
||||
typename = std::enable_if_t<is_any_of<T, half_t, bf16_t, fp8_t, bf8_t>::value>>
|
||||
CK_TILE_HOST_DEVICE constexpr T operator()(T& y, T x) const
|
||||
{
|
||||
float y_ = type_convert<float>(y);
|
||||
@@ -66,48 +63,74 @@ struct SquareAdd
|
||||
|
||||
struct Max
|
||||
{
|
||||
template <typename T,
|
||||
typename = std::enable_if_t<std::is_same_v<T, float> || std::is_same_v<T, double> ||
|
||||
std::is_same_v<T, int32_t> || std::is_same_v<T, int8_t> ||
|
||||
std::is_same_v<T, half_t> || std::is_same_v<T, bf16_t> ||
|
||||
std::is_same_v<T, fp8_t> || std::is_same_v<T, bf8_t>>>
|
||||
template <
|
||||
typename T,
|
||||
typename = std::enable_if_t<
|
||||
is_any_of<T, float, double, int32_t, int8_t, half_t, bf16_t, fp8_t, bf8_t>::value>>
|
||||
CK_TILE_HOST_DEVICE static constexpr T GetIdentityValue()
|
||||
{
|
||||
return numeric<T>::lowest();
|
||||
};
|
||||
|
||||
template <typename T,
|
||||
typename = std::enable_if_t<std::is_same_v<T, float> || std::is_same_v<T, double> ||
|
||||
std::is_same_v<T, int32_t> || std::is_same_v<T, int8_t> ||
|
||||
std::is_same_v<T, half_t> || std::is_same_v<T, bf16_t> ||
|
||||
std::is_same_v<T, fp8_t> || std::is_same_v<T, bf8_t>>>
|
||||
template <
|
||||
typename T,
|
||||
typename = std::enable_if_t<
|
||||
is_any_of<T, float, double, int32_t, int8_t, half_t, bf16_t, fp8_t, bf8_t>::value>>
|
||||
CK_TILE_HOST_DEVICE constexpr T operator()(const T& y, const T x) const
|
||||
{
|
||||
return max(y, x);
|
||||
}
|
||||
|
||||
// Overload with changed flag for index tracking
|
||||
template <
|
||||
typename T,
|
||||
typename = std::enable_if_t<
|
||||
is_any_of<T, float, double, int32_t, int8_t, half_t, bf16_t, fp8_t, bf8_t>::value>>
|
||||
CK_TILE_HOST_DEVICE constexpr T operator()(const T& y, const T x, bool& changed) const
|
||||
{
|
||||
T new_max = max(y, x);
|
||||
if(x > y)
|
||||
{
|
||||
changed = true;
|
||||
}
|
||||
return new_max;
|
||||
}
|
||||
};
|
||||
|
||||
struct AbsMax
|
||||
{
|
||||
template <typename T,
|
||||
typename = std::enable_if_t<std::is_same_v<T, float> || std::is_same_v<T, double> ||
|
||||
std::is_same_v<T, int32_t> || std::is_same_v<T, int8_t> ||
|
||||
std::is_same_v<T, half_t> || std::is_same_v<T, bf16_t> ||
|
||||
std::is_same_v<T, fp8_t> || std::is_same_v<T, bf8_t>>>
|
||||
template <
|
||||
typename T,
|
||||
typename = std::enable_if_t<
|
||||
is_any_of<T, float, double, int32_t, int8_t, half_t, bf16_t, fp8_t, bf8_t>::value>>
|
||||
CK_TILE_HOST_DEVICE static constexpr T GetIdentityValue()
|
||||
{
|
||||
return numeric<T>::zero();
|
||||
};
|
||||
|
||||
template <typename T,
|
||||
typename = std::enable_if_t<std::is_same_v<T, float> || std::is_same_v<T, double> ||
|
||||
std::is_same_v<T, int32_t> || std::is_same_v<T, int8_t> ||
|
||||
std::is_same_v<T, half_t> || std::is_same_v<T, bf16_t> ||
|
||||
std::is_same_v<T, fp8_t> || std::is_same_v<T, bf8_t>>>
|
||||
template <
|
||||
typename T,
|
||||
typename = std::enable_if_t<
|
||||
is_any_of<T, float, double, int32_t, int8_t, half_t, bf16_t, fp8_t, bf8_t>::value>>
|
||||
CK_TILE_HOST_DEVICE constexpr T operator()(const T& y, const T x) const
|
||||
{
|
||||
return max(y, abs(x));
|
||||
}
|
||||
|
||||
// Overload with changed flag for index tracking
|
||||
template <
|
||||
typename T,
|
||||
typename = std::enable_if_t<
|
||||
is_any_of<T, float, double, int32_t, int8_t, half_t, bf16_t, fp8_t, bf8_t>::value>>
|
||||
CK_TILE_HOST_DEVICE constexpr T operator()(const T& y, const T x, bool& changed) const
|
||||
{
|
||||
T new_max = max(y, abs(x));
|
||||
if(abs(x) > y)
|
||||
{
|
||||
changed = true;
|
||||
}
|
||||
return new_max;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ReduceOp
|
||||
|
||||
50
include/ck_tile/core/utility/reduce_operator_accumulate.hpp
Normal file
50
include/ck_tile/core/utility/reduce_operator_accumulate.hpp
Normal file
@@ -0,0 +1,50 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
/// @brief Accumulate with index tracking reductions, provides deterministic first occurring index
|
||||
struct AccumulateWithIndex
|
||||
{
|
||||
template <typename ReduceOp, typename T, typename IndexType>
|
||||
CK_TILE_HOST_DEVICE void operator()(const ReduceOp& reduce_func,
|
||||
T& current_value,
|
||||
IndexType& current_index,
|
||||
const T& new_value,
|
||||
const IndexType& new_index) const
|
||||
{
|
||||
bool changed = false;
|
||||
current_value = reduce_func(current_value, new_value, changed);
|
||||
|
||||
if(changed)
|
||||
{
|
||||
current_index = new_index;
|
||||
}
|
||||
else if(new_index < current_index)
|
||||
{
|
||||
bool reverse_changed = false;
|
||||
reduce_func(new_value, current_value, reverse_changed);
|
||||
|
||||
if(!reverse_changed)
|
||||
{
|
||||
current_index = new_index;
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct Accumulate
|
||||
{
|
||||
template <typename ReduceOp, typename T>
|
||||
CK_TILE_HOST_DEVICE void
|
||||
operator()(const ReduceOp& reduce_func, T& current_value, const T& new_value) const
|
||||
{
|
||||
current_value = reduce_func(current_value, new_value);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -7,17 +7,21 @@
|
||||
#include "ck_tile/host/host_tensor.hpp"
|
||||
#include "ck_tile/ops/pooling/kernel/pool_kernel.hpp"
|
||||
#include <thread>
|
||||
#include <cmath>
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename InDataType,
|
||||
typename ComputeDataType,
|
||||
typename OutDataType,
|
||||
typename IndexDataType,
|
||||
typename ReduceOp,
|
||||
typename TensorShape,
|
||||
typename WindowShape>
|
||||
typename WindowShape,
|
||||
bool OutputIndex = false>
|
||||
CK_TILE_HOST void reference_pool2d(const HostTensor<InDataType>& input,
|
||||
HostTensor<OutDataType>& output,
|
||||
HostTensor<IndexDataType>& output_index,
|
||||
PoolKernelArgs<TensorShape, WindowShape> kargs,
|
||||
ReduceOp reduce_op)
|
||||
{
|
||||
@@ -45,6 +49,8 @@ CK_TILE_HOST void reference_pool2d(const HostTensor<InDataType>& input,
|
||||
auto f = [&](auto n, auto ho, auto wo, auto c) {
|
||||
ComputeDataType v_acc = reduce_op.template GetIdentityValue<ComputeDataType>();
|
||||
|
||||
IndexDataType current_index = 0; // Declare outside if constexpr for efficiency
|
||||
|
||||
for(ck_tile::index_t y = 0; y < Y; ++y)
|
||||
{
|
||||
// Calculate input height index with stride, dilation, and padding
|
||||
@@ -58,13 +64,32 @@ CK_TILE_HOST void reference_pool2d(const HostTensor<InDataType>& input,
|
||||
if(hi >= 0 && hi < H && wi >= 0 && wi < W)
|
||||
{
|
||||
const ComputeDataType v_in = type_convert<ComputeDataType>(input(n, hi, wi, c));
|
||||
v_acc = reduce_op(v_acc, v_in);
|
||||
|
||||
if constexpr(OutputIndex)
|
||||
{
|
||||
IndexDataType flat_index = input.GetOffsetFromMultiIndex(n, hi, wi, c);
|
||||
bool changed = false;
|
||||
v_acc = reduce_op(v_acc, v_in, changed);
|
||||
if(changed)
|
||||
{
|
||||
current_index = flat_index;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
v_acc = reduce_op(v_acc, v_in);
|
||||
}
|
||||
}
|
||||
// For positions outside bounds, we implicitly use identity value
|
||||
}
|
||||
}
|
||||
|
||||
output(n, ho, wo, c) = ck_tile::type_convert<OutDataType>(v_acc);
|
||||
|
||||
if constexpr(OutputIndex)
|
||||
{
|
||||
output_index(n, ho, wo, c) = current_index;
|
||||
}
|
||||
};
|
||||
|
||||
// Parallelize over all output dimensions
|
||||
@@ -74,11 +99,14 @@ CK_TILE_HOST void reference_pool2d(const HostTensor<InDataType>& input,
|
||||
template <typename InDataType,
|
||||
typename ComputeDataType,
|
||||
typename OutDataType,
|
||||
typename IndexDataType,
|
||||
typename ReduceOp,
|
||||
typename TensorShape,
|
||||
typename WindowShape>
|
||||
typename WindowShape,
|
||||
bool OutputIndex = false>
|
||||
CK_TILE_HOST void reference_pool3d(const HostTensor<InDataType>& input,
|
||||
HostTensor<OutDataType>& output,
|
||||
HostTensor<IndexDataType>& output_index,
|
||||
PoolKernelArgs<TensorShape, WindowShape> kargs,
|
||||
ReduceOp reduce_op)
|
||||
{
|
||||
@@ -112,6 +140,8 @@ CK_TILE_HOST void reference_pool3d(const HostTensor<InDataType>& input,
|
||||
auto f = [&](auto n, auto do_, auto ho, auto wo, auto c) {
|
||||
ComputeDataType v_acc = reduce_op.template GetIdentityValue<ComputeDataType>();
|
||||
|
||||
IndexDataType current_index = 0; // Declare outside if constexpr for efficiency
|
||||
|
||||
for(ck_tile::index_t z = 0; z < Z; ++z)
|
||||
{
|
||||
// Calculate input depth index with stride, dilation, and padding
|
||||
@@ -131,7 +161,22 @@ CK_TILE_HOST void reference_pool3d(const HostTensor<InDataType>& input,
|
||||
{
|
||||
const ComputeDataType v_in =
|
||||
type_convert<ComputeDataType>(input(n, di, hi, wi, c));
|
||||
v_acc = reduce_op(v_acc, v_in);
|
||||
|
||||
if constexpr(OutputIndex)
|
||||
{
|
||||
IndexDataType flat_index =
|
||||
input.GetOffsetFromMultiIndex(n, di, hi, wi, c);
|
||||
bool changed = false;
|
||||
v_acc = reduce_op(v_acc, v_in, changed);
|
||||
if(changed)
|
||||
{
|
||||
current_index = flat_index;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
v_acc = reduce_op(v_acc, v_in);
|
||||
}
|
||||
}
|
||||
// For positions outside bounds, we implicitly use identity value
|
||||
}
|
||||
@@ -139,10 +184,15 @@ CK_TILE_HOST void reference_pool3d(const HostTensor<InDataType>& input,
|
||||
}
|
||||
|
||||
output(n, do_, ho, wo, c) = ck_tile::type_convert<OutDataType>(v_acc);
|
||||
|
||||
if constexpr(OutputIndex)
|
||||
{
|
||||
|
||||
output_index(n, do_, ho, wo, c) = current_index;
|
||||
}
|
||||
};
|
||||
|
||||
// Parallelize over all output dimensions
|
||||
make_ParallelTensorFunctor(f, N, Do, Ho, Wo, C)(std::thread::hardware_concurrency());
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -17,6 +17,7 @@ struct PoolHostArgs
|
||||
|
||||
CK_TILE_HOST PoolHostArgs(const void* input_ptr_,
|
||||
void* output_ptr_,
|
||||
void* output_index_ptr_,
|
||||
TensorShape input_shape_,
|
||||
TensorShape output_shape_,
|
||||
TensorShape input_strides_,
|
||||
@@ -28,6 +29,7 @@ struct PoolHostArgs
|
||||
WindowShape input_right_pads_)
|
||||
: input_ptr(input_ptr_),
|
||||
output_ptr(output_ptr_),
|
||||
output_index_ptr(output_index_ptr_),
|
||||
input_shape(input_shape_),
|
||||
output_shape(output_shape_),
|
||||
input_strides(input_strides_),
|
||||
@@ -42,6 +44,7 @@ struct PoolHostArgs
|
||||
|
||||
const void* input_ptr;
|
||||
void* output_ptr;
|
||||
void* output_index_ptr;
|
||||
|
||||
TensorShape input_shape;
|
||||
TensorShape output_shape;
|
||||
@@ -60,6 +63,7 @@ struct PoolKernelArgs
|
||||
{
|
||||
const void* input_ptr;
|
||||
void* output_ptr;
|
||||
void* output_index_ptr;
|
||||
TensorShape input_shape;
|
||||
TensorShape output_shape;
|
||||
TensorShape input_strides;
|
||||
@@ -80,6 +84,7 @@ struct PoolKernel
|
||||
using InDataType = ck_tile::remove_cvref_t<typename Problem::InDataType>;
|
||||
using ComputeDataType = ck_tile::remove_cvref_t<typename Problem::ComputeDataType>;
|
||||
using OutDataType = ck_tile::remove_cvref_t<typename Problem::OutDataType>;
|
||||
using IndexDataType = ck_tile::remove_cvref_t<typename Problem::IndexDataType>;
|
||||
|
||||
static constexpr index_t kBlockSize = Problem::BlockShape::BlockSize;
|
||||
|
||||
@@ -205,7 +210,23 @@ struct PoolKernel
|
||||
tensor_view<decltype(out_buffer_view), decltype(out_desc_padded)>{out_buffer_view,
|
||||
out_desc_padded};
|
||||
|
||||
return make_tuple(in_tensor_padded, out_tensor_padded);
|
||||
if constexpr(Problem::kOutputIndex)
|
||||
{
|
||||
auto out_index_buffer_view = make_buffer_view<address_space_enum::global>(
|
||||
static_cast<IndexDataType*>(kargs.output_index_ptr),
|
||||
out_desc.get_element_space_size(),
|
||||
IndexDataType(-1));
|
||||
const auto out_index_tensor_padded =
|
||||
tensor_view<decltype(out_index_buffer_view), decltype(out_desc_padded)>{
|
||||
out_index_buffer_view, out_desc_padded};
|
||||
|
||||
return make_tuple(in_tensor_padded, out_tensor_padded, out_index_tensor_padded);
|
||||
}
|
||||
else
|
||||
{
|
||||
// Return a dummy tensor for the third element when index output is not needed
|
||||
return make_tuple(in_tensor_padded, out_tensor_padded, null_tensor{});
|
||||
}
|
||||
}
|
||||
|
||||
template <typename TensorShape, typename WindowShape>
|
||||
@@ -338,7 +359,23 @@ struct PoolKernel
|
||||
tensor_view<decltype(out_buffer_view), decltype(out_desc_padded)>{out_buffer_view,
|
||||
out_desc_padded};
|
||||
|
||||
return make_tuple(in_tensor_padded, out_tensor_padded);
|
||||
if constexpr(Problem::kOutputIndex)
|
||||
{
|
||||
auto out_index_buffer_view = make_buffer_view<address_space_enum::global>(
|
||||
static_cast<IndexDataType*>(kargs.output_index_ptr),
|
||||
out_desc.get_element_space_size(),
|
||||
IndexDataType(-1));
|
||||
const auto out_index_tensor_padded =
|
||||
tensor_view<decltype(out_index_buffer_view), decltype(out_desc_padded)>{
|
||||
out_index_buffer_view, out_desc_padded};
|
||||
|
||||
return make_tuple(in_tensor_padded, out_tensor_padded, out_index_tensor_padded);
|
||||
}
|
||||
else
|
||||
{
|
||||
// Return a dummy tensor for the third element when index output is not needed
|
||||
return make_tuple(in_tensor_padded, out_tensor_padded, null_tensor{});
|
||||
}
|
||||
}
|
||||
|
||||
public:
|
||||
@@ -354,7 +391,7 @@ struct PoolKernel
|
||||
const auto iM = get_block_id() * S::Block_M;
|
||||
|
||||
// Get tensors based on dimensionality
|
||||
auto [in_tensor_padded, out_tensor_padded] = [&]() {
|
||||
auto [in_tensor_padded, out_tensor_padded, out_index_tensor_padded] = [&]() {
|
||||
if constexpr(WindowShape::size() == 2)
|
||||
return MakeTensorView2D(kargs);
|
||||
else if constexpr(WindowShape::size() == 3)
|
||||
@@ -387,16 +424,57 @@ struct PoolKernel
|
||||
auto y_tile = block_reduce2d.template MakeYBlockTile<XTensorTile>();
|
||||
set_tile(y_tile, reduce_op.template GetIdentityValue<ComputeDataType>());
|
||||
|
||||
for(int k_tile = __builtin_amdgcn_readfirstlane(0); k_tile < num_k_tiles; ++k_tile)
|
||||
if constexpr(Problem::kOutputIndex)
|
||||
{
|
||||
const auto x_tile = load_tile(x_window);
|
||||
block_reduce2d(x_tile, y_tile, reduce_op);
|
||||
move_tile_window(x_window, {0, S::Block_N});
|
||||
}
|
||||
auto y_index_window =
|
||||
make_tile_window(out_index_tensor_padded, make_tuple(number<S::Block_M>{}), {iM});
|
||||
|
||||
block_reduce2d_sync(y_tile, reduce_op);
|
||||
block_reduce2d_cross_warp(y_tile, smem, reduce_op);
|
||||
store_tile(y_window, cast_tile<OutDataType>(y_tile));
|
||||
auto y_index_tile =
|
||||
block_reduce2d.template MakeYIndexBlockTile<XTensorTile, IndexDataType>();
|
||||
set_tile(y_index_tile, IndexDataType(0));
|
||||
|
||||
// Main reduction loop - with index tracking
|
||||
for(int k_tile = amd_wave_read_first_lane(0); k_tile < num_k_tiles; ++k_tile)
|
||||
{
|
||||
const auto x_tile = load_tile(x_window);
|
||||
auto index_calculator = [&](const auto& x_indices) {
|
||||
// Get global coordinates in the 2D matrix space (M, N)
|
||||
const auto global_M = x_indices.at(number<0>{}) + iM;
|
||||
const auto global_N = (k_tile * S::Block_N) + x_indices.at(number<1>{});
|
||||
return in_tensor_padded.get_tensor_descriptor().calculate_offset(
|
||||
make_tuple(global_M, global_N));
|
||||
};
|
||||
|
||||
block_reduce2d(x_tile, y_tile, y_index_tile, reduce_op, index_calculator);
|
||||
move_tile_window(x_window, {0, S::Block_N});
|
||||
}
|
||||
|
||||
block_reduce2d_sync(y_tile, y_index_tile, reduce_op);
|
||||
if constexpr(Problem::kNeedCrossWarpSync)
|
||||
{
|
||||
__shared__ char smem_indices[Policy::template GetIndicesSmemSize<Problem>()];
|
||||
|
||||
block_reduce2d_cross_warp(y_tile, y_index_tile, smem, smem_indices, reduce_op);
|
||||
}
|
||||
|
||||
store_tile(y_window, cast_tile<OutDataType>(y_tile));
|
||||
store_tile(y_index_window, cast_tile<IndexDataType>(y_index_tile));
|
||||
}
|
||||
else
|
||||
{
|
||||
// Main reduction loop - without index tracking
|
||||
for(int k_tile = __builtin_amdgcn_readfirstlane(0); k_tile < num_k_tiles; ++k_tile)
|
||||
{
|
||||
const auto x_tile = load_tile(x_window);
|
||||
block_reduce2d(x_tile, y_tile, reduce_op);
|
||||
move_tile_window(x_window, {0, S::Block_N});
|
||||
}
|
||||
|
||||
block_reduce2d_sync(y_tile, reduce_op);
|
||||
block_reduce2d_cross_warp(y_tile, smem, reduce_op);
|
||||
|
||||
store_tile(y_window, cast_tile<OutDataType>(y_tile));
|
||||
}
|
||||
}
|
||||
|
||||
/// @brief Validates if the given arguments are supported by the pooling kernel.
|
||||
@@ -481,6 +559,7 @@ struct PoolKernel
|
||||
{
|
||||
return PoolKernelArgs<TensorShape, WindowShape>{host_args.input_ptr,
|
||||
host_args.output_ptr,
|
||||
host_args.output_index_ptr,
|
||||
host_args.input_shape,
|
||||
host_args.output_shape,
|
||||
host_args.input_strides,
|
||||
|
||||
@@ -32,7 +32,8 @@ struct PoolDefaultPolicy
|
||||
{
|
||||
using P_ = BlockReduce2dProblem<typename Problem::InDataType,
|
||||
typename Problem::ComputeDataType,
|
||||
typename Problem::BlockShape>;
|
||||
typename Problem::BlockShape,
|
||||
Problem::kOutputIndex>;
|
||||
return BlockReduce2d<P_>{};
|
||||
}
|
||||
|
||||
@@ -41,7 +42,8 @@ struct PoolDefaultPolicy
|
||||
{
|
||||
using P_ = BlockReduce2dProblem<typename Problem::InDataType,
|
||||
typename Problem::ComputeDataType,
|
||||
typename Problem::BlockShape>;
|
||||
typename Problem::BlockShape,
|
||||
Problem::kOutputIndex>;
|
||||
return BlockReduce2dSync<P_>{};
|
||||
}
|
||||
|
||||
@@ -50,7 +52,8 @@ struct PoolDefaultPolicy
|
||||
{
|
||||
using P_ = BlockReduce2dProblem<typename Problem::InDataType,
|
||||
typename Problem::ComputeDataType,
|
||||
typename Problem::BlockShape>;
|
||||
typename Problem::BlockShape,
|
||||
Problem::kOutputIndex>;
|
||||
return BlockReduce2dCrossWarpSync<P_>{};
|
||||
}
|
||||
|
||||
@@ -61,7 +64,8 @@ struct PoolDefaultPolicy
|
||||
{
|
||||
using P_ = BlockReduce2dProblem<typename Problem::InDataType,
|
||||
typename Problem::ComputeDataType,
|
||||
typename Problem::BlockShape>;
|
||||
typename Problem::BlockShape,
|
||||
Problem::kOutputIndex>;
|
||||
|
||||
using block_reduce2d = BlockReduce2d<P_>;
|
||||
using x_block_tile =
|
||||
@@ -76,5 +80,25 @@ struct PoolDefaultPolicy
|
||||
return 1; // zero size arrays are an extension
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetIndicesSmemSize()
|
||||
{
|
||||
|
||||
using P_ = BlockReduce2dProblem<typename Problem::InDataType,
|
||||
typename Problem::ComputeDataType,
|
||||
typename Problem::BlockShape,
|
||||
Problem::kOutputIndex>;
|
||||
|
||||
using block_reduce2d = BlockReduce2d<P_>;
|
||||
using x_block_tile = decltype(make_static_distributed_tensor<typename Problem::InDataType>(
|
||||
MakeXBlockTileDistribution<Problem>()));
|
||||
using y_index_block_tile = decltype(block_reduce2d::template MakeYIndexBlockTile<
|
||||
x_block_tile,
|
||||
typename Problem::IndexDataType>());
|
||||
|
||||
return GetBlockReduce2dCrossWarpSync<Problem>()
|
||||
.template GetIndicesSmemSize<y_index_block_tile>();
|
||||
}
|
||||
};
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -26,6 +26,8 @@ struct PoolProblem
|
||||
using OutputIndex = bool_constant<OutputIndex_>;
|
||||
using PropagateNan = bool_constant<PropagateNan_>;
|
||||
|
||||
static constexpr bool kOutputIndex = OutputIndex_;
|
||||
static constexpr bool kPropagateNan = PropagateNan_;
|
||||
static constexpr bool kNeedCrossLaneSync = BlockShape::ThreadPerWarp_N > 1;
|
||||
static constexpr bool kNeedCrossWarpSync = BlockShape::WarpPerBlock_N > 1;
|
||||
};
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/core/utility/reduce_operator_accumulate.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
@@ -50,6 +51,53 @@ struct BlockReduce2d
|
||||
|
||||
CK_TILE_DEVICE constexpr BlockReduce2d() {}
|
||||
|
||||
private:
|
||||
template <bool kProcessIndex,
|
||||
typename XDistributedTensor_,
|
||||
typename YDistributedTensor_,
|
||||
typename YIndexDistributedTensor_,
|
||||
typename ReduceFunc,
|
||||
typename IndexCalculatorFunc,
|
||||
typename ReducePacksPerXDim>
|
||||
CK_TILE_DEVICE void reduce_impl(const XDistributedTensor_& x_tensor,
|
||||
YDistributedTensor_& y_tensor,
|
||||
YIndexDistributedTensor_& y_index_tensor,
|
||||
const ReduceFunc& reduce_func,
|
||||
const IndexCalculatorFunc& index_calculator,
|
||||
ReducePacksPerXDim)
|
||||
{
|
||||
sweep_tile<XDistributedTensor_>(
|
||||
[&](auto... idx_) {
|
||||
constexpr auto idx_0 = make_tuple(make_tuple(idx_[number<0>{}]...)[number<0>{}]);
|
||||
|
||||
(..., [&](auto idx) {
|
||||
auto val = ck_tile::type_convert<ComputeDataType>(x_tensor[idx]);
|
||||
|
||||
if constexpr(kProcessIndex)
|
||||
{
|
||||
|
||||
const auto x_indices = get_x_indices_from_distributed_indices(
|
||||
XDistributedTensor_::get_tile_distribution(), idx);
|
||||
const auto new_idx = index_calculator(x_indices);
|
||||
auto current_idx = y_index_tensor(idx_0);
|
||||
|
||||
AccumulateWithIndex{}(
|
||||
reduce_func, y_tensor(idx_0), current_idx, val, new_idx);
|
||||
|
||||
y_index_tensor(idx_0) =
|
||||
type_convert<typename YIndexDistributedTensor_::DataType>(current_idx);
|
||||
}
|
||||
else
|
||||
{
|
||||
Accumulate{}(reduce_func, y_tensor(idx_0), val);
|
||||
}
|
||||
}(idx_));
|
||||
},
|
||||
ReducePacksPerXDim{});
|
||||
}
|
||||
|
||||
public:
|
||||
// Overload for non-index tracking
|
||||
template <
|
||||
typename XDistributedTensor_,
|
||||
typename YDistributedTensor_,
|
||||
@@ -61,13 +109,36 @@ struct BlockReduce2d
|
||||
const ReduceFunc& reduce_func,
|
||||
ReducePacksPerXDim = {})
|
||||
{
|
||||
sweep_tile<XDistributedTensor_>(
|
||||
[&](auto... idx_) {
|
||||
constexpr auto idx_0 = make_tuple(make_tuple(idx_[number<0>{}]...)[number<0>{}]);
|
||||
y_tensor(idx_0) = reduce_func(
|
||||
y_tensor(idx_0), ck_tile::type_convert<ComputeDataType>(x_tensor[idx_])...);
|
||||
},
|
||||
reduce_impl<false>(
|
||||
x_tensor,
|
||||
y_tensor,
|
||||
y_tensor, // dummy
|
||||
reduce_func,
|
||||
[](auto) { return 0; }, // dummy
|
||||
ReducePacksPerXDim{});
|
||||
}
|
||||
|
||||
// Overload for index tracking
|
||||
template <typename XDistributedTensor_,
|
||||
typename YDistributedTensor_,
|
||||
typename YIndexDistributedTensor_,
|
||||
typename ReduceFunc,
|
||||
typename IndexCalculatorFunc,
|
||||
typename ReducePacksPerXDim = uniform_sequence_gen_t<2, 1>>
|
||||
CK_TILE_DEVICE void operator()(const XDistributedTensor_& x_tensor,
|
||||
YDistributedTensor_& y_tensor,
|
||||
YIndexDistributedTensor_& y_index_tensor,
|
||||
const ReduceFunc& reduce_func,
|
||||
const IndexCalculatorFunc& index_calculator,
|
||||
ReducePacksPerXDim = {})
|
||||
{
|
||||
reduce_impl<Problem::kOutputIndex>(x_tensor,
|
||||
y_tensor,
|
||||
y_index_tensor,
|
||||
reduce_func,
|
||||
index_calculator,
|
||||
ReducePacksPerXDim{});
|
||||
}
|
||||
|
||||
#if 0
|
||||
constexpr auto I0 = number<0>{};
|
||||
@@ -90,7 +161,6 @@ struct BlockReduce2d
|
||||
y_tensor(y_dstr_idx) = y;
|
||||
});
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename XDistributedTensor_>
|
||||
CK_TILE_DEVICE static auto MakeYBlockTile()
|
||||
@@ -111,6 +181,25 @@ struct BlockReduce2d
|
||||
return tensor;
|
||||
}
|
||||
|
||||
template <typename XDistributedTensor_, typename IndexDataType = index_t>
|
||||
CK_TILE_DEVICE static auto MakeYIndexBlockTile()
|
||||
{
|
||||
static_assert(std::is_same_v<XDataType, typename XDistributedTensor_::DataType>, "wrong!");
|
||||
|
||||
// FIXME: hard coded to reduce 2nd axis
|
||||
constexpr auto reduce_dims = sequence<1>{};
|
||||
|
||||
constexpr auto dstr =
|
||||
make_static_tile_distribution(detail::make_reduce_tile_distribution_encoding(
|
||||
XDistributedTensor_::get_tile_distribution()
|
||||
.get_static_tile_distribution_encoding(),
|
||||
reduce_dims));
|
||||
|
||||
auto tensor = make_static_distributed_tensor<IndexDataType>(dstr);
|
||||
|
||||
return tensor;
|
||||
}
|
||||
|
||||
// uniform_sequence_gen_t<NSize, Value> generates sequence of NSize elements filled with Value
|
||||
// e.g., uniform_sequence_gen_t<2, 1> → {1, 1} and uniform_sequence_gen_t<3, 4> → {4, 4, 4}
|
||||
template <typename XDistributedTensor_,
|
||||
@@ -135,8 +224,14 @@ struct BlockReduce2dSync
|
||||
{
|
||||
using Problem = remove_cvref_t<Problem_>;
|
||||
|
||||
template <typename YDistributedTensor_, typename ReduceFunc>
|
||||
CK_TILE_DEVICE void operator()(YDistributedTensor_& y_tensor, const ReduceFunc& reduce_func)
|
||||
private:
|
||||
template <bool kProcessIndex,
|
||||
typename YDistributedTensor_,
|
||||
typename YIndexDistributedTensor_,
|
||||
typename ReduceFunc>
|
||||
CK_TILE_DEVICE void reduce_impl(YDistributedTensor_& y_tensor,
|
||||
YIndexDistributedTensor_& y_index_tensor,
|
||||
const ReduceFunc& reduce_func)
|
||||
{
|
||||
using Dstr = typename YDistributedTensor_::StaticTileDistribution;
|
||||
using DstrEncode = typename Dstr::DstrEncode;
|
||||
@@ -157,6 +252,14 @@ struct BlockReduce2dSync
|
||||
static_for<0, thread_buf_size, 1>{}([&](auto i) {
|
||||
auto v_local = y_tensor.get_thread_buffer()[i];
|
||||
|
||||
using IndexDataType = typename YIndexDistributedTensor_::DataType;
|
||||
IndexDataType idx_local{};
|
||||
|
||||
if constexpr(kProcessIndex)
|
||||
{
|
||||
idx_local = y_index_tensor.get_thread_buffer()[i];
|
||||
}
|
||||
|
||||
// cross-lane reduce for replication
|
||||
// only reduce on R dimension correspond to lane
|
||||
// (lane id maps to this R dimension)
|
||||
@@ -183,15 +286,46 @@ struct BlockReduce2dSync
|
||||
|
||||
// pull data from remote lane
|
||||
const auto v_remote = warp_shuffle(v_local, src_lane);
|
||||
v_local = reduce_func(v_local, v_remote);
|
||||
|
||||
if constexpr(kProcessIndex)
|
||||
{
|
||||
const auto idx_remote = warp_shuffle(idx_local, src_lane);
|
||||
|
||||
AccumulateWithIndex{}(
|
||||
reduce_func, v_local, idx_local, v_remote, idx_remote);
|
||||
}
|
||||
else
|
||||
{
|
||||
Accumulate{}(reduce_func, v_local, v_remote);
|
||||
}
|
||||
});
|
||||
}
|
||||
});
|
||||
|
||||
// TODO - Do we need to broadcast to other lane?
|
||||
y_tensor.get_thread_buffer()(i) = v_local;
|
||||
|
||||
if constexpr(kProcessIndex)
|
||||
{
|
||||
y_index_tensor.get_thread_buffer()(i) = idx_local;
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
public:
|
||||
template <typename YDistributedTensor_, typename ReduceFunc>
|
||||
CK_TILE_DEVICE void operator()(YDistributedTensor_& y_tensor, const ReduceFunc& reduce_func)
|
||||
{
|
||||
reduce_impl<false>(y_tensor, y_tensor, reduce_func);
|
||||
}
|
||||
|
||||
template <typename YDistributedTensor_, typename YIndexDistributedTensor_, typename ReduceFunc>
|
||||
CK_TILE_DEVICE void operator()(YDistributedTensor_& y_tensor,
|
||||
YIndexDistributedTensor_& y_index_tensor,
|
||||
const ReduceFunc& reduce_func)
|
||||
{
|
||||
reduce_impl<Problem::kOutputIndex>(y_tensor, y_index_tensor, reduce_func);
|
||||
}
|
||||
};
|
||||
|
||||
// BlockReduce2dCrossWarpSync: Cross-warp reduction (Stage 3)
|
||||
@@ -250,15 +384,39 @@ struct BlockReduce2dCrossWarpSync
|
||||
return num_warps * thread_buf_size * sizeof(DataType);
|
||||
}
|
||||
|
||||
template <typename YDistributedTensor_, typename ReduceFunc>
|
||||
CK_TILE_DEVICE void
|
||||
operator()(YDistributedTensor_& y_tensor, void* smem, const ReduceFunc& reduce_func)
|
||||
// return in byte - separate shared memory size calculation for indices
|
||||
template <typename YIndexDistributedTensor_>
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetIndicesSmemSize()
|
||||
{
|
||||
using DataType = typename YDistributedTensor_::DataType;
|
||||
using IndexDataType = typename YIndexDistributedTensor_::DataType;
|
||||
constexpr index_t thread_buf_size = YIndexDistributedTensor_::get_thread_buffer_size();
|
||||
constexpr index_t num_warps = BlockShape::BlockSize / get_warp_size();
|
||||
return num_warps * thread_buf_size * sizeof(IndexDataType);
|
||||
}
|
||||
|
||||
private:
|
||||
template <bool kProcessIndex,
|
||||
typename YDistributedTensor_,
|
||||
typename YIndexDistributedTensor_,
|
||||
typename ReduceFunc>
|
||||
CK_TILE_DEVICE void reduce_impl(YDistributedTensor_& y_tensor,
|
||||
YIndexDistributedTensor_& y_index_tensor,
|
||||
void* smem,
|
||||
void* smem_indices_ptr,
|
||||
const ReduceFunc& reduce_func)
|
||||
{
|
||||
using DataType = typename YDistributedTensor_::DataType;
|
||||
using IndexDataType = typename YIndexDistributedTensor_::DataType;
|
||||
|
||||
constexpr index_t thread_buf_size = YDistributedTensor_::get_thread_buffer_size();
|
||||
|
||||
DataType* smem_ptr = reinterpret_cast<DataType*>(smem);
|
||||
DataType* smem_ptr = reinterpret_cast<DataType*>(smem);
|
||||
IndexDataType* smem_indices = nullptr;
|
||||
if constexpr(kProcessIndex)
|
||||
{
|
||||
smem_indices = reinterpret_cast<IndexDataType*>(smem_indices_ptr);
|
||||
}
|
||||
|
||||
const index_t lane_id = get_lane_id();
|
||||
const index_t warp_id = get_warp_id();
|
||||
|
||||
@@ -275,6 +433,11 @@ struct BlockReduce2dCrossWarpSync
|
||||
static_for<0, thread_buf_size, 1>{}([&](auto i) {
|
||||
// Store the i-th element of this warp's thread_buffer into SMEM
|
||||
smem_ptr[smem_offset + i * num_warps] = y_tensor.get_thread_buffer()[i];
|
||||
if constexpr(kProcessIndex)
|
||||
{
|
||||
smem_indices[smem_offset + i * num_warps] =
|
||||
y_index_tensor.get_thread_buffer()[i];
|
||||
}
|
||||
});
|
||||
}
|
||||
block_sync_lds();
|
||||
@@ -282,10 +445,19 @@ struct BlockReduce2dCrossWarpSync
|
||||
// We let each warp holds a duplication to do reduction.
|
||||
const index_t local_warp_id = warp_id / num_reduce_warps;
|
||||
const index_t local_smem_os = local_warp_id * num_reduce_warps;
|
||||
|
||||
static_for<0, thread_buf_size, 1>{}([&](auto i) {
|
||||
DataType v[num_reduce_warps];
|
||||
static_for<0, num_reduce_warps, 1>{}(
|
||||
[&](auto idx) { v[idx] = smem_ptr[i * num_warps + local_smem_os + idx]; });
|
||||
[[maybe_unused]] std::
|
||||
conditional_t<kProcessIndex, IndexDataType[num_reduce_warps], IndexDataType> idx_v;
|
||||
|
||||
static_for<0, num_reduce_warps, 1>{}([&](auto idx) {
|
||||
v[idx] = smem_ptr[i * num_warps + local_smem_os + idx];
|
||||
if constexpr(kProcessIndex)
|
||||
{
|
||||
idx_v[idx] = smem_indices[i * num_warps + local_smem_os + idx];
|
||||
}
|
||||
});
|
||||
|
||||
static_assert(is_power_of_two_integer(num_reduce_warps),
|
||||
"wrong! only support power of 2 reduction");
|
||||
@@ -299,14 +471,44 @@ struct BlockReduce2dCrossWarpSync
|
||||
constexpr index_t i1 = idx_ + stride;
|
||||
if constexpr(i1 < num_reduce_warps)
|
||||
{
|
||||
v[i0] = reduce_func(v[i0], v[i1]);
|
||||
if constexpr(kProcessIndex)
|
||||
{
|
||||
AccumulateWithIndex{}(reduce_func, v[i0], idx_v[i0], v[i1], idx_v[i1]);
|
||||
}
|
||||
else
|
||||
{
|
||||
Accumulate{}(reduce_func, v[i0], v[i1]);
|
||||
}
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
y_tensor.get_thread_buffer()(i) = v[0];
|
||||
if constexpr(kProcessIndex)
|
||||
{
|
||||
y_index_tensor.get_thread_buffer()(i) = idx_v[0];
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
public:
|
||||
template <typename YDistributedTensor_, typename ReduceFunc>
|
||||
CK_TILE_DEVICE void
|
||||
operator()(YDistributedTensor_& y_tensor, void* smem, const ReduceFunc& reduce_func)
|
||||
{
|
||||
reduce_impl<false>(y_tensor, y_tensor, smem, nullptr, reduce_func);
|
||||
}
|
||||
|
||||
template <typename YDistributedTensor_, typename YIndexDistributedTensor_, typename ReduceFunc>
|
||||
CK_TILE_DEVICE void operator()(YDistributedTensor_& y_tensor,
|
||||
YIndexDistributedTensor_& y_index_tensor,
|
||||
void* smem,
|
||||
void* smem_indices,
|
||||
const ReduceFunc& reduce_func)
|
||||
{
|
||||
reduce_impl<Problem::kOutputIndex>(
|
||||
y_tensor, y_index_tensor, smem, smem_indices, reduce_func);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Problem_, typename Policy_ = void>
|
||||
@@ -364,15 +566,39 @@ struct BlockReduce2dLinearCrossWarpSync
|
||||
return num_warps * thread_buf_size * sizeof(DataType);
|
||||
}
|
||||
|
||||
template <typename YDistributedTensor_, typename ReduceFunc>
|
||||
CK_TILE_DEVICE void
|
||||
operator()(YDistributedTensor_& y_tensor, void* smem, const ReduceFunc& reduce_func)
|
||||
// return in byte - separate shared memory size calculation for indices
|
||||
template <typename YIndexDistributedTensor_>
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetIndicesSmemSize()
|
||||
{
|
||||
using DataType = typename YDistributedTensor_::DataType;
|
||||
using IndexDataType = typename YIndexDistributedTensor_::DataType;
|
||||
constexpr index_t thread_buf_size = YIndexDistributedTensor_::get_thread_buffer_size();
|
||||
constexpr index_t num_warps = BlockShape::BlockSize / get_warp_size();
|
||||
return num_warps * thread_buf_size * sizeof(IndexDataType);
|
||||
}
|
||||
|
||||
private:
|
||||
template <bool kProcessIndex,
|
||||
typename YDistributedTensor_,
|
||||
typename YIndexDistributedTensor_,
|
||||
typename ReduceFunc>
|
||||
CK_TILE_DEVICE void reduce_impl(YDistributedTensor_& y_tensor,
|
||||
YIndexDistributedTensor_& y_index_tensor,
|
||||
void* smem,
|
||||
void* smem_indices_ptr,
|
||||
const ReduceFunc& reduce_func)
|
||||
{
|
||||
using DataType = typename YDistributedTensor_::DataType;
|
||||
using IndexDataType = typename YIndexDistributedTensor_::DataType;
|
||||
|
||||
constexpr index_t thread_buf_size = YDistributedTensor_::get_thread_buffer_size();
|
||||
|
||||
DataType* smem_ptr = reinterpret_cast<DataType*>(smem);
|
||||
DataType* smem_ptr = reinterpret_cast<DataType*>(smem);
|
||||
IndexDataType* smem_indices = nullptr;
|
||||
if constexpr(kProcessIndex)
|
||||
{
|
||||
smem_indices = reinterpret_cast<IndexDataType*>(smem_indices_ptr);
|
||||
}
|
||||
|
||||
const index_t lane_id = get_lane_id();
|
||||
const index_t warp_id = get_warp_id();
|
||||
constexpr auto num_reduce_warps = GetReduceWarps<YDistributedTensor_>();
|
||||
@@ -388,6 +614,11 @@ struct BlockReduce2dLinearCrossWarpSync
|
||||
{
|
||||
static_for<0, thread_buf_size, 1>{}([&](auto i) {
|
||||
smem_ptr[smem_offset + i * num_warps] = y_tensor.get_thread_buffer()[i];
|
||||
if constexpr(kProcessIndex)
|
||||
{
|
||||
smem_indices[smem_offset + i * num_warps] =
|
||||
y_index_tensor.get_thread_buffer()[i];
|
||||
}
|
||||
});
|
||||
}
|
||||
block_sync_lds();
|
||||
@@ -395,31 +626,86 @@ struct BlockReduce2dLinearCrossWarpSync
|
||||
// load from smem. here we let everythread to do compute :)
|
||||
index_t local_warp_id = warp_id / num_reduce_warps;
|
||||
index_t local_smem_os = local_warp_id * num_reduce_warps;
|
||||
|
||||
DataType all_scratch[thread_buf_size * num_reduce_warps];
|
||||
[[maybe_unused]] std::conditional_t<kProcessIndex,
|
||||
IndexDataType[thread_buf_size * num_reduce_warps],
|
||||
IndexDataType> all_indices;
|
||||
|
||||
// Load data from shared memory
|
||||
static_for<0, thread_buf_size, 1>{}([&](auto i_0) {
|
||||
static_for<0, num_reduce_warps, 1>{}([&](auto i_1) {
|
||||
all_scratch[i_0 * num_reduce_warps + i_1] =
|
||||
smem_ptr[i_0 * num_warps + local_smem_os + i_1];
|
||||
|
||||
if constexpr(kProcessIndex)
|
||||
{
|
||||
all_indices[i_0 * num_reduce_warps + i_1] =
|
||||
smem_indices[i_0 * num_warps + local_smem_os + i_1];
|
||||
}
|
||||
});
|
||||
});
|
||||
block_sync_lds(); // TODO: we don't need sync here
|
||||
|
||||
// Perform reduction
|
||||
static_for<0, thread_buf_size, 1>{}([&](auto i_0) {
|
||||
// TODO: use descriptor for this
|
||||
auto v_local = all_scratch[i_0 * num_reduce_warps];
|
||||
|
||||
IndexDataType idx_local{};
|
||||
if constexpr(kProcessIndex)
|
||||
{
|
||||
idx_local = all_indices[i_0 * num_reduce_warps];
|
||||
}
|
||||
|
||||
// further reduce mean/var
|
||||
static_for<0, num_reduce_warps - 1, 1>{}([&](auto i_1_n1) {
|
||||
constexpr auto i_1 = number<i_1_n1 + 1>{};
|
||||
const DataType v_remote = all_scratch[i_0 * num_reduce_warps + i_1];
|
||||
|
||||
// reduce
|
||||
v_local = reduce_func(v_local, v_remote);
|
||||
if constexpr(kProcessIndex)
|
||||
{
|
||||
const IndexDataType idx_remote = all_indices[i_0 * num_reduce_warps + i_1];
|
||||
|
||||
bool changed = false;
|
||||
v_local = reduce_func(v_local, v_remote, changed);
|
||||
if(changed)
|
||||
{
|
||||
idx_local = idx_remote;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
v_local = reduce_func(v_local, v_remote);
|
||||
}
|
||||
});
|
||||
|
||||
y_tensor.get_thread_buffer()(i_0) = v_local;
|
||||
if constexpr(kProcessIndex)
|
||||
{
|
||||
y_index_tensor.get_thread_buffer()(i_0) = idx_local;
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
public:
|
||||
template <typename YDistributedTensor_, typename ReduceFunc>
|
||||
CK_TILE_DEVICE void
|
||||
operator()(YDistributedTensor_& y_tensor, void* smem, const ReduceFunc& reduce_func)
|
||||
{
|
||||
reduce_impl<false>(y_tensor, y_tensor, smem, nullptr, reduce_func);
|
||||
}
|
||||
|
||||
template <typename YDistributedTensor_, typename YIndexDistributedTensor_, typename ReduceFunc>
|
||||
CK_TILE_DEVICE void operator()(YDistributedTensor_& y_tensor,
|
||||
YIndexDistributedTensor_& y_index_tensor,
|
||||
void* smem,
|
||||
void* smem_indices,
|
||||
const ReduceFunc& reduce_func)
|
||||
{
|
||||
reduce_impl<Problem::kOutputIndex>(
|
||||
y_tensor, y_index_tensor, smem, smem_indices, reduce_func);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -7,12 +7,17 @@
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename XDataType_, typename ComputeDataType_, typename BlockShape_>
|
||||
template <typename XDataType_,
|
||||
typename ComputeDataType_,
|
||||
typename BlockShape_,
|
||||
bool OutputIndex_ = false>
|
||||
struct BlockReduce2dProblem
|
||||
{
|
||||
using XDataType = remove_cvref_t<XDataType_>;
|
||||
using ComputeDataType = remove_cvref_t<ComputeDataType_>;
|
||||
using BlockShape = remove_cvref_t<BlockShape_>;
|
||||
|
||||
static constexpr bool kOutputIndex = OutputIndex_;
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -32,7 +32,8 @@ struct Reduce2dDefaultPolicy
|
||||
{
|
||||
using P_ = BlockReduce2dProblem<typename Problem::XDataType,
|
||||
typename Problem::ComputeDataType,
|
||||
typename Problem::BlockShape>;
|
||||
typename Problem::BlockShape,
|
||||
Problem::kOutputIndex>;
|
||||
return BlockReduce2d<P_>{};
|
||||
}
|
||||
|
||||
@@ -41,7 +42,8 @@ struct Reduce2dDefaultPolicy
|
||||
{
|
||||
using P_ = BlockReduce2dProblem<typename Problem::XDataType,
|
||||
typename Problem::ComputeDataType,
|
||||
typename Problem::BlockShape>;
|
||||
typename Problem::BlockShape,
|
||||
Problem::kOutputIndex>;
|
||||
return BlockReduce2dSync<P_>{};
|
||||
}
|
||||
|
||||
@@ -50,7 +52,8 @@ struct Reduce2dDefaultPolicy
|
||||
{
|
||||
using P_ = BlockReduce2dProblem<typename Problem::XDataType,
|
||||
typename Problem::ComputeDataType,
|
||||
typename Problem::BlockShape>;
|
||||
typename Problem::BlockShape,
|
||||
Problem::kOutputIndex>;
|
||||
return BlockReduce2dCrossWarpSync<P_>{};
|
||||
}
|
||||
|
||||
@@ -61,7 +64,8 @@ struct Reduce2dDefaultPolicy
|
||||
{
|
||||
using P_ = BlockReduce2dProblem<typename Problem::XDataType,
|
||||
typename Problem::ComputeDataType,
|
||||
typename Problem::BlockShape>;
|
||||
typename Problem::BlockShape,
|
||||
Problem::kOutputIndex>;
|
||||
|
||||
using block_reduce2d = BlockReduce2d<P_>;
|
||||
using x_block_tile =
|
||||
@@ -76,5 +80,23 @@ struct Reduce2dDefaultPolicy
|
||||
return 1; // zero size arrays are an extension
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetIndicesSmemSize()
|
||||
{
|
||||
using P_ = BlockReduce2dProblem<typename Problem::XDataType,
|
||||
typename Problem::ComputeDataType,
|
||||
typename Problem::BlockShape,
|
||||
Problem::kOutputIndex>;
|
||||
|
||||
using block_reduce2d = BlockReduce2d<P_>;
|
||||
using x_block_tile = decltype(make_static_distributed_tensor<typename Problem::XDataType>(
|
||||
MakeXBlockTileDistribution<Problem>()));
|
||||
using y_index_block_tile =
|
||||
decltype(block_reduce2d::template MakeYIndexBlockTile<x_block_tile, index_t>());
|
||||
|
||||
return GetBlockReduce2dCrossWarpSync<Problem>()
|
||||
.template GetIndicesSmemSize<y_index_block_tile>();
|
||||
}
|
||||
};
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -11,7 +11,8 @@ template <typename XDataType_,
|
||||
typename ComputeDataType_,
|
||||
typename YDataType_,
|
||||
typename BlockShape_,
|
||||
typename ReduceOp_>
|
||||
typename ReduceOp_,
|
||||
bool OutputIndex_ = false>
|
||||
struct Reduce2dProblem
|
||||
{
|
||||
using XDataType = remove_cvref_t<XDataType_>;
|
||||
@@ -20,6 +21,7 @@ struct Reduce2dProblem
|
||||
using BlockShape = remove_cvref_t<BlockShape_>;
|
||||
using ReduceOp = ReduceOp_;
|
||||
|
||||
static constexpr bool kOutputIndex = OutputIndex_;
|
||||
static constexpr bool kNeedCrossLaneSync = BlockShape::ThreadPerWarp_N > 1;
|
||||
static constexpr bool kNeedCrossWarpSync = BlockShape::WarpPerBlock_N > 1;
|
||||
};
|
||||
|
||||
@@ -28,7 +28,19 @@ class TestCkTilePooling : public ::testing::Test
|
||||
|
||||
using TestPoolShape = ck_tile::PoolShape<BlockWarps_, BlockTile_, WarpTile_, ThreadTile_>;
|
||||
|
||||
// 3D pooling configuration
|
||||
// 2D pooling configuration (NHWC)
|
||||
struct Config2D
|
||||
{
|
||||
ck_tile::index_t N, H, W, C;
|
||||
ck_tile::index_t Y, X;
|
||||
ck_tile::index_t Sy, Sx;
|
||||
ck_tile::index_t Dy, Dx;
|
||||
ck_tile::index_t LeftPy, LeftPx;
|
||||
ck_tile::index_t RightPy, RightPx;
|
||||
std::string name;
|
||||
};
|
||||
|
||||
// 3D pooling configuration (NDHWC)
|
||||
struct Config3D
|
||||
{
|
||||
ck_tile::index_t N, D, H, W, C;
|
||||
@@ -40,6 +52,117 @@ class TestCkTilePooling : public ::testing::Test
|
||||
std::string name;
|
||||
};
|
||||
|
||||
bool RunPool2D(const Config2D& config)
|
||||
{
|
||||
std::cout << "Testing 2D: " << config.name << " ... ";
|
||||
|
||||
const ck_tile::index_t Ys = (config.Y - 1) * config.Dy + 1;
|
||||
const ck_tile::index_t Xs = (config.X - 1) * config.Dx + 1;
|
||||
const ck_tile::index_t Ho =
|
||||
(config.H + config.LeftPy + config.RightPy - Ys) / config.Sy + 1;
|
||||
const ck_tile::index_t Wo =
|
||||
(config.W + config.LeftPx + config.RightPx - Xs) / config.Sx + 1;
|
||||
|
||||
using IndexDataType = ck_tile::index_t;
|
||||
|
||||
// Host tensors
|
||||
ck_tile::HostTensor<InDataType> h_in({config.N, config.H, config.W, config.C});
|
||||
ck_tile::HostTensor<OutDataType> h_out({config.N, Ho, Wo, config.C});
|
||||
ck_tile::HostTensor<OutDataType> h_out_ref({config.N, Ho, Wo, config.C});
|
||||
ck_tile::HostTensor<IndexDataType> h_out_index({config.N, Ho, Wo, config.C});
|
||||
ck_tile::HostTensor<IndexDataType> h_out_ref_index({config.N, Ho, Wo, config.C});
|
||||
|
||||
// Initialize input with random data
|
||||
ck_tile::FillUniformDistribution<InDataType>{-5.f, 5.f}(h_in);
|
||||
|
||||
// Device memory
|
||||
ck_tile::DeviceMem d_in_mem(h_in.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem d_out_mem(h_out.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem d_out_index_mem(h_out_index.get_element_space_size_in_bytes());
|
||||
|
||||
d_in_mem.ToDevice(h_in.data());
|
||||
d_out_mem.ToDevice(h_out.data());
|
||||
d_out_index_mem.ToDevice(h_out_index.data());
|
||||
|
||||
constexpr ck_tile::index_t kBlockPerCu = 1;
|
||||
|
||||
using Problem = ck_tile::PoolProblem<InDataType,
|
||||
OutDataType,
|
||||
ComputeDataType,
|
||||
IndexDataType,
|
||||
ReduceOpType,
|
||||
true, // OutputIndex
|
||||
false, // PropagateNan
|
||||
TestPoolShape>;
|
||||
using Kernel = ck_tile::PoolKernel<Problem>;
|
||||
|
||||
const ck_tile::index_t kBlockSize = Kernel::BlockSize();
|
||||
|
||||
// Shapes and strides (NHWC)
|
||||
const auto input_shape = ck_tile::make_tuple(config.N, config.H, config.W, config.C);
|
||||
const auto output_shape = ck_tile::make_tuple(config.N, Ho, Wo, config.C);
|
||||
const auto input_strides =
|
||||
ck_tile::make_tuple(config.H * config.W * config.C, config.W * config.C, config.C, 1);
|
||||
const auto output_strides =
|
||||
ck_tile::make_tuple(Ho * Wo * config.C, Wo * config.C, config.C, 1);
|
||||
const auto window_spatial_lengths = ck_tile::make_tuple(config.Y, config.X);
|
||||
const auto window_strides = ck_tile::make_tuple(config.Sy, config.Sx);
|
||||
const auto window_dilations = ck_tile::make_tuple(config.Dy, config.Dx);
|
||||
const auto input_left_pads = ck_tile::make_tuple(config.LeftPy, config.LeftPx);
|
||||
const auto input_right_pads = ck_tile::make_tuple(config.RightPy, config.RightPx);
|
||||
|
||||
auto host_args =
|
||||
ck_tile::PoolHostArgs<decltype(input_shape), decltype(window_spatial_lengths)>{
|
||||
static_cast<InDataType*>(d_in_mem.GetDeviceBuffer()),
|
||||
static_cast<OutDataType*>(d_out_mem.GetDeviceBuffer()),
|
||||
static_cast<IndexDataType*>(d_out_index_mem.GetDeviceBuffer()),
|
||||
input_shape,
|
||||
output_shape,
|
||||
input_strides,
|
||||
output_strides,
|
||||
window_spatial_lengths,
|
||||
window_strides,
|
||||
window_dilations,
|
||||
input_left_pads,
|
||||
input_right_pads};
|
||||
|
||||
auto kernel_args = Kernel::MakeKernelArgs(host_args);
|
||||
const ck_tile::index_t kGridSize = Kernel::CalculateGridSize(kernel_args);
|
||||
|
||||
if(!Kernel::IsSupportedArgument(kernel_args))
|
||||
{
|
||||
return true;
|
||||
}
|
||||
|
||||
// Run kernel
|
||||
ck_tile::launch_kernel(
|
||||
ck_tile::stream_config{nullptr, false, 0},
|
||||
ck_tile::make_kernel<kBlockPerCu>(Kernel{}, kGridSize, kBlockSize, 0, kernel_args));
|
||||
|
||||
// Run reference
|
||||
ck_tile::reference_pool2d<InDataType,
|
||||
ComputeDataType,
|
||||
OutDataType,
|
||||
IndexDataType,
|
||||
ReduceOpType,
|
||||
decltype(input_shape),
|
||||
decltype(window_spatial_lengths),
|
||||
true>(
|
||||
h_in, h_out_ref, h_out_ref_index, kernel_args, ReduceOpType{});
|
||||
|
||||
d_out_mem.FromDevice(h_out.data());
|
||||
d_out_index_mem.FromDevice(h_out_index.data());
|
||||
|
||||
// Validate results
|
||||
bool pass_value =
|
||||
ck_tile::check_err(h_out, h_out_ref, "Error: Incorrect values!", 1e-5, 1e-5);
|
||||
bool pass_index = ck_tile::check_err(
|
||||
h_out_index, h_out_ref_index, "Error: Incorrect indices!", 1e-5, 1e-5);
|
||||
|
||||
std::cout << (pass_value && pass_index ? "PASS" : "FAIL") << std::endl;
|
||||
return pass_value && pass_index;
|
||||
}
|
||||
|
||||
bool RunPool3D(const Config3D& config)
|
||||
{
|
||||
std::cout << "Testing 3D: " << config.name << " ... ";
|
||||
@@ -72,6 +195,8 @@ class TestCkTilePooling : public ::testing::Test
|
||||
const auto input_right_pads =
|
||||
ck_tile::make_tuple(config.RightPz, config.RightPy, config.RightPx);
|
||||
|
||||
using IndexDataType = ck_tile::index_t;
|
||||
|
||||
ck_tile::HostTensor<InDataType> h_in({config.N, config.D, config.H, config.W, config.C},
|
||||
{config.D * config.H * config.W * config.C,
|
||||
config.H * config.W * config.C,
|
||||
@@ -84,6 +209,12 @@ class TestCkTilePooling : public ::testing::Test
|
||||
ck_tile::HostTensor<OutDataType> h_out_ref(
|
||||
{config.N, Do, Ho, Wo, config.C},
|
||||
{Do * Ho * Wo * config.C, Ho * Wo * config.C, Wo * config.C, config.C, 1});
|
||||
ck_tile::HostTensor<IndexDataType> h_out_index(
|
||||
{config.N, Do, Ho, Wo, config.C},
|
||||
{Do * Ho * Wo * config.C, Ho * Wo * config.C, Wo * config.C, config.C, 1});
|
||||
ck_tile::HostTensor<IndexDataType> h_out_ref_index(
|
||||
{config.N, Do, Ho, Wo, config.C},
|
||||
{Do * Ho * Wo * config.C, Ho * Wo * config.C, Wo * config.C, config.C, 1});
|
||||
|
||||
ck_tile::FillUniformDistribution<InDataType>{-5.f, 5.f}(h_in);
|
||||
h_out.SetZero();
|
||||
@@ -91,17 +222,19 @@ class TestCkTilePooling : public ::testing::Test
|
||||
|
||||
ck_tile::DeviceMem d_in_mem(h_in.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem d_out_mem(h_out.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem d_out_index_mem(h_out_index.get_element_space_size_in_bytes());
|
||||
|
||||
d_in_mem.ToDevice(h_in.data());
|
||||
d_out_mem.ToDevice(h_out.data());
|
||||
d_out_index_mem.ToDevice(h_out_index.data());
|
||||
|
||||
using Problem = ck_tile::PoolProblem<InDataType,
|
||||
OutDataType,
|
||||
ComputeDataType,
|
||||
OutDataType,
|
||||
IndexDataType,
|
||||
ReduceOpType,
|
||||
false,
|
||||
false,
|
||||
true, // OutputIndex
|
||||
false, // PropagateNan
|
||||
TestPoolShape>;
|
||||
using Kernel = ck_tile::PoolKernel<Problem>;
|
||||
|
||||
@@ -112,6 +245,7 @@ class TestCkTilePooling : public ::testing::Test
|
||||
ck_tile::PoolHostArgs<decltype(input_shape), decltype(window_spatial_lengths)>{
|
||||
static_cast<InDataType*>(d_in_mem.GetDeviceBuffer()),
|
||||
static_cast<OutDataType*>(d_out_mem.GetDeviceBuffer()),
|
||||
static_cast<IndexDataType*>(d_out_index_mem.GetDeviceBuffer()),
|
||||
input_shape,
|
||||
output_shape,
|
||||
input_strides,
|
||||
@@ -137,16 +271,27 @@ class TestCkTilePooling : public ::testing::Test
|
||||
ck_tile::make_kernel<kBlockPerCu>(Kernel{}, kGridSize, kBlockSize, 0, kernel_args));
|
||||
|
||||
// Run reference implementation
|
||||
ck_tile::reference_pool3d<InDataType, ComputeDataType, OutDataType>(
|
||||
h_in, h_out_ref, kernel_args, ReduceOpType{});
|
||||
ck_tile::reference_pool3d<InDataType,
|
||||
ComputeDataType,
|
||||
OutDataType,
|
||||
IndexDataType,
|
||||
ReduceOpType,
|
||||
decltype(input_shape),
|
||||
decltype(window_spatial_lengths),
|
||||
true>(
|
||||
h_in, h_out_ref, h_out_ref_index, kernel_args, ReduceOpType{});
|
||||
|
||||
d_out_mem.FromDevice(h_out.data());
|
||||
d_out_index_mem.FromDevice(h_out_index.data());
|
||||
|
||||
// Validate results
|
||||
bool pass = ck_tile::check_err(h_out, h_out_ref);
|
||||
std::cout << (pass ? "PASS" : "FAIL") << std::endl;
|
||||
bool pass_value =
|
||||
ck_tile::check_err(h_out, h_out_ref, "Error: Incorrect values!", 1e-5, 1e-5);
|
||||
bool pass_index = ck_tile::check_err(
|
||||
h_out_index, h_out_ref_index, "Error: Incorrect indices!", 1e-5, 1e-5);
|
||||
|
||||
return pass;
|
||||
std::cout << (pass_value && pass_index ? "PASS" : "FAIL") << std::endl;
|
||||
return pass_value && pass_index;
|
||||
}
|
||||
};
|
||||
|
||||
@@ -194,6 +339,50 @@ using TestTypes =
|
||||
|
||||
TYPED_TEST_SUITE(TestCkTilePooling, TestTypes);
|
||||
|
||||
// 2D Pooling Tests (NHWC)
|
||||
TYPED_TEST(TestCkTilePooling, Pool2D_2x2)
|
||||
{
|
||||
typename TestFixture::Config2D config = {1, // N - batch size
|
||||
8, // H - height dimension
|
||||
8, // W - width dimension
|
||||
32, // C - channel dimension
|
||||
2, // Y - pooling window height
|
||||
2, // X - pooling window width
|
||||
2, // Sy - window stride height
|
||||
2, // Sx - window stride width
|
||||
1, // Dy - window dilation height
|
||||
1, // Dx - window dilation width
|
||||
0, // LeftPy - left padding height
|
||||
0, // LeftPx - left padding width
|
||||
0, // RightPy - right padding height
|
||||
0, // RightPx - right padding width
|
||||
"2x2 pooling NHWC"};
|
||||
bool pass = this->RunPool2D(config);
|
||||
EXPECT_TRUE(pass);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestCkTilePooling, Pool2D_3x3_WithPadding)
|
||||
{
|
||||
typename TestFixture::Config2D config = {2, // N - batch size
|
||||
16, // H - height dimension
|
||||
16, // W - width dimension
|
||||
32, // C - channel dimension
|
||||
3, // Y - pooling window height
|
||||
3, // X - pooling window width
|
||||
2, // Sy - window stride height
|
||||
2, // Sx - window stride width
|
||||
1, // Dy - window dilation height
|
||||
1, // Dx - window dilation width
|
||||
1, // LeftPy - left padding height
|
||||
1, // LeftPx - left padding width
|
||||
1, // RightPy - right padding height
|
||||
1, // RightPx - right padding width
|
||||
"3x3 pooling NHWC with padding"};
|
||||
bool pass = this->RunPool2D(config);
|
||||
EXPECT_TRUE(pass);
|
||||
}
|
||||
|
||||
// 3D Pooling Tests (NDHWC)
|
||||
TYPED_TEST(TestCkTilePooling, Pool3D_2x2x2)
|
||||
{
|
||||
typename TestFixture::Config3D config = {1, // N - batch size
|
||||
@@ -216,7 +405,7 @@ TYPED_TEST(TestCkTilePooling, Pool3D_2x2x2)
|
||||
0, // RightPz - right padding depth
|
||||
0, // RightPy - right padding height
|
||||
0, // RightPx - right padding width
|
||||
"2x2x2 pooling"};
|
||||
"2x2x2 pooling NDHWC"};
|
||||
bool pass = this->RunPool3D(config);
|
||||
EXPECT_TRUE(pass);
|
||||
}
|
||||
@@ -243,7 +432,7 @@ TYPED_TEST(TestCkTilePooling, Pool3D_3x3x3)
|
||||
1, // RightPz - right padding depth
|
||||
1, // RightPy - right padding height
|
||||
1, // RightPx - right padding width
|
||||
"3x3x3 pooling"};
|
||||
"3x3x3 pooling NDHWC with padding"};
|
||||
bool pass = this->RunPool3D(config);
|
||||
EXPECT_TRUE(pass);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user