conv+conv (1x1 only) example using gemm+gemm (#393)

* refactor conv

* add conv+conv example, 1x1 only
This commit is contained in:
Chao Liu
2022-08-31 11:27:11 -05:00
committed by GitHub
parent d00e6115b9
commit 4df6d93f60
14 changed files with 1524 additions and 1055 deletions

View File

@@ -16,6 +16,7 @@
#include "ck/tensor_operation/gpu/grid/gridwise_batched_gemm_gemm_xdl_cshuffle_v1.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#include "ck/host_utility/io.hpp"
namespace ck {
namespace tensor_operation {
@@ -464,6 +465,14 @@ struct DeviceBatchedGemmGemm_Xdl_CShuffle : public DeviceBatchedGemmGemm<ALayout
}
}
void Print() const
{
std::cout << "A[AK0, M, AK1]: " << a_grid_desc_ak0_m_ak1_ << std::endl;
std::cout << "B0[BK0, N, BK1]: " << b_grid_desc_bk0_n_bk1_ << std::endl;
std::cout << "B1[BK0, N, BK1]: " << b1_grid_desc_bk0_n_bk1_ << std::endl;
std::cout << "C[M, N]: " << c_grid_desc_m_n_ << std::endl;
}
// private:
const ADataType* p_a_grid_;
const BDataType* p_b_grid_;

View File

@@ -16,6 +16,7 @@
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#include "ck/host_utility/io.hpp"
namespace ck {
namespace tensor_operation {

View File

@@ -292,8 +292,6 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
using BGridDesc_BK0_N_BK1 = remove_cvref_t<decltype(
GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(BGridDesc_N_K{}))>;
using Block2ETileMap = typename GridwiseGemm::DefaultBlock2ETileMap;
// Argument
struct Argument : public BaseArgument
{
@@ -391,7 +389,7 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
e_grid_desc_mblock_mperblock_nblock_nperblock_;
// block-to-e-tile map
Block2ETileMap block_2_etile_map_;
typename GridwiseGemm::DefaultBlock2ETileMap block_2_etile_map_;
// element-wise op
AElementwiseOperation a_element_op_;

View File

@@ -3,7 +3,7 @@
#pragma once
#include <vector>
#include <array>
#include "ck/tensor_operation/gpu/device/device_base.hpp"

View File

@@ -13,8 +13,9 @@
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_d.hpp"
#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp"
#include "ck/tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_d.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp"
@@ -296,922 +297,71 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
static constexpr auto conv_to_gemm_transformer =
TransformConvFwdToGemm<NDimSpatial, ConvForwardSpecialization>{};
static constexpr auto matrix_padder =
MatrixPadder<GemmSpec, index_t, index_t, index_t>{MPerBlock, NPerBlock, KPerBlock};
template <typename ALay,
typename std::enable_if<NDimSpatial == 1 &&
is_same_v<ALay, tensor_layout::convolution::GNWC>,
bool>::type = false>
static auto
MakeAGridDescriptor_M_K(const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
const std::array<index_t, NDimSpatial + 3>& /* a_g_n_c_wis_strides */,
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
const std::array<index_t, NDimSpatial + 3>& /* b_g_k_c_xs_strides */,
const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
const std::array<index_t, NDimSpatial + 3>& /* e_g_n_k_wos_strides */,
const std::array<index_t, NDimSpatial>& conv_filter_strides,
const std::array<index_t, NDimSpatial>& conv_filter_dilations,
const std::array<index_t, NDimSpatial>& input_left_pads,
const std::array<index_t, NDimSpatial>& input_right_pads)
{
const index_t N = a_g_n_c_wis_lengths[1];
const index_t C = a_g_n_c_wis_lengths[2];
const index_t Wi = a_g_n_c_wis_lengths[3];
const index_t Wo = e_g_n_k_wos_lengths[3];
const index_t ConvStrideW = conv_filter_strides[0];
if constexpr(ConvForwardSpecialization ==
ConvolutionForwardSpecialization::Filter1x1Stride1Pad0)
{
const index_t NWo = N * std::accumulate(e_g_n_k_wos_lengths.begin() + 3,
e_g_n_k_wos_lengths.begin() + 3 + NDimSpatial,
index_t{1},
std::multiplies<index_t>());
const auto in_gemmmraw_gemmk_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(NWo, C));
const auto in_gemmm_gemmk_grid_desc =
matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmk_grid_desc);
return in_gemmm_gemmk_grid_desc;
}
else if constexpr(ConvForwardSpecialization ==
ConvolutionForwardSpecialization::Filter1x1Pad0)
{
const auto in_n_wi_c_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(N, Wi, C));
const auto in_n_wo_c_grid_desc = transform_tensor_descriptor(
in_n_wi_c_grid_desc,
make_tuple(make_pass_through_transform(N),
make_embed_transform(make_tuple(Wo), make_tuple(ConvStrideW)),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
const auto in_gemmmraw_gemmkraw_grid_desc = transform_tensor_descriptor(
in_n_wo_c_grid_desc,
make_tuple(make_merge_transform(make_tuple(N, Wo)), make_pass_through_transform(C)),
make_tuple(Sequence<0, 1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto in_gemmm_gemmk_grid_desc =
matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_grid_desc);
return in_gemmm_gemmk_grid_desc;
}
else
{
const index_t X = b_g_k_c_xs_lengths[3];
const index_t ConvDilationW = conv_filter_dilations[0];
const index_t InLeftPadW = input_left_pads[0];
const index_t InRightPadW = input_right_pads[0];
const auto in_n_wi_c_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(N, Wi, C));
const auto in_n_wip_c_grid_desc = transform_tensor_descriptor(
in_n_wi_c_grid_desc,
make_tuple(make_pass_through_transform(N),
make_pad_transform(Wi, InLeftPadW, InRightPadW),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
const auto in_n_x_wo_c_grid_desc = transform_tensor_descriptor(
in_n_wip_c_grid_desc,
make_tuple(
make_pass_through_transform(N),
make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}));
const auto in_gemmmraw_gemmk_grid_desc =
transform_tensor_descriptor(in_n_x_wo_c_grid_desc,
make_tuple(make_merge_transform(make_tuple(N, Wo)),
make_merge_transform(make_tuple(X, C))),
make_tuple(Sequence<0, 2>{}, Sequence<1, 3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto in_gemmm_gemmk_grid_desc =
matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmk_grid_desc);
return in_gemmm_gemmk_grid_desc;
}
}
template <typename ALay,
typename std::enable_if<NDimSpatial == 2 &&
is_same_v<ALay, tensor_layout::convolution::GNHWC>,
bool>::type = false>
static auto
MakeAGridDescriptor_M_K(const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
const std::array<index_t, NDimSpatial + 3>& /* a_g_n_c_wis_strides */,
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
const std::array<index_t, NDimSpatial + 3>& /* b_g_k_c_xs_strides */,
const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
const std::array<index_t, NDimSpatial + 3>& /* e_g_n_k_wos_strides */,
const std::array<index_t, NDimSpatial>& conv_filter_strides,
const std::array<index_t, NDimSpatial>& conv_filter_dilations,
const std::array<index_t, NDimSpatial>& input_left_pads,
const std::array<index_t, NDimSpatial>& input_right_pads)
{
const index_t N = a_g_n_c_wis_lengths[1];
const index_t C = a_g_n_c_wis_lengths[2];
const index_t Hi = a_g_n_c_wis_lengths[3];
const index_t Wi = a_g_n_c_wis_lengths[4];
const index_t Ho = e_g_n_k_wos_lengths[3];
const index_t Wo = e_g_n_k_wos_lengths[4];
const index_t ConvStrideH = conv_filter_strides[0];
const index_t ConvStrideW = conv_filter_strides[1];
if constexpr(ConvForwardSpecialization ==
ConvolutionForwardSpecialization::Filter1x1Stride1Pad0)
{
const index_t NHoWo = N * std::accumulate(e_g_n_k_wos_lengths.begin() + 3,
e_g_n_k_wos_lengths.begin() + 3 + NDimSpatial,
index_t{1},
std::multiplies<index_t>());
const auto in_gemmmraw_gemmkraw_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(NHoWo, C));
const auto in_gemmm_gemmk_grid_desc =
matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_grid_desc);
return in_gemmm_gemmk_grid_desc;
}
else if constexpr(ConvForwardSpecialization ==
ConvolutionForwardSpecialization::Filter1x1Pad0)
{
const auto in_n_hi_wi_c_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(N, Hi, Wi, C));
const auto in_n_ho_wo_c_grid_desc = transform_tensor_descriptor(
in_n_hi_wi_c_grid_desc,
make_tuple(make_pass_through_transform(N),
make_embed_transform(make_tuple(Ho), make_tuple(ConvStrideH)),
make_embed_transform(make_tuple(Wo), make_tuple(ConvStrideW)),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
const auto in_gemmmraw_gemmk_grid_desc =
transform_tensor_descriptor(in_n_ho_wo_c_grid_desc,
make_tuple(make_merge_transform(make_tuple(N, Ho, Wo)),
make_pass_through_transform(C)),
make_tuple(Sequence<0, 1, 2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto in_gemmm_gemmk_grid_desc =
matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmk_grid_desc);
return in_gemmm_gemmk_grid_desc;
}
else
{
const index_t Y = b_g_k_c_xs_lengths[3];
const index_t X = b_g_k_c_xs_lengths[4];
const index_t ConvDilationH = conv_filter_dilations[0];
const index_t ConvDilationW = conv_filter_dilations[1];
const index_t InLeftPadH = input_left_pads[0];
const index_t InLeftPadW = input_left_pads[1];
const index_t InRightPadH = input_right_pads[0];
const index_t InRightPadW = input_right_pads[1];
const auto in_n_hi_wi_c_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(N, Hi, Wi, C));
const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor(
in_n_hi_wi_c_grid_desc,
make_tuple(make_pass_through_transform(N),
make_pad_transform(Hi, InLeftPadH, InRightPadH),
make_pad_transform(Wi, InLeftPadW, InRightPadW),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor(
in_n_hip_wip_c_grid_desc,
make_tuple(
make_pass_through_transform(N),
make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)),
make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
const auto in_gemmmraw_gemmk_grid_desc =
transform_tensor_descriptor(in_n_y_ho_x_wo_c_grid_desc,
make_tuple(make_merge_transform(make_tuple(N, Ho, Wo)),
make_merge_transform(make_tuple(Y, X, C))),
make_tuple(Sequence<0, 2, 4>{}, Sequence<1, 3, 5>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto in_gemmm_gemmk_grid_desc =
matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmk_grid_desc);
return in_gemmm_gemmk_grid_desc;
}
}
template <typename ALay,
typename std::enable_if<NDimSpatial == 3 &&
is_same_v<ALay, tensor_layout::convolution::GNDHWC>,
bool>::type = false>
static auto
MakeAGridDescriptor_M_K(const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
const std::array<index_t, NDimSpatial + 3>& /* a_g_n_c_wis_strides */,
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
const std::array<index_t, NDimSpatial + 3>& /* b_g_k_c_xs_strides */,
const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
const std::array<index_t, NDimSpatial + 3>& /* e_g_n_k_wos_strides */,
const std::array<index_t, NDimSpatial>& conv_filter_strides,
const std::array<index_t, NDimSpatial>& conv_filter_dilations,
const std::array<index_t, NDimSpatial>& input_left_pads,
const std::array<index_t, NDimSpatial>& input_right_pads)
{
const index_t N = a_g_n_c_wis_lengths[1];
const index_t C = a_g_n_c_wis_lengths[2];
const index_t Di = a_g_n_c_wis_lengths[3];
const index_t Hi = a_g_n_c_wis_lengths[4];
const index_t Wi = a_g_n_c_wis_lengths[5];
const index_t Do = e_g_n_k_wos_lengths[3];
const index_t Ho = e_g_n_k_wos_lengths[4];
const index_t Wo = e_g_n_k_wos_lengths[5];
const index_t ConvStrideD = conv_filter_strides[0];
const index_t ConvStrideH = conv_filter_strides[1];
const index_t ConvStrideW = conv_filter_strides[2];
if constexpr(ConvForwardSpecialization ==
ConvolutionForwardSpecialization::Filter1x1Stride1Pad0)
{
const index_t NDoHoWo =
N * std::accumulate(e_g_n_k_wos_lengths.begin() + 3,
e_g_n_k_wos_lengths.begin() + 3 + NDimSpatial,
index_t{1},
std::multiplies<index_t>());
const auto in_gemmmraw_gemmkraw_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(NDoHoWo, C));
const auto in_gemmm_gemmk_grid_desc =
matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_grid_desc);
return in_gemmm_gemmk_grid_desc;
}
else if constexpr(ConvForwardSpecialization ==
ConvolutionForwardSpecialization::Filter1x1Pad0)
{
const auto in_n_di_hi_wi_c_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(N, Di, Hi, Wi, C));
const auto in_n_do_ho_wo_c_grid_desc = transform_tensor_descriptor(
in_n_di_hi_wi_c_grid_desc,
make_tuple(make_pass_through_transform(N),
make_embed_transform(make_tuple(Do), make_tuple(ConvStrideD)),
make_embed_transform(make_tuple(Ho), make_tuple(ConvStrideH)),
make_embed_transform(make_tuple(Wo), make_tuple(ConvStrideW)),
make_pass_through_transform(C)),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}));
const auto in_gemmmraw_gemmkraw_grid_desc = transform_tensor_descriptor(
in_n_do_ho_wo_c_grid_desc,
make_tuple(make_merge_transform(make_tuple(N, Do, Ho, Wo)),
make_pass_through_transform(C)),
make_tuple(Sequence<0, 1, 2, 3>{}, Sequence<4>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto in_gemmm_gemmk_grid_desc =
matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_grid_desc);
return in_gemmm_gemmk_grid_desc;
}
else
{
const index_t Z = b_g_k_c_xs_lengths[3];
const index_t Y = b_g_k_c_xs_lengths[4];
const index_t X = b_g_k_c_xs_lengths[5];
const index_t ConvDilationD = conv_filter_dilations[0];
const index_t ConvDilationH = conv_filter_dilations[1];
const index_t ConvDilationW = conv_filter_dilations[2];
const index_t InLeftPadD = input_left_pads[0];
const index_t InLeftPadH = input_left_pads[1];
const index_t InLeftPadW = input_left_pads[2];
const index_t InRightPadD = input_right_pads[0];
const index_t InRightPadH = input_right_pads[1];
const index_t InRightPadW = input_right_pads[2];
const auto in_n_di_hi_wi_c_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(N, Di, Hi, Wi, C));
const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor(
in_n_di_hi_wi_c_grid_desc,
make_tuple(make_pass_through_transform(N),
make_pad_transform(Di, InLeftPadD, InRightPadD),
make_pad_transform(Hi, InLeftPadH, InRightPadH),
make_pad_transform(Wi, InLeftPadW, InRightPadW),
make_pass_through_transform(C)),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}));
const auto in_n_z_do_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor(
in_n_hip_wip_c_grid_desc,
make_tuple(
make_pass_through_transform(N),
make_embed_transform(make_tuple(Z, Do), make_tuple(ConvDilationD, ConvStrideD)),
make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)),
make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)),
make_pass_through_transform(C)),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}),
make_tuple(Sequence<0>{},
Sequence<1, 2>{},
Sequence<3, 4>{},
Sequence<5, 6>{},
Sequence<7>{}));
const auto in_gemmmraw_gemmkraw_grid_desc = transform_tensor_descriptor(
in_n_z_do_y_ho_x_wo_c_grid_desc,
make_tuple(make_merge_transform(make_tuple(N, Do, Ho, Wo)),
make_merge_transform(make_tuple(Z, Y, X, C))),
make_tuple(Sequence<0, 2, 4, 6>{}, Sequence<1, 3, 5, 7>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto in_gemmm_gemmk_grid_desc =
matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_grid_desc);
return in_gemmm_gemmk_grid_desc;
}
}
// TODO: implement ck::tensor_layout::convolution that describe packed/strided dimemsion as
// properties
template <typename ALay,
typename std::enable_if<NDimSpatial == 1 &&
(is_same_v<ALay, tensor_layout::convolution::G_NW_C> ||
is_same_v<ALay, tensor_layout::convolution::NWGC>),
bool>::type = false>
template <typename ALay>
static auto
MakeAGridDescriptor_M_K(const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
const std::array<index_t, NDimSpatial + 3>& /* b_g_k_c_xs_strides */,
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
const std::array<index_t, NDimSpatial + 3>& /* e_g_n_k_wos_strides */,
const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
const std::array<index_t, NDimSpatial>& conv_filter_strides,
const std::array<index_t, NDimSpatial>& conv_filter_dilations,
const std::array<index_t, NDimSpatial>& input_left_pads,
const std::array<index_t, NDimSpatial>& input_right_pads)
{
const index_t N = a_g_n_c_wis_lengths[1];
const index_t C = a_g_n_c_wis_lengths[2];
const auto in_gemmmraw_gemmkraw_desc =
conv_to_gemm_transformer.template MakeADescriptor_M_K<ALay>(a_g_n_c_wis_lengths,
a_g_n_c_wis_strides,
b_g_k_c_xs_lengths,
b_g_k_c_xs_strides,
e_g_n_k_wos_lengths,
e_g_n_k_wos_strides,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads);
const index_t Wi = a_g_n_c_wis_lengths[3];
const auto in_gemmm_gemmk_desc =
matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_desc);
const index_t Wo = e_g_n_k_wos_lengths[3];
const index_t ConvStrideW = conv_filter_strides[0];
if constexpr(ConvForwardSpecialization ==
ConvolutionForwardSpecialization::Filter1x1Stride1Pad0)
{
const index_t NHoWo = N * std::accumulate(e_g_n_k_wos_lengths.begin() + 3,
e_g_n_k_wos_lengths.begin() + 3 + NDimSpatial,
index_t{1},
std::multiplies<index_t>());
// This is different
const index_t WiStride = a_g_n_c_wis_strides[2 + NDimSpatial];
const auto CStride = I1;
const auto in_gemmmraw_gemmk_grid_desc =
make_naive_tensor_descriptor(make_tuple(NHoWo, C), make_tuple(WiStride, CStride));
const auto in_gemmm_gemmk_grid_desc =
matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmk_grid_desc);
return in_gemmm_gemmk_grid_desc;
}
else if constexpr(ConvForwardSpecialization ==
ConvolutionForwardSpecialization::Filter1x1Pad0)
{
// This is different
const index_t NStride = a_g_n_c_wis_strides[1];
const index_t WiStride = a_g_n_c_wis_strides[3];
const auto CStride = I1;
const auto in_n_wi_c_grid_desc = make_naive_tensor_descriptor(
make_tuple(N, Wi, C), make_tuple(NStride, WiStride, CStride));
const auto in_n_wo_c_grid_desc = transform_tensor_descriptor(
in_n_wi_c_grid_desc,
make_tuple(make_pass_through_transform(N),
make_embed_transform(make_tuple(Wo), make_tuple(ConvStrideW)),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
const auto in_gemmmraw_gemmkraw_grid_desc = transform_tensor_descriptor(
in_n_wo_c_grid_desc,
make_tuple(make_merge_transform(make_tuple(N, Wo)), make_pass_through_transform(C)),
make_tuple(Sequence<0, 1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto in_gemmm_gemmk_grid_desc =
matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_grid_desc);
return in_gemmm_gemmk_grid_desc;
}
else
{
const index_t X = b_g_k_c_xs_lengths[3];
const index_t ConvDilationW = conv_filter_dilations[0];
const index_t InLeftPadW = input_left_pads[0];
const index_t InRightPadW = input_right_pads[0];
// This is different
const index_t NStride = a_g_n_c_wis_strides[1];
const index_t WiStride = a_g_n_c_wis_strides[3];
const auto CStride = I1;
const auto in_n_wi_c_grid_desc = make_naive_tensor_descriptor(
make_tuple(N, Wi, C), make_tuple(NStride, WiStride, CStride));
const auto in_n_wip_c_grid_desc = transform_tensor_descriptor(
in_n_wi_c_grid_desc,
make_tuple(make_pass_through_transform(N),
make_pad_transform(Wi, InLeftPadW, InRightPadW),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
const auto in_n_x_wo_c_grid_desc = transform_tensor_descriptor(
in_n_wip_c_grid_desc,
make_tuple(
make_pass_through_transform(N),
make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}));
const auto in_gemmmraw_gemmk_grid_desc =
transform_tensor_descriptor(in_n_x_wo_c_grid_desc,
make_tuple(make_merge_transform(make_tuple(N, Wo)),
make_merge_transform(make_tuple(X, C))),
make_tuple(Sequence<0, 2>{}, Sequence<1, 3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto in_gemmm_gemmk_grid_desc =
matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmk_grid_desc);
return in_gemmm_gemmk_grid_desc;
}
return in_gemmm_gemmk_desc;
}
template <typename ALay,
typename std::enable_if<NDimSpatial == 2 &&
(is_same_v<ALay, tensor_layout::convolution::G_NHW_C> ||
is_same_v<ALay, tensor_layout::convolution::NHWGC>),
bool>::type = false>
static auto
MakeAGridDescriptor_M_K(const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
const std::array<index_t, NDimSpatial + 3>& /* b_g_k_c_xs_strides */,
const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
const std::array<index_t, NDimSpatial + 3>& /* e_g_n_k_wos_strides */,
const std::array<index_t, NDimSpatial>& conv_filter_strides,
const std::array<index_t, NDimSpatial>& conv_filter_dilations,
const std::array<index_t, NDimSpatial>& input_left_pads,
const std::array<index_t, NDimSpatial>& input_right_pads)
{
const index_t N = a_g_n_c_wis_lengths[1];
const index_t C = a_g_n_c_wis_lengths[2];
const index_t Hi = a_g_n_c_wis_lengths[3];
const index_t Wi = a_g_n_c_wis_lengths[4];
const index_t Ho = e_g_n_k_wos_lengths[3];
const index_t Wo = e_g_n_k_wos_lengths[4];
const index_t ConvStrideH = conv_filter_strides[0];
const index_t ConvStrideW = conv_filter_strides[1];
if constexpr(ConvForwardSpecialization ==
ConvolutionForwardSpecialization::Filter1x1Stride1Pad0)
{
const index_t NHoWo = N * std::accumulate(e_g_n_k_wos_lengths.begin() + 3,
e_g_n_k_wos_lengths.begin() + 3 + NDimSpatial,
index_t{1},
std::multiplies<index_t>());
// This is different
const index_t WiStride = a_g_n_c_wis_strides[2 + NDimSpatial];
const auto CStride = I1;
const auto in_gemmmraw_gemmkraw_grid_desc =
make_naive_tensor_descriptor(make_tuple(NHoWo, C), make_tuple(WiStride, CStride));
const auto in_gemmm_gemmk_grid_desc =
matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_grid_desc);
return in_gemmm_gemmk_grid_desc;
}
else if constexpr(ConvForwardSpecialization ==
ConvolutionForwardSpecialization::Filter1x1Pad0)
{
// This is different
const index_t NStride = a_g_n_c_wis_strides[1];
const index_t HiStride = a_g_n_c_wis_strides[3];
const index_t WiStride = a_g_n_c_wis_strides[4];
const auto CStride = I1;
const auto in_n_hi_wi_c_grid_desc = make_naive_tensor_descriptor(
make_tuple(N, Hi, Wi, C), make_tuple(NStride, HiStride, WiStride, CStride));
const auto in_n_ho_wo_c_grid_desc = transform_tensor_descriptor(
in_n_hi_wi_c_grid_desc,
make_tuple(make_pass_through_transform(N),
make_embed_transform(make_tuple(Ho), make_tuple(ConvStrideH)),
make_embed_transform(make_tuple(Wo), make_tuple(ConvStrideW)),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
const auto in_gemmmraw_gemmk_grid_desc =
transform_tensor_descriptor(in_n_ho_wo_c_grid_desc,
make_tuple(make_merge_transform(make_tuple(N, Ho, Wo)),
make_pass_through_transform(C)),
make_tuple(Sequence<0, 1, 2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto in_gemmm_gemmk_grid_desc =
matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmk_grid_desc);
return in_gemmm_gemmk_grid_desc;
}
else
{
const index_t Y = b_g_k_c_xs_lengths[3];
const index_t X = b_g_k_c_xs_lengths[4];
const index_t ConvDilationH = conv_filter_dilations[0];
const index_t ConvDilationW = conv_filter_dilations[1];
const index_t InLeftPadH = input_left_pads[0];
const index_t InLeftPadW = input_left_pads[1];
const index_t InRightPadH = input_right_pads[0];
const index_t InRightPadW = input_right_pads[1];
// This is different
const index_t NStride = a_g_n_c_wis_strides[1];
const index_t HiStride = a_g_n_c_wis_strides[3];
const index_t WiStride = a_g_n_c_wis_strides[4];
const auto CStride = I1;
const auto in_n_hi_wi_c_grid_desc = make_naive_tensor_descriptor(
make_tuple(N, Hi, Wi, C), make_tuple(NStride, HiStride, WiStride, CStride));
const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor(
in_n_hi_wi_c_grid_desc,
make_tuple(make_pass_through_transform(N),
make_pad_transform(Hi, InLeftPadH, InRightPadH),
make_pad_transform(Wi, InLeftPadW, InRightPadW),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor(
in_n_hip_wip_c_grid_desc,
make_tuple(
make_pass_through_transform(N),
make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)),
make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
const auto in_gemmmraw_gemmk_grid_desc =
transform_tensor_descriptor(in_n_y_ho_x_wo_c_grid_desc,
make_tuple(make_merge_transform(make_tuple(N, Ho, Wo)),
make_merge_transform(make_tuple(Y, X, C))),
make_tuple(Sequence<0, 2, 4>{}, Sequence<1, 3, 5>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto in_gemmm_gemmk_grid_desc =
matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmk_grid_desc);
return in_gemmm_gemmk_grid_desc;
}
}
template <typename ALay,
typename std::enable_if<NDimSpatial == 3 &&
(is_same_v<ALay, tensor_layout::convolution::G_NDHW_C> ||
is_same_v<ALay, tensor_layout::convolution::NDHWGC>),
bool>::type = false>
static auto
MakeAGridDescriptor_M_K(const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
const std::array<index_t, NDimSpatial + 3>& /* b_g_k_c_xs_strides */,
const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
const std::array<index_t, NDimSpatial + 3>& /* e_g_n_k_wos_strides */,
const std::array<index_t, NDimSpatial>& conv_filter_strides,
const std::array<index_t, NDimSpatial>& conv_filter_dilations,
const std::array<index_t, NDimSpatial>& input_left_pads,
const std::array<index_t, NDimSpatial>& input_right_pads)
{
const index_t N = a_g_n_c_wis_lengths[1];
const index_t C = a_g_n_c_wis_lengths[2];
const index_t Di = a_g_n_c_wis_lengths[3];
const index_t Hi = a_g_n_c_wis_lengths[4];
const index_t Wi = a_g_n_c_wis_lengths[5];
const index_t Do = e_g_n_k_wos_lengths[3];
const index_t Ho = e_g_n_k_wos_lengths[4];
const index_t Wo = e_g_n_k_wos_lengths[5];
const index_t ConvStrideD = conv_filter_strides[0];
const index_t ConvStrideH = conv_filter_strides[1];
const index_t ConvStrideW = conv_filter_strides[2];
if constexpr(ConvForwardSpecialization ==
ConvolutionForwardSpecialization::Filter1x1Stride1Pad0)
{
const index_t NDoHoWo =
N * std::accumulate(e_g_n_k_wos_lengths.begin() + 3,
e_g_n_k_wos_lengths.begin() + 3 + NDimSpatial,
index_t{1},
std::multiplies<index_t>());
// This is different
const index_t WiStride = a_g_n_c_wis_strides[2 + NDimSpatial];
const auto CStride = I1;
const auto in_gemmmraw_gemmkraw_grid_desc =
make_naive_tensor_descriptor(make_tuple(NDoHoWo, C), make_tuple(WiStride, CStride));
const auto in_gemmm_gemmk_grid_desc =
matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_grid_desc);
return in_gemmm_gemmk_grid_desc;
}
else if constexpr(ConvForwardSpecialization ==
ConvolutionForwardSpecialization::Filter1x1Pad0)
{
// This is different
const index_t NStride = a_g_n_c_wis_strides[1];
const index_t DiStride = a_g_n_c_wis_strides[3];
const index_t HiStride = a_g_n_c_wis_strides[4];
const index_t WiStride = a_g_n_c_wis_strides[5];
const auto CStride = I1;
const auto in_n_di_hi_wi_c_grid_desc = make_naive_tensor_descriptor(
make_tuple(N, Di, Hi, Wi, C),
make_tuple(NStride, DiStride, HiStride, WiStride, CStride));
const auto in_n_do_ho_wo_c_grid_desc = transform_tensor_descriptor(
in_n_di_hi_wi_c_grid_desc,
make_tuple(make_pass_through_transform(N),
make_embed_transform(make_tuple(Do), make_tuple(ConvStrideD)),
make_embed_transform(make_tuple(Ho), make_tuple(ConvStrideH)),
make_embed_transform(make_tuple(Wo), make_tuple(ConvStrideW)),
make_pass_through_transform(C)),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}));
const auto in_gemmmraw_gemmkraw_grid_desc = transform_tensor_descriptor(
in_n_do_ho_wo_c_grid_desc,
make_tuple(make_merge_transform(make_tuple(N, Do, Ho, Wo)),
make_pass_through_transform(C)),
make_tuple(Sequence<0, 1, 2, 3>{}, Sequence<4>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto in_gemmm_gemmk_grid_desc =
matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_grid_desc);
return in_gemmm_gemmk_grid_desc;
}
else
{
const index_t Z = b_g_k_c_xs_lengths[3];
const index_t Y = b_g_k_c_xs_lengths[4];
const index_t X = b_g_k_c_xs_lengths[5];
const index_t ConvDilationD = conv_filter_dilations[0];
const index_t ConvDilationH = conv_filter_dilations[1];
const index_t ConvDilationW = conv_filter_dilations[2];
const index_t InLeftPadD = input_left_pads[0];
const index_t InLeftPadH = input_left_pads[1];
const index_t InLeftPadW = input_left_pads[2];
const index_t InRightPadD = input_right_pads[0];
const index_t InRightPadH = input_right_pads[1];
const index_t InRightPadW = input_right_pads[2];
// This is different
const index_t NStride = a_g_n_c_wis_strides[1];
const index_t DiStride = a_g_n_c_wis_strides[3];
const index_t HiStride = a_g_n_c_wis_strides[4];
const index_t WiStride = a_g_n_c_wis_strides[5];
const auto CStride = I1;
const auto in_n_di_hi_wi_c_grid_desc = make_naive_tensor_descriptor(
make_tuple(N, Di, Hi, Wi, C),
make_tuple(NStride, DiStride, HiStride, WiStride, CStride));
const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor(
in_n_di_hi_wi_c_grid_desc,
make_tuple(make_pass_through_transform(N),
make_pad_transform(Di, InLeftPadD, InRightPadD),
make_pad_transform(Hi, InLeftPadH, InRightPadH),
make_pad_transform(Wi, InLeftPadW, InRightPadW),
make_pass_through_transform(C)),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}));
const auto in_n_z_do_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor(
in_n_hip_wip_c_grid_desc,
make_tuple(
make_pass_through_transform(N),
make_embed_transform(make_tuple(Z, Do), make_tuple(ConvDilationD, ConvStrideD)),
make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)),
make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)),
make_pass_through_transform(C)),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}),
make_tuple(Sequence<0>{},
Sequence<1, 2>{},
Sequence<3, 4>{},
Sequence<5, 6>{},
Sequence<7>{}));
const auto in_gemmmraw_gemmkraw_grid_desc = transform_tensor_descriptor(
in_n_z_do_y_ho_x_wo_c_grid_desc,
make_tuple(make_merge_transform(make_tuple(N, Do, Ho, Wo)),
make_merge_transform(make_tuple(Z, Y, X, C))),
make_tuple(Sequence<0, 2, 4, 6>{}, Sequence<1, 3, 5, 7>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto in_gemmm_gemmk_grid_desc =
matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_grid_desc);
return in_gemmm_gemmk_grid_desc;
}
}
template <typename BLay,
typename std::enable_if<is_same_v<BLay, tensor_layout::convolution::GKXC> ||
is_same_v<BLay, tensor_layout::convolution::GKYXC> ||
is_same_v<BLay, tensor_layout::convolution::GKZYXC>,
bool>::type = false>
static auto
MakeBGridDescriptor_N_K(const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
const std::array<index_t, NDimSpatial + 3>& /* b_g_k_c_xs_strides */)
{
const index_t K = b_g_k_c_xs_lengths[1];
const index_t C = b_g_k_c_xs_lengths[2];
const index_t YX = std::accumulate(b_g_k_c_xs_lengths.begin() + 3,
b_g_k_c_xs_lengths.begin() + 3 + NDimSpatial,
index_t{1},
std::multiplies<index_t>());
const auto wei_k_yxc_grid_desc = make_naive_tensor_descriptor_packed(make_tuple(K, YX * C));
const auto wei_gemmn_gemmk_grid_desc =
matrix_padder.PadBDescriptor_N_K(wei_k_yxc_grid_desc);
return wei_gemmn_gemmk_grid_desc;
}
template <typename BLay,
typename std::enable_if<is_same_v<BLay, tensor_layout::convolution::G_K_X_C> ||
is_same_v<BLay, tensor_layout::convolution::G_K_YX_C> ||
is_same_v<BLay, tensor_layout::convolution::G_K_ZYX_C> ||
is_same_v<BLay, tensor_layout::convolution::KXGC> ||
is_same_v<BLay, tensor_layout::convolution::KYXGC> ||
is_same_v<BLay, tensor_layout::convolution::KZYXGC>,
bool>::type = false>
template <typename BLay>
static auto
MakeBGridDescriptor_N_K(const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides)
{
const index_t K = b_g_k_c_xs_lengths[1];
const index_t C = b_g_k_c_xs_lengths[2];
const auto wei_gemmnraw_gemmkraw_desc =
conv_to_gemm_transformer.template MakeBDescriptor_N_K<BLay>(b_g_k_c_xs_lengths,
b_g_k_c_xs_strides);
const index_t YX = std::accumulate(b_g_k_c_xs_lengths.begin() + 3,
b_g_k_c_xs_lengths.begin() + 3 + NDimSpatial,
index_t{1},
std::multiplies<index_t>());
const auto wei_gemmn_gemmk_desc =
matrix_padder.PadBDescriptor_N_K(wei_gemmnraw_gemmkraw_desc);
const index_t KStride = b_g_k_c_xs_strides[1];
const index_t XStride = b_g_k_c_xs_strides[2 + NDimSpatial];
const auto CStride = I1;
const auto wei_k_yx_c_grid_desc = make_naive_tensor_descriptor(
make_tuple(K, YX, C), make_tuple(KStride, XStride, CStride));
const auto wei_gemmnraw_gemmkraw_grid_desc = transform_tensor_descriptor(
wei_k_yx_c_grid_desc,
make_tuple(make_pass_through_transform(K), make_merge_transform(make_tuple(YX, C))),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto wei_gemmn_gemmk_grid_desc =
matrix_padder.PadBDescriptor_N_K(wei_gemmnraw_gemmkraw_grid_desc);
return wei_gemmn_gemmk_grid_desc;
return wei_gemmn_gemmk_desc;
}
template <typename ELay,
typename std::enable_if<is_same_v<ELay, tensor_layout::convolution::GNWK> ||
is_same_v<ELay, tensor_layout::convolution::GNHWK> ||
is_same_v<ELay, tensor_layout::convolution::GNDHWK>,
bool>::type = false>
static auto
MakeEGridDescriptor_M_N(const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
const std::array<index_t, NDimSpatial + 3>& /* e_g_n_k_wos_strides */)
{
const index_t N = e_g_n_k_wos_lengths[1];
const index_t K = e_g_n_k_wos_lengths[2];
const index_t NHoWo = N * std::accumulate(e_g_n_k_wos_lengths.begin() + 3,
e_g_n_k_wos_lengths.begin() + 3 + NDimSpatial,
index_t{1},
std::multiplies<index_t>());
const auto out_gemmmraw_gemmnraw_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(NHoWo, K));
const auto out_gemmm_gemmn_grid_desc =
matrix_padder.PadCDescriptor_M_N(out_gemmmraw_gemmnraw_grid_desc);
return out_gemmm_gemmn_grid_desc;
}
template <typename ELay,
typename std::enable_if<is_same_v<ELay, tensor_layout::convolution::G_NW_K> ||
is_same_v<ELay, tensor_layout::convolution::G_NHW_K> ||
is_same_v<ELay, tensor_layout::convolution::G_NDHW_K> ||
is_same_v<ELay, tensor_layout::convolution::NWGK> ||
is_same_v<ELay, tensor_layout::convolution::NHWGK> ||
is_same_v<ELay, tensor_layout::convolution::NDHWGK>,
bool>::type = false>
template <typename ELay>
static auto
MakeEGridDescriptor_M_N(const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides)
{
const index_t N = e_g_n_k_wos_lengths[1];
const index_t K = e_g_n_k_wos_lengths[2];
const auto out_gemmmraw_gemmnraw_desc =
conv_to_gemm_transformer.template MakeCDescriptor_M_N<ELay>(e_g_n_k_wos_lengths,
e_g_n_k_wos_strides);
const auto KStride = I1;
const index_t WoStride = e_g_n_k_wos_strides[NDimSpatial + 2];
const auto out_gemmm_gemmn_desc =
matrix_padder.PadCDescriptor_M_N(out_gemmmraw_gemmnraw_desc);
const index_t NHoWo = N * std::accumulate(e_g_n_k_wos_lengths.begin() + 3,
e_g_n_k_wos_lengths.begin() + 3 + NDimSpatial,
index_t{1},
std::multiplies<index_t>());
const auto out_gemmmraw_gemmnraw_grid_desc =
make_naive_tensor_descriptor(make_tuple(NHoWo, K), make_tuple(WoStride, KStride));
const auto out_gemmm_gemmn_grid_desc =
matrix_padder.PadCDescriptor_M_N(out_gemmmraw_gemmnraw_grid_desc);
return out_gemmm_gemmn_grid_desc;
return out_gemmm_gemmn_desc;
}
static auto MakeDsGridDescriptor_M_N(

View File

@@ -12,70 +12,45 @@ namespace ck {
namespace tensor_operation {
namespace device {
// For padding tensors without batch dimension
template <bool PadM,
bool PadN,
typename TensorDesc_MRaw_NRaw,
typename MPerBlockType,
typename NPerBlockType,
enable_if_t<TensorDesc_MRaw_NRaw::GetNumOfVisibleDimension() == 2, bool> = false>
template <typename TensorDesc,
typename TileLengths, // Tuple<...>
typename DoPads> // Sequence<bool, bool, ...>
__host__ __device__ constexpr auto
PadTensorDescriptor(const TensorDesc_MRaw_NRaw& tensor_desc_mraw_nraw,
MPerBlockType MPerBlock,
NPerBlockType NPerBlock)
PadTensorDescriptor(const TensorDesc& desc, const TileLengths& tile_lengths, DoPads)
{
const auto MRaw = tensor_desc_mraw_nraw.GetLength(Number<0>{});
const auto NRaw = tensor_desc_mraw_nraw.GetLength(Number<1>{});
constexpr index_t num_dim = DoPads::Size();
const auto M = math::integer_divide_ceil(MRaw, MPerBlock) * MPerBlock;
const auto N = math::integer_divide_ceil(NRaw, NPerBlock) * NPerBlock;
static_assert(num_dim == TileLengths::Size() && num_dim == TensorDesc::GetNumOfDimension(),
"wrong! inconsistent # of dimensions");
const auto MPad = M - MRaw;
const auto NPad = N - NRaw;
// transforms
const auto transforms = generate_tuple(
[&](auto idim) {
const auto MRaw = desc.GetLength(idim);
const auto MTransform = conditional_expr<PadM>(make_right_pad_transform(MRaw, MPad),
make_pass_through_transform(MRaw));
const auto NTransform = conditional_expr<PadN>(make_right_pad_transform(NRaw, NPad),
make_pass_through_transform(NRaw));
const auto MPerTile = tile_lengths[idim];
return transform_tensor_descriptor(tensor_desc_mraw_nraw,
make_tuple(MTransform, NTransform),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
const auto M = math::integer_divide_ceil(MRaw, MPerTile) * MPerTile;
// For padding tensors with batch dimension
template <bool PadM,
bool PadN,
typename TensorDesc_GRaw_MRaw_NRaw,
typename MPerBlockType,
typename NPerBlockType,
enable_if_t<TensorDesc_GRaw_MRaw_NRaw::GetNumOfVisibleDimension() == 3, bool> = false>
__host__ __device__ constexpr auto
PadTensorDescriptor(const TensorDesc_GRaw_MRaw_NRaw& tensor_desc_graw_mraw_nraw,
MPerBlockType MPerBlock,
NPerBlockType NPerBlock)
{
const auto GRaw = tensor_desc_graw_mraw_nraw.GetLength(Number<0>{});
const auto MRaw = tensor_desc_graw_mraw_nraw.GetLength(Number<1>{});
const auto NRaw = tensor_desc_graw_mraw_nraw.GetLength(Number<2>{});
const auto MPad = M - MRaw;
const auto M = math::integer_divide_ceil(MRaw, MPerBlock) * MPerBlock;
const auto N = math::integer_divide_ceil(NRaw, NPerBlock) * NPerBlock;
const bool DoPadM = DoPads::At(idim);
const auto MPad = M - MRaw;
const auto NPad = N - NRaw;
const auto MTransform = conditional_expr<DoPadM>(make_right_pad_transform(MRaw, MPad),
make_pass_through_transform(MRaw));
const auto MTransform = conditional_expr<PadM>(make_right_pad_transform(MRaw, MPad),
make_pass_through_transform(MRaw));
const auto NTransform = conditional_expr<PadN>(make_right_pad_transform(NRaw, NPad),
make_pass_through_transform(NRaw));
return MTransform;
},
Number<num_dim>{});
return transform_tensor_descriptor(
tensor_desc_graw_mraw_nraw,
make_tuple(make_pass_through_transform(GRaw), MTransform, NTransform),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
// lower dimension Id
const auto lower_dimss =
generate_tuple([&](auto idim) { return Sequence<idim.value>{}; }, Number<num_dim>{});
// upper dimension Id
const auto upper_dimss = lower_dimss;
return transform_tensor_descriptor(desc, transforms, lower_dimss, upper_dimss);
}
// M/N/K/OPerTileType could be index_t or Number<>
@@ -113,7 +88,8 @@ struct GemmGemmPadder
__host__ __device__ constexpr auto
PadADescriptor_M_K(const ADesc_MRaw_KRaw& a_desc_mraw_kraw) const
{
return PadTensorDescriptor<PadM, PadK>(a_desc_mraw_kraw, MPerTile_, KPerTile_);
return PadTensorDescriptor(
a_desc_mraw_kraw, make_tuple(MPerTile_, KPerTile_), Sequence<PadM, PadK>{});
}
// B[K, N]
@@ -121,7 +97,8 @@ struct GemmGemmPadder
__host__ __device__ constexpr auto
PadBDescriptor_N_K(const BDesc_NRaw_KRaw& b_desc_nraw_kraw) const
{
return PadTensorDescriptor<PadN, PadK>(b_desc_nraw_kraw, NPerTile_, KPerTile_);
return PadTensorDescriptor(
b_desc_nraw_kraw, make_tuple(NPerTile_, KPerTile_), Sequence<PadN, PadK>{});
}
// B1[Gemm1N, Gemm1K] = B1[O, N]
@@ -129,7 +106,8 @@ struct GemmGemmPadder
__host__ __device__ constexpr auto
PadB1Descriptor_N_K(const B1Desc_NRaw_KRaw& b1_desc_nraw_kraw) const
{
return PadTensorDescriptor<PadO, PadN>(b1_desc_nraw_kraw, OPerTile_, NPerTile_);
return PadTensorDescriptor(
b1_desc_nraw_kraw, make_tuple(OPerTile_, NPerTile_), Sequence<PadO, PadN>{});
}
// C[M, Gemm1N] = C[M, O]
@@ -137,7 +115,8 @@ struct GemmGemmPadder
__host__ __device__ constexpr auto
PadCDescriptor_M_N(const CDesc_MRaw_NRaw& c_desc_mraw_nraw) const
{
return PadTensorDescriptor<PadM, PadO>(c_desc_mraw_nraw, MPerTile_, OPerTile_);
return PadTensorDescriptor(
c_desc_mraw_nraw, make_tuple(MPerTile_, OPerTile_), Sequence<PadM, PadO>{});
}
MPerTileType MPerTile_;
@@ -167,21 +146,24 @@ struct GemmPadder
__host__ __device__ constexpr auto
PadADescriptor_M_K(const ADesc_MRaw_KRaw& a_desc_mraw_kraw) const
{
return PadTensorDescriptor<PadM, PadK>(a_desc_mraw_kraw, MPerTile_, KPerTile_);
return PadTensorDescriptor(
a_desc_mraw_kraw, make_tuple(MPerTile_, KPerTile_), Sequence<PadM, PadK>{});
}
template <typename BDesc_NRaw_KRaw>
__host__ __device__ constexpr auto
PadBDescriptor_N_K(const BDesc_NRaw_KRaw& b_desc_nraw_kraw) const
{
return PadTensorDescriptor<PadN, PadK>(b_desc_nraw_kraw, NPerTile_, KPerTile_);
return PadTensorDescriptor(
b_desc_nraw_kraw, make_tuple(NPerTile_, KPerTile_), Sequence<PadN, PadK>{});
}
template <typename CDesc_MRaw_NRaw>
__host__ __device__ constexpr auto
PadCDescriptor_M_N(const CDesc_MRaw_NRaw& c_desc_mraw_nraw) const
{
return PadTensorDescriptor<PadM, PadN>(c_desc_mraw_nraw, MPerTile_, NPerTile_);
return PadTensorDescriptor(
c_desc_mraw_nraw, make_tuple(MPerTile_, NPerTile_), Sequence<PadM, PadN>{});
}
MPerTileType MPerTile_;
@@ -198,6 +180,44 @@ struct MatrixPadder : public GemmPadder<GemmSpec, MPerTileType, NPerTileType, KP
{
};
// M/N/KPerTileType could be index_t or Number<>
template <bool PadM,
bool PadN,
bool PadK,
typename MPerTileType,
typename NPerTileType,
typename KPerTileType>
struct GemmPadder_v2
{
template <typename ADesc_MRaw_KRaw>
__host__ __device__ constexpr auto
PadADescriptor_M_K(const ADesc_MRaw_KRaw& a_desc_mraw_kraw) const
{
return PadTensorDescriptor(
a_desc_mraw_kraw, make_tuple(MPerTile_, KPerTile_), Sequence<PadM, PadK>{});
}
template <typename BDesc_NRaw_KRaw>
__host__ __device__ constexpr auto
PadBDescriptor_N_K(const BDesc_NRaw_KRaw& b_desc_nraw_kraw) const
{
return PadTensorDescriptor(
b_desc_nraw_kraw, make_tuple(NPerTile_, KPerTile_), Sequence<PadN, PadK>{});
}
template <typename CDesc_MRaw_NRaw>
__host__ __device__ constexpr auto
PadCDescriptor_M_N(const CDesc_MRaw_NRaw& c_desc_mraw_nraw) const
{
return PadTensorDescriptor(
c_desc_mraw_nraw, make_tuple(MPerTile_, NPerTile_), Sequence<PadM, PadN>{});
}
MPerTileType MPerTile_;
NPerTileType NPerTile_;
KPerTileType KPerTile_;
};
} // namespace device
} // namespace tensor_operation
} // namespace ck