mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-20 21:09:08 +00:00
Extend pool3d fwd avg, max operations by f8_t, int8_t types (#1483)
* Extend pool3d fwd avg, max operations by f8_t, int8_t types
* Pack MaxPool3dFwd params together
* Fix MaxPool3dFwd AVG instances
* Decrease verification precision for bf16
* Adjust tests + review changes
* Adjust threshold for F8
* Adjusted compute types for MAX op instances
* Fix ComputeDataType mismatch in tests and profiler for AVG
* Fix naming from max_pool3d_fwd to pool3d_fwd
* Adjust CMakeLists
---------
Co-authored-by: Adam Osewski <19374865+aosewski@users.noreply.github.com>
[ROCm/composable_kernel commit: a793afc961]
This commit is contained in:
@@ -1,6 +1,10 @@
|
||||
set(DEVICE_POOL3D_FWD_INSTANCES)
|
||||
list(APPEND DEVICE_POOL3D_FWD_INSTANCES device_avg_pool3d_fwd_ndhwc_f16_instance.cpp
|
||||
device_max_pool3d_fwd_ndhwc_f16_instance.cpp
|
||||
device_max_pool3d_fwd_ndhwc_f8_instance.cpp
|
||||
device_avg_pool3d_fwd_ndhwc_f8_instance.cpp
|
||||
device_max_pool3d_fwd_ndhwc_i8_instance.cpp
|
||||
device_avg_pool3d_fwd_ndhwc_i8_instance.cpp
|
||||
device_avg_pool3d_fwd_ndhwc_f32_instance.cpp
|
||||
device_max_pool3d_fwd_ndhwc_f32_instance.cpp
|
||||
device_avg_pool3d_fwd_ndhwc_bf16_instance.cpp
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "pool_fwd_instance_common.hpp"
|
||||
|
||||
|
||||
@@ -0,0 +1,24 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "pool_fwd_instance_common.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
static constexpr auto ReduceOpId = ck::ReduceTensorOp::AVG;
|
||||
|
||||
void add_device_pool3d_fwd_ndhwc_f8_instances(
|
||||
std::vector<std::unique_ptr<DevicePoolFwd<5, 3, F8, F8, I32, NDHWC, NDHWC, ReduceOpId, false>>>&
|
||||
instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances, device_pool3d_fwd_ndhwc_instances<F8, F8, I32, F32, ReduceOpId, false>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,24 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "pool_fwd_instance_common.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
static constexpr auto ReduceOpId = ck::ReduceTensorOp::AVG;
|
||||
|
||||
void add_device_pool3d_fwd_ndhwc_i8_instances(
|
||||
std::vector<std::unique_ptr<DevicePoolFwd<5, 3, I8, I8, I32, NDHWC, NDHWC, ReduceOpId, false>>>&
|
||||
instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances, device_pool3d_fwd_ndhwc_instances<I8, I8, I32, I32, ReduceOpId, false>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "pool_fwd_instance_common.hpp"
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "pool_fwd_instance_common.hpp"
|
||||
|
||||
|
||||
@@ -0,0 +1,32 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "pool_fwd_instance_common.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
static constexpr auto ReduceOpId = ck::ReduceTensorOp::MAX;
|
||||
|
||||
void add_device_pool3d_fwd_ndhwc_f8_instances(
|
||||
std::vector<std::unique_ptr<DevicePoolFwd<5, 3, F8, F8, I32, NDHWC, NDHWC, ReduceOpId, false>>>&
|
||||
instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances, device_pool3d_fwd_ndhwc_instances<F8, F8, I32, F8, ReduceOpId, false>{});
|
||||
}
|
||||
|
||||
void add_device_pool3d_fwd_ndhwc_index_f8_instances(
|
||||
std::vector<std::unique_ptr<DevicePoolFwd<5, 3, F8, F8, I32, NDHWC, NDHWC, ReduceOpId, true>>>&
|
||||
instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances, device_pool3d_fwd_ndhwc_instances<F8, F8, I32, F8, ReduceOpId, true>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,32 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "pool_fwd_instance_common.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
static constexpr auto ReduceOpId = ck::ReduceTensorOp::MAX;
|
||||
|
||||
void add_device_pool3d_fwd_ndhwc_i8_instances(
|
||||
std::vector<std::unique_ptr<DevicePoolFwd<5, 3, I8, I8, I32, NDHWC, NDHWC, ReduceOpId, false>>>&
|
||||
instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances, device_pool3d_fwd_ndhwc_instances<I8, I8, I32, I8, ReduceOpId, false>{});
|
||||
}
|
||||
|
||||
void add_device_pool3d_fwd_ndhwc_index_i8_instances(
|
||||
std::vector<std::unique_ptr<DevicePoolFwd<5, 3, I8, I8, I32, NDHWC, NDHWC, ReduceOpId, true>>>&
|
||||
instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances, device_pool3d_fwd_ndhwc_instances<I8, I8, I32, I8, ReduceOpId, true>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -15,6 +15,8 @@ namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
using I8 = int8_t;
|
||||
using F8 = ck::f8_t;
|
||||
using I32 = int32_t;
|
||||
using F16 = ck::half_t;
|
||||
using BF16 = ck::bhalf_t;
|
||||
|
||||
Reference in New Issue
Block a user