Merge commit '3052d7c9e6972d5ea7d2225ab78e45554ba70efd' into develop

This commit is contained in:
assistant-librarian[bot]
2025-10-29 08:15:15 +00:00
parent e571490afc
commit 6f6c855c0e
13 changed files with 860 additions and 99 deletions

View File

@@ -38,7 +38,10 @@ auto create_args(int argc, char* argv[])
return std::make_tuple(result, arg_parser);
}
template <typename InDataType, typename OutDataType, typename ComputeDataType>
template <typename InDataType,
typename OutDataType,
typename ComputeDataType,
typename IndexDataType>
bool run(const ck_tile::ArgParser& arg_parser)
{
@@ -84,6 +87,9 @@ bool run(const ck_tile::ArgParser& arg_parser)
int warmup = arg_parser.get_int("warmup");
int repeat = arg_parser.get_int("repeat");
constexpr bool OutputIndex = true;
constexpr bool PropagateNan = false;
// Shapes / strides / parameters (NDHWC)
const auto input_shape = ck_tile::make_tuple(N, D, H, W, C);
const auto output_shape = ck_tile::make_tuple(N, Do, Ho, Wo, C);
@@ -100,11 +106,16 @@ bool run(const ck_tile::ArgParser& arg_parser)
{Do * Ho * Wo * C, Ho * Wo * C, Wo * C, C, 1});
ck_tile::HostTensor<OutDataType> out_ref({N, Do, Ho, Wo, C},
{Do * Ho * Wo * C, Ho * Wo * C, Wo * C, C, 1});
ck_tile::HostTensor<IndexDataType> out_index({N, Do, Ho, Wo, C},
{Do * Ho * Wo * C, Ho * Wo * C, Wo * C, C, 1});
ck_tile::HostTensor<IndexDataType> out_ref_index({N, Do, Ho, Wo, C},
{Do * Ho * Wo * C, Ho * Wo * C, Wo * C, C, 1});
ck_tile::FillUniformDistribution<InDataType>{-5.f, 5.f}(in);
ck_tile::DeviceMem in_buf(in.get_element_space_size_in_bytes());
ck_tile::DeviceMem out_buf(out.get_element_space_size_in_bytes());
ck_tile::DeviceMem out_index_buf(OutputIndex ? out_index.get_element_space_size_in_bytes() : 0);
in_buf.ToDevice(in.data());
@@ -118,10 +129,10 @@ bool run(const ck_tile::ArgParser& arg_parser)
using Problem = ck_tile::PoolProblem<InDataType,
OutDataType,
ComputeDataType,
OutDataType,
IndexDataType,
ReduceOp,
false,
false,
OutputIndex,
PropagateNan,
Shape>;
using Kernel = ck_tile::PoolKernel<Problem>;
@@ -131,6 +142,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
auto host_args = ck_tile::PoolHostArgs<decltype(input_shape), decltype(window_spatial_lengths)>{
static_cast<InDataType*>(in_buf.GetDeviceBuffer()),
static_cast<OutDataType*>(out_buf.GetDeviceBuffer()),
OutputIndex ? static_cast<IndexDataType*>(out_index_buf.GetDeviceBuffer()) : nullptr,
input_shape,
output_shape,
input_strides,
@@ -167,12 +179,28 @@ bool run(const ck_tile::ArgParser& arg_parser)
if(do_validation)
{
ck_tile::reference_pool3d<InDataType, ComputeDataType, OutDataType>(
in, out_ref, kernel_args, ReduceOp{});
out_buf.FromDevice(out.mData.data());
pass = ck_tile::check_err(out, out_ref);
std::cout << "valid:" << (pass ? "y" : "n") << std::flush << std::endl;
ck_tile::reference_pool3d<InDataType,
ComputeDataType,
OutDataType,
IndexDataType,
ReduceOp,
decltype(input_shape),
decltype(window_spatial_lengths),
OutputIndex>(in, out_ref, out_ref_index, kernel_args, ReduceOp{});
if constexpr(OutputIndex)
{
out_index_buf.FromDevice(out_index.mData.data());
pass = ck_tile::check_err(out, out_ref) && ck_tile::check_err(out_index, out_ref_index);
}
else
{
pass = ck_tile::check_err(out, out_ref);
}
std::cout << "valid:" << (pass ? "y" : "n") << std::endl;
}
return pass;
@@ -184,5 +212,5 @@ int main(int argc, char* argv[])
if(!result)
return -1;
return run<ck_tile::half_t, ck_tile::half_t, float>(arg_parser) ? 0 : -2;
return run<ck_tile::half_t, ck_tile::half_t, float, ck_tile::index_t>(arg_parser) ? 0 : -2;
}

View File

@@ -78,6 +78,7 @@
#include "ck_tile/core/utility/print.hpp"
#include "ck_tile/core/utility/random.hpp"
#include "ck_tile/core/utility/reduce_operator.hpp"
#include "ck_tile/core/utility/reduce_operator_accumulate.hpp"
#include "ck_tile/core/utility/static_counter.hpp"
#include "ck_tile/core/utility/to_sequence.hpp"
#include "ck_tile/core/utility/transpose_vectors.hpp"

View File

