mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-17 11:30:02 +00:00
* Expand the base class of pool2d, prepare to share base class with pool3d
* Add pool3d device op
* Add pool3d f16 example
* Refactor the base class. implement generic pooling in the future
* clang format
* get original index in max pooling
* Add outputindex to base class
* Fix dimension
* Add pooling instance
* Use indexType instead
* Remove useless header
* Extract IndexDataType to template
* Extract pooling reference code
* clang format
* clang format
* Fix typo
* Add tensor stride
* Add missing header
* Add index stride and output stride
* Refine naming
* Add type to base class
* Rename file
* Use proper size
* Fix typo
* Refine naming
* Modify the argument into vector.
* Add max pool profiler
* Refine naming
* Support f32 pool
* Fix typo
* Add avg pool2d fwd in profiler
* clang format
* Rename AccDatatype to ComputeDatatype
* Fix init
* test pool
* Extract variable
* Add client example
* Check the pooling dim
* clang format
* Connect argv and arg_parser
* Add found check
* Remove useless header
* Refine naming
* Adjust the order of device_pool_fwd
[ROCm/composable_kernel commit: 76ec0089fb]
117 lines
3.7 KiB
C++
117 lines
3.7 KiB
C++
// SPDX-License-Identifier: MIT
|
|
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
|
|
|
#include <iostream>
|
|
|
|
#include "ck/ck.hpp"
|
|
#include "ck/utility/reduction_enums.hpp"
|
|
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
|
|
|
#include "pool2d_fwd_common.hpp"
|
|
|
|
using InDataType = float;
|
|
using OutDataType = float;
|
|
using ComputeDataType = float;
|
|
|
|
using IndexDataType = int32_t;
|
|
|
|
using InLayout = ck::tensor_layout::convolution::NHWC;
|
|
using OutLayout = ck::tensor_layout::convolution::NHWC;
|
|
|
|
#if 1
|
|
static constexpr auto ReduceOpId = ck::ReduceTensorOp::MAX;
|
|
#else
|
|
static constexpr auto ReduceOpId = ck::ReduceTensorOp::AVG;
|
|
#endif
|
|
|
|
static constexpr bool OutputIndex = false;
|
|
static constexpr bool PropagateNan = false;
|
|
|
|
int main(int argc, char* argv[])
|
|
{
|
|
bool do_verification;
|
|
int init_method;
|
|
bool time_kernel;
|
|
|
|
// Pool shape
|
|
ck::index_t N = 128;
|
|
ck::index_t C = 192;
|
|
ck::index_t Y = 3;
|
|
ck::index_t X = 3;
|
|
ck::index_t Hi = 71;
|
|
ck::index_t Wi = 71;
|
|
ck::index_t window_stride_h = 2;
|
|
ck::index_t window_stride_w = 2;
|
|
ck::index_t in_left_pad_h = 1;
|
|
ck::index_t in_left_pad_w = 1;
|
|
ck::index_t in_right_pad_h = 1;
|
|
ck::index_t in_right_pad_w = 1;
|
|
|
|
if(argc == 1)
|
|
{
|
|
do_verification = true;
|
|
init_method = 1;
|
|
time_kernel = true;
|
|
}
|
|
else if(argc == 4)
|
|
{
|
|
do_verification = std::stoi(argv[1]);
|
|
init_method = std::stoi(argv[2]);
|
|
time_kernel = static_cast<bool>(std::stoi(argv[3]));
|
|
}
|
|
else if(argc == 16)
|
|
{
|
|
do_verification = std::stoi(argv[1]);
|
|
init_method = std::stoi(argv[2]);
|
|
time_kernel = static_cast<bool>(std::stoi(argv[3]));
|
|
|
|
N = std::stoi(argv[4]);
|
|
C = std::stoi(argv[5]);
|
|
Y = std::stoi(argv[6]);
|
|
X = std::stoi(argv[7]);
|
|
Hi = std::stoi(argv[8]);
|
|
Wi = std::stoi(argv[9]);
|
|
window_stride_h = std::stoi(argv[10]);
|
|
window_stride_w = std::stoi(argv[11]);
|
|
in_left_pad_h = std::stoi(argv[12]);
|
|
in_left_pad_w = std::stoi(argv[13]);
|
|
in_right_pad_h = std::stoi(argv[14]);
|
|
in_right_pad_w = std::stoi(argv[15]);
|
|
}
|
|
else
|
|
{
|
|
printf("arg1: verification (0=no, 1=yes)\n");
|
|
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
|
|
printf("arg3: time kernel (0=no, 1=yes)\n");
|
|
printf("arg4 to 15: N, C, Y, X, Hi, Wi, Sy, Sx, LeftPy, LeftPx, RightPy, "
|
|
"RightPx\n");
|
|
exit(0);
|
|
}
|
|
|
|
bool pass = pool_test<InDataType,
|
|
OutDataType,
|
|
ComputeDataType,
|
|
IndexDataType,
|
|
InLayout,
|
|
OutLayout,
|
|
ReduceOpId,
|
|
PropagateNan,
|
|
OutputIndex>(do_verification,
|
|
init_method,
|
|
time_kernel,
|
|
N,
|
|
C,
|
|
Y,
|
|
X,
|
|
Hi,
|
|
Wi,
|
|
window_stride_h,
|
|
window_stride_w,
|
|
in_left_pad_h,
|
|
in_left_pad_w,
|
|
in_right_pad_h,
|
|
in_right_pad_w);
|
|
|
|
return (pass ? 0 : 1);
|
|
}
|