[CK TILE] Convolution remove magic values (#3160)

* [CK TILE] Refactor Conv configs and Conv Elementwise

* fix

* [CK TILE] Convolution remove magix values

* fix partitioner
This commit is contained in:
Bartłomiej Kocot
2025-11-06 11:26:30 +01:00
committed by GitHub
parent 12922120d2
commit 2234ff830b
7 changed files with 355 additions and 352 deletions

View File

@@ -14,19 +14,11 @@
struct ConvConfigBase
{
static constexpr bool kPadM = true;
static constexpr bool kPadN = true;
static constexpr bool kPadK = true;
static constexpr bool TransposeC = false;
static constexpr ck_tile::index_t VectorSizeA = 4;
static constexpr ck_tile::index_t VectorSizeB = 8;
static constexpr ck_tile::index_t VectorSizeC = 8;
static constexpr int kBlockPerCu = 1;
static constexpr ck_tile::index_t TileParitionerGroupNum = 8;
static constexpr ck_tile::index_t TileParitionerM01 = 4;
static constexpr int kBlockPerCu = 1;
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave;
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V3;
static constexpr ck_tile::index_t NumWaveGroups = 1;
@@ -210,9 +202,9 @@ struct ConvConfigComputeV5 : public ConvConfigBase
static constexpr ck_tile::index_t N_Warp_Tile = 32;
static constexpr ck_tile::index_t K_Warp_Tile = 16;
static constexpr bool DoubleSmemBuffer = false;
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V5;
static constexpr ck_tile::index_t NumWaNumWaveGroups = 2;
static constexpr bool DoubleSmemBuffer = false;
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V5;
static constexpr ck_tile::index_t NumWaveGroups = 2;
};
template <typename PrecType>

View File

@@ -22,8 +22,6 @@ struct GroupedConvolutionBackwardDataInvoker
static float grouped_conv_bwd_data(const ck_tile::GroupedConvBwdDataHostArgs& args,
const ck_tile::stream_config& s)
{
constexpr int kBlockPerCu = 1;
// Implicit GEMM Traits
using GemmShape = ck_tile::TileGemmShape<
ck_tile::sequence<ConvConfig::M_Tile, ConvConfig::N_Tile, ConvConfig::K_Tile>,
@@ -32,36 +30,33 @@ struct GroupedConvolutionBackwardDataInvoker
ConvConfig::N_Warp_Tile,
ConvConfig::K_Warp_Tile>>;
constexpr ck_tile::index_t VectorSizeA = 8;
constexpr ck_tile::index_t VectorSizeB = 8;
constexpr ck_tile::index_t VectorSizeC = 8;
constexpr auto ConvSpec = ck_tile::ConvolutionSpecialization::Default;
using TilePartitioner =
ck_tile::GemmSpatiallyLocalTilePartitioner<GemmShape,
ConvConfig::TileParitionerGroupNum,
ConvConfig::TileParitionerM01>;
constexpr auto ConvSpec = ck_tile::ConvolutionSpecialization::Default;
using GroupedConvTraitsType = ck_tile::GroupedConvTraits<NDimSpatial,
ConvSpec,
InLayout,
WeiLayout,
DsLayout,
OutLayout,
VectorSizeA,
VectorSizeB,
VectorSizeC>;
ConvConfig::VectorSizeA,
ConvConfig::VectorSizeB,
ConvConfig::VectorSizeC>;
using TilePartitioner = ck_tile::GemmSpatiallyLocalTilePartitioner<
GemmShape,
GroupedConvTraitsType::FixedGemmParams::TilePartitionerGroupNum,
GroupedConvTraitsType::FixedGemmParams::TilePartitionerM01>;
using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits<
ConvConfig::kPadM,
ConvConfig::kPadN,
ConvConfig::kPadK,
GroupedConvTraitsType::FixedGemmParams::kPadM,
GroupedConvTraitsType::FixedGemmParams::kPadN,
GroupedConvTraitsType::FixedGemmParams::kPadK,
ConvConfig::DoubleSmemBuffer,
typename GroupedConvTraitsType::GroupedConvImplicitGemmTraitsBwdData::AsLayout,
typename GroupedConvTraitsType::GroupedConvImplicitGemmTraitsBwdData::BsLayout,
typename GroupedConvTraitsType::GroupedConvImplicitGemmTraitsBwdData::CLayout,
ConvConfig::TransposeC,
false,
false, // Persistent,
typename GroupedConvTraitsType::AsLayoutBwdData,
typename GroupedConvTraitsType::BsLayoutBwdData,
typename GroupedConvTraitsType::CLayoutBwdData,
GroupedConvTraitsType::FixedGemmParams::TransposeC,
GroupedConvTraitsType::FixedGemmParams::UseStructuredSparsity,
GroupedConvTraitsType::FixedGemmParams::Persistent,
ConvConfig::NumWaveGroups>;
using GemmPipelineProblem = ck_tile::GemmPipelineProblem<
@@ -69,13 +64,14 @@ struct GroupedConvolutionBackwardDataInvoker
WeiDataType,
AccDataType,
GemmShape,
typename GroupedConvTraitsType::GroupedConvImplicitGemmTraitsBwdData,
typename GroupedConvTraitsType::template GroupedConvImplicitGemmTraitsBwdData<
ConvConfig::NumWaveGroups>,
ck_tile::element_wise::PassThrough,
ck_tile::element_wise::PassThrough,
InDataType,
true,
VectorSizeA,
VectorSizeB>;
GroupedConvTraitsType::FixedGemmParams::FixedVectorSize,
GroupedConvTraitsType::VectorSizeA,
GroupedConvTraitsType::VectorSizeB>;
using BaseGemmPipeline = typename PipelineTypeTraits<
ConvConfig::Pipeline>::template UniversalGemmPipeline<GemmPipelineProblem>;
@@ -93,95 +89,96 @@ struct GroupedConvolutionBackwardDataInvoker
const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop);
float ave_time{0};
const auto Run =
[&](const auto has_hot_loop_, const auto tail_number_, const auto memory_operation_) {
constexpr bool has_hot_loop_v = has_hot_loop_.value;
constexpr auto tail_number_v = tail_number_.value;
constexpr auto scheduler = ConvConfig::Scheduler;
constexpr auto memory_operation = memory_operation_.value;
const auto Run = [&](const auto has_hot_loop_,
const auto tail_number_,
const auto memory_operation_) {
constexpr bool has_hot_loop_v = has_hot_loop_.value;
constexpr auto tail_number_v = tail_number_.value;
constexpr auto scheduler = ConvConfig::Scheduler;
constexpr auto memory_operation = memory_operation_.value;
using UniversalGemmProblem =
ck_tile::UniversalGemmPipelineProblem<OutDataType,
WeiDataType,
AccDataType,
GemmShape,
GemmUniversalTraits,
scheduler,
has_hot_loop_v,
tail_number_v,
ck_tile::element_wise::PassThrough,
ck_tile::element_wise::PassThrough,
InDataType,
true,
VectorSizeA,
VectorSizeB>;
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<
OutDataType,
WeiDataType,
AccDataType,
GemmShape,
GemmUniversalTraits,
scheduler,
has_hot_loop_v,
tail_number_v,
ck_tile::element_wise::PassThrough,
ck_tile::element_wise::PassThrough,
InDataType,
GroupedConvTraitsType::FixedGemmParams::FixedVectorSize,
GroupedConvTraitsType::VectorSizeA,
GroupedConvTraitsType::VectorSizeB>;
using GemmPipeline = typename PipelineTypeTraits<
ConvConfig::Pipeline>::template GemmPipeline<UniversalGemmProblem>;
using GemmPipeline = typename PipelineTypeTraits<
ConvConfig::Pipeline>::template GemmPipeline<UniversalGemmProblem>;
using ConvEpilogue = ck_tile::CShuffleEpilogue<ck_tile::CShuffleEpilogueProblem<
OutDataType,
WeiDataType,
DsDataType,
AccDataType,
InDataType,
typename GroupedConvTraitsType::ImplicitGemmDsLayout,
ck_tile::tensor_layout::gemm::RowMajor,
CDEElementWise,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
ConvConfig::M_Warp,
ConvConfig::N_Warp,
ConvConfig::M_Warp_Tile,
ConvConfig::N_Warp_Tile,
ConvConfig::K_Warp_Tile,
ConvConfig::TransposeC,
memory_operation,
1,
true,
GroupedConvTraitsType::VectorSizeC>>;
using ConvEpilogue = ck_tile::CShuffleEpilogue<ck_tile::CShuffleEpilogueProblem<
OutDataType,
WeiDataType,
DsDataType,
AccDataType,
InDataType,
typename GroupedConvTraitsType::ImplicitGemmDsLayout,
typename GroupedConvTraitsType::FixedGemmParams::ELayout,
CDEElementWise,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
ConvConfig::M_Warp,
ConvConfig::N_Warp,
ConvConfig::M_Warp_Tile,
ConvConfig::N_Warp_Tile,
ConvConfig::K_Warp_Tile,
GroupedConvTraitsType::FixedGemmParams::TransposeC,
memory_operation,
ConvConfig::NumWaveGroups,
GroupedConvTraitsType::FixedGemmParams::FixedVectorSize,
GroupedConvTraitsType::VectorSizeC>>;
using Kernel = ck_tile::GroupedConvolutionBackwardDataKernel<GroupedConvTraitsType,
TilePartitioner,
GemmPipeline,
ConvEpilogue>;
auto kargs = Kernel::MakeKernelArgs(args);
using Kernel = ck_tile::GroupedConvolutionBackwardDataKernel<GroupedConvTraitsType,
TilePartitioner,
GemmPipeline,
ConvEpilogue>;
auto kargs = Kernel::MakeKernelArgs(args);
const dim3 grids = Kernel::GridSize(args);
const dim3 blocks = Kernel::BlockSize();
const dim3 grids = Kernel::GridSize(args);
const dim3 blocks = Kernel::BlockSize();
if(!Kernel::IsSupportedArgument(kargs))
{
throw std::runtime_error("Wrong! Arguments not supported! Skipping conv!\n");
}
if(!Kernel::IsSupportedArgument(kargs))
{
throw std::runtime_error("Wrong! Arguments not supported! Skipping conv!\n");
}
if(s.log_level_ > 0)
{
std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n'
<< "shape: " << GemmShape::GetName() << '\n'
<< "problem: " << UniversalGemmProblem::GetName() << '\n'
<< "pipeline: " << GemmPipeline::GetName() << '\n'
<< "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z
<< "}" << '\n'
<< "Vector size A: " << GemmPipeline::GetVectorSizeA()
<< ", Vector size B: " << GemmPipeline::GetVectorSizeB()
<< ", Vector size C: " << ConvEpilogue::GetVectorSizeC() << std::endl;
}
if(s.log_level_ > 0)
{
std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n'
<< "shape: " << GemmShape::GetName() << '\n'
<< "problem: " << UniversalGemmProblem::GetName() << '\n'
<< "pipeline: " << GemmPipeline::GetName() << '\n'
<< "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z
<< "}" << '\n'
<< "Vector size A: " << GemmPipeline::GetVectorSizeA()
<< ", Vector size B: " << GemmPipeline::GetVectorSizeB()
<< ", Vector size C: " << ConvEpilogue::GetVectorSizeC() << std::endl;
}
auto preprocess = [&]() {
ck_tile::hip_check_error(hipMemsetAsync(
kargs.in_ptr, 0, args.template GetInputByte<InDataType>(), s.stream_id_));
};
ave_time = ck_tile::launch_kernel_time_mask(
s,
preprocess,
ck_tile::make_kernel<kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
return ave_time;
auto preprocess = [&]() {
ck_tile::hip_check_error(hipMemsetAsync(
kargs.in_ptr, 0, args.template GetInputByte<InDataType>(), s.stream_id_));
};
ave_time = ck_tile::launch_kernel_time_mask(
s,
preprocess,
ck_tile::make_kernel<ConvConfig::kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
return ave_time;
};
const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) {
if(args.k_batch == 1)
{

View File

@@ -21,8 +21,6 @@ struct GroupedConvolutionBackwardWeightInvoker
static float grouped_conv_bwd_weight(const ck_tile::GroupedConvBwdWeightHostArgs& args,
const ck_tile::stream_config& s)
{
constexpr int kBlockPerCu = 1;
// Implicit GEMM Traits
using GemmShape = ck_tile::TileGemmShape<
ck_tile::sequence<ConvConfig::M_Tile, ConvConfig::N_Tile, ConvConfig::K_Tile>,
@@ -31,37 +29,34 @@ struct GroupedConvolutionBackwardWeightInvoker
ConvConfig::N_Warp_Tile,
ConvConfig::K_Warp_Tile>>;
constexpr ck_tile::index_t VectorSizeA = ConvConfig::VectorSizeA;
constexpr ck_tile::index_t VectorSizeB = ConvConfig::VectorSizeB;
constexpr ck_tile::index_t VectorSizeC = ConvConfig::VectorSizeC;
constexpr auto ConvSpec = ck_tile::ConvolutionSpecialization::Default;
using TilePartitioner =
ck_tile::GemmSpatiallyLocalTilePartitioner<GemmShape,
ConvConfig::TileParitionerGroupNum,
ConvConfig::TileParitionerM01>;
constexpr auto ConvSpec = ck_tile::ConvolutionSpecialization::Default;
using GroupedConvTraitsType = ck_tile::GroupedConvTraits<NDimSpatial,
ConvSpec,
InLayout,
WeiLayout,
DsLayout,
OutLayout,
VectorSizeA,
VectorSizeB,
VectorSizeC,
ConvConfig::VectorSizeA,
ConvConfig::VectorSizeB,
ConvConfig::VectorSizeC,
ConvConfig::NumGroupsToMerge>;
using TilePartitioner = ck_tile::GemmSpatiallyLocalTilePartitioner<
GemmShape,
GroupedConvTraitsType::FixedGemmParams::TilePartitionerGroupNum,
GroupedConvTraitsType::FixedGemmParams::TilePartitionerM01>;
using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits<
ConvConfig::kPadM,
ConvConfig::kPadN,
ConvConfig::kPadK,
GroupedConvTraitsType::FixedGemmParams::kPadM,
GroupedConvTraitsType::FixedGemmParams::kPadN,
GroupedConvTraitsType::FixedGemmParams::kPadK,
ConvConfig::DoubleSmemBuffer,
typename GroupedConvTraitsType::GroupedConvImplicitGemmTraitsBwdWeight::AsLayout,
typename GroupedConvTraitsType::GroupedConvImplicitGemmTraitsBwdWeight::BsLayout,
typename GroupedConvTraitsType::GroupedConvImplicitGemmTraitsBwdWeight::CLayout,
ConvConfig::TransposeC,
false,
false, // Persistent,
typename GroupedConvTraitsType::AsLayoutBwdWeight,
typename GroupedConvTraitsType::BsLayoutBwdWeight,
typename GroupedConvTraitsType::CLayoutBwdWeight,
GroupedConvTraitsType::FixedGemmParams::TransposeC,
GroupedConvTraitsType::FixedGemmParams::UseStructuredSparsity,
GroupedConvTraitsType::FixedGemmParams::Persistent,
ConvConfig::NumWaveGroups>;
using GemmPipelineProblem = ck_tile::GemmPipelineProblem<
@@ -69,13 +64,14 @@ struct GroupedConvolutionBackwardWeightInvoker
InDataType,
AccDataType,
GemmShape,
typename GroupedConvTraitsType::GroupedConvImplicitGemmTraitsBwdWeight,
typename GroupedConvTraitsType::template GroupedConvImplicitGemmTraitsBwdWeight<
ConvConfig::NumWaveGroups>,
ck_tile::element_wise::PassThrough,
ck_tile::element_wise::PassThrough,
WeiDataType,
true,
VectorSizeA,
VectorSizeB>;
GroupedConvTraitsType::FixedGemmParams::FixedVectorSize,
GroupedConvTraitsType::VectorSizeA,
GroupedConvTraitsType::VectorSizeB>;
using BaseGemmPipeline = typename PipelineTypeTraits<
ConvConfig::Pipeline>::template UniversalGemmPipeline<GemmPipelineProblem>;
@@ -101,21 +97,21 @@ struct GroupedConvolutionBackwardWeightInvoker
constexpr auto scheduler = ConvConfig::Scheduler;
constexpr auto memory_operation = memory_operation_.value;
using UniversalGemmProblem =
ck_tile::UniversalGemmPipelineProblem<OutDataType,
InDataType,
AccDataType,
GemmShape,
GemmUniversalTraits,
scheduler,
has_hot_loop_v,
tail_number_v,
ck_tile::element_wise::PassThrough,
ck_tile::element_wise::PassThrough,
WeiDataType,
true,
VectorSizeA,
VectorSizeB>;
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<
OutDataType,
InDataType,
AccDataType,
GemmShape,
GemmUniversalTraits,
scheduler,
has_hot_loop_v,
tail_number_v,
ck_tile::element_wise::PassThrough,
ck_tile::element_wise::PassThrough,
WeiDataType,
GroupedConvTraitsType::FixedGemmParams::FixedVectorSize,
GroupedConvTraitsType::VectorSizeA,
GroupedConvTraitsType::VectorSizeB>;
using GemmPipeline = typename PipelineTypeTraits<
ConvConfig::Pipeline>::template GemmPipeline<UniversalGemmProblem>;
@@ -127,7 +123,7 @@ struct GroupedConvolutionBackwardWeightInvoker
AccDataType,
WeiDataType,
typename GroupedConvTraitsType::ImplicitGemmDsLayout,
ck_tile::tensor_layout::gemm::RowMajor,
typename GroupedConvTraitsType::FixedGemmParams::ELayout,
CDEElementWise,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
@@ -136,10 +132,10 @@ struct GroupedConvolutionBackwardWeightInvoker
ConvConfig::M_Warp_Tile,
ConvConfig::N_Warp_Tile,
ConvConfig::K_Warp_Tile,
ConvConfig::TransposeC,
GroupedConvTraitsType::FixedGemmParams::TransposeC,
memory_operation,
1,
true,
ConvConfig::NumWaveGroups,
GroupedConvTraitsType::FixedGemmParams::FixedVectorSize,
GroupedConvTraitsType::VectorSizeC>>;
using Kernel = ck_tile::GroupedConvolutionBackwardWeightKernel<GroupedConvTraitsType,
@@ -184,7 +180,7 @@ struct GroupedConvolutionBackwardWeightInvoker
ave_time = ck_tile::launch_kernel_time_mask(
s,
preprocess,
ck_tile::make_kernel<kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
ck_tile::make_kernel<ConvConfig::kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
return ave_time;
};

View File

@@ -23,8 +23,6 @@ struct GroupedConvolutionBackwardWeightTwoStageInvoker
{
using WorkspaceDataType = float;
constexpr int kBlockPerCu = 1;
// Implicit GEMM Traits
using GemmShape = ck_tile::TileGemmShape<
ck_tile::sequence<ConvConfig::M_Tile, ConvConfig::N_Tile, ConvConfig::K_Tile>,
@@ -33,36 +31,34 @@ struct GroupedConvolutionBackwardWeightTwoStageInvoker
ConvConfig::N_Warp_Tile,
ConvConfig::K_Warp_Tile>>;
constexpr ck_tile::index_t VectorSizeA = 4;
constexpr ck_tile::index_t VectorSizeB = 8;
constexpr ck_tile::index_t VectorSizeC = 8;
constexpr auto ConvSpec = ck_tile::ConvolutionSpecialization::Default;
using TilePartitioner =
ck_tile::GemmSpatiallyLocalTilePartitioner<GemmShape,
ConvConfig::TileParitionerGroupNum,
ConvConfig::TileParitionerM01>;
constexpr auto ConvSpec = ck_tile::ConvolutionSpecialization::Default;
using GroupedConvTraitsType = ck_tile::GroupedConvTraits<NDimSpatial,
ConvSpec,
InLayout,
WeiLayout,
DsLayout,
OutLayout,
VectorSizeA,
VectorSizeB,
VectorSizeC>;
ConvConfig::VectorSizeA,
ConvConfig::VectorSizeB,
ConvConfig::VectorSizeC,
ConvConfig::NumGroupsToMerge>;
using TilePartitioner = ck_tile::GemmSpatiallyLocalTilePartitioner<
GemmShape,
GroupedConvTraitsType::FixedGemmParams::TilePartitionerGroupNum,
GroupedConvTraitsType::FixedGemmParams::TilePartitionerM01>;
using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits<
ConvConfig::kPadM,
ConvConfig::kPadN,
ConvConfig::kPadK,
GroupedConvTraitsType::FixedGemmParams::kPadM,
GroupedConvTraitsType::FixedGemmParams::kPadN,
GroupedConvTraitsType::FixedGemmParams::kPadK,
ConvConfig::DoubleSmemBuffer,
typename GroupedConvTraitsType::GroupedConvImplicitGemmTraitsBwdWeight::AsLayout,
typename GroupedConvTraitsType::GroupedConvImplicitGemmTraitsBwdWeight::BsLayout,
typename GroupedConvTraitsType::GroupedConvImplicitGemmTraitsBwdWeight::CLayout,
ConvConfig::TransposeC,
false,
false, // Persistent,
typename GroupedConvTraitsType::AsLayoutBwdWeight,
typename GroupedConvTraitsType::BsLayoutBwdWeight,
typename GroupedConvTraitsType::CLayoutBwdWeight,
GroupedConvTraitsType::FixedGemmParams::TransposeC,
GroupedConvTraitsType::FixedGemmParams::UseStructuredSparsity,
GroupedConvTraitsType::FixedGemmParams::Persistent,
ConvConfig::NumWaveGroups>;
using GemmPipelineProblem = ck_tile::GemmPipelineProblem<
@@ -70,13 +66,14 @@ struct GroupedConvolutionBackwardWeightTwoStageInvoker
InDataType,
AccDataType,
GemmShape,
typename GroupedConvTraitsType::GroupedConvImplicitGemmTraitsBwdWeight,
typename GroupedConvTraitsType::template GroupedConvImplicitGemmTraitsBwdWeight<
ConvConfig::NumWaveGroups>,
ck_tile::element_wise::PassThrough,
ck_tile::element_wise::PassThrough,
WeiDataType,
true,
VectorSizeA,
VectorSizeB>;
GroupedConvTraitsType::FixedGemmParams::FixedVectorSize,
GroupedConvTraitsType::VectorSizeA,
GroupedConvTraitsType::VectorSizeB>;
using BaseGemmPipeline = typename PipelineTypeTraits<
ConvConfig::Pipeline>::template UniversalGemmPipeline<GemmPipelineProblem>;
@@ -102,21 +99,21 @@ struct GroupedConvolutionBackwardWeightTwoStageInvoker
constexpr auto scheduler = ConvConfig::Scheduler;
constexpr auto memory_operation = memory_operation_.value;
using UniversalGemmProblem =
ck_tile::UniversalGemmPipelineProblem<OutDataType,
InDataType,
AccDataType,
GemmShape,
GemmUniversalTraits,
scheduler,
has_hot_loop_v,
tail_number_v,
ck_tile::element_wise::PassThrough,
ck_tile::element_wise::PassThrough,
WeiDataType,
true,
VectorSizeA,
VectorSizeB>;
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<
OutDataType,
InDataType,
AccDataType,
GemmShape,
GemmUniversalTraits,
scheduler,
has_hot_loop_v,
tail_number_v,
ck_tile::element_wise::PassThrough,
ck_tile::element_wise::PassThrough,
WeiDataType,
GroupedConvTraitsType::FixedGemmParams::FixedVectorSize,
GroupedConvTraitsType::VectorSizeA,
GroupedConvTraitsType::VectorSizeB>;
using GemmPipeline = typename PipelineTypeTraits<
ConvConfig::Pipeline>::template GemmPipeline<UniversalGemmProblem>;
@@ -128,7 +125,7 @@ struct GroupedConvolutionBackwardWeightTwoStageInvoker
AccDataType,
WorkspaceDataType, // C: Workspace normally Out
typename GroupedConvTraitsType::ImplicitGemmDsLayout,
ck_tile::tensor_layout::gemm::RowMajor,
typename GroupedConvTraitsType::FixedGemmParams::ELayout,
CDEElementWise,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
@@ -139,8 +136,8 @@ struct GroupedConvolutionBackwardWeightTwoStageInvoker
ConvConfig::K_Warp_Tile,
GemmPipelineProblem::TransposeC,
memory_operation,
1,
true,
ConvConfig::NumWaveGroups,
GroupedConvTraitsType::FixedGemmParams::FixedVectorSize,
GroupedConvTraitsType::VectorSizeC>>;
using Kernel = ck_tile::GroupedConvolutionBackwardWeightKernel<GroupedConvTraitsType,
@@ -235,16 +232,17 @@ struct GroupedConvolutionBackwardWeightTwoStageInvoker
ave_time = ck_tile::launch_kernel_time_mask(
s,
preprocess,
ck_tile::make_kernel<kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs),
ck_tile::make_kernel<kBlockPerCu>(ElementwiseKernel{},
kGridSize,
kBlockSize,
0,
input_size,
ck_tile::make_tuple(shape[1], 1), // Input Stride
ck_tile::make_tuple(shape[1], 1), // Output Stride
input_tensors,
static_cast<WeiDataType*>(c_ptr)));
ck_tile::make_kernel<ConvConfig::kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs),
ck_tile::make_kernel<ConvConfig::kBlockPerCu>(
ElementwiseKernel{},
kGridSize,
kBlockSize,
0,
input_size,
ck_tile::make_tuple(shape[1], 1), // Input Stride
ck_tile::make_tuple(shape[1], 1), // Output Stride
input_tensors,
static_cast<WeiDataType*>(c_ptr)));
return ave_time;
};

View File

@@ -32,8 +32,6 @@ struct GroupedConvolutionForwardInvoker
{
std::cout << "[INVOKER] grouped_conv_fwd called, NDimSpatial=" << NDimSpatial << "\n";
}
constexpr int kBlockPerCu = 1;
// Implicit GEMM Traits
using GemmShape = ck_tile::TileGemmShape<
ck_tile::sequence<ConvConfig::M_Tile, ConvConfig::N_Tile, ConvConfig::K_Tile>,
@@ -42,38 +40,34 @@ struct GroupedConvolutionForwardInvoker
ConvConfig::N_Warp_Tile,
ConvConfig::K_Warp_Tile>>;
constexpr ck_tile::index_t VectorSizeA = 8;
constexpr ck_tile::index_t VectorSizeB = 8;
constexpr ck_tile::index_t VectorSizeC = 8;
constexpr ck_tile::index_t NumGroupsToMerge = 1;
constexpr auto ConvSpec = ck_tile::ConvolutionSpecialization::Default;
using TilePartitioner =
ck_tile::GemmSpatiallyLocalTilePartitioner<GemmShape,
ConvConfig::TileParitionerGroupNum,
ConvConfig::TileParitionerM01>;
constexpr auto ConvSpec = ck_tile::ConvolutionSpecialization::Default;
using GroupedConvTraitsType = ck_tile::GroupedConvTraits<NDimSpatial,
ConvSpec,
InLayout,
WeiLayout,
DsLayout,
OutLayout,
VectorSizeA,
VectorSizeB,
VectorSizeC,
NumGroupsToMerge>;
ConvConfig::VectorSizeA,
ConvConfig::VectorSizeB,
ConvConfig::VectorSizeC,
ConvConfig::NumGroupsToMerge>;
using TilePartitioner = ck_tile::GemmSpatiallyLocalTilePartitioner<
GemmShape,
GroupedConvTraitsType::FixedGemmParams::TilePartitionerGroupNum,
GroupedConvTraitsType::FixedGemmParams::TilePartitionerM01>;
using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits<
ConvConfig::kPadM,
ConvConfig::kPadN,
ConvConfig::kPadK,
GroupedConvTraitsType::FixedGemmParams::kPadM,
GroupedConvTraitsType::FixedGemmParams::kPadN,
GroupedConvTraitsType::FixedGemmParams::kPadK,
ConvConfig::DoubleSmemBuffer,
typename GroupedConvTraitsType::GroupedConvImplicitGemmTraitsFwd::AsLayout,
typename GroupedConvTraitsType::GroupedConvImplicitGemmTraitsFwd::BsLayout,
typename GroupedConvTraitsType::GroupedConvImplicitGemmTraitsFwd::CLayout,
ConvConfig::TransposeC,
false,
false, // Persistent,
typename GroupedConvTraitsType::AsLayoutFwd,
typename GroupedConvTraitsType::BsLayoutFwd,
typename GroupedConvTraitsType::CLayoutFwd,
GroupedConvTraitsType::FixedGemmParams::TransposeC,
GroupedConvTraitsType::FixedGemmParams::UseStructuredSparsity,
GroupedConvTraitsType::FixedGemmParams::Persistent,
ConvConfig::NumWaveGroups>;
using GemmPipelineProblem = ck_tile::GemmPipelineProblem<
@@ -81,13 +75,14 @@ struct GroupedConvolutionForwardInvoker
WeiDataType,
AccDataType,
GemmShape,
typename GroupedConvTraitsType::GroupedConvImplicitGemmTraitsFwd,
typename GroupedConvTraitsType::template GroupedConvImplicitGemmTraitsFwd<
ConvConfig::NumWaveGroups>,
ck_tile::element_wise::PassThrough,
ck_tile::element_wise::PassThrough,
OutDataType,
true,
VectorSizeA,
VectorSizeB>;
GroupedConvTraitsType::FixedGemmParams::FixedVectorSize,
GroupedConvTraitsType::VectorSizeA,
GroupedConvTraitsType::VectorSizeB>;
using BaseGemmPipeline = typename PipelineTypeTraits<
ConvConfig::Pipeline>::template UniversalGemmPipeline<GemmPipelineProblem>;
@@ -116,21 +111,21 @@ struct GroupedConvolutionForwardInvoker
constexpr auto scheduler = ConvConfig::Scheduler;
constexpr auto memory_operation = memory_operation_.value;
using UniversalGemmProblem =
ck_tile::UniversalGemmPipelineProblem<InDataType,
WeiDataType,
AccDataType,
GemmShape,
GemmUniversalTraits,
scheduler,
has_hot_loop_v,
tail_number_v,
ck_tile::element_wise::PassThrough,
ck_tile::element_wise::PassThrough,
OutDataType,
true,
VectorSizeA,
VectorSizeB>;
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<
InDataType,
WeiDataType,
AccDataType,
GemmShape,
GemmUniversalTraits,
scheduler,
has_hot_loop_v,
tail_number_v,
ck_tile::element_wise::PassThrough,
ck_tile::element_wise::PassThrough,
OutDataType,
GroupedConvTraitsType::FixedGemmParams::FixedVectorSize,
GroupedConvTraitsType::VectorSizeA,
GroupedConvTraitsType::VectorSizeB>;
using GemmPipeline = typename PipelineTypeTraits<
ConvConfig::Pipeline>::template GemmPipeline<UniversalGemmProblem>;
@@ -142,7 +137,7 @@ struct GroupedConvolutionForwardInvoker
AccDataType,
OutDataType,
typename GroupedConvTraitsType::ImplicitGemmDsLayout,
ck_tile::tensor_layout::gemm::RowMajor,
typename GroupedConvTraitsType::FixedGemmParams::ELayout,
CDElementWise,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
@@ -151,10 +146,10 @@ struct GroupedConvolutionForwardInvoker
ConvConfig::M_Warp_Tile,
ConvConfig::N_Warp_Tile,
ConvConfig::K_Warp_Tile,
ConvConfig::TransposeC,
GroupedConvTraitsType::FixedGemmParams::TransposeC,
memory_operation,
1,
true,
ConvConfig::NumWaveGroups,
GroupedConvTraitsType::FixedGemmParams::FixedVectorSize,
GroupedConvTraitsType::VectorSizeC>>;
using Kernel = ck_tile::GroupedConvolutionForwardKernel<GroupedConvTraitsType,
@@ -185,8 +180,9 @@ struct GroupedConvolutionForwardInvoker
<< ", Vector size C: " << ConvEpilogue::GetVectorSizeC() << std::endl;
}
ave_time = ck_tile::launch_kernel(
s, ck_tile::make_kernel<kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
ave_time = ck_tile::launch_kernel(s,
ck_tile::make_kernel<ConvConfig::kBlockPerCu>(
Kernel{}, grids, blocks, 0, kargs));
return ave_time;
};

View File

@@ -25,7 +25,6 @@ struct GroupedConvolutionForwardInvoker
{
std::cout << "[INVOKER] grouped_conv_fwd called, NDimSpatial=" << NDimSpatial << "\n";
}
constexpr int kBlockPerCu = 1;
// Implicit GEMM Traits
using GemmShape = ck_tile::TileGemmShape<
@@ -35,27 +34,18 @@ struct GroupedConvolutionForwardInvoker
ConvConfig::N_Warp_Tile,
ConvConfig::K_Warp_Tile>>;
constexpr ck_tile::index_t VectorSizeA = 8;
constexpr ck_tile::index_t VectorSizeB = 8;
constexpr ck_tile::index_t VectorSizeC = 8;
constexpr auto ConvSpec = ck_tile::ConvolutionSpecialization::Default;
using TilePartitioner =
ck_tile::GemmSpatiallyLocalTilePartitioner<GemmShape,
ConvConfig::TileParitionerGroupNum,
ConvConfig::TileParitionerM01>;
using GroupedConvTraitsTypeDefault = ck_tile::GroupedConvTraits<NDimSpatial,
ConvSpec,
InLayout,
WeiLayout,
DsLayout,
OutLayout,
VectorSizeA,
VectorSizeB,
VectorSizeC,
1, /*NumGroupsToMerge*/
false /*EnableSplitImage*/>;
using GroupedConvTraitsTypeDefault =
ck_tile::GroupedConvTraits<NDimSpatial,
ConvSpec,
InLayout,
WeiLayout,
DsLayout,
OutLayout,
ConvConfig::VectorSizeA,
ConvConfig::VectorSizeB,
ConvConfig::VectorSizeC,
ConvConfig::NumGroupsToMerge>;
using GroupedConvTraitsTypeLargeTensor =
ck_tile::GroupedConvTraits<NDimSpatial,
@@ -64,23 +54,28 @@ struct GroupedConvolutionForwardInvoker
WeiLayout,
DsLayout,
OutLayout,
VectorSizeA,
VectorSizeB,
VectorSizeC,
1, /*NumGroupsToMerge*/
ConvConfig::VectorSizeA,
ConvConfig::VectorSizeB,
ConvConfig::VectorSizeC,
ConvConfig::NumGroupsToMerge,
true /*EnableSplitImage*/>;
using TilePartitioner = ck_tile::GemmSpatiallyLocalTilePartitioner<
GemmShape,
GroupedConvTraitsTypeDefault::FixedGemmParams::TilePartitionerGroupNum,
GroupedConvTraitsTypeDefault::FixedGemmParams::TilePartitionerM01>;
using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits<
ConvConfig::kPadM,
ConvConfig::kPadN,
ConvConfig::kPadK,
GroupedConvTraitsTypeDefault::FixedGemmParams::kPadM,
GroupedConvTraitsTypeDefault::FixedGemmParams::kPadN,
GroupedConvTraitsTypeDefault::FixedGemmParams::kPadK,
ConvConfig::DoubleSmemBuffer,
typename GroupedConvTraitsTypeDefault::GroupedConvImplicitGemmTraitsFwd::AsLayout,
typename GroupedConvTraitsTypeDefault::GroupedConvImplicitGemmTraitsFwd::BsLayout,
typename GroupedConvTraitsTypeDefault::GroupedConvImplicitGemmTraitsFwd::CLayout,
ConvConfig::TransposeC,
false,
false, // Persistent,
typename GroupedConvTraitsTypeDefault::AsLayoutFwd,
typename GroupedConvTraitsTypeDefault::BsLayoutFwd,
typename GroupedConvTraitsTypeDefault::CLayoutFwd,
GroupedConvTraitsTypeDefault::FixedGemmParams::TransposeC,
GroupedConvTraitsTypeDefault::FixedGemmParams::UseStructuredSparsity,
GroupedConvTraitsTypeDefault::FixedGemmParams::Persistent,
ConvConfig::NumWaveGroups>;
using GemmPipelineProblem = ck_tile::GemmPipelineProblem<
@@ -88,13 +83,14 @@ struct GroupedConvolutionForwardInvoker
WeiDataType,
AccDataType,
GemmShape,
typename GroupedConvTraitsTypeDefault::GroupedConvImplicitGemmTraitsFwd,
typename GroupedConvTraitsTypeDefault::template GroupedConvImplicitGemmTraitsFwd<
ConvConfig::NumWaveGroups>,
ck_tile::element_wise::PassThrough,
ck_tile::element_wise::PassThrough,
OutDataType,
true,
VectorSizeA,
VectorSizeB>;
GroupedConvTraitsTypeDefault::FixedGemmParams::FixedVectorSize,
GroupedConvTraitsTypeDefault::VectorSizeA,
GroupedConvTraitsTypeDefault::VectorSizeB>;
using BaseGemmPipeline = typename PipelineTypeTraits<
ConvConfig::Pipeline>::template UniversalGemmPipeline<GemmPipelineProblem>;
@@ -116,9 +112,9 @@ struct GroupedConvolutionForwardInvoker
using TransformType =
ck_tile::TransformConvFwdToGemm<NDimSpatial,
ck_tile::ConvolutionSpecialization::Default,
VectorSizeA,
VectorSizeB,
VectorSizeC,
GroupedConvTraitsTypeDefault::VectorSizeA,
GroupedConvTraitsTypeDefault::VectorSizeB,
GroupedConvTraitsTypeDefault::VectorSizeC,
1, // NumGroupsToMerge
false, // SplitN
InDataType,
@@ -264,21 +260,21 @@ struct GroupedConvolutionForwardInvoker
GroupedConvTraitsTypeLargeTensor,
GroupedConvTraitsTypeDefault>;
using UniversalGemmProblem =
ck_tile::UniversalGemmPipelineProblem<InDataType,
WeiDataType,
AccDataType,
GemmShape,
GemmUniversalTraits,
scheduler,
has_hot_loop_v,
tail_number_v,
ck_tile::element_wise::PassThrough,
ck_tile::element_wise::PassThrough,
OutDataType,
true,
VectorSizeA,
VectorSizeB>;
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<
InDataType,
WeiDataType,
AccDataType,
GemmShape,
GemmUniversalTraits,
scheduler,
has_hot_loop_v,
tail_number_v,
ck_tile::element_wise::PassThrough,
ck_tile::element_wise::PassThrough,
OutDataType,
GroupedConvTraitsType::FixedGemmParams::FixedVectorSize,
GroupedConvTraitsType::VectorSizeA,
GroupedConvTraitsType::VectorSizeB>;
using GemmPipeline = typename PipelineTypeTraits<
ConvConfig::Pipeline>::template GemmPipeline<UniversalGemmProblem>;
@@ -290,7 +286,7 @@ struct GroupedConvolutionForwardInvoker
AccDataType,
OutDataType,
typename GroupedConvTraitsType::ImplicitGemmDsLayout,
ck_tile::tensor_layout::gemm::RowMajor,
typename GroupedConvTraitsType::FixedGemmParams::ELayout,
CDEElementWise,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
@@ -299,10 +295,10 @@ struct GroupedConvolutionForwardInvoker
ConvConfig::M_Warp_Tile,
ConvConfig::N_Warp_Tile,
ConvConfig::K_Warp_Tile,
ConvConfig::TransposeC,
GroupedConvTraitsType::FixedGemmParams::TransposeC,
memory_operation,
1,
true,
ConvConfig::NumWaveGroups,
GroupedConvTraitsType::FixedGemmParams::FixedVectorSize,
GroupedConvTraitsType::VectorSizeC>>;
// Use split-image kernel if layout supports it, otherwise use regular kernel
@@ -368,7 +364,8 @@ struct GroupedConvolutionForwardInvoker
}
ave_time = ck_tile::launch_kernel(
s, ck_tile::make_kernel<kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
s,
ck_tile::make_kernel<ConvConfig::kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
return ave_time;
};