@@ -4,6 +4,7 @@
#pragma once
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/utility/type_traits.hpp"
namespace ck_tile {
@@ -18,16 +19,14 @@ struct Add
};
template <typename T,
typename = std::enable_if_t<std::is_same_v<T, float> || std::is_same_v<T, double> ||
std::is_same_v<T, int32_t> || std::is_same_v<T, int8_t>>>
typename = std::enable_if_t<is_any_of<T, float, double, int32_t, int8_t>::value>>
CK_TILE_HOST_DEVICE constexpr T operator()(const T& y, const T x) const
{
return y + x;
}
template <typename T,
typename = std::enable_if_t<std::is_same_v<T, half_t> || std::is_same_v<T, bf16_t> ||
std::is_same_v<T, fp8_t> || std::is_same_v<T, bf8_t>>>
typename = std::enable_if_t<is_any_of<T, half_t, bf16_t, fp8_t, bf8_t>::value>>
CK_TILE_HOST_DEVICE constexpr T operator()(T& y, T x) const
{
float y_ = type_convert<float>(y);
@@ -46,16 +45,14 @@ struct SquareAdd
};
template <typename T,
typename = std::enable_if_t<std::is_same_v<T, float> || std::is_same_v<T, double> ||
std::is_same_v<T, int32_t> || std::is_same_v<T, int8_t>>>
typename = std::enable_if_t<is_any_of<T, float, double, int32_t, int8_t>::value>>
CK_TILE_HOST_DEVICE constexpr T operator()(const T& y, const T x) const
{
return y + (x * x);
}
template <typename T,
typename = std::enable_if_t<std::is_same_v<T, half_t> || std::is_same_v<T, bf16_t> ||
std::is_same_v<T, fp8_t> || std::is_same_v<T, bf8_t>>>
typename = std::enable_if_t<is_any_of<T, half_t, bf16_t, fp8_t, bf8_t>::value>>
CK_TILE_HOST_DEVICE constexpr T operator()(T& y, T x) const
{
float y_ = type_convert<float>(y);
@@ -66,48 +63,74 @@ struct SquareAdd
struct Max
{
template <typename T,
typename = std::enable_if_t<std::is_same_v<T, float> || std::is_same_v<T, double> ||
std::is_same_v<T, int32_t> || std::is_same_v<T, int8_t> ||
std::is_same_v<T, half_t> || std::is_same_v<T, bf16_t> ||
std::is_same_v<T, fp8_t> || std::is_same_v<T, bf8_t>>>
template <
typename T,
typename = std::enable_if_t<
is_any_of<T, float, double, int32_t, int8_t, half_t, bf16_t, fp8_t, bf8_t>::value>>
CK_TILE_HOST_DEVICE static constexpr T GetIdentityValue()
{
return numeric<T>::lowest();
};
template <typename T,
typename = std::enable_if_t<std::is_same_v<T, float> || std::is_same_v<T, double> ||
std::is_same_v<T, int32_t> || std::is_same_v<T, int8_t> ||
std::is_same_v<T, half_t> || std::is_same_v<T, bf16_t> ||
std::is_same_v<T, fp8_t> || std::is_same_v<T, bf8_t>>>
template <
typename T,
typename = std::enable_if_t<
is_any_of<T, float, double, int32_t, int8_t, half_t, bf16_t, fp8_t, bf8_t>::value>>
CK_TILE_HOST_DEVICE constexpr T operator()(const T& y, const T x) const
{
return max(y, x);
}
// Overload with changed flag for index tracking
template <
typename T,
typename = std::enable_if_t<
is_any_of<T, float, double, int32_t, int8_t, half_t, bf16_t, fp8_t, bf8_t>::value>>
CK_TILE_HOST_DEVICE constexpr T operator()(const T& y, const T x, bool& changed) const
{
T new_max = max(y, x);
if(x > y)
{
changed = true;
}
return new_max;
}
};
struct AbsMax
{
template <typename T,
typename = std::enable_if_t<std::is_same_v<T, float> || std::is_same_v<T, double> ||
std::is_same_v<T, int32_t> || std::is_same_v<T, int8_t> ||
std::is_same_v<T, half_t> || std::is_same_v<T, bf16_t> ||
std::is_same_v<T, fp8_t> || std::is_same_v<T, bf8_t>>>
template <
typename T,
typename = std::enable_if_t<
is_any_of<T, float, double, int32_t, int8_t, half_t, bf16_t, fp8_t, bf8_t>::value>>
CK_TILE_HOST_DEVICE static constexpr T GetIdentityValue()
{
return numeric<T>::zero();
};
template <typename T,
typename = std::enable_if_t<std::is_same_v<T, float> || std::is_same_v<T, double> ||
std::is_same_v<T, int32_t> || std::is_same_v<T, int8_t> ||
std::is_same_v<T, half_t> || std::is_same_v<T, bf16_t> ||
std::is_same_v<T, fp8_t> || std::is_same_v<T, bf8_t>>>
template <
typename T,
typename = std::enable_if_t<
is_any_of<T, float, double, int32_t, int8_t, half_t, bf16_t, fp8_t, bf8_t>::value>>
CK_TILE_HOST_DEVICE constexpr T operator()(const T& y, const T x) const
{
return max(y, abs(x));
}
// Overload with changed flag for index tracking
template <
typename T,
typename = std::enable_if_t<
is_any_of<T, float, double, int32_t, int8_t, half_t, bf16_t, fp8_t, bf8_t>::value>>
CK_TILE_HOST_DEVICE constexpr T operator()(const T& y, const T x, bool& changed) const
{
T new_max = max(y, abs(x));
if(abs(x) > y)
{
changed = true;
}
return new_max;
}
};
} // namespace ReduceOp

View File

@@ -0,0 +1,50 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/config.hpp"
namespace ck_tile {
/// @brief Accumulate with index tracking reductions, provides deterministic first occurring index
struct AccumulateWithIndex
{
template <typename ReduceOp, typename T, typename IndexType>
CK_TILE_HOST_DEVICE void operator()(const ReduceOp& reduce_func,
T& current_value,
IndexType& current_index,
const T& new_value,
const IndexType& new_index) const
{
bool changed = false;
current_value = reduce_func(current_value, new_value, changed);
if(changed)
{
current_index = new_index;
}
else if(new_index < current_index)
{
bool reverse_changed = false;
reduce_func(new_value, current_value, reverse_changed);
if(!reverse_changed)
{
current_index = new_index;
}
}
}
};
struct Accumulate
{
template <typename ReduceOp, typename T>
CK_TILE_HOST_DEVICE void
operator()(const ReduceOp& reduce_func, T& current_value, const T& new_value) const
{
current_value = reduce_func(current_value, new_value);
}
};
} // namespace ck_tile

View File

