mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-03 21:21:22 +00:00
Add support for groups in Img2Col/Col2Img (#1007)
* Add support for groups in Img2Col/Col2Img * Fix interface test * Fix interface test G to N * Improve performance * Change gemm layout to 3d * Fixes
This commit is contained in:
@@ -17,15 +17,18 @@
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/conv_tensor_rearrange_op.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp"
|
||||
#include "ck/host_utility/io.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
// Image to column for input layout NDHWC:
|
||||
// input : image converted to the gemm problem [N * Do * Ho * Wo, Z * Y * X * C]
|
||||
// output : image [N, Di, Hi, Wi, C]
|
||||
// Column to Image:
|
||||
// input : gemm form [G, N * Do * Ho * Wo, Z * Y * X * C]
|
||||
// output : input image [G, N, Di, Hi, Wi, C]
|
||||
// input : gemm form [N * Do * Ho * Wo, G, Z * Y * X * C]
|
||||
// output : input image [N, Di, Hi, Wi, G, C]
|
||||
template <index_t NDimSpatial,
|
||||
typename ImageLayout,
|
||||
typename InputDataType,
|
||||
@@ -43,6 +46,14 @@ struct DeviceColumnToImageImpl
|
||||
OutputDataType,
|
||||
conv_tensor_rearrange_op::ColumnToImage>
|
||||
{
|
||||
static constexpr bool is_NSpatialGC =
|
||||
std::is_same_v<ImageLayout, tensor_layout::convolution::NWGC> ||
|
||||
std::is_same_v<ImageLayout, tensor_layout::convolution::NHWGC> ||
|
||||
std::is_same_v<ImageLayout, tensor_layout::convolution::NDHWGC>;
|
||||
static constexpr bool is_GNSpatialC =
|
||||
std::is_same_v<ImageLayout, tensor_layout::convolution::GNWC> ||
|
||||
std::is_same_v<ImageLayout, tensor_layout::convolution::GNHWC> ||
|
||||
std::is_same_v<ImageLayout, tensor_layout::convolution::GNDHWC>;
|
||||
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
@@ -90,7 +101,7 @@ struct DeviceColumnToImageImpl
|
||||
const std::array<index_t, NDimSpatial>& filter_spatial_lengths,
|
||||
const std::array<index_t, NDimSpatial>& output_spatial_lengths,
|
||||
const std::array<index_t, NDimSpatial>& conv_filter_strides,
|
||||
const std::array<index_t, 2>& gemm_m_k_strides,
|
||||
const std::array<index_t, 3>& gemm_g_m_k_strides,
|
||||
const std::array<index_t, NDimSpatial>& independent_filters,
|
||||
const std::array<index_t, NDimSpatial>& effs)
|
||||
{
|
||||
@@ -100,23 +111,23 @@ struct DeviceColumnToImageImpl
|
||||
C * ck::accumulate_n<index_t>(
|
||||
filter_spatial_lengths.begin(), NDimSpatial, 1, std::multiplies<>());
|
||||
|
||||
const index_t NStride = DoHoWo * gemm_m_k_strides[I0] * gemm_m_k_strides[I1];
|
||||
const index_t NStride = DoHoWo * gemm_g_m_k_strides[I1] * gemm_g_m_k_strides[I2];
|
||||
// Calculate the appropriate stride for each set of independent filters
|
||||
// in each dimension
|
||||
const index_t WStride =
|
||||
math::integer_divide_ceil(effs[XIdx], conv_filter_strides[XIdx]) * gemm_m_k_strides[I0];
|
||||
const index_t WStride = math::integer_divide_ceil(effs[XIdx], conv_filter_strides[XIdx]) *
|
||||
gemm_g_m_k_strides[I1];
|
||||
const index_t HStride = math::integer_divide_ceil(effs[YIdx], conv_filter_strides[YIdx]) *
|
||||
output_spatial_lengths[XIdx] * gemm_m_k_strides[I0];
|
||||
output_spatial_lengths[XIdx] * gemm_g_m_k_strides[I1];
|
||||
const index_t DStride = math::integer_divide_ceil(effs[ZIdx], conv_filter_strides[ZIdx]) *
|
||||
output_spatial_lengths[YIdx] * output_spatial_lengths[XIdx] *
|
||||
gemm_m_k_strides[I0];
|
||||
gemm_g_m_k_strides[I1];
|
||||
// Create descriptor for independent filters in each dimension and
|
||||
// then merge them into column form
|
||||
if constexpr(NDimSpatial == 1)
|
||||
{
|
||||
const auto desc_gemm_form =
|
||||
make_naive_tensor_descriptor(make_tuple(N, independent_filters[XIdx], CZYX),
|
||||
make_tuple(NStride, WStride, gemm_m_k_strides[I1]));
|
||||
make_tuple(NStride, WStride, gemm_g_m_k_strides[I2]));
|
||||
const auto desc_gemm_form_merged_filters = transform_tensor_descriptor(
|
||||
desc_gemm_form,
|
||||
make_tuple(make_merge_transform(make_tuple(N, independent_filters[XIdx])),
|
||||
@@ -130,7 +141,7 @@ struct DeviceColumnToImageImpl
|
||||
{
|
||||
const auto desc_gemm_form = make_naive_tensor_descriptor(
|
||||
make_tuple(N, independent_filters[YIdx], independent_filters[XIdx], CZYX),
|
||||
make_tuple(NStride, HStride, WStride, gemm_m_k_strides[I1]));
|
||||
make_tuple(NStride, HStride, WStride, gemm_g_m_k_strides[I2]));
|
||||
const auto desc_gemm_form_merged_filters = transform_tensor_descriptor(
|
||||
desc_gemm_form,
|
||||
make_tuple(make_merge_transform(
|
||||
@@ -149,7 +160,7 @@ struct DeviceColumnToImageImpl
|
||||
independent_filters[YIdx],
|
||||
independent_filters[XIdx],
|
||||
CZYX),
|
||||
make_tuple(NStride, DStride, HStride, WStride, gemm_m_k_strides[I1]));
|
||||
make_tuple(NStride, DStride, HStride, WStride, gemm_g_m_k_strides[I2]));
|
||||
const auto desc_gemm_form_merged_filters = transform_tensor_descriptor(
|
||||
desc_gemm_form,
|
||||
make_tuple(make_merge_transform(make_tuple(N,
|
||||
@@ -252,34 +263,38 @@ struct DeviceColumnToImageImpl
|
||||
decltype(BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, KPerBlock, InputGridDesc>(
|
||||
InputGridDesc{}))>;
|
||||
|
||||
using GridwiseTensorRearrangeKernel = GridwiseTensorRearrange<InputGridDesc,
|
||||
InputDataType,
|
||||
OutputGridDesc,
|
||||
OutputDataType,
|
||||
BlockSize,
|
||||
MPerBlock,
|
||||
KPerBlock,
|
||||
ThreadClusterLengths,
|
||||
ScalarPerVector,
|
||||
InMemoryDataOperationEnum::Add,
|
||||
Block2ETileMap>;
|
||||
using GridwiseTensorRearrangeKernel =
|
||||
GridwiseTensorRearrange<InputGridDesc,
|
||||
InputDataType,
|
||||
OutputGridDesc,
|
||||
OutputDataType,
|
||||
BlockSize,
|
||||
MPerBlock,
|
||||
KPerBlock,
|
||||
ThreadClusterLengths,
|
||||
ScalarPerVector,
|
||||
InMemoryDataOperationEnum::Add,
|
||||
Block2ETileMap,
|
||||
ComputePtrOffsetOfStridedBatch<I0>>;
|
||||
|
||||
struct Argument : public BaseArgument
|
||||
{
|
||||
Argument(const void* p_in, // input image
|
||||
void* p_out, // output image
|
||||
const ck::index_t G,
|
||||
const ck::index_t N,
|
||||
const ck::index_t C,
|
||||
const std::array<index_t, NDimSpatial>& input_spatial_lengths,
|
||||
const std::array<index_t, NDimSpatial>& filter_spatial_lengths,
|
||||
const std::array<index_t, NDimSpatial>& output_spatial_lengths,
|
||||
const std::array<index_t, NDimSpatial + 3>& image_g_n_c_wis_strides,
|
||||
const std::array<index_t, 2>& gemm_m_k_strides,
|
||||
const std::array<index_t, 3>& gemm_g_m_k_strides,
|
||||
const std::array<index_t, NDimSpatial>& conv_filter_strides,
|
||||
const std::array<index_t, NDimSpatial>& conv_filter_dilations,
|
||||
const std::array<index_t, NDimSpatial>& input_left_pads,
|
||||
const std::array<index_t, NDimSpatial>& input_right_pads)
|
||||
: C_(C),
|
||||
: G_(G),
|
||||
C_(C),
|
||||
X_(filter_spatial_lengths[NDimSpatial - I1]),
|
||||
p_in_{static_cast<const InputDataType*>(p_in)},
|
||||
p_out_{static_cast<OutputDataType*>(p_out)},
|
||||
@@ -289,6 +304,9 @@ struct DeviceColumnToImageImpl
|
||||
input_left_pads_{input_left_pads},
|
||||
input_right_pads_{input_right_pads}
|
||||
{
|
||||
compute_ptr_offset_of_batch_.BatchStrideA_ = gemm_g_m_k_strides[I0];
|
||||
compute_ptr_offset_of_batch_.BatchStrideC_ = image_g_n_c_wis_strides[I0];
|
||||
|
||||
const index_t x_eff =
|
||||
(filter_spatial_lengths[XIdx] - 1) * conv_filter_dilations[XIdx] + 1;
|
||||
const index_t y_eff =
|
||||
@@ -354,7 +372,7 @@ struct DeviceColumnToImageImpl
|
||||
filter_spatial_lengths,
|
||||
output_spatial_lengths,
|
||||
conv_filter_strides,
|
||||
gemm_m_k_strides,
|
||||
gemm_g_m_k_strides,
|
||||
independent_filters,
|
||||
effs);
|
||||
const auto out_grid_desc_m_k =
|
||||
@@ -387,10 +405,9 @@ struct DeviceColumnToImageImpl
|
||||
// Memory offsets to next set of independent filters,
|
||||
// move to independent filters in each dimension
|
||||
const index_t in_offset =
|
||||
x_idx * gemm_m_k_strides[0] +
|
||||
y_idx * gemm_m_k_strides[0] * output_spatial_lengths[XIdx] +
|
||||
z_idx * gemm_m_k_strides[0] * output_spatial_lengths[YIdx] *
|
||||
output_spatial_lengths[XIdx];
|
||||
(x_idx + y_idx * output_spatial_lengths[XIdx] +
|
||||
z_idx * output_spatial_lengths[YIdx] * output_spatial_lengths[XIdx]) *
|
||||
gemm_g_m_k_strides[I1];
|
||||
// Move to independent filters in appropriate dimensions
|
||||
const index_t out_offset =
|
||||
x_offset_with_pad * image_g_n_c_wis_strides[spatial_offset + XIdx] +
|
||||
@@ -417,6 +434,7 @@ struct DeviceColumnToImageImpl
|
||||
}
|
||||
}
|
||||
|
||||
const ck::index_t G_;
|
||||
const ck::index_t C_;
|
||||
const ck::index_t X_;
|
||||
|
||||
@@ -434,6 +452,8 @@ struct DeviceColumnToImageImpl
|
||||
|
||||
std::vector<const InputDataType*> p_in_container_;
|
||||
std::vector<OutputDataType*> p_out_container_;
|
||||
|
||||
ComputePtrOffsetOfStridedBatch<I0> compute_ptr_offset_of_batch_;
|
||||
};
|
||||
|
||||
struct Invoker : public BaseInvoker
|
||||
@@ -451,6 +471,7 @@ struct DeviceColumnToImageImpl
|
||||
OutputGridDesc,
|
||||
OutputDataType,
|
||||
Block2ETileMap,
|
||||
ComputePtrOffsetOfStridedBatch<I0>,
|
||||
GridwiseTensorRearrangeKernel>;
|
||||
|
||||
// Execute each set of independent filters
|
||||
@@ -460,7 +481,7 @@ struct DeviceColumnToImageImpl
|
||||
BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, KPerBlock, InputGridDesc>(
|
||||
arg.out_grid_desc_m_k_container_[i]);
|
||||
const index_t grid_size =
|
||||
block_2_tile_map.CalculateGridSize(arg.in_grid_desc_m_k_container_[i]);
|
||||
block_2_tile_map.CalculateGridSize(arg.in_grid_desc_m_k_container_[i]) * arg.G_;
|
||||
elapsed_time += launch_and_time_kernel(stream_config,
|
||||
kernel,
|
||||
dim3(grid_size),
|
||||
@@ -470,7 +491,9 @@ struct DeviceColumnToImageImpl
|
||||
arg.p_in_container_[i],
|
||||
arg.out_grid_desc_m_k_container_[i],
|
||||
arg.p_out_container_[i],
|
||||
block_2_tile_map);
|
||||
arg.G_,
|
||||
block_2_tile_map,
|
||||
arg.compute_ptr_offset_of_batch_);
|
||||
}
|
||||
return elapsed_time;
|
||||
}
|
||||
@@ -485,8 +508,7 @@ struct DeviceColumnToImageImpl
|
||||
bool IsSupportedArgument(const Argument& arg)
|
||||
{
|
||||
using namespace tensor_layout::convolution;
|
||||
if constexpr(!(std::is_same_v<ImageLayout, GNWC> || std::is_same_v<ImageLayout, GNHWC> ||
|
||||
std::is_same_v<ImageLayout, GNDHWC>))
|
||||
if constexpr(!(is_NSpatialGC || is_GNSpatialC))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
@@ -534,13 +556,14 @@ struct DeviceColumnToImageImpl
|
||||
|
||||
static auto MakeArgument(const void* p_in, // input image
|
||||
void* p_out, // output image
|
||||
const ck::index_t G,
|
||||
const ck::index_t N,
|
||||
const ck::index_t C,
|
||||
const std::array<index_t, NDimSpatial>& input_spatial_lengths,
|
||||
const std::array<index_t, NDimSpatial>& filter_spatial_lengths,
|
||||
const std::array<index_t, NDimSpatial>& output_spatial_lengths,
|
||||
const std::array<index_t, NDimSpatial + 3>& image_g_n_c_wis_strides,
|
||||
const std::array<index_t, 2>& gemm_m_k_strides,
|
||||
const std::array<index_t, 3>& gemm_g_m_k_strides,
|
||||
const std::array<index_t, NDimSpatial>& conv_filter_strides,
|
||||
const std::array<index_t, NDimSpatial>& conv_filter_dilations,
|
||||
const std::array<index_t, NDimSpatial>& input_left_pads,
|
||||
@@ -548,13 +571,14 @@ struct DeviceColumnToImageImpl
|
||||
{
|
||||
return Argument{static_cast<const InputDataType*>(p_in),
|
||||
static_cast<OutputDataType*>(p_out),
|
||||
G,
|
||||
N,
|
||||
C,
|
||||
input_spatial_lengths,
|
||||
filter_spatial_lengths,
|
||||
output_spatial_lengths,
|
||||
image_g_n_c_wis_strides,
|
||||
gemm_m_k_strides,
|
||||
gemm_g_m_k_strides,
|
||||
conv_filter_strides,
|
||||
conv_filter_dilations,
|
||||
input_left_pads,
|
||||
@@ -566,13 +590,14 @@ struct DeviceColumnToImageImpl
|
||||
std::unique_ptr<BaseArgument>
|
||||
MakeArgumentPointer(const void* p_in, // input image
|
||||
void* p_out, // output image
|
||||
const ck::index_t G,
|
||||
const ck::index_t N,
|
||||
const ck::index_t C,
|
||||
const std::array<index_t, NDimSpatial>& input_spatial_lengths,
|
||||
const std::array<index_t, NDimSpatial>& filter_spatial_lengths,
|
||||
const std::array<index_t, NDimSpatial>& output_spatial_lengths,
|
||||
const std::array<index_t, NDimSpatial + 3>& image_g_n_c_wis_strides,
|
||||
const std::array<index_t, 2>& gemm_m_k_strides,
|
||||
const std::array<index_t, 3>& gemm_g_m_k_strides,
|
||||
const std::array<index_t, NDimSpatial>& conv_filter_strides,
|
||||
const std::array<index_t, NDimSpatial>& conv_filter_dilations,
|
||||
const std::array<index_t, NDimSpatial>& input_left_pads,
|
||||
@@ -580,13 +605,14 @@ struct DeviceColumnToImageImpl
|
||||
{
|
||||
return std::make_unique<Argument>(static_cast<const InputDataType*>(p_in),
|
||||
static_cast<OutputDataType*>(p_out),
|
||||
G,
|
||||
N,
|
||||
C,
|
||||
input_spatial_lengths,
|
||||
filter_spatial_lengths,
|
||||
output_spatial_lengths,
|
||||
image_g_n_c_wis_strides,
|
||||
gemm_m_k_strides,
|
||||
gemm_g_m_k_strides,
|
||||
conv_filter_strides,
|
||||
conv_filter_dilations,
|
||||
input_left_pads,
|
||||
|
||||
@@ -15,15 +15,18 @@
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/conv_tensor_rearrange_op.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp"
|
||||
#include "ck/host_utility/io.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
// Image to column for input layout NDHWC:
|
||||
// input : input image [N, Di, Hi, Wi, C]
|
||||
// output : gemm form [N * Do * Ho * Wo, Z * Y * X * C]
|
||||
// Image to column:
|
||||
// input : input image [G, N, Di, Hi, Wi, C]
|
||||
// output : gemm form [G * N * Do * Ho * Wo, Z * Y * X * C]
|
||||
// input : input image [N, Di, Hi, Wi, G, C]
|
||||
// output : gemm form [N * Do * Ho * Wo * G, Z * Y * X * C]
|
||||
template <index_t NDimSpatial,
|
||||
typename ImageLayout,
|
||||
typename InputDataType,
|
||||
@@ -41,6 +44,14 @@ struct DeviceImageToColumnImpl
|
||||
OutputDataType,
|
||||
conv_tensor_rearrange_op::ImageToColumn>
|
||||
{
|
||||
static constexpr bool is_NSpatialGC =
|
||||
std::is_same_v<ImageLayout, tensor_layout::convolution::NWGC> ||
|
||||
std::is_same_v<ImageLayout, tensor_layout::convolution::NHWGC> ||
|
||||
std::is_same_v<ImageLayout, tensor_layout::convolution::NDHWGC>;
|
||||
static constexpr bool is_GNSpatialC =
|
||||
std::is_same_v<ImageLayout, tensor_layout::convolution::GNWC> ||
|
||||
std::is_same_v<ImageLayout, tensor_layout::convolution::GNHWC> ||
|
||||
std::is_same_v<ImageLayout, tensor_layout::convolution::GNDHWC>;
|
||||
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
@@ -109,7 +120,7 @@ struct DeviceImageToColumnImpl
|
||||
const ck::index_t C,
|
||||
const std::array<index_t, NDimSpatial>& filter_spatial_lengths,
|
||||
const std::array<index_t, NDimSpatial>& output_spatial_lengths,
|
||||
const std::array<index_t, 2>& gemm_m_k_strides)
|
||||
const std::array<index_t, 3>& gemm_g_m_k_strides)
|
||||
{
|
||||
const index_t NDoHoWo =
|
||||
N * ck::accumulate_n<index_t>(
|
||||
@@ -117,11 +128,10 @@ struct DeviceImageToColumnImpl
|
||||
const index_t CZYX =
|
||||
C * ck::accumulate_n<index_t>(
|
||||
filter_spatial_lengths.begin(), NDimSpatial, 1, std::multiplies<>());
|
||||
const auto desc_mraw_kraw = make_naive_tensor_descriptor(
|
||||
make_tuple(NDoHoWo, CZYX), make_tuple(gemm_m_k_strides[I0], gemm_m_k_strides[I1]));
|
||||
|
||||
const auto desc_m_k = matrix_padder.PadADescriptor_M_K(desc_mraw_kraw);
|
||||
return desc_m_k;
|
||||
const auto desc_mraw_kraw = make_naive_tensor_descriptor(
|
||||
make_tuple(NDoHoWo, CZYX), make_tuple(gemm_g_m_k_strides[I1], gemm_g_m_k_strides[I2]));
|
||||
return matrix_padder.PadADescriptor_M_K(desc_mraw_kraw);
|
||||
}
|
||||
|
||||
using InputGridDesc =
|
||||
@@ -132,34 +142,38 @@ struct DeviceImageToColumnImpl
|
||||
decltype(BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, KPerBlock, OutputGridDesc>(
|
||||
OutputGridDesc{}))>;
|
||||
|
||||
using GridwiseTensorRearrangeKernel = GridwiseTensorRearrange<InputGridDesc,
|
||||
InputDataType,
|
||||
OutputGridDesc,
|
||||
OutputDataType,
|
||||
BlockSize,
|
||||
MPerBlock,
|
||||
KPerBlock,
|
||||
ThreadClusterLengths,
|
||||
ScalarPerVector,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
Block2ETileMap>;
|
||||
using GridwiseTensorRearrangeKernel =
|
||||
GridwiseTensorRearrange<InputGridDesc,
|
||||
InputDataType,
|
||||
OutputGridDesc,
|
||||
OutputDataType,
|
||||
BlockSize,
|
||||
MPerBlock,
|
||||
KPerBlock,
|
||||
ThreadClusterLengths,
|
||||
ScalarPerVector,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
Block2ETileMap,
|
||||
ComputePtrOffsetOfStridedBatch<I0>>;
|
||||
|
||||
struct Argument : public BaseArgument
|
||||
{
|
||||
Argument(const void* p_in, // input image
|
||||
void* p_out, // gemm form
|
||||
const ck::index_t G,
|
||||
const ck::index_t N,
|
||||
const ck::index_t C,
|
||||
const std::array<index_t, NDimSpatial>& input_spatial_lengths,
|
||||
const std::array<index_t, NDimSpatial>& filter_spatial_lengths,
|
||||
const std::array<index_t, NDimSpatial>& output_spatial_lengths,
|
||||
const std::array<index_t, NDimSpatial + 3>& image_g_n_c_wis_strides,
|
||||
const std::array<index_t, 2>& gemm_m_k_strides,
|
||||
const std::array<index_t, 3>& gemm_g_m_k_strides,
|
||||
const std::array<index_t, NDimSpatial>& conv_filter_strides,
|
||||
const std::array<index_t, NDimSpatial>& conv_filter_dilations,
|
||||
const std::array<index_t, NDimSpatial>& input_left_pads,
|
||||
const std::array<index_t, NDimSpatial>& input_right_pads)
|
||||
: C_(C),
|
||||
: G_(G),
|
||||
C_(C),
|
||||
X_(filter_spatial_lengths[NDimSpatial - I1]),
|
||||
p_in_{static_cast<const InputDataType*>(p_in)},
|
||||
p_out_{static_cast<OutputDataType*>(p_out)},
|
||||
@@ -176,14 +190,16 @@ struct DeviceImageToColumnImpl
|
||||
filter_spatial_lengths,
|
||||
output_spatial_lengths,
|
||||
image_g_n_c_wis_strides,
|
||||
|
||||
conv_filter_strides,
|
||||
conv_filter_dilations,
|
||||
input_left_pads,
|
||||
input_right_pads);
|
||||
|
||||
out_grid_desc_m_k_ = MakeOutDescriptor_M_K(
|
||||
N, C, filter_spatial_lengths, output_spatial_lengths, gemm_m_k_strides);
|
||||
N, C, filter_spatial_lengths, output_spatial_lengths, gemm_g_m_k_strides);
|
||||
|
||||
compute_ptr_offset_of_batch_.BatchStrideA_ = image_g_n_c_wis_strides[I0];
|
||||
compute_ptr_offset_of_batch_.BatchStrideC_ = gemm_g_m_k_strides[I0];
|
||||
}
|
||||
|
||||
void Print() const
|
||||
@@ -192,6 +208,7 @@ struct DeviceImageToColumnImpl
|
||||
std::cout << out_grid_desc_m_k_ << std::endl;
|
||||
}
|
||||
|
||||
const ck::index_t G_;
|
||||
const ck::index_t C_;
|
||||
const ck::index_t X_;
|
||||
|
||||
@@ -206,6 +223,8 @@ struct DeviceImageToColumnImpl
|
||||
|
||||
InputGridDesc in_grid_desc_m_k_;
|
||||
OutputGridDesc out_grid_desc_m_k_;
|
||||
|
||||
ComputePtrOffsetOfStridedBatch<I0> compute_ptr_offset_of_batch_;
|
||||
};
|
||||
|
||||
struct Invoker : public BaseInvoker
|
||||
@@ -220,12 +239,14 @@ struct DeviceImageToColumnImpl
|
||||
const auto block_2_tile_map =
|
||||
BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, KPerBlock, OutputGridDesc>(
|
||||
arg.out_grid_desc_m_k_);
|
||||
const index_t grid_size = block_2_tile_map.CalculateGridSize(arg.out_grid_desc_m_k_);
|
||||
const auto kernel = kernel_tensor_rearrange<InputGridDesc,
|
||||
const index_t grid_size =
|
||||
block_2_tile_map.CalculateGridSize(arg.out_grid_desc_m_k_) * arg.G_;
|
||||
const auto kernel = kernel_tensor_rearrange<InputGridDesc,
|
||||
InputDataType,
|
||||
OutputGridDesc,
|
||||
OutputDataType,
|
||||
Block2ETileMap,
|
||||
ComputePtrOffsetOfStridedBatch<I0>,
|
||||
GridwiseTensorRearrangeKernel>;
|
||||
|
||||
float elapsed_time = launch_and_time_kernel(stream_config,
|
||||
@@ -237,7 +258,9 @@ struct DeviceImageToColumnImpl
|
||||
arg.p_in_,
|
||||
arg.out_grid_desc_m_k_,
|
||||
arg.p_out_,
|
||||
block_2_tile_map);
|
||||
arg.G_,
|
||||
block_2_tile_map,
|
||||
arg.compute_ptr_offset_of_batch_);
|
||||
return elapsed_time;
|
||||
}
|
||||
|
||||
@@ -250,9 +273,7 @@ struct DeviceImageToColumnImpl
|
||||
|
||||
bool IsSupportedArgument(const Argument& arg)
|
||||
{
|
||||
using namespace tensor_layout::convolution;
|
||||
if constexpr(!(std::is_same_v<ImageLayout, GNWC> || std::is_same_v<ImageLayout, GNHWC> ||
|
||||
std::is_same_v<ImageLayout, GNDHWC>))
|
||||
if constexpr(!(is_NSpatialGC || is_GNSpatialC))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
@@ -295,13 +316,14 @@ struct DeviceImageToColumnImpl
|
||||
|
||||
static auto MakeArgument(const void* p_in, // input image
|
||||
void* p_out, // gemm form
|
||||
const ck::index_t G,
|
||||
const ck::index_t N,
|
||||
const ck::index_t C,
|
||||
const std::array<index_t, NDimSpatial>& input_spatial_lengths,
|
||||
const std::array<index_t, NDimSpatial>& filter_spatial_lengths,
|
||||
const std::array<index_t, NDimSpatial>& output_spatial_lengths,
|
||||
const std::array<index_t, NDimSpatial + 3>& image_g_n_c_wis_strides,
|
||||
const std::array<index_t, 2>& gemm_m_k_strides,
|
||||
const std::array<index_t, 3>& gemm_g_m_k_strides,
|
||||
const std::array<index_t, NDimSpatial>& conv_filter_strides,
|
||||
const std::array<index_t, NDimSpatial>& conv_filter_dilations,
|
||||
const std::array<index_t, NDimSpatial>& input_left_pads,
|
||||
@@ -309,13 +331,14 @@ struct DeviceImageToColumnImpl
|
||||
{
|
||||
return Argument{static_cast<const InputDataType*>(p_in),
|
||||
static_cast<OutputDataType*>(p_out),
|
||||
G,
|
||||
N,
|
||||
C,
|
||||
input_spatial_lengths,
|
||||
filter_spatial_lengths,
|
||||
output_spatial_lengths,
|
||||
image_g_n_c_wis_strides,
|
||||
gemm_m_k_strides,
|
||||
gemm_g_m_k_strides,
|
||||
conv_filter_strides,
|
||||
conv_filter_dilations,
|
||||
input_left_pads,
|
||||
@@ -327,13 +350,14 @@ struct DeviceImageToColumnImpl
|
||||
std::unique_ptr<BaseArgument>
|
||||
MakeArgumentPointer(const void* p_in, // input image
|
||||
void* p_out, // gemm form
|
||||
const ck::index_t G,
|
||||
const ck::index_t N,
|
||||
const ck::index_t C,
|
||||
const std::array<index_t, NDimSpatial>& input_spatial_lengths,
|
||||
const std::array<index_t, NDimSpatial>& filter_spatial_lengths,
|
||||
const std::array<index_t, NDimSpatial>& output_spatial_lengths,
|
||||
const std::array<index_t, NDimSpatial + 3>& image_g_n_c_wis_strides,
|
||||
const std::array<index_t, 2>& gemm_m_k_strides,
|
||||
const std::array<index_t, 3>& gemm_g_m_k_strides,
|
||||
const std::array<index_t, NDimSpatial>& conv_filter_strides,
|
||||
const std::array<index_t, NDimSpatial>& conv_filter_dilations,
|
||||
const std::array<index_t, NDimSpatial>& input_left_pads,
|
||||
@@ -341,13 +365,14 @@ struct DeviceImageToColumnImpl
|
||||
{
|
||||
return std::make_unique<Argument>(static_cast<const InputDataType*>(p_in),
|
||||
static_cast<OutputDataType*>(p_out),
|
||||
G,
|
||||
N,
|
||||
C,
|
||||
input_spatial_lengths,
|
||||
filter_spatial_lengths,
|
||||
output_spatial_lengths,
|
||||
image_g_n_c_wis_strides,
|
||||
gemm_m_k_strides,
|
||||
gemm_g_m_k_strides,
|
||||
conv_filter_strides,
|
||||
conv_filter_dilations,
|
||||
input_left_pads,
|
||||
|
||||
Reference in New Issue
Block a user