mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-15 10:37:44 +00:00
Merge commit '3052d7c9e6972d5ea7d2225ab78e45554ba70efd' into develop
This commit is contained in:
@@ -38,7 +38,10 @@ auto create_args(int argc, char* argv[])
|
||||
return std::make_tuple(result, arg_parser);
|
||||
}
|
||||
|
||||
template <typename InDataType, typename OutDataType, typename ComputeDataType>
|
||||
template <typename InDataType,
|
||||
typename OutDataType,
|
||||
typename ComputeDataType,
|
||||
typename IndexDataType>
|
||||
bool run(const ck_tile::ArgParser& arg_parser)
|
||||
{
|
||||
|
||||
@@ -84,6 +87,9 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
int warmup = arg_parser.get_int("warmup");
|
||||
int repeat = arg_parser.get_int("repeat");
|
||||
|
||||
constexpr bool OutputIndex = true;
|
||||
constexpr bool PropagateNan = false;
|
||||
|
||||
// Shapes / strides / parameters (NDHWC)
|
||||
const auto input_shape = ck_tile::make_tuple(N, D, H, W, C);
|
||||
const auto output_shape = ck_tile::make_tuple(N, Do, Ho, Wo, C);
|
||||
@@ -100,11 +106,16 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
{Do * Ho * Wo * C, Ho * Wo * C, Wo * C, C, 1});
|
||||
ck_tile::HostTensor<OutDataType> out_ref({N, Do, Ho, Wo, C},
|
||||
{Do * Ho * Wo * C, Ho * Wo * C, Wo * C, C, 1});
|
||||
ck_tile::HostTensor<IndexDataType> out_index({N, Do, Ho, Wo, C},
|
||||
{Do * Ho * Wo * C, Ho * Wo * C, Wo * C, C, 1});
|
||||
ck_tile::HostTensor<IndexDataType> out_ref_index({N, Do, Ho, Wo, C},
|
||||
{Do * Ho * Wo * C, Ho * Wo * C, Wo * C, C, 1});
|
||||
|
||||
ck_tile::FillUniformDistribution<InDataType>{-5.f, 5.f}(in);
|
||||
|
||||
ck_tile::DeviceMem in_buf(in.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem out_buf(out.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem out_index_buf(OutputIndex ? out_index.get_element_space_size_in_bytes() : 0);
|
||||
|
||||
in_buf.ToDevice(in.data());
|
||||
|
||||
@@ -118,10 +129,10 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
using Problem = ck_tile::PoolProblem<InDataType,
|
||||
OutDataType,
|
||||
ComputeDataType,
|
||||
OutDataType,
|
||||
IndexDataType,
|
||||
ReduceOp,
|
||||
false,
|
||||
false,
|
||||
OutputIndex,
|
||||
PropagateNan,
|
||||
Shape>;
|
||||
using Kernel = ck_tile::PoolKernel<Problem>;
|
||||
|
||||
@@ -131,6 +142,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
auto host_args = ck_tile::PoolHostArgs<decltype(input_shape), decltype(window_spatial_lengths)>{
|
||||
static_cast<InDataType*>(in_buf.GetDeviceBuffer()),
|
||||
static_cast<OutDataType*>(out_buf.GetDeviceBuffer()),
|
||||
OutputIndex ? static_cast<IndexDataType*>(out_index_buf.GetDeviceBuffer()) : nullptr,
|
||||
input_shape,
|
||||
output_shape,
|
||||
input_strides,
|
||||
@@ -167,12 +179,28 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
|
||||
if(do_validation)
|
||||
{
|
||||
ck_tile::reference_pool3d<InDataType, ComputeDataType, OutDataType>(
|
||||
in, out_ref, kernel_args, ReduceOp{});
|
||||
out_buf.FromDevice(out.mData.data());
|
||||
pass = ck_tile::check_err(out, out_ref);
|
||||
|
||||
std::cout << "valid:" << (pass ? "y" : "n") << std::flush << std::endl;
|
||||
ck_tile::reference_pool3d<InDataType,
|
||||
ComputeDataType,
|
||||
OutDataType,
|
||||
IndexDataType,
|
||||
ReduceOp,
|
||||
decltype(input_shape),
|
||||
decltype(window_spatial_lengths),
|
||||
OutputIndex>(in, out_ref, out_ref_index, kernel_args, ReduceOp{});
|
||||
|
||||
if constexpr(OutputIndex)
|
||||
{
|
||||
out_index_buf.FromDevice(out_index.mData.data());
|
||||
pass = ck_tile::check_err(out, out_ref) && ck_tile::check_err(out_index, out_ref_index);
|
||||
}
|
||||
else
|
||||
{
|
||||
pass = ck_tile::check_err(out, out_ref);
|
||||
}
|
||||
|
||||
std::cout << "valid:" << (pass ? "y" : "n") << std::endl;
|
||||
}
|
||||
|
||||
return pass;
|
||||
@@ -184,5 +212,5 @@ int main(int argc, char* argv[])
|
||||
if(!result)
|
||||
return -1;
|
||||
|
||||
return run<ck_tile::half_t, ck_tile::half_t, float>(arg_parser) ? 0 : -2;
|
||||
return run<ck_tile::half_t, ck_tile::half_t, float, ck_tile::index_t>(arg_parser) ? 0 : -2;
|
||||
}
|
||||
|
||||
@@ -78,6 +78,7 @@
|
||||
#include "ck_tile/core/utility/print.hpp"
|
||||
#include "ck_tile/core/utility/random.hpp"
|
||||
#include "ck_tile/core/utility/reduce_operator.hpp"
|
||||
#include "ck_tile/core/utility/reduce_operator_accumulate.hpp"
|
||||
#include "ck_tile/core/utility/static_counter.hpp"
|
||||
#include "ck_tile/core/utility/to_sequence.hpp"
|
||||
#include "ck_tile/core/utility/transpose_vectors.hpp"
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/utility/type_traits.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
@@ -18,16 +19,14 @@ struct Add
|
||||
};
|
||||
|
||||
template <typename T,
|
||||
typename = std::enable_if_t<std::is_same_v<T, float> || std::is_same_v<T, double> ||
|
||||
std::is_same_v<T, int32_t> || std::is_same_v<T, int8_t>>>
|
||||
typename = std::enable_if_t<is_any_of<T, float, double, int32_t, int8_t>::value>>
|
||||
CK_TILE_HOST_DEVICE constexpr T operator()(const T& y, const T x) const
|
||||
{
|
||||
return y + x;
|
||||
}
|
||||
|
||||
template <typename T,
|
||||
typename = std::enable_if_t<std::is_same_v<T, half_t> || std::is_same_v<T, bf16_t> ||
|
||||
std::is_same_v<T, fp8_t> || std::is_same_v<T, bf8_t>>>
|
||||
typename = std::enable_if_t<is_any_of<T, half_t, bf16_t, fp8_t, bf8_t>::value>>
|
||||
CK_TILE_HOST_DEVICE constexpr T operator()(T& y, T x) const
|
||||
{
|
||||
float y_ = type_convert<float>(y);
|
||||
@@ -46,16 +45,14 @@ struct SquareAdd
|
||||
};
|
||||
|
||||
template <typename T,
|
||||
typename = std::enable_if_t<std::is_same_v<T, float> || std::is_same_v<T, double> ||
|
||||
std::is_same_v<T, int32_t> || std::is_same_v<T, int8_t>>>
|
||||
typename = std::enable_if_t<is_any_of<T, float, double, int32_t, int8_t>::value>>
|
||||
CK_TILE_HOST_DEVICE constexpr T operator()(const T& y, const T x) const
|
||||
{
|
||||
return y + (x * x);
|
||||
}
|
||||
|
||||
template <typename T,
|
||||
typename = std::enable_if_t<std::is_same_v<T, half_t> || std::is_same_v<T, bf16_t> ||
|
||||
std::is_same_v<T, fp8_t> || std::is_same_v<T, bf8_t>>>
|
||||
typename = std::enable_if_t<is_any_of<T, half_t, bf16_t, fp8_t, bf8_t>::value>>
|
||||
CK_TILE_HOST_DEVICE constexpr T operator()(T& y, T x) const
|
||||
{
|
||||
float y_ = type_convert<float>(y);
|
||||
@@ -66,48 +63,74 @@ struct SquareAdd
|
||||
|
||||
struct Max
|
||||
{
|
||||
template <typename T,
|
||||
typename = std::enable_if_t<std::is_same_v<T, float> || std::is_same_v<T, double> ||
|
||||
std::is_same_v<T, int32_t> || std::is_same_v<T, int8_t> ||
|
||||
std::is_same_v<T, half_t> || std::is_same_v<T, bf16_t> ||
|
||||
std::is_same_v<T, fp8_t> || std::is_same_v<T, bf8_t>>>
|
||||
template <
|
||||
typename T,
|
||||
typename = std::enable_if_t<
|
||||
is_any_of<T, float, double, int32_t, int8_t, half_t, bf16_t, fp8_t, bf8_t>::value>>
|
||||
CK_TILE_HOST_DEVICE static constexpr T GetIdentityValue()
|
||||
{
|
||||
return numeric<T>::lowest();
|
||||
};
|
||||
|
||||
template <typename T,
|
||||
typename = std::enable_if_t<std::is_same_v<T, float> || std::is_same_v<T, double> ||
|
||||
std::is_same_v<T, int32_t> || std::is_same_v<T, int8_t> ||
|
||||
std::is_same_v<T, half_t> || std::is_same_v<T, bf16_t> ||
|
||||
std::is_same_v<T, fp8_t> || std::is_same_v<T, bf8_t>>>
|
||||
template <
|
||||
typename T,
|
||||
typename = std::enable_if_t<
|
||||
is_any_of<T, float, double, int32_t, int8_t, half_t, bf16_t, fp8_t, bf8_t>::value>>
|
||||
CK_TILE_HOST_DEVICE constexpr T operator()(const T& y, const T x) const
|
||||
{
|
||||
return max(y, x);
|
||||
}
|
||||
|
||||
// Overload with changed flag for index tracking
|
||||
template <
|
||||
typename T,
|
||||
typename = std::enable_if_t<
|
||||
is_any_of<T, float, double, int32_t, int8_t, half_t, bf16_t, fp8_t, bf8_t>::value>>
|
||||
CK_TILE_HOST_DEVICE constexpr T operator()(const T& y, const T x, bool& changed) const
|
||||
{
|
||||
T new_max = max(y, x);
|
||||
if(x > y)
|
||||
{
|
||||
changed = true;
|
||||
}
|
||||
return new_max;
|
||||
}
|
||||
};
|
||||
|
||||
struct AbsMax
|
||||
{
|
||||
template <typename T,
|
||||
typename = std::enable_if_t<std::is_same_v<T, float> || std::is_same_v<T, double> ||
|
||||
std::is_same_v<T, int32_t> || std::is_same_v<T, int8_t> ||
|
||||
std::is_same_v<T, half_t> || std::is_same_v<T, bf16_t> ||
|
||||
std::is_same_v<T, fp8_t> || std::is_same_v<T, bf8_t>>>
|
||||
template <
|
||||
typename T,
|
||||
typename = std::enable_if_t<
|
||||
is_any_of<T, float, double, int32_t, int8_t, half_t, bf16_t, fp8_t, bf8_t>::value>>
|
||||
CK_TILE_HOST_DEVICE static constexpr T GetIdentityValue()
|
||||
{
|
||||
return numeric<T>::zero();
|
||||
};
|
||||
|
||||
template <typename T,
|
||||
typename = std::enable_if_t<std::is_same_v<T, float> || std::is_same_v<T, double> ||
|
||||
std::is_same_v<T, int32_t> || std::is_same_v<T, int8_t> ||
|
||||
std::is_same_v<T, half_t> || std::is_same_v<T, bf16_t> ||
|
||||
std::is_same_v<T, fp8_t> || std::is_same_v<T, bf8_t>>>
|
||||
template <
|
||||
typename T,
|
||||
typename = std::enable_if_t<
|
||||
is_any_of<T, float, double, int32_t, int8_t, half_t, bf16_t, fp8_t, bf8_t>::value>>
|
||||
CK_TILE_HOST_DEVICE constexpr T operator()(const T& y, const T x) const
|
||||
{
|
||||
return max(y, abs(x));
|
||||
}
|
||||
|
||||
// Overload with changed flag for index tracking
|
||||
template <
|
||||
typename T,
|
||||
typename = std::enable_if_t<
|
||||
is_any_of<T, float, double, int32_t, int8_t, half_t, bf16_t, fp8_t, bf8_t>::value>>
|
||||
CK_TILE_HOST_DEVICE constexpr T operator()(const T& y, const T x, bool& changed) const
|
||||
{
|
||||
T new_max = max(y, abs(x));
|
||||
if(abs(x) > y)
|
||||
{
|
||||
changed = true;
|
||||
}
|
||||
return new_max;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ReduceOp
|
||||
|
||||
50
include/ck_tile/core/utility/reduce_operator_accumulate.hpp
Normal file
50
include/ck_tile/core/utility/reduce_operator_accumulate.hpp
Normal file
@@ -0,0 +1,50 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
/// @brief Accumulate with index tracking reductions, provides deterministic first occurring index
|
||||
struct AccumulateWithIndex
|
||||
{
|
||||
template <typename ReduceOp, typename T, typename IndexType>
|
||||
CK_TILE_HOST_DEVICE void operator()(const ReduceOp& reduce_func,
|
||||
T& current_value,
|
||||
IndexType& current_index,
|
||||
const T& new_value,
|
||||
const IndexType& new_index) const
|
||||
{
|
||||
bool changed = false;
|
||||
current_value = reduce_func(current_value, new_value, changed);
|
||||
|
||||
if(changed)
|
||||
{
|
||||
current_index = new_index;
|
||||
}
|
||||
else if(new_index < current_index)
|
||||
{
|
||||
bool reverse_changed = false;
|
||||
reduce_func(new_value, current_value, reverse_changed);
|
||||
|
||||
if(!reverse_changed)
|
||||
{
|
||||
current_index = new_index;
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct Accumulate
|
||||
{
|
||||
template <typename ReduceOp, typename T>
|
||||
CK_TILE_HOST_DEVICE void
|
||||
operator()(const ReduceOp& reduce_func, T& current_value, const T& new_value) const
|
||||
{
|
||||
current_value = reduce_func(current_value, new_value);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -7,17 +7,21 @@
|
||||
#include "ck_tile/host/host_tensor.hpp"
|
||||
#include "ck_tile/ops/pooling/kernel/pool_kernel.hpp"
|
||||
#include <thread>
|
||||
#include <cmath>
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename InDataType,
|
||||
typename ComputeDataType,
|
||||
typename OutDataType,
|
||||
typename IndexDataType,
|
||||
typename ReduceOp,
|
||||
typename TensorShape,
|
||||
typename WindowShape>
|
||||
typename WindowShape,
|
||||
bool OutputIndex = false>
|
||||
CK_TILE_HOST void reference_pool2d(const HostTensor<InDataType>& input,
|
||||
HostTensor<OutDataType>& output,
|
||||
HostTensor<IndexDataType>& output_index,
|
||||
PoolKernelArgs<TensorShape, WindowShape> kargs,
|
||||
ReduceOp reduce_op)
|
||||
{
|
||||
@@ -45,6 +49,8 @@ CK_TILE_HOST void reference_pool2d(const HostTensor<InDataType>& input,
|
||||
auto f = [&](auto n, auto ho, auto wo, auto c) {
|
||||
ComputeDataType v_acc = reduce_op.template GetIdentityValue<ComputeDataType>();
|
||||
|
||||
IndexDataType current_index = 0; // Declare outside if constexpr for efficiency
|
||||
|
||||
for(ck_tile::index_t y = 0; y < Y; ++y)
|
||||
{
|
||||
// Calculate input height index with stride, dilation, and padding
|
||||
@@ -58,13 +64,32 @@ CK_TILE_HOST void reference_pool2d(const HostTensor<InDataType>& input,
|
||||
if(hi >= 0 && hi < H && wi >= 0 && wi < W)
|
||||
{
|
||||
const ComputeDataType v_in = type_convert<ComputeDataType>(input(n, hi, wi, c));
|
||||
v_acc = reduce_op(v_acc, v_in);
|
||||
|
||||
if constexpr(OutputIndex)
|
||||
{
|
||||
IndexDataType flat_index = input.GetOffsetFromMultiIndex(n, hi, wi, c);
|
||||
bool changed = false;
|
||||
v_acc = reduce_op(v_acc, v_in, changed);
|
||||
if(changed)
|
||||
{
|
||||
current_index = flat_index;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
v_acc = reduce_op(v_acc, v_in);
|
||||
}
|
||||
}
|
||||
// For positions outside bounds, we implicitly use identity value
|
||||
}
|
||||
}
|
||||
|
||||
output(n, ho, wo, c) = ck_tile::type_convert<OutDataType>(v_acc);
|
||||
|
||||
if constexpr(OutputIndex)
|
||||
{
|
||||
output_index(n, ho, wo, c) = current_index;
|
||||
}
|
||||
};
|
||||
|
||||
// Parallelize over all output dimensions
|
||||
@@ -74,11 +99,14 @@ CK_TILE_HOST void reference_pool2d(const HostTensor<InDataType>& input,
|
||||
template <typename InDataType,
|
||||
typename ComputeDataType,
|
||||
typename OutDataType,
|
||||
typename IndexDataType,
|
||||
typename ReduceOp,
|
||||
typename TensorShape,
|
||||
typename WindowShape>
|
||||
typename WindowShape,
|
||||
bool OutputIndex = false>
|
||||
CK_TILE_HOST void reference_pool3d(const HostTensor<InDataType>& input,
|
||||
HostTensor<OutDataType>& output,
|
||||
HostTensor<IndexDataType>& output_index,
|
||||
PoolKernelArgs<TensorShape, WindowShape> kargs,
|
||||
ReduceOp reduce_op)
|
||||
{
|
||||
@@ -112,6 +140,8 @@ CK_TILE_HOST void reference_pool3d(const HostTensor<InDataType>& input,
|
||||
auto f = [&](auto n, auto do_, auto ho, auto wo, auto c) {
|
||||
ComputeDataType v_acc = reduce_op.template GetIdentityValue<ComputeDataType>();
|
||||
|
||||
IndexDataType current_index = 0; // Declare outside if constexpr for efficiency
|
||||
|
||||
for(ck_tile::index_t z = 0; z < Z; ++z)
|
||||
{
|
||||
// Calculate input depth index with stride, dilation, and padding
|
||||
@@ -131,7 +161,22 @@ CK_TILE_HOST void reference_pool3d(const HostTensor<InDataType>& input,
|
||||
{
|
||||
const ComputeDataType v_in =
|
||||
type_convert<ComputeDataType>(input(n, di, hi, wi, c));
|
||||
v_acc = reduce_op(v_acc, v_in);
|
||||
|
||||
if constexpr(OutputIndex)
|
||||
{
|
||||
IndexDataType flat_index =
|
||||
input.GetOffsetFromMultiIndex(n, di, hi, wi, c);
|
||||
bool changed = false;
|
||||
v_acc = reduce_op(v_acc, v_in, changed);
|
||||
if(changed)
|
||||
{
|
||||
current_index = flat_index;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
v_acc = reduce_op(v_acc, v_in);
|
||||
}
|
||||
}
|
||||
// For positions outside bounds, we implicitly use identity value
|
||||
}
|
||||
@@ -139,10 +184,15 @@ CK_TILE_HOST void reference_pool3d(const HostTensor<InDataType>& input,
|
||||
}
|
||||
|
||||
output(n, do_, ho, wo, c) = ck_tile::type_convert<OutDataType>(v_acc);
|
||||
|
||||
if constexpr(OutputIndex)
|
||||
{
|
||||
|
||||
output_index(n, do_, ho, wo, c) = current_index;
|
||||
}
|
||||
};
|
||||
|
||||
// Parallelize over all output dimensions
|
||||
make_ParallelTensorFunctor(f, N, Do, Ho, Wo, C)(std::thread::hardware_concurrency());
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -17,6 +17,7 @@ struct PoolHostArgs
|
||||
|
||||
CK_TILE_HOST PoolHostArgs(const void* input_ptr_,
|
||||
void* output_ptr_,
|
||||
void* output_index_ptr_,
|
||||
TensorShape input_shape_,
|
||||
TensorShape output_shape_,
|
||||
TensorShape input_strides_,
|
||||
@@ -28,6 +29,7 @@ struct PoolHostArgs
|
||||
WindowShape input_right_pads_)
|
||||
: input_ptr(input_ptr_),
|
||||
output_ptr(output_ptr_),
|
||||
output_index_ptr(output_index_ptr_),
|
||||
input_shape(input_shape_),
|
||||
output_shape(output_shape_),
|
||||
input_strides(input_strides_),
|
||||
@@ -42,6 +44,7 @@ struct PoolHostArgs
|
||||
|
||||
const void* input_ptr;
|
||||
void* output_ptr;
|
||||
void* output_index_ptr;
|
||||
|
||||
TensorShape input_shape;
|
||||
TensorShape output_shape;
|
||||
@@ -60,6 +63,7 @@ struct PoolKernelArgs
|
||||
{
|
||||
const void* input_ptr;
|
||||
void* output_ptr;
|
||||
void* output_index_ptr;
|
||||
TensorShape input_shape;
|
||||
TensorShape output_shape;
|
||||
TensorShape input_strides;
|
||||
@@ -80,6 +84,7 @@ struct PoolKernel
|
||||
using InDataType = ck_tile::remove_cvref_t<typename Problem::InDataType>;
|
||||
using ComputeDataType = ck_tile::remove_cvref_t<typename Problem::ComputeDataType>;
|
||||
using OutDataType = ck_tile::remove_cvref_t<typename Problem::OutDataType>;
|
||||
using IndexDataType = ck_tile::remove_cvref_t<typename Problem::IndexDataType>;
|
||||
|
||||
static constexpr index_t kBlockSize = Problem::BlockShape::BlockSize;
|
||||
|
||||
@@ -205,7 +210,23 @@ struct PoolKernel
|
||||
tensor_view<decltype(out_buffer_view), decltype(out_desc_padded)>{out_buffer_view,
|
||||
out_desc_padded};
|
||||
|
||||
return make_tuple(in_tensor_padded, out_tensor_padded);
|
||||
if constexpr(Problem::kOutputIndex)
|
||||
{
|
||||
auto out_index_buffer_view = make_buffer_view<address_space_enum::global>(
|
||||
static_cast<IndexDataType*>(kargs.output_index_ptr),
|
||||
out_desc.get_element_space_size(),
|
||||
IndexDataType(-1));
|
||||
const auto out_index_tensor_padded =
|
||||
tensor_view<decltype(out_index_buffer_view), decltype(out_desc_padded)>{
|
||||
out_index_buffer_view, out_desc_padded};
|
||||
|
||||
return make_tuple(in_tensor_padded, out_tensor_padded, out_index_tensor_padded);
|
||||
}
|
||||
else
|
||||
{
|
||||
// Return a dummy tensor for the third element when index output is not needed
|
||||
return make_tuple(in_tensor_padded, out_tensor_padded, null_tensor{});
|
||||
}
|
||||
}
|
||||
|
||||
template <typename TensorShape, typename WindowShape>
|
||||
@@ -338,7 +359,23 @@ struct PoolKernel
|
||||
tensor_view<decltype(out_buffer_view), decltype(out_desc_padded)>{out_buffer_view,
|
||||
out_desc_padded};
|
||||
|
||||
return make_tuple(in_tensor_padded, out_tensor_padded);
|
||||
if constexpr(Problem::kOutputIndex)
|
||||
{
|
||||
auto out_index_buffer_view = make_buffer_view<address_space_enum::global>(
|
||||
static_cast<IndexDataType*>(kargs.output_index_ptr),
|
||||
out_desc.get_element_space_size(),
|
||||
IndexDataType(-1));
|
||||
const auto out_index_tensor_padded =
|
||||
tensor_view<decltype(out_index_buffer_view), decltype(out_desc_padded)>{
|
||||
out_index_buffer_view, out_desc_padded};
|
||||
|
||||
return make_tuple(in_tensor_padded, out_tensor_padded, out_index_tensor_padded);
|
||||
}
|
||||
else
|
||||
{
|
||||
// Return a dummy tensor for the third element when index output is not needed
|
||||
return make_tuple(in_tensor_padded, out_tensor_padded, null_tensor{});
|
||||
}
|
||||
}
|
||||
|
||||
public:
|
||||
@@ -354,7 +391,7 @@ struct PoolKernel
|
||||
const auto iM = get_block_id() * S::Block_M;
|
||||
|
||||
// Get tensors based on dimensionality
|
||||
auto [in_tensor_padded, out_tensor_padded] = [&]() {
|
||||
auto [in_tensor_padded, out_tensor_padded, out_index_tensor_padded] = [&]() {
|
||||
if constexpr(WindowShape::size() == 2)
|
||||
return MakeTensorView2D(kargs);
|
||||
else if constexpr(WindowShape::size() == 3)
|
||||
@@ -387,16 +424,57 @@ struct PoolKernel
|
||||
auto y_tile = block_reduce2d.template MakeYBlockTile<XTensorTile>();
|
||||
set_tile(y_tile, reduce_op.template GetIdentityValue<ComputeDataType>());
|
||||
|
||||
for(int k_tile = __builtin_amdgcn_readfirstlane(0); k_tile < num_k_tiles; ++k_tile)
|
||||
if constexpr(Problem::kOutputIndex)
|
||||
{
|
||||
const auto x_tile = load_tile(x_window);
|
||||
block_reduce2d(x_tile, y_tile, reduce_op);
|
||||
move_tile_window(x_window, {0, S::Block_N});
|
||||
}
|
||||
auto y_index_window =
|
||||
make_tile_window(out_index_tensor_padded, make_tuple(number<S::Block_M>{}), {iM});
|
||||
|
||||
block_reduce2d_sync(y_tile, reduce_op);
|
||||
block_reduce2d_cross_warp(y_tile, smem, reduce_op);
|
||||
store_tile(y_window, cast_tile<OutDataType>(y_tile));
|
||||
auto y_index_tile =
|
||||
block_reduce2d.template MakeYIndexBlockTile<XTensorTile, IndexDataType>();
|
||||
set_tile(y_index_tile, IndexDataType(0));
|
||||
|
||||
// Main reduction loop - with index tracking
|
||||
for(int k_tile = amd_wave_read_first_lane(0); k_tile < num_k_tiles; ++k_tile)
|
||||
{
|
||||
const auto x_tile = load_tile(x_window);
|
||||
auto index_calculator = [&](const auto& x_indices) {
|
||||
// Get global coordinates in the 2D matrix space (M, N)
|
||||
const auto global_M = x_indices.at(number<0>{}) + iM;
|
||||
const auto global_N = (k_tile * S::Block_N) + x_indices.at(number<1>{});
|
||||
return in_tensor_padded.get_tensor_descriptor().calculate_offset(
|
||||
make_tuple(global_M, global_N));
|
||||
};
|
||||
|
||||
block_reduce2d(x_tile, y_tile, y_index_tile, reduce_op, index_calculator);
|
||||
move_tile_window(x_window, {0, S::Block_N});
|
||||
}
|
||||
|
||||
block_reduce2d_sync(y_tile, y_index_tile, reduce_op);
|
||||
if constexpr(Problem::kNeedCrossWarpSync)
|
||||
{
|
||||
__shared__ char smem_indices[Policy::template GetIndicesSmemSize<Problem>()];
|
||||
|
||||
block_reduce2d_cross_warp(y_tile, y_index_tile, smem, smem_indices, reduce_op);
|
||||
}
|
||||
|
||||
store_tile(y_window, cast_tile<OutDataType>(y_tile));
|
||||
store_tile(y_index_window, cast_tile<IndexDataType>(y_index_tile));
|
||||
}
|
||||
else
|
||||
{
|
||||
// Main reduction loop - without index tracking
|
||||
for(int k_tile = __builtin_amdgcn_readfirstlane(0); k_tile < num_k_tiles; ++k_tile)
|
||||
{
|
||||
const auto x_tile = load_tile(x_window);
|
||||
block_reduce2d(x_tile, y_tile, reduce_op);
|
||||
move_tile_window(x_window, {0, S::Block_N});
|
||||
}
|
||||
|
||||
block_reduce2d_sync(y_tile, reduce_op);
|
||||
block_reduce2d_cross_warp(y_tile, smem, reduce_op);
|
||||
|
||||
store_tile(y_window, cast_tile<OutDataType>(y_tile));
|
||||
}
|
||||
}
|
||||
|
||||
/// @brief Validates if the given arguments are supported by the pooling kernel.
|
||||
@@ -481,6 +559,7 @@ struct PoolKernel
|
||||
{
|
||||
return PoolKernelArgs<TensorShape, WindowShape>{host_args.input_ptr,
|
||||
host_args.output_ptr,
|
||||
host_args.output_index_ptr,
|
||||
host_args.input_shape,
|
||||
host_args.output_shape,
|
||||
host_args.input_strides,
|
||||
|
||||
@@ -32,7 +32,8 @@ struct PoolDefaultPolicy
|
||||
{
|
||||
using P_ = BlockReduce2dProblem<typename Problem::InDataType,
|
||||
typename Problem::ComputeDataType,
|
||||
typename Problem::BlockShape>;
|
||||
typename Problem::BlockShape,
|
||||
Problem::kOutputIndex>;
|
||||
return BlockReduce2d<P_>{};
|
||||
}
|
||||
|
||||
@@ -41,7 +42,8 @@ struct PoolDefaultPolicy
|
||||
{
|
||||
using P_ = BlockReduce2dProblem<typename Problem::InDataType,
|
||||
typename Problem::ComputeDataType,
|
||||
typename Problem::BlockShape>;
|
||||
typename Problem::BlockShape,
|
||||
Problem::kOutputIndex>;
|
||||
return BlockReduce2dSync<P_>{};
|
||||
}
|
||||
|
||||
@@ -50,7 +52,8 @@ struct PoolDefaultPolicy
|
||||
{
|
||||
using P_ = BlockReduce2dProblem<typename Problem::InDataType,
|
||||
typename Problem::ComputeDataType,
|
||||
typename Problem::BlockShape>;
|
||||
typename Problem::BlockShape,
|
||||
Problem::kOutputIndex>;
|
||||
return BlockReduce2dCrossWarpSync<P_>{};
|
||||
}
|
||||
|
||||
@@ -61,7 +64,8 @@ struct PoolDefaultPolicy
|
||||
{
|
||||
using P_ = BlockReduce2dProblem<typename Problem::InDataType,
|
||||
typename Problem::ComputeDataType,
|
||||
typename Problem::BlockShape>;
|
||||
typename Problem::BlockShape,
|
||||
Problem::kOutputIndex>;
|
||||
|
||||
using block_reduce2d = BlockReduce2d<P_>;
|
||||
using x_block_tile =
|
||||
@@ -76,5 +80,25 @@ struct PoolDefaultPolicy
|
||||
return 1; // zero size arrays are an extension
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetIndicesSmemSize()
|
||||
{
|
||||
|
||||
using P_ = BlockReduce2dProblem<typename Problem::InDataType,
|
||||
typename Problem::ComputeDataType,
|
||||
typename Problem::BlockShape,
|
||||
Problem::kOutputIndex>;
|
||||
|
||||
using block_reduce2d = BlockReduce2d<P_>;
|
||||
using x_block_tile = decltype(make_static_distributed_tensor<typename Problem::InDataType>(
|
||||
MakeXBlockTileDistribution<Problem>()));
|
||||
using y_index_block_tile = decltype(block_reduce2d::template MakeYIndexBlockTile<
|
||||
x_block_tile,
|
||||
typename Problem::IndexDataType>());
|
||||
|
||||
return GetBlockReduce2dCrossWarpSync<Problem>()
|
||||
.template GetIndicesSmemSize<y_index_block_tile>();
|
||||
}
|
||||
};
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -26,6 +26,8 @@ struct PoolProblem
|
||||
using OutputIndex = bool_constant<OutputIndex_>;
|
||||
using PropagateNan = bool_constant<PropagateNan_>;
|
||||
|
||||
static constexpr bool kOutputIndex = OutputIndex_;
|
||||
static constexpr bool kPropagateNan = PropagateNan_;
|
||||
static constexpr bool kNeedCrossLaneSync = BlockShape::ThreadPerWarp_N > 1;
|
||||
static constexpr bool kNeedCrossWarpSync = BlockShape::WarpPerBlock_N > 1;
|
||||
};
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/core/utility/reduce_operator_accumulate.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
@@ -50,6 +51,53 @@ struct BlockReduce2d
|
||||
|
||||
CK_TILE_DEVICE constexpr BlockReduce2d() {}
|
||||
|
||||
private:
|
||||
template <bool kProcessIndex,
|
||||
typename XDistributedTensor_,
|
||||
typename YDistributedTensor_,
|
||||
typename YIndexDistributedTensor_,
|
||||
typename ReduceFunc,
|
||||
typename IndexCalculatorFunc,
|
||||
typename ReducePacksPerXDim>
|
||||
CK_TILE_DEVICE void reduce_impl(const XDistributedTensor_& x_tensor,
|
||||
YDistributedTensor_& y_tensor,
|
||||
YIndexDistributedTensor_& y_index_tensor,
|
||||
const ReduceFunc& reduce_func,
|
||||
const IndexCalculatorFunc& index_calculator,
|
||||
ReducePacksPerXDim)
|
||||
{
|
||||
sweep_tile<XDistributedTensor_>(
|
||||
[&](auto... idx_) {
|
||||
constexpr auto idx_0 = make_tuple(make_tuple(idx_[number<0>{}]...)[number<0>{}]);
|
||||
|
||||
(..., [&](auto idx) {
|
||||
auto val = ck_tile::type_convert<ComputeDataType>(x_tensor[idx]);
|
||||
|
||||
if constexpr(kProcessIndex)
|
||||
{
|
||||
|
||||
const auto x_indices = get_x_indices_from_distributed_indices(
|
||||
XDistributedTensor_::get_tile_distribution(), idx);
|
||||
const auto new_idx = index_calculator(x_indices);
|
||||
auto current_idx = y_index_tensor(idx_0);
|
||||
|
||||
AccumulateWithIndex{}(
|
||||
reduce_func, y_tensor(idx_0), current_idx, val, new_idx);
|
||||
|
||||
y_index_tensor(idx_0) =
|
||||
type_convert<typename YIndexDistributedTensor_::DataType>(current_idx);
|
||||
}
|
||||
else
|
||||
{
|
||||
Accumulate{}(reduce_func, y_tensor(idx_0), val);
|
||||
}
|
||||
}(idx_));
|
||||
},
|
||||
ReducePacksPerXDim{});
|
||||
}
|
||||
|
||||
public:
|
||||
// Overload for non-index tracking
|
||||
template <
|
||||
typename XDistributedTensor_,
|
||||
typename YDistributedTensor_,
|
||||
@@ -61,13 +109,36 @@ struct BlockReduce2d
|
||||
const ReduceFunc& reduce_func,
|
||||
ReducePacksPerXDim = {})
|
||||
{
|
||||
sweep_tile<XDistributedTensor_>(
|
||||
[&](auto... idx_) {
|
||||
constexpr auto idx_0 = make_tuple(make_tuple(idx_[number<0>{}]...)[number<0>{}]);
|
||||
y_tensor(idx_0) = reduce_func(
|
||||
y_tensor(idx_0), ck_tile::type_convert<ComputeDataType>(x_tensor[idx_])...);
|
||||
},
|
||||
reduce_impl<false>(
|
||||
x_tensor,
|
||||
y_tensor,
|
||||
y_tensor, // dummy
|
||||
reduce_func,
|
||||
[](auto) { return 0; }, // dummy
|
||||
ReducePacksPerXDim{});
|
||||
}
|
||||
|
||||
// Overload for index tracking
|
||||
template <typename XDistributedTensor_,
|
||||
typename YDistributedTensor_,
|
||||
typename YIndexDistributedTensor_,
|
||||
typename ReduceFunc,
|
||||
typename IndexCalculatorFunc,
|
||||
typename ReducePacksPerXDim = uniform_sequence_gen_t<2, 1>>
|
||||
CK_TILE_DEVICE void operator()(const XDistributedTensor_& x_tensor,
|
||||
YDistributedTensor_& y_tensor,
|
||||
YIndexDistributedTensor_& y_index_tensor,
|
||||
const ReduceFunc& reduce_func,
|
||||
const IndexCalculatorFunc& index_calculator,
|
||||
ReducePacksPerXDim = {})
|
||||
{
|
||||
reduce_impl<Problem::kOutputIndex>(x_tensor,
|
||||
y_tensor,
|
||||
y_index_tensor,
|
||||
reduce_func,
|
||||
index_calculator,
|
||||
ReducePacksPerXDim{});
|
||||
}
|
||||
|
||||
#if 0
|
||||
constexpr auto I0 = number<0>{};
|
||||
@@ -90,7 +161,6 @@ struct BlockReduce2d
|
||||
y_tensor(y_dstr_idx) = y;
|
||||
});
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename XDistributedTensor_>
|
||||
CK_TILE_DEVICE static auto MakeYBlockTile()
|
||||
@@ -111,6 +181,25 @@ struct BlockReduce2d
|
||||
return tensor;
|
||||
}
|
||||
|
||||
template <typename XDistributedTensor_, typename IndexDataType = index_t>
|
||||
CK_TILE_DEVICE static auto MakeYIndexBlockTile()
|
||||
{
|
||||
static_assert(std::is_same_v<XDataType, typename XDistributedTensor_::DataType>, "wrong!");
|
||||
|
||||
// FIXME: hard coded to reduce 2nd axis
|
||||
constexpr auto reduce_dims = sequence<1>{};
|
||||
|
||||
constexpr auto dstr =
|
||||
make_static_tile_distribution(detail::make_reduce_tile_distribution_encoding(
|
||||
XDistributedTensor_::get_tile_distribution()
|
||||
.get_static_tile_distribution_encoding(),
|
||||
reduce_dims));
|
||||
|
||||
auto tensor = make_static_distributed_tensor<IndexDataType>(dstr);
|
||||
|
||||
return tensor;
|
||||
}
|
||||
|
||||
// uniform_sequence_gen_t<NSize, Value> generates sequence of NSize elements filled with Value
|
||||
// e.g., uniform_sequence_gen_t<2, 1> → {1, 1} and uniform_sequence_gen_t<3, 4> → {4, 4, 4}
|
||||
template <typename XDistributedTensor_,
|
||||
@@ -135,8 +224,14 @@ struct BlockReduce2dSync
|
||||
{
|
||||
using Problem = remove_cvref_t<Problem_>;
|
||||
|
||||
template <typename YDistributedTensor_, typename ReduceFunc>
|
||||
CK_TILE_DEVICE void operator()(YDistributedTensor_& y_tensor, const ReduceFunc& reduce_func)
|
||||
private:
|
||||
template <bool kProcessIndex,
|
||||
typename YDistributedTensor_,
|
||||
typename YIndexDistributedTensor_,
|
||||
typename ReduceFunc>
|
||||
CK_TILE_DEVICE void reduce_impl(YDistributedTensor_& y_tensor,
|
||||
YIndexDistributedTensor_& y_index_tensor,
|
||||
const ReduceFunc& reduce_func)
|
||||
{
|
||||
using Dstr = typename YDistributedTensor_::StaticTileDistribution;
|
||||
using DstrEncode = typename Dstr::DstrEncode;
|
||||
@@ -157,6 +252,14 @@ struct BlockReduce2dSync
|
||||
static_for<0, thread_buf_size, 1>{}([&](auto i) {
|
||||
auto v_local = y_tensor.get_thread_buffer()[i];
|
||||
|
||||
using IndexDataType = typename YIndexDistributedTensor_::DataType;
|
||||
IndexDataType idx_local{};
|
||||
|
||||
if constexpr(kProcessIndex)
|
||||
{
|
||||
idx_local = y_index_tensor.get_thread_buffer()[i];
|
||||
}
|
||||
|
||||
// cross-lane reduce for replication
|
||||
// only reduce on R dimension correspond to lane
|
||||
// (lane id maps to this R dimension)
|
||||
@@ -183,15 +286,46 @@ struct BlockReduce2dSync
|
||||
|
||||
// pull data from remote lane
|
||||
const auto v_remote = warp_shuffle(v_local, src_lane);
|
||||
v_local = reduce_func(v_local, v_remote);
|
||||
|
||||
if constexpr(kProcessIndex)
|
||||
{
|
||||
const auto idx_remote = warp_shuffle(idx_local, src_lane);
|
||||
|
||||
AccumulateWithIndex{}(
|
||||
reduce_func, v_local, idx_local, v_remote, idx_remote);
|
||||
}
|
||||
else
|
||||
{
|
||||
Accumulate{}(reduce_func, v_local, v_remote);
|
||||
}
|
||||
});
|
||||
}
|
||||
});
|
||||
|
||||
// TODO - Do we need to broadcast to other lane?
|
||||
y_tensor.get_thread_buffer()(i) = v_local;
|
||||
|
||||
if constexpr(kProcessIndex)
|
||||
{
|
||||
y_index_tensor.get_thread_buffer()(i) = idx_local;
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
public:
|
||||
template <typename YDistributedTensor_, typename ReduceFunc>
|
||||
CK_TILE_DEVICE void operator()(YDistributedTensor_& y_tensor, const ReduceFunc& reduce_func)
|
||||
{
|
||||
reduce_impl<false>(y_tensor, y_tensor, reduce_func);
|
||||
}
|
||||
|
||||
template <typename YDistributedTensor_, typename YIndexDistributedTensor_, typename ReduceFunc>
|
||||
CK_TILE_DEVICE void operator()(YDistributedTensor_& y_tensor,
|
||||
YIndexDistributedTensor_& y_index_tensor,
|
||||
const ReduceFunc& reduce_func)
|
||||
{
|
||||
reduce_impl<Problem::kOutputIndex>(y_tensor, y_index_tensor, reduce_func);
|
||||
}
|
||||
};
|
||||
|
||||
// BlockReduce2dCrossWarpSync: Cross-warp reduction (Stage 3)
|
||||
@@ -250,15 +384,39 @@ struct BlockReduce2dCrossWarpSync
|
||||
return num_warps * thread_buf_size * sizeof(DataType);
|
||||
}
|
||||
|
||||
template <typename YDistributedTensor_, typename ReduceFunc>
|
||||
CK_TILE_DEVICE void
|
||||
operator()(YDistributedTensor_& y_tensor, void* smem, const ReduceFunc& reduce_func)
|
||||
// return in byte - separate shared memory size calculation for indices
|
||||
template <typename YIndexDistributedTensor_>
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetIndicesSmemSize()
|
||||
{
|
||||
using DataType = typename YDistributedTensor_::DataType;
|
||||
using IndexDataType = typename YIndexDistributedTensor_::DataType;
|
||||
constexpr index_t thread_buf_size = YIndexDistributedTensor_::get_thread_buffer_size();
|
||||
constexpr index_t num_warps = BlockShape::BlockSize / get_warp_size();
|
||||
return num_warps * thread_buf_size * sizeof(IndexDataType);
|
||||
}
|
||||
|
||||
private:
|
||||
template <bool kProcessIndex,
|
||||
typename YDistributedTensor_,
|
||||
typename YIndexDistributedTensor_,
|
||||
typename ReduceFunc>
|
||||
CK_TILE_DEVICE void reduce_impl(YDistributedTensor_& y_tensor,
|
||||
YIndexDistributedTensor_& y_index_tensor,
|
||||
void* smem,
|
||||
void* smem_indices_ptr,
|
||||
const ReduceFunc& reduce_func)
|
||||
{
|
||||
using DataType = typename YDistributedTensor_::DataType;
|
||||
using IndexDataType = typename YIndexDistributedTensor_::DataType;
|
||||
|
||||
constexpr index_t thread_buf_size = YDistributedTensor_::get_thread_buffer_size();
|
||||
|
||||
DataType* smem_ptr = reinterpret_cast<DataType*>(smem);
|
||||
DataType* smem_ptr = reinterpret_cast<DataType*>(smem);
|
||||
IndexDataType* smem_indices = nullptr;
|
||||
if constexpr(kProcessIndex)
|
||||
{
|
||||
smem_indices = reinterpret_cast<IndexDataType*>(smem_indices_ptr);
|
||||
}
|
||||
|
||||
const index_t lane_id = get_lane_id();
|
||||
const index_t warp_id = get_warp_id();
|
||||
|
||||
@@ -275,6 +433,11 @@ struct BlockReduce2dCrossWarpSync
|
||||
static_for<0, thread_buf_size, 1>{}([&](auto i) {
|
||||
// Store the i-th element of this warp's thread_buffer into SMEM
|
||||
smem_ptr[smem_offset + i * num_warps] = y_tensor.get_thread_buffer()[i];
|
||||
if constexpr(kProcessIndex)
|
||||
{
|
||||
smem_indices[smem_offset + i * num_warps] =
|
||||
y_index_tensor.get_thread_buffer()[i];
|
||||
}
|
||||
});
|
||||
}
|
||||
block_sync_lds();
|
||||
@@ -282,10 +445,19 @@ struct BlockReduce2dCrossWarpSync
|
||||
// We let each warp holds a duplication to do reduction.
|
||||
const index_t local_warp_id = warp_id / num_reduce_warps;
|
||||
const index_t local_smem_os = local_warp_id * num_reduce_warps;
|
||||
|
||||
static_for<0, thread_buf_size, 1>{}([&](auto i) {
|
||||
DataType v[num_reduce_warps];
|
||||
static_for<0, num_reduce_warps, 1>{}(
|
||||
[&](auto idx) { v[idx] = smem_ptr[i * num_warps + local_smem_os + idx]; });
|
||||
[[maybe_unused]] std::
|
||||
conditional_t<kProcessIndex, IndexDataType[num_reduce_warps], IndexDataType> idx_v;
|
||||
|
||||
static_for<0, num_reduce_warps, 1>{}([&](auto idx) {
|
||||
v[idx] = smem_ptr[i * num_warps + local_smem_os + idx];
|
||||
if constexpr(kProcessIndex)
|
||||
{
|
||||
idx_v[idx] = smem_indices[i * num_warps + local_smem_os + idx];
|
||||
}
|
||||
});
|
||||
|
||||
static_assert(is_power_of_two_integer(num_reduce_warps),
|
||||
"wrong! only support power of 2 reduction");
|
||||
@@ -299,14 +471,44 @@ struct BlockReduce2dCrossWarpSync
|
||||
constexpr index_t i1 = idx_ + stride;
|
||||
if constexpr(i1 < num_reduce_warps)
|
||||
{
|
||||
v[i0] = reduce_func(v[i0], v[i1]);
|
||||
if constexpr(kProcessIndex)
|
||||
{
|
||||
AccumulateWithIndex{}(reduce_func, v[i0], idx_v[i0], v[i1], idx_v[i1]);
|
||||
}
|
||||
else
|
||||
{
|
||||
Accumulate{}(reduce_func, v[i0], v[i1]);
|
||||
}
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
y_tensor.get_thread_buffer()(i) = v[0];
|
||||
if constexpr(kProcessIndex)
|
||||
{
|
||||
y_index_tensor.get_thread_buffer()(i) = idx_v[0];
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
public:
|
||||
template <typename YDistributedTensor_, typename ReduceFunc>
|
||||
CK_TILE_DEVICE void
|
||||
operator()(YDistributedTensor_& y_tensor, void* smem, const ReduceFunc& reduce_func)
|
||||
{
|
||||
reduce_impl<false>(y_tensor, y_tensor, smem, nullptr, reduce_func);
|
||||
}
|
||||
|
||||
template <typename YDistributedTensor_, typename YIndexDistributedTensor_, typename ReduceFunc>
|
||||
CK_TILE_DEVICE void operator()(YDistributedTensor_& y_tensor,
|
||||
YIndexDistributedTensor_& y_index_tensor,
|
||||
void* smem,
|
||||
void* smem_indices,
|
||||
const ReduceFunc& reduce_func)
|
||||
{
|
||||
reduce_impl<Problem::kOutputIndex>(
|
||||
y_tensor, y_index_tensor, smem, smem_indices, reduce_func);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Problem_, typename Policy_ = void>
|
||||
@@ -364,15 +566,39 @@ struct BlockReduce2dLinearCrossWarpSync
|
||||
return num_warps * thread_buf_size * sizeof(DataType);
|
||||
}
|
||||
|
||||
template <typename YDistributedTensor_, typename ReduceFunc>
|
||||
CK_TILE_DEVICE void
|
||||
operator()(YDistributedTensor_& y_tensor, void* smem, const ReduceFunc& reduce_func)
|
||||
// return in byte - separate shared memory size calculation for indices
|
||||
template <typename YIndexDistributedTensor_>
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetIndicesSmemSize()
|
||||
{
|
||||
using DataType = typename YDistributedTensor_::DataType;
|
||||
using IndexDataType = typename YIndexDistributedTensor_::DataType;
|
||||
constexpr index_t thread_buf_size = YIndexDistributedTensor_::get_thread_buffer_size();
|
||||
constexpr index_t num_warps = BlockShape::BlockSize / get_warp_size();
|
||||
return num_warps * thread_buf_size * sizeof(IndexDataType);
|
||||
}
|
||||
|
||||
private:
|
||||
template <bool kProcessIndex,
|
||||
typename YDistributedTensor_,
|
||||
typename YIndexDistributedTensor_,
|
||||
typename ReduceFunc>
|
||||
CK_TILE_DEVICE void reduce_impl(YDistributedTensor_& y_tensor,
|
||||
YIndexDistributedTensor_& y_index_tensor,
|
||||
void* smem,
|
||||
void* smem_indices_ptr,
|
||||
const ReduceFunc& reduce_func)
|
||||
{
|
||||
using DataType = typename YDistributedTensor_::DataType;
|
||||
using IndexDataType = typename YIndexDistributedTensor_::DataType;
|
||||
|
||||
constexpr index_t thread_buf_size = YDistributedTensor_::get_thread_buffer_size();
|
||||
|
||||
DataType* smem_ptr = reinterpret_cast<DataType*>(smem);
|
||||
DataType* smem_ptr = reinterpret_cast<DataType*>(smem);
|
||||
IndexDataType* smem_indices = nullptr;
|
||||
if constexpr(kProcessIndex)
|
||||
{
|
||||
smem_indices = reinterpret_cast<IndexDataType*>(smem_indices_ptr);
|
||||
}
|
||||
|
||||
const index_t lane_id = get_lane_id();
|
||||
const index_t warp_id = get_warp_id();
|
||||
constexpr auto num_reduce_warps = GetReduceWarps<YDistributedTensor_>();
|
||||
@@ -388,6 +614,11 @@ struct BlockReduce2dLinearCrossWarpSync
|
||||
{
|
||||
static_for<0, thread_buf_size, 1>{}([&](auto i) {
|
||||
smem_ptr[smem_offset + i * num_warps] = y_tensor.get_thread_buffer()[i];
|
||||
if constexpr(kProcessIndex)
|
||||
{
|
||||
smem_indices[smem_offset + i * num_warps] =
|
||||
y_index_tensor.get_thread_buffer()[i];
|
||||
}
|
||||
});
|
||||
}
|
||||
block_sync_lds();
|
||||
@@ -395,31 +626,86 @@ struct BlockReduce2dLinearCrossWarpSync
|
||||
// load from smem. here we let everythread to do compute :)
|
||||
index_t local_warp_id = warp_id / num_reduce_warps;
|
||||
index_t local_smem_os = local_warp_id * num_reduce_warps;
|
||||
|
||||
DataType all_scratch[thread_buf_size * num_reduce_warps];
|
||||
[[maybe_unused]] std::conditional_t<kProcessIndex,
|
||||
IndexDataType[thread_buf_size * num_reduce_warps],
|
||||
IndexDataType> all_indices;
|
||||
|
||||
// Load data from shared memory
|
||||
static_for<0, thread_buf_size, 1>{}([&](auto i_0) {
|
||||
static_for<0, num_reduce_warps, 1>{}([&](auto i_1) {
|
||||
all_scratch[i_0 * num_reduce_warps + i_1] =
|
||||
smem_ptr[i_0 * num_warps + local_smem_os + i_1];
|
||||
|
||||
if constexpr(kProcessIndex)
|
||||
{
|
||||
all_indices[i_0 * num_reduce_warps + i_1] =
|
||||
smem_indices[i_0 * num_warps + local_smem_os + i_1];
|
||||
}
|
||||
});
|
||||
});
|
||||
block_sync_lds(); // TODO: we don't need sync here
|
||||
|
||||
// Perform reduction
|
||||
static_for<0, thread_buf_size, 1>{}([&](auto i_0) {
|
||||
// TODO: use descriptor for this
|
||||
auto v_local = all_scratch[i_0 * num_reduce_warps];
|
||||
|
||||
IndexDataType idx_local{};
|
||||
if constexpr(kProcessIndex)
|
||||
{
|
||||
idx_local = all_indices[i_0 * num_reduce_warps];
|
||||
}
|
||||
|
||||
// further reduce mean/var
|
||||
static_for<0, num_reduce_warps - 1, 1>{}([&](auto i_1_n1) {
|
||||
constexpr auto i_1 = number<i_1_n1 + 1>{};
|
||||
const DataType v_remote = all_scratch[i_0 * num_reduce_warps + i_1];
|
||||
|
||||
// reduce
|
||||
v_local = reduce_func(v_local, v_remote);
|
||||
if constexpr(kProcessIndex)
|
||||
{
|
||||
const IndexDataType idx_remote = all_indices[i_0 * num_reduce_warps + i_1];
|
||||
|
||||
bool changed = false;
|
||||
v_local = reduce_func(v_local, v_remote, changed);
|
||||
if(changed)
|
||||
{
|
||||
idx_local = idx_remote;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
v_local = reduce_func(v_local, v_remote);
|
||||
}
|
||||
});
|
||||
|
||||
y_tensor.get_thread_buffer()(i_0) = v_local;
|
||||
if constexpr(kProcessIndex)
|
||||
{
|
||||
y_index_tensor.get_thread_buffer()(i_0) = idx_local;
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
public:
|
||||
template <typename YDistributedTensor_, typename ReduceFunc>
|
||||
CK_TILE_DEVICE void
|
||||
operator()(YDistributedTensor_& y_tensor, void* smem, const ReduceFunc& reduce_func)
|
||||
{
|
||||
reduce_impl<false>(y_tensor, y_tensor, smem, nullptr, reduce_func);
|
||||
}
|
||||
|
||||
template <typename YDistributedTensor_, typename YIndexDistributedTensor_, typename ReduceFunc>
|
||||
CK_TILE_DEVICE void operator()(YDistributedTensor_& y_tensor,
|
||||
YIndexDistributedTensor_& y_index_tensor,
|
||||
void* smem,
|
||||
void* smem_indices,
|
||||
const ReduceFunc& reduce_func)
|
||||
{
|
||||
reduce_impl<Problem::kOutputIndex>(
|
||||
y_tensor, y_index_tensor, smem, smem_indices, reduce_func);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -7,12 +7,17 @@
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename XDataType_, typename ComputeDataType_, typename BlockShape_>
|
||||
template <typename XDataType_,
|
||||
typename ComputeDataType_,
|
||||
typename BlockShape_,
|
||||
bool OutputIndex_ = false>
|
||||
struct BlockReduce2dProblem
|
||||
{
|
||||
using XDataType = remove_cvref_t<XDataType_>;
|
||||
using ComputeDataType = remove_cvref_t<ComputeDataType_>;
|
||||
using BlockShape = remove_cvref_t<BlockShape_>;
|
||||
|
||||
static constexpr bool kOutputIndex = OutputIndex_;
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -32,7 +32,8 @@ struct Reduce2dDefaultPolicy
|
||||
{
|
||||
using P_ = BlockReduce2dProblem<typename Problem::XDataType,
|
||||
typename Problem::ComputeDataType,
|
||||
typename Problem::BlockShape>;
|
||||
typename Problem::BlockShape,
|
||||
Problem::kOutputIndex>;
|
||||
return BlockReduce2d<P_>{};
|
||||
}
|
||||
|
||||
@@ -41,7 +42,8 @@ struct Reduce2dDefaultPolicy
|
||||
{
|
||||
using P_ = BlockReduce2dProblem<typename Problem::XDataType,
|
||||
typename Problem::ComputeDataType,
|
||||
typename Problem::BlockShape>;
|
||||
typename Problem::BlockShape,
|
||||
Problem::kOutputIndex>;
|
||||
return BlockReduce2dSync<P_>{};
|
||||
}
|
||||
|
||||
@@ -50,7 +52,8 @@ struct Reduce2dDefaultPolicy
|
||||
{
|
||||
using P_ = BlockReduce2dProblem<typename Problem::XDataType,
|
||||
typename Problem::ComputeDataType,
|
||||
typename Problem::BlockShape>;
|
||||
typename Problem::BlockShape,
|
||||
Problem::kOutputIndex>;
|
||||
return BlockReduce2dCrossWarpSync<P_>{};
|
||||
}
|
||||
|
||||
@@ -61,7 +64,8 @@ struct Reduce2dDefaultPolicy
|
||||
{
|
||||
using P_ = BlockReduce2dProblem<typename Problem::XDataType,
|
||||
typename Problem::ComputeDataType,
|
||||
typename Problem::BlockShape>;
|
||||
typename Problem::BlockShape,
|
||||
Problem::kOutputIndex>;
|
||||
|
||||
using block_reduce2d = BlockReduce2d<P_>;
|
||||
using x_block_tile =
|
||||
@@ -76,5 +80,23 @@ struct Reduce2dDefaultPolicy
|
||||
return 1; // zero size arrays are an extension
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetIndicesSmemSize()
|
||||
{
|
||||
using P_ = BlockReduce2dProblem<typename Problem::XDataType,
|
||||
typename Problem::ComputeDataType,
|
||||
typename Problem::BlockShape,
|
||||
Problem::kOutputIndex>;
|
||||
|
||||
using block_reduce2d = BlockReduce2d<P_>;
|
||||
using x_block_tile = decltype(make_static_distributed_tensor<typename Problem::XDataType>(
|
||||
MakeXBlockTileDistribution<Problem>()));
|
||||
using y_index_block_tile =
|
||||
decltype(block_reduce2d::template MakeYIndexBlockTile<x_block_tile, index_t>());
|
||||
|
||||
return GetBlockReduce2dCrossWarpSync<Problem>()
|
||||
.template GetIndicesSmemSize<y_index_block_tile>();
|
||||
}
|
||||
};
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -11,7 +11,8 @@ template <typename XDataType_,
|
||||
typename ComputeDataType_,
|
||||
typename YDataType_,
|
||||
typename BlockShape_,
|
||||
typename ReduceOp_>
|
||||
typename ReduceOp_,
|
||||
bool OutputIndex_ = false>
|
||||
struct Reduce2dProblem
|
||||
{
|
||||
using XDataType = remove_cvref_t<XDataType_>;
|
||||
@@ -20,6 +21,7 @@ struct Reduce2dProblem
|
||||
using BlockShape = remove_cvref_t<BlockShape_>;
|
||||
using ReduceOp = ReduceOp_;
|
||||
|
||||
static constexpr bool kOutputIndex = OutputIndex_;
|
||||
static constexpr bool kNeedCrossLaneSync = BlockShape::ThreadPerWarp_N > 1;
|
||||
static constexpr bool kNeedCrossWarpSync = BlockShape::WarpPerBlock_N > 1;
|
||||
};
|
||||
|
||||
@@ -28,7 +28,19 @@ class TestCkTilePooling : public ::testing::Test
|
||||
|
||||
using TestPoolShape = ck_tile::PoolShape<BlockWarps_, BlockTile_, WarpTile_, ThreadTile_>;
|
||||
|
||||
// 3D pooling configuration
|
||||
// 2D pooling configuration (NHWC)
|
||||
struct Config2D
|
||||
{
|
||||
ck_tile::index_t N, H, W, C;
|
||||
ck_tile::index_t Y, X;
|
||||
ck_tile::index_t Sy, Sx;
|
||||
ck_tile::index_t Dy, Dx;
|
||||
ck_tile::index_t LeftPy, LeftPx;
|
||||
ck_tile::index_t RightPy, RightPx;
|
||||
std::string name;
|
||||
};
|
||||
|
||||
// 3D pooling configuration (NDHWC)
|
||||
struct Config3D
|
||||
{
|
||||
ck_tile::index_t N, D, H, W, C;
|
||||
@@ -40,6 +52,117 @@ class TestCkTilePooling : public ::testing::Test
|
||||
std::string name;
|
||||
};
|
||||
|
||||
bool RunPool2D(const Config2D& config)
|
||||
{
|
||||
std::cout << "Testing 2D: " << config.name << " ... ";
|
||||
|
||||
const ck_tile::index_t Ys = (config.Y - 1) * config.Dy + 1;
|
||||
const ck_tile::index_t Xs = (config.X - 1) * config.Dx + 1;
|
||||
const ck_tile::index_t Ho =
|
||||
(config.H + config.LeftPy + config.RightPy - Ys) / config.Sy + 1;
|
||||
const ck_tile::index_t Wo =
|
||||
(config.W + config.LeftPx + config.RightPx - Xs) / config.Sx + 1;
|
||||
|
||||
using IndexDataType = ck_tile::index_t;
|
||||
|
||||
// Host tensors
|
||||
ck_tile::HostTensor<InDataType> h_in({config.N, config.H, config.W, config.C});
|
||||
ck_tile::HostTensor<OutDataType> h_out({config.N, Ho, Wo, config.C});
|
||||
ck_tile::HostTensor<OutDataType> h_out_ref({config.N, Ho, Wo, config.C});
|
||||
ck_tile::HostTensor<IndexDataType> h_out_index({config.N, Ho, Wo, config.C});
|
||||
ck_tile::HostTensor<IndexDataType> h_out_ref_index({config.N, Ho, Wo, config.C});
|
||||
|
||||
// Initialize input with random data
|
||||
ck_tile::FillUniformDistribution<InDataType>{-5.f, 5.f}(h_in);
|
||||
|
||||
// Device memory
|
||||
ck_tile::DeviceMem d_in_mem(h_in.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem d_out_mem(h_out.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem d_out_index_mem(h_out_index.get_element_space_size_in_bytes());
|
||||
|
||||
d_in_mem.ToDevice(h_in.data());
|
||||
d_out_mem.ToDevice(h_out.data());
|
||||
d_out_index_mem.ToDevice(h_out_index.data());
|
||||
|
||||
constexpr ck_tile::index_t kBlockPerCu = 1;
|
||||
|
||||
using Problem = ck_tile::PoolProblem<InDataType,
|
||||
OutDataType,
|
||||
ComputeDataType,
|
||||
IndexDataType,
|
||||
ReduceOpType,
|
||||
true, // OutputIndex
|
||||
false, // PropagateNan
|
||||
TestPoolShape>;
|
||||
using Kernel = ck_tile::PoolKernel<Problem>;
|
||||
|
||||
const ck_tile::index_t kBlockSize = Kernel::BlockSize();
|
||||
|
||||
// Shapes and strides (NHWC)
|
||||
const auto input_shape = ck_tile::make_tuple(config.N, config.H, config.W, config.C);
|
||||
const auto output_shape = ck_tile::make_tuple(config.N, Ho, Wo, config.C);
|
||||
const auto input_strides =
|
||||
ck_tile::make_tuple(config.H * config.W * config.C, config.W * config.C, config.C, 1);
|
||||
const auto output_strides =
|
||||
ck_tile::make_tuple(Ho * Wo * config.C, Wo * config.C, config.C, 1);
|
||||
const auto window_spatial_lengths = ck_tile::make_tuple(config.Y, config.X);
|
||||
const auto window_strides = ck_tile::make_tuple(config.Sy, config.Sx);
|
||||
const auto window_dilations = ck_tile::make_tuple(config.Dy, config.Dx);
|
||||
const auto input_left_pads = ck_tile::make_tuple(config.LeftPy, config.LeftPx);
|
||||
const auto input_right_pads = ck_tile::make_tuple(config.RightPy, config.RightPx);
|
||||
|
||||
auto host_args =
|
||||
ck_tile::PoolHostArgs<decltype(input_shape), decltype(window_spatial_lengths)>{
|
||||
static_cast<InDataType*>(d_in_mem.GetDeviceBuffer()),
|
||||
static_cast<OutDataType*>(d_out_mem.GetDeviceBuffer()),
|
||||
static_cast<IndexDataType*>(d_out_index_mem.GetDeviceBuffer()),
|
||||
input_shape,
|
||||
output_shape,
|
||||
input_strides,
|
||||
output_strides,
|
||||
window_spatial_lengths,
|
||||
window_strides,
|
||||
window_dilations,
|
||||
input_left_pads,
|
||||
input_right_pads};
|
||||
|
||||
auto kernel_args = Kernel::MakeKernelArgs(host_args);
|
||||
const ck_tile::index_t kGridSize = Kernel::CalculateGridSize(kernel_args);
|
||||
|
||||
if(!Kernel::IsSupportedArgument(kernel_args))
|
||||
{
|
||||
return true;
|
||||
}
|
||||
|
||||
// Run kernel
|
||||
ck_tile::launch_kernel(
|
||||
ck_tile::stream_config{nullptr, false, 0},
|
||||
ck_tile::make_kernel<kBlockPerCu>(Kernel{}, kGridSize, kBlockSize, 0, kernel_args));
|
||||
|
||||
// Run reference
|
||||
ck_tile::reference_pool2d<InDataType,
|
||||
ComputeDataType,
|
||||
OutDataType,
|
||||
IndexDataType,
|
||||
ReduceOpType,
|
||||
decltype(input_shape),
|
||||
decltype(window_spatial_lengths),
|
||||
true>(
|
||||
h_in, h_out_ref, h_out_ref_index, kernel_args, ReduceOpType{});
|
||||
|
||||
d_out_mem.FromDevice(h_out.data());
|
||||
d_out_index_mem.FromDevice(h_out_index.data());
|
||||
|
||||
// Validate results
|
||||
bool pass_value =
|
||||
ck_tile::check_err(h_out, h_out_ref, "Error: Incorrect values!", 1e-5, 1e-5);
|
||||
bool pass_index = ck_tile::check_err(
|
||||
h_out_index, h_out_ref_index, "Error: Incorrect indices!", 1e-5, 1e-5);
|
||||
|
||||
std::cout << (pass_value && pass_index ? "PASS" : "FAIL") << std::endl;
|
||||
return pass_value && pass_index;
|
||||
}
|
||||
|
||||
bool RunPool3D(const Config3D& config)
|
||||
{
|
||||
std::cout << "Testing 3D: " << config.name << " ... ";
|
||||
@@ -72,6 +195,8 @@ class TestCkTilePooling : public ::testing::Test
|
||||
const auto input_right_pads =
|
||||
ck_tile::make_tuple(config.RightPz, config.RightPy, config.RightPx);
|
||||
|
||||
using IndexDataType = ck_tile::index_t;
|
||||
|
||||
ck_tile::HostTensor<InDataType> h_in({config.N, config.D, config.H, config.W, config.C},
|
||||
{config.D * config.H * config.W * config.C,
|
||||
config.H * config.W * config.C,
|
||||
@@ -84,6 +209,12 @@ class TestCkTilePooling : public ::testing::Test
|
||||
ck_tile::HostTensor<OutDataType> h_out_ref(
|
||||
{config.N, Do, Ho, Wo, config.C},
|
||||
{Do * Ho * Wo * config.C, Ho * Wo * config.C, Wo * config.C, config.C, 1});
|
||||
ck_tile::HostTensor<IndexDataType> h_out_index(
|
||||
{config.N, Do, Ho, Wo, config.C},
|
||||
{Do * Ho * Wo * config.C, Ho * Wo * config.C, Wo * config.C, config.C, 1});
|
||||
ck_tile::HostTensor<IndexDataType> h_out_ref_index(
|
||||
{config.N, Do, Ho, Wo, config.C},
|
||||
{Do * Ho * Wo * config.C, Ho * Wo * config.C, Wo * config.C, config.C, 1});
|
||||
|
||||
ck_tile::FillUniformDistribution<InDataType>{-5.f, 5.f}(h_in);
|
||||
h_out.SetZero();
|
||||
@@ -91,17 +222,19 @@ class TestCkTilePooling : public ::testing::Test
|
||||
|
||||
ck_tile::DeviceMem d_in_mem(h_in.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem d_out_mem(h_out.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem d_out_index_mem(h_out_index.get_element_space_size_in_bytes());
|
||||
|
||||
d_in_mem.ToDevice(h_in.data());
|
||||
d_out_mem.ToDevice(h_out.data());
|
||||
d_out_index_mem.ToDevice(h_out_index.data());
|
||||
|
||||
using Problem = ck_tile::PoolProblem<InDataType,
|
||||
OutDataType,
|
||||
ComputeDataType,
|
||||
OutDataType,
|
||||
IndexDataType,
|
||||
ReduceOpType,
|
||||
false,
|
||||
false,
|
||||
true, // OutputIndex
|
||||
false, // PropagateNan
|
||||
TestPoolShape>;
|
||||
using Kernel = ck_tile::PoolKernel<Problem>;
|
||||
|
||||
@@ -112,6 +245,7 @@ class TestCkTilePooling : public ::testing::Test
|
||||
ck_tile::PoolHostArgs<decltype(input_shape), decltype(window_spatial_lengths)>{
|
||||
static_cast<InDataType*>(d_in_mem.GetDeviceBuffer()),
|
||||
static_cast<OutDataType*>(d_out_mem.GetDeviceBuffer()),
|
||||
static_cast<IndexDataType*>(d_out_index_mem.GetDeviceBuffer()),
|
||||
input_shape,
|
||||
output_shape,
|
||||
input_strides,
|
||||
@@ -137,16 +271,27 @@ class TestCkTilePooling : public ::testing::Test
|
||||
ck_tile::make_kernel<kBlockPerCu>(Kernel{}, kGridSize, kBlockSize, 0, kernel_args));
|
||||
|
||||
// Run reference implementation
|
||||
ck_tile::reference_pool3d<InDataType, ComputeDataType, OutDataType>(
|
||||
h_in, h_out_ref, kernel_args, ReduceOpType{});
|
||||
ck_tile::reference_pool3d<InDataType,
|
||||
ComputeDataType,
|
||||
OutDataType,
|
||||
IndexDataType,
|
||||
ReduceOpType,
|
||||
decltype(input_shape),
|
||||
decltype(window_spatial_lengths),
|
||||
true>(
|
||||
h_in, h_out_ref, h_out_ref_index, kernel_args, ReduceOpType{});
|
||||
|
||||
d_out_mem.FromDevice(h_out.data());
|
||||
d_out_index_mem.FromDevice(h_out_index.data());
|
||||
|
||||
// Validate results
|
||||
bool pass = ck_tile::check_err(h_out, h_out_ref);
|
||||
std::cout << (pass ? "PASS" : "FAIL") << std::endl;
|
||||
bool pass_value =
|
||||
ck_tile::check_err(h_out, h_out_ref, "Error: Incorrect values!", 1e-5, 1e-5);
|
||||
bool pass_index = ck_tile::check_err(
|
||||
h_out_index, h_out_ref_index, "Error: Incorrect indices!", 1e-5, 1e-5);
|
||||
|
||||
return pass;
|
||||
std::cout << (pass_value && pass_index ? "PASS" : "FAIL") << std::endl;
|
||||
return pass_value && pass_index;
|
||||
}
|
||||
};
|
||||
|
||||
@@ -194,6 +339,50 @@ using TestTypes =
|
||||
|
||||
TYPED_TEST_SUITE(TestCkTilePooling, TestTypes);
|
||||
|
||||
// 2D Pooling Tests (NHWC)
|
||||
TYPED_TEST(TestCkTilePooling, Pool2D_2x2)
|
||||
{
|
||||
typename TestFixture::Config2D config = {1, // N - batch size
|
||||
8, // H - height dimension
|
||||
8, // W - width dimension
|
||||
32, // C - channel dimension
|
||||
2, // Y - pooling window height
|
||||
2, // X - pooling window width
|
||||
2, // Sy - window stride height
|
||||
2, // Sx - window stride width
|
||||
1, // Dy - window dilation height
|
||||
1, // Dx - window dilation width
|
||||
0, // LeftPy - left padding height
|
||||
0, // LeftPx - left padding width
|
||||
0, // RightPy - right padding height
|
||||
0, // RightPx - right padding width
|
||||
"2x2 pooling NHWC"};
|
||||
bool pass = this->RunPool2D(config);
|
||||
EXPECT_TRUE(pass);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestCkTilePooling, Pool2D_3x3_WithPadding)
|
||||
{
|
||||
typename TestFixture::Config2D config = {2, // N - batch size
|
||||
16, // H - height dimension
|
||||
16, // W - width dimension
|
||||
32, // C - channel dimension
|
||||
3, // Y - pooling window height
|
||||
3, // X - pooling window width
|
||||
2, // Sy - window stride height
|
||||
2, // Sx - window stride width
|
||||
1, // Dy - window dilation height
|
||||
1, // Dx - window dilation width
|
||||
1, // LeftPy - left padding height
|
||||
1, // LeftPx - left padding width
|
||||
1, // RightPy - right padding height
|
||||
1, // RightPx - right padding width
|
||||
"3x3 pooling NHWC with padding"};
|
||||
bool pass = this->RunPool2D(config);
|
||||
EXPECT_TRUE(pass);
|
||||
}
|
||||
|
||||
// 3D Pooling Tests (NDHWC)
|
||||
TYPED_TEST(TestCkTilePooling, Pool3D_2x2x2)
|
||||
{
|
||||
typename TestFixture::Config3D config = {1, // N - batch size
|
||||
@@ -216,7 +405,7 @@ TYPED_TEST(TestCkTilePooling, Pool3D_2x2x2)
|
||||
0, // RightPz - right padding depth
|
||||
0, // RightPy - right padding height
|
||||
0, // RightPx - right padding width
|
||||
"2x2x2 pooling"};
|
||||
"2x2x2 pooling NDHWC"};
|
||||
bool pass = this->RunPool3D(config);
|
||||
EXPECT_TRUE(pass);
|
||||
}
|
||||
@@ -243,7 +432,7 @@ TYPED_TEST(TestCkTilePooling, Pool3D_3x3x3)
|
||||
1, // RightPz - right padding depth
|
||||
1, // RightPy - right padding height
|
||||
1, // RightPx - right padding width
|
||||
"3x3x3 pooling"};
|
||||
"3x3x3 pooling NDHWC with padding"};
|
||||
bool pass = this->RunPool3D(config);
|
||||
EXPECT_TRUE(pass);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user