@@ -7,17 +7,21 @@
#include "ck_tile/host/host_tensor.hpp"
#include "ck_tile/ops/pooling/kernel/pool_kernel.hpp"
#include <thread>
#include <cmath>
namespace ck_tile {
template <typename InDataType,
typename ComputeDataType,
typename OutDataType,
typename IndexDataType,
typename ReduceOp,
typename TensorShape,
typename WindowShape>
typename WindowShape,
bool OutputIndex = false>
CK_TILE_HOST void reference_pool2d(const HostTensor<InDataType>& input,
HostTensor<OutDataType>& output,
HostTensor<IndexDataType>& output_index,
PoolKernelArgs<TensorShape, WindowShape> kargs,
ReduceOp reduce_op)
{
@@ -45,6 +49,8 @@ CK_TILE_HOST void reference_pool2d(const HostTensor<InDataType>& input,
auto f = [&](auto n, auto ho, auto wo, auto c) {
ComputeDataType v_acc = reduce_op.template GetIdentityValue<ComputeDataType>();
IndexDataType current_index = 0; // Declare outside if constexpr for efficiency
for(ck_tile::index_t y = 0; y < Y; ++y)
{
// Calculate input height index with stride, dilation, and padding
@@ -58,13 +64,32 @@ CK_TILE_HOST void reference_pool2d(const HostTensor<InDataType>& input,
if(hi >= 0 && hi < H && wi >= 0 && wi < W)
{
const ComputeDataType v_in = type_convert<ComputeDataType>(input(n, hi, wi, c));
v_acc = reduce_op(v_acc, v_in);
if constexpr(OutputIndex)
{
IndexDataType flat_index = input.GetOffsetFromMultiIndex(n, hi, wi, c);
bool changed = false;
v_acc = reduce_op(v_acc, v_in, changed);
if(changed)
{
current_index = flat_index;
}
}
else
{
v_acc = reduce_op(v_acc, v_in);
}
}
// For positions outside bounds, we implicitly use identity value
}
}
output(n, ho, wo, c) = ck_tile::type_convert<OutDataType>(v_acc);
if constexpr(OutputIndex)
{
output_index(n, ho, wo, c) = current_index;
}
};
// Parallelize over all output dimensions
@@ -74,11 +99,14 @@ CK_TILE_HOST void reference_pool2d(const HostTensor<InDataType>& input,
template <typename InDataType,
typename ComputeDataType,
typename OutDataType,
typename IndexDataType,
typename ReduceOp,
typename TensorShape,
typename WindowShape>
typename WindowShape,
bool OutputIndex = false>
CK_TILE_HOST void reference_pool3d(const HostTensor<InDataType>& input,
HostTensor<OutDataType>& output,
HostTensor<IndexDataType>& output_index,
PoolKernelArgs<TensorShape, WindowShape> kargs,
ReduceOp reduce_op)
{
@@ -112,6 +140,8 @@ CK_TILE_HOST void reference_pool3d(const HostTensor<InDataType>& input,
auto f = [&](auto n, auto do_, auto ho, auto wo, auto c) {
ComputeDataType v_acc = reduce_op.template GetIdentityValue<ComputeDataType>();
IndexDataType current_index = 0; // Declare outside if constexpr for efficiency
for(ck_tile::index_t z = 0; z < Z; ++z)
{
// Calculate input depth index with stride, dilation, and padding
@@ -131,7 +161,22 @@ CK_TILE_HOST void reference_pool3d(const HostTensor<InDataType>& input,
{
const ComputeDataType v_in =
type_convert<ComputeDataType>(input(n, di, hi, wi, c));
v_acc = reduce_op(v_acc, v_in);
if constexpr(OutputIndex)
{
IndexDataType flat_index =
input.GetOffsetFromMultiIndex(n, di, hi, wi, c);
bool changed = false;
v_acc = reduce_op(v_acc, v_in, changed);
if(changed)
{
current_index = flat_index;
}
}
else
{
v_acc = reduce_op(v_acc, v_in);
}
}
// For positions outside bounds, we implicitly use identity value
}
@@ -139,10 +184,15 @@ CK_TILE_HOST void reference_pool3d(const HostTensor<InDataType>& input,
}
output(n, do_, ho, wo, c) = ck_tile::type_convert<OutDataType>(v_acc);
if constexpr(OutputIndex)
{
output_index(n, do_, ho, wo, c) = current_index;
}
};
// Parallelize over all output dimensions
make_ParallelTensorFunctor(f, N, Do, Ho, Wo, C)(std::thread::hardware_concurrency());
}
} // namespace ck_tile

View File

@@ -17,6 +17,7 @@ struct PoolHostArgs
CK_TILE_HOST PoolHostArgs(const void* input_ptr_,
void* output_ptr_,
void* output_index_ptr_,
TensorShape input_shape_,
TensorShape output_shape_,
TensorShape input_strides_,
@@ -28,6 +29,7 @@ struct PoolHostArgs
WindowShape input_right_pads_)
: input_ptr(input_ptr_),
output_ptr(output_ptr_),
output_index_ptr(output_index_ptr_),
input_shape(input_shape_),
output_shape(output_shape_),
input_strides(input_strides_),
@@ -42,6 +44,7 @@ struct PoolHostArgs
const void* input_ptr;
void* output_ptr;
void* output_index_ptr;
TensorShape input_shape;
TensorShape output_shape;
@@ -60,6 +63,7 @@ struct PoolKernelArgs
{
const void* input_ptr;
void* output_ptr;
void* output_index_ptr;
TensorShape input_shape;
TensorShape output_shape;
TensorShape input_strides;
@@ -80,6 +84,7 @@ struct PoolKernel
using InDataType = ck_tile::remove_cvref_t<typename Problem::InDataType>;
using ComputeDataType = ck_tile::remove_cvref_t<typename Problem::ComputeDataType>;
using OutDataType = ck_tile::remove_cvref_t<typename Problem::OutDataType>;
using IndexDataType = ck_tile::remove_cvref_t<typename Problem::IndexDataType>;
static constexpr index_t kBlockSize = Problem::BlockShape::BlockSize;
@@ -205,7 +210,23 @@ struct PoolKernel
tensor_view<decltype(out_buffer_view), decltype(out_desc_padded)>{out_buffer_view,
out_desc_padded};
return make_tuple(in_tensor_padded, out_tensor_padded);
if constexpr(Problem::kOutputIndex)
{
auto out_index_buffer_view = make_buffer_view<address_space_enum::global>(
static_cast<IndexDataType*>(kargs.output_index_ptr),
out_desc.get_element_space_size(),
IndexDataType(-1));
const auto out_index_tensor_padded =
tensor_view<decltype(out_index_buffer_view), decltype(out_desc_padded)>{
out_index_buffer_view, out_desc_padded};
return make_tuple(in_tensor_padded, out_tensor_padded, out_index_tensor_padded);
}
else
{
// Return a dummy tensor for the third element when index output is not needed
return make_tuple(in_tensor_padded, out_tensor_padded, null_tensor{});
}
}
template <typename TensorShape, typename WindowShape>
@@ -338,7 +359,23 @@ struct PoolKernel
tensor_view<decltype(out_buffer_view), decltype(out_desc_padded)>{out_buffer_view,
out_desc_padded};
return make_tuple(in_tensor_padded, out_tensor_padded);
if constexpr(Problem::kOutputIndex)
{
auto out_index_buffer_view = make_buffer_view<address_space_enum::global>(
static_cast<IndexDataType*>(kargs.output_index_ptr),
out_desc.get_element_space_size(),
IndexDataType(-1));
const auto out_index_tensor_padded =
tensor_view<decltype(out_index_buffer_view), decltype(out_desc_padded)>{
out_index_buffer_view, out_desc_padded};
return make_tuple(in_tensor_padded, out_tensor_padded, out_index_tensor_padded);
}
else
{
// Return a dummy tensor for the third element when index output is not needed
return make_tuple(in_tensor_padded, out_tensor_padded, null_tensor{});
}
}
public:
@@ -354,7 +391,7 @@ struct PoolKernel
const auto iM = get_block_id() * S::Block_M;
// Get tensors based on dimensionality
auto [in_tensor_padded, out_tensor_padded] = [&]() {
auto [in_tensor_padded, out_tensor_padded, out_index_tensor_padded] = [&]() {
if constexpr(WindowShape::size() == 2)
return MakeTensorView2D(kargs);
else if constexpr(WindowShape::size() == 3)
@@ -387,16 +424,57 @@ struct PoolKernel
auto y_tile = block_reduce2d.template MakeYBlockTile<XTensorTile>();
set_tile(y_tile, reduce_op.template GetIdentityValue<ComputeDataType>());
for(int k_tile = __builtin_amdgcn_readfirstlane(0); k_tile < num_k_tiles; ++k_tile)
if constexpr(Problem::kOutputIndex)
{
const auto x_tile = load_tile(x_window);
block_reduce2d(x_tile, y_tile, reduce_op);
move_tile_window(x_window, {0, S::Block_N});
}
auto y_index_window =
make_tile_window(out_index_tensor_padded, make_tuple(number<S::Block_M>{}), {iM});
block_reduce2d_sync(y_tile, reduce_op);
block_reduce2d_cross_warp(y_tile, smem, reduce_op);
store_tile(y_window, cast_tile<OutDataType>(y_tile));
auto y_index_tile =
block_reduce2d.template MakeYIndexBlockTile<XTensorTile, IndexDataType>();
set_tile(y_index_tile, IndexDataType(0));
// Main reduction loop - with index tracking
for(int k_tile = amd_wave_read_first_lane(0); k_tile < num_k_tiles; ++k_tile)
{
const auto x_tile = load_tile(x_window);
auto index_calculator = [&](const auto& x_indices) {
// Get global coordinates in the 2D matrix space (M, N)
const auto global_M = x_indices.at(number<0>{}) + iM;
const auto global_N = (k_tile * S::Block_N) + x_indices.at(number<1>{});
return in_tensor_padded.get_tensor_descriptor().calculate_offset(
make_tuple(global_M, global_N));
};
block_reduce2d(x_tile, y_tile, y_index_tile, reduce_op, index_calculator);
move_tile_window(x_window, {0, S::Block_N});
}
block_reduce2d_sync(y_tile, y_index_tile, reduce_op);
if constexpr(Problem::kNeedCrossWarpSync)
{
__shared__ char smem_indices[Policy::template GetIndicesSmemSize<Problem>()];
block_reduce2d_cross_warp(y_tile, y_index_tile, smem, smem_indices, reduce_op);
}
store_tile(y_window, cast_tile<OutDataType>(y_tile));
store_tile(y_index_window, cast_tile<IndexDataType>(y_index_tile));
}
else
{
// Main reduction loop - without index tracking
for(int k_tile = __builtin_amdgcn_readfirstlane(0); k_tile < num_k_tiles; ++k_tile)
{
const auto x_tile = load_tile(x_window);
block_reduce2d(x_tile, y_tile, reduce_op);
move_tile_window(x_window, {0, S::Block_N});
}
block_reduce2d_sync(y_tile, reduce_op);
block_reduce2d_cross_warp(y_tile, smem, reduce_op);
store_tile(y_window, cast_tile<OutDataType>(y_tile));
}
}
/// @brief Validates if the given arguments are supported by the pooling kernel.
@@ -481,6 +559,7 @@ struct PoolKernel
{
return PoolKernelArgs<TensorShape, WindowShape>{host_args.input_ptr,
host_args.output_ptr,
host_args.output_index_ptr,
host_args.input_shape,
host_args.output_shape,
host_args.input_strides,

View File

@@ -32,7 +32,8 @@ struct PoolDefaultPolicy
{
using P_ = BlockReduce2dProblem<typename Problem::InDataType,
typename Problem::ComputeDataType,
typename Problem::BlockShape>;
typename Problem::BlockShape,
Problem::kOutputIndex>;
return BlockReduce2d<P_>{};
}
@@ -41,7 +42,8 @@ struct PoolDefaultPolicy
{
using P_ = BlockReduce2dProblem<typename Problem::InDataType,
typename Problem::ComputeDataType,
typename Problem::BlockShape>;
typename Problem::BlockShape,
Problem::kOutputIndex>;
return BlockReduce2dSync<P_>{};
}
@@ -50,7 +52,8 @@ struct PoolDefaultPolicy
{
using P_ = BlockReduce2dProblem<typename Problem::InDataType,
typename Problem::ComputeDataType,
typename Problem::BlockShape>;
typename Problem::BlockShape,
Problem::kOutputIndex>;
return BlockReduce2dCrossWarpSync<P_>{};
}
@@ -61,7 +64,8 @@ struct PoolDefaultPolicy
{
using P_ = BlockReduce2dProblem<typename Problem::InDataType,
typename Problem::ComputeDataType,
typename Problem::BlockShape>;
typename Problem::BlockShape,
Problem::kOutputIndex>;
using block_reduce2d = BlockReduce2d<P_>;
using x_block_tile =
@@ -76,5 +80,25 @@ struct PoolDefaultPolicy
return 1; // zero size arrays are an extension
}
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr index_t GetIndicesSmemSize()
{
using P_ = BlockReduce2dProblem<typename Problem::InDataType,
typename Problem::ComputeDataType,
typename Problem::BlockShape,
Problem::kOutputIndex>;
using block_reduce2d = BlockReduce2d<P_>;
using x_block_tile = decltype(make_static_distributed_tensor<typename Problem::InDataType>(
MakeXBlockTileDistribution<Problem>()));
using y_index_block_tile = decltype(block_reduce2d::template MakeYIndexBlockTile<
x_block_tile,
typename Problem::IndexDataType>());
return GetBlockReduce2dCrossWarpSync<Problem>()
.template GetIndicesSmemSize<y_index_block_tile>();
}
};
} // namespace ck_tile

View File

@@ -26,6 +26,8 @@ struct PoolProblem
using OutputIndex = bool_constant<OutputIndex_>;
using PropagateNan = bool_constant<PropagateNan_>;
static constexpr bool kOutputIndex = OutputIndex_;
static constexpr bool kPropagateNan = PropagateNan_;
static constexpr bool kNeedCrossLaneSync = BlockShape::ThreadPerWarp_N > 1;
static constexpr bool kNeedCrossWarpSync = BlockShape::WarpPerBlock_N > 1;
};

View File

@@ -4,6 +4,7 @@
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/core/utility/reduce_operator_accumulate.hpp"
namespace ck_tile {
@@ -50,6 +51,53 @@ struct BlockReduce2d
CK_TILE_DEVICE constexpr BlockReduce2d() {}
private:
template <bool kProcessIndex,
typename XDistributedTensor_,
typename YDistributedTensor_,
typename YIndexDistributedTensor_,
typename ReduceFunc,
typename IndexCalculatorFunc,
typename ReducePacksPerXDim>
CK_TILE_DEVICE void reduce_impl(const XDistributedTensor_& x_tensor,
YDistributedTensor_& y_tensor,
YIndexDistributedTensor_& y_index_tensor,
const ReduceFunc& reduce_func,
const IndexCalculatorFunc& index_calculator,
ReducePacksPerXDim)
{
sweep_tile<XDistributedTensor_>(
[&](auto... idx_) {
constexpr auto idx_0 = make_tuple(make_tuple(idx_[number<0>{}]...)[number<0>{}]);
(..., [&](auto idx) {
auto val = ck_tile::type_convert<ComputeDataType>(x_tensor[idx]);
if constexpr(kProcessIndex)
{
const auto x_indices = get_x_indices_from_distributed_indices(
XDistributedTensor_::get_tile_distribution(), idx);
const auto new_idx = index_calculator(x_indices);
auto current_idx = y_index_tensor(idx_0);
AccumulateWithIndex{}(
reduce_func, y_tensor(idx_0), current_idx, val, new_idx);
y_index_tensor(idx_0) =
type_convert<typename YIndexDistributedTensor_::DataType>(current_idx);
}
else
{
Accumulate{}(reduce_func, y_tensor(idx_0), val);
}
}(idx_));
},
ReducePacksPerXDim{});
}
public:
// Overload for non-index tracking
template <
typename XDistributedTensor_,
typename YDistributedTensor_,
@@ -61,13 +109,36 @@ struct BlockReduce2d
const ReduceFunc& reduce_func,
ReducePacksPerXDim = {})
{
sweep_tile<XDistributedTensor_>(
[&](auto... idx_) {
constexpr auto idx_0 = make_tuple(make_tuple(idx_[number<0>{}]...)[number<0>{}]);
y_tensor(idx_0) = reduce_func(
y_tensor(idx_0), ck_tile::type_convert<ComputeDataType>(x_tensor[idx_])...);
},
reduce_impl<false>(
x_tensor,
y_tensor,
y_tensor, // dummy
reduce_func,
[](auto) { return 0; }, // dummy
ReducePacksPerXDim{});
}
// Overload for index tracking
template <typename XDistributedTensor_,
typename YDistributedTensor_,
typename YIndexDistributedTensor_,
typename ReduceFunc,
typename IndexCalculatorFunc,
typename ReducePacksPerXDim = uniform_sequence_gen_t<2, 1>>
CK_TILE_DEVICE void operator()(const XDistributedTensor_& x_tensor,
YDistributedTensor_& y_tensor,
YIndexDistributedTensor_& y_index_tensor,
const ReduceFunc& reduce_func,
const IndexCalculatorFunc& index_calculator,
ReducePacksPerXDim = {})
{
reduce_impl<Problem::kOutputIndex>(x_tensor,
y_tensor,
y_index_tensor,
reduce_func,
index_calculator,
ReducePacksPerXDim{});
}
#if 0
constexpr auto I0 = number<0>{};
@@ -90,7 +161,6 @@ struct BlockReduce2d
y_tensor(y_dstr_idx) = y;
});
#endif
}
template <typename XDistributedTensor_>
CK_TILE_DEVICE static auto MakeYBlockTile()
@@ -111,6 +181,25 @@ struct BlockReduce2d
return tensor;
}
template <typename XDistributedTensor_, typename IndexDataType = index_t>
CK_TILE_DEVICE static auto MakeYIndexBlockTile()
{
static_assert(std::is_same_v<XDataType, typename XDistributedTensor_::DataType>, "wrong!");
// FIXME: hard coded to reduce 2nd axis
constexpr auto reduce_dims = sequence<1>{};
constexpr auto dstr =
make_static_tile_distribution(detail::make_reduce_tile_distribution_encoding(
XDistributedTensor_::get_tile_distribution()
.get_static_tile_distribution_encoding(),
reduce_dims));
auto tensor = make_static_distributed_tensor<IndexDataType>(dstr);
return tensor;
}
// uniform_sequence_gen_t<NSize, Value> generates sequence of NSize elements filled with Value
// e.g., uniform_sequence_gen_t<2, 1> → {1, 1} and uniform_sequence_gen_t<3, 4> → {4, 4, 4}
template <typename XDistributedTensor_,
@@ -135,8 +224,14 @@ struct BlockReduce2dSync
{
using Problem = remove_cvref_t<Problem_>;
template <typename YDistributedTensor_, typename ReduceFunc>
CK_TILE_DEVICE void operator()(YDistributedTensor_& y_tensor, const ReduceFunc& reduce_func)
private:
template <bool kProcessIndex,
typename YDistributedTensor_,
typename YIndexDistributedTensor_,
typename ReduceFunc>
CK_TILE_DEVICE void reduce_impl(YDistributedTensor_& y_tensor,
YIndexDistributedTensor_& y_index_tensor,
const ReduceFunc& reduce_func)
{
using Dstr = typename YDistributedTensor_::StaticTileDistribution;
using DstrEncode = typename Dstr::DstrEncode;
@@ -157,6 +252,14 @@ struct BlockReduce2dSync
static_for<0, thread_buf_size, 1>{}([&](auto i) {
auto v_local = y_tensor.get_thread_buffer()[i];
using IndexDataType = typename YIndexDistributedTensor_::DataType;
IndexDataType idx_local{};
if constexpr(kProcessIndex)
{
idx_local = y_index_tensor.get_thread_buffer()[i];
}
// cross-lane reduce for replication
// only reduce on R dimension correspond to lane
// (lane id maps to this R dimension)
@@ -183,15 +286,46 @@ struct BlockReduce2dSync
// pull data from remote lane
const auto v_remote = warp_shuffle(v_local, src_lane);
v_local = reduce_func(v_local, v_remote);
if constexpr(kProcessIndex)
{
const auto idx_remote = warp_shuffle(idx_local, src_lane);
AccumulateWithIndex{}(
reduce_func, v_local, idx_local, v_remote, idx_remote);
}
else
{
Accumulate{}(reduce_func, v_local, v_remote);
}
});
}
});
// TODO - Do we need to broadcast to other lane?
y_tensor.get_thread_buffer()(i) = v_local;
if constexpr(kProcessIndex)
{
y_index_tensor.get_thread_buffer()(i) = idx_local;
}
});
}
public:
template <typename YDistributedTensor_, typename ReduceFunc>
CK_TILE_DEVICE void operator()(YDistributedTensor_& y_tensor, const ReduceFunc& reduce_func)
{
reduce_impl<false>(y_tensor, y_tensor, reduce_func);
}
template <typename YDistributedTensor_, typename YIndexDistributedTensor_, typename ReduceFunc>
CK_TILE_DEVICE void operator()(YDistributedTensor_& y_tensor,
YIndexDistributedTensor_& y_index_tensor,
const ReduceFunc& reduce_func)
{
reduce_impl<Problem::kOutputIndex>(y_tensor, y_index_tensor, reduce_func);
}
};
// BlockReduce2dCrossWarpSync: Cross-warp reduction (Stage 3)
@@ -250,15 +384,39 @@ struct BlockReduce2dCrossWarpSync
return num_warps * thread_buf_size * sizeof(DataType);
}
template <typename YDistributedTensor_, typename ReduceFunc>
CK_TILE_DEVICE void
operator()(YDistributedTensor_& y_tensor, void* smem, const ReduceFunc& reduce_func)
// return in byte - separate shared memory size calculation for indices
template <typename YIndexDistributedTensor_>
CK_TILE_HOST_DEVICE static constexpr index_t GetIndicesSmemSize()
{
using DataType = typename YDistributedTensor_::DataType;
using IndexDataType = typename YIndexDistributedTensor_::DataType;
constexpr index_t thread_buf_size = YIndexDistributedTensor_::get_thread_buffer_size();
constexpr index_t num_warps = BlockShape::BlockSize / get_warp_size();
return num_warps * thread_buf_size * sizeof(IndexDataType);
}
private:
template <bool kProcessIndex,
typename YDistributedTensor_,
typename YIndexDistributedTensor_,
typename ReduceFunc>
CK_TILE_DEVICE void reduce_impl(YDistributedTensor_& y_tensor,
YIndexDistributedTensor_& y_index_tensor,
void* smem,
void* smem_indices_ptr,
const ReduceFunc& reduce_func)
{
using DataType = typename YDistributedTensor_::DataType;
using IndexDataType = typename YIndexDistributedTensor_::DataType;
constexpr index_t thread_buf_size = YDistributedTensor_::get_thread_buffer_size();
DataType* smem_ptr = reinterpret_cast<DataType*>(smem);
DataType* smem_ptr = reinterpret_cast<DataType*>(smem);
IndexDataType* smem_indices = nullptr;
if constexpr(kProcessIndex)
{
smem_indices = reinterpret_cast<IndexDataType*>(smem_indices_ptr);
}
const index_t lane_id = get_lane_id();
const index_t warp_id = get_warp_id();
@@ -275,6 +433,11 @@ struct BlockReduce2dCrossWarpSync
static_for<0, thread_buf_size, 1>{}([&](auto i) {
// Store the i-th element of this warp's thread_buffer into SMEM
smem_ptr[smem_offset + i * num_warps] = y_tensor.get_thread_buffer()[i];
if constexpr(kProcessIndex)
{
smem_indices[smem_offset + i * num_warps] =
y_index_tensor.get_thread_buffer()[i];
}
});
}
block_sync_lds();
@@ -282,10 +445,19 @@ struct BlockReduce2dCrossWarpSync
// We let each warp holds a duplication to do reduction.
const index_t local_warp_id = warp_id / num_reduce_warps;
const index_t local_smem_os = local_warp_id * num_reduce_warps;
static_for<0, thread_buf_size, 1>{}([&](auto i) {
DataType v[num_reduce_warps];
static_for<0, num_reduce_warps, 1>{}(
[&](auto idx) { v[idx] = smem_ptr[i * num_warps + local_smem_os + idx]; });
[[maybe_unused]] std::
conditional_t<kProcessIndex, IndexDataType[num_reduce_warps], IndexDataType> idx_v;
static_for<0, num_reduce_warps, 1>{}([&](auto idx) {
v[idx] = smem_ptr[i * num_warps + local_smem_os + idx];
if constexpr(kProcessIndex)
{
idx_v[idx] = smem_indices[i * num_warps + local_smem_os + idx];
}
});
static_assert(is_power_of_two_integer(num_reduce_warps),
"wrong! only support power of 2 reduction");
@@ -299,14 +471,44 @@ struct BlockReduce2dCrossWarpSync
constexpr index_t i1 = idx_ + stride;
if constexpr(i1 < num_reduce_warps)
{
v[i0] = reduce_func(v[i0], v[i1]);
if constexpr(kProcessIndex)
{
AccumulateWithIndex{}(reduce_func, v[i0], idx_v[i0], v[i1], idx_v[i1]);
}
else
{
Accumulate{}(reduce_func, v[i0], v[i1]);
}
}
});
});
y_tensor.get_thread_buffer()(i) = v[0];
if constexpr(kProcessIndex)
{
y_index_tensor.get_thread_buffer()(i) = idx_v[0];
}
});
}
public:
template <typename YDistributedTensor_, typename ReduceFunc>
CK_TILE_DEVICE void
operator()(YDistributedTensor_& y_tensor, void* smem, const ReduceFunc& reduce_func)
{
reduce_impl<false>(y_tensor, y_tensor, smem, nullptr, reduce_func);
}
template <typename YDistributedTensor_, typename YIndexDistributedTensor_, typename ReduceFunc>
CK_TILE_DEVICE void operator()(YDistributedTensor_& y_tensor,
YIndexDistributedTensor_& y_index_tensor,
void* smem,
void* smem_indices,
const ReduceFunc& reduce_func)
{
reduce_impl<Problem::kOutputIndex>(
y_tensor, y_index_tensor, smem, smem_indices, reduce_func);
}
};
template <typename Problem_, typename Policy_ = void>
@@ -364,15 +566,39 @@ struct BlockReduce2dLinearCrossWarpSync
return num_warps * thread_buf_size * sizeof(DataType);
}
template <typename YDistributedTensor_, typename ReduceFunc>
CK_TILE_DEVICE void
operator()(YDistributedTensor_& y_tensor, void* smem, const ReduceFunc& reduce_func)
// return in byte - separate shared memory size calculation for indices
template <typename YIndexDistributedTensor_>
CK_TILE_HOST_DEVICE static constexpr index_t GetIndicesSmemSize()
{
using DataType = typename YDistributedTensor_::DataType;
using IndexDataType = typename YIndexDistributedTensor_::DataType;
constexpr index_t thread_buf_size = YIndexDistributedTensor_::get_thread_buffer_size();
constexpr index_t num_warps = BlockShape::BlockSize / get_warp_size();
return num_warps * thread_buf_size * sizeof(IndexDataType);
}
private:
template <bool kProcessIndex,
typename YDistributedTensor_,
typename YIndexDistributedTensor_,
typename ReduceFunc>
CK_TILE_DEVICE void reduce_impl(YDistributedTensor_& y_tensor,
YIndexDistributedTensor_& y_index_tensor,
void* smem,
void* smem_indices_ptr,
const ReduceFunc& reduce_func)
{
using DataType = typename YDistributedTensor_::DataType;
using IndexDataType = typename YIndexDistributedTensor_::DataType;
constexpr index_t thread_buf_size = YDistributedTensor_::get_thread_buffer_size();
DataType* smem_ptr = reinterpret_cast<DataType*>(smem);
DataType* smem_ptr = reinterpret_cast<DataType*>(smem);
IndexDataType* smem_indices = nullptr;
if constexpr(kProcessIndex)
{
smem_indices = reinterpret_cast<IndexDataType*>(smem_indices_ptr);
}
const index_t lane_id = get_lane_id();
const index_t warp_id = get_warp_id();
constexpr auto num_reduce_warps = GetReduceWarps<YDistributedTensor_>();
@@ -388,6 +614,11 @@ struct BlockReduce2dLinearCrossWarpSync
{
static_for<0, thread_buf_size, 1>{}([&](auto i) {
smem_ptr[smem_offset + i * num_warps] = y_tensor.get_thread_buffer()[i];
if constexpr(kProcessIndex)
{
smem_indices[smem_offset + i * num_warps] =
y_index_tensor.get_thread_buffer()[i];
}
});
}
block_sync_lds();
@@ -395,31 +626,86 @@ struct BlockReduce2dLinearCrossWarpSync
// load from smem. here we let everythread to do compute :)
index_t local_warp_id = warp_id / num_reduce_warps;
index_t local_smem_os = local_warp_id * num_reduce_warps;
DataType all_scratch[thread_buf_size * num_reduce_warps];
[[maybe_unused]] std::conditional_t<kProcessIndex,
IndexDataType[thread_buf_size * num_reduce_warps],
IndexDataType> all_indices;
// Load data from shared memory
static_for<0, thread_buf_size, 1>{}([&](auto i_0) {
static_for<0, num_reduce_warps, 1>{}([&](auto i_1) {
all_scratch[i_0 * num_reduce_warps + i_1] =
smem_ptr[i_0 * num_warps + local_smem_os + i_1];
if constexpr(kProcessIndex)
{
all_indices[i_0 * num_reduce_warps + i_1] =
smem_indices[i_0 * num_warps + local_smem_os + i_1];
}
});
});
block_sync_lds(); // TODO: we don't need sync here
// Perform reduction
static_for<0, thread_buf_size, 1>{}([&](auto i_0) {
// TODO: use descriptor for this
auto v_local = all_scratch[i_0 * num_reduce_warps];
IndexDataType idx_local{};
if constexpr(kProcessIndex)
{
idx_local = all_indices[i_0 * num_reduce_warps];
}
// further reduce mean/var
static_for<0, num_reduce_warps - 1, 1>{}([&](auto i_1_n1) {
constexpr auto i_1 = number<i_1_n1 + 1>{};
const DataType v_remote = all_scratch[i_0 * num_reduce_warps + i_1];
// reduce
v_local = reduce_func(v_local, v_remote);
if constexpr(kProcessIndex)
{
const IndexDataType idx_remote = all_indices[i_0 * num_reduce_warps + i_1];
bool changed = false;
v_local = reduce_func(v_local, v_remote, changed);
if(changed)
{
idx_local = idx_remote;
}
}
else
{
v_local = reduce_func(v_local, v_remote);
}
});
y_tensor.get_thread_buffer()(i_0) = v_local;
if constexpr(kProcessIndex)
{
y_index_tensor.get_thread_buffer()(i_0) = idx_local;
}
});
}
public:
template <typename YDistributedTensor_, typename ReduceFunc>
CK_TILE_DEVICE void
operator()(YDistributedTensor_& y_tensor, void* smem, const ReduceFunc& reduce_func)
{
reduce_impl<false>(y_tensor, y_tensor, smem, nullptr, reduce_func);
}
template <typename YDistributedTensor_, typename YIndexDistributedTensor_, typename ReduceFunc>
CK_TILE_DEVICE void operator()(YDistributedTensor_& y_tensor,
YIndexDistributedTensor_& y_index_tensor,
void* smem,
void* smem_indices,
const ReduceFunc& reduce_func)
{
reduce_impl<Problem::kOutputIndex>(
y_tensor, y_index_tensor, smem, smem_indices, reduce_func);
}
};
} // namespace ck_tile

