mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 22:39:03 +00:00
[CK_Tile] Merge multiple convolution groups into a single GEMM batch (#2986)
* Fix compilation of the grouped conv examples. * Fix grouped conv bwd weight example output in CK Tile. * Add number of groups to merge to ck tile grouped gemm example. * Initial set of tests for TransformConvBwdWeightToGemm. * Added unit tests for TransformConvBwdWeightToGemm conv groups are merged. * WIP: Tensor transformations. * Add unit tests for coordinate transforms. * Fully working conv group merging for TransformConvBwdWeightToGemm. * WIP: Merged conv groups offset calculation. * Adde unit tests for tensor view. * WIP: Merged conv groups epilogue. * Enable running multiple conv groups per batch. * Add tests for tile_distribution_encoding. * Change example to match optimally depthwise convolution with merged groups. * Add more tests for tensor view. * Integration test for reading diagonal blocks from grouped distributed tensor. * Improved integration test. * Improve test for accessing diagonal blocks. * Added integration test for cshuffle epilogue LDS tile distribution. * Add more logging. * Increase the max number of reported errors. * WIP: merged conv groups GEMM epilogue changes. * LDS to global memory copy. * Fix tile window size for c block. * Integration test for CShuffle epilogue. * Improved CShuffle test. * WIP: Separate epilogue for merged conv groups. * Tile example parameters changes to match depthwise conv. * Offset fixes. * Epilogue fixes. * Working baseline for depthwise covolution with merged conv groups. * Fix build. * Initial unit tests for tensor descriptor. * Add one more unit test for tensor view. * WIP: LDS to global mem transfer using CK tile tensor descriptor and tile distribution encoding. * Fully functional LDS to global mem transfer using tensor descriptor and tile distribution encoding. * Add more comments, disable debug code. * Remove debug and other dead code. * Code clean-up for bwd tensor transformations. * Enable running multiple GEMM batches of merged conv groups. * Add compile check for assumed row-mjor layout. * Fix strides in 1D conv to gemm transformation. * WIP: Simplify conv to gemm transformations and handle K > 1 and C > 1 cases. * Fix case k > 1 and c=1. * Remove debug code. * Make MPerGroup and NPerGroup template parameters. * Add additional check for non-supported c > 1 case. * WIP: Put back the generic tensor descriptors for convolutions. * Fix tensor descriptors. * Remove the obsolete template parameters. * Add more instances. * Fix bugs in merged conv groups tensor descriptors. * Fix tensor descriptors for merged conv groups when K > 1. * Remove debug output. * Remove dead code. * Fix merge conflicts. * Code clean-up. * Remove unused code. * Run clang-formatting. * Remove debug prints and obsolete tests. * Check that number of convolution groups is multiple of merged groups. * Fix build after removing obsolete functionality. * Remove obsolete enumeration. * Fix new unit projects. * Remove unnecessary includes. * Fix passing the number of merged groups. * Remove unrelated tests. * Fix IsSupportedArgument for bwd weight conv kernel. * Fix clang formatting. * Fix the bwd weight conv to gemm mapping for num merged groups > 1. * GEMM config for conv group merging. * Fix clang-formatting. * Remove obsolete comment. * Fix typos in comment strings. * Increase the max number of reported errors when testing against reference implementation. * Rename gemm_config to conv_config. * Rename GemmConfig to ConvConfig and move NumGroupsToMerge into ConvConfig. * Change num_groups_to_merge to a boolean flag in the ck tile grouped conv example. * Run clang-format. * Add number of merged groups into kernel name string. * Remove group merging flag from CK Tile grouped conv example.
This commit is contained in:
@@ -17,7 +17,7 @@
|
||||
#define CK_TILE_PIPELINE_COMPUTE_V4 3
|
||||
#define CK_TILE_PIPELINE_COMPUTE_V5 4
|
||||
|
||||
struct GemmConfigBase
|
||||
struct ConvConfigBase
|
||||
{
|
||||
static constexpr bool kPadM = true;
|
||||
static constexpr bool kPadN = true;
|
||||
@@ -29,6 +29,10 @@ struct GemmConfigBase
|
||||
static constexpr bool TransposeC = false;
|
||||
static constexpr bool UseStructuredSparsity = false;
|
||||
|
||||
static constexpr ck_tile::index_t VectorSizeA = 4;
|
||||
static constexpr ck_tile::index_t VectorSizeB = 8;
|
||||
static constexpr ck_tile::index_t VectorSizeC = 8;
|
||||
|
||||
static constexpr int kBlockPerCu = 1;
|
||||
static constexpr ck_tile::index_t TileParitionerGroupNum = 8;
|
||||
static constexpr ck_tile::index_t TileParitionerM01 = 4;
|
||||
@@ -37,10 +41,12 @@ struct GemmConfigBase
|
||||
static constexpr ck_tile::index_t NumWaveGroups = 1;
|
||||
static constexpr bool Preshuffle = false;
|
||||
static constexpr bool TiledMMAPermuteN = false;
|
||||
|
||||
static constexpr ck_tile::index_t NumGroupsToMerge = 1;
|
||||
};
|
||||
|
||||
template <typename PrecType>
|
||||
struct GemmConfigMemoryInterwave : public GemmConfigBase
|
||||
struct ConvConfigMemoryInterwave : public ConvConfigBase
|
||||
{
|
||||
// Memory friendly for Interwave scheduler
|
||||
static constexpr ck_tile::index_t M_Tile = 128;
|
||||
@@ -61,7 +67,7 @@ struct GemmConfigMemoryInterwave : public GemmConfigBase
|
||||
};
|
||||
|
||||
template <typename PrecType>
|
||||
struct GemmConfigMemoryIntrawave : public GemmConfigBase
|
||||
struct ConvConfigMemoryIntrawave : public ConvConfigBase
|
||||
{
|
||||
static constexpr ck_tile::index_t M_Tile = 128;
|
||||
static constexpr ck_tile::index_t N_Tile = 32;
|
||||
@@ -80,7 +86,7 @@ struct GemmConfigMemoryIntrawave : public GemmConfigBase
|
||||
};
|
||||
|
||||
template <typename PrecType>
|
||||
struct GemmConfigComputeV3 : public GemmConfigBase
|
||||
struct ConvConfigComputeV3 : public ConvConfigBase
|
||||
{
|
||||
// Compute V3 only support Intrawave scheduler
|
||||
static constexpr ck_tile::index_t M_Tile = 16;
|
||||
@@ -100,7 +106,7 @@ struct GemmConfigComputeV3 : public GemmConfigBase
|
||||
};
|
||||
|
||||
template <typename PrecType>
|
||||
struct GemmConfigComputeV3_1 : public GemmConfigBase
|
||||
struct ConvConfigComputeV3_1 : public ConvConfigBase
|
||||
{
|
||||
static constexpr ck_tile::index_t M_Tile = 256;
|
||||
static constexpr ck_tile::index_t N_Tile = 256;
|
||||
@@ -119,7 +125,7 @@ struct GemmConfigComputeV3_1 : public GemmConfigBase
|
||||
};
|
||||
|
||||
template <typename PrecType>
|
||||
struct GemmConfigComputeV3_2 : public GemmConfigBase
|
||||
struct ConvConfigComputeV3_2 : public ConvConfigBase
|
||||
{
|
||||
static constexpr ck_tile::index_t M_Tile = 128;
|
||||
static constexpr ck_tile::index_t N_Tile = 128;
|
||||
@@ -140,7 +146,7 @@ struct GemmConfigComputeV3_2 : public GemmConfigBase
|
||||
};
|
||||
|
||||
template <typename PrecType>
|
||||
struct GemmConfigComputeV3_WMMA : public GemmConfigBase
|
||||
struct ConvConfigComputeV3_WMMA : public ConvConfigBase
|
||||
{
|
||||
static constexpr ck_tile::index_t M_Tile = 128;
|
||||
static constexpr ck_tile::index_t N_Tile = 128;
|
||||
@@ -161,7 +167,7 @@ struct GemmConfigComputeV3_WMMA : public GemmConfigBase
|
||||
};
|
||||
|
||||
template <typename PrecType>
|
||||
struct GemmConfigComputeV4 : public GemmConfigBase
|
||||
struct ConvConfigComputeV4 : public ConvConfigBase
|
||||
{
|
||||
// Compute V4 only support Intrawave scheduler
|
||||
// Using the ping pong reader in the lds level
|
||||
@@ -182,7 +188,7 @@ struct GemmConfigComputeV4 : public GemmConfigBase
|
||||
};
|
||||
|
||||
template <typename PrecType>
|
||||
struct GemmConfigComputeV4_1 : public GemmConfigBase
|
||||
struct ConvConfigComputeV4_1 : public ConvConfigBase
|
||||
{
|
||||
static constexpr ck_tile::index_t M_Tile = 256;
|
||||
static constexpr ck_tile::index_t N_Tile = 256;
|
||||
@@ -201,7 +207,7 @@ struct GemmConfigComputeV4_1 : public GemmConfigBase
|
||||
};
|
||||
|
||||
template <typename PrecType>
|
||||
struct GemmConfigComputeV5 : public GemmConfigBase
|
||||
struct ConvConfigComputeV5 : public ConvConfigBase
|
||||
{
|
||||
static constexpr ck_tile::index_t M_Tile = 128;
|
||||
static constexpr ck_tile::index_t N_Tile = 128;
|
||||
@@ -220,6 +226,31 @@ struct GemmConfigComputeV5 : public GemmConfigBase
|
||||
static constexpr ck_tile::index_t NumWaNumWaveGroups = 2;
|
||||
};
|
||||
|
||||
template <typename PrecType>
|
||||
struct ConvConfigComputeV3_merged_groups : public ConvConfigBase
|
||||
{
|
||||
static constexpr ck_tile::index_t VectorSizeA = 4;
|
||||
static constexpr ck_tile::index_t VectorSizeB = 8;
|
||||
static constexpr ck_tile::index_t VectorSizeC = 8;
|
||||
|
||||
static constexpr ck_tile::index_t M_Tile = 16;
|
||||
static constexpr ck_tile::index_t N_Tile = 32;
|
||||
static constexpr ck_tile::index_t K_Tile = 32;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp = 1;
|
||||
static constexpr ck_tile::index_t N_Warp = 2;
|
||||
static constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = 32;
|
||||
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3;
|
||||
|
||||
static constexpr ck_tile::index_t NumGroupsToMerge = 2;
|
||||
};
|
||||
|
||||
template <typename InDataType, typename WeiDataType = InDataType, typename OutDataType = InDataType>
|
||||
struct ConvTypeConfig;
|
||||
|
||||
@@ -51,8 +51,8 @@ int run_grouped_conv_bwd_data_example(int argc, char* argv[])
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
#if CK_TILE_USE_WMMA
|
||||
return !run_grouped_conv_bwd_data_example<GemmConfigComputeV3_WMMA>(argc, argv);
|
||||
return !run_grouped_conv_bwd_data_example<ConvConfigComputeV3_WMMA>(argc, argv);
|
||||
#else
|
||||
return !run_grouped_conv_bwd_data_example<GemmConfigComputeV3>(argc, argv);
|
||||
return !run_grouped_conv_bwd_data_example<ConvConfigComputeV3>(argc, argv);
|
||||
#endif
|
||||
}
|
||||
|
||||
@@ -14,7 +14,7 @@
|
||||
#include "grouped_convolution_backward_weight_invoker.hpp"
|
||||
#include "run_grouped_convolution_bwd_weight_example.inc"
|
||||
|
||||
template <template <typename PrecType> typename GemmConfig>
|
||||
template <template <typename PrecType> typename ConvConfig>
|
||||
int run_grouped_conv_bwd_weight_example(ck_tile::ArgParser& arg_parser)
|
||||
{
|
||||
using Invoker = GroupedConvolutionBackwardWeightInvoker;
|
||||
@@ -27,14 +27,14 @@ int run_grouped_conv_bwd_weight_example(ck_tile::ArgParser& arg_parser)
|
||||
if(data_type == "fp16")
|
||||
{
|
||||
return run_grouped_conv_bwd_weight_example_prec_type<Invoker,
|
||||
GemmConfig<ck_tile::half_t>,
|
||||
ConvConfig<ck_tile::half_t>,
|
||||
ck_tile::half_t>(
|
||||
in_layout, wei_layout, out_layout, arg_parser);
|
||||
}
|
||||
else if(data_type == "bf16")
|
||||
{
|
||||
return run_grouped_conv_bwd_weight_example_prec_type<Invoker,
|
||||
GemmConfig<ck_tile::bf16_t>,
|
||||
ConvConfig<ck_tile::bf16_t>,
|
||||
ck_tile::bf16_t>(
|
||||
in_layout, wei_layout, out_layout, arg_parser);
|
||||
}
|
||||
@@ -54,9 +54,9 @@ int main(int argc, char* argv[])
|
||||
try
|
||||
{
|
||||
#if CK_TILE_USE_WMMA
|
||||
return !run_grouped_conv_bwd_weight_example<GemmConfigComputeV3_WMMA>(arg_parser);
|
||||
return !run_grouped_conv_bwd_weight_example<ConvConfigComputeV3_WMMA>(arg_parser);
|
||||
#else
|
||||
return !run_grouped_conv_bwd_weight_example<GemmConfigComputeV3>(arg_parser);
|
||||
return !run_grouped_conv_bwd_weight_example<ConvConfigComputeV3>(arg_parser);
|
||||
#endif
|
||||
}
|
||||
catch(const std::runtime_error& e)
|
||||
|
||||
@@ -7,7 +7,7 @@
|
||||
struct GroupedConvolutionBackwardWeightInvoker
|
||||
{
|
||||
template <ck_tile::index_t NDimSpatial,
|
||||
typename GemmConfig,
|
||||
typename ConvConfig,
|
||||
typename InDataType,
|
||||
typename WeiDataType,
|
||||
typename AccDataType,
|
||||
@@ -25,22 +25,22 @@ struct GroupedConvolutionBackwardWeightInvoker
|
||||
|
||||
// Implicit GEMM Traits
|
||||
using GemmShape = ck_tile::TileGemmShape<
|
||||
ck_tile::sequence<GemmConfig::M_Tile, GemmConfig::N_Tile, GemmConfig::K_Tile>,
|
||||
ck_tile::sequence<GemmConfig::M_Warp, GemmConfig::N_Warp, GemmConfig::K_Warp>,
|
||||
ck_tile::sequence<ConvConfig::M_Tile, ConvConfig::N_Tile, ConvConfig::K_Tile>,
|
||||
ck_tile::sequence<ConvConfig::M_Warp, ConvConfig::N_Warp, ConvConfig::K_Warp>,
|
||||
ck_tile::
|
||||
sequence<GemmConfig::M_Warp_Tile, GemmConfig::N_Warp_Tile, GemmConfig::K_Warp_Tile>,
|
||||
GemmConfig::PermuteA,
|
||||
GemmConfig::PermuteB>;
|
||||
sequence<ConvConfig::M_Warp_Tile, ConvConfig::N_Warp_Tile, ConvConfig::K_Warp_Tile>,
|
||||
ConvConfig::PermuteA,
|
||||
ConvConfig::PermuteB>;
|
||||
|
||||
constexpr ck_tile::index_t VectorSizeA = 4;
|
||||
constexpr ck_tile::index_t VectorSizeB = 8;
|
||||
constexpr ck_tile::index_t VectorSizeC = 8;
|
||||
constexpr ck_tile::index_t VectorSizeA = ConvConfig::VectorSizeA;
|
||||
constexpr ck_tile::index_t VectorSizeB = ConvConfig::VectorSizeB;
|
||||
constexpr ck_tile::index_t VectorSizeC = ConvConfig::VectorSizeC;
|
||||
|
||||
constexpr auto ConvSpec = ck_tile::ConvolutionSpecialization::Default;
|
||||
using TilePartitioner =
|
||||
ck_tile::GemmSpatiallyLocalTilePartitioner<GemmShape,
|
||||
GemmConfig::TileParitionerGroupNum,
|
||||
GemmConfig::TileParitionerM01>;
|
||||
ConvConfig::TileParitionerGroupNum,
|
||||
ConvConfig::TileParitionerM01>;
|
||||
using GroupedConvTraitsType = ck_tile::GroupedConvTraits<NDimSpatial,
|
||||
ConvSpec,
|
||||
InLayout,
|
||||
@@ -49,20 +49,21 @@ struct GroupedConvolutionBackwardWeightInvoker
|
||||
OutLayout,
|
||||
VectorSizeA,
|
||||
VectorSizeB,
|
||||
VectorSizeC>;
|
||||
VectorSizeC,
|
||||
ConvConfig::NumGroupsToMerge>;
|
||||
|
||||
using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits<
|
||||
GemmConfig::kPadM,
|
||||
GemmConfig::kPadN,
|
||||
GemmConfig::kPadK,
|
||||
GemmConfig::DoubleSmemBuffer,
|
||||
ConvConfig::kPadM,
|
||||
ConvConfig::kPadN,
|
||||
ConvConfig::kPadK,
|
||||
ConvConfig::DoubleSmemBuffer,
|
||||
typename GroupedConvTraitsType::GroupedConvImplicitGemmTraitsBwdWeight::AsLayout,
|
||||
typename GroupedConvTraitsType::GroupedConvImplicitGemmTraitsBwdWeight::BsLayout,
|
||||
typename GroupedConvTraitsType::GroupedConvImplicitGemmTraitsBwdWeight::CLayout,
|
||||
GemmConfig::TransposeC,
|
||||
GemmConfig::UseStructuredSparsity,
|
||||
ConvConfig::TransposeC,
|
||||
ConvConfig::UseStructuredSparsity,
|
||||
false, // Persistent,
|
||||
GemmConfig::NumWaveGroups>;
|
||||
ConvConfig::NumWaveGroups>;
|
||||
|
||||
using GemmPipelineProblem = ck_tile::GemmPipelineProblem<
|
||||
OutDataType,
|
||||
@@ -78,7 +79,7 @@ struct GroupedConvolutionBackwardWeightInvoker
|
||||
VectorSizeB>;
|
||||
|
||||
using BaseGemmPipeline = typename PipelineTypeTraits<
|
||||
GemmConfig::Pipeline>::template UniversalGemmPipeline<GemmPipelineProblem>;
|
||||
ConvConfig::Pipeline>::template UniversalGemmPipeline<GemmPipelineProblem>;
|
||||
|
||||
const ck_tile::index_t gemm_k =
|
||||
args.N_ * std::accumulate(args.output_spatial_lengths_.begin(),
|
||||
@@ -86,8 +87,8 @@ struct GroupedConvolutionBackwardWeightInvoker
|
||||
1,
|
||||
std::multiplies<ck_tile::index_t>());
|
||||
|
||||
const ck_tile::index_t k_grain = args.k_batch * GemmConfig::K_Tile;
|
||||
const ck_tile::index_t K_split = (gemm_k + k_grain - 1) / k_grain * GemmConfig::K_Tile;
|
||||
const ck_tile::index_t k_grain = args.k_batch * ConvConfig::K_Tile;
|
||||
const ck_tile::index_t K_split = (gemm_k + k_grain - 1) / k_grain * ConvConfig::K_Tile;
|
||||
const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split);
|
||||
const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop);
|
||||
const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop);
|
||||
@@ -98,7 +99,7 @@ struct GroupedConvolutionBackwardWeightInvoker
|
||||
const auto memory_operation_) {
|
||||
constexpr bool has_hot_loop_v = has_hot_loop_.value;
|
||||
constexpr auto tail_number_v = tail_number_.value;
|
||||
constexpr auto scheduler = GemmConfig::Scheduler;
|
||||
constexpr auto scheduler = ConvConfig::Scheduler;
|
||||
constexpr auto memory_operation = memory_operation_.value;
|
||||
|
||||
using UniversalGemmProblem =
|
||||
@@ -118,7 +119,7 @@ struct GroupedConvolutionBackwardWeightInvoker
|
||||
VectorSizeB>;
|
||||
|
||||
using GemmPipeline = typename PipelineTypeTraits<
|
||||
GemmConfig::Pipeline>::template GemmPipeline<UniversalGemmProblem>;
|
||||
ConvConfig::Pipeline>::template GemmPipeline<UniversalGemmProblem>;
|
||||
|
||||
using ConvEpilogue = ck_tile::CShuffleEpilogue<ck_tile::CShuffleEpilogueProblem<
|
||||
OutDataType,
|
||||
@@ -131,12 +132,12 @@ struct GroupedConvolutionBackwardWeightInvoker
|
||||
CDEElementWise,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
GemmConfig::M_Warp,
|
||||
GemmConfig::N_Warp,
|
||||
GemmConfig::M_Warp_Tile,
|
||||
GemmConfig::N_Warp_Tile,
|
||||
GemmConfig::K_Warp_Tile,
|
||||
GemmConfig::TransposeC,
|
||||
ConvConfig::M_Warp,
|
||||
ConvConfig::N_Warp,
|
||||
ConvConfig::M_Warp_Tile,
|
||||
ConvConfig::N_Warp_Tile,
|
||||
ConvConfig::K_Warp_Tile,
|
||||
ConvConfig::TransposeC,
|
||||
memory_operation,
|
||||
1,
|
||||
true,
|
||||
|
||||
@@ -13,9 +13,9 @@
|
||||
#include "grouped_convolution_utils.hpp"
|
||||
#include "grouped_convolution_backward_weight_two_stage_invoker.hpp"
|
||||
#include "run_grouped_convolution_bwd_weight_example.inc"
|
||||
#include "gemm_configs.hpp"
|
||||
#include "conv_configs.hpp"
|
||||
|
||||
template <template <typename PrecType> typename GemmConfig>
|
||||
template <template <typename PrecType> typename ConvConfig>
|
||||
int run_grouped_conv_bwd_weight_example(ck_tile::ArgParser& arg_parser)
|
||||
{
|
||||
using Invoker = GroupedConvolutionBackwardWeightTwoStageInvoker;
|
||||
@@ -28,14 +28,14 @@ int run_grouped_conv_bwd_weight_example(ck_tile::ArgParser& arg_parser)
|
||||
if(data_type == "fp16")
|
||||
{
|
||||
return run_grouped_conv_bwd_weight_example_prec_type<Invoker,
|
||||
GemmConfig<ck_tile::half_t>,
|
||||
ConvConfig<ck_tile::half_t>,
|
||||
ck_tile::half_t>(
|
||||
in_layout, wei_layout, out_layout, arg_parser);
|
||||
}
|
||||
else if(data_type == "bf16")
|
||||
{
|
||||
return run_grouped_conv_bwd_weight_example_prec_type<Invoker,
|
||||
GemmConfig<ck_tile::bf16_t>,
|
||||
ConvConfig<ck_tile::bf16_t>,
|
||||
ck_tile::bf16_t>(
|
||||
in_layout, wei_layout, out_layout, arg_parser);
|
||||
}
|
||||
@@ -55,9 +55,9 @@ int main(int argc, char* argv[])
|
||||
try
|
||||
{
|
||||
#if CK_TILE_USE_WMMA
|
||||
return !run_grouped_conv_bwd_weight_example<GemmConfigComputeV3_WMMA>(arg_parser);
|
||||
return !run_grouped_conv_bwd_weight_example<ConvConfigComputeV3_WMMA>(arg_parser);
|
||||
#else
|
||||
return !run_grouped_conv_bwd_weight_example<GemmConfigComputeV3>(arg_parser);
|
||||
return !run_grouped_conv_bwd_weight_example<ConvConfigComputeV3>(arg_parser);
|
||||
#endif
|
||||
}
|
||||
catch(const std::runtime_error& e)
|
||||
|
||||
@@ -7,7 +7,7 @@
|
||||
struct GroupedConvolutionBackwardWeightTwoStageInvoker
|
||||
{
|
||||
template <ck_tile::index_t NDimSpatial,
|
||||
typename GemmConfig,
|
||||
typename ConvConfig,
|
||||
typename InDataType,
|
||||
typename WeiDataType,
|
||||
typename AccDataType,
|
||||
@@ -27,12 +27,12 @@ struct GroupedConvolutionBackwardWeightTwoStageInvoker
|
||||
|
||||
// Implicit GEMM Traits
|
||||
using GemmShape = ck_tile::TileGemmShape<
|
||||
ck_tile::sequence<GemmConfig::M_Tile, GemmConfig::N_Tile, GemmConfig::K_Tile>,
|
||||
ck_tile::sequence<GemmConfig::M_Warp, GemmConfig::N_Warp, GemmConfig::K_Warp>,
|
||||
ck_tile::sequence<ConvConfig::M_Tile, ConvConfig::N_Tile, ConvConfig::K_Tile>,
|
||||
ck_tile::sequence<ConvConfig::M_Warp, ConvConfig::N_Warp, ConvConfig::K_Warp>,
|
||||
ck_tile::
|
||||
sequence<GemmConfig::M_Warp_Tile, GemmConfig::N_Warp_Tile, GemmConfig::K_Warp_Tile>,
|
||||
GemmConfig::PermuteA,
|
||||
GemmConfig::PermuteB>;
|
||||
sequence<ConvConfig::M_Warp_Tile, ConvConfig::N_Warp_Tile, ConvConfig::K_Warp_Tile>,
|
||||
ConvConfig::PermuteA,
|
||||
ConvConfig::PermuteB>;
|
||||
|
||||
constexpr ck_tile::index_t VectorSizeA = 4;
|
||||
constexpr ck_tile::index_t VectorSizeB = 8;
|
||||
@@ -41,8 +41,8 @@ struct GroupedConvolutionBackwardWeightTwoStageInvoker
|
||||
constexpr auto ConvSpec = ck_tile::ConvolutionSpecialization::Default;
|
||||
using TilePartitioner =
|
||||
ck_tile::GemmSpatiallyLocalTilePartitioner<GemmShape,
|
||||
GemmConfig::TileParitionerGroupNum,
|
||||
GemmConfig::TileParitionerM01>;
|
||||
ConvConfig::TileParitionerGroupNum,
|
||||
ConvConfig::TileParitionerM01>;
|
||||
using GroupedConvTraitsType = ck_tile::GroupedConvTraits<NDimSpatial,
|
||||
ConvSpec,
|
||||
InLayout,
|
||||
@@ -54,17 +54,17 @@ struct GroupedConvolutionBackwardWeightTwoStageInvoker
|
||||
VectorSizeC>;
|
||||
|
||||
using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits<
|
||||
GemmConfig::kPadM,
|
||||
GemmConfig::kPadN,
|
||||
GemmConfig::kPadK,
|
||||
GemmConfig::DoubleSmemBuffer,
|
||||
ConvConfig::kPadM,
|
||||
ConvConfig::kPadN,
|
||||
ConvConfig::kPadK,
|
||||
ConvConfig::DoubleSmemBuffer,
|
||||
typename GroupedConvTraitsType::GroupedConvImplicitGemmTraitsBwdWeight::AsLayout,
|
||||
typename GroupedConvTraitsType::GroupedConvImplicitGemmTraitsBwdWeight::BsLayout,
|
||||
typename GroupedConvTraitsType::GroupedConvImplicitGemmTraitsBwdWeight::CLayout,
|
||||
GemmConfig::TransposeC,
|
||||
GemmConfig::UseStructuredSparsity,
|
||||
ConvConfig::TransposeC,
|
||||
ConvConfig::UseStructuredSparsity,
|
||||
false, // Persistent,
|
||||
GemmConfig::NumWaveGroups>;
|
||||
ConvConfig::NumWaveGroups>;
|
||||
|
||||
using GemmPipelineProblem = ck_tile::GemmPipelineProblem<
|
||||
OutDataType,
|
||||
@@ -80,7 +80,7 @@ struct GroupedConvolutionBackwardWeightTwoStageInvoker
|
||||
VectorSizeB>;
|
||||
|
||||
using BaseGemmPipeline = typename PipelineTypeTraits<
|
||||
GemmConfig::Pipeline>::template UniversalGemmPipeline<GemmPipelineProblem>;
|
||||
ConvConfig::Pipeline>::template UniversalGemmPipeline<GemmPipelineProblem>;
|
||||
|
||||
const ck_tile::index_t gemm_k =
|
||||
args.N_ * std::accumulate(args.output_spatial_lengths_.begin(),
|
||||
@@ -88,8 +88,8 @@ struct GroupedConvolutionBackwardWeightTwoStageInvoker
|
||||
1,
|
||||
std::multiplies<ck_tile::index_t>());
|
||||
|
||||
const ck_tile::index_t k_grain = args.k_batch * GemmConfig::K_Tile;
|
||||
const ck_tile::index_t K_split = (gemm_k + k_grain - 1) / k_grain * GemmConfig::K_Tile;
|
||||
const ck_tile::index_t k_grain = args.k_batch * ConvConfig::K_Tile;
|
||||
const ck_tile::index_t K_split = (gemm_k + k_grain - 1) / k_grain * ConvConfig::K_Tile;
|
||||
const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split);
|
||||
const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop);
|
||||
const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop);
|
||||
@@ -100,7 +100,7 @@ struct GroupedConvolutionBackwardWeightTwoStageInvoker
|
||||
const auto memory_operation_) {
|
||||
constexpr bool has_hot_loop_v = has_hot_loop_.value;
|
||||
constexpr auto tail_number_v = tail_number_.value;
|
||||
constexpr auto scheduler = GemmConfig::Scheduler;
|
||||
constexpr auto scheduler = ConvConfig::Scheduler;
|
||||
constexpr auto memory_operation = memory_operation_.value;
|
||||
|
||||
using UniversalGemmProblem =
|
||||
@@ -120,7 +120,7 @@ struct GroupedConvolutionBackwardWeightTwoStageInvoker
|
||||
VectorSizeB>;
|
||||
|
||||
using GemmPipeline = typename PipelineTypeTraits<
|
||||
GemmConfig::Pipeline>::template GemmPipeline<UniversalGemmProblem>;
|
||||
ConvConfig::Pipeline>::template GemmPipeline<UniversalGemmProblem>;
|
||||
|
||||
using ConvEpilogue = ck_tile::CShuffleEpilogue<ck_tile::CShuffleEpilogueProblem<
|
||||
OutDataType, // A: Out
|
||||
@@ -133,11 +133,11 @@ struct GroupedConvolutionBackwardWeightTwoStageInvoker
|
||||
CDEElementWise,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
GemmConfig::M_Warp,
|
||||
GemmConfig::N_Warp,
|
||||
GemmConfig::M_Warp_Tile,
|
||||
GemmConfig::N_Warp_Tile,
|
||||
GemmConfig::K_Warp_Tile,
|
||||
ConvConfig::M_Warp,
|
||||
ConvConfig::N_Warp,
|
||||
ConvConfig::M_Warp_Tile,
|
||||
ConvConfig::N_Warp_Tile,
|
||||
ConvConfig::K_Warp_Tile,
|
||||
GemmPipelineProblem::TransposeC,
|
||||
memory_operation,
|
||||
1,
|
||||
|
||||
@@ -51,8 +51,8 @@ int run_grouped_conv_fwd_example(int argc, char* argv[])
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
#if CK_TILE_USE_WMMA
|
||||
return !run_grouped_conv_fwd_example<GemmConfigComputeV3_WMMA>(argc, argv);
|
||||
return !run_grouped_conv_fwd_example<ConvConfigComputeV3_WMMA>(argc, argv);
|
||||
#else
|
||||
return !run_grouped_conv_fwd_example<GemmConfigComputeV3>(argc, argv);
|
||||
return !run_grouped_conv_fwd_example<ConvConfigComputeV3>(argc, argv);
|
||||
#endif
|
||||
}
|
||||
|
||||
@@ -32,9 +32,10 @@ struct GroupedConvolutionForwardInvoker
|
||||
GemmConfig::PermuteA,
|
||||
GemmConfig::PermuteB>;
|
||||
|
||||
constexpr ck_tile::index_t VectorSizeA = 8;
|
||||
constexpr ck_tile::index_t VectorSizeB = 8;
|
||||
constexpr ck_tile::index_t VectorSizeC = 8;
|
||||
constexpr ck_tile::index_t VectorSizeA = 8;
|
||||
constexpr ck_tile::index_t VectorSizeB = 8;
|
||||
constexpr ck_tile::index_t VectorSizeC = 8;
|
||||
constexpr ck_tile::index_t NumGroupsToMerge = 1;
|
||||
|
||||
constexpr auto ConvSpec = ck_tile::ConvolutionSpecialization::Default;
|
||||
using TilePartitioner =
|
||||
@@ -50,6 +51,7 @@ struct GroupedConvolutionForwardInvoker
|
||||
VectorSizeA,
|
||||
VectorSizeB,
|
||||
VectorSizeC,
|
||||
NumGroupsToMerge,
|
||||
CDElementWise>;
|
||||
|
||||
using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits<
|
||||
|
||||
@@ -11,24 +11,11 @@
|
||||
#include "ck_tile/ops/epilogue.hpp"
|
||||
#include "ck_tile/ops/gemm.hpp"
|
||||
#include "ck_tile/ops/grouped_convolution.hpp"
|
||||
#include "gemm_configs.hpp"
|
||||
#include "conv_configs.hpp"
|
||||
using MemoryOpSet =
|
||||
std::integral_constant<ck_tile::memory_operation_enum, ck_tile::memory_operation_enum::set>;
|
||||
using MemoryOpAtomicAdd = std::integral_constant<ck_tile::memory_operation_enum,
|
||||
ck_tile::memory_operation_enum::atomic_add>;
|
||||
struct GemmWarpConfig_Mfma
|
||||
{
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = 16;
|
||||
};
|
||||
|
||||
struct GemmWarpConfig_Wmma
|
||||
{
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = 16;
|
||||
};
|
||||
|
||||
template <typename InDataType, typename WeiDataType, typename AccDataType, typename OutDataType>
|
||||
auto calculate_rtol_atol(const ck_tile::index_t GemmK,
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
#pragma once
|
||||
|
||||
template <ck_tile::index_t NDimSpatial,
|
||||
typename GemmConfig,
|
||||
typename ConvConfig,
|
||||
typename Invoker,
|
||||
typename InDataType,
|
||||
typename WeiDataType,
|
||||
@@ -17,7 +17,7 @@ float invoke_grouped_conv_bwd_weight(ck_tile::GroupedConvBwdWeightHostArgs& args
|
||||
int n_repeat)
|
||||
{
|
||||
float ave_time = Invoker::template grouped_conv_bwd_weight<NDimSpatial,
|
||||
GemmConfig,
|
||||
ConvConfig,
|
||||
InDataType,
|
||||
WeiDataType,
|
||||
AccDataType,
|
||||
@@ -31,7 +31,7 @@ float invoke_grouped_conv_bwd_weight(ck_tile::GroupedConvBwdWeightHostArgs& args
|
||||
}
|
||||
|
||||
template <ck_tile::index_t NDimSpatial,
|
||||
typename GemmConfig,
|
||||
typename ConvConfig,
|
||||
typename Invoker,
|
||||
typename InDataType,
|
||||
typename WeiDataType = InDataType,
|
||||
@@ -131,7 +131,7 @@ int run_grouped_conv_bwd_weight_example_with_layouts(ck_tile::ArgParser& arg_par
|
||||
std::cout << "output: " << output.mDesc << std::endl;
|
||||
|
||||
float ave_time = invoke_grouped_conv_bwd_weight<NDimSpatial,
|
||||
GemmConfig,
|
||||
ConvConfig,
|
||||
Invoker,
|
||||
InDataType,
|
||||
WeiDataType,
|
||||
@@ -193,7 +193,7 @@ int run_grouped_conv_bwd_weight_example_with_layouts(ck_tile::ArgParser& arg_par
|
||||
}
|
||||
|
||||
template <typename Invoker,
|
||||
typename GemmConfig,
|
||||
typename ConvConfig,
|
||||
typename InPrecType,
|
||||
typename WeiPrecType = InPrecType,
|
||||
typename OutPrecType = InPrecType>
|
||||
@@ -217,7 +217,7 @@ int run_grouped_conv_bwd_weight_example_prec_type(std::string in_layout,
|
||||
if(in_layout == "NWGC" && wei_layout == "GKXC" && out_layout == "NWGK")
|
||||
{
|
||||
return run_grouped_conv_bwd_weight_example_with_layouts<ck_tile::number<1>{},
|
||||
GemmConfig,
|
||||
ConvConfig,
|
||||
Invoker,
|
||||
InPrecType,
|
||||
WeiPrecType,
|
||||
@@ -227,7 +227,7 @@ int run_grouped_conv_bwd_weight_example_prec_type(std::string in_layout,
|
||||
else if(in_layout == "NHWGC" && wei_layout == "GKYXC" && out_layout == "NHWGK")
|
||||
{
|
||||
return run_grouped_conv_bwd_weight_example_with_layouts<ck_tile::number<2>{},
|
||||
GemmConfig,
|
||||
ConvConfig,
|
||||
Invoker,
|
||||
InPrecType,
|
||||
WeiPrecType,
|
||||
@@ -237,7 +237,7 @@ int run_grouped_conv_bwd_weight_example_prec_type(std::string in_layout,
|
||||
else if(in_layout == "NDHWGC" && wei_layout == "GKZYXC" && out_layout == "NDHWGK")
|
||||
{
|
||||
return run_grouped_conv_bwd_weight_example_with_layouts<ck_tile::number<3>{},
|
||||
GemmConfig,
|
||||
ConvConfig,
|
||||
Invoker,
|
||||
InPrecType,
|
||||
WeiPrecType,
|
||||
|
||||
Reference in New Issue
Block a user