mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-01 12:11:19 +00:00
[CK_TILE] Add conv bwd weight two stage support (#2855)
* resolved conflicts * add conv bwd weight twostage * fix one file * fixes after review * fixes * fixes * Fix --------- Co-authored-by: Bartlomiej Kocot <barkocot@amd.com>
This commit is contained in:
@@ -49,7 +49,10 @@ template <index_t NDimSpatial_,
|
||||
typename InLayout_,
|
||||
typename WeiLayout_,
|
||||
typename DsLayout_,
|
||||
typename OutLayout_>
|
||||
typename OutLayout_,
|
||||
index_t VectorSizeA_ = 1,
|
||||
index_t VectorSizeB_ = 1,
|
||||
index_t VectorSizeC_ = 1>
|
||||
struct GroupedConvTraits
|
||||
{
|
||||
private:
|
||||
@@ -67,14 +70,38 @@ struct GroupedConvTraits
|
||||
using WeiLayout = WeiLayout_;
|
||||
using DsLayout = DsLayout_;
|
||||
using OutLayout = OutLayout_;
|
||||
using GroupedConvImplicitGemmTraits = TileGemmTraits<true,
|
||||
true,
|
||||
true,
|
||||
ck_tile::tensor_layout::gemm::RowMajor,
|
||||
ck_tile::tensor_layout::gemm::ColumnMajor,
|
||||
ck_tile::tensor_layout::gemm::RowMajor>;
|
||||
static constexpr index_t NumDTensor = DsLayout::size();
|
||||
using ImplicitGemmDsLayout = decltype(generate_implicit_gemm_layout());
|
||||
using GroupedConvImplicitGemmTraitsFwd =
|
||||
TileGemmTraits<true,
|
||||
true,
|
||||
true,
|
||||
ck_tile::tensor_layout::gemm::RowMajor,
|
||||
ck_tile::tensor_layout::gemm::ColumnMajor,
|
||||
ck_tile::tensor_layout::gemm::RowMajor>;
|
||||
using GroupedConvImplicitGemmTraitsBwdData =
|
||||
TileGemmTraits<true,
|
||||
true,
|
||||
true,
|
||||
ck_tile::tensor_layout::gemm::RowMajor,
|
||||
ck_tile::tensor_layout::gemm::ColumnMajor,
|
||||
// TODO: Change to and enable vector load
|
||||
// ck_tile::tensor_layout::gemm::RowMajor,
|
||||
// ck_tile::tensor_layout::gemm::RowMajor,
|
||||
ck_tile::tensor_layout::gemm::RowMajor>;
|
||||
using GroupedConvImplicitGemmTraitsBwdWeight =
|
||||
TileGemmTraits<true,
|
||||
true,
|
||||
true,
|
||||
ck_tile::tensor_layout::gemm::RowMajor,
|
||||
ck_tile::tensor_layout::gemm::ColumnMajor,
|
||||
// TODO: Change to and enable vector load
|
||||
// ck_tile::tensor_layout::gemm::ColumnMajor,
|
||||
// ck_tile::tensor_layout::gemm::RowMajor,
|
||||
ck_tile::tensor_layout::gemm::RowMajor>;
|
||||
static constexpr ck_tile::index_t VectorSizeA = VectorSizeA_;
|
||||
static constexpr ck_tile::index_t VectorSizeB = VectorSizeB_;
|
||||
static constexpr ck_tile::index_t VectorSizeC = VectorSizeC_;
|
||||
static constexpr index_t NumDTensor = DsLayout::size();
|
||||
using ImplicitGemmDsLayout = decltype(generate_implicit_gemm_layout());
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -10,6 +10,9 @@ namespace ck_tile {
|
||||
|
||||
template <index_t NDimSpatial,
|
||||
ConvolutionSpecialization ConvolutionSpecialization,
|
||||
index_t VectorSizeA,
|
||||
index_t VectorSizeB,
|
||||
index_t VectorSizeC,
|
||||
bool SplitN = false,
|
||||
typename ADataType = float,
|
||||
typename CDataType = float,
|
||||
@@ -442,14 +445,17 @@ struct TransformConvBwdDataToGemm
|
||||
// TODO Add support for NumGroupsToMerge > 1
|
||||
|
||||
return make_naive_tensor_descriptor(make_tuple(N_, Wo_, K_),
|
||||
make_tuple(NStride, WoStride, KStride));
|
||||
make_tuple(NStride, WoStride, KStride),
|
||||
number<VectorSizeA>{},
|
||||
I1);
|
||||
}
|
||||
|
||||
template <index_t NDim = NDimSpatial, typename std::enable_if<NDim == 1, bool>::type = false>
|
||||
CK_TILE_HOST auto make_wei_grid_desc() const
|
||||
{
|
||||
// GKXC
|
||||
return make_naive_tensor_descriptor_packed(make_tuple(K_, X_, C_));
|
||||
return make_naive_tensor_descriptor(
|
||||
make_tuple(K_, X_, C_), make_tuple(X_ * C_, C_, I1), number<VectorSizeB>{}, I1);
|
||||
}
|
||||
|
||||
template <index_t NDim = NDimSpatial, typename std::enable_if<NDim == 1, bool>::type = false>
|
||||
@@ -462,7 +468,9 @@ struct TransformConvBwdDataToGemm
|
||||
|
||||
// TODO Add support for NumGroupsToMerge > 1
|
||||
return make_naive_tensor_descriptor(make_tuple(N_, Wi_, C_),
|
||||
make_tuple(NStride, WiStride, CStride));
|
||||
make_tuple(NStride, WiStride, CStride),
|
||||
number<VectorSizeC>{},
|
||||
I1);
|
||||
}
|
||||
|
||||
template <index_t NDim = NDimSpatial, typename std::enable_if<NDim == 2, bool>::type = false>
|
||||
@@ -477,7 +485,9 @@ struct TransformConvBwdDataToGemm
|
||||
// TODO Add support for NumGroupsToMerge > 1
|
||||
|
||||
return make_naive_tensor_descriptor(make_tuple(N_, Ho_, Wo_, K_),
|
||||
make_tuple(NStride, HoStride, WoStride, KStride));
|
||||
make_tuple(NStride, HoStride, WoStride, KStride),
|
||||
number<VectorSizeA>{},
|
||||
I1);
|
||||
}
|
||||
|
||||
template <index_t NDim = NDimSpatial, typename std::enable_if<NDim == 2, bool>::type = false>
|
||||
@@ -491,14 +501,19 @@ struct TransformConvBwdDataToGemm
|
||||
|
||||
// TODO Add support for NumGroupsToMerge > 1
|
||||
return make_naive_tensor_descriptor(make_tuple(N_, Hi_, Wi_, C_),
|
||||
make_tuple(NStride, HiStride, WiStride, CStride));
|
||||
make_tuple(NStride, HiStride, WiStride, CStride),
|
||||
number<VectorSizeB>{},
|
||||
I1);
|
||||
}
|
||||
|
||||
template <index_t NDim = NDimSpatial, typename std::enable_if<NDim == 2, bool>::type = false>
|
||||
CK_TILE_HOST auto make_wei_grid_desc() const
|
||||
{
|
||||
// GKYXC
|
||||
return make_naive_tensor_descriptor_packed(make_tuple(K_, Y_, X_, C_));
|
||||
return make_naive_tensor_descriptor(make_tuple(K_, Y_, X_, C_),
|
||||
make_tuple(C_ * X_ * Y_, C_ * X_, C_, I1),
|
||||
number<VectorSizeC>{},
|
||||
I1);
|
||||
}
|
||||
|
||||
template <index_t NDim = NDimSpatial, typename std::enable_if<NDim == 3, bool>::type = false>
|
||||
@@ -514,7 +529,9 @@ struct TransformConvBwdDataToGemm
|
||||
// TODO Add support for NumGroupsToMerge > 1
|
||||
return make_naive_tensor_descriptor(
|
||||
make_tuple(N_, Do_, Ho_, Wo_, K_),
|
||||
make_tuple(NStride, DoStride, HoStride, WoStride, KStride));
|
||||
make_tuple(NStride, DoStride, HoStride, WoStride, KStride),
|
||||
number<VectorSizeA>{},
|
||||
I1);
|
||||
}
|
||||
|
||||
template <index_t NDim = NDimSpatial, typename std::enable_if<NDim == 3, bool>::type = false>
|
||||
@@ -529,14 +546,20 @@ struct TransformConvBwdDataToGemm
|
||||
// TODO Add support for NumGroupsToMerge > 1
|
||||
return make_naive_tensor_descriptor(
|
||||
make_tuple(N_, Di_, Hi_, Wi_, C_),
|
||||
make_tuple(NStride, DiStride, HiStride, WiStride, CStride));
|
||||
make_tuple(NStride, DiStride, HiStride, WiStride, CStride),
|
||||
number<VectorSizeB>{},
|
||||
I1);
|
||||
}
|
||||
|
||||
template <index_t NDim = NDimSpatial, typename std::enable_if<NDim == 3, bool>::type = false>
|
||||
CK_TILE_HOST auto make_wei_grid_desc() const
|
||||
{
|
||||
// GKZYXC
|
||||
return make_naive_tensor_descriptor_packed(make_tuple(K_, Z_, Y_, X_, C_));
|
||||
return make_naive_tensor_descriptor(
|
||||
make_tuple(K_, Z_, Y_, X_, C_),
|
||||
make_tuple(C_ * X_ * Y_ * Z_, C_ * X_ * Y_, C_ * X_, C_, I1),
|
||||
number<VectorSizeC>{},
|
||||
I1);
|
||||
}
|
||||
// TODO: implement ck_tile::tensor_layout::convolution that describe packed/strided dimemsion as
|
||||
// properties
|
||||
|
||||
@@ -10,6 +10,9 @@ namespace ck_tile {
|
||||
|
||||
template <index_t NDimSpatial,
|
||||
ConvolutionSpecialization ConvolutionSpecialization,
|
||||
index_t VectorSizeA,
|
||||
index_t VectorSizeB,
|
||||
index_t VectorSizeC,
|
||||
bool SplitN = false,
|
||||
typename ADataType = float,
|
||||
typename CDataType = float,
|
||||
@@ -420,7 +423,9 @@ struct TransformConvBwdWeightToGemm
|
||||
// TODO Add support for NumGroupsToMerge > 1
|
||||
|
||||
return make_naive_tensor_descriptor(make_tuple(K_, N_ * Wo_),
|
||||
make_tuple(KStride, NDoHoWoStride));
|
||||
make_tuple(KStride, NDoHoWoStride),
|
||||
number<VectorSizeA>{},
|
||||
I1);
|
||||
}
|
||||
|
||||
template <index_t NDim = NDimSpatial, typename std::enable_if<NDim == 1, bool>::type = false>
|
||||
@@ -433,7 +438,9 @@ struct TransformConvBwdWeightToGemm
|
||||
|
||||
// TODO Add support for NumGroupsToMerge > 1
|
||||
return make_naive_tensor_descriptor(make_tuple(N_, Wi_, C_),
|
||||
make_tuple(NStride, WiStride, CStride));
|
||||
make_tuple(NStride, WiStride, CStride),
|
||||
number<VectorSizeB>{},
|
||||
I1);
|
||||
}
|
||||
|
||||
template <index_t NDim = NDimSpatial, typename std::enable_if<NDim == 1, bool>::type = false>
|
||||
@@ -444,7 +451,8 @@ struct TransformConvBwdWeightToGemm
|
||||
constexpr auto CXStride = I1;
|
||||
|
||||
// TODO Add support for NumGroupsToMerge > 1
|
||||
return make_naive_tensor_descriptor(make_tuple(K_, X_ * C_), make_tuple(KStride, CXStride));
|
||||
return make_naive_tensor_descriptor(
|
||||
make_tuple(K_, X_ * C_), make_tuple(KStride, CXStride), number<VectorSizeC>{}, I1);
|
||||
}
|
||||
|
||||
template <index_t NDim = NDimSpatial, typename std::enable_if<NDim == 2, bool>::type = false>
|
||||
@@ -457,7 +465,9 @@ struct TransformConvBwdWeightToGemm
|
||||
// TODO Add support for NumGroupsToMerge > 1
|
||||
|
||||
return make_naive_tensor_descriptor(make_tuple(K_, N_ * Ho_ * Wo_),
|
||||
make_tuple(KStride, NDoHoWoStride));
|
||||
make_tuple(KStride, NDoHoWoStride),
|
||||
number<VectorSizeA>{},
|
||||
I1);
|
||||
}
|
||||
|
||||
template <index_t NDim = NDimSpatial, typename std::enable_if<NDim == 2, bool>::type = false>
|
||||
@@ -471,7 +481,9 @@ struct TransformConvBwdWeightToGemm
|
||||
|
||||
// TODO Add support for NumGroupsToMerge > 1
|
||||
return make_naive_tensor_descriptor(make_tuple(N_, Hi_, Wi_, C_),
|
||||
make_tuple(NStride, HiStride, WiStride, CStride));
|
||||
make_tuple(NStride, HiStride, WiStride, CStride),
|
||||
number<VectorSizeB>{},
|
||||
I1);
|
||||
}
|
||||
|
||||
template <index_t NDim = NDimSpatial, typename std::enable_if<NDim == 2, bool>::type = false>
|
||||
@@ -482,8 +494,8 @@ struct TransformConvBwdWeightToGemm
|
||||
constexpr auto CStride = I1;
|
||||
|
||||
// TODO Add support for NumGroupsToMerge > 1
|
||||
return make_naive_tensor_descriptor(make_tuple(K_, Y_ * X_ * C_),
|
||||
make_tuple(KStride, CStride));
|
||||
return make_naive_tensor_descriptor(
|
||||
make_tuple(K_, Y_ * X_ * C_), make_tuple(KStride, CStride), number<VectorSizeC>{}, I1);
|
||||
}
|
||||
|
||||
template <index_t NDim = NDimSpatial, typename std::enable_if<NDim == 3, bool>::type = false>
|
||||
@@ -496,7 +508,9 @@ struct TransformConvBwdWeightToGemm
|
||||
// TODO Add support for NumGroupsToMerge > 1
|
||||
|
||||
return make_naive_tensor_descriptor(make_tuple(K_, N_ * Do_ * Ho_ * Wo_),
|
||||
make_tuple(KStride, NDoHoWoStride));
|
||||
make_tuple(KStride, NDoHoWoStride),
|
||||
number<VectorSizeA>{},
|
||||
I1);
|
||||
}
|
||||
|
||||
template <index_t NDim = NDimSpatial, typename std::enable_if<NDim == 3, bool>::type = false>
|
||||
@@ -511,7 +525,9 @@ struct TransformConvBwdWeightToGemm
|
||||
// TODO Add support for NumGroupsToMerge > 1
|
||||
return make_naive_tensor_descriptor(
|
||||
make_tuple(N_, Di_, Hi_, Wi_, C_),
|
||||
make_tuple(NStride, DiStride, HiStride, WiStride, CStride));
|
||||
make_tuple(NStride, DiStride, HiStride, WiStride, CStride),
|
||||
number<VectorSizeB>{},
|
||||
I1);
|
||||
}
|
||||
|
||||
template <index_t NDim = NDimSpatial, typename std::enable_if<NDim == 3, bool>::type = false>
|
||||
@@ -523,7 +539,9 @@ struct TransformConvBwdWeightToGemm
|
||||
|
||||
// TODO Add support for NumGroupsToMerge > 1
|
||||
return make_naive_tensor_descriptor(make_tuple(K_, Z_ * Y_ * X_ * C_),
|
||||
make_tuple(KStride, CStride));
|
||||
make_tuple(KStride, CStride),
|
||||
number<VectorSizeC>{},
|
||||
I1);
|
||||
}
|
||||
|
||||
// TODO: implement ck_tile::tensor_layout::convolution that describe packed/strided dimemsion as
|
||||
|
||||
@@ -10,6 +10,9 @@ namespace ck_tile {
|
||||
|
||||
template <index_t NDimSpatial,
|
||||
ConvolutionSpecialization ConvSpecialization,
|
||||
index_t VectorSizeA,
|
||||
index_t VectorSizeB,
|
||||
index_t VectorSizeC,
|
||||
bool SplitN = false,
|
||||
typename ADataType = float,
|
||||
typename CDataType = float,
|
||||
@@ -446,7 +449,9 @@ struct TransformConvFwdToGemm
|
||||
{
|
||||
const auto in_gemmm_gemmk_desc = make_naive_tensor_descriptor(
|
||||
make_tuple(N_, Wo_, C_),
|
||||
make_tuple(NStrideTensorA_, WiStride_, CStrideTensorA_));
|
||||
make_tuple(NStrideTensorA_, WiStride_, CStrideTensorA_),
|
||||
number<VectorSizeA>{},
|
||||
I1);
|
||||
return transform_tensor_descriptor(
|
||||
in_gemmm_gemmk_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(N_, Wo_)),
|
||||
@@ -458,7 +463,9 @@ struct TransformConvFwdToGemm
|
||||
{
|
||||
const auto in_gemmm_groups_gemmk_desc = make_naive_tensor_descriptor(
|
||||
make_tuple(N_, Wo_, NumGroupsToMerge, C_),
|
||||
make_tuple(NStrideTensorA_, WiStride_, GStrideTensorA_, CStrideTensorA_));
|
||||
make_tuple(NStrideTensorA_, WiStride_, GStrideTensorA_, CStrideTensorA_),
|
||||
number<VectorSizeA>{},
|
||||
I1);
|
||||
|
||||
return transform_tensor_descriptor(
|
||||
in_gemmm_groups_gemmk_desc,
|
||||
@@ -473,8 +480,11 @@ struct TransformConvFwdToGemm
|
||||
if constexpr(NumGroupsToMerge == 1)
|
||||
{
|
||||
|
||||
const auto in_n_wi_c_desc = make_naive_tensor_descriptor(
|
||||
make_tuple(N_, Wi_), make_tuple(NStrideTensorA_, WiStride_));
|
||||
const auto in_n_wi_c_desc =
|
||||
make_naive_tensor_descriptor(make_tuple(N_, Wi_),
|
||||
make_tuple(NStrideTensorA_, WiStride_),
|
||||
number<VectorSizeA>{},
|
||||
I1);
|
||||
|
||||
const auto in_n_wip_c_desc = transform_tensor_descriptor(
|
||||
in_n_wi_c_desc,
|
||||
@@ -502,7 +512,9 @@ struct TransformConvFwdToGemm
|
||||
{
|
||||
const auto in_n_wi_c_desc = make_naive_tensor_descriptor(
|
||||
make_tuple(N_, Wi_, NumGroupsToMerge),
|
||||
make_tuple(NStrideTensorA_, WiStride_, GStrideTensorA_));
|
||||
make_tuple(NStrideTensorA_, WiStride_, GStrideTensorA_),
|
||||
number<VectorSizeA>{},
|
||||
I1);
|
||||
|
||||
const auto in_n_wip_c_desc = transform_tensor_descriptor(
|
||||
in_n_wi_c_desc,
|
||||
@@ -535,7 +547,9 @@ struct TransformConvFwdToGemm
|
||||
{
|
||||
const auto in_n_wi_c_desc = make_naive_tensor_descriptor(
|
||||
make_tuple(N_, Wi_, C_),
|
||||
make_tuple(NStrideTensorA_, WiStride_, CStrideTensorA_));
|
||||
make_tuple(NStrideTensorA_, WiStride_, CStrideTensorA_),
|
||||
number<VectorSizeA>{},
|
||||
I1);
|
||||
|
||||
const auto in_n_wo_c_desc = transform_tensor_descriptor(
|
||||
in_n_wi_c_desc,
|
||||
@@ -556,7 +570,9 @@ struct TransformConvFwdToGemm
|
||||
{
|
||||
const auto in_n_wi_c_desc = make_naive_tensor_descriptor(
|
||||
make_tuple(N_, Wi_, NumGroupsToMerge, C_),
|
||||
make_tuple(NStrideTensorA_, WiStride_, GStrideTensorA_, CStrideTensorA_));
|
||||
make_tuple(NStrideTensorA_, WiStride_, GStrideTensorA_, CStrideTensorA_),
|
||||
number<VectorSizeA>{},
|
||||
I1);
|
||||
|
||||
const auto in_n_wo_c_desc = transform_tensor_descriptor(
|
||||
in_n_wi_c_desc,
|
||||
@@ -581,7 +597,9 @@ struct TransformConvFwdToGemm
|
||||
{
|
||||
const auto in_n_wi_c_desc = make_naive_tensor_descriptor(
|
||||
make_tuple(N_, Wi_, C_),
|
||||
make_tuple(NStrideTensorA_, WiStride_, CStrideTensorA_));
|
||||
make_tuple(NStrideTensorA_, WiStride_, CStrideTensorA_),
|
||||
number<VectorSizeA>{},
|
||||
I1);
|
||||
|
||||
const auto in_n_wip_c_desc = transform_tensor_descriptor(
|
||||
in_n_wi_c_desc,
|
||||
@@ -611,7 +629,9 @@ struct TransformConvFwdToGemm
|
||||
{
|
||||
const auto in_n_wi_c_desc = make_naive_tensor_descriptor(
|
||||
make_tuple(N_, Wi_, NumGroupsToMerge, C_),
|
||||
make_tuple(NStrideTensorA_, WiStride_, GStrideTensorA_, CStrideTensorA_));
|
||||
make_tuple(NStrideTensorA_, WiStride_, GStrideTensorA_, CStrideTensorA_),
|
||||
number<VectorSizeA>{},
|
||||
I1);
|
||||
|
||||
const auto in_n_wip_c_desc = transform_tensor_descriptor(
|
||||
in_n_wi_c_desc,
|
||||
@@ -661,7 +681,9 @@ struct TransformConvFwdToGemm
|
||||
{
|
||||
const auto in_gemmm_gemmk_desc = make_naive_tensor_descriptor(
|
||||
make_tuple(N_, Ho_, Wo_, C_),
|
||||
make_tuple(NStrideTensorA_, HiStride_, WiStride_, CStrideTensorA_));
|
||||
make_tuple(NStrideTensorA_, HiStride_, WiStride_, CStrideTensorA_),
|
||||
number<VectorSizeA>{},
|
||||
I1);
|
||||
|
||||
return transform_tensor_descriptor(
|
||||
in_gemmm_gemmk_desc,
|
||||
@@ -675,7 +697,9 @@ struct TransformConvFwdToGemm
|
||||
const auto in_gemmm_groups_gemmk_desc = make_naive_tensor_descriptor(
|
||||
make_tuple(N_, Ho_, Wo_, NumGroupsToMerge, C_),
|
||||
make_tuple(
|
||||
NStrideTensorA_, HiStride_, WiStride_, GStrideTensorA_, CStrideTensorA_));
|
||||
NStrideTensorA_, HiStride_, WiStride_, GStrideTensorA_, CStrideTensorA_),
|
||||
number<VectorSizeA>{},
|
||||
I1);
|
||||
|
||||
return transform_tensor_descriptor(
|
||||
in_gemmm_groups_gemmk_desc,
|
||||
@@ -689,8 +713,11 @@ struct TransformConvFwdToGemm
|
||||
{
|
||||
if constexpr(NumGroupsToMerge == 1)
|
||||
{
|
||||
const auto in_n_hi_wi_c_desc = make_naive_tensor_descriptor(
|
||||
make_tuple(N_, Hi_, Wi_), make_tuple(NStrideTensorA_, HiStride_, WiStride_));
|
||||
const auto in_n_hi_wi_c_desc =
|
||||
make_naive_tensor_descriptor(make_tuple(N_, Hi_, Wi_),
|
||||
make_tuple(NStrideTensorA_, HiStride_, WiStride_),
|
||||
number<VectorSizeA>{},
|
||||
I1);
|
||||
|
||||
const auto in_n_hip_wip_c_desc = transform_tensor_descriptor(
|
||||
in_n_hi_wi_c_desc,
|
||||
@@ -721,7 +748,9 @@ struct TransformConvFwdToGemm
|
||||
{
|
||||
const auto in_n_hi_wi_groups_c_desc = make_naive_tensor_descriptor(
|
||||
make_tuple(N_, Hi_, Wi_, NumGroupsToMerge),
|
||||
make_tuple(NStrideTensorA_, HiStride_, WiStride_, GStrideTensorA_));
|
||||
make_tuple(NStrideTensorA_, HiStride_, WiStride_, GStrideTensorA_),
|
||||
number<VectorSizeA>{},
|
||||
I1);
|
||||
|
||||
const auto in_n_hip_wip_groups_c_desc = transform_tensor_descriptor(
|
||||
in_n_hi_wi_groups_c_desc,
|
||||
@@ -757,7 +786,9 @@ struct TransformConvFwdToGemm
|
||||
{
|
||||
const auto in_n_hi_wi_c_desc = make_naive_tensor_descriptor(
|
||||
make_tuple(N_, Hi_, Wi_, C_),
|
||||
make_tuple(NStrideTensorA_, HiStride_, WiStride_, CStrideTensorA_));
|
||||
make_tuple(NStrideTensorA_, HiStride_, WiStride_, CStrideTensorA_),
|
||||
number<VectorSizeA>{},
|
||||
I1);
|
||||
|
||||
const auto in_n_ho_wo_c_desc = transform_tensor_descriptor(
|
||||
in_n_hi_wi_c_desc,
|
||||
@@ -780,7 +811,9 @@ struct TransformConvFwdToGemm
|
||||
const auto in_n_hi_wi_groups_c_desc = make_naive_tensor_descriptor(
|
||||
make_tuple(N_, Hi_, Wi_, NumGroupsToMerge, C_),
|
||||
make_tuple(
|
||||
NStrideTensorA_, HiStride_, WiStride_, GStrideTensorA_, CStrideTensorA_));
|
||||
NStrideTensorA_, HiStride_, WiStride_, GStrideTensorA_, CStrideTensorA_),
|
||||
number<VectorSizeA>{},
|
||||
I1);
|
||||
|
||||
const auto in_n_ho_wo_groups_c_desc = transform_tensor_descriptor(
|
||||
in_n_hi_wi_groups_c_desc,
|
||||
@@ -808,7 +841,9 @@ struct TransformConvFwdToGemm
|
||||
{
|
||||
const auto in_n_hi_wi_c_desc = make_naive_tensor_descriptor(
|
||||
make_tuple(N_, Hi_, Wi_, C_),
|
||||
make_tuple(NStrideTensorA_, HiStride_, WiStride_, CStrideTensorA_));
|
||||
make_tuple(NStrideTensorA_, HiStride_, WiStride_, CStrideTensorA_),
|
||||
number<VectorSizeA>{},
|
||||
I1);
|
||||
|
||||
const auto in_n_hip_wip_c_desc = transform_tensor_descriptor(
|
||||
in_n_hi_wi_c_desc,
|
||||
@@ -843,7 +878,9 @@ struct TransformConvFwdToGemm
|
||||
const auto in_n_hi_wi_groups_c_desc = make_naive_tensor_descriptor(
|
||||
make_tuple(N_, Hi_, Wi_, NumGroupsToMerge, C_),
|
||||
make_tuple(
|
||||
NStrideTensorA_, HiStride_, WiStride_, GStrideTensorA_, CStrideTensorA_));
|
||||
NStrideTensorA_, HiStride_, WiStride_, GStrideTensorA_, CStrideTensorA_),
|
||||
number<VectorSizeA>{},
|
||||
I1);
|
||||
|
||||
const auto in_n_hip_wip_groups_c_desc = transform_tensor_descriptor(
|
||||
in_n_hi_wi_groups_c_desc,
|
||||
@@ -904,7 +941,9 @@ struct TransformConvFwdToGemm
|
||||
{
|
||||
const auto in_gemmm_gemmk_desc = make_naive_tensor_descriptor(
|
||||
make_tuple(N_, Do_, Ho_, Wo_, C_),
|
||||
make_tuple(NStrideTensorA_, DiStride_, HiStride_, WiStride_, CStrideTensorA_));
|
||||
make_tuple(NStrideTensorA_, DiStride_, HiStride_, WiStride_, CStrideTensorA_),
|
||||
number<VectorSizeA>{},
|
||||
I1);
|
||||
|
||||
return transform_tensor_descriptor(
|
||||
in_gemmm_gemmk_desc,
|
||||
@@ -922,7 +961,9 @@ struct TransformConvFwdToGemm
|
||||
HiStride_,
|
||||
WiStride_,
|
||||
GStrideTensorA_,
|
||||
CStrideTensorA_));
|
||||
CStrideTensorA_),
|
||||
number<VectorSizeA>{},
|
||||
I1);
|
||||
|
||||
return transform_tensor_descriptor(
|
||||
in_gemmm_groups_gemmk_desc,
|
||||
@@ -939,7 +980,9 @@ struct TransformConvFwdToGemm
|
||||
{
|
||||
const auto in_n_di_hi_wi_c_desc = make_naive_tensor_descriptor(
|
||||
make_tuple(N_, Di_, Hi_, Wi_),
|
||||
make_tuple(NStrideTensorA_, DiStride_, HiStride_, WiStride_));
|
||||
make_tuple(NStrideTensorA_, DiStride_, HiStride_, WiStride_),
|
||||
number<VectorSizeA>{},
|
||||
I1);
|
||||
|
||||
const auto in_n_hip_wip_c_desc = transform_tensor_descriptor(
|
||||
in_n_di_hi_wi_c_desc,
|
||||
@@ -975,7 +1018,9 @@ struct TransformConvFwdToGemm
|
||||
{
|
||||
const auto in_n_di_hi_wi_c_desc = make_naive_tensor_descriptor(
|
||||
make_tuple(N_, Di_, Hi_, Wi_, NumGroupsToMerge),
|
||||
make_tuple(NStrideTensorA_, DiStride_, HiStride_, WiStride_, GStrideTensorA_));
|
||||
make_tuple(NStrideTensorA_, DiStride_, HiStride_, WiStride_, GStrideTensorA_),
|
||||
number<VectorSizeA>{},
|
||||
I1);
|
||||
|
||||
const auto in_n_hip_wip_c_desc = transform_tensor_descriptor(
|
||||
in_n_di_hi_wi_c_desc,
|
||||
@@ -1022,7 +1067,9 @@ struct TransformConvFwdToGemm
|
||||
{
|
||||
const auto in_n_di_hi_wi_c_desc = make_naive_tensor_descriptor(
|
||||
make_tuple(N_, Di_, Hi_, Wi_, C_),
|
||||
make_tuple(NStrideTensorA_, DiStride_, HiStride_, WiStride_, CStrideTensorA_));
|
||||
make_tuple(NStrideTensorA_, DiStride_, HiStride_, WiStride_, CStrideTensorA_),
|
||||
number<VectorSizeA>{},
|
||||
I1);
|
||||
|
||||
const auto in_n_do_ho_wo_c_desc = transform_tensor_descriptor(
|
||||
in_n_di_hi_wi_c_desc,
|
||||
@@ -1052,7 +1099,9 @@ struct TransformConvFwdToGemm
|
||||
HiStride_,
|
||||
WiStride_,
|
||||
GStrideTensorA_,
|
||||
CStrideTensorA_));
|
||||
CStrideTensorA_),
|
||||
number<VectorSizeA>{},
|
||||
I1);
|
||||
|
||||
const auto in_n_do_ho_wo_c_desc = transform_tensor_descriptor(
|
||||
in_n_di_hi_wi_c_desc,
|
||||
@@ -1090,7 +1139,9 @@ struct TransformConvFwdToGemm
|
||||
{
|
||||
const auto in_n_di_hi_wi_c_desc = make_naive_tensor_descriptor(
|
||||
make_tuple(N_, Di_, Hi_, Wi_, C_),
|
||||
make_tuple(NStrideTensorA_, DiStride_, HiStride_, WiStride_, CStrideTensorA_));
|
||||
make_tuple(NStrideTensorA_, DiStride_, HiStride_, WiStride_, CStrideTensorA_),
|
||||
number<VectorSizeA>{},
|
||||
I1);
|
||||
|
||||
const auto in_n_hip_wip_c_desc = transform_tensor_descriptor(
|
||||
in_n_di_hi_wi_c_desc,
|
||||
@@ -1138,7 +1189,9 @@ struct TransformConvFwdToGemm
|
||||
HiStride_,
|
||||
WiStride_,
|
||||
GStrideTensorA_,
|
||||
CStrideTensorA_));
|
||||
CStrideTensorA_),
|
||||
number<VectorSizeA>{},
|
||||
I1);
|
||||
|
||||
const auto in_n_hip_wip_c_desc = transform_tensor_descriptor(
|
||||
in_n_di_hi_wi_c_desc,
|
||||
@@ -1217,14 +1270,19 @@ struct TransformConvFwdToGemm
|
||||
|
||||
if constexpr(NumGroupsToMerge == 1)
|
||||
{
|
||||
return make_naive_tensor_descriptor_packed(make_tuple(K_, FilterSizeNumType{}));
|
||||
return make_naive_tensor_descriptor(make_tuple(K_, FilterSizeNumType{}),
|
||||
make_tuple(FilterSizeNumType{}, I1),
|
||||
number<VectorSizeB>{},
|
||||
I1);
|
||||
}
|
||||
else
|
||||
{
|
||||
|
||||
const auto wei_gemmn_groups_gemmk_desc = make_naive_tensor_descriptor(
|
||||
make_tuple(K_, NumGroupsToMerge, FilterSizeNumType{}),
|
||||
make_tuple(KStrideTensorB_, GStrideTensorB_, CStrideTensorB_));
|
||||
make_tuple(KStrideTensorB_, GStrideTensorB_, CStrideTensorB_),
|
||||
number<VectorSizeB>{},
|
||||
I1);
|
||||
return transform_tensor_descriptor(
|
||||
wei_gemmn_groups_gemmk_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(K_, NumGroupsToMerge)),
|
||||
@@ -1237,13 +1295,18 @@ struct TransformConvFwdToGemm
|
||||
{
|
||||
if constexpr(NumGroupsToMerge == 1)
|
||||
{
|
||||
return make_naive_tensor_descriptor_packed(make_tuple(K_, ZYX_ * C_));
|
||||
return make_naive_tensor_descriptor(make_tuple(K_, ZYX_ * C_),
|
||||
make_tuple(ZYX_ * C_, I1),
|
||||
number<VectorSizeB>{},
|
||||
I1);
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto wei_gemmn_groups_gemmk_desc = make_naive_tensor_descriptor(
|
||||
make_tuple(K_, NumGroupsToMerge, ZYX_ * C_),
|
||||
make_tuple(KStrideTensorB_, GStrideTensorB_, CStrideTensorB_));
|
||||
make_tuple(KStrideTensorB_, GStrideTensorB_, CStrideTensorB_),
|
||||
number<VectorSizeB>{},
|
||||
I1);
|
||||
return transform_tensor_descriptor(
|
||||
wei_gemmn_groups_gemmk_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(K_, NumGroupsToMerge)),
|
||||
@@ -1270,14 +1333,18 @@ struct TransformConvFwdToGemm
|
||||
if constexpr(NumGroupsToMerge == 1)
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(NDoHoWo, K_),
|
||||
make_tuple(WoStride_, KStrideTensorC_));
|
||||
make_tuple(WoStride_, KStrideTensorC_),
|
||||
number<VectorSizeC>{},
|
||||
I1);
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto nhwo_groups_k_1_desc = make_naive_tensor_descriptor(
|
||||
make_tuple(N_, Wo_, NumGroupsToMerge, K_, 1),
|
||||
make_tuple(
|
||||
NStrideTensorC_, WoStride_, GStrideTensorC_, KStrideTensorC_, GStrideTensorC_));
|
||||
NStrideTensorC_, WoStride_, GStrideTensorC_, KStrideTensorC_, GStrideTensorC_),
|
||||
number<VectorSizeC>{},
|
||||
I1);
|
||||
// Padd 1 to NumGroupsToMerge
|
||||
const auto padded_desc = transform_tensor_descriptor(
|
||||
nhwo_groups_k_1_desc,
|
||||
@@ -1328,7 +1395,9 @@ struct TransformConvFwdToGemm
|
||||
if constexpr(NumGroupsToMerge == 1)
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(NDoHoWo, K_),
|
||||
make_tuple(WoStride_, KStrideTensorC_));
|
||||
make_tuple(WoStride_, KStrideTensorC_),
|
||||
number<VectorSizeC>{},
|
||||
I1);
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -1339,7 +1408,9 @@ struct TransformConvFwdToGemm
|
||||
WoStride_,
|
||||
GStrideTensorC_,
|
||||
KStrideTensorC_,
|
||||
GStrideTensorC_));
|
||||
GStrideTensorC_),
|
||||
number<VectorSizeC>{},
|
||||
I1);
|
||||
// Padd 1 to NumGroupsToMerge
|
||||
const auto padded_desc = transform_tensor_descriptor(
|
||||
nhwo_groups_k_1_desc,
|
||||
@@ -1390,7 +1461,9 @@ struct TransformConvFwdToGemm
|
||||
if constexpr(NumGroupsToMerge == 1)
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(NDoHoWo, K_),
|
||||
make_tuple(WoStride_, KStrideTensorC_));
|
||||
make_tuple(WoStride_, KStrideTensorC_),
|
||||
number<VectorSizeC>{},
|
||||
I1);
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -1402,7 +1475,9 @@ struct TransformConvFwdToGemm
|
||||
WoStride_,
|
||||
GStrideTensorC_,
|
||||
KStrideTensorC_,
|
||||
GStrideTensorC_));
|
||||
GStrideTensorC_),
|
||||
number<VectorSizeC>{},
|
||||
I1);
|
||||
// Padd 1 to NumGroupsToMerge
|
||||
const auto padded_desc = transform_tensor_descriptor(
|
||||
nhwo_groups_k_1_desc,
|
||||
|
||||
Reference in New Issue
Block a user