View File

@@ -7,12 +7,17 @@
namespace ck_tile {
template <typename XDataType_, typename ComputeDataType_, typename BlockShape_>
template <typename XDataType_,
typename ComputeDataType_,
typename BlockShape_,
bool OutputIndex_ = false>
struct BlockReduce2dProblem
{
using XDataType = remove_cvref_t<XDataType_>;
using ComputeDataType = remove_cvref_t<ComputeDataType_>;
using BlockShape = remove_cvref_t<BlockShape_>;
static constexpr bool kOutputIndex = OutputIndex_;
};
} // namespace ck_tile

View File

@@ -32,7 +32,8 @@ struct Reduce2dDefaultPolicy
{
using P_ = BlockReduce2dProblem<typename Problem::XDataType,
typename Problem::ComputeDataType,
typename Problem::BlockShape>;
typename Problem::BlockShape,
Problem::kOutputIndex>;
return BlockReduce2d<P_>{};
}
@@ -41,7 +42,8 @@ struct Reduce2dDefaultPolicy
{
using P_ = BlockReduce2dProblem<typename Problem::XDataType,
typename Problem::ComputeDataType,
typename Problem::BlockShape>;
typename Problem::BlockShape,
Problem::kOutputIndex>;
return BlockReduce2dSync<P_>{};
}
@@ -50,7 +52,8 @@ struct Reduce2dDefaultPolicy
{
using P_ = BlockReduce2dProblem<typename Problem::XDataType,
typename Problem::ComputeDataType,
typename Problem::BlockShape>;
typename Problem::BlockShape,
Problem::kOutputIndex>;
return BlockReduce2dCrossWarpSync<P_>{};
}
@@ -61,7 +64,8 @@ struct Reduce2dDefaultPolicy
{
using P_ = BlockReduce2dProblem<typename Problem::XDataType,
typename Problem::ComputeDataType,
typename Problem::BlockShape>;
typename Problem::BlockShape,
Problem::kOutputIndex>;
using block_reduce2d = BlockReduce2d<P_>;
using x_block_tile =
@@ -76,5 +80,23 @@ struct Reduce2dDefaultPolicy
return 1; // zero size arrays are an extension
}
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr index_t GetIndicesSmemSize()
{
using P_ = BlockReduce2dProblem<typename Problem::XDataType,
typename Problem::ComputeDataType,
typename Problem::BlockShape,
Problem::kOutputIndex>;
using block_reduce2d = BlockReduce2d<P_>;
using x_block_tile = decltype(make_static_distributed_tensor<typename Problem::XDataType>(
MakeXBlockTileDistribution<Problem>()));
using y_index_block_tile =
decltype(block_reduce2d::template MakeYIndexBlockTile<x_block_tile, index_t>());
return GetBlockReduce2dCrossWarpSync<Problem>()
.template GetIndicesSmemSize<y_index_block_tile>();
}
};
} // namespace ck_tile

