mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-18 12:00:07 +00:00
Maxpool bwd (#750)
* Add maxpool f32 kernel and example
* Revise copyright
* Add device pool bwd device op
* Support f16 and bf16
* Add compute datatype for reference code.
Prevent error in bf16
* Fix type error
* Remove layout
* Fix bf16 error
* Add f16 and bf16 example
* Add more operations
* Implement IsSupportedArgument
* Add changelog
* Add comment
* Add comment
* Remove useless header
* Move initialize of workspace to the run
* Move set din zero to the device operator
* Save din_length_raw
* Remove useless header
* Calculate gridsize according to the number of CU
* Calculate gridSize according to the number of CU.
Remove useless header
* Add put example
* Remove useless header
* Fix CI fail
[ROCm/composable_kernel commit: 341ad95665]
This commit is contained in:
@@ -0,0 +1,103 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
#include <vector>
|
||||
|
||||
#include "ck/tensor_operation/gpu/device/device_base.hpp"
|
||||
#include "ck/library/utility/host_tensor.hpp"
|
||||
#include "ck/library/utility/host_tensor_generator.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace host {
|
||||
using namespace std;
|
||||
|
||||
template <typename DOutDataType,
|
||||
typename IndexDataType,
|
||||
typename ConputeDataType,
|
||||
typename DInDataType,
|
||||
typename ElementwiseOperation>
|
||||
struct ReferenceMaxPoolBwd : public device::BaseOperator
|
||||
{
|
||||
// Argument
|
||||
struct Argument : public device::BaseArgument
|
||||
{
|
||||
Argument(const Tensor<DOutDataType>& dout,
|
||||
const Tensor<IndexDataType>& indices,
|
||||
Tensor<DInDataType>& din,
|
||||
ElementwiseOperation elementwise_op)
|
||||
: dout_(dout), indices_(indices), din_(din), elementwise_op_(elementwise_op)
|
||||
{
|
||||
}
|
||||
|
||||
const Tensor<DOutDataType>& dout_;
|
||||
const Tensor<IndexDataType>& indices_;
|
||||
Tensor<DInDataType>& din_;
|
||||
ElementwiseOperation elementwise_op_;
|
||||
};
|
||||
|
||||
// Invoker
|
||||
struct Invoker : public device::BaseInvoker
|
||||
{
|
||||
float Run(const Argument& arg)
|
||||
{
|
||||
int din_length = arg.din_.GetElementSpaceSize();
|
||||
int dout_length = arg.dout_.GetElementSpaceSize();
|
||||
std::vector<ConputeDataType> buf(din_length, 0);
|
||||
|
||||
for(int i = 0; i < dout_length; ++i)
|
||||
{
|
||||
int index = arg.indices_.mData[i];
|
||||
if(index >= 0 && index < din_length)
|
||||
buf[index] += ck::type_convert<ConputeDataType>(arg.dout_.mData[i]);
|
||||
}
|
||||
|
||||
for(int i = 0; i < din_length; ++i)
|
||||
arg.din_.mData[i] = ck::type_convert<DInDataType>(buf[i]);
|
||||
return 0;
|
||||
}
|
||||
|
||||
float Run(const device::BaseArgument* p_arg,
|
||||
const StreamConfig& /* stream_config */ = StreamConfig{}) override
|
||||
{
|
||||
return Run(*dynamic_cast<const Argument*>(p_arg));
|
||||
}
|
||||
};
|
||||
|
||||
bool IsSupportedArgument(const device::BaseArgument*) override { return true; }
|
||||
|
||||
static auto MakeArgument(const Tensor<DOutDataType>& dout,
|
||||
const Tensor<IndexDataType>& indices,
|
||||
Tensor<DInDataType>& din,
|
||||
ElementwiseOperation elementwise_op)
|
||||
{
|
||||
return Argument{dout, indices, din, elementwise_op};
|
||||
}
|
||||
|
||||
static auto MakeInvoker() { return Invoker{}; }
|
||||
|
||||
virtual std::unique_ptr<device::BaseInvoker> MakeInvokerPointer()
|
||||
{
|
||||
return std::make_unique<Invoker>(Invoker{});
|
||||
}
|
||||
|
||||
std::string GetTypeString() const override
|
||||
{
|
||||
auto str = std::stringstream();
|
||||
|
||||
// clang-format off
|
||||
str << "ReferenceMaxPoolBwd"
|
||||
<< std::endl;
|
||||
// clang-format on
|
||||
|
||||
return str.str();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace host
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -100,8 +100,8 @@ struct ReferencePoolingFwd : public device::BaseOperator
|
||||
wi >= 0 &&
|
||||
wi < static_cast<ck::index_t>(arg.in_.mDesc.GetLengths()[4]))
|
||||
{
|
||||
ComputeDataType currVal =
|
||||
static_cast<ComputeDataType>(arg.in_(n, c, di, hi, wi));
|
||||
ComputeDataType currVal = ck::type_convert<ComputeDataType>(
|
||||
arg.in_(n, c, di, hi, wi));
|
||||
|
||||
in_elementwise_op(currVal, currVal);
|
||||
|
||||
@@ -112,7 +112,7 @@ struct ReferencePoolingFwd : public device::BaseOperator
|
||||
}
|
||||
acc_elementwise_op(accuVal, accuVal);
|
||||
|
||||
arg.out_(n, c, do_, ho, wo) = accuVal;
|
||||
arg.out_(n, c, do_, ho, wo) = ck::type_convert<OutDataType>(accuVal);
|
||||
};
|
||||
|
||||
make_ParallelTensorFunctor(f_ncdhw,
|
||||
@@ -151,8 +151,8 @@ struct ReferencePoolingFwd : public device::BaseOperator
|
||||
wi >= 0 &&
|
||||
wi < static_cast<ck::index_t>(arg.in_.mDesc.GetLengths()[4]))
|
||||
{
|
||||
ComputeDataType currVal =
|
||||
static_cast<ComputeDataType>(arg.in_(n, c, di, hi, wi));
|
||||
ComputeDataType currVal = ck::type_convert<ComputeDataType>(
|
||||
arg.in_(n, c, di, hi, wi));
|
||||
IndexDataType currIndex =
|
||||
arg.in_.GetOffsetFromMultiIndex(n, c, di, hi, wi);
|
||||
|
||||
@@ -166,7 +166,7 @@ struct ReferencePoolingFwd : public device::BaseOperator
|
||||
|
||||
acc_elementwise_op(accuVal, accuVal);
|
||||
|
||||
arg.out_(n, c, do_, ho, wo) = accuVal;
|
||||
arg.out_(n, c, do_, ho, wo) = ck::type_convert<OutDataType>(accuVal);
|
||||
arg.out_indices_(n, c, do_, ho, wo) = accuIndex;
|
||||
};
|
||||
|
||||
@@ -212,7 +212,7 @@ struct ReferencePoolingFwd : public device::BaseOperator
|
||||
wi < static_cast<ck::index_t>(arg.in_.mDesc.GetLengths()[3]))
|
||||
{
|
||||
ComputeDataType currVal =
|
||||
static_cast<ComputeDataType>(arg.in_(n, c, hi, wi));
|
||||
ck::type_convert<ComputeDataType>(arg.in_(n, c, hi, wi));
|
||||
|
||||
in_elementwise_op(currVal, currVal);
|
||||
|
||||
@@ -222,7 +222,7 @@ struct ReferencePoolingFwd : public device::BaseOperator
|
||||
}
|
||||
|
||||
acc_elementwise_op(accuVal, accuVal);
|
||||
arg.out_(n, c, ho, wo) = accuVal;
|
||||
arg.out_(n, c, ho, wo) = ck::type_convert<OutDataType>(accuVal);
|
||||
};
|
||||
|
||||
make_ParallelTensorFunctor(f_nchw,
|
||||
@@ -255,7 +255,7 @@ struct ReferencePoolingFwd : public device::BaseOperator
|
||||
wi < static_cast<ck::index_t>(arg.in_.mDesc.GetLengths()[3]))
|
||||
{
|
||||
ComputeDataType currVal =
|
||||
static_cast<ComputeDataType>(arg.in_(n, c, hi, wi));
|
||||
ck::type_convert<ComputeDataType>(arg.in_(n, c, hi, wi));
|
||||
|
||||
IndexDataType currIndex =
|
||||
arg.in_.GetOffsetFromMultiIndex(n, c, hi, wi);
|
||||
@@ -268,7 +268,7 @@ struct ReferencePoolingFwd : public device::BaseOperator
|
||||
}
|
||||
|
||||
acc_elementwise_op(accuVal, accuVal);
|
||||
arg.out_(n, c, ho, wo) = accuVal;
|
||||
arg.out_(n, c, ho, wo) = ck::type_convert<OutDataType>(accuVal);
|
||||
arg.out_indices_(n, c, ho, wo) = accuIndex;
|
||||
};
|
||||
|
||||
|
||||
Reference in New Issue
Block a user