[CK_Tile] Merge multiple convolution groups into a single GEMM batch (#2986)

* Fix compilation of the grouped conv examples.

* Fix grouped conv bwd weight example output in CK Tile.

* Add number of groups to merge to ck tile grouped gemm example.

* Initial set of tests for TransformConvBwdWeightToGemm.

* Added unit tests for TransformConvBwdWeightToGemm conv groups are merged.

* WIP: Tensor transformations.

* Add unit tests for coordinate transforms.

* Fully working conv group merging for TransformConvBwdWeightToGemm.

* WIP: Merged conv groups offset calculation.

* Adde unit tests for tensor view.

* WIP: Merged conv groups epilogue.

* Enable running multiple conv groups per batch.

* Add tests for tile_distribution_encoding.

* Change example to match optimally depthwise convolution with merged groups.

* Add more tests for tensor view.

* Integration test for reading diagonal blocks from grouped distributed tensor.

* Improved integration test.

* Improve test for accessing diagonal blocks.

* Added integration test for cshuffle epilogue LDS tile distribution.

* Add more logging.

* Increase the max number of reported errors.

* WIP: merged conv groups GEMM epilogue changes.

* LDS to global memory copy.

* Fix tile window size for c block.

* Integration test for CShuffle epilogue.

* Improved CShuffle test.

* WIP: Separate epilogue for merged conv groups.

* Tile example parameters changes to match depthwise conv.

* Offset fixes.

* Epilogue fixes.

* Working baseline for depthwise covolution with merged conv groups.

* Fix build.

* Initial unit tests for tensor descriptor.

* Add one more unit test for tensor view.

* WIP: LDS to global mem transfer using CK tile tensor descriptor and tile distribution encoding.

* Fully functional LDS to global mem transfer using tensor descriptor and tile distribution encoding.

* Add more comments, disable debug code.

* Remove debug and other dead code.

* Code clean-up for bwd tensor transformations.

* Enable running multiple GEMM batches of merged conv groups.

* Add compile check for assumed row-mjor layout.

* Fix strides in 1D conv to gemm transformation.

* WIP: Simplify conv to gemm transformations and handle K > 1 and C > 1 cases.

* Fix case k > 1 and c=1.

* Remove debug code.

* Make MPerGroup and NPerGroup template parameters.

* Add additional check for non-supported c > 1 case.

* WIP: Put back the generic tensor descriptors for convolutions.

* Fix tensor descriptors.

* Remove the obsolete template parameters.

* Add more instances.

* Fix bugs in merged conv groups tensor descriptors.

* Fix tensor descriptors for merged conv groups when K > 1.

* Remove debug output.

* Remove dead code.

* Fix merge conflicts.

* Code clean-up.

* Remove unused code.

* Run clang-formatting.

* Remove debug prints and obsolete tests.

* Check that number of convolution groups is multiple of merged groups.

* Fix build after removing obsolete functionality.

* Remove obsolete enumeration.

* Fix new unit projects.

* Remove unnecessary includes.

* Fix passing the number of merged groups.

* Remove unrelated tests.

* Fix IsSupportedArgument for bwd weight conv kernel.

* Fix clang formatting.

* Fix the bwd weight conv to gemm mapping for num merged groups > 1.

* GEMM config for conv group merging.

* Fix clang-formatting.

* Remove obsolete comment.

* Fix typos in comment strings.

* Increase the max number of reported errors when testing against reference implementation.

* Rename gemm_config to conv_config.

* Rename GemmConfig to ConvConfig and move NumGroupsToMerge into ConvConfig.

* Change num_groups_to_merge to a boolean flag in the ck tile grouped conv example.

* Run clang-format.

* Add number of merged groups into kernel name string.

* Remove group merging flag from CK Tile grouped conv example.

[ROCm/composable_kernel commit: 121bf0e1f3]
This commit is contained in:
Ville Pietilä
2025-10-29 16:49:28 +02:00
committed by GitHub
parent df90bcbfd0
commit 88910537bf
17 changed files with 755 additions and 269 deletions

View File

@@ -17,7 +17,7 @@
#define CK_TILE_PIPELINE_COMPUTE_V4 3
#define CK_TILE_PIPELINE_COMPUTE_V5 4
struct GemmConfigBase
struct ConvConfigBase
{
static constexpr bool kPadM = true;
static constexpr bool kPadN = true;
@@ -29,6 +29,10 @@ struct GemmConfigBase
static constexpr bool TransposeC = false;
static constexpr bool UseStructuredSparsity = false;
static constexpr ck_tile::index_t VectorSizeA = 4;
static constexpr ck_tile::index_t VectorSizeB = 8;
static constexpr ck_tile::index_t VectorSizeC = 8;
static constexpr int kBlockPerCu = 1;
static constexpr ck_tile::index_t TileParitionerGroupNum = 8;
static constexpr ck_tile::index_t TileParitionerM01 = 4;
@@ -37,10 +41,12 @@ struct GemmConfigBase
static constexpr ck_tile::index_t NumWaveGroups = 1;
static constexpr bool Preshuffle = false;
static constexpr bool TiledMMAPermuteN = false;
static constexpr ck_tile::index_t NumGroupsToMerge = 1;
};
template <typename PrecType>
struct GemmConfigMemoryInterwave : public GemmConfigBase
struct ConvConfigMemoryInterwave : public ConvConfigBase
{
// Memory friendly for Interwave scheduler
static constexpr ck_tile::index_t M_Tile = 128;
@@ -61,7 +67,7 @@ struct GemmConfigMemoryInterwave : public GemmConfigBase
};
template <typename PrecType>
struct GemmConfigMemoryIntrawave : public GemmConfigBase
struct ConvConfigMemoryIntrawave : public ConvConfigBase
{
static constexpr ck_tile::index_t M_Tile = 128;
static constexpr ck_tile::index_t N_Tile = 32;
@@ -80,7 +86,7 @@ struct GemmConfigMemoryIntrawave : public GemmConfigBase
};
template <typename PrecType>
struct GemmConfigComputeV3 : public GemmConfigBase
struct ConvConfigComputeV3 : public ConvConfigBase
{
// Compute V3 only support Intrawave scheduler
static constexpr ck_tile::index_t M_Tile = 16;
@@ -100,7 +106,7 @@ struct GemmConfigComputeV3 : public GemmConfigBase
};
template <typename PrecType>
struct GemmConfigComputeV3_1 : public GemmConfigBase
struct ConvConfigComputeV3_1 : public ConvConfigBase
{
static constexpr ck_tile::index_t M_Tile = 256;
static constexpr ck_tile::index_t N_Tile = 256;
@@ -119,7 +125,7 @@ struct GemmConfigComputeV3_1 : public GemmConfigBase
};
template <typename PrecType>
struct GemmConfigComputeV3_2 : public GemmConfigBase
struct ConvConfigComputeV3_2 : public ConvConfigBase
{
static constexpr ck_tile::index_t M_Tile = 128;
static constexpr ck_tile::index_t N_Tile = 128;
@@ -140,7 +146,7 @@ struct GemmConfigComputeV3_2 : public GemmConfigBase
};
template <typename PrecType>
struct GemmConfigComputeV3_WMMA : public GemmConfigBase
struct ConvConfigComputeV3_WMMA : public ConvConfigBase
{
static constexpr ck_tile::index_t M_Tile = 128;
static constexpr ck_tile::index_t N_Tile = 128;
@@ -161,7 +167,7 @@ struct GemmConfigComputeV3_WMMA : public GemmConfigBase
};
template <typename PrecType>
struct GemmConfigComputeV4 : public GemmConfigBase
struct ConvConfigComputeV4 : public ConvConfigBase
{
// Compute V4 only support Intrawave scheduler
// Using the ping pong reader in the lds level
@@ -182,7 +188,7 @@ struct GemmConfigComputeV4 : public GemmConfigBase
};
template <typename PrecType>
struct GemmConfigComputeV4_1 : public GemmConfigBase
struct ConvConfigComputeV4_1 : public ConvConfigBase
{
static constexpr ck_tile::index_t M_Tile = 256;
static constexpr ck_tile::index_t N_Tile = 256;
@@ -201,7 +207,7 @@ struct GemmConfigComputeV4_1 : public GemmConfigBase
};
template <typename PrecType>
struct GemmConfigComputeV5 : public GemmConfigBase
struct ConvConfigComputeV5 : public ConvConfigBase
{
static constexpr ck_tile::index_t M_Tile = 128;
static constexpr ck_tile::index_t N_Tile = 128;
@@ -220,6 +226,31 @@ struct GemmConfigComputeV5 : public GemmConfigBase
static constexpr ck_tile::index_t NumWaNumWaveGroups = 2;
};
template <typename PrecType>
struct ConvConfigComputeV3_merged_groups : public ConvConfigBase
{
static constexpr ck_tile::index_t VectorSizeA = 4;
static constexpr ck_tile::index_t VectorSizeB = 8;
static constexpr ck_tile::index_t VectorSizeC = 8;
static constexpr ck_tile::index_t M_Tile = 16;
static constexpr ck_tile::index_t N_Tile = 32;
static constexpr ck_tile::index_t K_Tile = 32;
static constexpr ck_tile::index_t M_Warp = 1;
static constexpr ck_tile::index_t N_Warp = 2;
static constexpr ck_tile::index_t K_Warp = 1;
static constexpr ck_tile::index_t M_Warp_Tile = 16;
static constexpr ck_tile::index_t N_Warp_Tile = 16;
static constexpr ck_tile::index_t K_Warp_Tile = 32;
static constexpr bool DoubleSmemBuffer = false;
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3;
static constexpr ck_tile::index_t NumGroupsToMerge = 2;
};
template <typename InDataType, typename WeiDataType = InDataType, typename OutDataType = InDataType>
struct ConvTypeConfig;

View File

@@ -51,8 +51,8 @@ int run_grouped_conv_bwd_data_example(int argc, char* argv[])
int main(int argc, char* argv[])
{
#if CK_TILE_USE_WMMA
return !run_grouped_conv_bwd_data_example<GemmConfigComputeV3_WMMA>(argc, argv);
return !run_grouped_conv_bwd_data_example<ConvConfigComputeV3_WMMA>(argc, argv);
#else
return !run_grouped_conv_bwd_data_example<GemmConfigComputeV3>(argc, argv);
return !run_grouped_conv_bwd_data_example<ConvConfigComputeV3>(argc, argv);
#endif
}

View File

@@ -14,7 +14,7 @@
#include "grouped_convolution_backward_weight_invoker.hpp"
#include "run_grouped_convolution_bwd_weight_example.inc"
template <template <typename PrecType> typename GemmConfig>
template <template <typename PrecType> typename ConvConfig>
int run_grouped_conv_bwd_weight_example(ck_tile::ArgParser& arg_parser)
{
using Invoker = GroupedConvolutionBackwardWeightInvoker;
@@ -27,14 +27,14 @@ int run_grouped_conv_bwd_weight_example(ck_tile::ArgParser& arg_parser)
if(data_type == "fp16")
{
return run_grouped_conv_bwd_weight_example_prec_type<Invoker,
GemmConfig<ck_tile::half_t>,
ConvConfig<ck_tile::half_t>,
ck_tile::half_t>(
in_layout, wei_layout, out_layout, arg_parser);
}
else if(data_type == "bf16")
{
return run_grouped_conv_bwd_weight_example_prec_type<Invoker,
GemmConfig<ck_tile::bf16_t>,
ConvConfig<ck_tile::bf16_t>,
ck_tile::bf16_t>(
in_layout, wei_layout, out_layout, arg_parser);
}
@@ -54,9 +54,9 @@ int main(int argc, char* argv[])
try
{
#if CK_TILE_USE_WMMA
return !run_grouped_conv_bwd_weight_example<GemmConfigComputeV3_WMMA>(arg_parser);
return !run_grouped_conv_bwd_weight_example<ConvConfigComputeV3_WMMA>(arg_parser);
#else
return !run_grouped_conv_bwd_weight_example<GemmConfigComputeV3>(arg_parser);
return !run_grouped_conv_bwd_weight_example<ConvConfigComputeV3>(arg_parser);
#endif
}
catch(const std::runtime_error& e)

View File

@@ -7,7 +7,7 @@
struct GroupedConvolutionBackwardWeightInvoker
{
template <ck_tile::index_t NDimSpatial,
typename GemmConfig,
typename ConvConfig,
typename InDataType,
typename WeiDataType,
typename AccDataType,
@@ -25,22 +25,22 @@ struct GroupedConvolutionBackwardWeightInvoker
// Implicit GEMM Traits
using GemmShape = ck_tile::TileGemmShape<
ck_tile::sequence<GemmConfig::M_Tile, GemmConfig::N_Tile, GemmConfig::K_Tile>,
ck_tile::sequence<GemmConfig::M_Warp, GemmConfig::N_Warp, GemmConfig::K_Warp>,
ck_tile::sequence<ConvConfig::M_Tile, ConvConfig::N_Tile, ConvConfig::K_Tile>,
ck_tile::sequence<ConvConfig::M_Warp, ConvConfig::N_Warp, ConvConfig::K_Warp>,
ck_tile::
sequence<GemmConfig::M_Warp_Tile, GemmConfig::N_Warp_Tile, GemmConfig::K_Warp_Tile>,
GemmConfig::PermuteA,
GemmConfig::PermuteB>;
sequence<ConvConfig::M_Warp_Tile, ConvConfig::N_Warp_Tile, ConvConfig::K_Warp_Tile>,
ConvConfig::PermuteA,
ConvConfig::PermuteB>;
constexpr ck_tile::index_t VectorSizeA = 4;
constexpr ck_tile::index_t VectorSizeB = 8;
constexpr ck_tile::index_t VectorSizeC = 8;
constexpr ck_tile::index_t VectorSizeA = ConvConfig::VectorSizeA;
constexpr ck_tile::index_t VectorSizeB = ConvConfig::VectorSizeB;
constexpr ck_tile::index_t VectorSizeC = ConvConfig::VectorSizeC;
constexpr auto ConvSpec = ck_tile::ConvolutionSpecialization::Default;
using TilePartitioner =
ck_tile::GemmSpatiallyLocalTilePartitioner<GemmShape,
GemmConfig::TileParitionerGroupNum,
GemmConfig::TileParitionerM01>;
ConvConfig::TileParitionerGroupNum,
ConvConfig::TileParitionerM01>;
using GroupedConvTraitsType = ck_tile::GroupedConvTraits<NDimSpatial,
ConvSpec,
InLayout,
@@ -49,20 +49,21 @@ struct GroupedConvolutionBackwardWeightInvoker
OutLayout,
VectorSizeA,
VectorSizeB,
VectorSizeC>;
VectorSizeC,
ConvConfig::NumGroupsToMerge>;
using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits<
GemmConfig::kPadM,
GemmConfig::kPadN,
GemmConfig::kPadK,
GemmConfig::DoubleSmemBuffer,
ConvConfig::kPadM,
ConvConfig::kPadN,
ConvConfig::kPadK,
ConvConfig::DoubleSmemBuffer,
typename GroupedConvTraitsType::GroupedConvImplicitGemmTraitsBwdWeight::AsLayout,
typename GroupedConvTraitsType::GroupedConvImplicitGemmTraitsBwdWeight::BsLayout,
typename GroupedConvTraitsType::GroupedConvImplicitGemmTraitsBwdWeight::CLayout,
GemmConfig::TransposeC,
GemmConfig::UseStructuredSparsity,
ConvConfig::TransposeC,
ConvConfig::UseStructuredSparsity,
false, // Persistent,
GemmConfig::NumWaveGroups>;
ConvConfig::NumWaveGroups>;
using GemmPipelineProblem = ck_tile::GemmPipelineProblem<
OutDataType,
@@ -78,7 +79,7 @@ struct GroupedConvolutionBackwardWeightInvoker
VectorSizeB>;
using BaseGemmPipeline = typename PipelineTypeTraits<
GemmConfig::Pipeline>::template UniversalGemmPipeline<GemmPipelineProblem>;
ConvConfig::Pipeline>::template UniversalGemmPipeline<GemmPipelineProblem>;
const ck_tile::index_t gemm_k =
args.N_ * std::accumulate(args.output_spatial_lengths_.begin(),
@@ -86,8 +87,8 @@ struct GroupedConvolutionBackwardWeightInvoker
1,
std::multiplies<ck_tile::index_t>());
const ck_tile::index_t k_grain = args.k_batch * GemmConfig::K_Tile;
const ck_tile::index_t K_split = (gemm_k + k_grain - 1) / k_grain * GemmConfig::K_Tile;
const ck_tile::index_t k_grain = args.k_batch * ConvConfig::K_Tile;
const ck_tile::index_t K_split = (gemm_k + k_grain - 1) / k_grain * ConvConfig::K_Tile;
const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split);
const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop);
const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop);
@@ -98,7 +99,7 @@ struct GroupedConvolutionBackwardWeightInvoker
const auto memory_operation_) {
constexpr bool has_hot_loop_v = has_hot_loop_.value;
constexpr auto tail_number_v = tail_number_.value;
constexpr auto scheduler = GemmConfig::Scheduler;
constexpr auto scheduler = ConvConfig::Scheduler;
constexpr auto memory_operation = memory_operation_.value;
using UniversalGemmProblem =
@@ -118,7 +119,7 @@ struct GroupedConvolutionBackwardWeightInvoker
VectorSizeB>;
using GemmPipeline = typename PipelineTypeTraits<
GemmConfig::Pipeline>::template GemmPipeline<UniversalGemmProblem>;
ConvConfig::Pipeline>::template GemmPipeline<UniversalGemmProblem>;
using ConvEpilogue = ck_tile::CShuffleEpilogue<ck_tile::CShuffleEpilogueProblem<
OutDataType,
@@ -131,12 +132,12 @@ struct GroupedConvolutionBackwardWeightInvoker
CDEElementWise,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
GemmConfig::M_Warp,
GemmConfig::N_Warp,
GemmConfig::M_Warp_Tile,
GemmConfig::N_Warp_Tile,
GemmConfig::K_Warp_Tile,
GemmConfig::TransposeC,
ConvConfig::M_Warp,
ConvConfig::N_Warp,
ConvConfig::M_Warp_Tile,
ConvConfig::N_Warp_Tile,
ConvConfig::K_Warp_Tile,
ConvConfig::TransposeC,
memory_operation,
1,
true,

View File

@@ -13,9 +13,9 @@
#include "grouped_convolution_utils.hpp"
#include "grouped_convolution_backward_weight_two_stage_invoker.hpp"
#include "run_grouped_convolution_bwd_weight_example.inc"
#include "gemm_configs.hpp"
#include "conv_configs.hpp"
template <template <typename PrecType> typename GemmConfig>
template <template <typename PrecType> typename ConvConfig>
int run_grouped_conv_bwd_weight_example(ck_tile::ArgParser& arg_parser)
{
using Invoker = GroupedConvolutionBackwardWeightTwoStageInvoker;
@@ -28,14 +28,14 @@ int run_grouped_conv_bwd_weight_example(ck_tile::ArgParser& arg_parser)
if(data_type == "fp16")
{
return run_grouped_conv_bwd_weight_example_prec_type<Invoker,
GemmConfig<ck_tile::half_t>,
ConvConfig<ck_tile::half_t>,
ck_tile::half_t>(
in_layout, wei_layout, out_layout, arg_parser);
}
else if(data_type == "bf16")
{
return run_grouped_conv_bwd_weight_example_prec_type<Invoker,
GemmConfig<ck_tile::bf16_t>,
ConvConfig<ck_tile::bf16_t>,
ck_tile::bf16_t>(
in_layout, wei_layout, out_layout, arg_parser);
}
@@ -55,9 +55,9 @@ int main(int argc, char* argv[])
try
{
#if CK_TILE_USE_WMMA
return !run_grouped_conv_bwd_weight_example<GemmConfigComputeV3_WMMA>(arg_parser);
return !run_grouped_conv_bwd_weight_example<ConvConfigComputeV3_WMMA>(arg_parser);
#else
return !run_grouped_conv_bwd_weight_example<GemmConfigComputeV3>(arg_parser);
return !run_grouped_conv_bwd_weight_example<ConvConfigComputeV3>(arg_parser);
#endif
}
catch(const std::runtime_error& e)

View File

@@ -7,7 +7,7 @@
struct GroupedConvolutionBackwardWeightTwoStageInvoker
{
template <ck_tile::index_t NDimSpatial,
typename GemmConfig,
typename ConvConfig,
typename InDataType,
typename WeiDataType,
typename AccDataType,
@@ -27,12 +27,12 @@ struct GroupedConvolutionBackwardWeightTwoStageInvoker
// Implicit GEMM Traits
using GemmShape = ck_tile::TileGemmShape<
ck_tile::sequence<GemmConfig::M_Tile, GemmConfig::N_Tile, GemmConfig::K_Tile>,
ck_tile::sequence<GemmConfig::M_Warp, GemmConfig::N_Warp, GemmConfig::K_Warp>,
ck_tile::sequence<ConvConfig::M_Tile, ConvConfig::N_Tile, ConvConfig::K_Tile>,
ck_tile::sequence<ConvConfig::M_Warp, ConvConfig::N_Warp, ConvConfig::K_Warp>,
ck_tile::
sequence<GemmConfig::M_Warp_Tile, GemmConfig::N_Warp_Tile, GemmConfig::K_Warp_Tile>,
GemmConfig::PermuteA,
GemmConfig::PermuteB>;
sequence<ConvConfig::M_Warp_Tile, ConvConfig::N_Warp_Tile, ConvConfig::K_Warp_Tile>,
ConvConfig::PermuteA,
ConvConfig::PermuteB>;
constexpr ck_tile::index_t VectorSizeA = 4;
constexpr ck_tile::index_t VectorSizeB = 8;
@@ -41,8 +41,8 @@ struct GroupedConvolutionBackwardWeightTwoStageInvoker
constexpr auto ConvSpec = ck_tile::ConvolutionSpecialization::Default;
using TilePartitioner =
ck_tile::GemmSpatiallyLocalTilePartitioner<GemmShape,
GemmConfig::TileParitionerGroupNum,
GemmConfig::TileParitionerM01>;
ConvConfig::TileParitionerGroupNum,
ConvConfig::TileParitionerM01>;
using GroupedConvTraitsType = ck_tile::GroupedConvTraits<NDimSpatial,
ConvSpec,
InLayout,
@@ -54,17 +54,17 @@ struct GroupedConvolutionBackwardWeightTwoStageInvoker
VectorSizeC>;
using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits<
GemmConfig::kPadM,
GemmConfig::kPadN,
GemmConfig::kPadK,
GemmConfig::DoubleSmemBuffer,
ConvConfig::kPadM,
ConvConfig::kPadN,
ConvConfig::kPadK,
ConvConfig::DoubleSmemBuffer,
typename GroupedConvTraitsType::GroupedConvImplicitGemmTraitsBwdWeight::AsLayout,
typename GroupedConvTraitsType::GroupedConvImplicitGemmTraitsBwdWeight::BsLayout,
typename GroupedConvTraitsType::GroupedConvImplicitGemmTraitsBwdWeight::CLayout,
GemmConfig::TransposeC,
GemmConfig::UseStructuredSparsity,
ConvConfig::TransposeC,
ConvConfig::UseStructuredSparsity,
false, // Persistent,
GemmConfig::NumWaveGroups>;
ConvConfig::NumWaveGroups>;
using GemmPipelineProblem = ck_tile::GemmPipelineProblem<
OutDataType,
@@ -80,7 +80,7 @@ struct GroupedConvolutionBackwardWeightTwoStageInvoker
VectorSizeB>;
using BaseGemmPipeline = typename PipelineTypeTraits<
GemmConfig::Pipeline>::template UniversalGemmPipeline<GemmPipelineProblem>;
ConvConfig::Pipeline>::template UniversalGemmPipeline<GemmPipelineProblem>;
const ck_tile::index_t gemm_k =
args.N_ * std::accumulate(args.output_spatial_lengths_.begin(),
@@ -88,8 +88,8 @@ struct GroupedConvolutionBackwardWeightTwoStageInvoker
1,
std::multiplies<ck_tile::index_t>());
const ck_tile::index_t k_grain = args.k_batch * GemmConfig::K_Tile;
const ck_tile::index_t K_split = (gemm_k + k_grain - 1) / k_grain * GemmConfig::K_Tile;
const ck_tile::index_t k_grain = args.k_batch * ConvConfig::K_Tile;
const ck_tile::index_t K_split = (gemm_k + k_grain - 1) / k_grain * ConvConfig::K_Tile;
const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split);
const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop);
const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop);
@@ -100,7 +100,7 @@ struct GroupedConvolutionBackwardWeightTwoStageInvoker
const auto memory_operation_) {
constexpr bool has_hot_loop_v = has_hot_loop_.value;
constexpr auto tail_number_v = tail_number_.value;
constexpr auto scheduler = GemmConfig::Scheduler;
constexpr auto scheduler = ConvConfig::Scheduler;
constexpr auto memory_operation = memory_operation_.value;
using UniversalGemmProblem =
@@ -120,7 +120,7 @@ struct GroupedConvolutionBackwardWeightTwoStageInvoker
VectorSizeB>;
using GemmPipeline = typename PipelineTypeTraits<
GemmConfig::Pipeline>::template GemmPipeline<UniversalGemmProblem>;
ConvConfig::Pipeline>::template GemmPipeline<UniversalGemmProblem>;
using ConvEpilogue = ck_tile::CShuffleEpilogue<ck_tile::CShuffleEpilogueProblem<
OutDataType, // A: Out
@@ -133,11 +133,11 @@ struct GroupedConvolutionBackwardWeightTwoStageInvoker
CDEElementWise,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
GemmConfig::M_Warp,
GemmConfig::N_Warp,
GemmConfig::M_Warp_Tile,
GemmConfig::N_Warp_Tile,
GemmConfig::K_Warp_Tile,
ConvConfig::M_Warp,
ConvConfig::N_Warp,
ConvConfig::M_Warp_Tile,
ConvConfig::N_Warp_Tile,
ConvConfig::K_Warp_Tile,
GemmPipelineProblem::TransposeC,
memory_operation,
1,

View File

@@ -51,8 +51,8 @@ int run_grouped_conv_fwd_example(int argc, char* argv[])
int main(int argc, char* argv[])
{
#if CK_TILE_USE_WMMA
return !run_grouped_conv_fwd_example<GemmConfigComputeV3_WMMA>(argc, argv);
return !run_grouped_conv_fwd_example<ConvConfigComputeV3_WMMA>(argc, argv);
#else
return !run_grouped_conv_fwd_example<GemmConfigComputeV3>(argc, argv);
return !run_grouped_conv_fwd_example<ConvConfigComputeV3>(argc, argv);
#endif
}

View File

@@ -32,9 +32,10 @@ struct GroupedConvolutionForwardInvoker
GemmConfig::PermuteA,
GemmConfig::PermuteB>;
constexpr ck_tile::index_t VectorSizeA = 8;
constexpr ck_tile::index_t VectorSizeB = 8;
constexpr ck_tile::index_t VectorSizeC = 8;
constexpr ck_tile::index_t VectorSizeA = 8;
constexpr ck_tile::index_t VectorSizeB = 8;
constexpr ck_tile::index_t VectorSizeC = 8;
constexpr ck_tile::index_t NumGroupsToMerge = 1;
constexpr auto ConvSpec = ck_tile::ConvolutionSpecialization::Default;
using TilePartitioner =
@@ -50,6 +51,7 @@ struct GroupedConvolutionForwardInvoker
VectorSizeA,
VectorSizeB,
VectorSizeC,
NumGroupsToMerge,
CDElementWise>;
using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits<

View File

@@ -11,24 +11,11 @@
#include "ck_tile/ops/epilogue.hpp"
#include "ck_tile/ops/gemm.hpp"
#include "ck_tile/ops/grouped_convolution.hpp"
#include "gemm_configs.hpp"
#include "conv_configs.hpp"
using MemoryOpSet =
std::integral_constant<ck_tile::memory_operation_enum, ck_tile::memory_operation_enum::set>;
using MemoryOpAtomicAdd = std::integral_constant<ck_tile::memory_operation_enum,
ck_tile::memory_operation_enum::atomic_add>;
struct GemmWarpConfig_Mfma
{
static constexpr ck_tile::index_t M_Warp_Tile = 32;
static constexpr ck_tile::index_t N_Warp_Tile = 32;
static constexpr ck_tile::index_t K_Warp_Tile = 16;
};
struct GemmWarpConfig_Wmma
{
static constexpr ck_tile::index_t M_Warp_Tile = 16;
static constexpr ck_tile::index_t N_Warp_Tile = 16;
static constexpr ck_tile::index_t K_Warp_Tile = 16;
};
template <typename InDataType, typename WeiDataType, typename AccDataType, typename OutDataType>
auto calculate_rtol_atol(const ck_tile::index_t GemmK,

View File

@@ -3,7 +3,7 @@
#pragma once
template <ck_tile::index_t NDimSpatial,
typename GemmConfig,
typename ConvConfig,
typename Invoker,
typename InDataType,
typename WeiDataType,
@@ -17,7 +17,7 @@ float invoke_grouped_conv_bwd_weight(ck_tile::GroupedConvBwdWeightHostArgs& args
int n_repeat)
{
float ave_time = Invoker::template grouped_conv_bwd_weight<NDimSpatial,
GemmConfig,
ConvConfig,
InDataType,
WeiDataType,
AccDataType,
@@ -31,7 +31,7 @@ float invoke_grouped_conv_bwd_weight(ck_tile::GroupedConvBwdWeightHostArgs& args
}
template <ck_tile::index_t NDimSpatial,
typename GemmConfig,
typename ConvConfig,
typename Invoker,
typename InDataType,
typename WeiDataType = InDataType,
@@ -131,7 +131,7 @@ int run_grouped_conv_bwd_weight_example_with_layouts(ck_tile::ArgParser& arg_par
std::cout << "output: " << output.mDesc << std::endl;
float ave_time = invoke_grouped_conv_bwd_weight<NDimSpatial,
GemmConfig,
ConvConfig,
Invoker,
InDataType,
WeiDataType,
@@ -193,7 +193,7 @@ int run_grouped_conv_bwd_weight_example_with_layouts(ck_tile::ArgParser& arg_par
}
template <typename Invoker,
typename GemmConfig,
typename ConvConfig,
typename InPrecType,
typename WeiPrecType = InPrecType,
typename OutPrecType = InPrecType>
@@ -217,7 +217,7 @@ int run_grouped_conv_bwd_weight_example_prec_type(std::string in_layout,
if(in_layout == "NWGC" && wei_layout == "GKXC" && out_layout == "NWGK")
{
return run_grouped_conv_bwd_weight_example_with_layouts<ck_tile::number<1>{},
GemmConfig,
ConvConfig,
Invoker,
InPrecType,
WeiPrecType,
@@ -227,7 +227,7 @@ int run_grouped_conv_bwd_weight_example_prec_type(std::string in_layout,
else if(in_layout == "NHWGC" && wei_layout == "GKYXC" && out_layout == "NHWGK")
{
return run_grouped_conv_bwd_weight_example_with_layouts<ck_tile::number<2>{},
GemmConfig,
ConvConfig,
Invoker,
InPrecType,
WeiPrecType,
@@ -237,7 +237,7 @@ int run_grouped_conv_bwd_weight_example_prec_type(std::string in_layout,
else if(in_layout == "NDHWGC" && wei_layout == "GKZYXC" && out_layout == "NDHWGK")
{
return run_grouped_conv_bwd_weight_example_with_layouts<ck_tile::number<3>{},
GemmConfig,
ConvConfig,
Invoker,
InPrecType,
WeiPrecType,

View File

@@ -25,7 +25,7 @@
* (3) number of iterations to cover the entire Y axis.
* The raked here represents how data is partitioned across different processing granularity.
* It represents howe we are going to access the data in thread, warp, or blocked in contiguous
* It represents how we are going to access the data in thread, warp, or blocked in contiguous
region.
* From below, the qualifier for 'raked' is the part of warp/thread hierarchy
* in the split of Y tile dimension where the iteration happens,
@@ -101,7 +101,7 @@ enum struct tile_distribution_pattern
* @brief Block raked pattern - aka linear.
*
*/
block_raked,
block_raked
};
struct tile_distribution_encoding_pattern
@@ -144,7 +144,6 @@ struct tile_distribution_encoding_pattern_2d<BlockSize,
NumWaveGroups>
: public tile_distribution_encoding_pattern
{
// TODO: make pattern where below condition does not need to hold - GGemmMultiDSplitk!
static_assert(XPerTile % VecSize == 0, "XPerTile must be a multiple of VecSize!");
static constexpr index_t warp_size = get_warp_size();

View File

@@ -19,7 +19,7 @@
namespace ck_tile {
/** @brief Maximum number of error values to display when checking errors */
constexpr int ERROR_DETAIL_LIMIT = 5;
constexpr int ERROR_DETAIL_LIMIT = 128;
/** @brief 8-bit floating point type */
using F8 = ck_tile::fp8_t;

View File

@@ -26,7 +26,8 @@ struct GroupedConvBwdWeightKernelArgs
GroupedConvTraitsType_::ConvSpecialization,
GroupedConvTraitsType_::VectorSizeA,
GroupedConvTraitsType_::VectorSizeB,
GroupedConvTraitsType_::VectorSizeC>;
GroupedConvTraitsType_::VectorSizeC,
GroupedConvTraitsType_::NumGroupsToMerge>;
static constexpr index_t NumDTensor = GroupedConvTraitsType_::NumDTensor;
template <
@@ -84,9 +85,11 @@ struct GroupedConvBwdWeightKernelArgs
b_grid_desc_k_n = grid_descs.at(number<1>{});
c_grid_desc_m_n = grid_descs.at(number<2>{});
group_stride_a = args.K_; // A: Out NWGK
group_stride_b = args.C_; // B: In NWGC
group_stride_c = args.K_ * args.C_ * // C: Wei GKXC
NumGroupsPerBatch = GroupedConvTraitsType_::NumGroupsToMerge;
group_stride_a = args.K_ * NumGroupsPerBatch; // A: Out NWGK
group_stride_b = args.C_ * NumGroupsPerBatch; // B: In NWGC
group_stride_c = args.K_ * args.C_ // C: Wei GKXC
* NumGroupsPerBatch *
std::accumulate(args.filter_spatial_lengths_.begin(),
args.filter_spatial_lengths_.end(),
1,
@@ -95,7 +98,14 @@ struct GroupedConvBwdWeightKernelArgs
GemmM = a_grid_desc_k_m.get_length(number<1>{});
GemmN = b_grid_desc_k_n.get_length(number<1>{});
GemmK = a_grid_desc_k_m.get_length(number<0>{});
GemmBatch = args.G_;
GemmBatch = integer_divide_ceil(args.G_, NumGroupsPerBatch);
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
{
std::cout << "GemmM: " << GemmM << ", GemmN: " << GemmN << ", GemmK: " << GemmK
<< ", GemmBatch: " << GemmBatch
<< ", NumGroupsPerBatch: " << NumGroupsPerBatch << std::endl;
}
}
template <
@@ -160,9 +170,11 @@ struct GroupedConvBwdWeightKernelArgs
b_grid_desc_k_n = grid_descs.at(number<1>{});
c_grid_desc_m_n = grid_descs.at(number<2>{});
group_stride_a = args.K_; // A: Out NHWGK
group_stride_b = args.C_; // B: In NHWGC
group_stride_c = args.K_ * args.C_ * // C: Wei GKYXC
NumGroupsPerBatch = GroupedConvTraitsType_::NumGroupsToMerge;
group_stride_a = args.K_ * NumGroupsPerBatch; // A: Out NHWGK
group_stride_b = args.C_ * NumGroupsPerBatch; // B: In NHWGC
group_stride_c = args.K_ * args.C_ // C: Wei GKYXC
* NumGroupsPerBatch *
std::accumulate(args.filter_spatial_lengths_.begin(),
args.filter_spatial_lengths_.end(),
1,
@@ -171,7 +183,14 @@ struct GroupedConvBwdWeightKernelArgs
GemmM = a_grid_desc_k_m.get_length(number<1>{});
GemmN = b_grid_desc_k_n.get_length(number<1>{});
GemmK = a_grid_desc_k_m.get_length(number<0>{});
GemmBatch = args.G_;
GemmBatch = integer_divide_ceil(args.G_, NumGroupsPerBatch);
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
{
std::cout << "GemmM: " << GemmM << ", GemmN: " << GemmN << ", GemmK: " << GemmK
<< ", GemmBatch: " << GemmBatch
<< ", NumGroupsPerBatch: " << NumGroupsPerBatch << std::endl;
}
}
template <
@@ -243,9 +262,11 @@ struct GroupedConvBwdWeightKernelArgs
b_grid_desc_k_n = grid_descs.at(number<1>{});
c_grid_desc_m_n = grid_descs.at(number<2>{});
group_stride_a = args.K_; // A: Out NDHWGK
group_stride_b = args.C_; // B: In NDHWGC
group_stride_c = args.K_ * args.C_ * // C: wEI GKZYXC
NumGroupsPerBatch = GroupedConvTraitsType_::NumGroupsToMerge;
group_stride_a = args.K_ * NumGroupsPerBatch; // A: Out NDHWGK
group_stride_b = args.C_ * NumGroupsPerBatch; // B: In NDHWGC
group_stride_c = args.K_ * args.C_ // C: Wei GKZYXC
* NumGroupsPerBatch *
std::accumulate(args.filter_spatial_lengths_.begin(),
args.filter_spatial_lengths_.end(),
1,
@@ -254,7 +275,14 @@ struct GroupedConvBwdWeightKernelArgs
GemmM = a_grid_desc_k_m.get_length(number<1>{});
GemmN = b_grid_desc_k_n.get_length(number<1>{});
GemmK = a_grid_desc_k_m.get_length(number<0>{});
GemmBatch = args.G_;
GemmBatch = integer_divide_ceil(args.G_, NumGroupsPerBatch);
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
{
std::cout << "GemmM: " << GemmM << ", GemmN: " << GemmN << ", GemmK: " << GemmK
<< ", GemmBatch: " << GemmBatch
<< ", NumGroupsPerBatch: " << NumGroupsPerBatch << std::endl;
}
}
using ABCGridDescs = remove_cvref_t<
@@ -279,6 +307,7 @@ struct GroupedConvBwdWeightKernelArgs
index_t GemmN;
index_t GemmK;
index_t GemmBatch;
index_t NumGroupsPerBatch;
const void* out_ptr;
const void* in_ptr;
@@ -317,10 +346,9 @@ struct GroupedConvBwdWeightKernelArgs
/// the policy is responsible for definition of all necessary data layouts and thread's
/// work distribution.
///
/// @tparam GroupedConvTraitsType_ The type of class providing traits for grouped convolution.
/// @tparam GroupedConvTraitsType_ The type of class providing traits for grouped convolution.
/// @tparam TilePartitioner_ The type of class providing mapping of workgroup index into
/// the
/// output data tile to be calculated. It determines the
/// the output data tile to be calculated. It determines the
/// workgroup to data relationship (or in other words - which
/// data would be processed and calculated by which workgroup).
/// @tparam GemmPipeline_ The type of class which provides the core part of matrix
@@ -382,8 +410,12 @@ struct GroupedConvolutionBackwardWeightKernel
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
{
constexpr auto NumGroupsToMerge = GroupedConvTraitsType_::NumGroupsToMerge;
// clang-format off
return concat('_', "grouped_convolution_backward_weight", gemm_prec_str<InDataType, WeiDataType>, GemmPipeline::GetName());
if (NumGroupsToMerge > 1)
return concat('_', "grouped_convolution_backward_weight", gemm_prec_str<InDataType, WeiDataType>, GemmPipeline::GetName(), "merge", NumGroupsToMerge);
else
return concat('_', "grouped_convolution_backward_weight", gemm_prec_str<InDataType, WeiDataType>, GemmPipeline::GetName());
// clang-format on
}
@@ -402,6 +434,12 @@ struct GroupedConvolutionBackwardWeightKernel
CK_TILE_HOST static constexpr GroupedConvBwdWeightKernelArgsSpecialized
MakeKernelArgs(const GroupedConvBwdWeightHostArgs& hostArgs)
{
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
{
std::cout << "MPerBlock: " << number<TilePartitioner::MPerBlock>{} << std::endl;
std::cout << "NPerBlock: " << number<TilePartitioner::NPerBlock>{} << std::endl;
std::cout << "KPerBlock: " << number<TilePartitioner::KPerBlock>{} << std::endl;
}
return GroupedConvBwdWeightKernelArgsSpecialized(hostArgs);
}
@@ -442,11 +480,14 @@ struct GroupedConvolutionBackwardWeightKernel
{
return [&]() {
if(kargs.k_batch > 1)
hipGetErrorString(hipMemsetAsync(kargs.wei_ptr,
0,
kargs.GemmBatch * kargs.GemmM * kargs.GemmN *
sizeof(WeiDataType),
s.stream_id_));
{
// Total number of convolution groups (ConvG) = GemmBatch * NumGroupsPerBatch
// since we require that ConvG % NumGroupsPerBatch == 0.
const auto wei_size =
kargs.GemmBatch * kargs.GemmM * kargs.GemmN * kargs.NumGroupsPerBatch;
hipGetErrorString(
hipMemsetAsync(kargs.wei_ptr, 0, wei_size * sizeof(WeiDataType), s.stream_id_));
}
};
}
@@ -527,7 +568,8 @@ struct GroupedConvolutionBackwardWeightKernel
// Check access per C
if(ConvC % GroupedConvTraitsType_::VectorSizeB != 0)
{
CK_TILE_ERROR("Conv C is not a multiple of vector load size for input image!");
CK_TILE_ERROR("Conv C is not a multiple of vector load size for "
"input image!");
return false;
}
}
@@ -559,7 +601,8 @@ struct GroupedConvolutionBackwardWeightKernel
{
if(ConvK % GroupedConvTraitsType_::VectorSizeA != 0)
{
CK_TILE_ERROR("Conv K is not a multiple of vector store size for output image!");
CK_TILE_ERROR("Conv K is not a multiple of vector store size "
"for output image!");
return false;
}
}
@@ -569,6 +612,18 @@ struct GroupedConvolutionBackwardWeightKernel
return false;
}
if constexpr(GroupedConvTraitsType_::NumGroupsToMerge > 1)
{
const index_t ConvG = kargs.wei_g_k_c_xs_lengths[number<0>{}];
if(ConvG % GroupedConvTraitsType_::NumGroupsToMerge != 0)
{
CK_TILE_ERROR("ConvG must be a multiple of NumGroupsToMerge!");
return false;
}
// TODO: Should we also check that GemmM <= MPerBlock and GemmN <= NPerBlock?
}
return true;
}
@@ -654,6 +709,16 @@ struct GroupedConvolutionBackwardWeightKernel
return make_tuple(a_pad_view, b_pad_view, ds_pad_view, c_pad_view);
}
/**
* @brief Create views to the data that each workgroup will process.
*
* @param views padded views of A, B, D and C tensors
* @param i_m block m-index
* @param i_n block n-index
* @param i_k block k-index
*
* @return tuple of tile windows for A, B, D and C tensors
*/
template <typename PadView>
CK_TILE_DEVICE static auto MakeGemmTileWindows(const PadView& views,
const index_t i_m,
@@ -818,7 +883,6 @@ struct GroupedConvolutionBackwardWeightKernel
const InDataType* b_ptr = static_cast<const InDataType*>(kargs.in_ptr) + group_offset_b;
WeiDataType* c_ptr = static_cast<WeiDataType*>(kargs.wei_ptr) + group_offset_c;
// allocate LDS
__shared__ char smem_ptr_0[GetSmemSize()];
if constexpr(GemmPipeline::DoubleSmemBuffer == true)

View File

@@ -29,6 +29,7 @@ struct GroupedConvFwdKernelArgs
GroupedConvTraitsType_::VectorSizeA,
GroupedConvTraitsType_::VectorSizeB,
GroupedConvTraitsType_::VectorSizeC,
GroupedConvTraitsType_::NumGroupsToMerge,
true>; // Split N enabled
using CDElementwise = typename GroupedConvTraitsType_::CDElementwise;
static constexpr index_t NumDTensor = GroupedConvTraitsType_::NumDTensor;

View File

@@ -59,10 +59,11 @@ template <index_t NDimSpatial_,
typename WeiLayout_,
typename DsLayout_,
typename OutLayout_,
index_t VectorSizeA_ = 1,
index_t VectorSizeB_ = 1,
index_t VectorSizeC_ = 1,
typename CDElementwise_ = PassThrough>
index_t VectorSizeA_ = 1,
index_t VectorSizeB_ = 1,
index_t VectorSizeC_ = 1,
index_t NumGroupsToMerge_ = 1,
typename CDElementwise_ = PassThrough>
struct GroupedConvTraits
{
private:
@@ -73,7 +74,7 @@ struct GroupedConvTraits
}
public:
static constexpr index_t NumGroupsToMerge = 1;
static constexpr index_t NumGroupsToMerge = NumGroupsToMerge_;
static constexpr index_t NDimSpatial = NDimSpatial_;
static constexpr ConvolutionSpecialization ConvSpecialization = ConvSpecialization_;
using InLayout = InLayout_;

View File

@@ -13,10 +13,10 @@ template <index_t NDimSpatial,
index_t VectorSizeA,
index_t VectorSizeB,
index_t VectorSizeC,
index_t NumGroupsToMerge = 1,
bool SplitN = false,
typename ADataType = float,
typename CDataType = float,
index_t NumGroupsToMerge = 1,
typename IndexType = index_t>
struct TransformConvBwdWeightToGemm
{
@@ -125,8 +125,7 @@ struct TransformConvBwdWeightToGemm
InLeftPadW_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.InLeftPadW_)},
InRightPadD_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.InRightPadD_)},
InRightPadH_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.InRightPadH_)},
InRightPadW_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.InRightPadW_)},
ZYX_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.ZYX_)}
InRightPadW_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.InRightPadW_)}
{
}
@@ -164,8 +163,7 @@ struct TransformConvBwdWeightToGemm
InLeftPadW_{input_left_pads[I0]},
InRightPadD_{I0},
InRightPadH_{I0},
InRightPadW_{input_right_pads[I0]},
ZYX_{X_}
InRightPadW_{input_right_pads[I0]}
{
static_assert(std::is_same_v<ConvSpatialDimsType, std::array<IndexType, NDimSpatial>> ||
std::is_same_v<ConvSpatialDimsType, ck_tile::array<IndexType, NDimSpatial>>);
@@ -219,8 +217,7 @@ struct TransformConvBwdWeightToGemm
InLeftPadW_{input_left_pads[I1]},
InRightPadD_{I0},
InRightPadH_{input_right_pads[I0]},
InRightPadW_{input_right_pads[I1]},
ZYX_{Y_ * X_}
InRightPadW_{input_right_pads[I1]}
{
static_assert(std::is_same_v<ConvSpatialDimsType, std::array<IndexType, NDimSpatial>> ||
std::is_same_v<ConvSpatialDimsType, ck_tile::array<IndexType, NDimSpatial>>);
@@ -274,8 +271,7 @@ struct TransformConvBwdWeightToGemm
InLeftPadW_{input_left_pads[I2]},
InRightPadD_{input_right_pads[I0]},
InRightPadH_{input_right_pads[I1]},
InRightPadW_{input_right_pads[I2]},
ZYX_{Z_ * Y_ * X_}
InRightPadW_{input_right_pads[I2]}
{
static_assert(std::is_same_v<ConvSpatialDimsType, std::array<IndexType, NDimSpatial>> ||
std::is_same_v<ConvSpatialDimsType, ck_tile::array<IndexType, NDimSpatial>>);
@@ -420,11 +416,21 @@ struct TransformConvBwdWeightToGemm
const index_t NDoHoWoStride = G_ * K_;
constexpr auto KStride = I1;
// TODO Add support for NumGroupsToMerge > 1
return make_naive_tensor_descriptor(make_tuple(K_, N_ * Wo_),
make_tuple(KStride, NDoHoWoStride),
number<VectorSizeA>{},
I1);
if constexpr(NumGroupsToMerge > 1)
{
const index_t BatchStride = K_;
return make_naive_tensor_descriptor(make_tuple(K_, NumGroupsToMerge, N_ * Wo_),
make_tuple(KStride, BatchStride, NDoHoWoStride),
number<VectorSizeA>{},
I1);
}
else
{
return make_naive_tensor_descriptor(make_tuple(K_, N_ * Wo_),
make_tuple(KStride, NDoHoWoStride),
number<VectorSizeA>{},
I1);
}
}
template <index_t NDim = NDimSpatial, typename std::enable_if<NDim == 1, bool>::type = false>
@@ -435,11 +441,22 @@ struct TransformConvBwdWeightToGemm
const index_t WiStride = G_ * C_;
constexpr auto CStride = I1;
// TODO Add support for NumGroupsToMerge > 1
return make_naive_tensor_descriptor(make_tuple(N_, Wi_, C_),
make_tuple(NStride, WiStride, CStride),
number<VectorSizeB>{},
I1);
if constexpr(NumGroupsToMerge > 1)
{
const auto BatchStride = C_;
return make_naive_tensor_descriptor(make_tuple(N_, Wi_, NumGroupsToMerge, C_),
make_tuple(NStride, WiStride, BatchStride, CStride),
number<VectorSizeB>{},
I1);
}
else
{
return make_naive_tensor_descriptor(make_tuple(N_, Wi_, C_),
make_tuple(NStride, WiStride, CStride),
number<VectorSizeB>{},
I1);
}
}
template <index_t NDim = NDimSpatial, typename std::enable_if<NDim == 1, bool>::type = false>
@@ -449,9 +466,56 @@ struct TransformConvBwdWeightToGemm
const index_t KStride = X_ * C_;
constexpr auto CXStride = I1;
// TODO Add support for NumGroupsToMerge > 1
return make_naive_tensor_descriptor(
make_tuple(K_, X_ * C_), make_tuple(KStride, CXStride), number<VectorSizeC>{}, I1);
if constexpr(NumGroupsToMerge > 1)
{
const index_t XStride = C_;
const index_t BatchStride = K_ * X_ * C_;
// Add NumGroupsToMerge for Batch+M dimension and, 1 as a placeholder
// for Batch+N dimension
const auto desc = make_naive_tensor_descriptor(
make_tuple(NumGroupsToMerge, K_, X_, 1, C_),
make_tuple(BatchStride, KStride, XStride, BatchStride, CXStride),
number<VectorSizeC>{},
I1);
// Pad 1 to NumGroupsToMerge
const auto padded_desc = transform_tensor_descriptor(
desc,
make_tuple(make_pass_through_transform(NumGroupsToMerge),
make_pass_through_transform(K_),
make_pass_through_transform(X_),
make_pad_transform(1, 0, NumGroupsToMerge - 1),
make_pass_through_transform(C_)),
make_tuple(
sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}, sequence<4>{}),
make_tuple(
sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}, sequence<4>{}));
// We need only matrices from diagonal. Xor returns 0 for the same
// values. So if matrices is not on diagonal then it will be stored in padding.
// To avoid use of modulo after xor we assume that NumBatch to merge is power of 2.
static_assert(NumGroupsToMerge == 1 || NumGroupsToMerge == 2 || NumGroupsToMerge == 4 ||
NumGroupsToMerge == 8 || NumGroupsToMerge == 16 ||
NumGroupsToMerge == 32 || NumGroupsToMerge == 64);
const auto unmerged_padded_desc = transform_tensor_descriptor(
padded_desc,
make_tuple(make_xor_transform(make_tuple(NumGroupsToMerge, NumGroupsToMerge)),
make_pass_through_transform(K_),
make_pass_through_transform(X_),
make_pass_through_transform(C_)),
make_tuple(sequence<0, 3>{}, sequence<1>{}, sequence<2>{}, sequence<4>{}),
make_tuple(sequence<0, 3>{}, sequence<1>{}, sequence<2>{}, sequence<4>{}));
// Merge To M, N
return transform_tensor_descriptor(
unmerged_padded_desc,
make_tuple(make_merge_transform(make_tuple(NumGroupsToMerge, K_)),
make_merge_transform(make_tuple(X_, NumGroupsToMerge, C_))),
make_tuple(sequence<0, 1>{}, sequence<2, 3, 4>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
}
else
{
return make_naive_tensor_descriptor(
make_tuple(K_, X_ * C_), make_tuple(KStride, CXStride), number<VectorSizeC>{}, I1);
}
}
template <index_t NDim = NDimSpatial, typename std::enable_if<NDim == 2, bool>::type = false>
@@ -461,11 +525,22 @@ struct TransformConvBwdWeightToGemm
const index_t NDoHoWoStride = G_ * K_;
constexpr auto KStride = I1;
// TODO Add support for NumGroupsToMerge > 1
return make_naive_tensor_descriptor(make_tuple(N_ * Ho_ * Wo_, K_), // K_M
make_tuple(NDoHoWoStride, KStride),
number<VectorSizeA>{},
I1);
if constexpr(NumGroupsToMerge > 1)
{
const index_t BatchStride = K_;
return make_naive_tensor_descriptor(
make_tuple(N_ * Ho_ * Wo_, NumGroupsToMerge, K_), // K_Gm_M
make_tuple(NDoHoWoStride, BatchStride, KStride),
number<VectorSizeA>{},
I1);
}
else
{
return make_naive_tensor_descriptor(make_tuple(N_ * Ho_ * Wo_, K_), // K_M
make_tuple(NDoHoWoStride, KStride),
number<VectorSizeA>{},
I1);
}
}
template <index_t NDim = NDimSpatial, typename std::enable_if<NDim == 2, bool>::type = false>
@@ -477,11 +552,22 @@ struct TransformConvBwdWeightToGemm
const index_t WiStride = G_ * C_;
constexpr auto CStride = I1;
// TODO Add support for NumGroupsToMerge > 1
return make_naive_tensor_descriptor(make_tuple(N_, Hi_, Wi_, C_), // K_N
make_tuple(NStride, HiStride, WiStride, CStride),
number<VectorSizeB>{},
I1);
if constexpr(NumGroupsToMerge > 1)
{
const auto BatchStride = C_;
return make_naive_tensor_descriptor(
make_tuple(N_, Hi_, Wi_, NumGroupsToMerge, C_), // K_Gm_N
make_tuple(NStride, HiStride, WiStride, BatchStride, CStride),
number<VectorSizeB>{},
I1);
}
else
{
return make_naive_tensor_descriptor(make_tuple(N_, Hi_, Wi_, C_), // K_N
make_tuple(NStride, HiStride, WiStride, CStride),
number<VectorSizeB>{},
I1);
}
}
template <index_t NDim = NDimSpatial, typename std::enable_if<NDim == 2, bool>::type = false>
@@ -491,9 +577,58 @@ struct TransformConvBwdWeightToGemm
const index_t KStride = Y_ * X_ * C_;
constexpr auto CStride = I1;
// TODO Add support for NumGroupsToMerge > 1
return make_naive_tensor_descriptor(
make_tuple(K_, Y_ * X_ * C_), make_tuple(KStride, CStride), number<VectorSizeC>{}, I1);
if constexpr(NumGroupsToMerge > 1)
{
const index_t YXStride = C_;
const index_t BatchStride = K_ * Y_ * X_ * C_;
// Add NumGroupsToMerge for Batch+M dimension and, 1 as a placeholder
// for Batch+N dimension
const auto desc = make_naive_tensor_descriptor(
make_tuple(NumGroupsToMerge, K_, Y_ * X_, 1, C_),
make_tuple(BatchStride, KStride, YXStride, BatchStride, CStride),
number<VectorSizeC>{},
I1);
// Pad 1 to NumGroupsToMerge
const auto padded_desc = transform_tensor_descriptor(
desc,
make_tuple(make_pass_through_transform(NumGroupsToMerge),
make_pass_through_transform(K_),
make_pass_through_transform(Y_ * X_),
make_pad_transform(1, 0, NumGroupsToMerge - 1),
make_pass_through_transform(C_)),
make_tuple(
sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}, sequence<4>{}),
make_tuple(
sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}, sequence<4>{}));
// We need only matrices from diagonal. Xor returns 0 for the same
// values. So if matrices is not on diagonal then it will be stored in padding.
// To avoid use of modulo after xor we assume that NumBatch to merge is power of 2.
static_assert(NumGroupsToMerge == 1 || NumGroupsToMerge == 2 || NumGroupsToMerge == 4 ||
NumGroupsToMerge == 8 || NumGroupsToMerge == 16 ||
NumGroupsToMerge == 32 || NumGroupsToMerge == 64);
const auto unmerged_padded_desc = transform_tensor_descriptor(
padded_desc,
make_tuple(make_xor_transform(make_tuple(NumGroupsToMerge, NumGroupsToMerge)),
make_pass_through_transform(K_),
make_pass_through_transform(Y_ * X_),
make_pass_through_transform(C_)),
make_tuple(sequence<0, 3>{}, sequence<1>{}, sequence<2>{}, sequence<4>{}),
make_tuple(sequence<0, 3>{}, sequence<1>{}, sequence<2>{}, sequence<4>{}));
// Merge To M, N
return transform_tensor_descriptor(
unmerged_padded_desc,
make_tuple(make_merge_transform(make_tuple(NumGroupsToMerge, K_)),
make_merge_transform(make_tuple(Y_ * X_, NumGroupsToMerge, C_))),
make_tuple(sequence<0, 1>{}, sequence<2, 3, 4>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
}
else
{
return make_naive_tensor_descriptor(make_tuple(K_, Y_ * X_ * C_),
make_tuple(KStride, CStride),
number<VectorSizeC>{},
I1);
}
}
template <index_t NDim = NDimSpatial, typename std::enable_if<NDim == 3, bool>::type = false>
@@ -503,11 +638,22 @@ struct TransformConvBwdWeightToGemm
const index_t NDoHoWoStride = G_ * K_;
constexpr auto KStride = I1;
// TODO Add support for NumGroupsToMerge > 1
return make_naive_tensor_descriptor(make_tuple(N_ * Do_ * Ho_ * Wo_, K_),
make_tuple(NDoHoWoStride, KStride),
number<VectorSizeA>{},
I1);
if constexpr(NumGroupsToMerge > 1)
{
const auto BatchStride = K_;
return make_naive_tensor_descriptor(
make_tuple(N_ * Do_ * Ho_ * Wo_, NumGroupsToMerge, K_),
make_tuple(NDoHoWoStride, BatchStride, KStride),
number<VectorSizeA>{},
I1);
}
else
{
return make_naive_tensor_descriptor(make_tuple(N_ * Do_ * Ho_ * Wo_, K_),
make_tuple(NDoHoWoStride, KStride),
number<VectorSizeA>{},
I1);
}
}
template <index_t NDim = NDimSpatial, typename std::enable_if<NDim == 3, bool>::type = false>
@@ -519,26 +665,84 @@ struct TransformConvBwdWeightToGemm
const index_t WiStride = G_ * C_;
constexpr auto CStride = I1;
// TODO Add support for NumGroupsToMerge > 1
return make_naive_tensor_descriptor(
make_tuple(N_, Di_, Hi_, Wi_, C_),
make_tuple(NStride, DiStride, HiStride, WiStride, CStride),
number<VectorSizeB>{},
I1);
if constexpr(NumGroupsToMerge > 1)
{
const index_t BatchStride = C_;
return make_naive_tensor_descriptor(
make_tuple(N_, Di_, Hi_, Wi_, NumGroupsToMerge, C_),
make_tuple(NStride, DiStride, HiStride, WiStride, BatchStride, CStride),
number<VectorSizeB>{},
I1);
}
else
{
return make_naive_tensor_descriptor(
make_tuple(N_, Di_, Hi_, Wi_, C_),
make_tuple(NStride, DiStride, HiStride, WiStride, CStride),
number<VectorSizeB>{},
I1);
}
}
template <index_t NDim = NDimSpatial, typename std::enable_if<NDim == 3, bool>::type = false>
CK_TILE_HOST auto make_wei_grid_desc() const
{
// KZYXC
// GKZYXC
const index_t KStride = Z_ * Y_ * X_ * C_;
constexpr auto CStride = I1;
// TODO Add support for NumGroupsToMerge > 1
return make_naive_tensor_descriptor(make_tuple(K_, Z_ * Y_ * X_ * C_),
make_tuple(KStride, CStride),
number<VectorSizeC>{},
I1);
if constexpr(NumGroupsToMerge > 1)
{
const index_t ZYXStride = C_;
const index_t BatchStride = K_ * Z_ * Y_ * X_ * C_;
// Add NumGroupsToMerge for Batch+M dimension and, 1 as a placeholder
// for Batch+N dimension
const auto desc = make_naive_tensor_descriptor(
make_tuple(NumGroupsToMerge, K_, Z_ * Y_ * X_, 1, C_),
make_tuple(BatchStride, KStride, ZYXStride, BatchStride, CStride),
number<VectorSizeC>{},
I1);
// Pad 1 to NumGroupsToMerge
const auto padded_desc = transform_tensor_descriptor(
desc,
make_tuple(make_pass_through_transform(NumGroupsToMerge),
make_pass_through_transform(K_),
make_pass_through_transform(Z_ * Y_ * X_),
make_pad_transform(1, 0, NumGroupsToMerge - 1),
make_pass_through_transform(C_)),
make_tuple(
sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}, sequence<4>{}),
make_tuple(
sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}, sequence<4>{}));
// We need only matrices from diagonal. Xor returns 0 for the same
// values. So if matrices is not on diagonal then it will be stored in padding.
// To avoid use of modulo after xor we assume that NumBatch to merge is power of 2.
static_assert(NumGroupsToMerge == 1 || NumGroupsToMerge == 2 || NumGroupsToMerge == 4 ||
NumGroupsToMerge == 8 || NumGroupsToMerge == 16 ||
NumGroupsToMerge == 32 || NumGroupsToMerge == 64);
const auto unmerged_padded_desc = transform_tensor_descriptor(
padded_desc,
make_tuple(make_xor_transform(make_tuple(NumGroupsToMerge, NumGroupsToMerge)),
make_pass_through_transform(K_),
make_pass_through_transform(Z_ * Y_ * X_),
make_pass_through_transform(C_)),
make_tuple(sequence<0, 3>{}, sequence<1>{}, sequence<2>{}, sequence<4>{}),
make_tuple(sequence<0, 3>{}, sequence<1>{}, sequence<2>{}, sequence<4>{}));
// Merge To M, N
return transform_tensor_descriptor(
unmerged_padded_desc,
make_tuple(make_merge_transform(make_tuple(NumGroupsToMerge, K_)),
make_merge_transform(make_tuple(Z_ * Y_ * X_, NumGroupsToMerge, C_))),
make_tuple(sequence<0, 1>{}, sequence<2, 3, 4>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
}
else
{
return make_naive_tensor_descriptor(make_tuple(K_, Z_ * Y_ * X_ * C_),
make_tuple(KStride, CStride),
number<VectorSizeC>{},
I1);
}
}
// TODO: implement ck_tile::tensor_layout::convolution that describe packed/strided dimemsion as
@@ -552,31 +756,84 @@ struct TransformConvBwdWeightToGemm
const auto wei_grid_desc = make_wei_grid_desc<NDimSpatial>();
// B: input tensor comes in K_N
const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor(
in_grid_desc,
make_tuple(make_pass_through_transform(N_),
make_pad_transform(Wi_, InLeftPadW_, InRightPadW_),
make_pass_through_transform(C_)),
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}),
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}));
if constexpr(NumGroupsToMerge > 1)
{
// Output tensor transformation
// [0, 1, 2] -> [0, 1]
// [(N*Wo), Gm, K] -> [(N*Wo), (Gm*K)]
const auto out_gemm_k_gemm_m_grid_desc = transform_tensor_descriptor(
out_grid_desc,
make_tuple(make_pass_through_transform(N_ * Wo_),
make_merge_transform(make_tuple(NumGroupsToMerge, K_))),
make_tuple(sequence<0>{}, sequence<1, 2>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor(
in_n_hip_wip_c_grid_desc,
make_tuple(
make_pass_through_transform(N_),
make_embed_transform(make_tuple(X_, Wo_), make_tuple(ConvDilationW_, ConvStrideW_)),
make_pass_through_transform(C_)),
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}),
make_tuple(sequence<0>{}, sequence<1, 2>{}, sequence<3>{}));
// Input tensor transformation, part 1.
// [N, Wi, Gm, C] -> [N, (Wi + InLeftPadW + InRightPadW), Gm, C] = [N, Wip, Gm, C]
const auto in_n_wip_gm_c_grid_desc = transform_tensor_descriptor(
in_grid_desc,
make_tuple(make_pass_through_transform(N_),
make_pad_transform(Wi_, InLeftPadW_, InRightPadW_),
make_pass_through_transform(NumGroupsToMerge),
make_pass_through_transform(C_)),
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}),
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}));
const auto in_gemmn_gemmktotal_grid_desc =
transform_tensor_descriptor(in_n_y_ho_x_wo_c_grid_desc,
make_tuple(make_merge_transform(make_tuple(X_, C_)),
make_merge_transform(make_tuple(N_, Wo_))),
make_tuple(sequence<1, 3>{}, sequence<0, 2>{}),
make_tuple(sequence<1>{}, sequence<0>{}));
// Input tensor transformation, part 2.
// [N, Wip, Gm, C] -> [N, X, Wo, Gm, C]
const auto in_n_x_wo_gm_c_grid_desc = transform_tensor_descriptor(
in_n_wip_gm_c_grid_desc,
make_tuple(make_pass_through_transform(N_),
make_embed_transform(make_tuple(X_, Wo_),
make_tuple(ConvDilationW_, ConvStrideW_)),
make_pass_through_transform(NumGroupsToMerge),
make_pass_through_transform(C_)),
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}),
make_tuple(sequence<0>{}, sequence<1, 2>{}, sequence<3>{}, sequence<4>{}));
return make_tuple(out_grid_desc, in_gemmn_gemmktotal_grid_desc, wei_grid_desc);
// Input tensor transformation, part 3.
// [0, 1, 2, 3, 4] -> [0, 1]
// [N, X, Wo, Gm, C] -> [(N*Wo), (Gm*X*C)]
const auto in_gemm_n_gemm_k_grid_desc = transform_tensor_descriptor(
in_n_x_wo_gm_c_grid_desc,
make_tuple(make_merge_transform(make_tuple(X_, NumGroupsToMerge, C_)),
make_merge_transform(make_tuple(N_, Wo_))),
make_tuple(sequence<1, 3, 4>{}, sequence<0, 2>{}),
make_tuple(sequence<1>{}, sequence<0>{}));
return make_tuple(
out_gemm_k_gemm_m_grid_desc, in_gemm_n_gemm_k_grid_desc, wei_grid_desc);
}
else
{
// [N, Wi, C] -> [N, (Wi + InLeftPadW + InRightPadW), C] = [N, Wip, C]
const auto in_n_wip_c_grid_desc = transform_tensor_descriptor(
in_grid_desc,
make_tuple(make_pass_through_transform(N_),
make_pad_transform(Wi_, InLeftPadW_, InRightPadW_),
make_pass_through_transform(C_)),
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}),
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}));
// [N, Wip, C] -> [N, X, Wo, C]
const auto in_n_x_wo_c_grid_desc = transform_tensor_descriptor(
in_n_wip_c_grid_desc,
make_tuple(make_pass_through_transform(N_),
make_embed_transform(make_tuple(X_, Wo_),
make_tuple(ConvDilationW_, ConvStrideW_)),
make_pass_through_transform(C_)),
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}),
make_tuple(sequence<0>{}, sequence<1, 2>{}, sequence<3>{}));
const auto in_gemmn_gemmktotal_grid_desc =
transform_tensor_descriptor(in_n_x_wo_c_grid_desc,
make_tuple(make_merge_transform(make_tuple(X_, C_)),
make_merge_transform(make_tuple(N_, Wo_))),
make_tuple(sequence<1, 3>{}, sequence<0, 2>{}),
make_tuple(sequence<1>{}, sequence<0>{}));
return make_tuple(out_grid_desc, in_gemmn_gemmktotal_grid_desc, wei_grid_desc);
}
}
template <index_t NDim = NDimSpatial, typename std::enable_if<NDim == 2, bool>::type = false>
@@ -587,33 +844,95 @@ struct TransformConvBwdWeightToGemm
const auto wei_grid_desc = make_wei_grid_desc<NDimSpatial>();
// B: input tensor comes in K_N
const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor(
in_grid_desc,
make_tuple(make_pass_through_transform(N_),
make_pad_transform(Hi_, InLeftPadH_, InRightPadH_),
make_pad_transform(Wi_, InLeftPadW_, InRightPadW_),
make_pass_through_transform(C_)),
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}),
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}));
if constexpr(NumGroupsToMerge > 1)
{
// Output tensor transformation
// [0, 1, 2] -> [0, 1]
// [(N*Ho*Wo), Gm, K] -> [(N*Ho*Wo), (K*Gm)]
const auto out_gemm_k_gemm_m_grid_desc = transform_tensor_descriptor(
out_grid_desc,
make_tuple(make_pass_through_transform(N_ * Ho_ * Wo_),
make_merge_transform(make_tuple(NumGroupsToMerge, K_))),
make_tuple(sequence<0>{}, sequence<1, 2>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor(
in_n_hip_wip_c_grid_desc,
make_tuple(
make_pass_through_transform(N_),
make_embed_transform(make_tuple(Y_, Ho_), make_tuple(ConvDilationH_, ConvStrideH_)),
make_embed_transform(make_tuple(X_, Wo_), make_tuple(ConvDilationW_, ConvStrideW_)),
make_pass_through_transform(C_)),
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}),
make_tuple(sequence<0>{}, sequence<1, 2>{}, sequence<3, 4>{}, sequence<5>{}));
// Input tensor transformation, part 1.
// [N, Hi, Wi, Gm, C] -> [N, Hip, Wip, Gm, C]
const auto in_n_hip_wip_gm_c_grid_desc = transform_tensor_descriptor(
in_grid_desc,
make_tuple(make_pass_through_transform(N_),
make_pad_transform(Hi_, InLeftPadH_, InRightPadH_),
make_pad_transform(Wi_, InLeftPadW_, InRightPadW_),
make_pass_through_transform(NumGroupsToMerge),
make_pass_through_transform(C_)),
make_tuple(
sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}, sequence<4>{}),
make_tuple(
sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}, sequence<4>{}));
const auto in_gemmn_gemmktotal_grid_desc =
transform_tensor_descriptor(in_n_y_ho_x_wo_c_grid_desc,
make_tuple(make_merge_transform(make_tuple(Y_, X_, C_)),
make_merge_transform(make_tuple(N_, Ho_, Wo_))),
make_tuple(sequence<1, 3, 5>{}, sequence<0, 2, 4>{}),
make_tuple(sequence<1>{}, sequence<0>{}));
// Input tensor transformation, part 2.
// [N, Hip, Wip, Gm, C] -> [N, (Y, Wo), (X, Wo), Gm, C]
const auto in_n_y_ho_x_wo_gm_c_grid_desc = transform_tensor_descriptor(
in_n_hip_wip_gm_c_grid_desc,
make_tuple(make_pass_through_transform(N_),
make_embed_transform(make_tuple(Y_, Ho_),
make_tuple(ConvDilationH_, ConvStrideH_)),
make_embed_transform(make_tuple(X_, Wo_),
make_tuple(ConvDilationW_, ConvStrideW_)),
make_pass_through_transform(NumGroupsToMerge),
make_pass_through_transform(C_)),
make_tuple(
sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}, sequence<4>{}),
make_tuple(sequence<0>{},
sequence<1, 2>{},
sequence<3, 4>{},
sequence<5>{},
sequence<6>{}));
return make_tuple(out_grid_desc, in_gemmn_gemmktotal_grid_desc, wei_grid_desc);
// Input tensor transformation, part 3.
// [0, 1, 2, 3, 4 5 6] -> [0, 1]
// [N, Y, Ho, X, Wo, Gm, C] -> [(N*Ho*Wo), (Gm*Y*X*C)]
const auto in_gemm_n_gemm_k_grid_desc = transform_tensor_descriptor(
in_n_y_ho_x_wo_gm_c_grid_desc,
make_tuple(make_merge_transform(make_tuple(Y_, X_, NumGroupsToMerge, C_)),
make_merge_transform(make_tuple(N_, Ho_, Wo_))),
make_tuple(sequence<1, 3, 5, 6>{}, sequence<0, 2, 4>{}),
make_tuple(sequence<1>{}, sequence<0>{}));
return make_tuple(
out_gemm_k_gemm_m_grid_desc, in_gemm_n_gemm_k_grid_desc, wei_grid_desc);
}
else
{
const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor(
in_grid_desc,
make_tuple(make_pass_through_transform(N_),
make_pad_transform(Hi_, InLeftPadH_, InRightPadH_),
make_pad_transform(Wi_, InLeftPadW_, InRightPadW_),
make_pass_through_transform(C_)),
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}),
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}));
const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor(
in_n_hip_wip_c_grid_desc,
make_tuple(make_pass_through_transform(N_),
make_embed_transform(make_tuple(Y_, Ho_),
make_tuple(ConvDilationH_, ConvStrideH_)),
make_embed_transform(make_tuple(X_, Wo_),
make_tuple(ConvDilationW_, ConvStrideW_)),
make_pass_through_transform(C_)),
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}),
make_tuple(sequence<0>{}, sequence<1, 2>{}, sequence<3, 4>{}, sequence<5>{}));
const auto in_gemmn_gemmktotal_grid_desc = transform_tensor_descriptor(
in_n_y_ho_x_wo_c_grid_desc,
make_tuple(make_merge_transform(make_tuple(Y_, X_, C_)),
make_merge_transform(make_tuple(N_, Ho_, Wo_))),
make_tuple(sequence<1, 3, 5>{}, sequence<0, 2, 4>{}),
make_tuple(sequence<1>{}, sequence<0>{}));
return make_tuple(out_grid_desc, in_gemmn_gemmktotal_grid_desc, wei_grid_desc);
}
}
template <index_t NDim = NDimSpatial, typename std::enable_if<NDim == 3, bool>::type = false>
@@ -624,39 +943,121 @@ struct TransformConvBwdWeightToGemm
const auto wei_grid_desc = make_wei_grid_desc<NDimSpatial>();
// B: input tensor comes in K_N
const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor(
in_grid_desc,
make_tuple(make_pass_through_transform(N_),
make_pad_transform(Di_, InLeftPadD_, InRightPadD_),
make_pad_transform(Hi_, InLeftPadH_, InRightPadH_),
make_pad_transform(Wi_, InLeftPadW_, InRightPadW_),
make_pass_through_transform(C_)),
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}, sequence<4>{}),
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}, sequence<4>{}));
if constexpr(NumGroupsToMerge > 1)
{
// Output tensor transformation
// [0, 1, 2] -> [0, 1]
// [(N*Do*Ho*Wo), Gm, K] -> [(N*Do*Ho*Wo), (K*Gm)]
const auto out_gemm_k_gemm_m_grid_desc = transform_tensor_descriptor(
out_grid_desc,
make_tuple(make_pass_through_transform(N_ * Do_ * Ho_ * Wo_),
make_merge_transform(make_tuple(NumGroupsToMerge, K_))),
make_tuple(sequence<0>{}, sequence<1, 2>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor(
in_n_hip_wip_c_grid_desc,
make_tuple(
make_pass_through_transform(N_),
make_embed_transform(make_tuple(Z_, Do_), make_tuple(ConvDilationD_, ConvStrideD_)),
make_embed_transform(make_tuple(Y_, Ho_), make_tuple(ConvDilationH_, ConvStrideH_)),
make_embed_transform(make_tuple(X_, Wo_), make_tuple(ConvDilationW_, ConvStrideW_)),
make_pass_through_transform(C_)),
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}, sequence<4>{}),
make_tuple(sequence<0>{},
sequence<1, 2>{},
sequence<3, 4>{},
sequence<5, 6>{},
sequence<7>{}));
// Input tensor transformation, part 1.
// [N, Di, Hi, Wi, Gm, C] -> [N, Dip, Hip, Wip, Gm, C]
const auto in_n_dip_hip_wip_gm_c_grid_desc = transform_tensor_descriptor(
in_grid_desc,
make_tuple(make_pass_through_transform(N_),
make_pad_transform(Di_, InLeftPadD_, InRightPadD_),
make_pad_transform(Hi_, InLeftPadH_, InRightPadH_),
make_pad_transform(Wi_, InLeftPadW_, InRightPadW_),
make_pass_through_transform(NumGroupsToMerge),
make_pass_through_transform(C_)),
make_tuple(sequence<0>{},
sequence<1>{},
sequence<2>{},
sequence<3>{},
sequence<4>{},
sequence<5>{}),
make_tuple(sequence<0>{},
sequence<1>{},
sequence<2>{},
sequence<3>{},
sequence<4>{},
sequence<5>{}));
const auto in_gemmn_gemmktotal_grid_desc = transform_tensor_descriptor(
in_n_y_ho_x_wo_c_grid_desc,
make_tuple(make_merge_transform(make_tuple(Z_, Y_, X_, C_)),
make_merge_transform(make_tuple(N_, Do_, Ho_, Wo_))),
make_tuple(sequence<1, 3, 5, 7>{}, sequence<0, 2, 4, 6>{}),
make_tuple(sequence<1>{}, sequence<0>{}));
// Input tensor transformation, part 2.
// [N, Zip, Hip, Wip, Gm, C] -> [N, (Z, Zo), (Y, Wo), (X, Wo), Gm, C]
const auto in_n_z_do_y_ho_x_wo_gm_c_grid_desc = transform_tensor_descriptor(
in_n_dip_hip_wip_gm_c_grid_desc,
make_tuple(make_pass_through_transform(N_),
make_embed_transform(make_tuple(Z_, Do_),
make_tuple(ConvDilationD_, ConvStrideD_)),
make_embed_transform(make_tuple(Y_, Ho_),
make_tuple(ConvDilationH_, ConvStrideH_)),
make_embed_transform(make_tuple(X_, Wo_),
make_tuple(ConvDilationW_, ConvStrideW_)),
make_pass_through_transform(NumGroupsToMerge),
make_pass_through_transform(C_)),
make_tuple(sequence<0>{},
sequence<1>{},
sequence<2>{},
sequence<3>{},
sequence<4>{},
sequence<5>{}),
make_tuple(sequence<0>{},
sequence<1, 2>{},
sequence<3, 4>{},
sequence<5, 6>{},
sequence<7>{},
sequence<8>{}));
return make_tuple(out_grid_desc, in_gemmn_gemmktotal_grid_desc, wei_grid_desc);
// Input tensor transformation, part 3.
// [0, 1, 2, 3, 4, 5, 6, 7, 8] -> [0, 1]
// [N, Z, Do, Y, Ho, X, Wo, Gm, C] -> [(N*Do*Ho*Wo), (Z*Y*X*Gm*C)]
const auto in_gemm_k_gemm_n_grid_desc = transform_tensor_descriptor(
in_n_z_do_y_ho_x_wo_gm_c_grid_desc,
make_tuple(make_merge_transform(make_tuple(Z_, Y_, X_, NumGroupsToMerge, C_)),
make_merge_transform(make_tuple(N_, Do_, Ho_, Wo_))),
make_tuple(sequence<1, 3, 5, 7, 8>{}, sequence<0, 2, 4, 6>{}),
make_tuple(sequence<1>{}, sequence<0>{}));
return make_tuple(
out_gemm_k_gemm_m_grid_desc, in_gemm_k_gemm_n_grid_desc, wei_grid_desc);
}
else
{
const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor(
in_grid_desc,
make_tuple(make_pass_through_transform(N_),
make_pad_transform(Di_, InLeftPadD_, InRightPadD_),
make_pad_transform(Hi_, InLeftPadH_, InRightPadH_),
make_pad_transform(Wi_, InLeftPadW_, InRightPadW_),
make_pass_through_transform(C_)),
make_tuple(
sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}, sequence<4>{}),
make_tuple(
sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}, sequence<4>{}));
const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor(
in_n_hip_wip_c_grid_desc,
make_tuple(make_pass_through_transform(N_),
make_embed_transform(make_tuple(Z_, Do_),
make_tuple(ConvDilationD_, ConvStrideD_)),
make_embed_transform(make_tuple(Y_, Ho_),
make_tuple(ConvDilationH_, ConvStrideH_)),
make_embed_transform(make_tuple(X_, Wo_),
make_tuple(ConvDilationW_, ConvStrideW_)),
make_pass_through_transform(C_)),
make_tuple(
sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}, sequence<4>{}),
make_tuple(sequence<0>{},
sequence<1, 2>{},
sequence<3, 4>{},
sequence<5, 6>{},
sequence<7>{}));
const auto in_gemmn_gemmktotal_grid_desc = transform_tensor_descriptor(
in_n_y_ho_x_wo_c_grid_desc,
make_tuple(make_merge_transform(make_tuple(Z_, Y_, X_, C_)),
make_merge_transform(make_tuple(N_, Do_, Ho_, Wo_))),
make_tuple(sequence<1, 3, 5, 7>{}, sequence<0, 2, 4, 6>{}),
make_tuple(sequence<1>{}, sequence<0>{}));
return make_tuple(out_grid_desc, in_gemmn_gemmktotal_grid_desc, wei_grid_desc);
}
}
IndexType G_, N_;
@@ -668,7 +1069,6 @@ struct TransformConvBwdWeightToGemm
IndexType ConvDilationD_, ConvDilationH_, ConvDilationW_;
IndexType InLeftPadD_, InLeftPadH_, InLeftPadW_;
IndexType InRightPadD_, InRightPadH_, InRightPadW_;
IndexType ZYX_;
};
} // namespace ck_tile

View File

@@ -13,10 +13,10 @@ template <index_t NDimSpatial,
index_t VectorSizeA,
index_t VectorSizeB,
index_t VectorSizeC,
index_t NumGroupsToMerge = 1,
bool SplitN = false,
typename ADataType = float,
typename CDataType = float,
index_t NumGroupsToMerge = 1,
typename IndexType = index_t>
struct TransformConvFwdToGemm
{