Rewrite pool2d fwd (#1462)

* added pool2d fwd

* add tests

* add reviewers changes

* Revert "Merge remote-tracking branch 'origin/develop' into jakpiase/pool2d_fwd_new"

This reverts commit 6b2ba7ff89, reversing
changes made to 22c82bea0c.

* Revert "add reviewers changes"

This reverts commit 22c82bea0c.

* added reviewers comments

* revert some old files

* add reviewers requests

---------

Co-authored-by: Adam Osewski <19374865+aosewski@users.noreply.github.com>

[ROCm/composable_kernel commit: e8d2887cb2]
This commit is contained in:
jakpiase
2024-09-11 15:21:00 +02:00
committed by GitHub
parent 681d36db5f
commit 8aeb2afbe2
16 changed files with 1374 additions and 71 deletions

View File

@@ -0,0 +1,8 @@
set(DEVICE_POOL2D_FWD_INSTANCES)
list(APPEND DEVICE_POOL2D_FWD_INSTANCES device_avg_pool2d_fwd_nhwc_f16_instance.cpp
device_max_pool2d_fwd_nhwc_f16_instance.cpp
device_avg_pool2d_fwd_nhwc_f32_instance.cpp
device_max_pool2d_fwd_nhwc_f32_instance.cpp
device_avg_pool2d_fwd_nhwc_bf16_instance.cpp
device_max_pool2d_fwd_nhwc_bf16_instance.cpp)
add_instance_library(device_pool2d_fwd_instance ${DEVICE_POOL2D_FWD_INSTANCES})

View File

@@ -0,0 +1,25 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "pool2d_fwd_instance_common.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
static constexpr auto ReduceOpId = ck::ReduceTensorOp::AVG;
void add_device_pool2d_fwd_nhwc_bf16_instances(
std::vector<
std::unique_ptr<DevicePoolFwd<4, 2, BF16, BF16, I32, NHWC, NHWC, ReduceOpId, false>>>&
instances)
{
add_device_operation_instances(
instances, device_pool2d_fwd_nhwc_instances<BF16, BF16, I32, F32, ReduceOpId, false>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,24 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "pool2d_fwd_instance_common.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
static constexpr auto ReduceOpId = ck::ReduceTensorOp::AVG;
void add_device_pool2d_fwd_nhwc_f16_instances(
std::vector<std::unique_ptr<DevicePoolFwd<4, 2, F16, F16, I32, NHWC, NHWC, ReduceOpId, false>>>&
instances)
{
add_device_operation_instances(
instances, device_pool2d_fwd_nhwc_instances<F16, F16, I32, F32, ReduceOpId, false>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,24 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "pool2d_fwd_instance_common.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
static constexpr auto ReduceOpId = ck::ReduceTensorOp::AVG;
void add_device_pool2d_fwd_nhwc_f32_instances(
std::vector<std::unique_ptr<DevicePoolFwd<4, 2, F32, F32, I32, NHWC, NHWC, ReduceOpId, false>>>&
instances)
{
add_device_operation_instances(
instances, device_pool2d_fwd_nhwc_instances<F32, F32, I32, F32, ReduceOpId, false>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,34 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "pool2d_fwd_instance_common.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
static constexpr auto ReduceOpId = ck::ReduceTensorOp::MAX;
void add_device_pool2d_fwd_nhwc_bf16_instances(
std::vector<
std::unique_ptr<DevicePoolFwd<4, 2, BF16, BF16, I32, NHWC, NHWC, ReduceOpId, false>>>&
instances)
{
add_device_operation_instances(
instances, device_pool2d_fwd_nhwc_instances<BF16, BF16, I32, F32, ReduceOpId, false>{});
}
void add_device_pool2d_fwd_nhwc_index_bf16_instances(
std::vector<
std::unique_ptr<DevicePoolFwd<4, 2, BF16, BF16, I32, NHWC, NHWC, ReduceOpId, true>>>&
instances)
{
add_device_operation_instances(
instances, device_pool2d_fwd_nhwc_instances<BF16, BF16, I32, F32, ReduceOpId, true>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,32 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "pool2d_fwd_instance_common.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
static constexpr auto ReduceOpId = ck::ReduceTensorOp::MAX;
void add_device_pool2d_fwd_nhwc_f16_instances(
std::vector<std::unique_ptr<DevicePoolFwd<4, 2, F16, F16, I32, NHWC, NHWC, ReduceOpId, false>>>&
instances)
{
add_device_operation_instances(
instances, device_pool2d_fwd_nhwc_instances<F16, F16, I32, F32, ReduceOpId, false>{});
}
void add_device_pool2d_fwd_nhwc_index_f16_instances(
std::vector<std::unique_ptr<DevicePoolFwd<4, 2, F16, F16, I32, NHWC, NHWC, ReduceOpId, true>>>&
instances)
{
add_device_operation_instances(
instances, device_pool2d_fwd_nhwc_instances<F16, F16, I32, F32, ReduceOpId, true>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,32 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "pool2d_fwd_instance_common.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
static constexpr auto ReduceOpId = ck::ReduceTensorOp::MAX;
void add_device_pool2d_fwd_nhwc_f32_instances(
std::vector<std::unique_ptr<DevicePoolFwd<4, 2, F32, F32, I32, NHWC, NHWC, ReduceOpId, false>>>&
instances)
{
add_device_operation_instances(
instances, device_pool2d_fwd_nhwc_instances<F32, F32, I32, F32, ReduceOpId, false>{});
}
void add_device_pool2d_fwd_nhwc_index_f32_instances(
std::vector<std::unique_ptr<DevicePoolFwd<4, 2, F32, F32, I32, NHWC, NHWC, ReduceOpId, true>>>&
instances)
{
add_device_operation_instances(
instances, device_pool2d_fwd_nhwc_instances<F32, F32, I32, F32, ReduceOpId, true>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,41 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_pool2d_fwd_nhwc_nhwc.hpp"
#include "ck/utility/data_type.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
using I32 = int32_t;
using F16 = ck::half_t;
using BF16 = ck::bhalf_t;
using F32 = float;
using NHWC = ck::tensor_layout::convolution::NHWC;
template <typename InDataType,
typename OutDataType,
typename IndexDataType,
typename ComputeDataType,
ReduceTensorOp ReduceOpId,
bool OutputIndex>
using device_pool2d_fwd_nhwc_instances =
// clang-format off
std::tuple <
DevicePool2dFwd_NHWC_NHWC<InDataType, OutDataType, IndexDataType, ComputeDataType, ReduceOpId, OutputIndex, 256, 256, 1, 1, 1, 1>,
DevicePool2dFwd_NHWC_NHWC<InDataType, OutDataType, IndexDataType, ComputeDataType, ReduceOpId, OutputIndex, 256, 256, 1, 2, 1, 2>,
DevicePool2dFwd_NHWC_NHWC<InDataType, OutDataType, IndexDataType, ComputeDataType, ReduceOpId, OutputIndex, 256, 256, 1, 4, 1, 4>
// clang-format on
>;
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck