mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 22:39:03 +00:00
[CK_TILE] Add indexing to pooling operator (Lwpck 3892) (#3013)
* Add indexing support to pooling operator - Add IndexDataType template parameter to pooling problem and kernel definitions - Enable pooling kernel to output indices of selected elements during max/absmax pooling - Add overloaded operators for Max and AbsMax that track when values change using bool changed parameter - Support optional index buffer allocation and management in device memory - Modify BlockReduce2d classes to handle index tensors alongside value tensors - Add separate shared memory allocation for index data in cross-warp reductions - Create validate_pool_indices function to verify index correctness - Modify pool3d.cpp example to demonstrate index output functionality - Add tests for index output * fixes * Refactor BlockReduce2D functions to get rid auxiliary private types. * comment resolutions and some changes to block_reduce2d - index reference implementation improved - reduce_operator.hpp cleanedup - updated the block_reduce2d.hpp to have index calculation for BlockReduce2dLinearCrossWarpSync as well * conditionally used variable declaration improvement - the conditionally used vairbales are used only when indexing is enabled. To inform the compiler that they may be unused and declare them with least size possible. This may allow it to be optimized compared to the previous declarations * comment resolutions * lexical ordering of the indicies - introduced accumulate methods that handle the intermediate steps if needed to order the indexes * add reduce_operator_accumulate.hpp to core.hpp --------- Co-authored-by: Adam Osewski <Adam.Osewski@amd.com>
This commit is contained in:
committed by
GitHub
parent
7c6430eca0
commit
3052d7c9e6
@@ -28,7 +28,19 @@ class TestCkTilePooling : public ::testing::Test
|
||||
|
||||
using TestPoolShape = ck_tile::PoolShape<BlockWarps_, BlockTile_, WarpTile_, ThreadTile_>;
|
||||
|
||||
// 3D pooling configuration
|
||||
// 2D pooling configuration (NHWC)
|
||||
struct Config2D
|
||||
{
|
||||
ck_tile::index_t N, H, W, C;
|
||||
ck_tile::index_t Y, X;
|
||||
ck_tile::index_t Sy, Sx;
|
||||
ck_tile::index_t Dy, Dx;
|
||||
ck_tile::index_t LeftPy, LeftPx;
|
||||
ck_tile::index_t RightPy, RightPx;
|
||||
std::string name;
|
||||
};
|
||||
|
||||
// 3D pooling configuration (NDHWC)
|
||||
struct Config3D
|
||||
{
|
||||
ck_tile::index_t N, D, H, W, C;
|
||||
@@ -40,6 +52,117 @@ class TestCkTilePooling : public ::testing::Test
|
||||
std::string name;
|
||||
};
|
||||
|
||||
bool RunPool2D(const Config2D& config)
|
||||
{
|
||||
std::cout << "Testing 2D: " << config.name << " ... ";
|
||||
|
||||
const ck_tile::index_t Ys = (config.Y - 1) * config.Dy + 1;
|
||||
const ck_tile::index_t Xs = (config.X - 1) * config.Dx + 1;
|
||||
const ck_tile::index_t Ho =
|
||||
(config.H + config.LeftPy + config.RightPy - Ys) / config.Sy + 1;
|
||||
const ck_tile::index_t Wo =
|
||||
(config.W + config.LeftPx + config.RightPx - Xs) / config.Sx + 1;
|
||||
|
||||
using IndexDataType = ck_tile::index_t;
|
||||
|
||||
// Host tensors
|
||||
ck_tile::HostTensor<InDataType> h_in({config.N, config.H, config.W, config.C});
|
||||
ck_tile::HostTensor<OutDataType> h_out({config.N, Ho, Wo, config.C});
|
||||
ck_tile::HostTensor<OutDataType> h_out_ref({config.N, Ho, Wo, config.C});
|
||||
ck_tile::HostTensor<IndexDataType> h_out_index({config.N, Ho, Wo, config.C});
|
||||
ck_tile::HostTensor<IndexDataType> h_out_ref_index({config.N, Ho, Wo, config.C});
|
||||
|
||||
// Initialize input with random data
|
||||
ck_tile::FillUniformDistribution<InDataType>{-5.f, 5.f}(h_in);
|
||||
|
||||
// Device memory
|
||||
ck_tile::DeviceMem d_in_mem(h_in.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem d_out_mem(h_out.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem d_out_index_mem(h_out_index.get_element_space_size_in_bytes());
|
||||
|
||||
d_in_mem.ToDevice(h_in.data());
|
||||
d_out_mem.ToDevice(h_out.data());
|
||||
d_out_index_mem.ToDevice(h_out_index.data());
|
||||
|
||||
constexpr ck_tile::index_t kBlockPerCu = 1;
|
||||
|
||||
using Problem = ck_tile::PoolProblem<InDataType,
|
||||
OutDataType,
|
||||
ComputeDataType,
|
||||
IndexDataType,
|
||||
ReduceOpType,
|
||||
true, // OutputIndex
|
||||
false, // PropagateNan
|
||||
TestPoolShape>;
|
||||
using Kernel = ck_tile::PoolKernel<Problem>;
|
||||
|
||||
const ck_tile::index_t kBlockSize = Kernel::BlockSize();
|
||||
|
||||
// Shapes and strides (NHWC)
|
||||
const auto input_shape = ck_tile::make_tuple(config.N, config.H, config.W, config.C);
|
||||
const auto output_shape = ck_tile::make_tuple(config.N, Ho, Wo, config.C);
|
||||
const auto input_strides =
|
||||
ck_tile::make_tuple(config.H * config.W * config.C, config.W * config.C, config.C, 1);
|
||||
const auto output_strides =
|
||||
ck_tile::make_tuple(Ho * Wo * config.C, Wo * config.C, config.C, 1);
|
||||
const auto window_spatial_lengths = ck_tile::make_tuple(config.Y, config.X);
|
||||
const auto window_strides = ck_tile::make_tuple(config.Sy, config.Sx);
|
||||
const auto window_dilations = ck_tile::make_tuple(config.Dy, config.Dx);
|
||||
const auto input_left_pads = ck_tile::make_tuple(config.LeftPy, config.LeftPx);
|
||||
const auto input_right_pads = ck_tile::make_tuple(config.RightPy, config.RightPx);
|
||||
|
||||
auto host_args =
|
||||
ck_tile::PoolHostArgs<decltype(input_shape), decltype(window_spatial_lengths)>{
|
||||
static_cast<InDataType*>(d_in_mem.GetDeviceBuffer()),
|
||||
static_cast<OutDataType*>(d_out_mem.GetDeviceBuffer()),
|
||||
static_cast<IndexDataType*>(d_out_index_mem.GetDeviceBuffer()),
|
||||
input_shape,
|
||||
output_shape,
|
||||
input_strides,
|
||||
output_strides,
|
||||
window_spatial_lengths,
|
||||
window_strides,
|
||||
window_dilations,
|
||||
input_left_pads,
|
||||
input_right_pads};
|
||||
|
||||
auto kernel_args = Kernel::MakeKernelArgs(host_args);
|
||||
const ck_tile::index_t kGridSize = Kernel::CalculateGridSize(kernel_args);
|
||||
|
||||
if(!Kernel::IsSupportedArgument(kernel_args))
|
||||
{
|
||||
return true;
|
||||
}
|
||||
|
||||
// Run kernel
|
||||
ck_tile::launch_kernel(
|
||||
ck_tile::stream_config{nullptr, false, 0},
|
||||
ck_tile::make_kernel<kBlockPerCu>(Kernel{}, kGridSize, kBlockSize, 0, kernel_args));
|
||||
|
||||
// Run reference
|
||||
ck_tile::reference_pool2d<InDataType,
|
||||
ComputeDataType,
|
||||
OutDataType,
|
||||
IndexDataType,
|
||||
ReduceOpType,
|
||||
decltype(input_shape),
|
||||
decltype(window_spatial_lengths),
|
||||
true>(
|
||||
h_in, h_out_ref, h_out_ref_index, kernel_args, ReduceOpType{});
|
||||
|
||||
d_out_mem.FromDevice(h_out.data());
|
||||
d_out_index_mem.FromDevice(h_out_index.data());
|
||||
|
||||
// Validate results
|
||||
bool pass_value =
|
||||
ck_tile::check_err(h_out, h_out_ref, "Error: Incorrect values!", 1e-5, 1e-5);
|
||||
bool pass_index = ck_tile::check_err(
|
||||
h_out_index, h_out_ref_index, "Error: Incorrect indices!", 1e-5, 1e-5);
|
||||
|
||||
std::cout << (pass_value && pass_index ? "PASS" : "FAIL") << std::endl;
|
||||
return pass_value && pass_index;
|
||||
}
|
||||
|
||||
bool RunPool3D(const Config3D& config)
|
||||
{
|
||||
std::cout << "Testing 3D: " << config.name << " ... ";
|
||||
@@ -72,6 +195,8 @@ class TestCkTilePooling : public ::testing::Test
|
||||
const auto input_right_pads =
|
||||
ck_tile::make_tuple(config.RightPz, config.RightPy, config.RightPx);
|
||||
|
||||
using IndexDataType = ck_tile::index_t;
|
||||
|
||||
ck_tile::HostTensor<InDataType> h_in({config.N, config.D, config.H, config.W, config.C},
|
||||
{config.D * config.H * config.W * config.C,
|
||||
config.H * config.W * config.C,
|
||||
@@ -84,6 +209,12 @@ class TestCkTilePooling : public ::testing::Test
|
||||
ck_tile::HostTensor<OutDataType> h_out_ref(
|
||||
{config.N, Do, Ho, Wo, config.C},
|
||||
{Do * Ho * Wo * config.C, Ho * Wo * config.C, Wo * config.C, config.C, 1});
|
||||
ck_tile::HostTensor<IndexDataType> h_out_index(
|
||||
{config.N, Do, Ho, Wo, config.C},
|
||||
{Do * Ho * Wo * config.C, Ho * Wo * config.C, Wo * config.C, config.C, 1});
|
||||
ck_tile::HostTensor<IndexDataType> h_out_ref_index(
|
||||
{config.N, Do, Ho, Wo, config.C},
|
||||
{Do * Ho * Wo * config.C, Ho * Wo * config.C, Wo * config.C, config.C, 1});
|
||||
|
||||
ck_tile::FillUniformDistribution<InDataType>{-5.f, 5.f}(h_in);
|
||||
h_out.SetZero();
|
||||
@@ -91,17 +222,19 @@ class TestCkTilePooling : public ::testing::Test
|
||||
|
||||
ck_tile::DeviceMem d_in_mem(h_in.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem d_out_mem(h_out.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem d_out_index_mem(h_out_index.get_element_space_size_in_bytes());
|
||||
|
||||
d_in_mem.ToDevice(h_in.data());
|
||||
d_out_mem.ToDevice(h_out.data());
|
||||
d_out_index_mem.ToDevice(h_out_index.data());
|
||||
|
||||
using Problem = ck_tile::PoolProblem<InDataType,
|
||||
OutDataType,
|
||||
ComputeDataType,
|
||||
OutDataType,
|
||||
IndexDataType,
|
||||
ReduceOpType,
|
||||
false,
|
||||
false,
|
||||
true, // OutputIndex
|
||||
false, // PropagateNan
|
||||
TestPoolShape>;
|
||||
using Kernel = ck_tile::PoolKernel<Problem>;
|
||||
|
||||
@@ -112,6 +245,7 @@ class TestCkTilePooling : public ::testing::Test
|
||||
ck_tile::PoolHostArgs<decltype(input_shape), decltype(window_spatial_lengths)>{
|
||||
static_cast<InDataType*>(d_in_mem.GetDeviceBuffer()),
|
||||
static_cast<OutDataType*>(d_out_mem.GetDeviceBuffer()),
|
||||
static_cast<IndexDataType*>(d_out_index_mem.GetDeviceBuffer()),
|
||||
input_shape,
|
||||
output_shape,
|
||||
input_strides,
|
||||
@@ -137,16 +271,27 @@ class TestCkTilePooling : public ::testing::Test
|
||||
ck_tile::make_kernel<kBlockPerCu>(Kernel{}, kGridSize, kBlockSize, 0, kernel_args));
|
||||
|
||||
// Run reference implementation
|
||||
ck_tile::reference_pool3d<InDataType, ComputeDataType, OutDataType>(
|
||||
h_in, h_out_ref, kernel_args, ReduceOpType{});
|
||||
ck_tile::reference_pool3d<InDataType,
|
||||
ComputeDataType,
|
||||
OutDataType,
|
||||
IndexDataType,
|
||||
ReduceOpType,
|
||||
decltype(input_shape),
|
||||
decltype(window_spatial_lengths),
|
||||
true>(
|
||||
h_in, h_out_ref, h_out_ref_index, kernel_args, ReduceOpType{});
|
||||
|
||||
d_out_mem.FromDevice(h_out.data());
|
||||
d_out_index_mem.FromDevice(h_out_index.data());
|
||||
|
||||
// Validate results
|
||||
bool pass = ck_tile::check_err(h_out, h_out_ref);
|
||||
std::cout << (pass ? "PASS" : "FAIL") << std::endl;
|
||||
bool pass_value =
|
||||
ck_tile::check_err(h_out, h_out_ref, "Error: Incorrect values!", 1e-5, 1e-5);
|
||||
bool pass_index = ck_tile::check_err(
|
||||
h_out_index, h_out_ref_index, "Error: Incorrect indices!", 1e-5, 1e-5);
|
||||
|
||||
return pass;
|
||||
std::cout << (pass_value && pass_index ? "PASS" : "FAIL") << std::endl;
|
||||
return pass_value && pass_index;
|
||||
}
|
||||
};
|
||||
|
||||
@@ -194,6 +339,50 @@ using TestTypes =
|
||||
|
||||
TYPED_TEST_SUITE(TestCkTilePooling, TestTypes);
|
||||
|
||||
// 2D Pooling Tests (NHWC)
|
||||
TYPED_TEST(TestCkTilePooling, Pool2D_2x2)
|
||||
{
|
||||
typename TestFixture::Config2D config = {1, // N - batch size
|
||||
8, // H - height dimension
|
||||
8, // W - width dimension
|
||||
32, // C - channel dimension
|
||||
2, // Y - pooling window height
|
||||
2, // X - pooling window width
|
||||
2, // Sy - window stride height
|
||||
2, // Sx - window stride width
|
||||
1, // Dy - window dilation height
|
||||
1, // Dx - window dilation width
|
||||
0, // LeftPy - left padding height
|
||||
0, // LeftPx - left padding width
|
||||
0, // RightPy - right padding height
|
||||
0, // RightPx - right padding width
|
||||
"2x2 pooling NHWC"};
|
||||
bool pass = this->RunPool2D(config);
|
||||
EXPECT_TRUE(pass);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestCkTilePooling, Pool2D_3x3_WithPadding)
|
||||
{
|
||||
typename TestFixture::Config2D config = {2, // N - batch size
|
||||
16, // H - height dimension
|
||||
16, // W - width dimension
|
||||
32, // C - channel dimension
|
||||
3, // Y - pooling window height
|
||||
3, // X - pooling window width
|
||||
2, // Sy - window stride height
|
||||
2, // Sx - window stride width
|
||||
1, // Dy - window dilation height
|
||||
1, // Dx - window dilation width
|
||||
1, // LeftPy - left padding height
|
||||
1, // LeftPx - left padding width
|
||||
1, // RightPy - right padding height
|
||||
1, // RightPx - right padding width
|
||||
"3x3 pooling NHWC with padding"};
|
||||
bool pass = this->RunPool2D(config);
|
||||
EXPECT_TRUE(pass);
|
||||
}
|
||||
|
||||
// 3D Pooling Tests (NDHWC)
|
||||
TYPED_TEST(TestCkTilePooling, Pool3D_2x2x2)
|
||||
{
|
||||
typename TestFixture::Config3D config = {1, // N - batch size
|
||||
@@ -216,7 +405,7 @@ TYPED_TEST(TestCkTilePooling, Pool3D_2x2x2)
|
||||
0, // RightPz - right padding depth
|
||||
0, // RightPy - right padding height
|
||||
0, // RightPx - right padding width
|
||||
"2x2x2 pooling"};
|
||||
"2x2x2 pooling NDHWC"};
|
||||
bool pass = this->RunPool3D(config);
|
||||
EXPECT_TRUE(pass);
|
||||
}
|
||||
@@ -243,7 +432,7 @@ TYPED_TEST(TestCkTilePooling, Pool3D_3x3x3)
|
||||
1, // RightPz - right padding depth
|
||||
1, // RightPy - right padding height
|
||||
1, // RightPx - right padding width
|
||||
"3x3x3 pooling"};
|
||||
"3x3x3 pooling NDHWC with padding"};
|
||||
bool pass = this->RunPool3D(config);
|
||||
EXPECT_TRUE(pass);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user