mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-21 13:29:20 +00:00
Pool2d max/avg kernel in the BWD version (#1494)
* Add pool2d instance BWD AVG
* Add pool2d instance BWD MAX
* Fix: avg review
* Fix review: part2
* Fix - enable test when type is compiled
* Fix review part3
[ROCm/composable_kernel commit: 448c0f56d8]
This commit is contained in:
@@ -0,0 +1,8 @@
|
||||
set(DEVICE_AVGPOOL_2D_BWD_INSTANCES)
|
||||
list(APPEND DEVICE_AVGPOOL_2D_BWD_INSTANCES device_avg_pool2d_bwd_nhwc_bf16_instance.cpp
|
||||
device_avg_pool2d_bwd_nhwc_f16_instance.cpp
|
||||
device_avg_pool2d_bwd_nhwc_f32_instance.cpp
|
||||
device_avg_pool2d_bwd_nhwc_f8_instance.cpp
|
||||
device_avg_pool2d_bwd_nhwc_int8_instance.cpp)
|
||||
add_instance_library(device_avg_pool2d_bwd_instance ${DEVICE_AVGPOOL_2D_BWD_INSTANCES})
|
||||
|
||||
@@ -0,0 +1,21 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "device_avg_pool2d_bwd_nhwc_instance_common.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
void add_device_avgpool_2D_bwd_nhwc_bf16_instances(
|
||||
std::vector<std::unique_ptr<DeviceAvgPoolBwd<2, BF16, BF16, NHWC, NHWC>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(instances,
|
||||
device_avgpool_2D_bwd_nhwc_instances<BF16, BF16, F32>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,21 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "device_avg_pool2d_bwd_nhwc_instance_common.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
void add_device_avgpool_2D_bwd_nhwc_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceAvgPoolBwd<2, F16, F16, NHWC, NHWC>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(instances,
|
||||
device_avgpool_2D_bwd_nhwc_instances<F16, F16, F32>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,21 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "device_avg_pool2d_bwd_nhwc_instance_common.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
void add_device_avgpool_2D_bwd_nhwc_f32_instances(
|
||||
std::vector<std::unique_ptr<DeviceAvgPoolBwd<2, F32, F32, NHWC, NHWC>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(instances,
|
||||
device_avgpool_2D_bwd_nhwc_instances<F32, F32, F32>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,20 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "device_avg_pool2d_bwd_nhwc_instance_common.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
void add_device_avgpool_2D_bwd_nhwc_f8_instances(
|
||||
std::vector<std::unique_ptr<DeviceAvgPoolBwd<2, F8, F8, NHWC, NHWC>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(instances, device_avgpool_2D_bwd_nhwc_instances<F8, F8, F32>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,38 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_avgpool2d_bwd_nhwc_nhwc.hpp"
|
||||
#include "ck/utility/data_type.hpp"
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
using F16 = ck::half_t;
|
||||
using BF16 = ck::bhalf_t;
|
||||
using F8 = ck::f8_t;
|
||||
using I8 = int8_t;
|
||||
using I32 = int32_t;
|
||||
using F32 = float;
|
||||
using NHWC = ck::tensor_layout::convolution::NHWC;
|
||||
|
||||
template <typename OutType, typename InType, typename ComputeType>
|
||||
using device_avgpool_2D_bwd_nhwc_instances = std::tuple<
|
||||
// clang-format off
|
||||
DeviceAvgPool2dBwd_NHWC_NHWC<OutType, InType, ComputeType, 256, 256, 1, 1, 1, 1>,
|
||||
DeviceAvgPool2dBwd_NHWC_NHWC<OutType, InType, ComputeType, 256, 256, 1, 2, 2, 2>,
|
||||
DeviceAvgPool2dBwd_NHWC_NHWC<OutType, InType, ComputeType, 256, 256, 1, 4, 4, 4>,
|
||||
DeviceAvgPool2dBwd_NHWC_NHWC<OutType, InType, ComputeType, 256, 256, 1, 8, 8, 8>,
|
||||
DeviceAvgPool2dBwd_NHWC_NHWC<OutType, InType, ComputeType, 256, 32, 8, 8, 8, 8>
|
||||
// clang-format on
|
||||
>;
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,20 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "device_avg_pool2d_bwd_nhwc_instance_common.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
void add_device_avgpool_2D_bwd_nhwc_int8_instances(
|
||||
std::vector<std::unique_ptr<DeviceAvgPoolBwd<2, I8, I8, NHWC, NHWC>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(instances, device_avgpool_2D_bwd_nhwc_instances<I8, I8, I32>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -1,5 +1,6 @@
|
||||
set(DEVICE_MAXPOOL_BWD_INSTANCES)
|
||||
list(APPEND DEVICE_MAXPOOL_BWD_INSTANCES device_max_pool_bwd_f16_instance.cpp
|
||||
device_max_pool_bwd_bf16_instance.cpp
|
||||
device_max_pool_bwd_f32_instance.cpp)
|
||||
device_max_pool_bwd_f32_instance.cpp
|
||||
device_max_pool_bwd_int8_instance.cpp)
|
||||
add_instance_library(device_max_pool_bwd_instance ${DEVICE_MAXPOOL_BWD_INSTANCES})
|
||||
|
||||
@@ -0,0 +1,20 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "max_pool_bwd_instance_common.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
void add_device_maxpool_bwd_int8_instances(
|
||||
std::vector<std::unique_ptr<DeviceMaxPoolBwd<I8, I32, I8>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(instances, device_maxpool_bwd_instances<I8, I32, I8>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -17,6 +17,8 @@ namespace instance {
|
||||
using I32 = int32_t;
|
||||
using F16 = ck::half_t;
|
||||
using BF16 = ck::bhalf_t;
|
||||
using I8 = int8_t;
|
||||
using F8 = ck::f8_t;
|
||||
using F32 = float;
|
||||
|
||||
template <typename DOutDataType, typename IndexDataType, typename DInDataType>
|
||||
|
||||
Reference in New Issue
Block a user