Maxpool bwd (#750)

* Add maxpool f32 kernel and example

* Revise copyright

* Add device pool bwd device op

* Support f16 and bf16

* Add compute datatype for reference code.
Prevent error in bf16

* Fix type error

* Remove layout

* Fix bf16 error

* Add f16 and bf16 example

* Add more operations

* Implement IsSupportedArgument

* Add changelog

* Add comment

* Add comment

* Remove useless header

* Move initialize of workspace to the run

* Move set din zero to the device operator

* Save din_length_raw

* Remove useless header

* Calculate gridsize according to the number of CU

* Calculate gridSize according to the number of CU.
Remove useless header

* Add put example

* Remove useless header

* Fix CI fail

[ROCm/composable_kernel commit: 341ad95665]
This commit is contained in:
rocking
2023-06-19 22:44:22 +08:00
committed by GitHub
parent d6f690d361
commit 9c2487d2a0
16 changed files with 1310 additions and 11 deletions

View File

@@ -0,0 +1,103 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <sstream>
#include <vector>
#include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
namespace ck {
namespace tensor_operation {
namespace host {
using namespace std;
template <typename DOutDataType,
typename IndexDataType,
typename ConputeDataType,
typename DInDataType,
typename ElementwiseOperation>
struct ReferenceMaxPoolBwd : public device::BaseOperator
{
// Argument
struct Argument : public device::BaseArgument
{
Argument(const Tensor<DOutDataType>& dout,
const Tensor<IndexDataType>& indices,
Tensor<DInDataType>& din,
ElementwiseOperation elementwise_op)
: dout_(dout), indices_(indices), din_(din), elementwise_op_(elementwise_op)
{
}
const Tensor<DOutDataType>& dout_;
const Tensor<IndexDataType>& indices_;
Tensor<DInDataType>& din_;
ElementwiseOperation elementwise_op_;
};
// Invoker
struct Invoker : public device::BaseInvoker
{
float Run(const Argument& arg)
{
int din_length = arg.din_.GetElementSpaceSize();
int dout_length = arg.dout_.GetElementSpaceSize();
std::vector<ConputeDataType> buf(din_length, 0);
for(int i = 0; i < dout_length; ++i)
{
int index = arg.indices_.mData[i];
if(index >= 0 && index < din_length)
buf[index] += ck::type_convert<ConputeDataType>(arg.dout_.mData[i]);
}
for(int i = 0; i < din_length; ++i)
arg.din_.mData[i] = ck::type_convert<DInDataType>(buf[i]);
return 0;
}
float Run(const device::BaseArgument* p_arg,
const StreamConfig& /* stream_config */ = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg));
}
};
bool IsSupportedArgument(const device::BaseArgument*) override { return true; }
static auto MakeArgument(const Tensor<DOutDataType>& dout,
const Tensor<IndexDataType>& indices,
Tensor<DInDataType>& din,
ElementwiseOperation elementwise_op)
{
return Argument{dout, indices, din, elementwise_op};
}
static auto MakeInvoker() { return Invoker{}; }
virtual std::unique_ptr<device::BaseInvoker> MakeInvokerPointer()
{
return std::make_unique<Invoker>(Invoker{});
}
std::string GetTypeString() const override
{
auto str = std::stringstream();
// clang-format off
str << "ReferenceMaxPoolBwd"
<< std::endl;
// clang-format on
return str.str();
}
};
} // namespace host
} // namespace tensor_operation
} // namespace ck

View File

@@ -100,8 +100,8 @@ struct ReferencePoolingFwd : public device::BaseOperator
wi >= 0 &&
wi < static_cast<ck::index_t>(arg.in_.mDesc.GetLengths()[4]))
{
ComputeDataType currVal =
static_cast<ComputeDataType>(arg.in_(n, c, di, hi, wi));
ComputeDataType currVal = ck::type_convert<ComputeDataType>(
arg.in_(n, c, di, hi, wi));
in_elementwise_op(currVal, currVal);
@@ -112,7 +112,7 @@ struct ReferencePoolingFwd : public device::BaseOperator
}
acc_elementwise_op(accuVal, accuVal);
arg.out_(n, c, do_, ho, wo) = accuVal;
arg.out_(n, c, do_, ho, wo) = ck::type_convert<OutDataType>(accuVal);
};
make_ParallelTensorFunctor(f_ncdhw,
@@ -151,8 +151,8 @@ struct ReferencePoolingFwd : public device::BaseOperator
wi >= 0 &&
wi < static_cast<ck::index_t>(arg.in_.mDesc.GetLengths()[4]))
{
ComputeDataType currVal =
static_cast<ComputeDataType>(arg.in_(n, c, di, hi, wi));
ComputeDataType currVal = ck::type_convert<ComputeDataType>(
arg.in_(n, c, di, hi, wi));
IndexDataType currIndex =
arg.in_.GetOffsetFromMultiIndex(n, c, di, hi, wi);
@@ -166,7 +166,7 @@ struct ReferencePoolingFwd : public device::BaseOperator
acc_elementwise_op(accuVal, accuVal);
arg.out_(n, c, do_, ho, wo) = accuVal;
arg.out_(n, c, do_, ho, wo) = ck::type_convert<OutDataType>(accuVal);
arg.out_indices_(n, c, do_, ho, wo) = accuIndex;
};
@@ -212,7 +212,7 @@ struct ReferencePoolingFwd : public device::BaseOperator
wi < static_cast<ck::index_t>(arg.in_.mDesc.GetLengths()[3]))
{
ComputeDataType currVal =
static_cast<ComputeDataType>(arg.in_(n, c, hi, wi));
ck::type_convert<ComputeDataType>(arg.in_(n, c, hi, wi));
in_elementwise_op(currVal, currVal);
@@ -222,7 +222,7 @@ struct ReferencePoolingFwd : public device::BaseOperator
}
acc_elementwise_op(accuVal, accuVal);
arg.out_(n, c, ho, wo) = accuVal;
arg.out_(n, c, ho, wo) = ck::type_convert<OutDataType>(accuVal);
};
make_ParallelTensorFunctor(f_nchw,
@@ -255,7 +255,7 @@ struct ReferencePoolingFwd : public device::BaseOperator
wi < static_cast<ck::index_t>(arg.in_.mDesc.GetLengths()[3]))
{
ComputeDataType currVal =
static_cast<ComputeDataType>(arg.in_(n, c, hi, wi));
ck::type_convert<ComputeDataType>(arg.in_(n, c, hi, wi));
IndexDataType currIndex =
arg.in_.GetOffsetFromMultiIndex(n, c, hi, wi);
@@ -268,7 +268,7 @@ struct ReferencePoolingFwd : public device::BaseOperator
}
acc_elementwise_op(accuVal, accuVal);
arg.out_(n, c, ho, wo) = accuVal;
arg.out_(n, c, ho, wo) = ck::type_convert<OutDataType>(accuVal);
arg.out_indices_(n, c, ho, wo) = accuIndex;
};