mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 06:49:15 +00:00
[CK_Tile] Merge multiple convolution groups into a single GEMM batch (#2986)
* Fix compilation of the grouped conv examples. * Fix grouped conv bwd weight example output in CK Tile. * Add number of groups to merge to ck tile grouped gemm example. * Initial set of tests for TransformConvBwdWeightToGemm. * Added unit tests for TransformConvBwdWeightToGemm conv groups are merged. * WIP: Tensor transformations. * Add unit tests for coordinate transforms. * Fully working conv group merging for TransformConvBwdWeightToGemm. * WIP: Merged conv groups offset calculation. * Adde unit tests for tensor view. * WIP: Merged conv groups epilogue. * Enable running multiple conv groups per batch. * Add tests for tile_distribution_encoding. * Change example to match optimally depthwise convolution with merged groups. * Add more tests for tensor view. * Integration test for reading diagonal blocks from grouped distributed tensor. * Improved integration test. * Improve test for accessing diagonal blocks. * Added integration test for cshuffle epilogue LDS tile distribution. * Add more logging. * Increase the max number of reported errors. * WIP: merged conv groups GEMM epilogue changes. * LDS to global memory copy. * Fix tile window size for c block. * Integration test for CShuffle epilogue. * Improved CShuffle test. * WIP: Separate epilogue for merged conv groups. * Tile example parameters changes to match depthwise conv. * Offset fixes. * Epilogue fixes. * Working baseline for depthwise covolution with merged conv groups. * Fix build. * Initial unit tests for tensor descriptor. * Add one more unit test for tensor view. * WIP: LDS to global mem transfer using CK tile tensor descriptor and tile distribution encoding. * Fully functional LDS to global mem transfer using tensor descriptor and tile distribution encoding. * Add more comments, disable debug code. * Remove debug and other dead code. * Code clean-up for bwd tensor transformations. * Enable running multiple GEMM batches of merged conv groups. * Add compile check for assumed row-mjor layout. * Fix strides in 1D conv to gemm transformation. * WIP: Simplify conv to gemm transformations and handle K > 1 and C > 1 cases. * Fix case k > 1 and c=1. * Remove debug code. * Make MPerGroup and NPerGroup template parameters. * Add additional check for non-supported c > 1 case. * WIP: Put back the generic tensor descriptors for convolutions. * Fix tensor descriptors. * Remove the obsolete template parameters. * Add more instances. * Fix bugs in merged conv groups tensor descriptors. * Fix tensor descriptors for merged conv groups when K > 1. * Remove debug output. * Remove dead code. * Fix merge conflicts. * Code clean-up. * Remove unused code. * Run clang-formatting. * Remove debug prints and obsolete tests. * Check that number of convolution groups is multiple of merged groups. * Fix build after removing obsolete functionality. * Remove obsolete enumeration. * Fix new unit projects. * Remove unnecessary includes. * Fix passing the number of merged groups. * Remove unrelated tests. * Fix IsSupportedArgument for bwd weight conv kernel. * Fix clang formatting. * Fix the bwd weight conv to gemm mapping for num merged groups > 1. * GEMM config for conv group merging. * Fix clang-formatting. * Remove obsolete comment. * Fix typos in comment strings. * Increase the max number of reported errors when testing against reference implementation. * Rename gemm_config to conv_config. * Rename GemmConfig to ConvConfig and move NumGroupsToMerge into ConvConfig. * Change num_groups_to_merge to a boolean flag in the ck tile grouped conv example. * Run clang-format. * Add number of merged groups into kernel name string. * Remove group merging flag from CK Tile grouped conv example.
This commit is contained in:
@@ -59,10 +59,11 @@ template <index_t NDimSpatial_,
|
||||
typename WeiLayout_,
|
||||
typename DsLayout_,
|
||||
typename OutLayout_,
|
||||
index_t VectorSizeA_ = 1,
|
||||
index_t VectorSizeB_ = 1,
|
||||
index_t VectorSizeC_ = 1,
|
||||
typename CDElementwise_ = PassThrough>
|
||||
index_t VectorSizeA_ = 1,
|
||||
index_t VectorSizeB_ = 1,
|
||||
index_t VectorSizeC_ = 1,
|
||||
index_t NumGroupsToMerge_ = 1,
|
||||
typename CDElementwise_ = PassThrough>
|
||||
struct GroupedConvTraits
|
||||
{
|
||||
private:
|
||||
@@ -73,7 +74,7 @@ struct GroupedConvTraits
|
||||
}
|
||||
|
||||
public:
|
||||
static constexpr index_t NumGroupsToMerge = 1;
|
||||
static constexpr index_t NumGroupsToMerge = NumGroupsToMerge_;
|
||||
static constexpr index_t NDimSpatial = NDimSpatial_;
|
||||
static constexpr ConvolutionSpecialization ConvSpecialization = ConvSpecialization_;
|
||||
using InLayout = InLayout_;
|
||||
|
||||
@@ -13,10 +13,10 @@ template <index_t NDimSpatial,
|
||||
index_t VectorSizeA,
|
||||
index_t VectorSizeB,
|
||||
index_t VectorSizeC,
|
||||
index_t NumGroupsToMerge = 1,
|
||||
bool SplitN = false,
|
||||
typename ADataType = float,
|
||||
typename CDataType = float,
|
||||
index_t NumGroupsToMerge = 1,
|
||||
typename IndexType = index_t>
|
||||
struct TransformConvBwdWeightToGemm
|
||||
{
|
||||
@@ -125,8 +125,7 @@ struct TransformConvBwdWeightToGemm
|
||||
InLeftPadW_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.InLeftPadW_)},
|
||||
InRightPadD_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.InRightPadD_)},
|
||||
InRightPadH_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.InRightPadH_)},
|
||||
InRightPadW_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.InRightPadW_)},
|
||||
ZYX_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.ZYX_)}
|
||||
InRightPadW_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.InRightPadW_)}
|
||||
{
|
||||
}
|
||||
|
||||
@@ -164,8 +163,7 @@ struct TransformConvBwdWeightToGemm
|
||||
InLeftPadW_{input_left_pads[I0]},
|
||||
InRightPadD_{I0},
|
||||
InRightPadH_{I0},
|
||||
InRightPadW_{input_right_pads[I0]},
|
||||
ZYX_{X_}
|
||||
InRightPadW_{input_right_pads[I0]}
|
||||
{
|
||||
static_assert(std::is_same_v<ConvSpatialDimsType, std::array<IndexType, NDimSpatial>> ||
|
||||
std::is_same_v<ConvSpatialDimsType, ck_tile::array<IndexType, NDimSpatial>>);
|
||||
@@ -219,8 +217,7 @@ struct TransformConvBwdWeightToGemm
|
||||
InLeftPadW_{input_left_pads[I1]},
|
||||
InRightPadD_{I0},
|
||||
InRightPadH_{input_right_pads[I0]},
|
||||
InRightPadW_{input_right_pads[I1]},
|
||||
ZYX_{Y_ * X_}
|
||||
InRightPadW_{input_right_pads[I1]}
|
||||
{
|
||||
static_assert(std::is_same_v<ConvSpatialDimsType, std::array<IndexType, NDimSpatial>> ||
|
||||
std::is_same_v<ConvSpatialDimsType, ck_tile::array<IndexType, NDimSpatial>>);
|
||||
@@ -274,8 +271,7 @@ struct TransformConvBwdWeightToGemm
|
||||
InLeftPadW_{input_left_pads[I2]},
|
||||
InRightPadD_{input_right_pads[I0]},
|
||||
InRightPadH_{input_right_pads[I1]},
|
||||
InRightPadW_{input_right_pads[I2]},
|
||||
ZYX_{Z_ * Y_ * X_}
|
||||
InRightPadW_{input_right_pads[I2]}
|
||||
{
|
||||
static_assert(std::is_same_v<ConvSpatialDimsType, std::array<IndexType, NDimSpatial>> ||
|
||||
std::is_same_v<ConvSpatialDimsType, ck_tile::array<IndexType, NDimSpatial>>);
|
||||
@@ -420,11 +416,21 @@ struct TransformConvBwdWeightToGemm
|
||||
const index_t NDoHoWoStride = G_ * K_;
|
||||
constexpr auto KStride = I1;
|
||||
|
||||
// TODO Add support for NumGroupsToMerge > 1
|
||||
return make_naive_tensor_descriptor(make_tuple(K_, N_ * Wo_),
|
||||
make_tuple(KStride, NDoHoWoStride),
|
||||
number<VectorSizeA>{},
|
||||
I1);
|
||||
if constexpr(NumGroupsToMerge > 1)
|
||||
{
|
||||
const index_t BatchStride = K_;
|
||||
return make_naive_tensor_descriptor(make_tuple(K_, NumGroupsToMerge, N_ * Wo_),
|
||||
make_tuple(KStride, BatchStride, NDoHoWoStride),
|
||||
number<VectorSizeA>{},
|
||||
I1);
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(K_, N_ * Wo_),
|
||||
make_tuple(KStride, NDoHoWoStride),
|
||||
number<VectorSizeA>{},
|
||||
I1);
|
||||
}
|
||||
}
|
||||
|
||||
template <index_t NDim = NDimSpatial, typename std::enable_if<NDim == 1, bool>::type = false>
|
||||
@@ -435,11 +441,22 @@ struct TransformConvBwdWeightToGemm
|
||||
const index_t WiStride = G_ * C_;
|
||||
constexpr auto CStride = I1;
|
||||
|
||||
// TODO Add support for NumGroupsToMerge > 1
|
||||
return make_naive_tensor_descriptor(make_tuple(N_, Wi_, C_),
|
||||
make_tuple(NStride, WiStride, CStride),
|
||||
number<VectorSizeB>{},
|
||||
I1);
|
||||
if constexpr(NumGroupsToMerge > 1)
|
||||
{
|
||||
const auto BatchStride = C_;
|
||||
return make_naive_tensor_descriptor(make_tuple(N_, Wi_, NumGroupsToMerge, C_),
|
||||
make_tuple(NStride, WiStride, BatchStride, CStride),
|
||||
number<VectorSizeB>{},
|
||||
I1);
|
||||
}
|
||||
else
|
||||
{
|
||||
|
||||
return make_naive_tensor_descriptor(make_tuple(N_, Wi_, C_),
|
||||
make_tuple(NStride, WiStride, CStride),
|
||||
number<VectorSizeB>{},
|
||||
I1);
|
||||
}
|
||||
}
|
||||
|
||||
template <index_t NDim = NDimSpatial, typename std::enable_if<NDim == 1, bool>::type = false>
|
||||
@@ -449,9 +466,56 @@ struct TransformConvBwdWeightToGemm
|
||||
const index_t KStride = X_ * C_;
|
||||
constexpr auto CXStride = I1;
|
||||
|
||||
// TODO Add support for NumGroupsToMerge > 1
|
||||
return make_naive_tensor_descriptor(
|
||||
make_tuple(K_, X_ * C_), make_tuple(KStride, CXStride), number<VectorSizeC>{}, I1);
|
||||
if constexpr(NumGroupsToMerge > 1)
|
||||
{
|
||||
const index_t XStride = C_;
|
||||
const index_t BatchStride = K_ * X_ * C_;
|
||||
// Add NumGroupsToMerge for Batch+M dimension and, 1 as a placeholder
|
||||
// for Batch+N dimension
|
||||
const auto desc = make_naive_tensor_descriptor(
|
||||
make_tuple(NumGroupsToMerge, K_, X_, 1, C_),
|
||||
make_tuple(BatchStride, KStride, XStride, BatchStride, CXStride),
|
||||
number<VectorSizeC>{},
|
||||
I1);
|
||||
// Pad 1 to NumGroupsToMerge
|
||||
const auto padded_desc = transform_tensor_descriptor(
|
||||
desc,
|
||||
make_tuple(make_pass_through_transform(NumGroupsToMerge),
|
||||
make_pass_through_transform(K_),
|
||||
make_pass_through_transform(X_),
|
||||
make_pad_transform(1, 0, NumGroupsToMerge - 1),
|
||||
make_pass_through_transform(C_)),
|
||||
make_tuple(
|
||||
sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}, sequence<4>{}),
|
||||
make_tuple(
|
||||
sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}, sequence<4>{}));
|
||||
// We need only matrices from diagonal. Xor returns 0 for the same
|
||||
// values. So if matrices is not on diagonal then it will be stored in padding.
|
||||
// To avoid use of modulo after xor we assume that NumBatch to merge is power of 2.
|
||||
static_assert(NumGroupsToMerge == 1 || NumGroupsToMerge == 2 || NumGroupsToMerge == 4 ||
|
||||
NumGroupsToMerge == 8 || NumGroupsToMerge == 16 ||
|
||||
NumGroupsToMerge == 32 || NumGroupsToMerge == 64);
|
||||
const auto unmerged_padded_desc = transform_tensor_descriptor(
|
||||
padded_desc,
|
||||
make_tuple(make_xor_transform(make_tuple(NumGroupsToMerge, NumGroupsToMerge)),
|
||||
make_pass_through_transform(K_),
|
||||
make_pass_through_transform(X_),
|
||||
make_pass_through_transform(C_)),
|
||||
make_tuple(sequence<0, 3>{}, sequence<1>{}, sequence<2>{}, sequence<4>{}),
|
||||
make_tuple(sequence<0, 3>{}, sequence<1>{}, sequence<2>{}, sequence<4>{}));
|
||||
// Merge To M, N
|
||||
return transform_tensor_descriptor(
|
||||
unmerged_padded_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(NumGroupsToMerge, K_)),
|
||||
make_merge_transform(make_tuple(X_, NumGroupsToMerge, C_))),
|
||||
make_tuple(sequence<0, 1>{}, sequence<2, 3, 4>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_naive_tensor_descriptor(
|
||||
make_tuple(K_, X_ * C_), make_tuple(KStride, CXStride), number<VectorSizeC>{}, I1);
|
||||
}
|
||||
}
|
||||
|
||||
template <index_t NDim = NDimSpatial, typename std::enable_if<NDim == 2, bool>::type = false>
|
||||
@@ -461,11 +525,22 @@ struct TransformConvBwdWeightToGemm
|
||||
const index_t NDoHoWoStride = G_ * K_;
|
||||
constexpr auto KStride = I1;
|
||||
|
||||
// TODO Add support for NumGroupsToMerge > 1
|
||||
return make_naive_tensor_descriptor(make_tuple(N_ * Ho_ * Wo_, K_), // K_M
|
||||
make_tuple(NDoHoWoStride, KStride),
|
||||
number<VectorSizeA>{},
|
||||
I1);
|
||||
if constexpr(NumGroupsToMerge > 1)
|
||||
{
|
||||
const index_t BatchStride = K_;
|
||||
return make_naive_tensor_descriptor(
|
||||
make_tuple(N_ * Ho_ * Wo_, NumGroupsToMerge, K_), // K_Gm_M
|
||||
make_tuple(NDoHoWoStride, BatchStride, KStride),
|
||||
number<VectorSizeA>{},
|
||||
I1);
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(N_ * Ho_ * Wo_, K_), // K_M
|
||||
make_tuple(NDoHoWoStride, KStride),
|
||||
number<VectorSizeA>{},
|
||||
I1);
|
||||
}
|
||||
}
|
||||
|
||||
template <index_t NDim = NDimSpatial, typename std::enable_if<NDim == 2, bool>::type = false>
|
||||
@@ -477,11 +552,22 @@ struct TransformConvBwdWeightToGemm
|
||||
const index_t WiStride = G_ * C_;
|
||||
constexpr auto CStride = I1;
|
||||
|
||||
// TODO Add support for NumGroupsToMerge > 1
|
||||
return make_naive_tensor_descriptor(make_tuple(N_, Hi_, Wi_, C_), // K_N
|
||||
make_tuple(NStride, HiStride, WiStride, CStride),
|
||||
number<VectorSizeB>{},
|
||||
I1);
|
||||
if constexpr(NumGroupsToMerge > 1)
|
||||
{
|
||||
const auto BatchStride = C_;
|
||||
return make_naive_tensor_descriptor(
|
||||
make_tuple(N_, Hi_, Wi_, NumGroupsToMerge, C_), // K_Gm_N
|
||||
make_tuple(NStride, HiStride, WiStride, BatchStride, CStride),
|
||||
number<VectorSizeB>{},
|
||||
I1);
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(N_, Hi_, Wi_, C_), // K_N
|
||||
make_tuple(NStride, HiStride, WiStride, CStride),
|
||||
number<VectorSizeB>{},
|
||||
I1);
|
||||
}
|
||||
}
|
||||
|
||||
template <index_t NDim = NDimSpatial, typename std::enable_if<NDim == 2, bool>::type = false>
|
||||
@@ -491,9 +577,58 @@ struct TransformConvBwdWeightToGemm
|
||||
const index_t KStride = Y_ * X_ * C_;
|
||||
constexpr auto CStride = I1;
|
||||
|
||||
// TODO Add support for NumGroupsToMerge > 1
|
||||
return make_naive_tensor_descriptor(
|
||||
make_tuple(K_, Y_ * X_ * C_), make_tuple(KStride, CStride), number<VectorSizeC>{}, I1);
|
||||
if constexpr(NumGroupsToMerge > 1)
|
||||
{
|
||||
const index_t YXStride = C_;
|
||||
const index_t BatchStride = K_ * Y_ * X_ * C_;
|
||||
// Add NumGroupsToMerge for Batch+M dimension and, 1 as a placeholder
|
||||
// for Batch+N dimension
|
||||
const auto desc = make_naive_tensor_descriptor(
|
||||
make_tuple(NumGroupsToMerge, K_, Y_ * X_, 1, C_),
|
||||
make_tuple(BatchStride, KStride, YXStride, BatchStride, CStride),
|
||||
number<VectorSizeC>{},
|
||||
I1);
|
||||
// Pad 1 to NumGroupsToMerge
|
||||
const auto padded_desc = transform_tensor_descriptor(
|
||||
desc,
|
||||
make_tuple(make_pass_through_transform(NumGroupsToMerge),
|
||||
make_pass_through_transform(K_),
|
||||
make_pass_through_transform(Y_ * X_),
|
||||
make_pad_transform(1, 0, NumGroupsToMerge - 1),
|
||||
make_pass_through_transform(C_)),
|
||||
make_tuple(
|
||||
sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}, sequence<4>{}),
|
||||
make_tuple(
|
||||
sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}, sequence<4>{}));
|
||||
// We need only matrices from diagonal. Xor returns 0 for the same
|
||||
// values. So if matrices is not on diagonal then it will be stored in padding.
|
||||
// To avoid use of modulo after xor we assume that NumBatch to merge is power of 2.
|
||||
static_assert(NumGroupsToMerge == 1 || NumGroupsToMerge == 2 || NumGroupsToMerge == 4 ||
|
||||
NumGroupsToMerge == 8 || NumGroupsToMerge == 16 ||
|
||||
NumGroupsToMerge == 32 || NumGroupsToMerge == 64);
|
||||
const auto unmerged_padded_desc = transform_tensor_descriptor(
|
||||
padded_desc,
|
||||
make_tuple(make_xor_transform(make_tuple(NumGroupsToMerge, NumGroupsToMerge)),
|
||||
make_pass_through_transform(K_),
|
||||
make_pass_through_transform(Y_ * X_),
|
||||
make_pass_through_transform(C_)),
|
||||
make_tuple(sequence<0, 3>{}, sequence<1>{}, sequence<2>{}, sequence<4>{}),
|
||||
make_tuple(sequence<0, 3>{}, sequence<1>{}, sequence<2>{}, sequence<4>{}));
|
||||
// Merge To M, N
|
||||
return transform_tensor_descriptor(
|
||||
unmerged_padded_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(NumGroupsToMerge, K_)),
|
||||
make_merge_transform(make_tuple(Y_ * X_, NumGroupsToMerge, C_))),
|
||||
make_tuple(sequence<0, 1>{}, sequence<2, 3, 4>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(K_, Y_ * X_ * C_),
|
||||
make_tuple(KStride, CStride),
|
||||
number<VectorSizeC>{},
|
||||
I1);
|
||||
}
|
||||
}
|
||||
|
||||
template <index_t NDim = NDimSpatial, typename std::enable_if<NDim == 3, bool>::type = false>
|
||||
@@ -503,11 +638,22 @@ struct TransformConvBwdWeightToGemm
|
||||
const index_t NDoHoWoStride = G_ * K_;
|
||||
constexpr auto KStride = I1;
|
||||
|
||||
// TODO Add support for NumGroupsToMerge > 1
|
||||
return make_naive_tensor_descriptor(make_tuple(N_ * Do_ * Ho_ * Wo_, K_),
|
||||
make_tuple(NDoHoWoStride, KStride),
|
||||
number<VectorSizeA>{},
|
||||
I1);
|
||||
if constexpr(NumGroupsToMerge > 1)
|
||||
{
|
||||
const auto BatchStride = K_;
|
||||
return make_naive_tensor_descriptor(
|
||||
make_tuple(N_ * Do_ * Ho_ * Wo_, NumGroupsToMerge, K_),
|
||||
make_tuple(NDoHoWoStride, BatchStride, KStride),
|
||||
number<VectorSizeA>{},
|
||||
I1);
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(N_ * Do_ * Ho_ * Wo_, K_),
|
||||
make_tuple(NDoHoWoStride, KStride),
|
||||
number<VectorSizeA>{},
|
||||
I1);
|
||||
}
|
||||
}
|
||||
|
||||
template <index_t NDim = NDimSpatial, typename std::enable_if<NDim == 3, bool>::type = false>
|
||||
@@ -519,26 +665,84 @@ struct TransformConvBwdWeightToGemm
|
||||
const index_t WiStride = G_ * C_;
|
||||
constexpr auto CStride = I1;
|
||||
|
||||
// TODO Add support for NumGroupsToMerge > 1
|
||||
return make_naive_tensor_descriptor(
|
||||
make_tuple(N_, Di_, Hi_, Wi_, C_),
|
||||
make_tuple(NStride, DiStride, HiStride, WiStride, CStride),
|
||||
number<VectorSizeB>{},
|
||||
I1);
|
||||
if constexpr(NumGroupsToMerge > 1)
|
||||
{
|
||||
const index_t BatchStride = C_;
|
||||
return make_naive_tensor_descriptor(
|
||||
make_tuple(N_, Di_, Hi_, Wi_, NumGroupsToMerge, C_),
|
||||
make_tuple(NStride, DiStride, HiStride, WiStride, BatchStride, CStride),
|
||||
number<VectorSizeB>{},
|
||||
I1);
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_naive_tensor_descriptor(
|
||||
make_tuple(N_, Di_, Hi_, Wi_, C_),
|
||||
make_tuple(NStride, DiStride, HiStride, WiStride, CStride),
|
||||
number<VectorSizeB>{},
|
||||
I1);
|
||||
}
|
||||
}
|
||||
|
||||
template <index_t NDim = NDimSpatial, typename std::enable_if<NDim == 3, bool>::type = false>
|
||||
CK_TILE_HOST auto make_wei_grid_desc() const
|
||||
{
|
||||
// KZYXC
|
||||
// GKZYXC
|
||||
const index_t KStride = Z_ * Y_ * X_ * C_;
|
||||
constexpr auto CStride = I1;
|
||||
|
||||
// TODO Add support for NumGroupsToMerge > 1
|
||||
return make_naive_tensor_descriptor(make_tuple(K_, Z_ * Y_ * X_ * C_),
|
||||
make_tuple(KStride, CStride),
|
||||
number<VectorSizeC>{},
|
||||
I1);
|
||||
if constexpr(NumGroupsToMerge > 1)
|
||||
{
|
||||
const index_t ZYXStride = C_;
|
||||
const index_t BatchStride = K_ * Z_ * Y_ * X_ * C_;
|
||||
// Add NumGroupsToMerge for Batch+M dimension and, 1 as a placeholder
|
||||
// for Batch+N dimension
|
||||
const auto desc = make_naive_tensor_descriptor(
|
||||
make_tuple(NumGroupsToMerge, K_, Z_ * Y_ * X_, 1, C_),
|
||||
make_tuple(BatchStride, KStride, ZYXStride, BatchStride, CStride),
|
||||
number<VectorSizeC>{},
|
||||
I1);
|
||||
// Pad 1 to NumGroupsToMerge
|
||||
const auto padded_desc = transform_tensor_descriptor(
|
||||
desc,
|
||||
make_tuple(make_pass_through_transform(NumGroupsToMerge),
|
||||
make_pass_through_transform(K_),
|
||||
make_pass_through_transform(Z_ * Y_ * X_),
|
||||
make_pad_transform(1, 0, NumGroupsToMerge - 1),
|
||||
make_pass_through_transform(C_)),
|
||||
make_tuple(
|
||||
sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}, sequence<4>{}),
|
||||
make_tuple(
|
||||
sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}, sequence<4>{}));
|
||||
// We need only matrices from diagonal. Xor returns 0 for the same
|
||||
// values. So if matrices is not on diagonal then it will be stored in padding.
|
||||
// To avoid use of modulo after xor we assume that NumBatch to merge is power of 2.
|
||||
static_assert(NumGroupsToMerge == 1 || NumGroupsToMerge == 2 || NumGroupsToMerge == 4 ||
|
||||
NumGroupsToMerge == 8 || NumGroupsToMerge == 16 ||
|
||||
NumGroupsToMerge == 32 || NumGroupsToMerge == 64);
|
||||
const auto unmerged_padded_desc = transform_tensor_descriptor(
|
||||
padded_desc,
|
||||
make_tuple(make_xor_transform(make_tuple(NumGroupsToMerge, NumGroupsToMerge)),
|
||||
make_pass_through_transform(K_),
|
||||
make_pass_through_transform(Z_ * Y_ * X_),
|
||||
make_pass_through_transform(C_)),
|
||||
make_tuple(sequence<0, 3>{}, sequence<1>{}, sequence<2>{}, sequence<4>{}),
|
||||
make_tuple(sequence<0, 3>{}, sequence<1>{}, sequence<2>{}, sequence<4>{}));
|
||||
// Merge To M, N
|
||||
return transform_tensor_descriptor(
|
||||
unmerged_padded_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(NumGroupsToMerge, K_)),
|
||||
make_merge_transform(make_tuple(Z_ * Y_ * X_, NumGroupsToMerge, C_))),
|
||||
make_tuple(sequence<0, 1>{}, sequence<2, 3, 4>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(K_, Z_ * Y_ * X_ * C_),
|
||||
make_tuple(KStride, CStride),
|
||||
number<VectorSizeC>{},
|
||||
I1);
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: implement ck_tile::tensor_layout::convolution that describe packed/strided dimemsion as
|
||||
@@ -552,31 +756,84 @@ struct TransformConvBwdWeightToGemm
|
||||
const auto wei_grid_desc = make_wei_grid_desc<NDimSpatial>();
|
||||
|
||||
// B: input tensor comes in K_N
|
||||
const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor(
|
||||
in_grid_desc,
|
||||
make_tuple(make_pass_through_transform(N_),
|
||||
make_pad_transform(Wi_, InLeftPadW_, InRightPadW_),
|
||||
make_pass_through_transform(C_)),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}));
|
||||
if constexpr(NumGroupsToMerge > 1)
|
||||
{
|
||||
// Output tensor transformation
|
||||
// [0, 1, 2] -> [0, 1]
|
||||
// [(N*Wo), Gm, K] -> [(N*Wo), (Gm*K)]
|
||||
const auto out_gemm_k_gemm_m_grid_desc = transform_tensor_descriptor(
|
||||
out_grid_desc,
|
||||
make_tuple(make_pass_through_transform(N_ * Wo_),
|
||||
make_merge_transform(make_tuple(NumGroupsToMerge, K_))),
|
||||
make_tuple(sequence<0>{}, sequence<1, 2>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
|
||||
const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor(
|
||||
in_n_hip_wip_c_grid_desc,
|
||||
make_tuple(
|
||||
make_pass_through_transform(N_),
|
||||
make_embed_transform(make_tuple(X_, Wo_), make_tuple(ConvDilationW_, ConvStrideW_)),
|
||||
make_pass_through_transform(C_)),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1, 2>{}, sequence<3>{}));
|
||||
// Input tensor transformation, part 1.
|
||||
// [N, Wi, Gm, C] -> [N, (Wi + InLeftPadW + InRightPadW), Gm, C] = [N, Wip, Gm, C]
|
||||
const auto in_n_wip_gm_c_grid_desc = transform_tensor_descriptor(
|
||||
in_grid_desc,
|
||||
make_tuple(make_pass_through_transform(N_),
|
||||
make_pad_transform(Wi_, InLeftPadW_, InRightPadW_),
|
||||
make_pass_through_transform(NumGroupsToMerge),
|
||||
make_pass_through_transform(C_)),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}));
|
||||
|
||||
const auto in_gemmn_gemmktotal_grid_desc =
|
||||
transform_tensor_descriptor(in_n_y_ho_x_wo_c_grid_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(X_, C_)),
|
||||
make_merge_transform(make_tuple(N_, Wo_))),
|
||||
make_tuple(sequence<1, 3>{}, sequence<0, 2>{}),
|
||||
make_tuple(sequence<1>{}, sequence<0>{}));
|
||||
// Input tensor transformation, part 2.
|
||||
// [N, Wip, Gm, C] -> [N, X, Wo, Gm, C]
|
||||
const auto in_n_x_wo_gm_c_grid_desc = transform_tensor_descriptor(
|
||||
in_n_wip_gm_c_grid_desc,
|
||||
make_tuple(make_pass_through_transform(N_),
|
||||
make_embed_transform(make_tuple(X_, Wo_),
|
||||
make_tuple(ConvDilationW_, ConvStrideW_)),
|
||||
make_pass_through_transform(NumGroupsToMerge),
|
||||
make_pass_through_transform(C_)),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1, 2>{}, sequence<3>{}, sequence<4>{}));
|
||||
|
||||
return make_tuple(out_grid_desc, in_gemmn_gemmktotal_grid_desc, wei_grid_desc);
|
||||
// Input tensor transformation, part 3.
|
||||
// [0, 1, 2, 3, 4] -> [0, 1]
|
||||
// [N, X, Wo, Gm, C] -> [(N*Wo), (Gm*X*C)]
|
||||
const auto in_gemm_n_gemm_k_grid_desc = transform_tensor_descriptor(
|
||||
in_n_x_wo_gm_c_grid_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(X_, NumGroupsToMerge, C_)),
|
||||
make_merge_transform(make_tuple(N_, Wo_))),
|
||||
make_tuple(sequence<1, 3, 4>{}, sequence<0, 2>{}),
|
||||
make_tuple(sequence<1>{}, sequence<0>{}));
|
||||
|
||||
return make_tuple(
|
||||
out_gemm_k_gemm_m_grid_desc, in_gemm_n_gemm_k_grid_desc, wei_grid_desc);
|
||||
}
|
||||
else
|
||||
{
|
||||
// [N, Wi, C] -> [N, (Wi + InLeftPadW + InRightPadW), C] = [N, Wip, C]
|
||||
const auto in_n_wip_c_grid_desc = transform_tensor_descriptor(
|
||||
in_grid_desc,
|
||||
make_tuple(make_pass_through_transform(N_),
|
||||
make_pad_transform(Wi_, InLeftPadW_, InRightPadW_),
|
||||
make_pass_through_transform(C_)),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}));
|
||||
|
||||
// [N, Wip, C] -> [N, X, Wo, C]
|
||||
const auto in_n_x_wo_c_grid_desc = transform_tensor_descriptor(
|
||||
in_n_wip_c_grid_desc,
|
||||
make_tuple(make_pass_through_transform(N_),
|
||||
make_embed_transform(make_tuple(X_, Wo_),
|
||||
make_tuple(ConvDilationW_, ConvStrideW_)),
|
||||
make_pass_through_transform(C_)),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1, 2>{}, sequence<3>{}));
|
||||
|
||||
const auto in_gemmn_gemmktotal_grid_desc =
|
||||
transform_tensor_descriptor(in_n_x_wo_c_grid_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(X_, C_)),
|
||||
make_merge_transform(make_tuple(N_, Wo_))),
|
||||
make_tuple(sequence<1, 3>{}, sequence<0, 2>{}),
|
||||
make_tuple(sequence<1>{}, sequence<0>{}));
|
||||
|
||||
return make_tuple(out_grid_desc, in_gemmn_gemmktotal_grid_desc, wei_grid_desc);
|
||||
}
|
||||
}
|
||||
|
||||
template <index_t NDim = NDimSpatial, typename std::enable_if<NDim == 2, bool>::type = false>
|
||||
@@ -587,33 +844,95 @@ struct TransformConvBwdWeightToGemm
|
||||
const auto wei_grid_desc = make_wei_grid_desc<NDimSpatial>();
|
||||
|
||||
// B: input tensor comes in K_N
|
||||
const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor(
|
||||
in_grid_desc,
|
||||
make_tuple(make_pass_through_transform(N_),
|
||||
make_pad_transform(Hi_, InLeftPadH_, InRightPadH_),
|
||||
make_pad_transform(Wi_, InLeftPadW_, InRightPadW_),
|
||||
make_pass_through_transform(C_)),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}));
|
||||
if constexpr(NumGroupsToMerge > 1)
|
||||
{
|
||||
// Output tensor transformation
|
||||
// [0, 1, 2] -> [0, 1]
|
||||
// [(N*Ho*Wo), Gm, K] -> [(N*Ho*Wo), (K*Gm)]
|
||||
const auto out_gemm_k_gemm_m_grid_desc = transform_tensor_descriptor(
|
||||
out_grid_desc,
|
||||
make_tuple(make_pass_through_transform(N_ * Ho_ * Wo_),
|
||||
make_merge_transform(make_tuple(NumGroupsToMerge, K_))),
|
||||
make_tuple(sequence<0>{}, sequence<1, 2>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
|
||||
const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor(
|
||||
in_n_hip_wip_c_grid_desc,
|
||||
make_tuple(
|
||||
make_pass_through_transform(N_),
|
||||
make_embed_transform(make_tuple(Y_, Ho_), make_tuple(ConvDilationH_, ConvStrideH_)),
|
||||
make_embed_transform(make_tuple(X_, Wo_), make_tuple(ConvDilationW_, ConvStrideW_)),
|
||||
make_pass_through_transform(C_)),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1, 2>{}, sequence<3, 4>{}, sequence<5>{}));
|
||||
// Input tensor transformation, part 1.
|
||||
// [N, Hi, Wi, Gm, C] -> [N, Hip, Wip, Gm, C]
|
||||
const auto in_n_hip_wip_gm_c_grid_desc = transform_tensor_descriptor(
|
||||
in_grid_desc,
|
||||
make_tuple(make_pass_through_transform(N_),
|
||||
make_pad_transform(Hi_, InLeftPadH_, InRightPadH_),
|
||||
make_pad_transform(Wi_, InLeftPadW_, InRightPadW_),
|
||||
make_pass_through_transform(NumGroupsToMerge),
|
||||
make_pass_through_transform(C_)),
|
||||
make_tuple(
|
||||
sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}, sequence<4>{}),
|
||||
make_tuple(
|
||||
sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}, sequence<4>{}));
|
||||
|
||||
const auto in_gemmn_gemmktotal_grid_desc =
|
||||
transform_tensor_descriptor(in_n_y_ho_x_wo_c_grid_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(Y_, X_, C_)),
|
||||
make_merge_transform(make_tuple(N_, Ho_, Wo_))),
|
||||
make_tuple(sequence<1, 3, 5>{}, sequence<0, 2, 4>{}),
|
||||
make_tuple(sequence<1>{}, sequence<0>{}));
|
||||
// Input tensor transformation, part 2.
|
||||
// [N, Hip, Wip, Gm, C] -> [N, (Y, Wo), (X, Wo), Gm, C]
|
||||
const auto in_n_y_ho_x_wo_gm_c_grid_desc = transform_tensor_descriptor(
|
||||
in_n_hip_wip_gm_c_grid_desc,
|
||||
make_tuple(make_pass_through_transform(N_),
|
||||
make_embed_transform(make_tuple(Y_, Ho_),
|
||||
make_tuple(ConvDilationH_, ConvStrideH_)),
|
||||
make_embed_transform(make_tuple(X_, Wo_),
|
||||
make_tuple(ConvDilationW_, ConvStrideW_)),
|
||||
make_pass_through_transform(NumGroupsToMerge),
|
||||
make_pass_through_transform(C_)),
|
||||
make_tuple(
|
||||
sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}, sequence<4>{}),
|
||||
make_tuple(sequence<0>{},
|
||||
sequence<1, 2>{},
|
||||
sequence<3, 4>{},
|
||||
sequence<5>{},
|
||||
sequence<6>{}));
|
||||
|
||||
return make_tuple(out_grid_desc, in_gemmn_gemmktotal_grid_desc, wei_grid_desc);
|
||||
// Input tensor transformation, part 3.
|
||||
// [0, 1, 2, 3, 4 5 6] -> [0, 1]
|
||||
// [N, Y, Ho, X, Wo, Gm, C] -> [(N*Ho*Wo), (Gm*Y*X*C)]
|
||||
const auto in_gemm_n_gemm_k_grid_desc = transform_tensor_descriptor(
|
||||
in_n_y_ho_x_wo_gm_c_grid_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(Y_, X_, NumGroupsToMerge, C_)),
|
||||
make_merge_transform(make_tuple(N_, Ho_, Wo_))),
|
||||
make_tuple(sequence<1, 3, 5, 6>{}, sequence<0, 2, 4>{}),
|
||||
make_tuple(sequence<1>{}, sequence<0>{}));
|
||||
|
||||
return make_tuple(
|
||||
out_gemm_k_gemm_m_grid_desc, in_gemm_n_gemm_k_grid_desc, wei_grid_desc);
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor(
|
||||
in_grid_desc,
|
||||
make_tuple(make_pass_through_transform(N_),
|
||||
make_pad_transform(Hi_, InLeftPadH_, InRightPadH_),
|
||||
make_pad_transform(Wi_, InLeftPadW_, InRightPadW_),
|
||||
make_pass_through_transform(C_)),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}));
|
||||
|
||||
const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor(
|
||||
in_n_hip_wip_c_grid_desc,
|
||||
make_tuple(make_pass_through_transform(N_),
|
||||
make_embed_transform(make_tuple(Y_, Ho_),
|
||||
make_tuple(ConvDilationH_, ConvStrideH_)),
|
||||
make_embed_transform(make_tuple(X_, Wo_),
|
||||
make_tuple(ConvDilationW_, ConvStrideW_)),
|
||||
make_pass_through_transform(C_)),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1, 2>{}, sequence<3, 4>{}, sequence<5>{}));
|
||||
|
||||
const auto in_gemmn_gemmktotal_grid_desc = transform_tensor_descriptor(
|
||||
in_n_y_ho_x_wo_c_grid_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(Y_, X_, C_)),
|
||||
make_merge_transform(make_tuple(N_, Ho_, Wo_))),
|
||||
make_tuple(sequence<1, 3, 5>{}, sequence<0, 2, 4>{}),
|
||||
make_tuple(sequence<1>{}, sequence<0>{}));
|
||||
|
||||
return make_tuple(out_grid_desc, in_gemmn_gemmktotal_grid_desc, wei_grid_desc);
|
||||
}
|
||||
}
|
||||
|
||||
template <index_t NDim = NDimSpatial, typename std::enable_if<NDim == 3, bool>::type = false>
|
||||
@@ -624,39 +943,121 @@ struct TransformConvBwdWeightToGemm
|
||||
const auto wei_grid_desc = make_wei_grid_desc<NDimSpatial>();
|
||||
|
||||
// B: input tensor comes in K_N
|
||||
const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor(
|
||||
in_grid_desc,
|
||||
make_tuple(make_pass_through_transform(N_),
|
||||
make_pad_transform(Di_, InLeftPadD_, InRightPadD_),
|
||||
make_pad_transform(Hi_, InLeftPadH_, InRightPadH_),
|
||||
make_pad_transform(Wi_, InLeftPadW_, InRightPadW_),
|
||||
make_pass_through_transform(C_)),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}, sequence<4>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}, sequence<4>{}));
|
||||
if constexpr(NumGroupsToMerge > 1)
|
||||
{
|
||||
// Output tensor transformation
|
||||
// [0, 1, 2] -> [0, 1]
|
||||
// [(N*Do*Ho*Wo), Gm, K] -> [(N*Do*Ho*Wo), (K*Gm)]
|
||||
const auto out_gemm_k_gemm_m_grid_desc = transform_tensor_descriptor(
|
||||
out_grid_desc,
|
||||
make_tuple(make_pass_through_transform(N_ * Do_ * Ho_ * Wo_),
|
||||
make_merge_transform(make_tuple(NumGroupsToMerge, K_))),
|
||||
make_tuple(sequence<0>{}, sequence<1, 2>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
|
||||
const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor(
|
||||
in_n_hip_wip_c_grid_desc,
|
||||
make_tuple(
|
||||
make_pass_through_transform(N_),
|
||||
make_embed_transform(make_tuple(Z_, Do_), make_tuple(ConvDilationD_, ConvStrideD_)),
|
||||
make_embed_transform(make_tuple(Y_, Ho_), make_tuple(ConvDilationH_, ConvStrideH_)),
|
||||
make_embed_transform(make_tuple(X_, Wo_), make_tuple(ConvDilationW_, ConvStrideW_)),
|
||||
make_pass_through_transform(C_)),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}, sequence<4>{}),
|
||||
make_tuple(sequence<0>{},
|
||||
sequence<1, 2>{},
|
||||
sequence<3, 4>{},
|
||||
sequence<5, 6>{},
|
||||
sequence<7>{}));
|
||||
// Input tensor transformation, part 1.
|
||||
// [N, Di, Hi, Wi, Gm, C] -> [N, Dip, Hip, Wip, Gm, C]
|
||||
const auto in_n_dip_hip_wip_gm_c_grid_desc = transform_tensor_descriptor(
|
||||
in_grid_desc,
|
||||
make_tuple(make_pass_through_transform(N_),
|
||||
make_pad_transform(Di_, InLeftPadD_, InRightPadD_),
|
||||
make_pad_transform(Hi_, InLeftPadH_, InRightPadH_),
|
||||
make_pad_transform(Wi_, InLeftPadW_, InRightPadW_),
|
||||
make_pass_through_transform(NumGroupsToMerge),
|
||||
make_pass_through_transform(C_)),
|
||||
make_tuple(sequence<0>{},
|
||||
sequence<1>{},
|
||||
sequence<2>{},
|
||||
sequence<3>{},
|
||||
sequence<4>{},
|
||||
sequence<5>{}),
|
||||
make_tuple(sequence<0>{},
|
||||
sequence<1>{},
|
||||
sequence<2>{},
|
||||
sequence<3>{},
|
||||
sequence<4>{},
|
||||
sequence<5>{}));
|
||||
|
||||
const auto in_gemmn_gemmktotal_grid_desc = transform_tensor_descriptor(
|
||||
in_n_y_ho_x_wo_c_grid_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(Z_, Y_, X_, C_)),
|
||||
make_merge_transform(make_tuple(N_, Do_, Ho_, Wo_))),
|
||||
make_tuple(sequence<1, 3, 5, 7>{}, sequence<0, 2, 4, 6>{}),
|
||||
make_tuple(sequence<1>{}, sequence<0>{}));
|
||||
// Input tensor transformation, part 2.
|
||||
// [N, Zip, Hip, Wip, Gm, C] -> [N, (Z, Zo), (Y, Wo), (X, Wo), Gm, C]
|
||||
const auto in_n_z_do_y_ho_x_wo_gm_c_grid_desc = transform_tensor_descriptor(
|
||||
in_n_dip_hip_wip_gm_c_grid_desc,
|
||||
make_tuple(make_pass_through_transform(N_),
|
||||
make_embed_transform(make_tuple(Z_, Do_),
|
||||
make_tuple(ConvDilationD_, ConvStrideD_)),
|
||||
make_embed_transform(make_tuple(Y_, Ho_),
|
||||
make_tuple(ConvDilationH_, ConvStrideH_)),
|
||||
make_embed_transform(make_tuple(X_, Wo_),
|
||||
make_tuple(ConvDilationW_, ConvStrideW_)),
|
||||
make_pass_through_transform(NumGroupsToMerge),
|
||||
make_pass_through_transform(C_)),
|
||||
make_tuple(sequence<0>{},
|
||||
sequence<1>{},
|
||||
sequence<2>{},
|
||||
sequence<3>{},
|
||||
sequence<4>{},
|
||||
sequence<5>{}),
|
||||
make_tuple(sequence<0>{},
|
||||
sequence<1, 2>{},
|
||||
sequence<3, 4>{},
|
||||
sequence<5, 6>{},
|
||||
sequence<7>{},
|
||||
sequence<8>{}));
|
||||
|
||||
return make_tuple(out_grid_desc, in_gemmn_gemmktotal_grid_desc, wei_grid_desc);
|
||||
// Input tensor transformation, part 3.
|
||||
// [0, 1, 2, 3, 4, 5, 6, 7, 8] -> [0, 1]
|
||||
// [N, Z, Do, Y, Ho, X, Wo, Gm, C] -> [(N*Do*Ho*Wo), (Z*Y*X*Gm*C)]
|
||||
const auto in_gemm_k_gemm_n_grid_desc = transform_tensor_descriptor(
|
||||
in_n_z_do_y_ho_x_wo_gm_c_grid_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(Z_, Y_, X_, NumGroupsToMerge, C_)),
|
||||
make_merge_transform(make_tuple(N_, Do_, Ho_, Wo_))),
|
||||
make_tuple(sequence<1, 3, 5, 7, 8>{}, sequence<0, 2, 4, 6>{}),
|
||||
make_tuple(sequence<1>{}, sequence<0>{}));
|
||||
|
||||
return make_tuple(
|
||||
out_gemm_k_gemm_m_grid_desc, in_gemm_k_gemm_n_grid_desc, wei_grid_desc);
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor(
|
||||
in_grid_desc,
|
||||
make_tuple(make_pass_through_transform(N_),
|
||||
make_pad_transform(Di_, InLeftPadD_, InRightPadD_),
|
||||
make_pad_transform(Hi_, InLeftPadH_, InRightPadH_),
|
||||
make_pad_transform(Wi_, InLeftPadW_, InRightPadW_),
|
||||
make_pass_through_transform(C_)),
|
||||
make_tuple(
|
||||
sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}, sequence<4>{}),
|
||||
make_tuple(
|
||||
sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}, sequence<4>{}));
|
||||
|
||||
const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor(
|
||||
in_n_hip_wip_c_grid_desc,
|
||||
make_tuple(make_pass_through_transform(N_),
|
||||
make_embed_transform(make_tuple(Z_, Do_),
|
||||
make_tuple(ConvDilationD_, ConvStrideD_)),
|
||||
make_embed_transform(make_tuple(Y_, Ho_),
|
||||
make_tuple(ConvDilationH_, ConvStrideH_)),
|
||||
make_embed_transform(make_tuple(X_, Wo_),
|
||||
make_tuple(ConvDilationW_, ConvStrideW_)),
|
||||
make_pass_through_transform(C_)),
|
||||
make_tuple(
|
||||
sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}, sequence<4>{}),
|
||||
make_tuple(sequence<0>{},
|
||||
sequence<1, 2>{},
|
||||
sequence<3, 4>{},
|
||||
sequence<5, 6>{},
|
||||
sequence<7>{}));
|
||||
|
||||
const auto in_gemmn_gemmktotal_grid_desc = transform_tensor_descriptor(
|
||||
in_n_y_ho_x_wo_c_grid_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(Z_, Y_, X_, C_)),
|
||||
make_merge_transform(make_tuple(N_, Do_, Ho_, Wo_))),
|
||||
make_tuple(sequence<1, 3, 5, 7>{}, sequence<0, 2, 4, 6>{}),
|
||||
make_tuple(sequence<1>{}, sequence<0>{}));
|
||||
|
||||
return make_tuple(out_grid_desc, in_gemmn_gemmktotal_grid_desc, wei_grid_desc);
|
||||
}
|
||||
}
|
||||
|
||||
IndexType G_, N_;
|
||||
@@ -668,7 +1069,6 @@ struct TransformConvBwdWeightToGemm
|
||||
IndexType ConvDilationD_, ConvDilationH_, ConvDilationW_;
|
||||
IndexType InLeftPadD_, InLeftPadH_, InLeftPadW_;
|
||||
IndexType InRightPadD_, InRightPadH_, InRightPadW_;
|
||||
IndexType ZYX_;
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -13,10 +13,10 @@ template <index_t NDimSpatial,
|
||||
index_t VectorSizeA,
|
||||
index_t VectorSizeB,
|
||||
index_t VectorSizeC,
|
||||
index_t NumGroupsToMerge = 1,
|
||||
bool SplitN = false,
|
||||
typename ADataType = float,
|
||||
typename CDataType = float,
|
||||
index_t NumGroupsToMerge = 1,
|
||||
typename IndexType = index_t>
|
||||
struct TransformConvFwdToGemm
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user