[CK_Tile] Merge multiple convolution groups into a single GEMM batch (#2986)

* Fix compilation of the grouped conv examples.

* Fix grouped conv bwd weight example output in CK Tile.

* Add number of groups to merge to ck tile grouped gemm example.

* Initial set of tests for TransformConvBwdWeightToGemm.

* Added unit tests for TransformConvBwdWeightToGemm conv groups are merged.

* WIP: Tensor transformations.

* Add unit tests for coordinate transforms.

* Fully working conv group merging for TransformConvBwdWeightToGemm.

* WIP: Merged conv groups offset calculation.

* Adde unit tests for tensor view.

* WIP: Merged conv groups epilogue.

* Enable running multiple conv groups per batch.

* Add tests for tile_distribution_encoding.

* Change example to match optimally depthwise convolution with merged groups.

* Add more tests for tensor view.

* Integration test for reading diagonal blocks from grouped distributed tensor.

* Improved integration test.

* Improve test for accessing diagonal blocks.

* Added integration test for cshuffle epilogue LDS tile distribution.

* Add more logging.

* Increase the max number of reported errors.

* WIP: merged conv groups GEMM epilogue changes.

* LDS to global memory copy.

* Fix tile window size for c block.

* Integration test for CShuffle epilogue.

* Improved CShuffle test.

* WIP: Separate epilogue for merged conv groups.

* Tile example parameters changes to match depthwise conv.

* Offset fixes.

* Epilogue fixes.

* Working baseline for depthwise covolution with merged conv groups.

* Fix build.

* Initial unit tests for tensor descriptor.

* Add one more unit test for tensor view.

* WIP: LDS to global mem transfer using CK tile tensor descriptor and tile distribution encoding.

* Fully functional LDS to global mem transfer using tensor descriptor and tile distribution encoding.

* Add more comments, disable debug code.

* Remove debug and other dead code.

* Code clean-up for bwd tensor transformations.

* Enable running multiple GEMM batches of merged conv groups.

* Add compile check for assumed row-mjor layout.

* Fix strides in 1D conv to gemm transformation.

* WIP: Simplify conv to gemm transformations and handle K > 1 and C > 1 cases.

* Fix case k > 1 and c=1.

* Remove debug code.

* Make MPerGroup and NPerGroup template parameters.

* Add additional check for non-supported c > 1 case.

* WIP: Put back the generic tensor descriptors for convolutions.

* Fix tensor descriptors.

* Remove the obsolete template parameters.

* Add more instances.

* Fix bugs in merged conv groups tensor descriptors.

* Fix tensor descriptors for merged conv groups when K > 1.

* Remove debug output.

* Remove dead code.

* Fix merge conflicts.

* Code clean-up.

* Remove unused code.

* Run clang-formatting.

* Remove debug prints and obsolete tests.

* Check that number of convolution groups is multiple of merged groups.

* Fix build after removing obsolete functionality.

* Remove obsolete enumeration.

* Fix new unit projects.

* Remove unnecessary includes.

* Fix passing the number of merged groups.

* Remove unrelated tests.

* Fix IsSupportedArgument for bwd weight conv kernel.

* Fix clang formatting.

* Fix the bwd weight conv to gemm mapping for num merged groups > 1.

* GEMM config for conv group merging.

* Fix clang-formatting.

* Remove obsolete comment.

* Fix typos in comment strings.

* Increase the max number of reported errors when testing against reference implementation.

* Rename gemm_config to conv_config.

* Rename GemmConfig to ConvConfig and move NumGroupsToMerge into ConvConfig.

* Change num_groups_to_merge to a boolean flag in the ck tile grouped conv example.

* Run clang-format.

* Add number of merged groups into kernel name string.

* Remove group merging flag from CK Tile grouped conv example.
This commit is contained in:
Ville Pietilä
2025-10-29 16:49:28 +02:00
committed by GitHub
parent aa22da07be
commit 121bf0e1f3
17 changed files with 755 additions and 269 deletions

View File

@@ -17,7 +17,7 @@
#define CK_TILE_PIPELINE_COMPUTE_V4 3
#define CK_TILE_PIPELINE_COMPUTE_V5 4
struct GemmConfigBase
struct ConvConfigBase
{
static constexpr bool kPadM = true;
static constexpr bool kPadN = true;
@@ -29,6 +29,10 @@ struct GemmConfigBase
static constexpr bool TransposeC = false;
static constexpr bool UseStructuredSparsity = false;
static constexpr ck_tile::index_t VectorSizeA = 4;
static constexpr ck_tile::index_t VectorSizeB = 8;
static constexpr ck_tile::index_t VectorSizeC = 8;
static constexpr int kBlockPerCu = 1;
static constexpr ck_tile::index_t TileParitionerGroupNum = 8;
static constexpr ck_tile::index_t TileParitionerM01 = 4;
@@ -37,10 +41,12 @@ struct GemmConfigBase
static constexpr ck_tile::index_t NumWaveGroups = 1;
static constexpr bool Preshuffle = false;
static constexpr bool TiledMMAPermuteN = false;
static constexpr ck_tile::index_t NumGroupsToMerge = 1;
};
template <typename PrecType>
struct GemmConfigMemoryInterwave : public GemmConfigBase
struct ConvConfigMemoryInterwave : public ConvConfigBase
{
// Memory friendly for Interwave scheduler
static constexpr ck_tile::index_t M_Tile = 128;
@@ -61,7 +67,7 @@ struct GemmConfigMemoryInterwave : public GemmConfigBase
};
template <typename PrecType>
struct GemmConfigMemoryIntrawave : public GemmConfigBase
struct ConvConfigMemoryIntrawave : public ConvConfigBase
{
static constexpr ck_tile::index_t M_Tile = 128;
static constexpr ck_tile::index_t N_Tile = 32;
@@ -80,7 +86,7 @@ struct GemmConfigMemoryIntrawave : public GemmConfigBase
};
template <typename PrecType>
struct GemmConfigComputeV3 : public GemmConfigBase
struct ConvConfigComputeV3 : public ConvConfigBase
{
// Compute V3 only support Intrawave scheduler
static constexpr ck_tile::index_t M_Tile = 16;
@@ -100,7 +106,7 @@ struct GemmConfigComputeV3 : public GemmConfigBase
};
template <typename PrecType>
struct GemmConfigComputeV3_1 : public GemmConfigBase
struct ConvConfigComputeV3_1 : public ConvConfigBase
{
static constexpr ck_tile::index_t M_Tile = 256;
static constexpr ck_tile::index_t N_Tile = 256;
@@ -119,7 +125,7 @@ struct GemmConfigComputeV3_1 : public GemmConfigBase
};
template <typename PrecType>
struct GemmConfigComputeV3_2 : public GemmConfigBase
struct ConvConfigComputeV3_2 : public ConvConfigBase
{
static constexpr ck_tile::index_t M_Tile = 128;
static constexpr ck_tile::index_t N_Tile = 128;
@@ -140,7 +146,7 @@ struct GemmConfigComputeV3_2 : public GemmConfigBase
};
template <typename PrecType>
struct GemmConfigComputeV3_WMMA : public GemmConfigBase
struct ConvConfigComputeV3_WMMA : public ConvConfigBase
{
static constexpr ck_tile::index_t M_Tile = 128;
static constexpr ck_tile::index_t N_Tile = 128;
@@ -161,7 +167,7 @@ struct GemmConfigComputeV3_WMMA : public GemmConfigBase
};
template <typename PrecType>
struct GemmConfigComputeV4 : public GemmConfigBase
struct ConvConfigComputeV4 : public ConvConfigBase
{
// Compute V4 only support Intrawave scheduler
// Using the ping pong reader in the lds level
@@ -182,7 +188,7 @@ struct GemmConfigComputeV4 : public GemmConfigBase
};
template <typename PrecType>
struct GemmConfigComputeV4_1 : public GemmConfigBase
struct ConvConfigComputeV4_1 : public ConvConfigBase
{
static constexpr ck_tile::index_t M_Tile = 256;
static constexpr ck_tile::index_t N_Tile = 256;
@@ -201,7 +207,7 @@ struct GemmConfigComputeV4_1 : public GemmConfigBase
};
template <typename PrecType>
struct GemmConfigComputeV5 : public GemmConfigBase
struct ConvConfigComputeV5 : public ConvConfigBase
{
static constexpr ck_tile::index_t M_Tile = 128;
static constexpr ck_tile::index_t N_Tile = 128;
@@ -220,6 +226,31 @@ struct GemmConfigComputeV5 : public GemmConfigBase
static constexpr ck_tile::index_t NumWaNumWaveGroups = 2;
};
template <typename PrecType>
struct ConvConfigComputeV3_merged_groups : public ConvConfigBase
{
static constexpr ck_tile::index_t VectorSizeA = 4;
static constexpr ck_tile::index_t VectorSizeB = 8;
static constexpr ck_tile::index_t VectorSizeC = 8;
static constexpr ck_tile::index_t M_Tile = 16;
static constexpr ck_tile::index_t N_Tile = 32;
static constexpr ck_tile::index_t K_Tile = 32;
static constexpr ck_tile::index_t M_Warp = 1;
static constexpr ck_tile::index_t N_Warp = 2;
static constexpr ck_tile::index_t K_Warp = 1;
static constexpr ck_tile::index_t M_Warp_Tile = 16;
static constexpr ck_tile::index_t N_Warp_Tile = 16;
static constexpr ck_tile::index_t K_Warp_Tile = 32;
static constexpr bool DoubleSmemBuffer = false;
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3;
static constexpr ck_tile::index_t NumGroupsToMerge = 2;
};
template <typename InDataType, typename WeiDataType = InDataType, typename OutDataType = InDataType>
struct ConvTypeConfig;

View File

@@ -51,8 +51,8 @@ int run_grouped_conv_bwd_data_example(int argc, char* argv[])
int main(int argc, char* argv[])
{
#if CK_TILE_USE_WMMA
return !run_grouped_conv_bwd_data_example<GemmConfigComputeV3_WMMA>(argc, argv);
return !run_grouped_conv_bwd_data_example<ConvConfigComputeV3_WMMA>(argc, argv);
#else
return !run_grouped_conv_bwd_data_example<GemmConfigComputeV3>(argc, argv);
return !run_grouped_conv_bwd_data_example<ConvConfigComputeV3>(argc, argv);
#endif
}

View File

@@ -14,7 +14,7 @@
#include "grouped_convolution_backward_weight_invoker.hpp"
#include "run_grouped_convolution_bwd_weight_example.inc"
template <template <typename PrecType> typename GemmConfig>
template <template <typename PrecType> typename ConvConfig>
int run_grouped_conv_bwd_weight_example(ck_tile::ArgParser& arg_parser)
{
using Invoker = GroupedConvolutionBackwardWeightInvoker;
@@ -27,14 +27,14 @@ int run_grouped_conv_bwd_weight_example(ck_tile::ArgParser& arg_parser)
if(data_type == "fp16")
{
return run_grouped_conv_bwd_weight_example_prec_type<Invoker,
GemmConfig<ck_tile::half_t>,
ConvConfig<ck_tile::half_t>,
ck_tile::half_t>(
in_layout, wei_layout, out_layout, arg_parser);
}
else if(data_type == "bf16")
{
return run_grouped_conv_bwd_weight_example_prec_type<Invoker,
GemmConfig<ck_tile::bf16_t>,
ConvConfig<ck_tile::bf16_t>,
ck_tile::bf16_t>(
in_layout, wei_layout, out_layout, arg_parser);
}
@@ -54,9 +54,9 @@ int main(int argc, char* argv[])
try
{
#if CK_TILE_USE_WMMA
return !run_grouped_conv_bwd_weight_example<GemmConfigComputeV3_WMMA>(arg_parser);
return !run_grouped_conv_bwd_weight_example<ConvConfigComputeV3_WMMA>(arg_parser);
#else
return !run_grouped_conv_bwd_weight_example<GemmConfigComputeV3>(arg_parser);
return !run_grouped_conv_bwd_weight_example<ConvConfigComputeV3>(arg_parser);
#endif
}
catch(const std::runtime_error& e)

View File

@@ -7,7 +7,7 @@
struct GroupedConvolutionBackwardWeightInvoker
{
template <ck_tile::index_t NDimSpatial,
typename GemmConfig,
typename ConvConfig,
typename InDataType,
typename WeiDataType,
typename AccDataType,
@@ -25,22 +25,22 @@ struct GroupedConvolutionBackwardWeightInvoker
// Implicit GEMM Traits
using GemmShape = ck_tile::TileGemmShape<
ck_tile::sequence<GemmConfig::M_Tile, GemmConfig::N_Tile, GemmConfig::K_Tile>,
ck_tile::sequence<GemmConfig::M_Warp, GemmConfig::N_Warp, GemmConfig::K_Warp>,
ck_tile::sequence<ConvConfig::M_Tile, ConvConfig::N_Tile, ConvConfig::K_Tile>,
ck_tile::sequence<ConvConfig::M_Warp, ConvConfig::N_Warp, ConvConfig::K_Warp>,
ck_tile::
sequence<GemmConfig::M_Warp_Tile, GemmConfig::N_Warp_Tile, GemmConfig::K_Warp_Tile>,
GemmConfig::PermuteA,
GemmConfig::PermuteB>;
sequence<ConvConfig::M_Warp_Tile, ConvConfig::N_Warp_Tile, ConvConfig::K_Warp_Tile>,
ConvConfig::PermuteA,
ConvConfig::PermuteB>;
constexpr ck_tile::index_t VectorSizeA = 4;
constexpr ck_tile::index_t VectorSizeB = 8;
constexpr ck_tile::index_t VectorSizeC = 8;
constexpr ck_tile::index_t VectorSizeA = ConvConfig::VectorSizeA;
constexpr ck_tile::index_t VectorSizeB = ConvConfig::VectorSizeB;
constexpr ck_tile::index_t VectorSizeC = ConvConfig::VectorSizeC;
constexpr auto ConvSpec = ck_tile::ConvolutionSpecialization::Default;
using TilePartitioner =
ck_tile::GemmSpatiallyLocalTilePartitioner<GemmShape,
GemmConfig::TileParitionerGroupNum,
GemmConfig::TileParitionerM01>;
ConvConfig::TileParitionerGroupNum,
ConvConfig::TileParitionerM01>;
using GroupedConvTraitsType = ck_tile::GroupedConvTraits<NDimSpatial,
ConvSpec,
InLayout,
@@ -49,20 +49,21 @@ struct GroupedConvolutionBackwardWeightInvoker
OutLayout,
VectorSizeA,
VectorSizeB,
VectorSizeC>;
VectorSizeC,
ConvConfig::NumGroupsToMerge>;
using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits<
GemmConfig::kPadM,
GemmConfig::kPadN,
GemmConfig::kPadK,
GemmConfig::DoubleSmemBuffer,
ConvConfig::kPadM,
ConvConfig::kPadN,
ConvConfig::kPadK,
ConvConfig::DoubleSmemBuffer,
typename GroupedConvTraitsType::GroupedConvImplicitGemmTraitsBwdWeight::AsLayout,
typename GroupedConvTraitsType::GroupedConvImplicitGemmTraitsBwdWeight::BsLayout,
typename GroupedConvTraitsType::GroupedConvImplicitGemmTraitsBwdWeight::CLayout,
GemmConfig::TransposeC,
GemmConfig::UseStructuredSparsity,
ConvConfig::TransposeC,
ConvConfig::UseStructuredSparsity,
false, // Persistent,
GemmConfig::NumWaveGroups>;
ConvConfig::NumWaveGroups>;
using GemmPipelineProblem = ck_tile::GemmPipelineProblem<
OutDataType,
@@ -78,7 +79,7 @@ struct GroupedConvolutionBackwardWeightInvoker
VectorSizeB>;
using BaseGemmPipeline = typename PipelineTypeTraits<
GemmConfig::Pipeline>::template UniversalGemmPipeline<GemmPipelineProblem>;
ConvConfig::Pipeline>::template UniversalGemmPipeline<GemmPipelineProblem>;
const ck_tile::index_t gemm_k =
args.N_ * std::accumulate(args.output_spatial_lengths_.begin(),
@@ -86,8 +87,8 @@ struct GroupedConvolutionBackwardWeightInvoker
1,
std::multiplies<ck_tile::index_t>());
const ck_tile::index_t k_grain = args.k_batch * GemmConfig::K_Tile;
const ck_tile::index_t K_split = (gemm_k + k_grain - 1) / k_grain * GemmConfig::K_Tile;
const ck_tile::index_t k_grain = args.k_batch * ConvConfig::K_Tile;
const ck_tile::index_t K_split = (gemm_k + k_grain - 1) / k_grain * ConvConfig::K_Tile;
const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split);
const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop);
const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop);
@@ -98,7 +99,7 @@ struct GroupedConvolutionBackwardWeightInvoker
const auto memory_operation_) {
constexpr bool has_hot_loop_v = has_hot_loop_.value;
constexpr auto tail_number_v = tail_number_.value;
constexpr auto scheduler = GemmConfig::Scheduler;
constexpr auto scheduler = ConvConfig::Scheduler;
constexpr auto memory_operation = memory_operation_.value;
using UniversalGemmProblem =
@@ -118,7 +119,7 @@ struct GroupedConvolutionBackwardWeightInvoker
VectorSizeB>;
using GemmPipeline = typename PipelineTypeTraits<
GemmConfig::Pipeline>::template GemmPipeline<UniversalGemmProblem>;
ConvConfig::Pipeline>::template GemmPipeline<UniversalGemmProblem>;
using ConvEpilogue = ck_tile::CShuffleEpilogue<ck_tile::CShuffleEpilogueProblem<
OutDataType,
@@ -131,12 +132,12 @@ struct GroupedConvolutionBackwardWeightInvoker
CDEElementWise,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
GemmConfig::M_Warp,
GemmConfig::N_Warp,
GemmConfig::M_Warp_Tile,
GemmConfig::N_Warp_Tile,
GemmConfig::K_Warp_Tile,
GemmConfig::TransposeC,
ConvConfig::M_Warp,
ConvConfig::N_Warp,
ConvConfig::M_Warp_Tile,
ConvConfig::N_Warp_Tile,
ConvConfig::K_Warp_Tile,
ConvConfig::TransposeC,
memory_operation,
1,
true,

View File

@@ -13,9 +13,9 @@
#include "grouped_convolution_utils.hpp"
#include "grouped_convolution_backward_weight_two_stage_invoker.hpp"
#include "run_grouped_convolution_bwd_weight_example.inc"
#include "gemm_configs.hpp"
#include "conv_configs.hpp"
template <template <typename PrecType> typename GemmConfig>
template <template <typename PrecType> typename ConvConfig>
int run_grouped_conv_bwd_weight_example(ck_tile::ArgParser& arg_parser)
{
using Invoker = GroupedConvolutionBackwardWeightTwoStageInvoker;
@@ -28,14 +28,14 @@ int run_grouped_conv_bwd_weight_example(ck_tile::ArgParser& arg_parser)
if(data_type == "fp16")
{
return run_grouped_conv_bwd_weight_example_prec_type<Invoker,
GemmConfig<ck_tile::half_t>,
ConvConfig<ck_tile::half_t>,
ck_tile::half_t>(
in_layout, wei_layout, out_layout, arg_parser);
}
else if(data_type == "bf16")
{
return run_grouped_conv_bwd_weight_example_prec_type<Invoker,
GemmConfig<ck_tile::bf16_t>,
ConvConfig<ck_tile::bf16_t>,
ck_tile::bf16_t>(
in_layout, wei_layout, out_layout, arg_parser);
}
@@ -55,9 +55,9 @@ int main(int argc, char* argv[])
try
{
#if CK_TILE_USE_WMMA
return !run_grouped_conv_bwd_weight_example<GemmConfigComputeV3_WMMA>(arg_parser);
return !run_grouped_conv_bwd_weight_example<ConvConfigComputeV3_WMMA>(arg_parser);
#else
return !run_grouped_conv_bwd_weight_example<GemmConfigComputeV3>(arg_parser);
return !run_grouped_conv_bwd_weight_example<ConvConfigComputeV3>(arg_parser);
#endif
}
catch(const std::runtime_error& e)

View File

@@ -7,7 +7,7 @@
struct GroupedConvolutionBackwardWeightTwoStageInvoker
{
template <ck_tile::index_t NDimSpatial,
typename GemmConfig,
typename ConvConfig,
typename InDataType,
typename WeiDataType,
typename AccDataType,
@@ -27,12 +27,12 @@ struct GroupedConvolutionBackwardWeightTwoStageInvoker
// Implicit GEMM Traits
using GemmShape = ck_tile::TileGemmShape<
ck_tile::sequence<GemmConfig::M_Tile, GemmConfig::N_Tile, GemmConfig::K_Tile>,
ck_tile::sequence<GemmConfig::M_Warp, GemmConfig::N_Warp, GemmConfig::K_Warp>,
ck_tile::sequence<ConvConfig::M_Tile, ConvConfig::N_Tile, ConvConfig::K_Tile>,
ck_tile::sequence<ConvConfig::M_Warp, ConvConfig::N_Warp, ConvConfig::K_Warp>,
ck_tile::
sequence<GemmConfig::M_Warp_Tile, GemmConfig::N_Warp_Tile, GemmConfig::K_Warp_Tile>,
GemmConfig::PermuteA,
GemmConfig::PermuteB>;
sequence<ConvConfig::M_Warp_Tile, ConvConfig::N_Warp_Tile, ConvConfig::K_Warp_Tile>,
ConvConfig::PermuteA,
ConvConfig::PermuteB>;
constexpr ck_tile::index_t VectorSizeA = 4;
constexpr ck_tile::index_t VectorSizeB = 8;
@@ -41,8 +41,8 @@ struct GroupedConvolutionBackwardWeightTwoStageInvoker
constexpr auto ConvSpec = ck_tile::ConvolutionSpecialization::Default;
using TilePartitioner =
ck_tile::GemmSpatiallyLocalTilePartitioner<GemmShape,
GemmConfig::TileParitionerGroupNum,
GemmConfig::TileParitionerM01>;
ConvConfig::TileParitionerGroupNum,
ConvConfig::TileParitionerM01>;
using GroupedConvTraitsType = ck_tile::GroupedConvTraits<NDimSpatial,
ConvSpec,
InLayout,
@@ -54,17 +54,17 @@ struct GroupedConvolutionBackwardWeightTwoStageInvoker
VectorSizeC>;
using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits<
GemmConfig::kPadM,
GemmConfig::kPadN,
GemmConfig::kPadK,
GemmConfig::DoubleSmemBuffer,
ConvConfig::kPadM,
ConvConfig::kPadN,
ConvConfig::kPadK,
ConvConfig::DoubleSmemBuffer,
typename GroupedConvTraitsType::GroupedConvImplicitGemmTraitsBwdWeight::AsLayout,
typename GroupedConvTraitsType::GroupedConvImplicitGemmTraitsBwdWeight::BsLayout,
typename GroupedConvTraitsType::GroupedConvImplicitGemmTraitsBwdWeight::CLayout,
GemmConfig::TransposeC,
GemmConfig::UseStructuredSparsity,
ConvConfig::TransposeC,
ConvConfig::UseStructuredSparsity,
false, // Persistent,
GemmConfig::NumWaveGroups>;
ConvConfig::NumWaveGroups>;
using GemmPipelineProblem = ck_tile::GemmPipelineProblem<
OutDataType,
@@ -80,7 +80,7 @@ struct GroupedConvolutionBackwardWeightTwoStageInvoker
VectorSizeB>;
using BaseGemmPipeline = typename PipelineTypeTraits<
GemmConfig::Pipeline>::template UniversalGemmPipeline<GemmPipelineProblem>;
ConvConfig::Pipeline>::template UniversalGemmPipeline<GemmPipelineProblem>;
const ck_tile::index_t gemm_k =
args.N_ * std::accumulate(args.output_spatial_lengths_.begin(),
@@ -88,8 +88,8 @@ struct GroupedConvolutionBackwardWeightTwoStageInvoker
1,
std::multiplies<ck_tile::index_t>());
const ck_tile::index_t k_grain = args.k_batch * GemmConfig::K_Tile;
const ck_tile::index_t K_split = (gemm_k + k_grain - 1) / k_grain * GemmConfig::K_Tile;
const ck_tile::index_t k_grain = args.k_batch * ConvConfig::K_Tile;
const ck_tile::index_t K_split = (gemm_k + k_grain - 1) / k_grain * ConvConfig::K_Tile;
const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split);
const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop);
const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop);
@@ -100,7 +100,7 @@ struct GroupedConvolutionBackwardWeightTwoStageInvoker
const auto memory_operation_) {
constexpr bool has_hot_loop_v = has_hot_loop_.value;
constexpr auto tail_number_v = tail_number_.value;
constexpr auto scheduler = GemmConfig::Scheduler;
constexpr auto scheduler = ConvConfig::Scheduler;
constexpr auto memory_operation = memory_operation_.value;
using UniversalGemmProblem =
@@ -120,7 +120,7 @@ struct GroupedConvolutionBackwardWeightTwoStageInvoker
VectorSizeB>;
using GemmPipeline = typename PipelineTypeTraits<
GemmConfig::Pipeline>::template GemmPipeline<UniversalGemmProblem>;
ConvConfig::Pipeline>::template GemmPipeline<UniversalGemmProblem>;
using ConvEpilogue = ck_tile::CShuffleEpilogue<ck_tile::CShuffleEpilogueProblem<
OutDataType, // A: Out
@@ -133,11 +133,11 @@ struct GroupedConvolutionBackwardWeightTwoStageInvoker
CDEElementWise,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
GemmConfig::M_Warp,
GemmConfig::N_Warp,
GemmConfig::M_Warp_Tile,
GemmConfig::N_Warp_Tile,
GemmConfig::K_Warp_Tile,
ConvConfig::M_Warp,
ConvConfig::N_Warp,
ConvConfig::M_Warp_Tile,
ConvConfig::N_Warp_Tile,
ConvConfig::K_Warp_Tile,
GemmPipelineProblem::TransposeC,
memory_operation,
1,

View File

@@ -51,8 +51,8 @@ int run_grouped_conv_fwd_example(int argc, char* argv[])
int main(int argc, char* argv[])
{
#if CK_TILE_USE_WMMA
return !run_grouped_conv_fwd_example<GemmConfigComputeV3_WMMA>(argc, argv);
return !run_grouped_conv_fwd_example<ConvConfigComputeV3_WMMA>(argc, argv);
#else
return !run_grouped_conv_fwd_example<GemmConfigComputeV3>(argc, argv);
return !run_grouped_conv_fwd_example<ConvConfigComputeV3>(argc, argv);
#endif
}

View File

@@ -32,9 +32,10 @@ struct GroupedConvolutionForwardInvoker
GemmConfig::PermuteA,
GemmConfig::PermuteB>;
constexpr ck_tile::index_t VectorSizeA = 8;
constexpr ck_tile::index_t VectorSizeB = 8;
constexpr ck_tile::index_t VectorSizeC = 8;
constexpr ck_tile::index_t VectorSizeA = 8;
constexpr ck_tile::index_t VectorSizeB = 8;
constexpr ck_tile::index_t VectorSizeC = 8;
constexpr ck_tile::index_t NumGroupsToMerge = 1;
constexpr auto ConvSpec = ck_tile::ConvolutionSpecialization::Default;
using TilePartitioner =
@@ -50,6 +51,7 @@ struct GroupedConvolutionForwardInvoker
VectorSizeA,
VectorSizeB,
VectorSizeC,
NumGroupsToMerge,
CDElementWise>;
using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits<

View File

@@ -11,24 +11,11 @@
#include "ck_tile/ops/epilogue.hpp"
#include "ck_tile/ops/gemm.hpp"
#include "ck_tile/ops/grouped_convolution.hpp"
#include "gemm_configs.hpp"
#include "conv_configs.hpp"
using MemoryOpSet =
std::integral_constant<ck_tile::memory_operation_enum, ck_tile::memory_operation_enum::set>;
using MemoryOpAtomicAdd = std::integral_constant<ck_tile::memory_operation_enum,
ck_tile::memory_operation_enum::atomic_add>;
struct GemmWarpConfig_Mfma
{
static constexpr ck_tile::index_t M_Warp_Tile = 32;
static constexpr ck_tile::index_t N_Warp_Tile = 32;
static constexpr ck_tile::index_t K_Warp_Tile = 16;
};
struct GemmWarpConfig_Wmma
{
static constexpr ck_tile::index_t M_Warp_Tile = 16;
static constexpr ck_tile::index_t N_Warp_Tile = 16;
static constexpr ck_tile::index_t K_Warp_Tile = 16;
};
template <typename InDataType, typename WeiDataType, typename AccDataType, typename OutDataType>
auto calculate_rtol_atol(const ck_tile::index_t GemmK,

View File

@@ -3,7 +3,7 @@
#pragma once
template <ck_tile::index_t NDimSpatial,
typename GemmConfig,
typename ConvConfig,
typename Invoker,
typename InDataType,
typename WeiDataType,
@@ -17,7 +17,7 @@ float invoke_grouped_conv_bwd_weight(ck_tile::GroupedConvBwdWeightHostArgs& args
int n_repeat)
{
float ave_time = Invoker::template grouped_conv_bwd_weight<NDimSpatial,
GemmConfig,
ConvConfig,
InDataType,
WeiDataType,
AccDataType,
@@ -31,7 +31,7 @@ float invoke_grouped_conv_bwd_weight(ck_tile::GroupedConvBwdWeightHostArgs& args
}
template <ck_tile::index_t NDimSpatial,
typename GemmConfig,
typename ConvConfig,
typename Invoker,
typename InDataType,
typename WeiDataType = InDataType,
@@ -131,7 +131,7 @@ int run_grouped_conv_bwd_weight_example_with_layouts(ck_tile::ArgParser& arg_par
std::cout << "output: " << output.mDesc << std::endl;
float ave_time = invoke_grouped_conv_bwd_weight<NDimSpatial,
GemmConfig,
ConvConfig,
Invoker,
InDataType,
WeiDataType,
@@ -193,7 +193,7 @@ int run_grouped_conv_bwd_weight_example_with_layouts(ck_tile::ArgParser& arg_par
}
template <typename Invoker,
typename GemmConfig,
typename ConvConfig,
typename InPrecType,
typename WeiPrecType = InPrecType,
typename OutPrecType = InPrecType>
@@ -217,7 +217,7 @@ int run_grouped_conv_bwd_weight_example_prec_type(std::string in_layout,
if(in_layout == "NWGC" && wei_layout == "GKXC" && out_layout == "NWGK")
{
return run_grouped_conv_bwd_weight_example_with_layouts<ck_tile::number<1>{},
GemmConfig,
ConvConfig,
Invoker,
InPrecType,
WeiPrecType,
@@ -227,7 +227,7 @@ int run_grouped_conv_bwd_weight_example_prec_type(std::string in_layout,
else if(in_layout == "NHWGC" && wei_layout == "GKYXC" && out_layout == "NHWGK")
{
return run_grouped_conv_bwd_weight_example_with_layouts<ck_tile::number<2>{},
GemmConfig,
ConvConfig,
Invoker,
InPrecType,
WeiPrecType,
@@ -237,7 +237,7 @@ int run_grouped_conv_bwd_weight_example_prec_type(std::string in_layout,
else if(in_layout == "NDHWGC" && wei_layout == "GKZYXC" && out_layout == "NDHWGK")
{
return run_grouped_conv_bwd_weight_example_with_layouts<ck_tile::number<3>{},
GemmConfig,
ConvConfig,
Invoker,
InPrecType,
WeiPrecType,