mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-21 21:39:15 +00:00
Conv3d new (#94)
* conv3d compiles but has memory error
* conv3d works
* fix performance issue by using __builtin_amdgc_readfirstlane
* change MakeBlock2CTileMap to MakeDefaultBlock2CTileMap; change c_blockid_to* to cblockid_to*
* clang-format
* remove CK_EXPERIMENTAL_PASS_TENSOR_DECRIPTOR_BY_*; moved wrapper into DeviceConv3d
* format
* remove useless marc
* add comment
Co-authored-by: Chao Liu <chao.liu2@amd.com>
[ROCm/composable_kernel commit: 6dfb92bbef]
This commit is contained in:
@@ -48,3 +48,102 @@ void host_conv_nchw_kcyx_nkhw(const Tensor<TIn>& in,
|
||||
out.mDesc.GetLengths()[2],
|
||||
out.mDesc.GetLengths()[3])(std::thread::hardware_concurrency());
|
||||
}
|
||||
|
||||
template <typename TIn,
|
||||
typename TWei,
|
||||
typename TOut,
|
||||
typename ConvStrides,
|
||||
typename ConvDilations,
|
||||
typename InLeftPads,
|
||||
typename InRightPads>
|
||||
void host_conv3d_ndhwc_kzyxc_ndhwk(const Tensor<TIn>& in,
|
||||
const Tensor<TWei>& wei,
|
||||
Tensor<TOut>& out,
|
||||
const ConvStrides& conv_strides,
|
||||
const ConvDilations& conv_dilations,
|
||||
const InLeftPads& in_left_pads,
|
||||
const InRightPads&)
|
||||
{
|
||||
using namespace ck;
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
const auto Di = in.mDesc.GetLengths()[1];
|
||||
const auto Hi = in.mDesc.GetLengths()[2];
|
||||
const auto Wi = in.mDesc.GetLengths()[3];
|
||||
const auto Z = wei.mDesc.GetLengths()[1];
|
||||
const auto Y = wei.mDesc.GetLengths()[2];
|
||||
const auto X = wei.mDesc.GetLengths()[3];
|
||||
const auto C = wei.mDesc.GetLengths()[4];
|
||||
|
||||
auto f_ndhwc = [&](auto n, auto do__, auto ho_, auto wo_, auto k) {
|
||||
// do__ must be converted to signed integer, otherwise zmin might be wrong in cases
|
||||
// negative values.
|
||||
const int do_ = static_cast<int>(do__);
|
||||
const int ho = static_cast<int>(ho_);
|
||||
const int wo = static_cast<int>(wo_);
|
||||
const int zmin =
|
||||
std::max(0,
|
||||
(in_left_pads[I0] - do_ * conv_strides[I0] + conv_dilations[I0] - 1) /
|
||||
conv_dilations[I0]);
|
||||
const int ymin =
|
||||
std::max(0,
|
||||
(in_left_pads[I1] - ho * conv_strides[I1] + conv_dilations[I1] - 1) /
|
||||
conv_dilations[I1]);
|
||||
const int xmin =
|
||||
std::max(0,
|
||||
(in_left_pads[I2] - wo * conv_strides[I2] + conv_dilations[I2] - 1) /
|
||||
conv_dilations[I2]);
|
||||
const int zmax =
|
||||
std::min(Z, (in_left_pads[I0] - do_ * conv_strides[I0] + Di) / conv_dilations[I0]);
|
||||
const int ymax =
|
||||
std::min(Y, (in_left_pads[I1] - ho * conv_strides[I1] + Hi) / conv_dilations[I1]);
|
||||
const int xmax =
|
||||
std::min(X, (in_left_pads[I2] - wo * conv_strides[I2] + Wi) / conv_dilations[I2]);
|
||||
const int di_min = do_ * conv_strides[I0] + zmin * conv_dilations[I0] - in_left_pads[I0];
|
||||
const int hi_min = ho * conv_strides[I1] + ymin * conv_dilations[I1] - in_left_pads[I1];
|
||||
const int wi_min = wo * conv_strides[I2] + xmin * conv_dilations[I2] - in_left_pads[I2];
|
||||
|
||||
double v = 0;
|
||||
|
||||
const TIn* in_n = in.mData.data() + n * Di * Hi * Wi * C;
|
||||
const TWei* wei_k = wei.mData.data() + k * Z * Y * X * C;
|
||||
|
||||
int di = di_min;
|
||||
for(int z = zmin; z < zmax; ++z, di += conv_dilations[I0])
|
||||
{
|
||||
const TIn* in_n_di = in_n + di * Hi * Wi * C;
|
||||
const TWei* wei_k_z = wei_k + z * Y * X * C;
|
||||
int hi = hi_min;
|
||||
|
||||
for(int y = ymin; y < ymax; ++y, hi += conv_dilations[I1])
|
||||
{
|
||||
const TIn* in_n_di_hi = in_n_di + hi * Wi * C;
|
||||
const TWei* wei_k_z_y = wei_k_z + y * X * C;
|
||||
int wi = wi_min;
|
||||
|
||||
for(int x = xmin; x < xmax; ++x, wi += conv_dilations[I2])
|
||||
{
|
||||
const TIn* in_n_di_hi_wi = in_n_di_hi + wi * C;
|
||||
const TWei* wei_k_z_y_x = wei_k_z_y + x * C;
|
||||
|
||||
for(int c = 0; c < C; ++c)
|
||||
{
|
||||
v += static_cast<const double>(in_n_di_hi_wi[c]) *
|
||||
static_cast<const double>(wei_k_z_y_x[c]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
out(n, do_, ho, wo, k) = v;
|
||||
};
|
||||
|
||||
make_ParallelTensorFunctor(f_ndhwc,
|
||||
out.mDesc.GetLengths()[0],
|
||||
out.mDesc.GetLengths()[1],
|
||||
out.mDesc.GetLengths()[2],
|
||||
out.mDesc.GetLengths()[3],
|
||||
out.mDesc.GetLengths()[4])(std::thread::hardware_concurrency() - 4);
|
||||
}
|
||||
|
||||
@@ -144,7 +144,7 @@ struct GeneratorTensor_Checkboard
|
||||
template <typename... Ts>
|
||||
float operator()(Ts... Xs) const
|
||||
{
|
||||
std::array<ck::index_t, sizeof...(Ts)> dims = {{static_cast<ck::index_t>(Xs)...}};
|
||||
std::array<ck::index_t, sizeof...(Ts)> dims = {static_cast<ck::index_t>(Xs)...};
|
||||
return std::accumulate(dims.begin(),
|
||||
dims.end(),
|
||||
true,
|
||||
|
||||
Reference in New Issue
Block a user