mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-24 23:05:54 +00:00
Maxpool bwd (#750)
* Add maxpool f32 kernel and example
* Revise copyright
* Add device pool bwd device op
* Support f16 and bf16
* Add compute datatype for reference code.
Prevent error in bf16
* Fix type error
* Remove layout
* Fix bf16 error
* Add f16 and bf16 example
* Add more operations
* Implement IsSupportedArgument
* Add changelog
* Add comment
* Add comment
* Remove useless header
* Move initialize of workspace to the run
* Move set din zero to the device operator
* Save din_length_raw
* Remove useless header
* Calculate gridsize according to the number of CU
* Calculate gridSize according to the number of CU.
Remove useless header
* Add put example
* Remove useless header
* Fix CI fail
[ROCm/composable_kernel commit: 341ad95665]
This commit is contained in:
62
example/49_maxpool2d_bwd/maxpool2d_bwd_fp32.cpp
Normal file
62
example/49_maxpool2d_bwd/maxpool2d_bwd_fp32.cpp
Normal file
@@ -0,0 +1,62 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/utility/reduction_enums.hpp"
|
||||
|
||||
#include "maxpool2d_bwd_common.hpp"
|
||||
|
||||
using InDataType = float;
|
||||
using OutDataType = float;
|
||||
using IndexDataType = int32_t;
|
||||
using ComputeDataType = float;
|
||||
using DInDataType = float;
|
||||
using DOutDataType = float;
|
||||
|
||||
static constexpr bool PropagateNan = false;
|
||||
|
||||
int main()
|
||||
{
|
||||
bool do_verification = true;
|
||||
bool time_kernel = false;
|
||||
|
||||
// Pool shape
|
||||
ck::index_t N = 1;
|
||||
ck::index_t C = 1;
|
||||
ck::index_t Y = 2;
|
||||
ck::index_t X = 2;
|
||||
ck::index_t Hi = 32;
|
||||
ck::index_t Wi = 32;
|
||||
ck::index_t window_stride_h = 2;
|
||||
ck::index_t window_stride_w = 2;
|
||||
ck::index_t in_left_pad_h = 0;
|
||||
ck::index_t in_left_pad_w = 0;
|
||||
ck::index_t in_right_pad_h = 0;
|
||||
ck::index_t in_right_pad_w = 0;
|
||||
|
||||
bool pass = maxpool_bwd_test<InDataType,
|
||||
OutDataType,
|
||||
IndexDataType,
|
||||
ComputeDataType,
|
||||
DInDataType,
|
||||
DOutDataType,
|
||||
PropagateNan>(do_verification,
|
||||
time_kernel,
|
||||
N,
|
||||
C,
|
||||
Y,
|
||||
X,
|
||||
Hi,
|
||||
Wi,
|
||||
window_stride_h,
|
||||
window_stride_w,
|
||||
in_left_pad_h,
|
||||
in_left_pad_w,
|
||||
in_right_pad_h,
|
||||
in_right_pad_w);
|
||||
|
||||
return (pass ? 0 : 1);
|
||||
}
|
||||
Reference in New Issue
Block a user