View File

@@ -11,7 +11,8 @@ template <typename XDataType_,
typename ComputeDataType_,
typename YDataType_,
typename BlockShape_,
typename ReduceOp_>
typename ReduceOp_,
bool OutputIndex_ = false>
struct Reduce2dProblem
{
using XDataType = remove_cvref_t<XDataType_>;
@@ -20,6 +21,7 @@ struct Reduce2dProblem
using BlockShape = remove_cvref_t<BlockShape_>;
using ReduceOp = ReduceOp_;
static constexpr bool kOutputIndex = OutputIndex_;
static constexpr bool kNeedCrossLaneSync = BlockShape::ThreadPerWarp_N > 1;
static constexpr bool kNeedCrossWarpSync = BlockShape::WarpPerBlock_N > 1;
};

View File

@@ -28,7 +28,19 @@ class TestCkTilePooling : public ::testing::Test
using TestPoolShape = ck_tile::PoolShape<BlockWarps_, BlockTile_, WarpTile_, ThreadTile_>;
// 3D pooling configuration
// 2D pooling configuration (NHWC)
struct Config2D
{
ck_tile::index_t N, H, W, C;
ck_tile::index_t Y, X;
ck_tile::index_t Sy, Sx;
ck_tile::index_t Dy, Dx;
ck_tile::index_t LeftPy, LeftPx;
ck_tile::index_t RightPy, RightPx;
std::string name;
};
// 3D pooling configuration (NDHWC)
struct Config3D
{
ck_tile::index_t N, D, H, W, C;
@@ -40,6 +52,117 @@ class TestCkTilePooling : public ::testing::Test
std::string name;
};
bool RunPool2D(const Config2D& config)
{
std::cout << "Testing 2D: " << config.name << " ... ";
const ck_tile::index_t Ys = (config.Y - 1) * config.Dy + 1;
const ck_tile::index_t Xs = (config.X - 1) * config.Dx + 1;
const ck_tile::index_t Ho =
(config.H + config.LeftPy + config.RightPy - Ys) / config.Sy + 1;
const ck_tile::index_t Wo =
(config.W + config.LeftPx + config.RightPx - Xs) / config.Sx + 1;
using IndexDataType = ck_tile::index_t;
// Host tensors
ck_tile::HostTensor<InDataType> h_in({config.N, config.H, config.W, config.C});
ck_tile::HostTensor<OutDataType> h_out({config.N, Ho, Wo, config.C});
ck_tile::HostTensor<OutDataType> h_out_ref({config.N, Ho, Wo, config.C});
ck_tile::HostTensor<IndexDataType> h_out_index({config.N, Ho, Wo, config.C});
ck_tile::HostTensor<IndexDataType> h_out_ref_index({config.N, Ho, Wo, config.C});
// Initialize input with random data
ck_tile::FillUniformDistribution<InDataType>{-5.f, 5.f}(h_in);
// Device memory
ck_tile::DeviceMem d_in_mem(h_in.get_element_space_size_in_bytes());
ck_tile::DeviceMem d_out_mem(h_out.get_element_space_size_in_bytes());
ck_tile::DeviceMem d_out_index_mem(h_out_index.get_element_space_size_in_bytes());
d_in_mem.ToDevice(h_in.data());
d_out_mem.ToDevice(h_out.data());
d_out_index_mem.ToDevice(h_out_index.data());
constexpr ck_tile::index_t kBlockPerCu = 1;
using Problem = ck_tile::PoolProblem<InDataType,
OutDataType,
ComputeDataType,
IndexDataType,
ReduceOpType,
true, // OutputIndex
false, // PropagateNan
TestPoolShape>;
using Kernel = ck_tile::PoolKernel<Problem>;
const ck_tile::index_t kBlockSize = Kernel::BlockSize();
// Shapes and strides (NHWC)
const auto input_shape = ck_tile::make_tuple(config.N, config.H, config.W, config.C);
const auto output_shape = ck_tile::make_tuple(config.N, Ho, Wo, config.C);
const auto input_strides =
ck_tile::make_tuple(config.H * config.W * config.C, config.W * config.C, config.C, 1);
const auto output_strides =
ck_tile::make_tuple(Ho * Wo * config.C, Wo * config.C, config.C, 1);
const auto window_spatial_lengths = ck_tile::make_tuple(config.Y, config.X);
const auto window_strides = ck_tile::make_tuple(config.Sy, config.Sx);
const auto window_dilations = ck_tile::make_tuple(config.Dy, config.Dx);
const auto input_left_pads = ck_tile::make_tuple(config.LeftPy, config.LeftPx);
const auto input_right_pads = ck_tile::make_tuple(config.RightPy, config.RightPx);
auto host_args =
ck_tile::PoolHostArgs<decltype(input_shape), decltype(window_spatial_lengths)>{
static_cast<InDataType*>(d_in_mem.GetDeviceBuffer()),
static_cast<OutDataType*>(d_out_mem.GetDeviceBuffer()),
static_cast<IndexDataType*>(d_out_index_mem.GetDeviceBuffer()),
input_shape,
output_shape,
input_strides,
output_strides,
window_spatial_lengths,
window_strides,
window_dilations,
input_left_pads,
input_right_pads};
auto kernel_args = Kernel::MakeKernelArgs(host_args);
const ck_tile::index_t kGridSize = Kernel::CalculateGridSize(kernel_args);
if(!Kernel::IsSupportedArgument(kernel_args))
{
return true;
}
// Run kernel
ck_tile::launch_kernel(
ck_tile::stream_config{nullptr, false, 0},
ck_tile::make_kernel<kBlockPerCu>(Kernel{}, kGridSize, kBlockSize, 0, kernel_args));
// Run reference
ck_tile::reference_pool2d<InDataType,
ComputeDataType,
OutDataType,
IndexDataType,
ReduceOpType,
decltype(input_shape),
decltype(window_spatial_lengths),
true>(
h_in, h_out_ref, h_out_ref_index, kernel_args, ReduceOpType{});
d_out_mem.FromDevice(h_out.data());
d_out_index_mem.FromDevice(h_out_index.data());
// Validate results
bool pass_value =
ck_tile::check_err(h_out, h_out_ref, "Error: Incorrect values!", 1e-5, 1e-5);
bool pass_index = ck_tile::check_err(
h_out_index, h_out_ref_index, "Error: Incorrect indices!", 1e-5, 1e-5);
std::cout << (pass_value && pass_index ? "PASS" : "FAIL") << std::endl;
return pass_value && pass_index;
}
bool RunPool3D(const Config3D& config)
{
std::cout << "Testing 3D: " << config.name << " ... ";
@@ -72,6 +195,8 @@ class TestCkTilePooling : public ::testing::Test
const auto input_right_pads =
ck_tile::make_tuple(config.RightPz, config.RightPy, config.RightPx);
using IndexDataType = ck_tile::index_t;
ck_tile::HostTensor<InDataType> h_in({config.N, config.D, config.H, config.W, config.C},
{config.D * config.H * config.W * config.C,
config.H * config.W * config.C,
@@ -84,6 +209,12 @@ class TestCkTilePooling : public ::testing::Test
ck_tile::HostTensor<OutDataType> h_out_ref(
{config.N, Do, Ho, Wo, config.C},
{Do * Ho * Wo * config.C, Ho * Wo * config.C, Wo * config.C, config.C, 1});
ck_tile::HostTensor<IndexDataType> h_out_index(
{config.N, Do, Ho, Wo, config.C},
{Do * Ho * Wo * config.C, Ho * Wo * config.C, Wo * config.C, config.C, 1});
ck_tile::HostTensor<IndexDataType> h_out_ref_index(
{config.N, Do, Ho, Wo, config.C},
{Do * Ho * Wo * config.C, Ho * Wo * config.C, Wo * config.C, config.C, 1});
ck_tile::FillUniformDistribution<InDataType>{-5.f, 5.f}(h_in);
h_out.SetZero();
@@ -91,17 +222,19 @@ class TestCkTilePooling : public ::testing::Test
ck_tile::DeviceMem d_in_mem(h_in.get_element_space_size_in_bytes());
ck_tile::DeviceMem d_out_mem(h_out.get_element_space_size_in_bytes());
ck_tile::DeviceMem d_out_index_mem(h_out_index.get_element_space_size_in_bytes());
d_in_mem.ToDevice(h_in.data());
d_out_mem.ToDevice(h_out.data());
d_out_index_mem.ToDevice(h_out_index.data());
using Problem = ck_tile::PoolProblem<InDataType,
OutDataType,
ComputeDataType,
OutDataType,
IndexDataType,
ReduceOpType,
false,
false,
true, // OutputIndex
false, // PropagateNan
TestPoolShape>;
using Kernel = ck_tile::PoolKernel<Problem>;
@@ -112,6 +245,7 @@ class TestCkTilePooling : public ::testing::Test
ck_tile::PoolHostArgs<decltype(input_shape), decltype(window_spatial_lengths)>{
static_cast<InDataType*>(d_in_mem.GetDeviceBuffer()),
static_cast<OutDataType*>(d_out_mem.GetDeviceBuffer()),
static_cast<IndexDataType*>(d_out_index_mem.GetDeviceBuffer()),
input_shape,
output_shape,
input_strides,
@@ -137,16 +271,27 @@ class TestCkTilePooling : public ::testing::Test
ck_tile::make_kernel<kBlockPerCu>(Kernel{}, kGridSize, kBlockSize, 0, kernel_args));
// Run reference implementation
ck_tile::reference_pool3d<InDataType, ComputeDataType, OutDataType>(
h_in, h_out_ref, kernel_args, ReduceOpType{});
ck_tile::reference_pool3d<InDataType,
ComputeDataType,
OutDataType,
IndexDataType,
ReduceOpType,
decltype(input_shape),
decltype(window_spatial_lengths),
true>(
h_in, h_out_ref, h_out_ref_index, kernel_args, ReduceOpType{});
d_out_mem.FromDevice(h_out.data());
d_out_index_mem.FromDevice(h_out_index.data());
// Validate results
bool pass = ck_tile::check_err(h_out, h_out_ref);
std::cout << (pass ? "PASS" : "FAIL") << std::endl;
bool pass_value =
ck_tile::check_err(h_out, h_out_ref, "Error: Incorrect values!", 1e-5, 1e-5);
bool pass_index = ck_tile::check_err(
h_out_index, h_out_ref_index, "Error: Incorrect indices!", 1e-5, 1e-5);
return pass;
std::cout << (pass_value && pass_index ? "PASS" : "FAIL") << std::endl;
return pass_value && pass_index;
}
};
@@ -194,6 +339,50 @@ using TestTypes =
TYPED_TEST_SUITE(TestCkTilePooling, TestTypes);
// 2D Pooling Tests (NHWC)
TYPED_TEST(TestCkTilePooling, Pool2D_2x2)
{
typename TestFixture::Config2D config = {1, // N - batch size
8, // H - height dimension
8, // W - width dimension
32, // C - channel dimension
2, // Y - pooling window height
2, // X - pooling window width
2, // Sy - window stride height
2, // Sx - window stride width
1, // Dy - window dilation height
1, // Dx - window dilation width
0, // LeftPy - left padding height
0, // LeftPx - left padding width
0, // RightPy - right padding height
0, // RightPx - right padding width
"2x2 pooling NHWC"};
bool pass = this->RunPool2D(config);
EXPECT_TRUE(pass);
}
TYPED_TEST(TestCkTilePooling, Pool2D_3x3_WithPadding)
{
typename TestFixture::Config2D config = {2, // N - batch size
16, // H - height dimension
16, // W - width dimension
32, // C - channel dimension
3, // Y - pooling window height
3, // X - pooling window width
2, // Sy - window stride height
2, // Sx - window stride width
1, // Dy - window dilation height
1, // Dx - window dilation width
1, // LeftPy - left padding height
1, // LeftPx - left padding width
1, // RightPy - right padding height
1, // RightPx - right padding width
"3x3 pooling NHWC with padding"};
bool pass = this->RunPool2D(config);
EXPECT_TRUE(pass);
}
// 3D Pooling Tests (NDHWC)
TYPED_TEST(TestCkTilePooling, Pool3D_2x2x2)
{
typename TestFixture::Config3D config = {1, // N - batch size
@@ -216,7 +405,7 @@ TYPED_TEST(TestCkTilePooling, Pool3D_2x2x2)
0, // RightPz - right padding depth
0, // RightPy - right padding height
0, // RightPx - right padding width
"2x2x2 pooling"};
"2x2x2 pooling NDHWC"};
bool pass = this->RunPool3D(config);
EXPECT_TRUE(pass);
}
@@ -243,7 +432,7 @@ TYPED_TEST(TestCkTilePooling, Pool3D_3x3x3)
1, // RightPz - right padding depth
1, // RightPy - right padding height
1, // RightPx - right padding width
"3x3x3 pooling"};
"3x3x3 pooling NDHWC with padding"};
bool pass = this->RunPool3D(config);
EXPECT_TRUE(pass);
}