mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-18 03:49:41 +00:00
Add pool2d int8 and fp8 instances (#1508)
* add pool2d fp8 and int8
* minor fixes
* add formatting
* add reviewer suggestions
* add reviewer suggestions
[ROCm/composable_kernel commit: 8f8a2ce396]
This commit is contained in:
@@ -67,6 +67,36 @@ void add_device_pool2d_fwd_nhwc_index_f32_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DevicePoolFwd<InOutRank, WindowRank, F32, F32, I32, NHWC, NHWC, MaxOp, true>>>&);
|
||||
#endif
|
||||
#ifdef CK_ENABLE_INT8
|
||||
// I8
|
||||
void add_device_pool2d_fwd_nhwc_i8_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DevicePoolFwd<InOutRank, WindowRank, I8, I8, I32, NHWC, NHWC, MaxOp, false>>>&);
|
||||
|
||||
void add_device_pool2d_fwd_nhwc_i8_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DevicePoolFwd<InOutRank, WindowRank, I8, I8, I32, NHWC, NHWC, AvgOp, false>>>&);
|
||||
|
||||
// I8 - return index
|
||||
void add_device_pool2d_fwd_nhwc_index_i8_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DevicePoolFwd<InOutRank, WindowRank, I8, I8, I32, NHWC, NHWC, MaxOp, true>>>&);
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP8
|
||||
// F8
|
||||
void add_device_pool2d_fwd_nhwc_f8_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DevicePoolFwd<InOutRank, WindowRank, F8, F8, I32, NHWC, NHWC, MaxOp, false>>>&);
|
||||
|
||||
void add_device_pool2d_fwd_nhwc_f8_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DevicePoolFwd<InOutRank, WindowRank, F8, F8, I32, NHWC, NHWC, AvgOp, false>>>&);
|
||||
|
||||
// F8 - return index
|
||||
void add_device_pool2d_fwd_nhwc_index_f8_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DevicePoolFwd<InOutRank, WindowRank, F8, F8, I32, NHWC, NHWC, MaxOp, true>>>&);
|
||||
#endif
|
||||
template <typename InDataType,
|
||||
typename OutDataType,
|
||||
typename IndexDataType,
|
||||
@@ -140,6 +170,34 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DevicePoolFw
|
||||
add_device_pool2d_fwd_nhwc_f32_instances(op_ptrs);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
#ifdef CK_ENABLE_INT8
|
||||
else if constexpr(is_same_v<InDataType, I8> && is_same_v<OutDataType, I8> &&
|
||||
is_same_v<IndexDataType, I32>)
|
||||
{
|
||||
if constexpr(OutputIndex && ReduceOpId == MaxOp)
|
||||
{
|
||||
add_device_pool2d_fwd_nhwc_index_i8_instances(op_ptrs);
|
||||
}
|
||||
else
|
||||
{
|
||||
add_device_pool2d_fwd_nhwc_i8_instances(op_ptrs);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP8
|
||||
else if constexpr(is_same_v<InDataType, F8> && is_same_v<OutDataType, F8> &&
|
||||
is_same_v<IndexDataType, I32>)
|
||||
{
|
||||
if constexpr(OutputIndex && ReduceOpId == MaxOp)
|
||||
{
|
||||
add_device_pool2d_fwd_nhwc_index_f8_instances(op_ptrs);
|
||||
}
|
||||
else
|
||||
{
|
||||
add_device_pool2d_fwd_nhwc_f8_instances(op_ptrs);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
|
||||
@@ -4,5 +4,9 @@ list(APPEND DEVICE_POOL2D_FWD_INSTANCES device_avg_pool2d_fwd_nhwc_f16_instance.
|
||||
device_avg_pool2d_fwd_nhwc_f32_instance.cpp
|
||||
device_max_pool2d_fwd_nhwc_f32_instance.cpp
|
||||
device_avg_pool2d_fwd_nhwc_bf16_instance.cpp
|
||||
device_max_pool2d_fwd_nhwc_bf16_instance.cpp)
|
||||
device_max_pool2d_fwd_nhwc_bf16_instance.cpp
|
||||
device_avg_pool2d_fwd_nhwc_i8_instance.cpp
|
||||
device_max_pool2d_fwd_nhwc_i8_instance.cpp
|
||||
device_avg_pool2d_fwd_nhwc_f8_instance.cpp
|
||||
device_max_pool2d_fwd_nhwc_f8_instance.cpp)
|
||||
add_instance_library(device_pool2d_fwd_instance ${DEVICE_POOL2D_FWD_INSTANCES})
|
||||
|
||||
@@ -0,0 +1,24 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "pool2d_fwd_instance_common.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
static constexpr auto ReduceOpId = ck::ReduceTensorOp::AVG;
|
||||
|
||||
void add_device_pool2d_fwd_nhwc_f8_instances(
|
||||
std::vector<std::unique_ptr<DevicePoolFwd<4, 2, F8, F8, I32, NHWC, NHWC, ReduceOpId, false>>>&
|
||||
instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances, device_pool2d_fwd_nhwc_instances<F8, F8, I32, F32, ReduceOpId, false>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,24 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "pool2d_fwd_instance_common.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
static constexpr auto ReduceOpId = ck::ReduceTensorOp::AVG;
|
||||
|
||||
void add_device_pool2d_fwd_nhwc_i8_instances(
|
||||
std::vector<std::unique_ptr<DevicePoolFwd<4, 2, I8, I8, I32, NHWC, NHWC, ReduceOpId, false>>>&
|
||||
instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances, device_pool2d_fwd_nhwc_instances<I8, I8, I32, F32, ReduceOpId, false>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,32 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "pool2d_fwd_instance_common.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
static constexpr auto ReduceOpId = ck::ReduceTensorOp::MAX;
|
||||
|
||||
void add_device_pool2d_fwd_nhwc_f8_instances(
|
||||
std::vector<std::unique_ptr<DevicePoolFwd<4, 2, F8, F8, I32, NHWC, NHWC, ReduceOpId, false>>>&
|
||||
instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances, device_pool2d_fwd_nhwc_instances<F8, F8, I32, F32, ReduceOpId, false>{});
|
||||
}
|
||||
|
||||
void add_device_pool2d_fwd_nhwc_index_f8_instances(
|
||||
std::vector<std::unique_ptr<DevicePoolFwd<4, 2, F8, F8, I32, NHWC, NHWC, ReduceOpId, true>>>&
|
||||
instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances, device_pool2d_fwd_nhwc_instances<F8, F8, I32, F32, ReduceOpId, true>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,32 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "pool2d_fwd_instance_common.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
static constexpr auto ReduceOpId = ck::ReduceTensorOp::MAX;
|
||||
|
||||
void add_device_pool2d_fwd_nhwc_i8_instances(
|
||||
std::vector<std::unique_ptr<DevicePoolFwd<4, 2, I8, I8, I32, NHWC, NHWC, ReduceOpId, false>>>&
|
||||
instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances, device_pool2d_fwd_nhwc_instances<I8, I8, I32, F32, ReduceOpId, false>{});
|
||||
}
|
||||
|
||||
void add_device_pool2d_fwd_nhwc_index_i8_instances(
|
||||
std::vector<std::unique_ptr<DevicePoolFwd<4, 2, I8, I8, I32, NHWC, NHWC, ReduceOpId, true>>>&
|
||||
instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances, device_pool2d_fwd_nhwc_instances<I8, I8, I32, F32, ReduceOpId, true>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -15,9 +15,11 @@ namespace device {
|
||||
namespace instance {
|
||||
|
||||
using I32 = int32_t;
|
||||
using F32 = float;
|
||||
using F16 = ck::half_t;
|
||||
using BF16 = ck::bhalf_t;
|
||||
using F32 = float;
|
||||
using I8 = int8_t;
|
||||
using F8 = ck::f8_t;
|
||||
using NHWC = ck::tensor_layout::convolution::NHWC;
|
||||
|
||||
template <typename InDataType,
|
||||
|
||||
@@ -49,9 +49,18 @@ struct maxPoolFwdArgParser
|
||||
}
|
||||
};
|
||||
|
||||
enum struct PoolDataType
|
||||
{
|
||||
F32 = 0,
|
||||
BF16,
|
||||
F16,
|
||||
INT8,
|
||||
F8,
|
||||
};
|
||||
|
||||
void print_help_max_pool2d_fwd()
|
||||
{
|
||||
std::cout << "arg1: data type (0: fp16; 1: fp32; 5: bf16)\n"
|
||||
std::cout << "arg1: data type (0: fp16; 1: fp32; 2: bf16; 3: int8; 4: fp8)\n"
|
||||
<< "arg2: verification (0: no; 1: yes)\n"
|
||||
<< "arg3: initialization (0: no init; 1: integer value; 2: decimal value)\n"
|
||||
<< "arg4: print tensor value (0: no; 1: yes)\n"
|
||||
@@ -70,12 +79,12 @@ void print_help_max_pool2d_fwd()
|
||||
|
||||
int profile_max_pool2d_fwd(int argc, char* argv[])
|
||||
{
|
||||
ck::DataTypeEnum data_type = ck::DataTypeEnum::Half;
|
||||
bool do_verification = true;
|
||||
int init_method = 0;
|
||||
bool do_log = false;
|
||||
bool time_kernel = true;
|
||||
bool return_index = false;
|
||||
PoolDataType data_type = PoolDataType::F32;
|
||||
bool do_verification = true;
|
||||
int init_method = 0;
|
||||
bool do_log = false;
|
||||
bool time_kernel = true;
|
||||
bool return_index = false;
|
||||
|
||||
std::vector<index_t> in_length = {2, 32, 30, 30};
|
||||
std::vector<index_t> wsize = {2, 2};
|
||||
@@ -91,7 +100,7 @@ int profile_max_pool2d_fwd(int argc, char* argv[])
|
||||
}
|
||||
else if(argc == 28)
|
||||
{
|
||||
data_type = static_cast<ck::DataTypeEnum>(std::stoi(argv[2]));
|
||||
data_type = static_cast<PoolDataType>(std::stoi(argv[2]));
|
||||
do_verification = std::stoi(argv[3]);
|
||||
init_method = std::stoi(argv[4]);
|
||||
do_log = std::stoi(argv[5]);
|
||||
@@ -113,11 +122,13 @@ int profile_max_pool2d_fwd(int argc, char* argv[])
|
||||
using BF16 = ck::bhalf_t;
|
||||
using F32 = float;
|
||||
using I32 = int32_t;
|
||||
using F8 = ck::f8_t;
|
||||
using I8 = int8_t;
|
||||
using NHWC = ck::tensor_layout::convolution::NHWC;
|
||||
|
||||
constexpr auto ReduceOpId = ck::ReduceTensorOp::MAX;
|
||||
|
||||
if(data_type == ck::DataTypeEnum::Half)
|
||||
if(data_type == PoolDataType::F16)
|
||||
{
|
||||
if(return_index)
|
||||
{
|
||||
@@ -150,7 +161,7 @@ int profile_max_pool2d_fwd(int argc, char* argv[])
|
||||
pad2);
|
||||
}
|
||||
}
|
||||
else if(data_type == ck::DataTypeEnum::BFloat16)
|
||||
else if(data_type == PoolDataType::BF16)
|
||||
{
|
||||
if(return_index)
|
||||
{
|
||||
@@ -189,7 +200,7 @@ int profile_max_pool2d_fwd(int argc, char* argv[])
|
||||
pad2);
|
||||
}
|
||||
}
|
||||
else if(data_type == ck::DataTypeEnum::Float)
|
||||
else if(data_type == PoolDataType::F32)
|
||||
{
|
||||
if(return_index)
|
||||
{
|
||||
@@ -222,6 +233,72 @@ int profile_max_pool2d_fwd(int argc, char* argv[])
|
||||
pad2);
|
||||
}
|
||||
}
|
||||
else if(data_type == PoolDataType::INT8)
|
||||
{
|
||||
if(return_index)
|
||||
{
|
||||
ck::profiler::
|
||||
profile_pool2d_fwd_impl<I8, I8, F32, I32, NHWC, NHWC, ReduceOpId, false, true>(
|
||||
do_verification,
|
||||
init_method,
|
||||
do_log,
|
||||
time_kernel,
|
||||
in_length,
|
||||
wsize,
|
||||
wstride,
|
||||
wdilation,
|
||||
pad1,
|
||||
pad2);
|
||||
}
|
||||
else
|
||||
{
|
||||
ck::profiler::
|
||||
profile_pool2d_fwd_impl<I8, I8, F32, I32, NHWC, NHWC, ReduceOpId, false, false>(
|
||||
do_verification,
|
||||
init_method,
|
||||
do_log,
|
||||
time_kernel,
|
||||
in_length,
|
||||
wsize,
|
||||
wstride,
|
||||
wdilation,
|
||||
pad1,
|
||||
pad2);
|
||||
}
|
||||
}
|
||||
else if(data_type == PoolDataType::F8)
|
||||
{
|
||||
if(return_index)
|
||||
{
|
||||
ck::profiler::
|
||||
profile_pool2d_fwd_impl<F8, F8, F32, I32, NHWC, NHWC, ReduceOpId, false, true>(
|
||||
do_verification,
|
||||
init_method,
|
||||
do_log,
|
||||
time_kernel,
|
||||
in_length,
|
||||
wsize,
|
||||
wstride,
|
||||
wdilation,
|
||||
pad1,
|
||||
pad2);
|
||||
}
|
||||
else
|
||||
{
|
||||
ck::profiler::
|
||||
profile_pool2d_fwd_impl<F8, F8, F32, I32, NHWC, NHWC, ReduceOpId, false, false>(
|
||||
do_verification,
|
||||
init_method,
|
||||
do_log,
|
||||
time_kernel,
|
||||
in_length,
|
||||
wsize,
|
||||
wstride,
|
||||
wdilation,
|
||||
pad1,
|
||||
pad2);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("not implemented yet");
|
||||
|
||||
@@ -14,13 +14,12 @@ class TestAvgPool2dFwd : public ::testing::Test
|
||||
using ComputeDataType = std::tuple_element_t<2, Tuple>;
|
||||
using IndexDataType = std::tuple_element_t<3, Tuple>;
|
||||
|
||||
std::vector<PoolingParam> params;
|
||||
static std::vector<PoolingParam> params;
|
||||
|
||||
void Run()
|
||||
{
|
||||
for(auto param : params)
|
||||
{
|
||||
// avg pool
|
||||
bool success =
|
||||
ck::profiler::profile_pool2d_fwd_impl<InDataType,
|
||||
OutDataType,
|
||||
@@ -45,24 +44,102 @@ class TestAvgPool2dFwd : public ::testing::Test
|
||||
}
|
||||
};
|
||||
|
||||
using KernelTypes = std::conditional_t<
|
||||
CK_ENABLE_FP16 && CK_ENABLE_BF16,
|
||||
::testing::Types<std::tuple<F16, F16, F32, I32>,
|
||||
std::tuple<F16, F16, F32, I32>,
|
||||
std::tuple<BF16, BF16, F32, I32>,
|
||||
std::tuple<BF16, BF16, F32, I32>,
|
||||
std::tuple<F32, F32, F32, I32>,
|
||||
std::tuple<F32, F32, F32, I32>>,
|
||||
::testing::Types<std::tuple<F32, F32, F32, I32>, std::tuple<F32, F32, F32, I32>>>;
|
||||
template <typename T>
|
||||
std::vector<PoolingParam> TestAvgPool2dFwd<T>::params = {
|
||||
{{{1, 1, 1, 1}, {1, 1}, {1, 1}, {1, 1}, {0, 0}, {0, 0}},
|
||||
{{2, 16, 64, 64}, {64, 64}, {1, 1}, {1, 1}, {0, 0}, {0, 0}},
|
||||
{{2, 16, 64, 64}, {4, 4}, {4, 4}, {2, 2}, {0, 0}, {0, 0}},
|
||||
{{2, 32, 30, 30}, {2, 2}, {2, 2}, {1, 1}, {1, 1}, {1, 1}}}};
|
||||
|
||||
TYPED_TEST_SUITE(TestAvgPool2dFwd, KernelTypes);
|
||||
TYPED_TEST(TestAvgPool2dFwd, Test_Pool)
|
||||
using AvgPool2D_F32_Types =
|
||||
::testing::Types<std::tuple<F32, F32, F32, I32>, std::tuple<F32, F32, F32, I32>>;
|
||||
using AvgPool2D_F16_Types =
|
||||
::testing::Types<std::tuple<F16, F16, F32, I32>, std::tuple<F16, F16, F32, I32>>;
|
||||
using AvgPool2D_BF16_Types =
|
||||
::testing::Types<std::tuple<I8, I8, F32, I32>, std::tuple<BF16, BF16, F32, I32>>;
|
||||
using AvgPool2D_I8_Types =
|
||||
::testing::Types<std::tuple<I8, I8, F32, I32>, std::tuple<I8, I8, F32, I32>>;
|
||||
using AvgPool2D_F8_Types =
|
||||
::testing::Types<std::tuple<F8, F8, F32, I32>, std::tuple<F8, F8, F32, I32>>;
|
||||
|
||||
template <typename TType>
|
||||
class AvgPool2D_F32 : public TestAvgPool2dFwd<TType>
|
||||
{
|
||||
// length, window_length, window_stride, window_dilation, left_pad, right_pad
|
||||
this->params = {{{1, 1, 1, 1}, {1, 1}, {1, 1}, {1, 1}, {0, 0}, {0, 0}},
|
||||
{{2, 16, 64, 64}, {64, 64}, {1, 1}, {1, 1}, {0, 0}, {0, 0}},
|
||||
{{2, 16, 64, 64}, {4, 4}, {4, 4}, {2, 2}, {0, 0}, {0, 0}},
|
||||
{{2, 32, 30, 30}, {2, 2}, {2, 2}, {1, 1}, {1, 1}, {1, 1}}};
|
||||
protected:
|
||||
void SetUp() override
|
||||
{
|
||||
if(!CK_ENABLE_FP32)
|
||||
{
|
||||
GTEST_SKIP() << "Skipping AvgPool2D_F32 tests because CK_ENABLE_FP32 is "
|
||||
"not enabled";
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
this->Run();
|
||||
}
|
||||
template <typename TType>
|
||||
class AvgPool2D_F16 : public TestAvgPool2dFwd<TType>
|
||||
{
|
||||
protected:
|
||||
void SetUp() override
|
||||
{
|
||||
if(!CK_ENABLE_FP16)
|
||||
{
|
||||
GTEST_SKIP() << "Skipping AvgPool2D_F16 tests because CK_ENABLE_FP16 is "
|
||||
"not enabled";
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename TType>
|
||||
class AvgPool2D_BF16 : public TestAvgPool2dFwd<TType>
|
||||
{
|
||||
protected:
|
||||
void SetUp() override
|
||||
{
|
||||
if(!CK_ENABLE_BF16)
|
||||
{
|
||||
GTEST_SKIP() << "Skipping AvgPool2D_BF16 tests because CK_ENABLE_BF16 is "
|
||||
"not enabled";
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename TType>
|
||||
class AvgPool2D_I8 : public TestAvgPool2dFwd<TType>
|
||||
{
|
||||
protected:
|
||||
void SetUp() override
|
||||
{
|
||||
if(!CK_ENABLE_INT8)
|
||||
{
|
||||
GTEST_SKIP() << "Skipping AvgPool2D_I8 tests because CK_ENABLE_INT8 is "
|
||||
"not enabled";
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename TType>
|
||||
class AvgPool2D_F8 : public TestAvgPool2dFwd<TType>
|
||||
{
|
||||
protected:
|
||||
void SetUp() override
|
||||
{
|
||||
if(!CK_ENABLE_FP8)
|
||||
{
|
||||
GTEST_SKIP() << "Skipping AvgPool2D_F8 tests because CK_ENABLE_FP8 is "
|
||||
"not enabled";
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
TYPED_TEST_SUITE(AvgPool2D_F32, AvgPool2D_F32_Types);
|
||||
TYPED_TEST_SUITE(AvgPool2D_F16, AvgPool2D_F16_Types);
|
||||
TYPED_TEST_SUITE(AvgPool2D_BF16, AvgPool2D_BF16_Types);
|
||||
TYPED_TEST_SUITE(AvgPool2D_I8, AvgPool2D_I8_Types);
|
||||
TYPED_TEST_SUITE(AvgPool2D_F8, AvgPool2D_F8_Types);
|
||||
|
||||
TYPED_TEST(AvgPool2D_F32, AvgPool2D_I8_Test) { this->Run(); }
|
||||
TYPED_TEST(AvgPool2D_F16, AvgPool2D_F16_Test) { this->Run(); }
|
||||
TYPED_TEST(AvgPool2D_BF16, AvgPool2D_BF16_Test) { this->Run(); }
|
||||
TYPED_TEST(AvgPool2D_I8, AvgPool2D_I8_Test) { this->Run(); }
|
||||
TYPED_TEST(AvgPool2D_F8, AvgPool2D_F8_Test) { this->Run(); }
|
||||
|
||||
@@ -15,7 +15,7 @@ class TestMaxPool2dFwd : public ::testing::Test
|
||||
using IndexDataType = std::tuple_element_t<3, Tuple>;
|
||||
static constexpr bool ReturnIndex = std::tuple_element_t<4, Tuple>::value;
|
||||
|
||||
std::vector<PoolingParam> params;
|
||||
static std::vector<PoolingParam> params;
|
||||
|
||||
void Run()
|
||||
{
|
||||
@@ -46,27 +46,105 @@ class TestMaxPool2dFwd : public ::testing::Test
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
std::vector<PoolingParam> TestMaxPool2dFwd<T>::params = {
|
||||
{{{1, 1, 1, 1}, {1, 1}, {1, 1}, {1, 1}, {0, 0}, {0, 0}},
|
||||
{{2, 16, 64, 64}, {64, 64}, {1, 1}, {1, 1}, {0, 0}, {0, 0}},
|
||||
{{2, 16, 64, 64}, {4, 4}, {4, 4}, {2, 2}, {0, 0}, {0, 0}},
|
||||
{{2, 32, 30, 30}, {2, 2}, {2, 2}, {1, 1}, {1, 1}, {1, 1}}}};
|
||||
|
||||
using true_t = std::integral_constant<bool, true>;
|
||||
using false_t = std::integral_constant<bool, false>;
|
||||
|
||||
using KernelTypes = std::conditional_t<CK_ENABLE_FP16 && CK_ENABLE_BF16,
|
||||
::testing::Types<std::tuple<F16, F16, F32, I32, true_t>,
|
||||
std::tuple<F16, F16, F32, I32, false_t>,
|
||||
std::tuple<BF16, BF16, F32, I32, true_t>,
|
||||
std::tuple<BF16, BF16, F32, I32, false_t>,
|
||||
std::tuple<F32, F32, F32, I32, true_t>,
|
||||
std::tuple<F32, F32, F32, I32, false_t>>,
|
||||
::testing::Types<std::tuple<F32, F32, F32, I32, true_t>,
|
||||
std::tuple<F32, F32, F32, I32, false_t>>>;
|
||||
using MaxPool2D_F32_Types = ::testing::Types<std::tuple<F32, F32, F32, I32, true_t>,
|
||||
std::tuple<F32, F32, F32, I32, false_t>>;
|
||||
using MaxPool2D_F16_Types = ::testing::Types<std::tuple<F16, F16, F32, I32, true_t>,
|
||||
std::tuple<F16, F16, F32, I32, false_t>>;
|
||||
using MaxPool2D_BF16_Types = ::testing::Types<std::tuple<I8, I8, F32, I32, true_t>,
|
||||
std::tuple<BF16, BF16, F32, I32, false_t>>;
|
||||
using MaxPool2D_I8_Types =
|
||||
::testing::Types<std::tuple<I8, I8, F32, I32, true_t>, std::tuple<I8, I8, F32, I32, false_t>>;
|
||||
using MaxPool2D_F8_Types =
|
||||
::testing::Types<std::tuple<F8, F8, F32, I32, true_t>, std::tuple<F8, F8, F32, I32, false_t>>;
|
||||
|
||||
TYPED_TEST_SUITE(TestMaxPool2dFwd, KernelTypes);
|
||||
TYPED_TEST(TestMaxPool2dFwd, Test_Pool)
|
||||
template <typename TType>
|
||||
class MaxPool2D_F32 : public TestMaxPool2dFwd<TType>
|
||||
{
|
||||
// length, window_length, window_stride, window_dilation, left_pad, right_pad
|
||||
this->params = {{{1, 1, 1, 1}, {1, 1}, {1, 1}, {1, 1}, {0, 0}, {0, 0}},
|
||||
{{2, 16, 64, 64}, {64, 64}, {1, 1}, {1, 1}, {0, 0}, {0, 0}},
|
||||
{{2, 16, 64, 64}, {4, 4}, {4, 4}, {2, 2}, {0, 0}, {0, 0}},
|
||||
{{2, 32, 30, 30}, {2, 2}, {2, 2}, {1, 1}, {1, 1}, {1, 1}}};
|
||||
protected:
|
||||
void SetUp() override
|
||||
{
|
||||
if(!CK_ENABLE_FP32)
|
||||
{
|
||||
GTEST_SKIP() << "Skipping MaxPool2D_F32 tests because CK_ENABLE_FP32 is "
|
||||
"not enabled";
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
this->Run();
|
||||
}
|
||||
template <typename TType>
|
||||
class MaxPool2D_F16 : public TestMaxPool2dFwd<TType>
|
||||
{
|
||||
protected:
|
||||
void SetUp() override
|
||||
{
|
||||
if(!CK_ENABLE_FP16)
|
||||
{
|
||||
GTEST_SKIP() << "Skipping MaxPool2D_F16 tests because CK_ENABLE_FP16 is "
|
||||
"not enabled";
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename TType>
|
||||
class MaxPool2D_BF16 : public TestMaxPool2dFwd<TType>
|
||||
{
|
||||
protected:
|
||||
void SetUp() override
|
||||
{
|
||||
if(!CK_ENABLE_BF16)
|
||||
{
|
||||
GTEST_SKIP() << "Skipping MaxPool2D_BF16 tests because CK_ENABLE_BF16 is "
|
||||
"not enabled";
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename TType>
|
||||
class MaxPool2D_I8 : public TestMaxPool2dFwd<TType>
|
||||
{
|
||||
protected:
|
||||
void SetUp() override
|
||||
{
|
||||
if(!CK_ENABLE_INT8)
|
||||
{
|
||||
GTEST_SKIP() << "Skipping MaxPool2D_I8 tests because CK_ENABLE_INT8 is "
|
||||
"not enabled";
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename TType>
|
||||
class MaxPool2D_F8 : public TestMaxPool2dFwd<TType>
|
||||
{
|
||||
protected:
|
||||
void SetUp() override
|
||||
{
|
||||
if(!CK_ENABLE_FP8)
|
||||
{
|
||||
GTEST_SKIP() << "Skipping MaxPool2D_F8 tests because CK_ENABLE_FP8 is "
|
||||
"not enabled";
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
TYPED_TEST_SUITE(MaxPool2D_F32, MaxPool2D_F32_Types);
|
||||
TYPED_TEST_SUITE(MaxPool2D_F16, MaxPool2D_F16_Types);
|
||||
TYPED_TEST_SUITE(MaxPool2D_BF16, MaxPool2D_BF16_Types);
|
||||
TYPED_TEST_SUITE(MaxPool2D_I8, MaxPool2D_I8_Types);
|
||||
TYPED_TEST_SUITE(MaxPool2D_F8, MaxPool2D_F8_Types);
|
||||
|
||||
TYPED_TEST(MaxPool2D_F32, MaxPool2D_I8_Test) { this->Run(); }
|
||||
TYPED_TEST(MaxPool2D_F16, MaxPool2D_F16_Test) { this->Run(); }
|
||||
TYPED_TEST(MaxPool2D_BF16, MaxPool2D_BF16_Test) { this->Run(); }
|
||||
TYPED_TEST(MaxPool2D_I8, MaxPool2D_I8_Test) { this->Run(); }
|
||||
TYPED_TEST(MaxPool2D_F8, MaxPool2D_F8_Test) { this->Run(); }
|
||||
|
||||
Reference in New Issue
Block a user