mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-16 10:59:55 +00:00
Pool3d fwd (#697)
* Expand the base class of pool2d, prepare to share base class with pool3d
* Add pool3d device op
* Add pool3d f16 example
* Refactor the base class. implement generic pooling in the future
* clang format
* get original index in max pooling
* Add outputindex to base class
* Fix dimension
* Add pooling instance
* Use indexType instead
* Remove useless header
* Extract IndexDataType to template
* Extract pooling reference code
* clang format
* clang format
* Fix typo
* Add tensor stride
* Add missing header
* Add index stride and output stride
* Refine naming
* Add type to base class
* Rename file
* Use proper size
* Fix typo
* Refine naming
* Modify the argument into vector.
* Add max pool profiler
* Refine naming
* Support f32 pool
* Fix typo
* Add avg pool2d fwd in profiler
* clang format
* Rename AccDatatype to ComputeDatatype
* Fix init
* test pool
* Extract variable
* Add client example
* Check the pooling dim
* clang format
* Connect argv and arg_parser
* Add found check
* Remove useless header
* Refine naming
* Adjust the order of device_pool_fwd
[ROCm/composable_kernel commit: 76ec0089fb]
This commit is contained in:
@@ -0,0 +1,10 @@
|
||||
add_instance_library(device_pool_fwd_instance
|
||||
device_avg_pool2d_fwd_nhwc_f16_instance.cpp
|
||||
device_avg_pool2d_fwd_nhwc_f32_instance.cpp
|
||||
device_avg_pool3d_fwd_ndhwc_f16_instance.cpp
|
||||
device_avg_pool3d_fwd_ndhwc_f32_instance.cpp
|
||||
device_max_pool2d_fwd_nhwc_f16_instance.cpp
|
||||
device_max_pool2d_fwd_nhwc_f32_instance.cpp
|
||||
device_max_pool3d_fwd_ndhwc_f16_instance.cpp
|
||||
device_max_pool3d_fwd_ndhwc_f32_instance.cpp
|
||||
)
|
||||
@@ -0,0 +1,23 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "pool_fwd_instance_common.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
static constexpr auto ReduceOpId = ck::ReduceTensorOp::AVG;
|
||||
|
||||
void add_device_pool2d_fwd_nhwc_f16_instances(
|
||||
std::vector<std::unique_ptr<DevicePoolFwd<4, 2, F16, F16, I32, ReduceOpId, false>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances, device_pool2d_fwd_nhwc_instances<F16, F16, I32, F32, ReduceOpId, false>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,23 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "pool_fwd_instance_common.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
static constexpr auto ReduceOpId = ck::ReduceTensorOp::AVG;
|
||||
|
||||
void add_device_pool2d_fwd_nhwc_f32_instances(
|
||||
std::vector<std::unique_ptr<DevicePoolFwd<4, 2, F32, F32, I32, ReduceOpId, false>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances, device_pool2d_fwd_nhwc_instances<F32, F32, I32, F32, ReduceOpId, false>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,23 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "pool_fwd_instance_common.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
static constexpr auto ReduceOpId = ck::ReduceTensorOp::AVG;
|
||||
|
||||
void add_device_pool3d_fwd_ndhwc_f16_instances(
|
||||
std::vector<std::unique_ptr<DevicePoolFwd<5, 3, F16, F16, I32, ReduceOpId, false>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances, device_pool3d_fwd_ndhwc_instances<F16, F16, I32, F32, ReduceOpId, false>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,23 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "pool_fwd_instance_common.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
static constexpr auto ReduceOpId = ck::ReduceTensorOp::AVG;
|
||||
|
||||
void add_device_pool3d_fwd_ndhwc_f32_instances(
|
||||
std::vector<std::unique_ptr<DevicePoolFwd<5, 3, F32, F32, I32, ReduceOpId, false>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances, device_pool3d_fwd_ndhwc_instances<F32, F32, I32, F32, ReduceOpId, false>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,30 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "pool_fwd_instance_common.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
static constexpr auto ReduceOpId = ck::ReduceTensorOp::MAX;
|
||||
|
||||
void add_device_pool2d_fwd_nhwc_f16_instances(
|
||||
std::vector<std::unique_ptr<DevicePoolFwd<4, 2, F16, F16, I32, ReduceOpId, false>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances, device_pool2d_fwd_nhwc_instances<F16, F16, I32, F16, ReduceOpId, false>{});
|
||||
}
|
||||
|
||||
void add_device_pool2d_fwd_nhwc_index_f16_instances(
|
||||
std::vector<std::unique_ptr<DevicePoolFwd<4, 2, F16, F16, I32, ReduceOpId, true>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances, device_pool2d_fwd_nhwc_instances<F16, F16, I32, F16, ReduceOpId, true>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,30 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "pool_fwd_instance_common.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
static constexpr auto ReduceOpId = ck::ReduceTensorOp::MAX;
|
||||
|
||||
void add_device_pool2d_fwd_nhwc_f32_instances(
|
||||
std::vector<std::unique_ptr<DevicePoolFwd<4, 2, F32, F32, I32, ReduceOpId, false>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances, device_pool2d_fwd_nhwc_instances<F32, F32, I32, F32, ReduceOpId, false>{});
|
||||
}
|
||||
|
||||
void add_device_pool2d_fwd_nhwc_index_f32_instances(
|
||||
std::vector<std::unique_ptr<DevicePoolFwd<4, 2, F32, F32, I32, ReduceOpId, true>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances, device_pool2d_fwd_nhwc_instances<F32, F32, I32, F32, ReduceOpId, true>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,30 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "pool_fwd_instance_common.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
static constexpr auto ReduceOpId = ck::ReduceTensorOp::MAX;
|
||||
|
||||
void add_device_pool3d_fwd_ndhwc_f16_instances(
|
||||
std::vector<std::unique_ptr<DevicePoolFwd<5, 3, F16, F16, I32, ReduceOpId, false>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances, device_pool3d_fwd_ndhwc_instances<F16, F16, I32, F16, ReduceOpId, false>{});
|
||||
}
|
||||
|
||||
void add_device_pool3d_fwd_ndhwc_index_f16_instances(
|
||||
std::vector<std::unique_ptr<DevicePoolFwd<5, 3, F16, F16, I32, ReduceOpId, true>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances, device_pool3d_fwd_ndhwc_instances<F16, F16, I32, F16, ReduceOpId, true>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,30 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "pool_fwd_instance_common.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
static constexpr auto ReduceOpId = ck::ReduceTensorOp::MAX;
|
||||
|
||||
void add_device_pool3d_fwd_ndhwc_f32_instances(
|
||||
std::vector<std::unique_ptr<DevicePoolFwd<5, 3, F32, F32, I32, ReduceOpId, false>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances, device_pool3d_fwd_ndhwc_instances<F32, F32, I32, F32, ReduceOpId, false>{});
|
||||
}
|
||||
|
||||
void add_device_pool3d_fwd_ndhwc_index_f32_instances(
|
||||
std::vector<std::unique_ptr<DevicePoolFwd<5, 3, F32, F32, I32, ReduceOpId, true>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances, device_pool3d_fwd_ndhwc_instances<F32, F32, I32, F32, ReduceOpId, true>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,55 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_pool2d_fwd_nhwc_nhwc.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_pool3d_fwd_ndhwc_ndhwc.hpp"
|
||||
#include "ck/utility/data_type.hpp"
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
using I32 = int32_t;
|
||||
using F16 = ck::half_t;
|
||||
using F32 = float;
|
||||
|
||||
template <typename InDataType,
|
||||
typename OutDataType,
|
||||
typename IndexDataType,
|
||||
typename ComputeDataType,
|
||||
ReduceTensorOp ReduceOpId,
|
||||
bool OutputIndex>
|
||||
using device_pool2d_fwd_nhwc_instances =
|
||||
// clang-format off
|
||||
std::tuple <
|
||||
DevicePool2dFwd_Input_N_Hi_Wi_C_Output_N_Ho_Wo_C<InDataType, OutDataType, IndexDataType, ComputeDataType, ReduceOpId, OutputIndex, 256, 256, 1, 1, 1, 1>,
|
||||
DevicePool2dFwd_Input_N_Hi_Wi_C_Output_N_Ho_Wo_C<InDataType, OutDataType, IndexDataType, ComputeDataType, ReduceOpId, OutputIndex, 256, 256, 1, 2, 1, 2>,
|
||||
DevicePool2dFwd_Input_N_Hi_Wi_C_Output_N_Ho_Wo_C<InDataType, OutDataType, IndexDataType, ComputeDataType, ReduceOpId, OutputIndex, 256, 256, 1, 4, 1, 4>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
template <typename InDataType,
|
||||
typename OutDataType,
|
||||
typename IndexDataType,
|
||||
typename ComputeDataType,
|
||||
ReduceTensorOp ReduceOpId,
|
||||
bool OutputIndex>
|
||||
using device_pool3d_fwd_ndhwc_instances =
|
||||
// clang-format off
|
||||
std::tuple <
|
||||
DevicePool3dFwd_Input_N_Di_Hi_Wi_C_Output_N_Do_Ho_Wo_C<InDataType, OutDataType, IndexDataType, ComputeDataType, ReduceOpId, OutputIndex, 256, 256, 1, 1, 1, 1>,
|
||||
DevicePool3dFwd_Input_N_Di_Hi_Wi_C_Output_N_Do_Ho_Wo_C<InDataType, OutDataType, IndexDataType, ComputeDataType, ReduceOpId, OutputIndex, 256, 256, 1, 2, 1, 2>,
|
||||
DevicePool3dFwd_Input_N_Di_Hi_Wi_C_Output_N_Do_Ho_Wo_C<InDataType, OutDataType, IndexDataType, ComputeDataType, ReduceOpId, OutputIndex, 256, 256, 1, 4, 1, 4>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
Reference in New Issue
Block a user