mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-11 17:00:18 +00:00
tidy
This commit is contained in:
@@ -14,7 +14,7 @@ void host_direct_convolution_backward_data(Tensor<TIn>& in,
|
||||
const ConvStrides& conv_strides,
|
||||
const ConvDilations& conv_dilations,
|
||||
const InLeftPads& in_left_pads,
|
||||
const InRightPads& in_right_pads,
|
||||
const InRightPads& /* in_right_pads */,
|
||||
const ConvTensorLayout layout = ConvTensorLayout::NCHW)
|
||||
{
|
||||
using namespace ck;
|
||||
@@ -25,11 +25,6 @@ void host_direct_convolution_backward_data(Tensor<TIn>& in,
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
auto f_nchw = [&](auto n, auto c, auto hi, auto wi) {
|
||||
std::size_t N = in.mDesc.GetLengths()[I0];
|
||||
std::size_t C = in.mDesc.GetLengths()[I1];
|
||||
std::size_t Hi = in.mDesc.GetLengths()[I2];
|
||||
std::size_t Wi = in.mDesc.GetLengths()[I3];
|
||||
|
||||
std::size_t K = wei.mDesc.GetLengths()[I0];
|
||||
std::size_t Y = wei.mDesc.GetLengths()[I2];
|
||||
std::size_t X = wei.mDesc.GetLengths()[I3];
|
||||
@@ -74,11 +69,6 @@ void host_direct_convolution_backward_data(Tensor<TIn>& in,
|
||||
};
|
||||
|
||||
auto f_nhwc = [&](auto n, auto hi, auto wi, auto c) {
|
||||
std::size_t N = in.mDesc.GetLengths()[I0];
|
||||
std::size_t Hi = in.mDesc.GetLengths()[I1];
|
||||
std::size_t Wi = in.mDesc.GetLengths()[I2];
|
||||
std::size_t C = in.mDesc.GetLengths()[I3];
|
||||
|
||||
std::size_t K = wei.mDesc.GetLengths()[I0];
|
||||
std::size_t Y = wei.mDesc.GetLengths()[I1];
|
||||
std::size_t X = wei.mDesc.GetLengths()[I2];
|
||||
@@ -122,22 +112,24 @@ void host_direct_convolution_backward_data(Tensor<TIn>& in,
|
||||
in(n, hi, wi, c) = v;
|
||||
};
|
||||
|
||||
switch(layout)
|
||||
if(layout == ConvTensorLayout::NCHW)
|
||||
{
|
||||
case ConvTensorLayout::NCHW:
|
||||
make_ParallelTensorFunctor(f_nchw,
|
||||
in.mDesc.GetLengths()[0],
|
||||
in.mDesc.GetLengths()[1],
|
||||
in.mDesc.GetLengths()[2],
|
||||
in.mDesc.GetLengths()[3])(std::thread::hardware_concurrency());
|
||||
break;
|
||||
case ConvTensorLayout::NHWC:
|
||||
}
|
||||
else if(layout == ConvTensorLayout::NHWC)
|
||||
{
|
||||
make_ParallelTensorFunctor(f_nhwc,
|
||||
in.mDesc.GetLengths()[0],
|
||||
in.mDesc.GetLengths()[1],
|
||||
in.mDesc.GetLengths()[2],
|
||||
in.mDesc.GetLengths()[3])(std::thread::hardware_concurrency());
|
||||
break;
|
||||
default: throw std::runtime_error("wrong! not supported layout");
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("wrong! not supported layout");
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user