mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-04 05:31:24 +00:00
* Add maxpool f32 kernel and example * Revise copyright * Add device pool bwd device op * Support f16 and bf16 * Add compute datatype for reference code. Prevent error in bf16 * Fix type error * Remove layout * Fix bf16 error * Add f16 and bf16 example * Add more operations * Implement IsSupportedArgument * Add changelog * Add comment * Add comment * Remove useless header * Move initialize of workspace to the run * Move set din zero to the device operator * Save din_length_raw * Remove useless header * Calculate gridsize according to the number of CU * Calculate gridSize according to the number of CU. Remove useless header * Add put example * Remove useless header * Fix CI fail
63 lines
2.1 KiB
C++
63 lines
2.1 KiB
C++
// SPDX-License-Identifier: MIT
|
|
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
|
|
|
#include <iostream>
|
|
|
|
#include "ck/ck.hpp"
|
|
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
|
#include "ck/utility/reduction_enums.hpp"
|
|
|
|
#include "maxpool2d_bwd_common.hpp"
|
|
|
|
using InDataType = float;
|
|
using OutDataType = float;
|
|
using IndexDataType = int32_t;
|
|
using ComputeDataType = float;
|
|
using DInDataType = float;
|
|
using DOutDataType = float;
|
|
|
|
static constexpr bool PropagateNan = false;
|
|
|
|
int main()
|
|
{
|
|
bool do_verification = true;
|
|
bool time_kernel = false;
|
|
|
|
// Pool shape
|
|
ck::index_t N = 1;
|
|
ck::index_t C = 1;
|
|
ck::index_t Y = 2;
|
|
ck::index_t X = 2;
|
|
ck::index_t Hi = 32;
|
|
ck::index_t Wi = 32;
|
|
ck::index_t window_stride_h = 2;
|
|
ck::index_t window_stride_w = 2;
|
|
ck::index_t in_left_pad_h = 0;
|
|
ck::index_t in_left_pad_w = 0;
|
|
ck::index_t in_right_pad_h = 0;
|
|
ck::index_t in_right_pad_w = 0;
|
|
|
|
bool pass = maxpool_bwd_test<InDataType,
|
|
OutDataType,
|
|
IndexDataType,
|
|
ComputeDataType,
|
|
DInDataType,
|
|
DOutDataType,
|
|
PropagateNan>(do_verification,
|
|
time_kernel,
|
|
N,
|
|
C,
|
|
Y,
|
|
X,
|
|
Hi,
|
|
Wi,
|
|
window_stride_h,
|
|
window_stride_w,
|
|
in_left_pad_h,
|
|
in_left_pad_w,
|
|
in_right_pad_h,
|
|
in_right_pad_w);
|
|
|
|
return (pass ? 0 : 1);
|
|
}
|