diff --git a/library/src/tensor_operation_instance/gpu/pool2d_fwd/device_max_pool2d_fwd_nhwc_bf16_instance.cpp b/library/src/tensor_operation_instance/gpu/pool2d_fwd/device_max_pool2d_fwd_nhwc_bf16_instance.cpp index a1577d40f7..4f02d84305 100644 --- a/library/src/tensor_operation_instance/gpu/pool2d_fwd/device_max_pool2d_fwd_nhwc_bf16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/pool2d_fwd/device_max_pool2d_fwd_nhwc_bf16_instance.cpp @@ -19,6 +19,15 @@ void add_device_pool2d_fwd_nhwc_bf16_instances( instances, device_pool2d_fwd_nhwc_instances{}); } +void add_device_pool2d_fwd_nhwc_index_bf16_instances( + std::vector< + std::unique_ptr>>& + instances) +{ + add_device_operation_instances( + instances, device_pool2d_fwd_nhwc_instances{}); +} + } // namespace instance } // namespace device } // namespace tensor_operation diff --git a/library/src/tensor_operation_instance/gpu/pool2d_fwd/device_max_pool2d_fwd_nhwc_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/pool2d_fwd/device_max_pool2d_fwd_nhwc_f16_instance.cpp index 26ee5e4bbe..e6a580d07f 100644 --- a/library/src/tensor_operation_instance/gpu/pool2d_fwd/device_max_pool2d_fwd_nhwc_f16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/pool2d_fwd/device_max_pool2d_fwd_nhwc_f16_instance.cpp @@ -18,6 +18,14 @@ void add_device_pool2d_fwd_nhwc_f16_instances( instances, device_pool2d_fwd_nhwc_instances{}); } +void add_device_pool2d_fwd_nhwc_index_f16_instances( + std::vector>>& + instances) +{ + add_device_operation_instances( + instances, device_pool2d_fwd_nhwc_instances{}); +} + } // namespace instance } // namespace device } // namespace tensor_operation diff --git a/library/src/tensor_operation_instance/gpu/pool2d_fwd/device_max_pool2d_fwd_nhwc_f32_instance.cpp b/library/src/tensor_operation_instance/gpu/pool2d_fwd/device_max_pool2d_fwd_nhwc_f32_instance.cpp index cc40ca48ad..1f104eab7f 100644 --- a/library/src/tensor_operation_instance/gpu/pool2d_fwd/device_max_pool2d_fwd_nhwc_f32_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/pool2d_fwd/device_max_pool2d_fwd_nhwc_f32_instance.cpp @@ -18,6 +18,14 @@ void add_device_pool2d_fwd_nhwc_f32_instances( instances, device_pool2d_fwd_nhwc_instances{}); } +void add_device_pool2d_fwd_nhwc_index_f32_instances( + std::vector>>& + instances) +{ + add_device_operation_instances( + instances, device_pool2d_fwd_nhwc_instances{}); +} + } // namespace instance } // namespace device } // namespace tensor_operation diff --git a/profiler/include/profiler/profile_pool2d_fwd_impl.hpp b/profiler/include/profiler/profile_pool2d_fwd_impl.hpp new file mode 100644 index 0000000000..63c71ea7f5 --- /dev/null +++ b/profiler/include/profiler/profile_pool2d_fwd_impl.hpp @@ -0,0 +1,274 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +#include "ck/ck.hpp" +#include "ck/library/tensor_operation_instance/gpu/pool2d_fwd.hpp" +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_pool_fwd.hpp" + +namespace ck { +namespace profiler { + +template +bool profile_pool2d_fwd_impl(int do_verification, + int init_method, + bool do_log, + bool time_kernel, + std::vector in_length, // NCHW + std::vector window_spatial_lengths, + std::vector window_strides, + std::vector window_dilations, + std::vector input_left_pads, + std::vector input_right_pads) +{ + constexpr index_t InOutRank = 4; + constexpr index_t WindowRank = 2; + + if(in_length.size() != InOutRank || window_spatial_lengths.size() != WindowRank || + window_strides.size() != WindowRank || window_dilations.size() != WindowRank || + input_left_pads.size() != WindowRank || input_right_pads.size() != WindowRank) + return false; + + std::vector out_length(InOutRank); + + int N = in_length[0]; + int C = in_length[1]; + + out_length[0] = N; + out_length[1] = C; + + // Calculate Do, Ho, Wo + for(int i = 2; i < InOutRank; ++i) + { + auto pad1 = input_left_pads[i - 2]; + auto pad2 = input_right_pads[i - 2]; + auto windows_size = window_spatial_lengths[i - 2]; + auto windows_stride = window_strides[i - 2]; + auto windows_dilation = window_dilations[i - 2]; + auto eff = (windows_size - 1) * windows_dilation + 1; + out_length[i] = (in_length[i] + pad1 + pad2 - eff) / windows_stride + 1; + } + + int Hi = in_length[2]; + int Wi = in_length[3]; + int Ho = out_length[2]; + int Wo = out_length[3]; + + auto f_host_tensor_descriptor = + [](std::size_t N_, std::size_t C_, std::size_t H, std::size_t W) { + using namespace ck::literals; + + return HostTensorDescriptor({N_, C_, H, W}, {C_ * H * W, 1_uz, W * C_, C_}); + }; + + Tensor in_n_c_hi_wi(f_host_tensor_descriptor(N, C, Hi, Wi)); + Tensor out_n_c_ho_wo_host(f_host_tensor_descriptor(N, C, Ho, Wo)); + Tensor out_indices_n_c_ho_wo_host(f_host_tensor_descriptor(N, C, Ho, Wo)); + + Tensor out_n_c_ho_wo_device(f_host_tensor_descriptor(N, C, Ho, Wo)); + Tensor out_indices_n_c_ho_wo_device(f_host_tensor_descriptor(N, C, Ho, Wo)); + + switch(init_method) + { + case 0: in_n_c_hi_wi.GenerateTensorValue(GeneratorTensor_1{}); break; + case 1: in_n_c_hi_wi.GenerateTensorValue(GeneratorTensor_2{-5, 5}); break; + default: in_n_c_hi_wi.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + } + + DeviceMem in_device_buf(sizeof(InDataType) * in_n_c_hi_wi.mDesc.GetElementSpaceSize()); + DeviceMem out_device_buf(sizeof(OutDataType) * + out_n_c_ho_wo_device.mDesc.GetElementSpaceSize()); + DeviceMem out_indices_device_buf(sizeof(IndexDataType) * + out_indices_n_c_ho_wo_device.mDesc.GetElementSpaceSize()); + + in_device_buf.ToDevice(in_n_c_hi_wi.mData.data()); + + // add device normalization instances + using DeviceOp = ck::tensor_operation::device::DevicePoolFwd; + + // get device op instances + const auto instance_ptrs = + ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + std::cout << "found " << instance_ptrs.size() << " instances" << std::endl; + + std::string best_instance_name; + float best_avg_time = std::numeric_limits::max(); + float best_gb_per_sec = 0; + + if(do_verification) + { + using ReferenceInstance = ck::tensor_operation::host::ReferencePoolingFwd; + + ReferenceInstance ref; + auto ref_argument = ref.MakeArgument(in_n_c_hi_wi, + out_n_c_ho_wo_host, + out_indices_n_c_ho_wo_host, + window_spatial_lengths, + window_strides, + window_dilations, + input_left_pads, + input_right_pads); + auto ref_invoker = ref.MakeInvoker(); + ref_invoker.Run(ref_argument); + } + + int num_kernel = 0; + + for(auto& inst_ptr : instance_ptrs) + { + auto argument_ptr = inst_ptr->MakeArgumentPointer( + static_cast(in_device_buf.GetDeviceBuffer()), + static_cast(out_device_buf.GetDeviceBuffer()), + static_cast(out_indices_device_buf.GetDeviceBuffer()), + in_length, + window_spatial_lengths, + out_length, + {C * Hi * Wi, 1, Wi * C, C}, + {C * Ho * Wo, 1, Wo * C, C}, + {C * Ho * Wo, 1, Wo * C, C}, + window_strides, + window_dilations, + input_left_pads, + input_right_pads, + {2, 3}); + + if(inst_ptr->IsSupportedArgument(argument_ptr.get())) + { + ++num_kernel; + } + else + { + if(time_kernel) + { + std::cout << inst_ptr->GetTypeString() << " skipped due to unsupported argument: "; + LogRange(std::cout << "input lengths = ", in_length, ", ") << std::endl; + } + + continue; + } + + auto invoker_ptr = inst_ptr->MakeInvokerPointer(); + + float avg_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel}); + + std::size_t num_bytes = in_n_c_hi_wi.mDesc.GetElementSize() * sizeof(InDataType) + + out_n_c_ho_wo_host.mDesc.GetElementSize() * sizeof(OutDataType); + + if constexpr(OutputIndex) + num_bytes += out_indices_n_c_ho_wo_host.mDesc.GetElementSize() * sizeof(IndexDataType); + + float gb_per_sec = num_bytes / 1.E6 / avg_time; + + if(time_kernel) + std::cout << "Perf: " << std::setw(10) << avg_time << " ms, " << gb_per_sec << " GB/s, " + << inst_ptr->GetTypeString() << std::endl; + + if(avg_time < best_avg_time) + { + best_instance_name = inst_ptr->GetTypeString(); + best_avg_time = avg_time; + best_gb_per_sec = gb_per_sec; + } + + if(do_verification) + { + out_device_buf.FromDevice(out_n_c_ho_wo_device.mData.data()); + + bool pass = ck::utils::check_err(out_n_c_ho_wo_device.mData, + out_n_c_ho_wo_host.mData, + "Error: Incorrect results", + 1e-3, + 1e-3); + + if constexpr(OutputIndex) + { + out_indices_device_buf.FromDevice(out_indices_n_c_ho_wo_device.mData.data()); + + pass = pass && ck::utils::check_err(out_indices_n_c_ho_wo_device, + out_indices_n_c_ho_wo_host); + } + + if(do_log) + { + LogRangeAsType(std::cout << "in_n_c_hi_wi : ", in_n_c_hi_wi.mData, ",") + << std::endl; + LogRangeAsType( + std::cout << "out_n_c_ho_wo_host : ", out_n_c_ho_wo_host.mData, ",") + << std::endl; + LogRangeAsType( + std::cout << "out_n_c_ho_wo_device : ", out_n_c_ho_wo_device.mData, ",") + << std::endl; + + if constexpr(OutputIndex) + LogRangeAsType(std::cout << "out_indices_n_c_ho_wo_device : ", + out_indices_n_c_ho_wo_device.mData, + ",") + << std::endl; + } + + if(!pass) + { + std::cout << inst_ptr->GetTypeString() << " failed verification: "; + LogRange(std::cout << "lengths = [", in_length, ", ") << "]." << std::endl; + return false; + } + else + { + if(time_kernel) + std::cout << "pass" << std::endl; + } + } + } + + if(time_kernel) + { + LogRange(std::cout << "length = ", in_length, ",") << std::endl; + std::cout << "best perf = " << best_avg_time << " ms, " << best_gb_per_sec << " GB/s, " + << best_instance_name << std::endl; + } + + if(num_kernel == 0) + { + std::cout << "Error: No kernel is applicable" << std::endl; + return false; + } + + return true; +} + +} // namespace profiler +} // namespace ck diff --git a/profiler/src/CMakeLists.txt b/profiler/src/CMakeLists.txt index 43bebba8cb..554808cac5 100755 --- a/profiler/src/CMakeLists.txt +++ b/profiler/src/CMakeLists.txt @@ -9,6 +9,7 @@ set(PROFILER_SOURCES profile_layernorm_bwd_gamma_beta.cpp profile_groupnorm_bwd_gamma_beta.cpp profile_layernorm_fwd.cpp + profile_max_pool2d_fwd.cpp profile_max_pool3d_fwd.cpp profile_avg_pool3d_bwd.cpp profile_max_pool3d_bwd.cpp @@ -98,6 +99,7 @@ target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_normalization_bwd_ga target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_softmax_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_reduce_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batchnorm_instance) +target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_pool2d_fwd_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_pool3d_fwd_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_avg_pool3d_bwd_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_max_pool_bwd_instance) diff --git a/profiler/src/profile_max_pool2d_fwd.cpp b/profiler/src/profile_max_pool2d_fwd.cpp new file mode 100644 index 0000000000..7ef3b98c77 --- /dev/null +++ b/profiler/src/profile_max_pool2d_fwd.cpp @@ -0,0 +1,239 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include + +#include "profiler/data_type_enum.hpp" +#include "profiler/profile_pool2d_fwd_impl.hpp" +#include "profiler_operation_registry.hpp" + +using ck::index_t; + +struct maxPoolFwdArgParser +{ + std::unordered_map> long_opts = {{"length", {}}, + {"wsize", {}}, + {"wstride", {}}, + {"wdilation", {}}, + {"pad1", {}}, + {"pad2", {}}}; + + bool parse_opt(int argc, char* argv[], const std::string& key, int i) + { + if(std::string("--") + key == argv[i]) + { + int pos = i; + while(++i < argc && argv[i][0] != '-') {} + int end = i; + for(int j = pos + 1; j < end; j++) + { + long_opts[key].push_back(std::stoi(argv[j])); + } + return true; + } + return false; + } + + void operator()(int argc, char* argv[]) + { + for(auto& kv : long_opts) + { + for(int i = 1; i < argc; i++) + { + if(parse_opt(argc, argv, kv.first, i)) + break; + } + } + } +}; + +void print_help_max_pool2d_fwd() +{ + std::cout << "arg1: data type (0: fp16; 1: fp32; 5: bf16)\n" + << "arg2: verification (0: no; 1: yes)\n" + << "arg3: initialization (0: no init; 1: integer value; 2: decimal value)\n" + << "arg4: print tensor value (0: no; 1: yes)\n" + << "arg5: time kernel (0=no, 1=yes)\n" + << "arg6: return index (0=no, 1=yes)\n" + << "--length: input tensor length for NCDHW(e.g, --length 2 32 30 30 30) \n" + << "--wsize: window size for ZYX (e.g, --wsize 2 2 2) \n" + << "--wstride: window stride for DHW (e.g, --wstride 2 2 2) \n" + << "--wdilation: window dilation for DHW (e.g, --wdilation 1 1 1) \n" + << "--pad1: left side of padding in DHW (e.g, --pad1 1 1 1) \n" + << "--pad2: right side of padding in DHW (e.g, --pad2 1 1 1) \n" + << "eg: ckProfiler max_pool3d_fwd 0 1 2 0 1 0 --length 2 32 30 30 30 --wsize 2 2 2 " + "--wstride 2 2 2 --wdilation 1 1 1 --pad1 1 1 1 --pad2 1 1 1" + << std::endl; +} + +int profile_max_pool2d_fwd(int argc, char* argv[]) +{ + ck::DataTypeEnum data_type = ck::DataTypeEnum::Half; + bool do_verification = true; + int init_method = 0; + bool do_log = false; + bool time_kernel = true; + bool return_index = false; + + std::vector in_length = {2, 32, 30, 30}; + std::vector wsize = {2, 2}; + std::vector wstride = {2, 2}; + std::vector wdilation = {1, 1}; + std::vector pad1 = {1, 1}; + std::vector pad2 = {1, 1}; + + if(argc != 2 && argc != 34) + { + print_help_max_pool2d_fwd(); + return 0; + } + else if(argc == 34) + { + data_type = static_cast(std::stoi(argv[2])); + do_verification = std::stoi(argv[3]); + init_method = std::stoi(argv[4]); + do_log = std::stoi(argv[5]); + time_kernel = std::stoi(argv[6]); + return_index = std::stoi(argv[7]); + + // parse the long options + maxPoolFwdArgParser arg_parser; + arg_parser(argc, argv); + in_length = arg_parser.long_opts["length"]; + wsize = arg_parser.long_opts["wsize"]; + wstride = arg_parser.long_opts["wstride"]; + wdilation = arg_parser.long_opts["wdilation"]; + pad1 = arg_parser.long_opts["pad1"]; + pad2 = arg_parser.long_opts["pad2"]; + } + +#ifdef CK_ENABLE_FP16 + using F16 = ck::half_t; +#endif +#ifdef CK_ENABLE_BF16 + using BF16 = ck::bhalf_t; +#endif +#ifdef CK_ENABLE_FP32 + using F32 = float; +#endif + using I32 = int32_t; + using NHWC = ck::tensor_layout::convolution::NHWC; + +#if 1 + constexpr auto ReduceOpId = ck::ReduceTensorOp::MAX; +#else + constexpr auto ReduceOpId = ck::ReduceTensorOp::AVG; +#endif + + if(false) + ; +#ifdef CK_ENABLE_FP16 + else if(data_type == ck::DataTypeEnum::Half) + { + if(return_index) + ck::profiler:: + profile_pool2d_fwd_impl( + do_verification, + init_method, + do_log, + time_kernel, + in_length, + wsize, + wstride, + wdilation, + pad1, + pad2); + else + ck::profiler:: + profile_pool2d_fwd_impl( + do_verification, + init_method, + do_log, + time_kernel, + in_length, + wsize, + wstride, + wdilation, + pad1, + pad2); + } +#endif +#ifdef CK_ENABLE_BF16 + else if(data_type == ck::DataTypeEnum::BFloat16) + { + if(return_index) + ck::profiler:: + profile_pool2d_fwd_impl( + do_verification, + init_method, + do_log, + time_kernel, + in_length, + wsize, + wstride, + wdilation, + pad1, + pad2); + else + ck::profiler::profile_pool2d_fwd_impl(do_verification, + init_method, + do_log, + time_kernel, + in_length, + wsize, + wstride, + wdilation, + pad1, + pad2); + } +#endif +#ifdef CK_ENABLE_FP32 + else if(data_type == ck::DataTypeEnum::Float) + { + if(return_index) + ck::profiler:: + profile_pool2d_fwd_impl( + do_verification, + init_method, + do_log, + time_kernel, + in_length, + wsize, + wstride, + wdilation, + pad1, + pad2); + else + ck::profiler:: + profile_pool2d_fwd_impl( + do_verification, + init_method, + do_log, + time_kernel, + in_length, + wsize, + wstride, + wdilation, + pad1, + pad2); + } +#endif + else + { + throw std::runtime_error("not implemented yet"); + } + + return 0; +} + +REGISTER_PROFILER_OPERATION("max_pool2d_fwd", "max_pool2d fwd", profile_max_pool2d_fwd); diff --git a/test/pool/CMakeLists.txt b/test/pool/CMakeLists.txt index fac806897a..0118a7591b 100644 --- a/test/pool/CMakeLists.txt +++ b/test/pool/CMakeLists.txt @@ -4,13 +4,19 @@ add_gtest_executable(test_avg_pool3d_bwd test_avg_pool3d_bwd.cpp) add_gtest_executable(test_max_pool3d_bwd test_max_pool3d_bwd.cpp) add_gtest_executable(test_avg_pool3d_fwd test_avg_pool3d_fwd.cpp) add_gtest_executable(test_max_pool3d_fwd test_max_pool3d_fwd.cpp) +add_gtest_executable(test_avg_pool2d_fwd test_avg_pool2d_fwd.cpp) +add_gtest_executable(test_max_pool2d_fwd test_max_pool2d_fwd.cpp) target_link_libraries(test_avg_pool3d_bwd PRIVATE utility device_avg_pool3d_bwd_instance) target_link_libraries(test_max_pool3d_bwd PRIVATE utility device_max_pool_bwd_instance) target_link_libraries(test_avg_pool3d_fwd PRIVATE utility device_pool3d_fwd_instance) target_link_libraries(test_max_pool3d_fwd PRIVATE utility device_pool3d_fwd_instance) +target_link_libraries(test_avg_pool2d_fwd PRIVATE utility device_pool2d_fwd_instance) +target_link_libraries(test_max_pool2d_fwd PRIVATE utility device_pool2d_fwd_instance) add_dependencies(test_pool test_avg_pool3d_bwd) add_dependencies(test_pool test_max_pool3d_bwd) add_dependencies(test_pool test_avg_pool3d_fwd) add_dependencies(test_pool test_max_pool3d_fwd) +add_dependencies(test_pool test_avg_pool2d_fwd) +add_dependencies(test_pool test_max_pool2d_fwd) diff --git a/test/pool/test_avg_pool2d_fwd.cpp b/test/pool/test_avg_pool2d_fwd.cpp new file mode 100644 index 0000000000..7f0e08c6d1 --- /dev/null +++ b/test/pool/test_avg_pool2d_fwd.cpp @@ -0,0 +1,65 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "gtest/gtest.h" +#include "profiler/profile_pool2d_fwd_impl.hpp" +#include "test_pool_fwd_common.hpp" + +template +class TestAvgPool2dFwd : public ::testing::Test +{ + protected: + using InDataType = std::tuple_element_t<0, Tuple>; + using OutDataType = std::tuple_element_t<1, Tuple>; + using ComputeDataType = std::tuple_element_t<2, Tuple>; + using IndexDataType = std::tuple_element_t<3, Tuple>; + + std::vector params; + + void Run() + { + for(auto param : params) + { + // avg pool + bool success = + ck::profiler::profile_pool2d_fwd_impl(true, + 2, + false, + false, + param.length_, + param.window_spatial_lengths_, + param.window_strides_, + param.window_dilations_, + param.input_left_pads_, + param.input_right_pads_); + EXPECT_TRUE(success); + } + } +}; + +#ifdef CK_ENABLE_FP16 +using KernelTypes = + ::testing::Types, std::tuple>; +#else +using KernelTypes = ::testing::Types>; +#endif + +TYPED_TEST_SUITE(TestAvgPool2dFwd, KernelTypes); +TYPED_TEST(TestAvgPool2dFwd, Test_Pool) +{ + // length, window_length, window_stride, window_dilation, left_pad, right_pad + this->params = {{{1, 1, 1, 1}, {1, 1}, {1, 1}, {1, 1}, {0, 0}, {0, 0}}, + {{2, 16, 64, 64}, {64, 64}, {1, 1}, {1, 1}, {0, 0}, {0, 0}}, + {{2, 16, 64, 64}, {4, 4}, {4, 4}, {2, 2}, {0, 0}, {0, 0}}, + {{2, 32, 30, 30}, {2, 2}, {2, 2}, {1, 1}, {1, 1}, {1, 1}}}; + + this->Run(); +} diff --git a/test/pool/test_max_pool2d_fwd.cpp b/test/pool/test_max_pool2d_fwd.cpp new file mode 100644 index 0000000000..06be1e00ab --- /dev/null +++ b/test/pool/test_max_pool2d_fwd.cpp @@ -0,0 +1,86 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "gtest/gtest.h" +#include "profiler/profile_pool2d_fwd_impl.hpp" +#include "test_pool_fwd_common.hpp" + +template +class TestMaxPool2dFwd : public ::testing::Test +{ + protected: + using InDataType = std::tuple_element_t<0, Tuple>; + using OutDataType = std::tuple_element_t<1, Tuple>; + using ComputeDataType = std::tuple_element_t<2, Tuple>; + using IndexDataType = std::tuple_element_t<3, Tuple>; + + std::vector params; + + void Run() + { + for(auto param : params) + { + // max pool + bool success = + ck::profiler::profile_pool2d_fwd_impl(true, + 2, + false, + false, + param.length_, + param.window_spatial_lengths_, + param.window_strides_, + param.window_dilations_, + param.input_left_pads_, + param.input_right_pads_); + EXPECT_TRUE(success); + + // max pool + index + success = ck::profiler::profile_pool2d_fwd_impl(true, + 2, + false, + false, + param.length_, + param.window_spatial_lengths_, + param.window_strides_, + param.window_dilations_, + param.input_left_pads_, + param.input_right_pads_); + EXPECT_TRUE(success); + } + } +}; + +#ifdef CK_ENABLE_FP16 +using KernelTypes = + ::testing::Types, std::tuple>; +#else +using KernelTypes = ::testing::Types>; +#endif + +TYPED_TEST_SUITE(TestMaxPool2dFwd, KernelTypes); +TYPED_TEST(TestMaxPool2dFwd, Test_Pool) +{ + // length, window_length, window_stride, window_dilation, left_pad, right_pad + this->params = {{{1, 1, 1, 1}, {1, 1}, {1, 1}, {1, 1}, {0, 0}, {0, 0}}, + {{2, 16, 64, 64}, {64, 64}, {1, 1}, {1, 1}, {0, 0}, {0, 0}}, + {{2, 16, 64, 64}, {4, 4}, {4, 4}, {2, 2}, {0, 0}, {0, 0}}, + {{2, 32, 30, 30}, {2, 2}, {2, 2}, {1, 1}, {1, 1}, {1, 1}}}; + + this->Run(); +} diff --git a/test/pool/test_pool_fwd_common.hpp b/test/pool/test_pool_fwd_common.hpp index 5917a27e56..fbd14d968c 100644 --- a/test/pool/test_pool_fwd_common.hpp +++ b/test/pool/test_pool_fwd_common.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #include "gtest/gtest.h" #include "ck/ck.hpp"