mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-03 13:11:25 +00:00
Add column to image kernel (#930)
* Add column to image kernel
* Minor fixes for dtypes and client examples
* Disable tests for disabled dtypes
* Disable add instances functions for disabled data types
* Minor stylistic fixes
* Revert "Disable add instances functions for disabled data types"
This reverts commit 728b869563.
* Instances reduction
* Add comments in device_column_to_image_impl
* Update changelog and Copyrights
* Improve changelog
This commit is contained in:
@@ -0,0 +1,33 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
namespace ck {
|
||||
namespace conv_tensor_rearrange_op {
|
||||
|
||||
struct BaseConvTensorRearrangeOp
|
||||
{
|
||||
};
|
||||
|
||||
struct ImageToColumn : public BaseConvTensorRearrangeOp
|
||||
{
|
||||
static constexpr const char* name = "Image to Column";
|
||||
};
|
||||
|
||||
struct ColumnToImage : public BaseConvTensorRearrangeOp
|
||||
{
|
||||
static constexpr const char* name = "Column to Image";
|
||||
};
|
||||
|
||||
template <typename Op,
|
||||
typename std::enable_if<std::is_base_of<BaseConvTensorRearrangeOp, Op>::value,
|
||||
bool>::type = false>
|
||||
std::ostream& operator<<(std::ostream& os, const BaseConvTensorRearrangeOp&)
|
||||
{
|
||||
os << Op::name;
|
||||
return os;
|
||||
}
|
||||
|
||||
} // namespace conv_tensor_rearrange_op
|
||||
} // namespace ck
|
||||
@@ -12,21 +12,26 @@ namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
/**
|
||||
* \brief Image to column.
|
||||
* \brief Convolution Tensor Rearrange.
|
||||
*
|
||||
* This Device operator converts image ([G, N, Di, Hi, Wi, C]) to the gemm
|
||||
* problem([N * Do * Ho * Wo, Z * Y * X * C]). G must be equal to 1.
|
||||
* This Device operator supports conversion image ([G, N, Di, Hi, Wi, C]) to
|
||||
* the gemm problem([N * Do * Ho * Wo, Z * Y * X * C]) (Image to Column) and
|
||||
* conversion gemm form to the image (Column to Image).
|
||||
*
|
||||
* Note that G must be equal to 1.
|
||||
*
|
||||
* \tparam NDimSpatial Number of spatial dimensions.
|
||||
* \tparam InputLayout Input Layout.
|
||||
* \tparam ImageLayout Input Layout.
|
||||
* \tparam InputDataType Input Data Type.
|
||||
* \tparam OutputDataType Output Data Type.
|
||||
* \tparam ConvTensorRearrangeOp Operation type: ImageToColumn, ColumnToImage.
|
||||
*/
|
||||
template <index_t NDimSpatial,
|
||||
typename InputLayout,
|
||||
typename ImageLayout,
|
||||
typename InputDataType,
|
||||
typename OutputDataType>
|
||||
struct DeviceImageToColumn : public BaseOperator
|
||||
typename OutputDataType,
|
||||
typename ConvTensorRearrangeOp>
|
||||
struct DeviceConvTensorRearrange : public BaseOperator
|
||||
{
|
||||
|
||||
/**
|
||||
@@ -39,8 +44,8 @@ struct DeviceImageToColumn : public BaseOperator
|
||||
* \param input_spatial_lengths Input spatial lengths.
|
||||
* \param filter_spatial_lengths Filter spatial lengths.
|
||||
* \param output_spatial_lengths Output spatial lengths.
|
||||
* \param input_g_n_c_wis_strides Input strides in order [G, N, C, D, H, W].
|
||||
* \param output_m_k_strides Output strides.
|
||||
* \param image_g_n_c_wis_strides Image strides in order [G, N, C, D, H, W].
|
||||
* \param gemm_m_k_strides Gemm form strides.
|
||||
* \param conv_filter_strides Convolution filter strides.
|
||||
* \param conv_filter_dilations Convolution filter dilations.
|
||||
* \param input_left_pads Convolution left pads.
|
||||
@@ -55,8 +60,8 @@ struct DeviceImageToColumn : public BaseOperator
|
||||
const std::array<index_t, NDimSpatial>& input_spatial_lengths,
|
||||
const std::array<index_t, NDimSpatial>& filter_spatial_lengths,
|
||||
const std::array<index_t, NDimSpatial>& output_spatial_lengths,
|
||||
const std::array<index_t, NDimSpatial + 3>& input_g_n_c_wis_strides,
|
||||
const std::array<index_t, 2>& output_m_k_strides,
|
||||
const std::array<index_t, NDimSpatial + 3>& image_g_n_c_wis_strides,
|
||||
const std::array<index_t, 2>& gemm_m_k_strides,
|
||||
const std::array<index_t, NDimSpatial>& conv_filter_strides,
|
||||
const std::array<index_t, NDimSpatial>& conv_filter_dilations,
|
||||
const std::array<index_t, NDimSpatial>& input_left_pads,
|
||||
@@ -0,0 +1,621 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck/tensor_description/tensor_descriptor.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_conv_tensor_rearrange.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_tensor_rearrange.hpp"
|
||||
#include "ck/host_utility/kernel_launch.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/convolution_backward_data_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
|
||||
|
||||
#include "ck/tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp"
|
||||
|
||||
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/conv_tensor_rearrange_op.hpp"
|
||||
#include "ck/host_utility/io.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
// Image to column for input layout NDHWC:
|
||||
// input : image converted to the gemm problem [N * Do * Ho * Wo, Z * Y * X * C]
|
||||
// output : image [N, Di, Hi, Wi, C]
|
||||
template <index_t NDimSpatial,
|
||||
typename ImageLayout,
|
||||
typename InputDataType,
|
||||
typename OutputDataType,
|
||||
index_t BlockSize,
|
||||
index_t MPerBlock,
|
||||
index_t KPerBlock,
|
||||
typename ThreadClusterLengths,
|
||||
index_t ScalarPerVector,
|
||||
typename std::enable_if<NDimSpatial >= 1 && NDimSpatial <= 3, bool>::type = false>
|
||||
struct DeviceColumnToImageImpl
|
||||
: public DeviceConvTensorRearrange<NDimSpatial,
|
||||
ImageLayout,
|
||||
InputDataType,
|
||||
OutputDataType,
|
||||
conv_tensor_rearrange_op::ColumnToImage>
|
||||
{
|
||||
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
static constexpr auto I2 = Number<2>{};
|
||||
|
||||
static constexpr auto ZIdx = Number<I0>{};
|
||||
static constexpr auto YIdx = NDimSpatial == 1 ? I0 : Number<NDimSpatial - I2>{};
|
||||
static constexpr auto XIdx = Number<NDimSpatial - I1>{};
|
||||
|
||||
static constexpr auto spatial_offset = Number<3>{};
|
||||
|
||||
static constexpr auto conv_to_gemm_transformer =
|
||||
TransformConvFwdToGemm<NDimSpatial, ConvolutionForwardSpecialization::Default>{};
|
||||
static constexpr auto matrix_padder =
|
||||
MatrixPadder<GemmSpecialization::MKPadding, index_t, index_t, index_t>{
|
||||
MPerBlock, 0 /* NPerBlock*/, KPerBlock};
|
||||
|
||||
// Calculate number of independent filters for given conv params
|
||||
static index_t GetNumberOfIndependentFilters(const index_t input_spatial_len,
|
||||
const index_t left_pad,
|
||||
const index_t right_pad,
|
||||
const index_t filter_len,
|
||||
const index_t filter_stride,
|
||||
const index_t filter_dilation,
|
||||
const index_t image_offset)
|
||||
{
|
||||
const index_t x_eff = (filter_len - 1) * filter_dilation + 1;
|
||||
const index_t next_filter_padded =
|
||||
math::integer_divide_ceil(x_eff, filter_stride) * filter_stride;
|
||||
// If filter_stride >= x_eff then each filter is independent
|
||||
const index_t independent_filter_stride =
|
||||
filter_stride >= x_eff ? filter_stride : next_filter_padded;
|
||||
const index_t w_eff = input_spatial_len - image_offset + left_pad + right_pad - x_eff;
|
||||
// There are no independent filters
|
||||
if(w_eff < 0)
|
||||
return 0;
|
||||
const index_t independent_kernels_num = w_eff / independent_filter_stride + 1;
|
||||
return independent_kernels_num;
|
||||
}
|
||||
|
||||
// Make column form descriptor
|
||||
static auto
|
||||
MakeInputDescriptor_M_K(const ck::index_t N,
|
||||
const ck::index_t C,
|
||||
const std::array<index_t, NDimSpatial>& filter_spatial_lengths,
|
||||
const std::array<index_t, NDimSpatial>& output_spatial_lengths,
|
||||
const std::array<index_t, NDimSpatial>& conv_filter_strides,
|
||||
const std::array<index_t, 2>& gemm_m_k_strides,
|
||||
const std::array<index_t, NDimSpatial>& independent_filters,
|
||||
const std::array<index_t, NDimSpatial>& effs)
|
||||
{
|
||||
const index_t DoHoWo = ck::accumulate_n<index_t>(
|
||||
output_spatial_lengths.begin(), NDimSpatial, 1, std::multiplies<>());
|
||||
const index_t CZYX =
|
||||
C * ck::accumulate_n<index_t>(
|
||||
filter_spatial_lengths.begin(), NDimSpatial, 1, std::multiplies<>());
|
||||
|
||||
const index_t NStride = DoHoWo * gemm_m_k_strides[I0] * gemm_m_k_strides[I1];
|
||||
// Calculate the appropriate stride for each set of independent filters
|
||||
// in each dimension
|
||||
const index_t WStride =
|
||||
math::integer_divide_ceil(effs[XIdx], conv_filter_strides[XIdx]) * gemm_m_k_strides[I0];
|
||||
const index_t HStride = math::integer_divide_ceil(effs[YIdx], conv_filter_strides[YIdx]) *
|
||||
output_spatial_lengths[XIdx] * gemm_m_k_strides[I0];
|
||||
const index_t DStride = math::integer_divide_ceil(effs[ZIdx], conv_filter_strides[ZIdx]) *
|
||||
output_spatial_lengths[YIdx] * output_spatial_lengths[XIdx] *
|
||||
gemm_m_k_strides[I0];
|
||||
// Create descriptor for independent filters in each dimension and
|
||||
// then merge them into column form
|
||||
if constexpr(NDimSpatial == 1)
|
||||
{
|
||||
const auto desc_gemm_form =
|
||||
make_naive_tensor_descriptor(make_tuple(N, independent_filters[XIdx], CZYX),
|
||||
make_tuple(NStride, WStride, gemm_m_k_strides[I1]));
|
||||
const auto desc_gemm_form_merged_filters = transform_tensor_descriptor(
|
||||
desc_gemm_form,
|
||||
make_tuple(make_merge_transform(make_tuple(N, independent_filters[XIdx])),
|
||||
make_pass_through_transform(CZYX)),
|
||||
make_tuple(Sequence<0, 1>{}, Sequence<2>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
const auto desc_m_k = matrix_padder.PadADescriptor_M_K(desc_gemm_form_merged_filters);
|
||||
return desc_m_k;
|
||||
}
|
||||
else if constexpr(NDimSpatial == 2)
|
||||
{
|
||||
const auto desc_gemm_form = make_naive_tensor_descriptor(
|
||||
make_tuple(N, independent_filters[YIdx], independent_filters[XIdx], CZYX),
|
||||
make_tuple(NStride, HStride, WStride, gemm_m_k_strides[I1]));
|
||||
const auto desc_gemm_form_merged_filters = transform_tensor_descriptor(
|
||||
desc_gemm_form,
|
||||
make_tuple(make_merge_transform(
|
||||
make_tuple(N, independent_filters[YIdx], independent_filters[XIdx])),
|
||||
make_pass_through_transform(CZYX)),
|
||||
make_tuple(Sequence<0, 1, 2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
const auto desc_m_k = matrix_padder.PadADescriptor_M_K(desc_gemm_form_merged_filters);
|
||||
return desc_m_k;
|
||||
}
|
||||
else if constexpr(NDimSpatial == 3)
|
||||
{
|
||||
const auto desc_gemm_form = make_naive_tensor_descriptor(
|
||||
make_tuple(N,
|
||||
independent_filters[ZIdx],
|
||||
independent_filters[YIdx],
|
||||
independent_filters[XIdx],
|
||||
CZYX),
|
||||
make_tuple(NStride, DStride, HStride, WStride, gemm_m_k_strides[I1]));
|
||||
const auto desc_gemm_form_merged_filters = transform_tensor_descriptor(
|
||||
desc_gemm_form,
|
||||
make_tuple(make_merge_transform(make_tuple(N,
|
||||
independent_filters[ZIdx],
|
||||
independent_filters[YIdx],
|
||||
independent_filters[XIdx])),
|
||||
make_pass_through_transform(CZYX)),
|
||||
make_tuple(Sequence<0, 1, 2, 3>{}, Sequence<4>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
const auto desc_m_k = matrix_padder.PadADescriptor_M_K(desc_gemm_form_merged_filters);
|
||||
return desc_m_k;
|
||||
}
|
||||
}
|
||||
|
||||
// Use MakeADescriptor_M_K from grouped convolution forward
|
||||
static auto
|
||||
MakeOutDescriptor_M_K(const ck::index_t N,
|
||||
const ck::index_t C,
|
||||
const std::array<index_t, NDimSpatial>& input_spatial_lengths,
|
||||
const std::array<index_t, NDimSpatial>& filter_spatial_lengths,
|
||||
const std::array<index_t, NDimSpatial + 3>& image_g_n_c_wis_strides,
|
||||
const std::array<index_t, NDimSpatial>& conv_filter_strides,
|
||||
const std::array<index_t, NDimSpatial>& conv_filter_dilations,
|
||||
const std::array<index_t, NDimSpatial>& input_left_pads,
|
||||
const std::array<index_t, NDimSpatial>& input_right_pads,
|
||||
const std::array<index_t, NDimSpatial>& image_offsets,
|
||||
const std::array<index_t, NDimSpatial>& independent_filters,
|
||||
const std::array<index_t, NDimSpatial>& effs)
|
||||
{
|
||||
std::array<index_t, NDimSpatial + 3> a_g_n_c_wis_lengths{1};
|
||||
std::array<index_t, NDimSpatial + 3> b_g_k_c_xs_lengths{1};
|
||||
std::array<index_t, NDimSpatial + 3> c_g_n_k_wos_lengths{1};
|
||||
|
||||
auto copy = [](const auto& x, auto& y, index_t dst_offset) {
|
||||
std::copy(x.begin(), x.end(), y.begin() + dst_offset);
|
||||
};
|
||||
|
||||
copy(input_spatial_lengths, a_g_n_c_wis_lengths, spatial_offset);
|
||||
copy(filter_spatial_lengths, b_g_k_c_xs_lengths, spatial_offset);
|
||||
// Calculate descriptor only for independent filters
|
||||
copy(independent_filters, c_g_n_k_wos_lengths, spatial_offset);
|
||||
|
||||
// fill only significant values (C and N)
|
||||
a_g_n_c_wis_lengths[I1] = N;
|
||||
a_g_n_c_wis_lengths[I2] = C;
|
||||
b_g_k_c_xs_lengths[I2] = C;
|
||||
c_g_n_k_wos_lengths[I1] = N;
|
||||
|
||||
// Modify pads to apply offsets
|
||||
std::array<index_t, NDimSpatial> input_left_pads_with_offset;
|
||||
for(index_t i = 0; i < NDimSpatial; i++)
|
||||
{
|
||||
input_left_pads_with_offset[i] = math::max(0, input_left_pads[i] - image_offsets[i]);
|
||||
}
|
||||
// Modify input spatial lengths to apply offsets
|
||||
for(index_t i = 0; i < NDimSpatial; i++)
|
||||
{
|
||||
a_g_n_c_wis_lengths[i + spatial_offset] -=
|
||||
math::max(0, image_offsets[i] - input_left_pads[i]);
|
||||
}
|
||||
|
||||
// Strides to next independent filters
|
||||
std::array<index_t, NDimSpatial> independent_filter_strides;
|
||||
for(index_t i = 0; i < NDimSpatial; i++)
|
||||
{
|
||||
index_t independent_filter_stride =
|
||||
math::integer_divide_ceil(effs[i], conv_filter_strides[i]) * conv_filter_strides[i];
|
||||
// If conv stride is greater than whole filter size, use conv stride
|
||||
independent_filter_strides[i] = conv_filter_strides[i] >= effs[i]
|
||||
? conv_filter_strides[i]
|
||||
: independent_filter_stride;
|
||||
}
|
||||
|
||||
// Calculate image form descriptor for the modified convolution problem
|
||||
const auto in_gemmmraw_gemmkraw_desc =
|
||||
conv_to_gemm_transformer.template MakeADescriptor_M_K<ImageLayout>(
|
||||
a_g_n_c_wis_lengths,
|
||||
image_g_n_c_wis_strides,
|
||||
b_g_k_c_xs_lengths,
|
||||
{}, // not needed for A Descriptor
|
||||
c_g_n_k_wos_lengths,
|
||||
{}, // not needed for A Descriptor
|
||||
// conv_filter_strides,
|
||||
independent_filter_strides,
|
||||
conv_filter_dilations,
|
||||
input_left_pads_with_offset,
|
||||
input_right_pads);
|
||||
|
||||
const auto in_gemmm_gemmk_desc =
|
||||
matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_desc);
|
||||
return in_gemmm_gemmk_desc;
|
||||
}
|
||||
|
||||
using InputGridDesc =
|
||||
remove_cvref_t<decltype(MakeInputDescriptor_M_K(1, 1, {}, {}, {}, {}, {}, {}))>;
|
||||
using OutputGridDesc = remove_cvref_t<decltype(MakeOutDescriptor_M_K(
|
||||
1, 1, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}))>;
|
||||
|
||||
using Block2ETileMap = remove_cvref_t<
|
||||
decltype(BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, KPerBlock, InputGridDesc>(
|
||||
InputGridDesc{}))>;
|
||||
|
||||
using GridwiseTensorRearrangeKernel = GridwiseTensorRearrange<InputGridDesc,
|
||||
InputDataType,
|
||||
OutputGridDesc,
|
||||
OutputDataType,
|
||||
BlockSize,
|
||||
MPerBlock,
|
||||
KPerBlock,
|
||||
ThreadClusterLengths,
|
||||
ScalarPerVector,
|
||||
InMemoryDataOperationEnum::Add,
|
||||
Block2ETileMap>;
|
||||
|
||||
struct Argument : public BaseArgument
|
||||
{
|
||||
Argument(const void* p_in, // input image
|
||||
void* p_out, // output image
|
||||
const ck::index_t N,
|
||||
const ck::index_t C,
|
||||
const std::array<index_t, NDimSpatial>& input_spatial_lengths,
|
||||
const std::array<index_t, NDimSpatial>& filter_spatial_lengths,
|
||||
const std::array<index_t, NDimSpatial>& output_spatial_lengths,
|
||||
const std::array<index_t, NDimSpatial + 3>& image_g_n_c_wis_strides,
|
||||
const std::array<index_t, 2>& gemm_m_k_strides,
|
||||
const std::array<index_t, NDimSpatial>& conv_filter_strides,
|
||||
const std::array<index_t, NDimSpatial>& conv_filter_dilations,
|
||||
const std::array<index_t, NDimSpatial>& input_left_pads,
|
||||
const std::array<index_t, NDimSpatial>& input_right_pads)
|
||||
: C_(C),
|
||||
X_(filter_spatial_lengths[NDimSpatial - I1]),
|
||||
p_in_{static_cast<const InputDataType*>(p_in)},
|
||||
p_out_{static_cast<OutputDataType*>(p_out)},
|
||||
image_g_n_c_wis_strides_{image_g_n_c_wis_strides},
|
||||
conv_filter_strides_{conv_filter_strides},
|
||||
conv_filter_dilations_{conv_filter_dilations},
|
||||
input_left_pads_{input_left_pads},
|
||||
input_right_pads_{input_right_pads}
|
||||
{
|
||||
const index_t x_eff =
|
||||
(filter_spatial_lengths[XIdx] - 1) * conv_filter_dilations[XIdx] + 1;
|
||||
const index_t y_eff =
|
||||
NDimSpatial < 2
|
||||
? I1
|
||||
: (filter_spatial_lengths[YIdx] - 1) * conv_filter_dilations[YIdx] + 1;
|
||||
const index_t z_eff =
|
||||
NDimSpatial < 3
|
||||
? I1
|
||||
: (filter_spatial_lengths[ZIdx] - 1) * conv_filter_dilations[ZIdx] + 1;
|
||||
|
||||
// Iterate over sets of independent filters
|
||||
for(int z_img_offset = 0; z_img_offset < z_eff;
|
||||
z_img_offset += conv_filter_strides[ZIdx])
|
||||
{
|
||||
for(int y_img_offset = 0; y_img_offset < y_eff;
|
||||
y_img_offset += conv_filter_strides[YIdx])
|
||||
{
|
||||
for(int x_img_offset = 0; x_img_offset < x_eff;
|
||||
x_img_offset += conv_filter_strides[XIdx])
|
||||
{
|
||||
|
||||
std::array<index_t, NDimSpatial> image_offsets;
|
||||
std::array<index_t, NDimSpatial> effs;
|
||||
// Calculate the starting offset for a given set of
|
||||
// independent filters
|
||||
if constexpr(NDimSpatial == 1)
|
||||
{
|
||||
image_offsets = {x_img_offset};
|
||||
effs = {x_eff};
|
||||
}
|
||||
if constexpr(NDimSpatial == 2)
|
||||
{
|
||||
image_offsets = {y_img_offset, x_img_offset};
|
||||
effs = {y_eff, x_eff};
|
||||
}
|
||||
else if constexpr(NDimSpatial == 3)
|
||||
{
|
||||
image_offsets = {z_img_offset, y_img_offset, x_img_offset};
|
||||
effs = {z_eff, y_eff, x_eff};
|
||||
}
|
||||
|
||||
std::array<index_t, NDimSpatial> independent_filters;
|
||||
for(index_t i = 0; i < NDimSpatial; i++)
|
||||
{
|
||||
independent_filters[i] =
|
||||
GetNumberOfIndependentFilters(input_spatial_lengths[i],
|
||||
input_left_pads[i],
|
||||
input_right_pads[i],
|
||||
filter_spatial_lengths[i],
|
||||
conv_filter_strides[i],
|
||||
conv_filter_dilations[i],
|
||||
image_offsets[i]);
|
||||
}
|
||||
const index_t independent_filters_acum = ck::accumulate_n<index_t>(
|
||||
independent_filters.begin(), NDimSpatial, 1, std::multiplies<>());
|
||||
if(independent_filters_acum <= 0)
|
||||
continue;
|
||||
|
||||
const auto in_grid_desc_m_k =
|
||||
MakeInputDescriptor_M_K(N,
|
||||
C,
|
||||
filter_spatial_lengths,
|
||||
output_spatial_lengths,
|
||||
conv_filter_strides,
|
||||
gemm_m_k_strides,
|
||||
independent_filters,
|
||||
effs);
|
||||
const auto out_grid_desc_m_k =
|
||||
MakeOutDescriptor_M_K(N,
|
||||
C,
|
||||
input_spatial_lengths,
|
||||
filter_spatial_lengths,
|
||||
image_g_n_c_wis_strides,
|
||||
conv_filter_strides,
|
||||
conv_filter_dilations,
|
||||
input_left_pads,
|
||||
input_right_pads,
|
||||
image_offsets,
|
||||
independent_filters,
|
||||
effs);
|
||||
in_grid_desc_m_k_container_.push_back(in_grid_desc_m_k);
|
||||
out_grid_desc_m_k_container_.push_back(out_grid_desc_m_k);
|
||||
|
||||
const index_t x_idx = x_img_offset / conv_filter_strides[XIdx];
|
||||
const index_t y_idx = y_img_offset / conv_filter_strides[YIdx];
|
||||
const index_t z_idx = z_img_offset / conv_filter_strides[ZIdx];
|
||||
|
||||
const index_t x_offset_with_pad =
|
||||
math::max(0, x_img_offset - input_left_pads[XIdx]);
|
||||
const index_t y_offset_with_pad =
|
||||
math::max(0, y_img_offset - input_left_pads[YIdx]);
|
||||
const index_t z_offset_with_pad =
|
||||
math::max(0, z_img_offset - input_left_pads[ZIdx]);
|
||||
|
||||
// Memory offsets to next set of independent filters,
|
||||
// move to independent filters in each dimension
|
||||
const index_t in_offset =
|
||||
x_idx * gemm_m_k_strides[0] +
|
||||
y_idx * gemm_m_k_strides[0] * output_spatial_lengths[XIdx] +
|
||||
z_idx * gemm_m_k_strides[0] * output_spatial_lengths[YIdx] *
|
||||
output_spatial_lengths[XIdx];
|
||||
// Move to independent filters in appropriate dimensions
|
||||
const index_t out_offset =
|
||||
x_offset_with_pad * image_g_n_c_wis_strides[spatial_offset + XIdx] +
|
||||
y_offset_with_pad * image_g_n_c_wis_strides[spatial_offset + YIdx] +
|
||||
z_offset_with_pad * image_g_n_c_wis_strides[spatial_offset + ZIdx];
|
||||
|
||||
const InputDataType* p_in_with_offset =
|
||||
static_cast<const InputDataType*>(p_in) + in_offset;
|
||||
OutputDataType* p_out_with_offset =
|
||||
static_cast<OutputDataType*>(p_out) + out_offset;
|
||||
p_in_container_.push_back(p_in_with_offset);
|
||||
p_out_container_.push_back(p_out_with_offset);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void Print() const
|
||||
{
|
||||
for(std::size_t i = 0; i < in_grid_desc_m_k_container_.size(); i++)
|
||||
{
|
||||
std::cout << in_grid_desc_m_k_container_[i] << std::endl;
|
||||
std::cout << out_grid_desc_m_k_container_[i] << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
const ck::index_t C_;
|
||||
const ck::index_t X_;
|
||||
|
||||
const InputDataType* p_in_;
|
||||
OutputDataType* p_out_;
|
||||
|
||||
const std::array<index_t, NDimSpatial + 3>& image_g_n_c_wis_strides_;
|
||||
const std::array<index_t, NDimSpatial>& conv_filter_strides_;
|
||||
const std::array<index_t, NDimSpatial>& conv_filter_dilations_;
|
||||
const std::array<index_t, NDimSpatial>& input_left_pads_;
|
||||
const std::array<index_t, NDimSpatial>& input_right_pads_;
|
||||
|
||||
std::vector<InputGridDesc> in_grid_desc_m_k_container_;
|
||||
std::vector<OutputGridDesc> out_grid_desc_m_k_container_;
|
||||
|
||||
std::vector<const InputDataType*> p_in_container_;
|
||||
std::vector<OutputDataType*> p_out_container_;
|
||||
};
|
||||
|
||||
struct Invoker : public BaseInvoker
|
||||
{
|
||||
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
|
||||
{
|
||||
if(stream_config.log_level_ > 0)
|
||||
{
|
||||
arg.Print();
|
||||
}
|
||||
|
||||
float elapsed_time = 0.f;
|
||||
const auto kernel = kernel_tensor_rearrange<InputGridDesc,
|
||||
InputDataType,
|
||||
OutputGridDesc,
|
||||
OutputDataType,
|
||||
Block2ETileMap,
|
||||
GridwiseTensorRearrangeKernel>;
|
||||
|
||||
// Execute each set of independent filters
|
||||
for(std::size_t i = 0; i < arg.in_grid_desc_m_k_container_.size(); i++)
|
||||
{
|
||||
const auto block_2_tile_map =
|
||||
BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, KPerBlock, InputGridDesc>(
|
||||
arg.out_grid_desc_m_k_container_[i]);
|
||||
const index_t grid_size =
|
||||
block_2_tile_map.CalculateGridSize(arg.in_grid_desc_m_k_container_[i]);
|
||||
elapsed_time += launch_and_time_kernel(stream_config,
|
||||
kernel,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
arg.in_grid_desc_m_k_container_[i],
|
||||
arg.p_in_container_[i],
|
||||
arg.out_grid_desc_m_k_container_[i],
|
||||
arg.p_out_container_[i],
|
||||
block_2_tile_map);
|
||||
}
|
||||
return elapsed_time;
|
||||
}
|
||||
|
||||
float Run(const BaseArgument* p_arg,
|
||||
const StreamConfig& stream_config = StreamConfig{}) override
|
||||
{
|
||||
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
|
||||
}
|
||||
};
|
||||
|
||||
bool IsSupportedArgument(const Argument& arg)
|
||||
{
|
||||
using namespace tensor_layout::convolution;
|
||||
if constexpr(!(std::is_same_v<ImageLayout, GNWC> || std::is_same_v<ImageLayout, GNHWC> ||
|
||||
std::is_same_v<ImageLayout, GNDHWC>))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
const auto w_pad_left = arg.input_left_pads_[NDimSpatial - I1];
|
||||
const auto w_pad_right = arg.input_right_pads_[NDimSpatial - I1];
|
||||
const auto dilation_x = arg.conv_filter_dilations_[NDimSpatial - I1];
|
||||
const auto stride_x = arg.conv_filter_strides_[NDimSpatial - I1];
|
||||
bool is_w_packed = arg.image_g_n_c_wis_strides_[NDimSpatial + I2] == arg.C_;
|
||||
bool is_c_packed = arg.image_g_n_c_wis_strides_[I2] == 1;
|
||||
|
||||
// check vector acces with c not packed
|
||||
if(!is_c_packed && ScalarPerVector != 1)
|
||||
return false;
|
||||
// check vector access of filter window row (only C if C is not packed)
|
||||
if(!is_w_packed && arg.C_ % ScalarPerVector != 0)
|
||||
return false;
|
||||
// check vector access of filter window row (X * C)
|
||||
if(arg.X_ * arg.C_ % ScalarPerVector != 0)
|
||||
return false;
|
||||
// check vector access of pads (w_pad_left/w_pad_right * C)
|
||||
if(w_pad_left * arg.C_ % ScalarPerVector != 0 ||
|
||||
w_pad_right * arg.C_ % ScalarPerVector != 0)
|
||||
return false;
|
||||
// check vector access of with stride and pad
|
||||
if((w_pad_left != 0 || w_pad_right != 0) && stride_x > 1 && arg.C_ % ScalarPerVector != 0)
|
||||
return false;
|
||||
// check vector access of with dilation
|
||||
if(dilation_x > 1 && arg.C_ % ScalarPerVector != 0)
|
||||
return false;
|
||||
|
||||
bool valid = true;
|
||||
for(std::size_t i = 0; i < arg.in_grid_desc_m_k_container_.size(); i++)
|
||||
{
|
||||
valid &= GridwiseTensorRearrangeKernel::CheckValidity(
|
||||
arg.in_grid_desc_m_k_container_[i], arg.out_grid_desc_m_k_container_[i]);
|
||||
}
|
||||
return valid;
|
||||
}
|
||||
|
||||
bool IsSupportedArgument(const BaseArgument* p_arg) override
|
||||
{
|
||||
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
|
||||
}
|
||||
|
||||
static auto MakeArgument(const void* p_in, // input image
|
||||
void* p_out, // output image
|
||||
const ck::index_t N,
|
||||
const ck::index_t C,
|
||||
const std::array<index_t, NDimSpatial>& input_spatial_lengths,
|
||||
const std::array<index_t, NDimSpatial>& filter_spatial_lengths,
|
||||
const std::array<index_t, NDimSpatial>& output_spatial_lengths,
|
||||
const std::array<index_t, NDimSpatial + 3>& image_g_n_c_wis_strides,
|
||||
const std::array<index_t, 2>& gemm_m_k_strides,
|
||||
const std::array<index_t, NDimSpatial>& conv_filter_strides,
|
||||
const std::array<index_t, NDimSpatial>& conv_filter_dilations,
|
||||
const std::array<index_t, NDimSpatial>& input_left_pads,
|
||||
const std::array<index_t, NDimSpatial>& input_right_pads)
|
||||
{
|
||||
return Argument{static_cast<const InputDataType*>(p_in),
|
||||
static_cast<OutputDataType*>(p_out),
|
||||
N,
|
||||
C,
|
||||
input_spatial_lengths,
|
||||
filter_spatial_lengths,
|
||||
output_spatial_lengths,
|
||||
image_g_n_c_wis_strides,
|
||||
gemm_m_k_strides,
|
||||
conv_filter_strides,
|
||||
conv_filter_dilations,
|
||||
input_left_pads,
|
||||
input_right_pads};
|
||||
}
|
||||
|
||||
static auto MakeInvoker() { return Invoker{}; }
|
||||
|
||||
std::unique_ptr<BaseArgument>
|
||||
MakeArgumentPointer(const void* p_in, // input image
|
||||
void* p_out, // output image
|
||||
const ck::index_t N,
|
||||
const ck::index_t C,
|
||||
const std::array<index_t, NDimSpatial>& input_spatial_lengths,
|
||||
const std::array<index_t, NDimSpatial>& filter_spatial_lengths,
|
||||
const std::array<index_t, NDimSpatial>& output_spatial_lengths,
|
||||
const std::array<index_t, NDimSpatial + 3>& image_g_n_c_wis_strides,
|
||||
const std::array<index_t, 2>& gemm_m_k_strides,
|
||||
const std::array<index_t, NDimSpatial>& conv_filter_strides,
|
||||
const std::array<index_t, NDimSpatial>& conv_filter_dilations,
|
||||
const std::array<index_t, NDimSpatial>& input_left_pads,
|
||||
const std::array<index_t, NDimSpatial>& input_right_pads) override
|
||||
{
|
||||
return std::make_unique<Argument>(static_cast<const InputDataType*>(p_in),
|
||||
static_cast<OutputDataType*>(p_out),
|
||||
N,
|
||||
C,
|
||||
input_spatial_lengths,
|
||||
filter_spatial_lengths,
|
||||
output_spatial_lengths,
|
||||
image_g_n_c_wis_strides,
|
||||
gemm_m_k_strides,
|
||||
conv_filter_strides,
|
||||
conv_filter_dilations,
|
||||
input_left_pads,
|
||||
input_right_pads);
|
||||
}
|
||||
|
||||
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
|
||||
{
|
||||
return std::make_unique<Invoker>(Invoker{});
|
||||
}
|
||||
|
||||
std::string GetTypeString() const override
|
||||
{
|
||||
auto str = std::stringstream();
|
||||
|
||||
// clang-format off
|
||||
str << "DeviceColumnToImage"
|
||||
<< "<"
|
||||
<< BlockSize << ", "
|
||||
<< MPerBlock << ", "
|
||||
<< KPerBlock << ", "
|
||||
<< ScalarPerVector
|
||||
<< ">";
|
||||
// clang-format on
|
||||
|
||||
return str.str();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -5,64 +5,41 @@
|
||||
|
||||
#include "ck/tensor_description/tensor_descriptor.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_image_to_column.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_image_to_column.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_conv_tensor_rearrange.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_tensor_rearrange.hpp"
|
||||
#include "ck/host_utility/kernel_launch.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
|
||||
#include "ck/tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/conv_tensor_rearrange_op.hpp"
|
||||
#include "ck/host_utility/io.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
template <typename InputGridDesc,
|
||||
typename InputDataType,
|
||||
typename OutputGridDesc,
|
||||
typename OutputDataType,
|
||||
typename Block2ETileMap,
|
||||
typename GridwiseImageToColumnKernel>
|
||||
__global__ void
|
||||
#if CK_USE_LAUNCH_BOUNDS
|
||||
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||
#endif
|
||||
kernel_image_to_column(const InputGridDesc in_grid_desc,
|
||||
const InputDataType* __restrict__ p_in_global,
|
||||
const OutputGridDesc out_grid_desc,
|
||||
OutputDataType* __restrict__ p_out_global,
|
||||
const Block2ETileMap block_2_tile_map)
|
||||
{
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx908__) || \
|
||||
defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx1030__) || defined(__gfx1100__) || \
|
||||
defined(__gfx1101__) || defined(__gfx1102__) || defined(__gfx941__) || defined(__gfx942__))
|
||||
GridwiseImageToColumnKernel::Run(
|
||||
in_grid_desc, p_in_global, out_grid_desc, p_out_global, block_2_tile_map);
|
||||
#else
|
||||
ignore = in_grid_desc;
|
||||
ignore = p_in_global;
|
||||
ignore = out_grid_desc;
|
||||
ignore = p_out_global;
|
||||
ignore = block_2_tile_map;
|
||||
#endif
|
||||
}
|
||||
|
||||
// Image to column for input layout NDHWC:
|
||||
// input : input image [N, Di, Hi, Wi, C],
|
||||
// output : output image [N * Do * Ho * Wo, Z * Y * X * C]
|
||||
// input : input image [N, Di, Hi, Wi, C]
|
||||
// output : gemm form [N * Do * Ho * Wo, Z * Y * X * C]
|
||||
template <index_t NDimSpatial,
|
||||
typename InputLayout,
|
||||
typename ImageLayout,
|
||||
typename InputDataType,
|
||||
typename OutputDataType,
|
||||
index_t BlockSize,
|
||||
index_t MPerBlock,
|
||||
index_t KPerBlock,
|
||||
typename ThreadClusterLengths,
|
||||
index_t ScalarPerVector>
|
||||
index_t ScalarPerVector,
|
||||
typename std::enable_if<NDimSpatial >= 1 && NDimSpatial <= 3, bool>::type = false>
|
||||
struct DeviceImageToColumnImpl
|
||||
: public DeviceImageToColumn<NDimSpatial, InputLayout, InputDataType, OutputDataType>
|
||||
: public DeviceConvTensorRearrange<NDimSpatial,
|
||||
ImageLayout,
|
||||
InputDataType,
|
||||
OutputDataType,
|
||||
conv_tensor_rearrange_op::ImageToColumn>
|
||||
{
|
||||
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
@@ -83,7 +60,7 @@ struct DeviceImageToColumnImpl
|
||||
const std::array<index_t, NDimSpatial>& input_spatial_lengths,
|
||||
const std::array<index_t, NDimSpatial>& filter_spatial_lengths,
|
||||
const std::array<index_t, NDimSpatial>& output_spatial_lengths,
|
||||
const std::array<index_t, NDimSpatial + 3>& input_g_n_c_wis_strides,
|
||||
const std::array<index_t, NDimSpatial + 3>& image_g_n_c_wis_strides,
|
||||
const std::array<index_t, NDimSpatial>& conv_filter_strides,
|
||||
const std::array<index_t, NDimSpatial>& conv_filter_dilations,
|
||||
const std::array<index_t, NDimSpatial>& input_left_pads,
|
||||
@@ -110,9 +87,9 @@ struct DeviceImageToColumnImpl
|
||||
c_g_n_k_wos_lengths[I1] = N;
|
||||
|
||||
const auto in_gemmmraw_gemmkraw_desc =
|
||||
conv_to_gemm_transformer.template MakeADescriptor_M_K<InputLayout>(
|
||||
conv_to_gemm_transformer.template MakeADescriptor_M_K<ImageLayout>(
|
||||
a_g_n_c_wis_lengths,
|
||||
input_g_n_c_wis_strides,
|
||||
image_g_n_c_wis_strides,
|
||||
b_g_k_c_xs_lengths,
|
||||
{}, // not needed for A Descriptor
|
||||
c_g_n_k_wos_lengths,
|
||||
@@ -132,7 +109,7 @@ struct DeviceImageToColumnImpl
|
||||
const ck::index_t C,
|
||||
const std::array<index_t, NDimSpatial>& filter_spatial_lengths,
|
||||
const std::array<index_t, NDimSpatial>& output_spatial_lengths,
|
||||
const std::array<index_t, 2>& output_m_k_strides)
|
||||
const std::array<index_t, 2>& gemm_m_k_strides)
|
||||
{
|
||||
const index_t NDoHoWo =
|
||||
N * ck::accumulate_n<index_t>(
|
||||
@@ -141,7 +118,7 @@ struct DeviceImageToColumnImpl
|
||||
C * ck::accumulate_n<index_t>(
|
||||
filter_spatial_lengths.begin(), NDimSpatial, 1, std::multiplies<>());
|
||||
const auto desc_mraw_kraw = make_naive_tensor_descriptor(
|
||||
make_tuple(NDoHoWo, CZYX), make_tuple(output_m_k_strides[I0], output_m_k_strides[I1]));
|
||||
make_tuple(NDoHoWo, CZYX), make_tuple(gemm_m_k_strides[I0], gemm_m_k_strides[I1]));
|
||||
|
||||
const auto desc_m_k = matrix_padder.PadADescriptor_M_K(desc_mraw_kraw);
|
||||
return desc_m_k;
|
||||
@@ -155,28 +132,29 @@ struct DeviceImageToColumnImpl
|
||||
decltype(BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, KPerBlock, OutputGridDesc>(
|
||||
OutputGridDesc{}))>;
|
||||
|
||||
using GridwiseImageToColumnKernel = GridwiseImageToColumn<InputGridDesc,
|
||||
InputDataType,
|
||||
OutputGridDesc,
|
||||
OutputDataType,
|
||||
BlockSize,
|
||||
MPerBlock,
|
||||
KPerBlock,
|
||||
ThreadClusterLengths,
|
||||
ScalarPerVector,
|
||||
Block2ETileMap>;
|
||||
using GridwiseTensorRearrangeKernel = GridwiseTensorRearrange<InputGridDesc,
|
||||
InputDataType,
|
||||
OutputGridDesc,
|
||||
OutputDataType,
|
||||
BlockSize,
|
||||
MPerBlock,
|
||||
KPerBlock,
|
||||
ThreadClusterLengths,
|
||||
ScalarPerVector,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
Block2ETileMap>;
|
||||
|
||||
struct Argument : public BaseArgument
|
||||
{
|
||||
Argument(const void* p_in, // input image
|
||||
void* p_out, // output image
|
||||
void* p_out, // gemm form
|
||||
const ck::index_t N,
|
||||
const ck::index_t C,
|
||||
const std::array<index_t, NDimSpatial>& input_spatial_lengths,
|
||||
const std::array<index_t, NDimSpatial>& filter_spatial_lengths,
|
||||
const std::array<index_t, NDimSpatial>& output_spatial_lengths,
|
||||
const std::array<index_t, NDimSpatial + 3>& input_g_n_c_wis_strides,
|
||||
const std::array<index_t, 2>& output_m_k_strides,
|
||||
const std::array<index_t, NDimSpatial + 3>& image_g_n_c_wis_strides,
|
||||
const std::array<index_t, 2>& gemm_m_k_strides,
|
||||
const std::array<index_t, NDimSpatial>& conv_filter_strides,
|
||||
const std::array<index_t, NDimSpatial>& conv_filter_dilations,
|
||||
const std::array<index_t, NDimSpatial>& input_left_pads,
|
||||
@@ -185,7 +163,7 @@ struct DeviceImageToColumnImpl
|
||||
X_(filter_spatial_lengths[NDimSpatial - I1]),
|
||||
p_in_{static_cast<const InputDataType*>(p_in)},
|
||||
p_out_{static_cast<OutputDataType*>(p_out)},
|
||||
input_g_n_c_wis_strides_{input_g_n_c_wis_strides},
|
||||
image_g_n_c_wis_strides_{image_g_n_c_wis_strides},
|
||||
conv_filter_strides_{conv_filter_strides},
|
||||
conv_filter_dilations_{conv_filter_dilations},
|
||||
input_left_pads_{input_left_pads},
|
||||
@@ -197,7 +175,7 @@ struct DeviceImageToColumnImpl
|
||||
input_spatial_lengths,
|
||||
filter_spatial_lengths,
|
||||
output_spatial_lengths,
|
||||
input_g_n_c_wis_strides,
|
||||
image_g_n_c_wis_strides,
|
||||
|
||||
conv_filter_strides,
|
||||
conv_filter_dilations,
|
||||
@@ -205,7 +183,7 @@ struct DeviceImageToColumnImpl
|
||||
input_right_pads);
|
||||
|
||||
out_grid_desc_m_k_ = MakeOutDescriptor_M_K(
|
||||
N, C, filter_spatial_lengths, output_spatial_lengths, output_m_k_strides);
|
||||
N, C, filter_spatial_lengths, output_spatial_lengths, gemm_m_k_strides);
|
||||
}
|
||||
|
||||
void Print() const
|
||||
@@ -220,7 +198,7 @@ struct DeviceImageToColumnImpl
|
||||
const InputDataType* p_in_;
|
||||
OutputDataType* p_out_;
|
||||
|
||||
const std::array<index_t, NDimSpatial + 3>& input_g_n_c_wis_strides_;
|
||||
const std::array<index_t, NDimSpatial + 3>& image_g_n_c_wis_strides_;
|
||||
const std::array<index_t, NDimSpatial>& conv_filter_strides_;
|
||||
const std::array<index_t, NDimSpatial>& conv_filter_dilations_;
|
||||
const std::array<index_t, NDimSpatial>& input_left_pads_;
|
||||
@@ -243,12 +221,12 @@ struct DeviceImageToColumnImpl
|
||||
BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, KPerBlock, OutputGridDesc>(
|
||||
arg.out_grid_desc_m_k_);
|
||||
const index_t grid_size = block_2_tile_map.CalculateGridSize(arg.out_grid_desc_m_k_);
|
||||
const auto kernel = kernel_image_to_column<InputGridDesc,
|
||||
InputDataType,
|
||||
OutputGridDesc,
|
||||
OutputDataType,
|
||||
Block2ETileMap,
|
||||
GridwiseImageToColumnKernel>;
|
||||
const auto kernel = kernel_tensor_rearrange<InputGridDesc,
|
||||
InputDataType,
|
||||
OutputGridDesc,
|
||||
OutputDataType,
|
||||
Block2ETileMap,
|
||||
GridwiseTensorRearrangeKernel>;
|
||||
|
||||
float elapsed_time = launch_and_time_kernel(stream_config,
|
||||
kernel,
|
||||
@@ -273,12 +251,8 @@ struct DeviceImageToColumnImpl
|
||||
bool IsSupportedArgument(const Argument& arg)
|
||||
{
|
||||
using namespace tensor_layout::convolution;
|
||||
if(!(std::is_same_v<InputLayout, GNWC> || std::is_same_v<InputLayout, GNHWC> ||
|
||||
std::is_same_v<InputLayout, GNDHWC>))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
if(!(NDimSpatial >= 1 && NDimSpatial <= 3))
|
||||
if constexpr(!(std::is_same_v<ImageLayout, GNWC> || std::is_same_v<ImageLayout, GNHWC> ||
|
||||
std::is_same_v<ImageLayout, GNDHWC>))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
@@ -287,8 +261,8 @@ struct DeviceImageToColumnImpl
|
||||
const auto w_pad_right = arg.input_right_pads_[NDimSpatial - I1];
|
||||
const auto dilation_x = arg.conv_filter_dilations_[NDimSpatial - I1];
|
||||
const auto stride_x = arg.conv_filter_strides_[NDimSpatial - I1];
|
||||
bool is_w_packed = arg.input_g_n_c_wis_strides_[NDimSpatial + I2] == arg.C_;
|
||||
bool is_c_packed = arg.input_g_n_c_wis_strides_[I2] == 1;
|
||||
bool is_w_packed = arg.image_g_n_c_wis_strides_[NDimSpatial + I2] == arg.C_;
|
||||
bool is_c_packed = arg.image_g_n_c_wis_strides_[I2] == 1;
|
||||
|
||||
// check vector acces with c not packed
|
||||
if(!is_c_packed && ScalarPerVector != 1)
|
||||
@@ -310,8 +284,8 @@ struct DeviceImageToColumnImpl
|
||||
if(dilation_x > 1 && arg.C_ % ScalarPerVector != 0)
|
||||
return false;
|
||||
|
||||
return GridwiseImageToColumnKernel::CheckValidity(arg.in_grid_desc_m_k_,
|
||||
arg.out_grid_desc_m_k_);
|
||||
return GridwiseTensorRearrangeKernel::CheckValidity(arg.in_grid_desc_m_k_,
|
||||
arg.out_grid_desc_m_k_);
|
||||
}
|
||||
|
||||
bool IsSupportedArgument(const BaseArgument* p_arg) override
|
||||
@@ -320,14 +294,14 @@ struct DeviceImageToColumnImpl
|
||||
}
|
||||
|
||||
static auto MakeArgument(const void* p_in, // input image
|
||||
void* p_out, // output image
|
||||
void* p_out, // gemm form
|
||||
const ck::index_t N,
|
||||
const ck::index_t C,
|
||||
const std::array<index_t, NDimSpatial>& input_spatial_lengths,
|
||||
const std::array<index_t, NDimSpatial>& filter_spatial_lengths,
|
||||
const std::array<index_t, NDimSpatial>& output_spatial_lengths,
|
||||
const std::array<index_t, NDimSpatial + 3>& input_g_n_c_wis_strides,
|
||||
const std::array<index_t, 2>& output_m_k_strides,
|
||||
const std::array<index_t, NDimSpatial + 3>& image_g_n_c_wis_strides,
|
||||
const std::array<index_t, 2>& gemm_m_k_strides,
|
||||
const std::array<index_t, NDimSpatial>& conv_filter_strides,
|
||||
const std::array<index_t, NDimSpatial>& conv_filter_dilations,
|
||||
const std::array<index_t, NDimSpatial>& input_left_pads,
|
||||
@@ -340,8 +314,8 @@ struct DeviceImageToColumnImpl
|
||||
input_spatial_lengths,
|
||||
filter_spatial_lengths,
|
||||
output_spatial_lengths,
|
||||
input_g_n_c_wis_strides,
|
||||
output_m_k_strides,
|
||||
image_g_n_c_wis_strides,
|
||||
gemm_m_k_strides,
|
||||
conv_filter_strides,
|
||||
conv_filter_dilations,
|
||||
input_left_pads,
|
||||
@@ -352,14 +326,14 @@ struct DeviceImageToColumnImpl
|
||||
|
||||
std::unique_ptr<BaseArgument>
|
||||
MakeArgumentPointer(const void* p_in, // input image
|
||||
void* p_out, // output image
|
||||
void* p_out, // gemm form
|
||||
const ck::index_t N,
|
||||
const ck::index_t C,
|
||||
const std::array<index_t, NDimSpatial>& input_spatial_lengths,
|
||||
const std::array<index_t, NDimSpatial>& filter_spatial_lengths,
|
||||
const std::array<index_t, NDimSpatial>& output_spatial_lengths,
|
||||
const std::array<index_t, NDimSpatial + 3>& input_g_n_c_wis_strides,
|
||||
const std::array<index_t, 2>& output_m_k_strides,
|
||||
const std::array<index_t, NDimSpatial + 3>& image_g_n_c_wis_strides,
|
||||
const std::array<index_t, 2>& gemm_m_k_strides,
|
||||
const std::array<index_t, NDimSpatial>& conv_filter_strides,
|
||||
const std::array<index_t, NDimSpatial>& conv_filter_dilations,
|
||||
const std::array<index_t, NDimSpatial>& input_left_pads,
|
||||
@@ -372,8 +346,8 @@ struct DeviceImageToColumnImpl
|
||||
input_spatial_lengths,
|
||||
filter_spatial_lengths,
|
||||
output_spatial_lengths,
|
||||
input_g_n_c_wis_strides,
|
||||
output_m_k_strides,
|
||||
image_g_n_c_wis_strides,
|
||||
gemm_m_k_strides,
|
||||
conv_filter_strides,
|
||||
conv_filter_dilations,
|
||||
input_left_pads,
|
||||
|
||||
Reference in New Issue
Block a user