mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-04 13:41:24 +00:00
[CK_TILE] Image to Column kernel (#1532)
* [CK_TILE] Image to Column kernel * Fixes * Vector loads and stores * Fixes * Fixes * change test dir name
This commit is contained in:
@@ -0,0 +1,266 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/ops/common/tensor_layout.hpp"
|
||||
#include "ck_tile/host/convolution_parameter.hpp"
|
||||
#include "ck_tile/host/host_tensor.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
namespace conv {
|
||||
namespace detail {
|
||||
|
||||
template <typename OldLayout>
|
||||
CK_TILE_HOST std::vector<std::size_t> get_layout_transpose_gnchw_to_old()
|
||||
{
|
||||
if constexpr(std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::GNCW> ||
|
||||
std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::GKCX> ||
|
||||
std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::GNKW>)
|
||||
{
|
||||
return {0, 1, 2, 3};
|
||||
}
|
||||
else if constexpr(std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::GNCHW> ||
|
||||
std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::GKCYX> ||
|
||||
std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::GNKHW>)
|
||||
{
|
||||
return {0, 1, 2, 3, 4};
|
||||
}
|
||||
else if constexpr(std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::GNCDHW> ||
|
||||
std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::GKCZYX> ||
|
||||
std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::GNKDHW>)
|
||||
{
|
||||
return {0, 1, 2, 3, 4, 5};
|
||||
}
|
||||
if constexpr(std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::GNWC> ||
|
||||
std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::GKXC> ||
|
||||
std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::GNWK>)
|
||||
{
|
||||
return {0, 1, 3, 2};
|
||||
}
|
||||
else if constexpr(std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::GNHWC> ||
|
||||
std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::GKYXC> ||
|
||||
std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::GNHWK>)
|
||||
{
|
||||
return {0, 1, 4, 2, 3};
|
||||
}
|
||||
else if constexpr(std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::GNDHWC> ||
|
||||
std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::GKZYXC> ||
|
||||
std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::GNDHWK>)
|
||||
{
|
||||
return {0, 1, 5, 2, 3, 4};
|
||||
}
|
||||
else if constexpr(std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::NWGC> ||
|
||||
std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::KXGC> ||
|
||||
std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::NWGK>)
|
||||
{
|
||||
return {2, 0, 3, 1};
|
||||
}
|
||||
else if constexpr(std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::NHWGC> ||
|
||||
std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::KYXGC> ||
|
||||
std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::NHWGK>)
|
||||
{
|
||||
return {3, 0, 4, 1, 2};
|
||||
}
|
||||
else if constexpr(std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::NDHWGC> ||
|
||||
std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::KZYXGC> ||
|
||||
std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::NDHWGK>)
|
||||
{
|
||||
return {4, 0, 5, 1, 2, 3};
|
||||
}
|
||||
else
|
||||
{
|
||||
printf("%s\n", __func__);
|
||||
throw std::runtime_error("wrong! unsupported layout");
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace detail
|
||||
|
||||
// make tensor descriptor for packed input tensor, and order the dimension in the order of GNCHW
|
||||
// regardless of physical layout
|
||||
template <typename InLayout>
|
||||
CK_TILE_HOST HostTensorDescriptor
|
||||
make_input_host_tensor_descriptor_g_n_c_wis_packed(const ck_tile::conv::ConvParam& param)
|
||||
{
|
||||
std::vector<std::size_t> physical_lengths;
|
||||
|
||||
if constexpr(std::is_same_v<InLayout, ck_tile::tensor_layout::convolution::GNCW> ||
|
||||
std::is_same_v<InLayout, ck_tile::tensor_layout::convolution::GNCHW> ||
|
||||
std::is_same_v<InLayout, ck_tile::tensor_layout::convolution::GNCDHW>)
|
||||
{
|
||||
physical_lengths = std::vector<std::size_t>{static_cast<std::size_t>(param.G_),
|
||||
static_cast<std::size_t>(param.N_),
|
||||
static_cast<std::size_t>(param.C_)};
|
||||
|
||||
physical_lengths.insert(physical_lengths.end(),
|
||||
param.input_spatial_lengths_.begin(),
|
||||
param.input_spatial_lengths_.begin() + param.num_dim_spatial_);
|
||||
}
|
||||
else if constexpr(std::is_same_v<InLayout, ck_tile::tensor_layout::convolution::GNWC> ||
|
||||
std::is_same_v<InLayout, ck_tile::tensor_layout::convolution::GNHWC> ||
|
||||
std::is_same_v<InLayout, ck_tile::tensor_layout::convolution::GNDHWC>)
|
||||
{
|
||||
physical_lengths = std::vector<std::size_t>{static_cast<std::size_t>(param.G_),
|
||||
static_cast<std::size_t>(param.N_),
|
||||
static_cast<std::size_t>(param.C_)};
|
||||
|
||||
physical_lengths.insert(physical_lengths.begin() + 2,
|
||||
param.input_spatial_lengths_.begin(),
|
||||
param.input_spatial_lengths_.begin() + param.num_dim_spatial_);
|
||||
}
|
||||
else if constexpr(std::is_same_v<InLayout, ck_tile::tensor_layout::convolution::NWGC> ||
|
||||
std::is_same_v<InLayout, ck_tile::tensor_layout::convolution::NHWGC> ||
|
||||
std::is_same_v<InLayout, ck_tile::tensor_layout::convolution::NDHWGC>)
|
||||
{
|
||||
physical_lengths = std::vector<std::size_t>{static_cast<std::size_t>(param.N_),
|
||||
static_cast<std::size_t>(param.G_),
|
||||
static_cast<std::size_t>(param.C_)};
|
||||
|
||||
physical_lengths.insert(physical_lengths.begin() + 1,
|
||||
param.input_spatial_lengths_.begin(),
|
||||
param.input_spatial_lengths_.begin() + param.num_dim_spatial_);
|
||||
}
|
||||
else
|
||||
{
|
||||
printf("%s\n", __func__);
|
||||
printf("%s\n", InLayout::name);
|
||||
throw std::runtime_error("wrong! unsupported layout");
|
||||
}
|
||||
|
||||
return transpose_host_tensor_descriptor_given_new2old(
|
||||
HostTensorDescriptor(physical_lengths),
|
||||
detail::get_layout_transpose_gnchw_to_old<InLayout>());
|
||||
}
|
||||
|
||||
// make tensor descriptor for packed weight tensor, and order the dimension in the order of GKCYX
|
||||
// regardless of physical layout
|
||||
template <typename WeiLayout>
|
||||
CK_TILE_HOST HostTensorDescriptor
|
||||
make_weight_host_tensor_descriptor_g_k_c_xs_packed(const ck_tile::conv::ConvParam& param)
|
||||
{
|
||||
std::vector<std::size_t> physical_lengths;
|
||||
|
||||
if constexpr(std::is_same_v<WeiLayout, ck_tile::tensor_layout::convolution::KXC> ||
|
||||
std::is_same_v<WeiLayout, ck_tile::tensor_layout::convolution::KYXC> ||
|
||||
std::is_same_v<WeiLayout, ck_tile::tensor_layout::convolution::KZYXC>)
|
||||
{
|
||||
if(param.G_ != 1)
|
||||
{
|
||||
throw std::runtime_error("wrong! G != 1");
|
||||
}
|
||||
|
||||
physical_lengths = std::vector<std::size_t>{static_cast<std::size_t>(param.K_),
|
||||
static_cast<std::size_t>(param.C_)};
|
||||
|
||||
physical_lengths.insert(physical_lengths.end(),
|
||||
param.filter_spatial_lengths_.begin(),
|
||||
param.filter_spatial_lengths_.begin() + param.num_dim_spatial_);
|
||||
}
|
||||
else if constexpr(std::is_same_v<WeiLayout, ck_tile::tensor_layout::convolution::GKCX> ||
|
||||
std::is_same_v<WeiLayout, ck_tile::tensor_layout::convolution::GKCYX> ||
|
||||
std::is_same_v<WeiLayout, ck_tile::tensor_layout::convolution::GKCZYX>)
|
||||
{
|
||||
physical_lengths = std::vector<std::size_t>{static_cast<std::size_t>(param.G_),
|
||||
static_cast<std::size_t>(param.K_),
|
||||
static_cast<std::size_t>(param.C_)};
|
||||
|
||||
physical_lengths.insert(physical_lengths.end(),
|
||||
param.filter_spatial_lengths_.begin(),
|
||||
param.filter_spatial_lengths_.begin() + param.num_dim_spatial_);
|
||||
}
|
||||
else if constexpr(std::is_same_v<WeiLayout, ck_tile::tensor_layout::convolution::GKXC> ||
|
||||
std::is_same_v<WeiLayout, ck_tile::tensor_layout::convolution::GKYXC> ||
|
||||
std::is_same_v<WeiLayout, ck_tile::tensor_layout::convolution::GKZYXC>)
|
||||
{
|
||||
physical_lengths = std::vector<std::size_t>{static_cast<std::size_t>(param.G_),
|
||||
static_cast<std::size_t>(param.K_),
|
||||
static_cast<std::size_t>(param.C_)};
|
||||
|
||||
physical_lengths.insert(physical_lengths.begin() + 2,
|
||||
param.filter_spatial_lengths_.begin(),
|
||||
param.filter_spatial_lengths_.begin() + param.num_dim_spatial_);
|
||||
}
|
||||
else if constexpr(std::is_same_v<WeiLayout, ck_tile::tensor_layout::convolution::KXGC> ||
|
||||
std::is_same_v<WeiLayout, ck_tile::tensor_layout::convolution::KYXGC> ||
|
||||
std::is_same_v<WeiLayout, ck_tile::tensor_layout::convolution::KZYXGC>)
|
||||
{
|
||||
physical_lengths = std::vector<std::size_t>{static_cast<std::size_t>(param.K_),
|
||||
static_cast<std::size_t>(param.G_),
|
||||
static_cast<std::size_t>(param.C_)};
|
||||
|
||||
physical_lengths.insert(physical_lengths.begin() + 1,
|
||||
param.filter_spatial_lengths_.begin(),
|
||||
param.filter_spatial_lengths_.begin() + param.num_dim_spatial_);
|
||||
}
|
||||
else
|
||||
{
|
||||
printf("%s\n", __func__);
|
||||
printf("%s\n", WeiLayout::name);
|
||||
throw std::runtime_error("wrong! unsupported layout");
|
||||
}
|
||||
|
||||
return transpose_host_tensor_descriptor_given_new2old(
|
||||
HostTensorDescriptor(physical_lengths),
|
||||
detail::get_layout_transpose_gnchw_to_old<WeiLayout>());
|
||||
}
|
||||
|
||||
// make tensor descriptor for packed output tensor, and order the dimension in the order of GNKHW
|
||||
// regardless of physical layout
|
||||
template <typename OutLayout>
|
||||
CK_TILE_HOST HostTensorDescriptor
|
||||
make_output_host_tensor_descriptor_g_n_k_wos_packed(const ck_tile::conv::ConvParam& param)
|
||||
{
|
||||
std::vector<std::size_t> physical_lengths;
|
||||
|
||||
if constexpr(std::is_same_v<OutLayout, ck_tile::tensor_layout::convolution::GNKW> ||
|
||||
std::is_same_v<OutLayout, ck_tile::tensor_layout::convolution::GNKHW> ||
|
||||
std::is_same_v<OutLayout, ck_tile::tensor_layout::convolution::GNKDHW>)
|
||||
{
|
||||
physical_lengths = std::vector<std::size_t>{static_cast<std::size_t>(param.G_),
|
||||
static_cast<std::size_t>(param.N_),
|
||||
static_cast<std::size_t>(param.K_)};
|
||||
|
||||
physical_lengths.insert(physical_lengths.end(),
|
||||
param.output_spatial_lengths_.begin(),
|
||||
param.output_spatial_lengths_.begin() + param.num_dim_spatial_);
|
||||
}
|
||||
// separate from legacy code above
|
||||
else if constexpr(std::is_same_v<OutLayout, ck_tile::tensor_layout::convolution::GNWK> ||
|
||||
std::is_same_v<OutLayout, ck_tile::tensor_layout::convolution::GNHWK> ||
|
||||
std::is_same_v<OutLayout, ck_tile::tensor_layout::convolution::GNDHWK>)
|
||||
{
|
||||
physical_lengths = std::vector<std::size_t>{static_cast<std::size_t>(param.G_),
|
||||
static_cast<std::size_t>(param.N_),
|
||||
static_cast<std::size_t>(param.K_)};
|
||||
|
||||
physical_lengths.insert(physical_lengths.begin() + 2,
|
||||
param.output_spatial_lengths_.begin(),
|
||||
param.output_spatial_lengths_.begin() + param.num_dim_spatial_);
|
||||
}
|
||||
else if constexpr(std::is_same_v<OutLayout, ck_tile::tensor_layout::convolution::NWGK> ||
|
||||
std::is_same_v<OutLayout, ck_tile::tensor_layout::convolution::NHWGK> ||
|
||||
std::is_same_v<OutLayout, ck_tile::tensor_layout::convolution::NDHWGK>)
|
||||
{
|
||||
physical_lengths = std::vector<std::size_t>{static_cast<std::size_t>(param.N_),
|
||||
static_cast<std::size_t>(param.G_),
|
||||
static_cast<std::size_t>(param.K_)};
|
||||
|
||||
physical_lengths.insert(physical_lengths.begin() + 1,
|
||||
param.output_spatial_lengths_.begin(),
|
||||
param.output_spatial_lengths_.begin() + param.num_dim_spatial_);
|
||||
}
|
||||
else
|
||||
{
|
||||
printf("%s\n", __func__);
|
||||
printf("%s\n", OutLayout::name);
|
||||
throw std::runtime_error("wrong! unsupported layout");
|
||||
}
|
||||
|
||||
return transpose_host_tensor_descriptor_given_new2old(
|
||||
HostTensorDescriptor(physical_lengths),
|
||||
detail::get_layout_transpose_gnchw_to_old<OutLayout>());
|
||||
}
|
||||
|
||||
} // namespace conv
|
||||
} // namespace ck_tile
|
||||
283
include/ck_tile/host/convolution_parameter.hpp
Normal file
283
include/ck_tile/host/convolution_parameter.hpp
Normal file
@@ -0,0 +1,283 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cstdlib>
|
||||
#include <numeric>
|
||||
#include <iterator>
|
||||
#include <vector>
|
||||
|
||||
namespace ck_tile {
|
||||
namespace conv {
|
||||
|
||||
struct ConvParam
|
||||
{
|
||||
ConvParam();
|
||||
ConvParam(ck_tile::index_t n_dim,
|
||||
ck_tile::index_t group_count,
|
||||
ck_tile::index_t n_batch,
|
||||
ck_tile::index_t n_out_channels,
|
||||
ck_tile::index_t n_in_channels,
|
||||
const std::vector<ck_tile::index_t>& filters_len,
|
||||
const std::vector<ck_tile::index_t>& input_len,
|
||||
const std::vector<ck_tile::index_t>& strides,
|
||||
const std::vector<ck_tile::index_t>& dilations,
|
||||
const std::vector<ck_tile::index_t>& left_pads,
|
||||
const std::vector<ck_tile::index_t>& right_pads)
|
||||
: num_dim_spatial_(static_cast<ck_tile::long_index_t>(n_dim)),
|
||||
G_(static_cast<ck_tile::long_index_t>(group_count)),
|
||||
N_(static_cast<ck_tile::long_index_t>(n_batch)),
|
||||
K_(static_cast<ck_tile::long_index_t>(n_out_channels)),
|
||||
C_(static_cast<ck_tile::long_index_t>(n_in_channels)),
|
||||
filter_spatial_lengths_(num_dim_spatial_),
|
||||
input_spatial_lengths_(num_dim_spatial_),
|
||||
output_spatial_lengths_(num_dim_spatial_),
|
||||
conv_filter_strides_(num_dim_spatial_),
|
||||
conv_filter_dilations_(num_dim_spatial_),
|
||||
input_left_pads_(num_dim_spatial_),
|
||||
input_right_pads_(num_dim_spatial_)
|
||||
{
|
||||
if(static_cast<ck_tile::index_t>(filter_spatial_lengths_.size()) != num_dim_spatial_ ||
|
||||
static_cast<ck_tile::index_t>(input_spatial_lengths_.size()) != num_dim_spatial_ ||
|
||||
static_cast<ck_tile::index_t>(conv_filter_strides_.size()) != num_dim_spatial_ ||
|
||||
static_cast<ck_tile::index_t>(conv_filter_dilations_.size()) != num_dim_spatial_ ||
|
||||
static_cast<ck_tile::index_t>(input_left_pads_.size()) != num_dim_spatial_ ||
|
||||
static_cast<ck_tile::index_t>(input_right_pads_.size()) != num_dim_spatial_)
|
||||
{
|
||||
throw(std::runtime_error(
|
||||
"ConvParam::ConvParam: "
|
||||
"parameter size is different from number of declared dimensions!"));
|
||||
}
|
||||
|
||||
for(ck_tile::index_t i = 0; i < num_dim_spatial_; ++i)
|
||||
{
|
||||
filter_spatial_lengths_[i] = static_cast<ck_tile::long_index_t>(filters_len[i]);
|
||||
input_spatial_lengths_[i] = static_cast<ck_tile::long_index_t>(input_len[i]);
|
||||
conv_filter_strides_[i] = static_cast<ck_tile::long_index_t>(strides[i]);
|
||||
conv_filter_dilations_[i] = static_cast<ck_tile::long_index_t>(dilations[i]);
|
||||
input_left_pads_[i] = static_cast<ck_tile::long_index_t>(left_pads[i]);
|
||||
input_right_pads_[i] = static_cast<ck_tile::long_index_t>(right_pads[i]);
|
||||
|
||||
// XEff = (X - 1) * conv_dilation_w + 1;
|
||||
// Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + 1;
|
||||
const ck_tile::long_index_t x_eff =
|
||||
(filter_spatial_lengths_[i] - 1) * conv_filter_dilations_[i] + 1;
|
||||
|
||||
output_spatial_lengths_[i] =
|
||||
(input_spatial_lengths_[i] + input_left_pads_[i] + input_right_pads_[i] - x_eff) /
|
||||
conv_filter_strides_[i] +
|
||||
1;
|
||||
}
|
||||
}
|
||||
|
||||
ConvParam(ck_tile::long_index_t n_dim,
|
||||
ck_tile::long_index_t group_count,
|
||||
ck_tile::long_index_t n_batch,
|
||||
ck_tile::long_index_t n_out_channels,
|
||||
ck_tile::long_index_t n_in_channels,
|
||||
const std::vector<ck_tile::long_index_t>& filters_len,
|
||||
const std::vector<ck_tile::long_index_t>& input_len,
|
||||
const std::vector<ck_tile::long_index_t>& strides,
|
||||
const std::vector<ck_tile::long_index_t>& dilations,
|
||||
const std::vector<ck_tile::long_index_t>& left_pads,
|
||||
const std::vector<ck_tile::long_index_t>& right_pads)
|
||||
: num_dim_spatial_(n_dim),
|
||||
G_(group_count),
|
||||
N_(n_batch),
|
||||
K_(n_out_channels),
|
||||
C_(n_in_channels),
|
||||
filter_spatial_lengths_(filters_len),
|
||||
input_spatial_lengths_(input_len),
|
||||
output_spatial_lengths_(num_dim_spatial_),
|
||||
conv_filter_strides_(strides),
|
||||
conv_filter_dilations_(dilations),
|
||||
input_left_pads_(left_pads),
|
||||
input_right_pads_(right_pads)
|
||||
{
|
||||
if(static_cast<ck_tile::index_t>(filter_spatial_lengths_.size()) != num_dim_spatial_ ||
|
||||
static_cast<ck_tile::index_t>(input_spatial_lengths_.size()) != num_dim_spatial_ ||
|
||||
static_cast<ck_tile::index_t>(conv_filter_strides_.size()) != num_dim_spatial_ ||
|
||||
static_cast<ck_tile::index_t>(conv_filter_dilations_.size()) != num_dim_spatial_ ||
|
||||
static_cast<ck_tile::index_t>(input_left_pads_.size()) != num_dim_spatial_ ||
|
||||
static_cast<ck_tile::index_t>(input_right_pads_.size()) != num_dim_spatial_)
|
||||
{
|
||||
throw(std::runtime_error(
|
||||
"ConvParam::ConvParam: "
|
||||
"parameter size is different from number of declared dimensions!"));
|
||||
}
|
||||
|
||||
for(ck_tile::index_t i = 0; i < num_dim_spatial_; ++i)
|
||||
{
|
||||
// XEff = (X - 1) * conv_dilation_w + 1;
|
||||
// Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + 1;
|
||||
const ck_tile::long_index_t x_eff =
|
||||
(filter_spatial_lengths_[i] - 1) * conv_filter_dilations_[i] + 1;
|
||||
|
||||
output_spatial_lengths_[i] =
|
||||
(input_spatial_lengths_[i] + input_left_pads_[i] + input_right_pads_[i] - x_eff) /
|
||||
conv_filter_strides_[i] +
|
||||
1;
|
||||
}
|
||||
}
|
||||
|
||||
ck_tile::long_index_t num_dim_spatial_;
|
||||
ck_tile::long_index_t G_;
|
||||
ck_tile::long_index_t N_;
|
||||
ck_tile::long_index_t K_;
|
||||
ck_tile::long_index_t C_;
|
||||
|
||||
std::vector<ck_tile::long_index_t> filter_spatial_lengths_;
|
||||
std::vector<ck_tile::long_index_t> input_spatial_lengths_;
|
||||
std::vector<ck_tile::long_index_t> output_spatial_lengths_;
|
||||
|
||||
std::vector<ck_tile::long_index_t> conv_filter_strides_;
|
||||
std::vector<ck_tile::long_index_t> conv_filter_dilations_;
|
||||
|
||||
std::vector<ck_tile::long_index_t> input_left_pads_;
|
||||
std::vector<ck_tile::long_index_t> input_right_pads_;
|
||||
|
||||
std::vector<ck_tile::long_index_t> GetOutputSpatialLengths() const
|
||||
{
|
||||
return output_spatial_lengths_;
|
||||
}
|
||||
|
||||
std::size_t GetFlops() const
|
||||
{
|
||||
// 2 * G * N * K * C * <output spatial lengths product> * <filter spatial lengths product>
|
||||
return static_cast<std::size_t>(2) * G_ * N_ * K_ * C_ *
|
||||
std::accumulate(std::begin(output_spatial_lengths_),
|
||||
std::next(std::begin(output_spatial_lengths_), num_dim_spatial_),
|
||||
1,
|
||||
std::multiplies<>()) *
|
||||
std::accumulate(std::begin(filter_spatial_lengths_),
|
||||
std::next(std::begin(filter_spatial_lengths_), num_dim_spatial_),
|
||||
1,
|
||||
std::multiplies<>());
|
||||
}
|
||||
|
||||
template <typename InDataType>
|
||||
std::size_t GetInputByte() const
|
||||
{
|
||||
// sizeof(InDataType) * (G * N * C * <input spatial lengths product>) +
|
||||
return sizeof(InDataType) *
|
||||
(G_ * N_ * C_ *
|
||||
std::accumulate(std::begin(input_spatial_lengths_),
|
||||
std::next(std::begin(input_spatial_lengths_), num_dim_spatial_),
|
||||
1,
|
||||
std::multiplies<>()));
|
||||
}
|
||||
|
||||
template <typename WeiDataType>
|
||||
std::size_t GetWeightByte() const
|
||||
{
|
||||
// sizeof(WeiDataType) * (G * K * C * <filter spatial lengths product>) +
|
||||
return sizeof(WeiDataType) *
|
||||
(G_ * K_ * C_ *
|
||||
std::accumulate(std::begin(filter_spatial_lengths_),
|
||||
std::next(std::begin(filter_spatial_lengths_), num_dim_spatial_),
|
||||
1,
|
||||
std::multiplies<>()));
|
||||
}
|
||||
|
||||
template <typename OutDataType>
|
||||
std::size_t GetOutputByte() const
|
||||
{
|
||||
// sizeof(OutDataType) * (G * N * K * <output spatial lengths product>);
|
||||
return sizeof(OutDataType) * (G_ * N_ * K_ *
|
||||
std::accumulate(std::begin(output_spatial_lengths_),
|
||||
std::end(output_spatial_lengths_),
|
||||
static_cast<std::size_t>(1),
|
||||
std::multiplies<std::size_t>()));
|
||||
}
|
||||
|
||||
template <typename InDataType, typename WeiDataType, typename OutDataType>
|
||||
std::size_t GetByte() const
|
||||
{
|
||||
return GetInputByte<InDataType>() + GetWeightByte<WeiDataType>() +
|
||||
GetOutputByte<OutDataType>();
|
||||
}
|
||||
};
|
||||
|
||||
ConvParam::ConvParam()
|
||||
: ConvParam::ConvParam(2, 1, 128, 256, 192, {3, 3}, {71, 71}, {2, 2}, {1, 1}, {1, 1}, {1, 1})
|
||||
{
|
||||
}
|
||||
|
||||
CK_TILE_HOST std::string get_conv_param_parser_helper_msg()
|
||||
{
|
||||
std::string msg;
|
||||
|
||||
msg += "Following arguments (depending on number of spatial dims):\n"
|
||||
" Number of spatial dimensions (1=Conv1d, 2=Conv2d, 3=Conv3d)\n"
|
||||
" G, N, K, C, \n"
|
||||
" <filter spatial dimensions>, (ie Y, X for 2D)\n"
|
||||
" <input image spatial dimensions>, (ie Hi, Wi for 2D)\n"
|
||||
" <strides>, (ie Sy, Sx for 2D)\n"
|
||||
" <dilations>, (ie Dy, Dx for 2D)\n"
|
||||
" <left padding>, (ie LeftPy, LeftPx for 2D)\n"
|
||||
" <right padding>, (ie RightPy, RightPx for 2D)\n";
|
||||
|
||||
return msg;
|
||||
}
|
||||
|
||||
CK_TILE_HOST ck_tile::conv::ConvParam
|
||||
parse_conv_param(int num_dim_spatial, int arg_idx, char* const argv[])
|
||||
{
|
||||
const ck_tile::long_index_t G = std::stol(argv[arg_idx++]);
|
||||
const ck_tile::long_index_t N = std::stol(argv[arg_idx++]);
|
||||
const ck_tile::long_index_t K = std::stol(argv[arg_idx++]);
|
||||
const ck_tile::long_index_t C = std::stol(argv[arg_idx++]);
|
||||
|
||||
std::vector<ck_tile::long_index_t> filter_spatial_lengths(num_dim_spatial);
|
||||
std::vector<ck_tile::long_index_t> input_spatial_lengths(num_dim_spatial);
|
||||
std::vector<ck_tile::long_index_t> conv_filter_strides(num_dim_spatial);
|
||||
std::vector<ck_tile::long_index_t> conv_filter_dilations(num_dim_spatial);
|
||||
std::vector<ck_tile::long_index_t> input_left_pads(num_dim_spatial);
|
||||
std::vector<ck_tile::long_index_t> input_right_pads(num_dim_spatial);
|
||||
|
||||
for(int i = 0; i < num_dim_spatial; ++i)
|
||||
{
|
||||
filter_spatial_lengths[i] = std::stol(argv[arg_idx++]);
|
||||
}
|
||||
|
||||
for(int i = 0; i < num_dim_spatial; ++i)
|
||||
{
|
||||
input_spatial_lengths[i] = std::stol(argv[arg_idx++]);
|
||||
}
|
||||
|
||||
for(int i = 0; i < num_dim_spatial; ++i)
|
||||
{
|
||||
conv_filter_strides[i] = std::stol(argv[arg_idx++]);
|
||||
}
|
||||
|
||||
for(int i = 0; i < num_dim_spatial; ++i)
|
||||
{
|
||||
conv_filter_dilations[i] = std::stol(argv[arg_idx++]);
|
||||
}
|
||||
|
||||
for(int i = 0; i < num_dim_spatial; ++i)
|
||||
{
|
||||
input_left_pads[i] = std::stol(argv[arg_idx++]);
|
||||
}
|
||||
|
||||
for(int i = 0; i < num_dim_spatial; ++i)
|
||||
{
|
||||
input_right_pads[i] = std::stol(argv[arg_idx++]);
|
||||
}
|
||||
|
||||
return ck_tile::conv::ConvParam{num_dim_spatial,
|
||||
G,
|
||||
N,
|
||||
K,
|
||||
C,
|
||||
filter_spatial_lengths,
|
||||
input_spatial_lengths,
|
||||
conv_filter_strides,
|
||||
conv_filter_dilations,
|
||||
input_left_pads,
|
||||
input_right_pads};
|
||||
}
|
||||
|
||||
} // namespace conv
|
||||
} // namespace ck_tile
|
||||
@@ -176,7 +176,20 @@ struct HostTensorDescriptor
|
||||
return std::inner_product(iss.begin(), iss.end(), mStrides.begin(), std::size_t{0});
|
||||
}
|
||||
|
||||
friend std::ostream& operator<<(std::ostream& os, const HostTensorDescriptor& desc);
|
||||
friend std::ostream& operator<<(std::ostream& os, const HostTensorDescriptor& desc)
|
||||
{
|
||||
os << "dim " << desc.get_num_of_dimension() << ", ";
|
||||
|
||||
os << "lengths {";
|
||||
LogRange(os, desc.get_lengths(), ", ");
|
||||
os << "}, ";
|
||||
|
||||
os << "strides {";
|
||||
LogRange(os, desc.get_strides(), ", ");
|
||||
os << "}";
|
||||
|
||||
return os;
|
||||
}
|
||||
|
||||
private:
|
||||
std::vector<std::size_t> mLens;
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -9,53 +9,125 @@
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_HOST void reference_im2col(HostTensor<T>& in_mtx_host_ref,
|
||||
const HostTensor<T>& in_host,
|
||||
int /*N*/,
|
||||
int /*K*/,
|
||||
int C,
|
||||
int /*Y*/,
|
||||
int X,
|
||||
int Hi,
|
||||
int Wi,
|
||||
int Ho,
|
||||
int Wo,
|
||||
int ConvStrideH,
|
||||
int ConvStrideW,
|
||||
int ConvDilationH,
|
||||
int ConvDilationW,
|
||||
int InLeftPadH,
|
||||
int InLeftPadW,
|
||||
int /*InRightPadH*/,
|
||||
int /*InRightPadW*/)
|
||||
template <typename InDataType, typename OutDataType, index_t NDimSpatial>
|
||||
CK_TILE_HOST void reference_im2col(const HostTensor<InDataType>& in_host,
|
||||
HostTensor<OutDataType>& out_host,
|
||||
const ck_tile::conv::ConvParam& conv_params)
|
||||
{
|
||||
int GemmM = in_mtx_host_ref.get_lengths()[0];
|
||||
int GemmK = in_mtx_host_ref.get_lengths()[1];
|
||||
const long_index_t G = in_host.get_lengths()[0];
|
||||
const long_index_t N = in_host.get_lengths()[1];
|
||||
const long_index_t C = in_host.get_lengths()[2];
|
||||
|
||||
for(int gemm_m = 0; gemm_m < GemmM; ++gemm_m)
|
||||
if constexpr(NDimSpatial == 1)
|
||||
{
|
||||
int mtmp = gemm_m;
|
||||
int n = mtmp / (Ho * Wo);
|
||||
mtmp -= n * Ho * Wo;
|
||||
int ho = mtmp / Wo;
|
||||
int wo = mtmp - ho * Wo;
|
||||
const long_index_t Wo = conv_params.output_spatial_lengths_[0];
|
||||
auto func = [&](auto g, auto n, auto wo) {
|
||||
long_index_t row = n * Wo + wo;
|
||||
long_index_t column = 0;
|
||||
|
||||
for(int gemm_k = 0; gemm_k < GemmK; ++gemm_k)
|
||||
{
|
||||
int ktmp = gemm_k;
|
||||
int y = ktmp / (X * C);
|
||||
ktmp -= y * X * C;
|
||||
int x = ktmp / C;
|
||||
int c = ktmp - x * C;
|
||||
for(long_index_t x = 0; x < conv_params.filter_spatial_lengths_[0]; ++x)
|
||||
{
|
||||
auto wi = static_cast<long_index_t>(wo * conv_params.conv_filter_strides_[0]) +
|
||||
static_cast<long_index_t>(x * conv_params.conv_filter_dilations_[0]) -
|
||||
static_cast<long_index_t>(conv_params.input_left_pads_[0]);
|
||||
|
||||
int hi = y * ConvDilationH + ho * ConvStrideH - InLeftPadH;
|
||||
int wi = x * ConvDilationW + wo * ConvStrideW - InLeftPadW;
|
||||
for(long_index_t c = 0; c < C; ++c)
|
||||
{
|
||||
if(wi >= 0 && type_convert<std::size_t>(wi) < in_host.get_lengths()[3])
|
||||
{
|
||||
InDataType v_in = in_host(g, n, c, wi);
|
||||
out_host(g, row, column) = type_convert<OutDataType>(v_in);
|
||||
}
|
||||
column++;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
bool inbound = (hi >= 0 && hi < Hi && wi >= 0 && wi < Wi);
|
||||
make_ParallelTensorFunctor(func, G, N, Wo)(std::thread::hardware_concurrency());
|
||||
}
|
||||
else if constexpr(NDimSpatial == 2)
|
||||
{
|
||||
const long_index_t Ho = conv_params.output_spatial_lengths_[0];
|
||||
const long_index_t Wo = conv_params.output_spatial_lengths_[1];
|
||||
|
||||
in_mtx_host_ref(gemm_m, gemm_k) = inbound ? in_host(n, hi, wi, c) : 0;
|
||||
}
|
||||
auto func = [&](auto g, auto n, auto ho, auto wo) {
|
||||
long_index_t row = n * Ho * Wo + ho * Wo + wo;
|
||||
long_index_t column = 0;
|
||||
|
||||
for(long_index_t y = 0; y < conv_params.filter_spatial_lengths_[0]; ++y)
|
||||
{
|
||||
auto hi = static_cast<long_index_t>(ho * conv_params.conv_filter_strides_[0]) +
|
||||
static_cast<long_index_t>(y * conv_params.conv_filter_dilations_[0]) -
|
||||
static_cast<long_index_t>(conv_params.input_left_pads_[0]);
|
||||
|
||||
for(long_index_t x = 0; x < conv_params.filter_spatial_lengths_[1]; ++x)
|
||||
{
|
||||
auto wi = static_cast<long_index_t>(wo * conv_params.conv_filter_strides_[1]) +
|
||||
static_cast<long_index_t>(x * conv_params.conv_filter_dilations_[1]) -
|
||||
static_cast<long_index_t>(conv_params.input_left_pads_[1]);
|
||||
|
||||
for(long_index_t c = 0; c < C; ++c)
|
||||
{
|
||||
|
||||
if(hi >= 0 && type_convert<std::size_t>(hi) < in_host.get_lengths()[3] &&
|
||||
wi >= 0 && type_convert<std::size_t>(wi) < in_host.get_lengths()[4])
|
||||
{
|
||||
InDataType v_in = in_host(g, n, c, hi, wi);
|
||||
out_host(g, row, column) = type_convert<OutDataType>(v_in);
|
||||
}
|
||||
column++;
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
make_ParallelTensorFunctor(func, G, N, Ho, Wo)(std::thread::hardware_concurrency());
|
||||
}
|
||||
else if constexpr(NDimSpatial == 3)
|
||||
{
|
||||
const long_index_t Do = conv_params.output_spatial_lengths_[0];
|
||||
const long_index_t Ho = conv_params.output_spatial_lengths_[1];
|
||||
const long_index_t Wo = conv_params.output_spatial_lengths_[2];
|
||||
|
||||
auto func = [&](auto g, auto n, auto d_o, auto ho, auto wo) {
|
||||
long_index_t row = n * Do * Ho * Wo + d_o * Ho * Wo + ho * Wo + wo;
|
||||
long_index_t column = 0;
|
||||
|
||||
for(long_index_t z = 0; z < conv_params.filter_spatial_lengths_[0]; ++z)
|
||||
{
|
||||
auto di = static_cast<long_index_t>(d_o * conv_params.conv_filter_strides_[0]) +
|
||||
static_cast<long_index_t>(z * conv_params.conv_filter_dilations_[0]) -
|
||||
static_cast<long_index_t>(conv_params.input_left_pads_[0]);
|
||||
for(long_index_t y = 0; y < conv_params.filter_spatial_lengths_[1]; ++y)
|
||||
{
|
||||
auto hi = static_cast<long_index_t>(ho * conv_params.conv_filter_strides_[1]) +
|
||||
static_cast<long_index_t>(y * conv_params.conv_filter_dilations_[1]) -
|
||||
static_cast<long_index_t>(conv_params.input_left_pads_[1]);
|
||||
for(long_index_t x = 0; x < conv_params.filter_spatial_lengths_[2]; ++x)
|
||||
{
|
||||
auto wi =
|
||||
static_cast<long_index_t>(wo * conv_params.conv_filter_strides_[2]) +
|
||||
static_cast<long_index_t>(x * conv_params.conv_filter_dilations_[2]) -
|
||||
static_cast<long_index_t>(conv_params.input_left_pads_[2]);
|
||||
for(long_index_t c = 0; c < C; ++c)
|
||||
{
|
||||
if(di >= 0 &&
|
||||
type_convert<std::size_t>(di) < in_host.get_lengths()[3] &&
|
||||
hi >= 0 &&
|
||||
type_convert<std::size_t>(hi) < in_host.get_lengths()[4] &&
|
||||
wi >= 0 && type_convert<std::size_t>(wi) < in_host.get_lengths()[5])
|
||||
{
|
||||
InDataType v_in = in_host(g, n, c, di, hi, wi);
|
||||
out_host(g, row, column) = type_convert<OutDataType>(v_in);
|
||||
}
|
||||
column++;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
make_ParallelTensorFunctor(func, G, N, Do, Ho, Wo)(std::thread::hardware_concurrency());
|
||||
}
|
||||
}
|
||||
} // namespace ck_tile
|
||||
|
||||
Reference in New Issue
Block a user