[CK TILE] Refactor Conv configs and Conv Elementwise (#3151)

* [CK TILE] Refactor Conv configs and Conv Elementwise

* fix
This commit is contained in:
Bartłomiej Kocot
2025-11-04 15:04:53 +01:00
committed by GitHub
parent 99f38e4d9b
commit 8681ced962
14 changed files with 230 additions and 219 deletions

View File

@@ -18,11 +18,7 @@ struct ConvConfigBase
static constexpr bool kPadN = true;
static constexpr bool kPadK = true;
static constexpr bool PermuteA = false;
static constexpr bool PermuteB = false;
static constexpr bool TransposeC = false;
static constexpr bool UseStructuredSparsity = false;
static constexpr bool TransposeC = false;
static constexpr ck_tile::index_t VectorSizeA = 4;
static constexpr ck_tile::index_t VectorSizeB = 8;
@@ -34,8 +30,6 @@ struct ConvConfigBase
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave;
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V3;
static constexpr ck_tile::index_t NumWaveGroups = 1;
static constexpr bool Preshuffle = false;
static constexpr bool TiledMMAPermuteN = false;
static constexpr ck_tile::index_t NumGroupsToMerge = 1;
};

View File

@@ -14,7 +14,7 @@
#include "grouped_convolution_backward_data_invoker.hpp"
#include "run_grouped_convolution_bwd_data_example.inc"
template <template <typename PrecType> typename GemmConfig>
template <template <typename PrecType> typename ConvConfig>
int run_grouped_conv_bwd_data_example(int argc, char* argv[])
{
using Invoker = GroupedConvolutionBackwardDataInvoker;
@@ -31,14 +31,14 @@ int run_grouped_conv_bwd_data_example(int argc, char* argv[])
if(data_type == "fp16")
{
return run_grouped_conv_bwd_data_example_prec_type<Invoker,
GemmConfig<ck_tile::half_t>,
ConvConfig<ck_tile::half_t>,
ck_tile::half_t>(
in_layout, wei_layout, out_layout, argc, argv);
}
else if(data_type == "bf16")
{
return run_grouped_conv_bwd_data_example_prec_type<Invoker,
GemmConfig<ck_tile::bf16_t>,
ConvConfig<ck_tile::bf16_t>,
ck_tile::bf16_t>(
in_layout, wei_layout, out_layout, argc, argv);
}

View File

@@ -8,7 +8,7 @@ struct GroupedConvolutionBackwardDataInvoker
{
template <ck_tile::index_t NDimSpatial,
typename GemmConfig,
typename ConvConfig,
typename InDataType,
typename WeiDataType,
typename AccDataType,
@@ -26,12 +26,11 @@ struct GroupedConvolutionBackwardDataInvoker
// Implicit GEMM Traits
using GemmShape = ck_tile::TileGemmShape<
ck_tile::sequence<GemmConfig::M_Tile, GemmConfig::N_Tile, GemmConfig::K_Tile>,
ck_tile::sequence<GemmConfig::M_Warp, GemmConfig::N_Warp, GemmConfig::K_Warp>,
ck_tile::
sequence<GemmConfig::M_Warp_Tile, GemmConfig::N_Warp_Tile, GemmConfig::K_Warp_Tile>,
GemmConfig::PermuteA,
GemmConfig::PermuteB>;
ck_tile::sequence<ConvConfig::M_Tile, ConvConfig::N_Tile, ConvConfig::K_Tile>,
ck_tile::sequence<ConvConfig::M_Warp, ConvConfig::N_Warp, ConvConfig::K_Warp>,
ck_tile::sequence<ConvConfig::M_Warp_Tile,
ConvConfig::N_Warp_Tile,
ConvConfig::K_Warp_Tile>>;
constexpr ck_tile::index_t VectorSizeA = 8;
constexpr ck_tile::index_t VectorSizeB = 8;
@@ -40,8 +39,8 @@ struct GroupedConvolutionBackwardDataInvoker
constexpr auto ConvSpec = ck_tile::ConvolutionSpecialization::Default;
using TilePartitioner =
ck_tile::GemmSpatiallyLocalTilePartitioner<GemmShape,
GemmConfig::TileParitionerGroupNum,
GemmConfig::TileParitionerM01>;
ConvConfig::TileParitionerGroupNum,
ConvConfig::TileParitionerM01>;
using GroupedConvTraitsType = ck_tile::GroupedConvTraits<NDimSpatial,
ConvSpec,
InLayout,
@@ -53,17 +52,17 @@ struct GroupedConvolutionBackwardDataInvoker
VectorSizeC>;
using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits<
GemmConfig::kPadM,
GemmConfig::kPadN,
GemmConfig::kPadK,
GemmConfig::DoubleSmemBuffer,
ConvConfig::kPadM,
ConvConfig::kPadN,
ConvConfig::kPadK,
ConvConfig::DoubleSmemBuffer,
typename GroupedConvTraitsType::GroupedConvImplicitGemmTraitsBwdData::AsLayout,
typename GroupedConvTraitsType::GroupedConvImplicitGemmTraitsBwdData::BsLayout,
typename GroupedConvTraitsType::GroupedConvImplicitGemmTraitsBwdData::CLayout,
GemmConfig::TransposeC,
GemmConfig::UseStructuredSparsity,
ConvConfig::TransposeC,
false,
false, // Persistent,
GemmConfig::NumWaveGroups>;
ConvConfig::NumWaveGroups>;
using GemmPipelineProblem = ck_tile::GemmPipelineProblem<
OutDataType,
@@ -79,7 +78,7 @@ struct GroupedConvolutionBackwardDataInvoker
VectorSizeB>;
using BaseGemmPipeline = typename PipelineTypeTraits<
GemmConfig::Pipeline>::template UniversalGemmPipeline<GemmPipelineProblem>;
ConvConfig::Pipeline>::template UniversalGemmPipeline<GemmPipelineProblem>;
const ck_tile::index_t gemm_k =
args.K_ * std::accumulate(args.filter_spatial_lengths_.begin(),
@@ -87,8 +86,8 @@ struct GroupedConvolutionBackwardDataInvoker
1,
std::multiplies<ck_tile::index_t>());
const ck_tile::index_t k_grain = args.k_batch * GemmConfig::K_Tile;
const ck_tile::index_t K_split = (gemm_k + k_grain - 1) / k_grain * GemmConfig::K_Tile;
const ck_tile::index_t k_grain = args.k_batch * ConvConfig::K_Tile;
const ck_tile::index_t K_split = (gemm_k + k_grain - 1) / k_grain * ConvConfig::K_Tile;
const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split);
const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop);
const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop);
@@ -98,7 +97,7 @@ struct GroupedConvolutionBackwardDataInvoker
[&](const auto has_hot_loop_, const auto tail_number_, const auto memory_operation_) {
constexpr bool has_hot_loop_v = has_hot_loop_.value;
constexpr auto tail_number_v = tail_number_.value;
constexpr auto scheduler = GemmConfig::Scheduler;
constexpr auto scheduler = ConvConfig::Scheduler;
constexpr auto memory_operation = memory_operation_.value;
using UniversalGemmProblem =
@@ -118,7 +117,7 @@ struct GroupedConvolutionBackwardDataInvoker
VectorSizeB>;
using GemmPipeline = typename PipelineTypeTraits<
GemmConfig::Pipeline>::template GemmPipeline<UniversalGemmProblem>;
ConvConfig::Pipeline>::template GemmPipeline<UniversalGemmProblem>;
using ConvEpilogue = ck_tile::CShuffleEpilogue<ck_tile::CShuffleEpilogueProblem<
OutDataType,
@@ -131,12 +130,12 @@ struct GroupedConvolutionBackwardDataInvoker
CDEElementWise,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
GemmConfig::M_Warp,
GemmConfig::N_Warp,
GemmConfig::M_Warp_Tile,
GemmConfig::N_Warp_Tile,
GemmConfig::K_Warp_Tile,
GemmConfig::TransposeC,
ConvConfig::M_Warp,
ConvConfig::N_Warp,
ConvConfig::M_Warp_Tile,
ConvConfig::N_Warp_Tile,
ConvConfig::K_Warp_Tile,
ConvConfig::TransposeC,
memory_operation,
1,
true,

View File

@@ -27,10 +27,9 @@ struct GroupedConvolutionBackwardWeightInvoker
using GemmShape = ck_tile::TileGemmShape<
ck_tile::sequence<ConvConfig::M_Tile, ConvConfig::N_Tile, ConvConfig::K_Tile>,
ck_tile::sequence<ConvConfig::M_Warp, ConvConfig::N_Warp, ConvConfig::K_Warp>,
ck_tile::
sequence<ConvConfig::M_Warp_Tile, ConvConfig::N_Warp_Tile, ConvConfig::K_Warp_Tile>,
ConvConfig::PermuteA,
ConvConfig::PermuteB>;
ck_tile::sequence<ConvConfig::M_Warp_Tile,
ConvConfig::N_Warp_Tile,
ConvConfig::K_Warp_Tile>>;
constexpr ck_tile::index_t VectorSizeA = ConvConfig::VectorSizeA;
constexpr ck_tile::index_t VectorSizeB = ConvConfig::VectorSizeB;
@@ -61,7 +60,7 @@ struct GroupedConvolutionBackwardWeightInvoker
typename GroupedConvTraitsType::GroupedConvImplicitGemmTraitsBwdWeight::BsLayout,
typename GroupedConvTraitsType::GroupedConvImplicitGemmTraitsBwdWeight::CLayout,
ConvConfig::TransposeC,
ConvConfig::UseStructuredSparsity,
false,
false, // Persistent,
ConvConfig::NumWaveGroups>;

View File

@@ -29,10 +29,9 @@ struct GroupedConvolutionBackwardWeightTwoStageInvoker
using GemmShape = ck_tile::TileGemmShape<
ck_tile::sequence<ConvConfig::M_Tile, ConvConfig::N_Tile, ConvConfig::K_Tile>,
ck_tile::sequence<ConvConfig::M_Warp, ConvConfig::N_Warp, ConvConfig::K_Warp>,
ck_tile::
sequence<ConvConfig::M_Warp_Tile, ConvConfig::N_Warp_Tile, ConvConfig::K_Warp_Tile>,
ConvConfig::PermuteA,
ConvConfig::PermuteB>;
ck_tile::sequence<ConvConfig::M_Warp_Tile,
ConvConfig::N_Warp_Tile,
ConvConfig::K_Warp_Tile>>;
constexpr ck_tile::index_t VectorSizeA = 4;
constexpr ck_tile::index_t VectorSizeB = 8;
@@ -62,7 +61,7 @@ struct GroupedConvolutionBackwardWeightTwoStageInvoker
typename GroupedConvTraitsType::GroupedConvImplicitGemmTraitsBwdWeight::BsLayout,
typename GroupedConvTraitsType::GroupedConvImplicitGemmTraitsBwdWeight::CLayout,
ConvConfig::TransposeC,
ConvConfig::UseStructuredSparsity,
false,
false, // Persistent,
ConvConfig::NumWaveGroups>;

View File

@@ -14,7 +14,7 @@
#include "grouped_convolution_forward_invoker.hpp"
#include "run_grouped_convolution_fwd_example.inc"
template <template <typename PrecType> typename GemmConfig>
template <template <typename PrecType> typename ConvConfig>
int run_grouped_conv_fwd_example(int argc, char* argv[])
{
using Invoker = GroupedConvolutionForwardInvoker;
@@ -31,14 +31,14 @@ int run_grouped_conv_fwd_example(int argc, char* argv[])
if(data_type == "fp16")
{
return run_grouped_conv_fwd_example_prec_type<Invoker,
GemmConfig<ck_tile::half_t>,
ConvConfig<ck_tile::half_t>,
ck_tile::half_t>(
in_layout, wei_layout, out_layout, argc, argv);
}
else if(data_type == "bf16")
{
return run_grouped_conv_fwd_example_prec_type<Invoker,
GemmConfig<ck_tile::bf16_t>,
ConvConfig<ck_tile::bf16_t>,
ck_tile::bf16_t>(
in_layout, wei_layout, out_layout, argc, argv);
}

View File

@@ -14,7 +14,7 @@
#include "grouped_convolution_forward_invoker.hpp"
#include "run_grouped_convolution_fwd_bias_clamp_example.inc"
template <template <typename PrecType> typename GemmConfig>
template <template <typename PrecType> typename ConvConfig>
int run_grouped_conv_fwd_bias_clamp_example(int argc, char* argv[])
{
using Invoker = GroupedConvolutionForwardInvoker;
@@ -31,14 +31,14 @@ int run_grouped_conv_fwd_bias_clamp_example(int argc, char* argv[])
if(data_type == "fp16")
{
return run_grouped_conv_fwd_bias_clamp_example_prec_type<Invoker,
GemmConfig<ck_tile::half_t>,
ConvConfig<ck_tile::half_t>,
ck_tile::half_t>(
in_layout, wei_layout, out_layout, argc, argv);
}
else if(data_type == "bf16")
{
return run_grouped_conv_fwd_bias_clamp_example_prec_type<Invoker,
GemmConfig<ck_tile::bf16_t>,
ConvConfig<ck_tile::bf16_t>,
ck_tile::bf16_t>(
in_layout, wei_layout, out_layout, argc, argv);
}

View File

@@ -14,7 +14,7 @@
struct GroupedConvolutionForwardInvoker
{
template <ck_tile::index_t NDimSpatial,
typename GemmConfig,
typename ConvConfig,
typename InDataType,
typename WeiDataType,
typename AccDataType,
@@ -36,12 +36,11 @@ struct GroupedConvolutionForwardInvoker
// Implicit GEMM Traits
using GemmShape = ck_tile::TileGemmShape<
ck_tile::sequence<GemmConfig::M_Tile, GemmConfig::N_Tile, GemmConfig::K_Tile>,
ck_tile::sequence<GemmConfig::M_Warp, GemmConfig::N_Warp, GemmConfig::K_Warp>,
ck_tile::
sequence<GemmConfig::M_Warp_Tile, GemmConfig::N_Warp_Tile, GemmConfig::K_Warp_Tile>,
GemmConfig::PermuteA,
GemmConfig::PermuteB>;
ck_tile::sequence<ConvConfig::M_Tile, ConvConfig::N_Tile, ConvConfig::K_Tile>,
ck_tile::sequence<ConvConfig::M_Warp, ConvConfig::N_Warp, ConvConfig::K_Warp>,
ck_tile::sequence<ConvConfig::M_Warp_Tile,
ConvConfig::N_Warp_Tile,
ConvConfig::K_Warp_Tile>>;
constexpr ck_tile::index_t VectorSizeA = 8;
constexpr ck_tile::index_t VectorSizeB = 8;
@@ -51,8 +50,8 @@ struct GroupedConvolutionForwardInvoker
constexpr auto ConvSpec = ck_tile::ConvolutionSpecialization::Default;
using TilePartitioner =
ck_tile::GemmSpatiallyLocalTilePartitioner<GemmShape,
GemmConfig::TileParitionerGroupNum,
GemmConfig::TileParitionerM01>;
ConvConfig::TileParitionerGroupNum,
ConvConfig::TileParitionerM01>;
using GroupedConvTraitsType = ck_tile::GroupedConvTraits<NDimSpatial,
ConvSpec,
InLayout,
@@ -62,22 +61,20 @@ struct GroupedConvolutionForwardInvoker
VectorSizeA,
VectorSizeB,
VectorSizeC,
NumGroupsToMerge,
CDElementWise>;
NumGroupsToMerge>;
using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits<
GemmConfig::kPadM,
GemmConfig::kPadN,
GemmConfig::kPadK,
GemmConfig::DoubleSmemBuffer,
ConvConfig::kPadM,
ConvConfig::kPadN,
ConvConfig::kPadK,
ConvConfig::DoubleSmemBuffer,
typename GroupedConvTraitsType::GroupedConvImplicitGemmTraitsFwd::AsLayout,
typename GroupedConvTraitsType::GroupedConvImplicitGemmTraitsFwd::BsLayout,
typename GroupedConvTraitsType::GroupedConvImplicitGemmTraitsFwd::CLayout,
GemmConfig::TransposeC,
GemmConfig::UseStructuredSparsity,
ConvConfig::TransposeC,
false,
false, // Persistent,
GemmConfig::NumWaveGroups,
GemmConfig::Preshuffle>;
ConvConfig::NumWaveGroups>;
using GemmPipelineProblem = ck_tile::GemmPipelineProblem<
InDataType,
@@ -93,7 +90,7 @@ struct GroupedConvolutionForwardInvoker
VectorSizeB>;
using BaseGemmPipeline = typename PipelineTypeTraits<
GemmConfig::Pipeline>::template UniversalGemmPipeline<GemmPipelineProblem>;
ConvConfig::Pipeline>::template UniversalGemmPipeline<GemmPipelineProblem>;
const ck_tile::index_t gemm_k =
args.C_ * std::accumulate(args.filter_spatial_lengths_.begin(),
@@ -102,8 +99,8 @@ struct GroupedConvolutionForwardInvoker
std::multiplies<ck_tile::index_t>());
// Split-K parameters
const ck_tile::index_t k_grain = args.k_batch * GemmConfig::K_Tile;
const ck_tile::index_t K_split = (gemm_k + k_grain - 1) / k_grain * GemmConfig::K_Tile;
const ck_tile::index_t k_grain = args.k_batch * ConvConfig::K_Tile;
const ck_tile::index_t K_split = (gemm_k + k_grain - 1) / k_grain * ConvConfig::K_Tile;
const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split);
const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop);
const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop);
@@ -116,7 +113,7 @@ struct GroupedConvolutionForwardInvoker
[&](const auto has_hot_loop_, const auto tail_number_, const auto memory_operation_) {
constexpr bool has_hot_loop_v = has_hot_loop_.value;
constexpr auto tail_number_v = tail_number_.value;
constexpr auto scheduler = GemmConfig::Scheduler;
constexpr auto scheduler = ConvConfig::Scheduler;
constexpr auto memory_operation = memory_operation_.value;
using UniversalGemmProblem =
@@ -136,7 +133,7 @@ struct GroupedConvolutionForwardInvoker
VectorSizeB>;
using GemmPipeline = typename PipelineTypeTraits<
GemmConfig::Pipeline>::template GemmPipeline<UniversalGemmProblem>;
ConvConfig::Pipeline>::template GemmPipeline<UniversalGemmProblem>;
using ConvEpilogue = ck_tile::CShuffleEpilogue<ck_tile::CShuffleEpilogueProblem<
InDataType,
@@ -149,12 +146,12 @@ struct GroupedConvolutionForwardInvoker
CDElementWise,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
GemmConfig::M_Warp,
GemmConfig::N_Warp,
GemmConfig::M_Warp_Tile,
GemmConfig::N_Warp_Tile,
GemmConfig::K_Warp_Tile,
GemmConfig::TransposeC,
ConvConfig::M_Warp,
ConvConfig::N_Warp,
ConvConfig::M_Warp_Tile,
ConvConfig::N_Warp_Tile,
ConvConfig::K_Warp_Tile,
ConvConfig::TransposeC,
memory_operation,
1,
true,

View File

@@ -19,7 +19,7 @@
#include "grouped_convolution_forward_large_tensor_invoker.hpp"
#include "run_grouped_convolution_fwd_example.inc"
template <template <typename PrecType> typename GemmConfig>
template <template <typename PrecType> typename ConvConfig>
int run_grouped_conv_fwd_example(int argc, char* argv[])
{
using Invoker = GroupedConvolutionForwardInvoker;
@@ -36,14 +36,14 @@ int run_grouped_conv_fwd_example(int argc, char* argv[])
if(data_type == "fp16")
{
return run_grouped_conv_fwd_example_prec_type<Invoker,
GemmConfig<ck_tile::half_t>,
ConvConfig<ck_tile::half_t>,
ck_tile::half_t>(
in_layout, wei_layout, out_layout, argc, argv);
}
else if(data_type == "bf16")
{
return run_grouped_conv_fwd_example_prec_type<Invoker,
GemmConfig<ck_tile::bf16_t>,
ConvConfig<ck_tile::bf16_t>,
ck_tile::bf16_t>(
in_layout, wei_layout, out_layout, argc, argv);
}

View File

@@ -7,7 +7,7 @@
struct GroupedConvolutionForwardInvoker
{
template <ck_tile::index_t NDimSpatial,
typename GemmConfig,
typename ConvConfig,
typename InDataType,
typename WeiDataType,
typename AccDataType,
@@ -29,12 +29,11 @@ struct GroupedConvolutionForwardInvoker
// Implicit GEMM Traits
using GemmShape = ck_tile::TileGemmShape<
ck_tile::sequence<GemmConfig::M_Tile, GemmConfig::N_Tile, GemmConfig::K_Tile>,
ck_tile::sequence<GemmConfig::M_Warp, GemmConfig::N_Warp, GemmConfig::K_Warp>,
ck_tile::
sequence<GemmConfig::M_Warp_Tile, GemmConfig::N_Warp_Tile, GemmConfig::K_Warp_Tile>,
GemmConfig::PermuteA,
GemmConfig::PermuteB>;
ck_tile::sequence<ConvConfig::M_Tile, ConvConfig::N_Tile, ConvConfig::K_Tile>,
ck_tile::sequence<ConvConfig::M_Warp, ConvConfig::N_Warp, ConvConfig::K_Warp>,
ck_tile::sequence<ConvConfig::M_Warp_Tile,
ConvConfig::N_Warp_Tile,
ConvConfig::K_Warp_Tile>>;
constexpr ck_tile::index_t VectorSizeA = 8;
constexpr ck_tile::index_t VectorSizeB = 8;
@@ -43,41 +42,53 @@ struct GroupedConvolutionForwardInvoker
constexpr auto ConvSpec = ck_tile::ConvolutionSpecialization::Default;
using TilePartitioner =
ck_tile::GemmSpatiallyLocalTilePartitioner<GemmShape,
GemmConfig::TileParitionerGroupNum,
GemmConfig::TileParitionerM01>;
using GroupedConvTraitsType = ck_tile::GroupedConvTraits<NDimSpatial,
ConvSpec,
InLayout,
WeiLayout,
DsLayout,
OutLayout,
VectorSizeA,
VectorSizeB,
VectorSizeC,
1, /*NumGroupsToMerge*/
ck_tile::element_wise::PassThrough,
true /*EnableSplitImage*/>;
ConvConfig::TileParitionerGroupNum,
ConvConfig::TileParitionerM01>;
using GroupedConvTraitsTypeDefault = ck_tile::GroupedConvTraits<NDimSpatial,
ConvSpec,
InLayout,
WeiLayout,
DsLayout,
OutLayout,
VectorSizeA,
VectorSizeB,
VectorSizeC,
1, /*NumGroupsToMerge*/
false /*EnableSplitImage*/>;
using GroupedConvTraitsTypeLargeTensor =
ck_tile::GroupedConvTraits<NDimSpatial,
ConvSpec,
InLayout,
WeiLayout,
DsLayout,
OutLayout,
VectorSizeA,
VectorSizeB,
VectorSizeC,
1, /*NumGroupsToMerge*/
true /*EnableSplitImage*/>;
using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits<
GemmConfig::kPadM,
GemmConfig::kPadN,
GemmConfig::kPadK,
GemmConfig::DoubleSmemBuffer,
typename GroupedConvTraitsType::GroupedConvImplicitGemmTraitsFwd::AsLayout,
typename GroupedConvTraitsType::GroupedConvImplicitGemmTraitsFwd::BsLayout,
typename GroupedConvTraitsType::GroupedConvImplicitGemmTraitsFwd::CLayout,
GemmConfig::TransposeC,
GemmConfig::UseStructuredSparsity,
ConvConfig::kPadM,
ConvConfig::kPadN,
ConvConfig::kPadK,
ConvConfig::DoubleSmemBuffer,
typename GroupedConvTraitsTypeDefault::GroupedConvImplicitGemmTraitsFwd::AsLayout,
typename GroupedConvTraitsTypeDefault::GroupedConvImplicitGemmTraitsFwd::BsLayout,
typename GroupedConvTraitsTypeDefault::GroupedConvImplicitGemmTraitsFwd::CLayout,
ConvConfig::TransposeC,
false,
false, // Persistent,
GemmConfig::NumWaveGroups,
GemmConfig::Preshuffle>;
ConvConfig::NumWaveGroups>;
using GemmPipelineProblem = ck_tile::GemmPipelineProblem<
InDataType,
WeiDataType,
AccDataType,
GemmShape,
typename GroupedConvTraitsType::GroupedConvImplicitGemmTraitsFwd,
typename GroupedConvTraitsTypeDefault::GroupedConvImplicitGemmTraitsFwd,
ck_tile::element_wise::PassThrough,
ck_tile::element_wise::PassThrough,
OutDataType,
@@ -86,7 +97,7 @@ struct GroupedConvolutionForwardInvoker
VectorSizeB>;
using BaseGemmPipeline = typename PipelineTypeTraits<
GemmConfig::Pipeline>::template UniversalGemmPipeline<GemmPipelineProblem>;
ConvConfig::Pipeline>::template UniversalGemmPipeline<GemmPipelineProblem>;
const ck_tile::index_t gemm_k =
args.C_ * std::accumulate(args.filter_spatial_lengths_.begin(),
@@ -95,8 +106,8 @@ struct GroupedConvolutionForwardInvoker
std::multiplies<ck_tile::index_t>());
// Split-K parameters
const ck_tile::index_t k_grain = args.k_batch * GemmConfig::K_Tile;
const ck_tile::index_t K_split = (gemm_k + k_grain - 1) / k_grain * GemmConfig::K_Tile;
const ck_tile::index_t k_grain = args.k_batch * ConvConfig::K_Tile;
const ck_tile::index_t K_split = (gemm_k + k_grain - 1) / k_grain * ConvConfig::K_Tile;
const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split);
const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop);
const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop);
@@ -241,64 +252,71 @@ struct GroupedConvolutionForwardInvoker
// =====================================================================
// Kernel launch lambda: Uses EnableSplitImage based on layout support
// =====================================================================
const auto Run =
[&](const auto has_hot_loop_, const auto tail_number_, const auto memory_operation_) {
constexpr bool has_hot_loop_v = has_hot_loop_.value;
constexpr auto tail_number_v = tail_number_.value;
constexpr auto scheduler = GemmConfig::Scheduler;
constexpr auto memory_operation = memory_operation_.value;
const auto Run = [&]<bool EnableSplitImage>(const auto has_hot_loop_,
const auto tail_number_,
const auto memory_operation_) {
constexpr bool has_hot_loop_v = has_hot_loop_.value;
constexpr auto tail_number_v = tail_number_.value;
constexpr auto scheduler = ConvConfig::Scheduler;
constexpr auto memory_operation = memory_operation_.value;
using UniversalGemmProblem =
ck_tile::UniversalGemmPipelineProblem<InDataType,
WeiDataType,
AccDataType,
GemmShape,
GemmUniversalTraits,
scheduler,
has_hot_loop_v,
tail_number_v,
ck_tile::element_wise::PassThrough,
ck_tile::element_wise::PassThrough,
OutDataType,
true,
VectorSizeA,
VectorSizeB>;
using GroupedConvTraitsType = std::conditional_t<EnableSplitImage,
GroupedConvTraitsTypeLargeTensor,
GroupedConvTraitsTypeDefault>;
using GemmPipeline = typename PipelineTypeTraits<
GemmConfig::Pipeline>::template GemmPipeline<UniversalGemmProblem>;
using UniversalGemmProblem =
ck_tile::UniversalGemmPipelineProblem<InDataType,
WeiDataType,
AccDataType,
GemmShape,
GemmUniversalTraits,
scheduler,
has_hot_loop_v,
tail_number_v,
ck_tile::element_wise::PassThrough,
ck_tile::element_wise::PassThrough,
OutDataType,
true,
VectorSizeA,
VectorSizeB>;
using ConvEpilogue = ck_tile::CShuffleEpilogue<ck_tile::CShuffleEpilogueProblem<
InDataType,
WeiDataType,
DsDataType,
AccDataType,
OutDataType,
typename GroupedConvTraitsType::ImplicitGemmDsLayout,
ck_tile::tensor_layout::gemm::RowMajor,
CDEElementWise,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
GemmConfig::M_Warp,
GemmConfig::N_Warp,
GemmConfig::M_Warp_Tile,
GemmConfig::N_Warp_Tile,
GemmConfig::K_Warp_Tile,
GemmConfig::TransposeC,
memory_operation,
1,
true,
GroupedConvTraitsType::VectorSizeC>>;
using GemmPipeline = typename PipelineTypeTraits<
ConvConfig::Pipeline>::template GemmPipeline<UniversalGemmProblem>;
// Use split-image kernel if layout supports it, otherwise use regular kernel
using Kernel = ck_tile::GroupedConvolutionForwardKernel<GroupedConvTraitsType,
TilePartitioner,
GemmPipeline,
ConvEpilogue>;
using ConvEpilogue = ck_tile::CShuffleEpilogue<ck_tile::CShuffleEpilogueProblem<
InDataType,
WeiDataType,
DsDataType,
AccDataType,
OutDataType,
typename GroupedConvTraitsType::ImplicitGemmDsLayout,
ck_tile::tensor_layout::gemm::RowMajor,
CDEElementWise,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
ConvConfig::M_Warp,
ConvConfig::N_Warp,
ConvConfig::M_Warp_Tile,
ConvConfig::N_Warp_Tile,
ConvConfig::K_Warp_Tile,
ConvConfig::TransposeC,
memory_operation,
1,
true,
GroupedConvTraitsType::VectorSizeC>>;
// Create kargs
auto kargs = Kernel::MakeKernelArgs(args);
// Use split-image kernel if layout supports it, otherwise use regular kernel
using Kernel = ck_tile::GroupedConvolutionForwardKernel<GroupedConvTraitsType,
TilePartitioner,
GemmPipeline,
ConvEpilogue>;
// Populate split-image metadata ONLY if using split-image kernel
// Create kargs
auto kargs = Kernel::MakeKernelArgs(args);
// Populate split-image metadata ONLY if using split-image kernel
if constexpr(EnableSplitImage)
{
kargs.num_spatial_pieces = total_pieces;
kargs.split_image.total_d = total_d;
kargs.split_image.total_h = total_h;
@@ -319,35 +337,41 @@ struct GroupedConvolutionForwardInvoker
temp_pieces[i].h_size,
temp_pieces[i].w_size};
}
}
// Calculate grid: use total_blocks for split-image, or normal GridSize for regular
const dim3 grids = dim3(total_blocks, kargs.GemmBatch, kargs.n_splits);
const dim3 blocks = Kernel::BlockSize();
// Calculate grid: use total_blocks for split-image, or normal GridSize for regular
const dim3 grids = [&]() {
if constexpr(EnableSplitImage)
return dim3(total_blocks, kargs.GemmBatch, kargs.n_splits);
else
return Kernel::GridSize(kargs);
}();
const dim3 blocks = Kernel::BlockSize();
if(!Kernel::IsSupportedArgument(kargs))
{
throw std::runtime_error("Wrong! Arguments not supported! Skipping conv!\n");
}
if(!Kernel::IsSupportedArgument(kargs))
{
throw std::runtime_error("Wrong! Arguments not supported! Skipping conv!\n");
}
if(s.log_level_ > 0)
{
std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n'
<< "shape: " << GemmShape::GetName() << '\n'
<< "problem: " << UniversalGemmProblem::GetName() << '\n'
<< "pipeline: " << GemmPipeline::GetName() << '\n'
<< "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z
<< "}" << '\n'
<< "Vector size A: " << GemmPipeline::GetVectorSizeA()
<< ", Vector size B: " << GemmPipeline::GetVectorSizeB()
<< ", Vector size C: " << ConvEpilogue::GetVectorSizeC() << std::endl;
}
if(s.log_level_ > 0)
{
std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n'
<< "shape: " << GemmShape::GetName() << '\n'
<< "problem: " << UniversalGemmProblem::GetName() << '\n'
<< "pipeline: " << GemmPipeline::GetName() << '\n'
<< "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z
<< "}" << '\n'
<< "Vector size A: " << GemmPipeline::GetVectorSizeA()
<< ", Vector size B: " << GemmPipeline::GetVectorSizeB()
<< ", Vector size C: " << ConvEpilogue::GetVectorSizeC() << std::endl;
}
ave_time = ck_tile::launch_kernel(
s, ck_tile::make_kernel<kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
ave_time = ck_tile::launch_kernel(
s, ck_tile::make_kernel<kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
return ave_time;
};
return ave_time;
};
// =====================================================================
// Step 4: Dispatch kernel (split-image or regular based on decision)

View File

@@ -3,7 +3,7 @@
#pragma once
template <ck_tile::index_t NDimSpatial,
typename GemmConfig,
typename ConvConfig,
typename Invoker,
typename InDataType,
typename WeiDataType,
@@ -17,7 +17,7 @@ float invoke_grouped_conv_bwd_data(ck_tile::GroupedConvBwdDataHostArgs& args,
int n_repeat)
{
float ave_time = Invoker::template grouped_conv_bwd_data<NDimSpatial,
GemmConfig,
ConvConfig,
InDataType,
WeiDataType,
AccDataType,
@@ -39,7 +39,7 @@ float invoke_grouped_conv_bwd_data(ck_tile::GroupedConvBwdDataHostArgs& args,
}
template <ck_tile::index_t NDimSpatial,
typename GemmConfig,
typename ConvConfig,
typename Invoker,
typename InDataType,
typename WeiDataType = InDataType,
@@ -141,7 +141,7 @@ int run_grouped_conv_bwd_data_example_with_layouts(
std::cout << "output: " << output.mDesc << std::endl;
invoke_grouped_conv_bwd_data<NDimSpatial,
GemmConfig,
ConvConfig,
Invoker,
InDataType,
WeiDataType,
@@ -193,7 +193,7 @@ int run_grouped_conv_bwd_data_example_with_layouts(
}
template <typename Invoker,
typename GemmConfig,
typename ConvConfig,
typename InPrecType,
typename WeiPrecType = InPrecType,
typename OutPrecType = InPrecType>
@@ -215,7 +215,7 @@ int run_grouped_conv_bwd_data_example_prec_type(
if(in_layout == "NWGC" && wei_layout == "GKXC" && out_layout == "NWGK")
{
return run_grouped_conv_bwd_data_example_with_layouts<ck_tile::number<1>{},
GemmConfig,
ConvConfig,
Invoker,
InPrecType,
WeiPrecType,
@@ -225,7 +225,7 @@ int run_grouped_conv_bwd_data_example_prec_type(
else if(in_layout == "NHWGC" && wei_layout == "GKYXC" && out_layout == "NHWGK")
{
return run_grouped_conv_bwd_data_example_with_layouts<ck_tile::number<2>{},
GemmConfig,
ConvConfig,
Invoker,
InPrecType,
WeiPrecType,
@@ -235,7 +235,7 @@ int run_grouped_conv_bwd_data_example_prec_type(
else if(in_layout == "NDHWGC" && wei_layout == "GKZYXC" && out_layout == "NDHWGK")
{
return run_grouped_conv_bwd_data_example_with_layouts<ck_tile::number<3>{},
GemmConfig,
ConvConfig,
Invoker,
InPrecType,
WeiPrecType,

View File

@@ -3,7 +3,7 @@
#pragma once
template <ck_tile::index_t NDimSpatial,
typename GemmConfig,
typename ConvConfig,
typename Invoker,
typename InDataType,
typename WeiDataType,
@@ -17,7 +17,7 @@ float invoke_grouped_conv_fwd(const ck_tile::GroupedConvFwdHostArgs<>& args,
int n_repeat)
{
float ave_time = Invoker::template grouped_conv_fwd<NDimSpatial,
GemmConfig,
ConvConfig,
InDataType,
WeiDataType,
AccDataType,
@@ -39,7 +39,7 @@ float invoke_grouped_conv_fwd(const ck_tile::GroupedConvFwdHostArgs<>& args,
}
template <ck_tile::index_t NDimSpatial,
typename GemmConfig,
typename ConvConfig,
typename Invoker,
typename InDataType,
typename WeiDataType = InDataType,
@@ -141,7 +141,7 @@ int run_grouped_conv_fwd_example_with_layouts(
std::cout << "output: " << output.mDesc << std::endl;
invoke_grouped_conv_fwd<NDimSpatial,
GemmConfig,
ConvConfig,
Invoker,
InDataType,
WeiDataType,
@@ -193,7 +193,7 @@ int run_grouped_conv_fwd_example_with_layouts(
}
template <typename Invoker,
typename GemmConfig,
typename ConvConfig,
typename InPrecType,
typename WeiPrecType = InPrecType,
typename OutPrecType = InPrecType>
@@ -215,7 +215,7 @@ int run_grouped_conv_fwd_example_prec_type(
if(in_layout == "NWGC" && wei_layout == "GKXC" && out_layout == "NWGK")
{
return run_grouped_conv_fwd_example_with_layouts<ck_tile::number<1>{},
GemmConfig,
ConvConfig,
Invoker,
InPrecType,
WeiPrecType,
@@ -225,7 +225,7 @@ int run_grouped_conv_fwd_example_prec_type(
else if(in_layout == "NHWGC" && wei_layout == "GKYXC" && out_layout == "NHWGK")
{
return run_grouped_conv_fwd_example_with_layouts<ck_tile::number<2>{},
GemmConfig,
ConvConfig,
Invoker,
InPrecType,
WeiPrecType,
@@ -235,7 +235,7 @@ int run_grouped_conv_fwd_example_prec_type(
else if(in_layout == "NDHWGC" && wei_layout == "GKZYXC" && out_layout == "NDHWGK")
{
return run_grouped_conv_fwd_example_with_layouts<ck_tile::number<3>{},
GemmConfig,
ConvConfig,
Invoker,
InPrecType,
WeiPrecType,