Merge commit 'cafaeb6b7bac4e18b0a5341cd14f54224292a0c9' into develop

This commit is contained in:
assistant-librarian[bot]
2025-10-29 15:12:59 +00:00
parent 83b2a1d876
commit 26e9ec020f
29 changed files with 1970 additions and 282 deletions

22
Jenkinsfile vendored
View File

@@ -20,6 +20,28 @@ def failurePatterns = [
[pattern: /cat: .* No such file or directory/, description: "GPU not found"],
]
// Given a pattern, check if the log contains the pattern and return the context.
def checkForPattern(pattern, log) {
def lines = log.split('\n')
for (int i = 0; i < lines.size(); i++) {
if (lines[i] =~ pattern) {
echo "Found pattern match in log for ${pattern}"
// Get the two lines before and after failure.
def contextStart = Math.max(0, i - 2)
def contextEnd = Math.min(lines.size() - 1, i + 2)
def contextLines = []
for (int j = contextStart; j <= contextEnd; j++) {
contextLines.add(lines[j])
}
return [found: true, matchedLine: lines[i], context: contextLines.join('\n')]
}
}
echo "No pattern match found in log for ${pattern}"
return [found: false, matchedLine: "", context: ""]
}
class Version {
int major, minor, patch
@Override

View File

@@ -17,7 +17,7 @@
#define CK_TILE_PIPELINE_COMPUTE_V4 3
#define CK_TILE_PIPELINE_COMPUTE_V5 4
struct GemmConfigBase
struct ConvConfigBase
{
static constexpr bool kPadM = true;
static constexpr bool kPadN = true;
@@ -29,6 +29,10 @@ struct GemmConfigBase
static constexpr bool TransposeC = false;
static constexpr bool UseStructuredSparsity = false;
static constexpr ck_tile::index_t VectorSizeA = 4;
static constexpr ck_tile::index_t VectorSizeB = 8;
static constexpr ck_tile::index_t VectorSizeC = 8;
static constexpr int kBlockPerCu = 1;
static constexpr ck_tile::index_t TileParitionerGroupNum = 8;
static constexpr ck_tile::index_t TileParitionerM01 = 4;
@@ -37,10 +41,12 @@ struct GemmConfigBase
static constexpr ck_tile::index_t NumWaveGroups = 1;
static constexpr bool Preshuffle = false;
static constexpr bool TiledMMAPermuteN = false;
static constexpr ck_tile::index_t NumGroupsToMerge = 1;
};
template <typename PrecType>
struct GemmConfigMemoryInterwave : public GemmConfigBase
struct ConvConfigMemoryInterwave : public ConvConfigBase
{
// Memory friendly for Interwave scheduler
static constexpr ck_tile::index_t M_Tile = 128;
@@ -61,7 +67,7 @@ struct GemmConfigMemoryInterwave : public GemmConfigBase
};
template <typename PrecType>
struct GemmConfigMemoryIntrawave : public GemmConfigBase
struct ConvConfigMemoryIntrawave : public ConvConfigBase
{
static constexpr ck_tile::index_t M_Tile = 128;
static constexpr ck_tile::index_t N_Tile = 32;
@@ -80,7 +86,7 @@ struct GemmConfigMemoryIntrawave : public GemmConfigBase
};
template <typename PrecType>
struct GemmConfigComputeV3 : public GemmConfigBase
struct ConvConfigComputeV3 : public ConvConfigBase
{
// Compute V3 only support Intrawave scheduler
static constexpr ck_tile::index_t M_Tile = 16;
@@ -100,7 +106,7 @@ struct GemmConfigComputeV3 : public GemmConfigBase
};
template <typename PrecType>
struct GemmConfigComputeV3_1 : public GemmConfigBase
struct ConvConfigComputeV3_1 : public ConvConfigBase
{
static constexpr ck_tile::index_t M_Tile = 256;
static constexpr ck_tile::index_t N_Tile = 256;
@@ -119,7 +125,7 @@ struct GemmConfigComputeV3_1 : public GemmConfigBase
};
template <typename PrecType>
struct GemmConfigComputeV3_2 : public GemmConfigBase
struct ConvConfigComputeV3_2 : public ConvConfigBase
{
static constexpr ck_tile::index_t M_Tile = 128;
static constexpr ck_tile::index_t N_Tile = 128;
@@ -140,7 +146,7 @@ struct GemmConfigComputeV3_2 : public GemmConfigBase
};
template <typename PrecType>
struct GemmConfigComputeV3_WMMA : public GemmConfigBase
struct ConvConfigComputeV3_WMMA : public ConvConfigBase
{
static constexpr ck_tile::index_t M_Tile = 128;
static constexpr ck_tile::index_t N_Tile = 128;
@@ -161,7 +167,7 @@ struct GemmConfigComputeV3_WMMA : public GemmConfigBase
};
template <typename PrecType>
struct GemmConfigComputeV4 : public GemmConfigBase
struct ConvConfigComputeV4 : public ConvConfigBase
{
// Compute V4 only support Intrawave scheduler
// Using the ping pong reader in the lds level
@@ -182,7 +188,7 @@ struct GemmConfigComputeV4 : public GemmConfigBase
};
template <typename PrecType>
struct GemmConfigComputeV4_1 : public GemmConfigBase
struct ConvConfigComputeV4_1 : public ConvConfigBase
{
static constexpr ck_tile::index_t M_Tile = 256;
static constexpr ck_tile::index_t N_Tile = 256;
@@ -201,7 +207,7 @@ struct GemmConfigComputeV4_1 : public GemmConfigBase
};
template <typename PrecType>
struct GemmConfigComputeV5 : public GemmConfigBase
struct ConvConfigComputeV5 : public ConvConfigBase
{
static constexpr ck_tile::index_t M_Tile = 128;
static constexpr ck_tile::index_t N_Tile = 128;
@@ -220,6 +226,31 @@ struct GemmConfigComputeV5 : public GemmConfigBase
static constexpr ck_tile::index_t NumWaNumWaveGroups = 2;
};
template <typename PrecType>
struct ConvConfigComputeV3_merged_groups : public ConvConfigBase
{
static constexpr ck_tile::index_t VectorSizeA = 4;
static constexpr ck_tile::index_t VectorSizeB = 8;
static constexpr ck_tile::index_t VectorSizeC = 8;
static constexpr ck_tile::index_t M_Tile = 16;
static constexpr ck_tile::index_t N_Tile = 32;
static constexpr ck_tile::index_t K_Tile = 32;
static constexpr ck_tile::index_t M_Warp = 1;
static constexpr ck_tile::index_t N_Warp = 2;
static constexpr ck_tile::index_t K_Warp = 1;
static constexpr ck_tile::index_t M_Warp_Tile = 16;
static constexpr ck_tile::index_t N_Warp_Tile = 16;
static constexpr ck_tile::index_t K_Warp_Tile = 32;
static constexpr bool DoubleSmemBuffer = false;
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3;
static constexpr ck_tile::index_t NumGroupsToMerge = 2;
};
template <typename InDataType, typename WeiDataType = InDataType, typename OutDataType = InDataType>
struct ConvTypeConfig;

View File

@@ -51,8 +51,8 @@ int run_grouped_conv_bwd_data_example(int argc, char* argv[])
int main(int argc, char* argv[])
{
#if CK_TILE_USE_WMMA
return !run_grouped_conv_bwd_data_example<GemmConfigComputeV3_WMMA>(argc, argv);
return !run_grouped_conv_bwd_data_example<ConvConfigComputeV3_WMMA>(argc, argv);
#else
return !run_grouped_conv_bwd_data_example<GemmConfigComputeV3>(argc, argv);
return !run_grouped_conv_bwd_data_example<ConvConfigComputeV3>(argc, argv);
#endif
}

View File

@@ -14,7 +14,7 @@
#include "grouped_convolution_backward_weight_invoker.hpp"
#include "run_grouped_convolution_bwd_weight_example.inc"
template <template <typename PrecType> typename GemmConfig>
template <template <typename PrecType> typename ConvConfig>
int run_grouped_conv_bwd_weight_example(ck_tile::ArgParser& arg_parser)
{
using Invoker = GroupedConvolutionBackwardWeightInvoker;
@@ -27,14 +27,14 @@ int run_grouped_conv_bwd_weight_example(ck_tile::ArgParser& arg_parser)
if(data_type == "fp16")
{
return run_grouped_conv_bwd_weight_example_prec_type<Invoker,
GemmConfig<ck_tile::half_t>,
ConvConfig<ck_tile::half_t>,
ck_tile::half_t>(
in_layout, wei_layout, out_layout, arg_parser);
}
else if(data_type == "bf16")
{
return run_grouped_conv_bwd_weight_example_prec_type<Invoker,
GemmConfig<ck_tile::bf16_t>,
ConvConfig<ck_tile::bf16_t>,
ck_tile::bf16_t>(
in_layout, wei_layout, out_layout, arg_parser);
}
@@ -54,9 +54,9 @@ int main(int argc, char* argv[])
try
{
#if CK_TILE_USE_WMMA
return !run_grouped_conv_bwd_weight_example<GemmConfigComputeV3_WMMA>(arg_parser);
return !run_grouped_conv_bwd_weight_example<ConvConfigComputeV3_WMMA>(arg_parser);
#else
return !run_grouped_conv_bwd_weight_example<GemmConfigComputeV3>(arg_parser);
return !run_grouped_conv_bwd_weight_example<ConvConfigComputeV3>(arg_parser);
#endif
}
catch(const std::runtime_error& e)

View File

@@ -7,7 +7,7 @@
struct GroupedConvolutionBackwardWeightInvoker
{
template <ck_tile::index_t NDimSpatial,
typename GemmConfig,
typename ConvConfig,
typename InDataType,
typename WeiDataType,
typename AccDataType,
@@ -25,22 +25,22 @@ struct GroupedConvolutionBackwardWeightInvoker
// Implicit GEMM Traits
using GemmShape = ck_tile::TileGemmShape<
ck_tile::sequence<GemmConfig::M_Tile, GemmConfig::N_Tile, GemmConfig::K_Tile>,
ck_tile::sequence<GemmConfig::M_Warp, GemmConfig::N_Warp, GemmConfig::K_Warp>,
ck_tile::sequence<ConvConfig::M_Tile, ConvConfig::N_Tile, ConvConfig::K_Tile>,
ck_tile::sequence<ConvConfig::M_Warp, ConvConfig::N_Warp, ConvConfig::K_Warp>,
ck_tile::
sequence<GemmConfig::M_Warp_Tile, GemmConfig::N_Warp_Tile, GemmConfig::K_Warp_Tile>,
GemmConfig::PermuteA,
GemmConfig::PermuteB>;
sequence<ConvConfig::M_Warp_Tile, ConvConfig::N_Warp_Tile, ConvConfig::K_Warp_Tile>,
ConvConfig::PermuteA,
ConvConfig::PermuteB>;
constexpr ck_tile::index_t VectorSizeA = 4;
constexpr ck_tile::index_t VectorSizeB = 8;
constexpr ck_tile::index_t VectorSizeC = 8;
constexpr ck_tile::index_t VectorSizeA = ConvConfig::VectorSizeA;
constexpr ck_tile::index_t VectorSizeB = ConvConfig::VectorSizeB;
constexpr ck_tile::index_t VectorSizeC = ConvConfig::VectorSizeC;
constexpr auto ConvSpec = ck_tile::ConvolutionSpecialization::Default;
using TilePartitioner =
ck_tile::GemmSpatiallyLocalTilePartitioner<GemmShape,
GemmConfig::TileParitionerGroupNum,
GemmConfig::TileParitionerM01>;
ConvConfig::TileParitionerGroupNum,
ConvConfig::TileParitionerM01>;
using GroupedConvTraitsType = ck_tile::GroupedConvTraits<NDimSpatial,
ConvSpec,
InLayout,
@@ -49,20 +49,21 @@ struct GroupedConvolutionBackwardWeightInvoker
OutLayout,
VectorSizeA,
VectorSizeB,
VectorSizeC>;
VectorSizeC,
ConvConfig::NumGroupsToMerge>;
using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits<
GemmConfig::kPadM,
GemmConfig::kPadN,
GemmConfig::kPadK,
GemmConfig::DoubleSmemBuffer,
ConvConfig::kPadM,
ConvConfig::kPadN,
ConvConfig::kPadK,
ConvConfig::DoubleSmemBuffer,
typename GroupedConvTraitsType::GroupedConvImplicitGemmTraitsBwdWeight::AsLayout,
typename GroupedConvTraitsType::GroupedConvImplicitGemmTraitsBwdWeight::BsLayout,
typename GroupedConvTraitsType::GroupedConvImplicitGemmTraitsBwdWeight::CLayout,
GemmConfig::TransposeC,
GemmConfig::UseStructuredSparsity,
ConvConfig::TransposeC,
ConvConfig::UseStructuredSparsity,
false, // Persistent,
GemmConfig::NumWaveGroups>;
ConvConfig::NumWaveGroups>;
using GemmPipelineProblem = ck_tile::GemmPipelineProblem<
OutDataType,
@@ -78,7 +79,7 @@ struct GroupedConvolutionBackwardWeightInvoker
VectorSizeB>;
using BaseGemmPipeline = typename PipelineTypeTraits<
GemmConfig::Pipeline>::template UniversalGemmPipeline<GemmPipelineProblem>;
ConvConfig::Pipeline>::template UniversalGemmPipeline<GemmPipelineProblem>;
const ck_tile::index_t gemm_k =
args.N_ * std::accumulate(args.output_spatial_lengths_.begin(),
@@ -86,8 +87,8 @@ struct GroupedConvolutionBackwardWeightInvoker
1,
std::multiplies<ck_tile::index_t>());
const ck_tile::index_t k_grain = args.k_batch * GemmConfig::K_Tile;
const ck_tile::index_t K_split = (gemm_k + k_grain - 1) / k_grain * GemmConfig::K_Tile;
const ck_tile::index_t k_grain = args.k_batch * ConvConfig::K_Tile;
const ck_tile::index_t K_split = (gemm_k + k_grain - 1) / k_grain * ConvConfig::K_Tile;
const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split);
const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop);
const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop);
@@ -98,7 +99,7 @@ struct GroupedConvolutionBackwardWeightInvoker
const auto memory_operation_) {
constexpr bool has_hot_loop_v = has_hot_loop_.value;
constexpr auto tail_number_v = tail_number_.value;
constexpr auto scheduler = GemmConfig::Scheduler;
constexpr auto scheduler = ConvConfig::Scheduler;
constexpr auto memory_operation = memory_operation_.value;
using UniversalGemmProblem =
@@ -118,7 +119,7 @@ struct GroupedConvolutionBackwardWeightInvoker
VectorSizeB>;
using GemmPipeline = typename PipelineTypeTraits<
GemmConfig::Pipeline>::template GemmPipeline<UniversalGemmProblem>;
ConvConfig::Pipeline>::template GemmPipeline<UniversalGemmProblem>;
using ConvEpilogue = ck_tile::CShuffleEpilogue<ck_tile::CShuffleEpilogueProblem<
OutDataType,
@@ -131,12 +132,12 @@ struct GroupedConvolutionBackwardWeightInvoker
CDEElementWise,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
GemmConfig::M_Warp,
GemmConfig::N_Warp,
GemmConfig::M_Warp_Tile,
GemmConfig::N_Warp_Tile,
GemmConfig::K_Warp_Tile,
GemmConfig::TransposeC,
ConvConfig::M_Warp,
ConvConfig::N_Warp,
ConvConfig::M_Warp_Tile,
ConvConfig::N_Warp_Tile,
ConvConfig::K_Warp_Tile,
ConvConfig::TransposeC,
memory_operation,
1,
true,

View File

@@ -13,9 +13,9 @@
#include "grouped_convolution_utils.hpp"
#include "grouped_convolution_backward_weight_two_stage_invoker.hpp"
#include "run_grouped_convolution_bwd_weight_example.inc"
#include "gemm_configs.hpp"
#include "conv_configs.hpp"
template <template <typename PrecType> typename GemmConfig>
template <template <typename PrecType> typename ConvConfig>
int run_grouped_conv_bwd_weight_example(ck_tile::ArgParser& arg_parser)
{
using Invoker = GroupedConvolutionBackwardWeightTwoStageInvoker;
@@ -28,14 +28,14 @@ int run_grouped_conv_bwd_weight_example(ck_tile::ArgParser& arg_parser)
if(data_type == "fp16")
{
return run_grouped_conv_bwd_weight_example_prec_type<Invoker,
GemmConfig<ck_tile::half_t>,
ConvConfig<ck_tile::half_t>,
ck_tile::half_t>(
in_layout, wei_layout, out_layout, arg_parser);
}
else if(data_type == "bf16")
{
return run_grouped_conv_bwd_weight_example_prec_type<Invoker,
GemmConfig<ck_tile::bf16_t>,
ConvConfig<ck_tile::bf16_t>,
ck_tile::bf16_t>(
in_layout, wei_layout, out_layout, arg_parser);
}
@@ -55,9 +55,9 @@ int main(int argc, char* argv[])
try
{
#if CK_TILE_USE_WMMA
return !run_grouped_conv_bwd_weight_example<GemmConfigComputeV3_WMMA>(arg_parser);
return !run_grouped_conv_bwd_weight_example<ConvConfigComputeV3_WMMA>(arg_parser);
#else
return !run_grouped_conv_bwd_weight_example<GemmConfigComputeV3>(arg_parser);
return !run_grouped_conv_bwd_weight_example<ConvConfigComputeV3>(arg_parser);
#endif
}
catch(const std::runtime_error& e)

View File

@@ -7,7 +7,7 @@
struct GroupedConvolutionBackwardWeightTwoStageInvoker
{
template <ck_tile::index_t NDimSpatial,
typename GemmConfig,
typename ConvConfig,
typename InDataType,
typename WeiDataType,
typename AccDataType,
@@ -27,12 +27,12 @@ struct GroupedConvolutionBackwardWeightTwoStageInvoker
// Implicit GEMM Traits
using GemmShape = ck_tile::TileGemmShape<
ck_tile::sequence<GemmConfig::M_Tile, GemmConfig::N_Tile, GemmConfig::K_Tile>,
ck_tile::sequence<GemmConfig::M_Warp, GemmConfig::N_Warp, GemmConfig::K_Warp>,
ck_tile::sequence<ConvConfig::M_Tile, ConvConfig::N_Tile, ConvConfig::K_Tile>,
ck_tile::sequence<ConvConfig::M_Warp, ConvConfig::N_Warp, ConvConfig::K_Warp>,
ck_tile::
sequence<GemmConfig::M_Warp_Tile, GemmConfig::N_Warp_Tile, GemmConfig::K_Warp_Tile>,
GemmConfig::PermuteA,
GemmConfig::PermuteB>;
sequence<ConvConfig::M_Warp_Tile, ConvConfig::N_Warp_Tile, ConvConfig::K_Warp_Tile>,
ConvConfig::PermuteA,
ConvConfig::PermuteB>;
constexpr ck_tile::index_t VectorSizeA = 4;
constexpr ck_tile::index_t VectorSizeB = 8;
@@ -41,8 +41,8 @@ struct GroupedConvolutionBackwardWeightTwoStageInvoker
constexpr auto ConvSpec = ck_tile::ConvolutionSpecialization::Default;
using TilePartitioner =
ck_tile::GemmSpatiallyLocalTilePartitioner<GemmShape,
GemmConfig::TileParitionerGroupNum,
GemmConfig::TileParitionerM01>;
ConvConfig::TileParitionerGroupNum,
ConvConfig::TileParitionerM01>;
using GroupedConvTraitsType = ck_tile::GroupedConvTraits<NDimSpatial,
ConvSpec,
InLayout,
@@ -54,17 +54,17 @@ struct GroupedConvolutionBackwardWeightTwoStageInvoker
VectorSizeC>;
using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits<
GemmConfig::kPadM,
GemmConfig::kPadN,
GemmConfig::kPadK,
GemmConfig::DoubleSmemBuffer,
ConvConfig::kPadM,
ConvConfig::kPadN,
ConvConfig::kPadK,
ConvConfig::DoubleSmemBuffer,
typename GroupedConvTraitsType::GroupedConvImplicitGemmTraitsBwdWeight::AsLayout,
typename GroupedConvTraitsType::GroupedConvImplicitGemmTraitsBwdWeight::BsLayout,
typename GroupedConvTraitsType::GroupedConvImplicitGemmTraitsBwdWeight::CLayout,
GemmConfig::TransposeC,
GemmConfig::UseStructuredSparsity,
ConvConfig::TransposeC,
ConvConfig::UseStructuredSparsity,
false, // Persistent,
GemmConfig::NumWaveGroups>;
ConvConfig::NumWaveGroups>;
using GemmPipelineProblem = ck_tile::GemmPipelineProblem<
OutDataType,
@@ -80,7 +80,7 @@ struct GroupedConvolutionBackwardWeightTwoStageInvoker
VectorSizeB>;
using BaseGemmPipeline = typename PipelineTypeTraits<
GemmConfig::Pipeline>::template UniversalGemmPipeline<GemmPipelineProblem>;
ConvConfig::Pipeline>::template UniversalGemmPipeline<GemmPipelineProblem>;
const ck_tile::index_t gemm_k =
args.N_ * std::accumulate(args.output_spatial_lengths_.begin(),
@@ -88,8 +88,8 @@ struct GroupedConvolutionBackwardWeightTwoStageInvoker
1,
std::multiplies<ck_tile::index_t>());
const ck_tile::index_t k_grain = args.k_batch * GemmConfig::K_Tile;
const ck_tile::index_t K_split = (gemm_k + k_grain - 1) / k_grain * GemmConfig::K_Tile;
const ck_tile::index_t k_grain = args.k_batch * ConvConfig::K_Tile;
const ck_tile::index_t K_split = (gemm_k + k_grain - 1) / k_grain * ConvConfig::K_Tile;
const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split);
const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop);
const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop);
@@ -100,7 +100,7 @@ struct GroupedConvolutionBackwardWeightTwoStageInvoker
const auto memory_operation_) {
constexpr bool has_hot_loop_v = has_hot_loop_.value;
constexpr auto tail_number_v = tail_number_.value;
constexpr auto scheduler = GemmConfig::Scheduler;
constexpr auto scheduler = ConvConfig::Scheduler;
constexpr auto memory_operation = memory_operation_.value;
using UniversalGemmProblem =
@@ -120,7 +120,7 @@ struct GroupedConvolutionBackwardWeightTwoStageInvoker
VectorSizeB>;
using GemmPipeline = typename PipelineTypeTraits<
GemmConfig::Pipeline>::template GemmPipeline<UniversalGemmProblem>;
ConvConfig::Pipeline>::template GemmPipeline<UniversalGemmProblem>;
using ConvEpilogue = ck_tile::CShuffleEpilogue<ck_tile::CShuffleEpilogueProblem<
OutDataType, // A: Out
@@ -133,11 +133,11 @@ struct GroupedConvolutionBackwardWeightTwoStageInvoker
CDEElementWise,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
GemmConfig::M_Warp,
GemmConfig::N_Warp,
GemmConfig::M_Warp_Tile,
GemmConfig::N_Warp_Tile,
GemmConfig::K_Warp_Tile,
ConvConfig::M_Warp,
ConvConfig::N_Warp,
ConvConfig::M_Warp_Tile,
ConvConfig::N_Warp_Tile,
ConvConfig::K_Warp_Tile,
GemmPipelineProblem::TransposeC,
memory_operation,
1,

View File

@@ -51,8 +51,8 @@ int run_grouped_conv_fwd_example(int argc, char* argv[])
int main(int argc, char* argv[])
{
#if CK_TILE_USE_WMMA
return !run_grouped_conv_fwd_example<GemmConfigComputeV3_WMMA>(argc, argv);
return !run_grouped_conv_fwd_example<ConvConfigComputeV3_WMMA>(argc, argv);
#else
return !run_grouped_conv_fwd_example<GemmConfigComputeV3>(argc, argv);
return !run_grouped_conv_fwd_example<ConvConfigComputeV3>(argc, argv);
#endif
}

View File

@@ -32,9 +32,10 @@ struct GroupedConvolutionForwardInvoker
GemmConfig::PermuteA,
GemmConfig::PermuteB>;
constexpr ck_tile::index_t VectorSizeA = 8;
constexpr ck_tile::index_t VectorSizeB = 8;
constexpr ck_tile::index_t VectorSizeC = 8;
constexpr ck_tile::index_t VectorSizeA = 8;
constexpr ck_tile::index_t VectorSizeB = 8;
constexpr ck_tile::index_t VectorSizeC = 8;
constexpr ck_tile::index_t NumGroupsToMerge = 1;
constexpr auto ConvSpec = ck_tile::ConvolutionSpecialization::Default;
using TilePartitioner =
@@ -50,6 +51,7 @@ struct GroupedConvolutionForwardInvoker
VectorSizeA,
VectorSizeB,
VectorSizeC,
NumGroupsToMerge,
CDElementWise>;
using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits<

View File

@@ -11,24 +11,11 @@
#include "ck_tile/ops/epilogue.hpp"
#include "ck_tile/ops/gemm.hpp"
#include "ck_tile/ops/grouped_convolution.hpp"
#include "gemm_configs.hpp"
#include "conv_configs.hpp"
using MemoryOpSet =
std::integral_constant<ck_tile::memory_operation_enum, ck_tile::memory_operation_enum::set>;
using MemoryOpAtomicAdd = std::integral_constant<ck_tile::memory_operation_enum,
ck_tile::memory_operation_enum::atomic_add>;
struct GemmWarpConfig_Mfma
{
static constexpr ck_tile::index_t M_Warp_Tile = 32;
static constexpr ck_tile::index_t N_Warp_Tile = 32;
static constexpr ck_tile::index_t K_Warp_Tile = 16;
};
struct GemmWarpConfig_Wmma
{
static constexpr ck_tile::index_t M_Warp_Tile = 16;
static constexpr ck_tile::index_t N_Warp_Tile = 16;
static constexpr ck_tile::index_t K_Warp_Tile = 16;
};
template <typename InDataType, typename WeiDataType, typename AccDataType, typename OutDataType>
auto calculate_rtol_atol(const ck_tile::index_t GemmK,

View File

@@ -3,7 +3,7 @@
#pragma once
template <ck_tile::index_t NDimSpatial,
typename GemmConfig,
typename ConvConfig,
typename Invoker,
typename InDataType,
typename WeiDataType,
@@ -17,7 +17,7 @@ float invoke_grouped_conv_bwd_weight(ck_tile::GroupedConvBwdWeightHostArgs& args
int n_repeat)
{
float ave_time = Invoker::template grouped_conv_bwd_weight<NDimSpatial,
GemmConfig,
ConvConfig,
InDataType,
WeiDataType,
AccDataType,
@@ -31,7 +31,7 @@ float invoke_grouped_conv_bwd_weight(ck_tile::GroupedConvBwdWeightHostArgs& args
}
template <ck_tile::index_t NDimSpatial,
typename GemmConfig,
typename ConvConfig,
typename Invoker,
typename InDataType,
typename WeiDataType = InDataType,
@@ -131,7 +131,7 @@ int run_grouped_conv_bwd_weight_example_with_layouts(ck_tile::ArgParser& arg_par
std::cout << "output: " << output.mDesc << std::endl;
float ave_time = invoke_grouped_conv_bwd_weight<NDimSpatial,
GemmConfig,
ConvConfig,
Invoker,
InDataType,
WeiDataType,
@@ -193,7 +193,7 @@ int run_grouped_conv_bwd_weight_example_with_layouts(ck_tile::ArgParser& arg_par
}
template <typename Invoker,
typename GemmConfig,
typename ConvConfig,
typename InPrecType,
typename WeiPrecType = InPrecType,
typename OutPrecType = InPrecType>
@@ -217,7 +217,7 @@ int run_grouped_conv_bwd_weight_example_prec_type(std::string in_layout,
if(in_layout == "NWGC" && wei_layout == "GKXC" && out_layout == "NWGK")
{
return run_grouped_conv_bwd_weight_example_with_layouts<ck_tile::number<1>{},
GemmConfig,
ConvConfig,
Invoker,
InPrecType,
WeiPrecType,
@@ -227,7 +227,7 @@ int run_grouped_conv_bwd_weight_example_prec_type(std::string in_layout,
else if(in_layout == "NHWGC" && wei_layout == "GKYXC" && out_layout == "NHWGK")
{
return run_grouped_conv_bwd_weight_example_with_layouts<ck_tile::number<2>{},
GemmConfig,
ConvConfig,
Invoker,
InPrecType,
WeiPrecType,
@@ -237,7 +237,7 @@ int run_grouped_conv_bwd_weight_example_prec_type(std::string in_layout,
else if(in_layout == "NDHWGC" && wei_layout == "GKZYXC" && out_layout == "NDHWGK")
{
return run_grouped_conv_bwd_weight_example_with_layouts<ck_tile::number<3>{},
GemmConfig,
ConvConfig,
Invoker,
InPrecType,
WeiPrecType,

View File

@@ -0,0 +1,350 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
// InstanceTraits specialization for DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
//
// CRITICAL MAINTENANCE NOTE:
// This InstanceTraits file MUST be kept strictly in sync with the device implementation header:
// ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp
// "In sync" means that the template parameter order, names, and types in the declaration below
// MUST EXACTLY MATCH those in the device implementation. If these diverge, you may encounter
// compilation errors, subtle template instantiation mismatches, or silent runtime bugs that are
// difficult to diagnose. Always update both files together and review changes carefully.
#pragma once
#include "instance_traits.hpp"
// Forward declaration to avoid circular dependency.
namespace ck::tensor_operation::device {
template <ck::index_t NDimSpatial,
typename ALayout,
typename BLayout,
typename DsLayout,
typename ELayout,
typename ADataType,
typename BDataType,
typename AccDataType,
typename CShuffleDataType,
typename DsDataType,
typename EDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CDEElementwiseOperation,
ConvolutionForwardSpecialization ConvForwardSpecialization,
GemmSpecialization GemmSpec,
ck::index_t NumGemmKPrefetchStage,
ck::index_t BlockSize,
ck::index_t MPerBlock,
ck::index_t NPerBlock,
ck::index_t KPerBlock,
ck::index_t AK1,
ck::index_t BK1,
ck::index_t MPerXDL,
ck::index_t NPerXDL,
ck::index_t MXdlPerWave,
ck::index_t NXdlPerWave,
typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder,
ck::index_t ABlockTransferSrcVectorDim,
ck::index_t ABlockTransferSrcScalarPerVector,
ck::index_t ABlockTransferDstScalarPerVector_AK1,
ck::index_t ABlockLdsExtraM,
typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
typename BBlockTransferThreadClusterArrangeOrder,
typename BBlockTransferSrcAccessOrder,
ck::index_t BBlockTransferSrcVectorDim,
ck::index_t BBlockTransferSrcScalarPerVector,
ck::index_t BBlockTransferDstScalarPerVector_BK1,
ck::index_t BBlockLdsExtraN,
ck::index_t CShuffleMXdlPerWavePerShuffle,
ck::index_t CShuffleNXdlPerWavePerShuffle,
typename CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
ck::index_t CDEBlockTransferScalarPerVector_NPerBlock,
typename AComputeDataType,
typename BComputeDataType,
LoopScheduler LoopSched,
ck::index_t NumGroupsToMerge>
struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle;
} // namespace ck::tensor_operation::device
namespace ck_tile::reflect {
// Specialization for DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
template <ck::index_t NDimSpatial,
typename ALayout_,
typename BLayout_,
typename DsLayout_,
typename ELayout_,
typename ADataType_,
typename BDataType_,
typename AccDataType_,
typename CShuffleDataType_,
typename DsDataType_,
typename EDataType_,
typename AElementwiseOperation_,
typename BElementwiseOperation_,
typename CDEElementwiseOperation_,
ck::tensor_operation::device::ConvolutionForwardSpecialization ConvForwardSpecialization,
ck::tensor_operation::device::GemmSpecialization GemmSpec,
ck::index_t NumGemmKPrefetchStage,
ck::index_t BlockSize,
ck::index_t MPerBlock,
ck::index_t NPerBlock,
ck::index_t KPerBlock,
ck::index_t AK1,
ck::index_t BK1,
ck::index_t MPerXDL,
ck::index_t NPerXDL,
ck::index_t MXdlPerWave,
ck::index_t NXdlPerWave,
typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder_,
ck::index_t ABlockTransferSrcVectorDim,
ck::index_t ABlockTransferSrcScalarPerVector,
ck::index_t ABlockTransferDstScalarPerVector_AK1,
ck::index_t ABlockLdsExtraM,
typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
typename BBlockTransferThreadClusterArrangeOrder,
typename BBlockTransferSrcAccessOrder_,
ck::index_t BBlockTransferSrcVectorDim,
ck::index_t BBlockTransferSrcScalarPerVector,
ck::index_t BBlockTransferDstScalarPerVector_BK1,
ck::index_t BBlockLdsExtraN,
ck::index_t CShuffleMXdlPerWavePerShuffle,
ck::index_t CShuffleNXdlPerWavePerShuffle,
typename CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
ck::index_t CDEBlockTransferScalarPerVector_NPerBlock,
typename AComputeDataType_,
typename BComputeDataType_,
ck::LoopScheduler LoopSched,
ck::index_t NumGroupsToMerge>
struct InstanceTraits<ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<
NDimSpatial,
ALayout_,
BLayout_,
DsLayout_,
ELayout_,
ADataType_,
BDataType_,
AccDataType_,
CShuffleDataType_,
DsDataType_,
EDataType_,
AElementwiseOperation_,
BElementwiseOperation_,
CDEElementwiseOperation_,
ConvForwardSpecialization,
GemmSpec,
NumGemmKPrefetchStage,
BlockSize,
MPerBlock,
NPerBlock,
KPerBlock,
AK1,
BK1,
MPerXDL,
NPerXDL,
MXdlPerWave,
NXdlPerWave,
ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder_,
ABlockTransferSrcVectorDim,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_AK1,
ABlockLdsExtraM,
BBlockTransferThreadClusterLengths_BK0_N_BK1,
BBlockTransferThreadClusterArrangeOrder,
BBlockTransferSrcAccessOrder_,
BBlockTransferSrcVectorDim,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_BK1,
BBlockLdsExtraN,
CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle,
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
CDEBlockTransferScalarPerVector_NPerBlock,
AComputeDataType_,
BComputeDataType_,
LoopSched,
NumGroupsToMerge>>
{
// Spatial dimension
static constexpr int kSpatialDim = NDimSpatial;
// Layout types
using ALayout = ALayout_;
using BLayout = BLayout_;
using DsLayout = DsLayout_;
using ELayout = ELayout_;
// Data types
using ADataType = ADataType_;
using BDataType = BDataType_;
using AccDataType = AccDataType_;
using CShuffleDataType = CShuffleDataType_;
using DsDataType = DsDataType_;
using EDataType = EDataType_;
// Element-wise operations
using AElementwiseOperation = AElementwiseOperation_;
using BElementwiseOperation = BElementwiseOperation_;
using CDEElementwiseOperation = CDEElementwiseOperation_;
// Specialization
static constexpr ck::tensor_operation::device::ConvolutionForwardSpecialization
kConvForwardSpecialization = ConvForwardSpecialization;
static constexpr ck::tensor_operation::device::GemmSpecialization kGemmSpecialization =
GemmSpec;
// Prefetch stage
static constexpr int kNumGemmKPrefetchStage = NumGemmKPrefetchStage;
// Block configuration
static constexpr int kBlockSize = BlockSize;
static constexpr int kMPerBlock = MPerBlock;
static constexpr int kNPerBlock = NPerBlock;
static constexpr int kKPerBlock = KPerBlock;
// Tuning parameters
static constexpr int kAK1 = AK1;
static constexpr int kBK1 = BK1;
static constexpr int kMPerXDL = MPerXDL;
static constexpr int kNPerXDL = NPerXDL;
static constexpr int kMXdlPerWave = MXdlPerWave;
static constexpr int kNXdlPerWave = NXdlPerWave;
// A block transfer thread cluster dimensions (converted to std::array)
static constexpr auto kAThreadClusterLengths =
detail::SequenceToArray<ABlockTransferThreadClusterLengths_AK0_M_AK1>::value;
static constexpr auto kAThreadClusterArrangeOrder =
detail::SequenceToArray<ABlockTransferThreadClusterArrangeOrder>::value;
static constexpr auto kABlockTransferSrcAccessOrder =
detail::SequenceToArray<ABlockTransferSrcAccessOrder_>::value;
static constexpr int kABlockTransferSrcVectorDim = ABlockTransferSrcVectorDim;
static constexpr int kABlockTransferSrcScalarPerVector = ABlockTransferSrcScalarPerVector;
static constexpr int kABlockTransferDstScalarPerVectorK1 = ABlockTransferDstScalarPerVector_AK1;
static constexpr int kABlockLdsExtraM = ABlockLdsExtraM;
// B block transfer thread cluster dimensions (converted to std::array)
static constexpr auto kBThreadClusterLengths =
detail::SequenceToArray<BBlockTransferThreadClusterLengths_BK0_N_BK1>::value;
static constexpr auto kBThreadClusterArrangeOrder =
detail::SequenceToArray<BBlockTransferThreadClusterArrangeOrder>::value;
static constexpr auto kBBlockTransferSrcAccessOrder =
detail::SequenceToArray<BBlockTransferSrcAccessOrder_>::value;
static constexpr int kBBlockTransferSrcVectorDim = BBlockTransferSrcVectorDim;
static constexpr int kBBlockTransferSrcScalarPerVector = BBlockTransferSrcScalarPerVector;
static constexpr int kBBlockTransferDstScalarPerVectorK1 = BBlockTransferDstScalarPerVector_BK1;
static constexpr int kBBlockLdsExtraN = BBlockLdsExtraN;
// C shuffle parameters (converted to std::array)
static constexpr int kCShuffleMXdlPerWavePerShuffle = CShuffleMXdlPerWavePerShuffle;
static constexpr int kCShuffleNXdlPerWavePerShuffle = CShuffleNXdlPerWavePerShuffle;
static constexpr auto kCThreadClusterLengths = detail::SequenceToArray<
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock>::value;
static constexpr int kCBlockTransferScalarPerVector = CDEBlockTransferScalarPerVector_NPerBlock;
// Compute data types
using AComputeDataType = AComputeDataType_;
using BComputeDataType = BComputeDataType_;
// Loop scheduler
static constexpr ck::LoopScheduler kLoopScheduler = LoopSched;
// Groups to merge
static constexpr int kNumGroupsToMerge = NumGroupsToMerge;
// Static member function to generate instance string
static std::string instance_string()
{
std::ostringstream oss;
// Kernel type name
oss << "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle";
// Template parameters in exact order matching InstanceTraits member order
oss << "<" << kSpatialDim; // 1. NDimSpatial
oss << "," << detail::layout_name<ALayout>(); // 2. ALayout
oss << "," << detail::layout_name<BLayout>(); // 3. BLayout
oss << "," << detail::tuple_name<DsLayout>(); // 4. DsLayout
oss << "," << detail::layout_name<ELayout>(); // 5. ELayout
oss << "," << detail::type_name<ADataType>(); // 6. ADataType
oss << "," << detail::type_name<BDataType>(); // 7. BDataType
oss << "," << detail::type_name<AccDataType>(); // 8. AccDataType
oss << "," << detail::type_name<CShuffleDataType>(); // 9. CShuffleDataType
oss << "," << detail::tuple_name<DsDataType>(); // 10. DsDataType
oss << "," << detail::type_name<EDataType>(); // 11. EDataType
oss << ","
<< detail::elementwise_op_name<AElementwiseOperation>(); // 12. AElementwiseOperation
oss << ","
<< detail::elementwise_op_name<BElementwiseOperation>(); // 13. BElementwiseOperation
oss << ","
<< detail::elementwise_op_name<CDEElementwiseOperation>(); // 14.
// CDEElementwiseOperation
oss << ","
<< detail::conv_fwd_spec_name(
kConvForwardSpecialization); // 15. ConvForwardSpecialization
oss << "," << detail::gemm_spec_name(kGemmSpecialization); // 16. GemmSpec
oss << "," << kNumGemmKPrefetchStage; // 17. NumGemmKPrefetchStage
oss << "," << kBlockSize; // 18. BlockSize
oss << "," << kMPerBlock; // 19. MPerBlock
oss << "," << kNPerBlock; // 20. NPerBlock
oss << "," << kKPerBlock; // 21. KPerBlock
oss << "," << kAK1; // 22. AK1
oss << "," << kBK1; // 23. BK1
oss << "," << kMPerXDL; // 24. MPerXDL
oss << "," << kNPerXDL; // 25. NPerXDL
oss << "," << kMXdlPerWave; // 26. MXdlPerWave
oss << "," << kNXdlPerWave; // 27. NXdlPerWave
oss << ","
<< detail::array_to_string(
kAThreadClusterLengths); // 28. ABlockTransferThreadClusterLengths
oss << ","
<< detail::array_to_string(
kAThreadClusterArrangeOrder); // 29. ABlockTransferThreadClusterArrangeOrder
oss << ","
<< detail::array_to_string(
kABlockTransferSrcAccessOrder); // 30. ABlockTransferSrcAccessOrder
oss << "," << kABlockTransferSrcVectorDim; // 31. ABlockTransferSrcVectorDim
oss << "," << kABlockTransferSrcScalarPerVector; // 32. ABlockTransferSrcScalarPerVector
oss << ","
<< kABlockTransferDstScalarPerVectorK1; // 33. ABlockTransferDstScalarPerVector_AK1
oss << "," << kABlockLdsExtraM; // 34. ABlockLdsExtraM
oss << ","
<< detail::array_to_string(
kBThreadClusterLengths); // 35. BBlockTransferThreadClusterLengths
oss << ","
<< detail::array_to_string(
kBThreadClusterArrangeOrder); // 36. BBlockTransferThreadClusterArrangeOrder
oss << ","
<< detail::array_to_string(
kBBlockTransferSrcAccessOrder); // 37. BBlockTransferSrcAccessOrder
oss << "," << kBBlockTransferSrcVectorDim; // 38. BBlockTransferSrcVectorDim
oss << "," << kBBlockTransferSrcScalarPerVector; // 39. BBlockTransferSrcScalarPerVector
oss << ","
<< kBBlockTransferDstScalarPerVectorK1; // 40. BBlockTransferDstScalarPerVector_BK1
oss << "," << kBBlockLdsExtraN; // 41. BBlockLdsExtraN
oss << "," << kCShuffleMXdlPerWavePerShuffle; // 42. CShuffleMXdlPerWavePerShuffle
oss << "," << kCShuffleNXdlPerWavePerShuffle; // 43. CShuffleNXdlPerWavePerShuffle
oss << ","
<< detail::array_to_string(
kCThreadClusterLengths); // 44. CDEBlockTransferClusterLengths
oss << ","
<< kCBlockTransferScalarPerVector; // 45. CDEBlockTransferScalarPerVector_NPerBlock
oss << "," << detail::type_name<AComputeDataType>(); // 46. AComputeDataType
oss << "," << detail::type_name<BComputeDataType>(); // 47. BComputeDataType
oss << "," << detail::loop_scheduler_name(kLoopScheduler); // 48. LoopSched
oss << "," << kNumGroupsToMerge; // 49. NumGroupsToMerge
oss << ">";
return oss.str();
}
};
} // namespace ck_tile::reflect

View File

@@ -0,0 +1,344 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
// InstanceTraits specialization for DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor
//
// CRITICAL MAINTENANCE NOTE:
// This InstanceTraits file MUST be kept strictly in sync with the device implementation header:
// ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp
// "In sync" means that the template parameter order, names, and types in the declaration below
// MUST EXACTLY MATCH those in the device implementation. If these diverge, you may encounter
// compilation errors, subtle template instantiation mismatches, or silent runtime bugs that are
// difficult to diagnose. Always update both files together and review changes carefully.
#pragma once
#include "instance_traits.hpp"
// Forward declaration to avoid circular dependency.
namespace ck::tensor_operation::device {
template <ck::index_t NDimSpatial,
typename ALayout,
typename BLayout,
typename DsLayout,
typename ELayout,
typename ADataType,
typename BDataType,
typename AccDataType,
typename CShuffleDataType,
typename DsDataType,
typename EDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CDEElementwiseOperation,
ConvolutionForwardSpecialization ConvForwardSpecialization,
GemmSpecialization GemmSpec,
ck::index_t NumGemmKPrefetchStage,
ck::index_t BlockSize,
ck::index_t MPerBlock,
ck::index_t NPerBlock,
ck::index_t KPerBlock,
ck::index_t AK1,
ck::index_t BK1,
ck::index_t MPerXDL,
ck::index_t NPerXDL,
ck::index_t MXdlPerWave,
ck::index_t NXdlPerWave,
typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder,
ck::index_t ABlockTransferSrcVectorDim,
ck::index_t ABlockTransferSrcScalarPerVector,
ck::index_t ABlockTransferDstScalarPerVector_AK1,
ck::index_t ABlockLdsExtraM,
typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
typename BBlockTransferThreadClusterArrangeOrder,
typename BBlockTransferSrcAccessOrder,
ck::index_t BBlockTransferSrcVectorDim,
ck::index_t BBlockTransferSrcScalarPerVector,
ck::index_t BBlockTransferDstScalarPerVector_BK1,
ck::index_t BBlockLdsExtraN,
ck::index_t CShuffleMXdlPerWavePerShuffle,
ck::index_t CShuffleNXdlPerWavePerShuffle,
typename CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
ck::index_t CDEBlockTransferScalarPerVector_NPerBlock,
typename AComputeDataType,
typename BComputeDataType,
LoopScheduler LoopSched>
struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor;
} // namespace ck::tensor_operation::device
namespace ck_tile::reflect {
// Specialization for DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor
template <ck::index_t NDimSpatial,
typename ALayout_,
typename BLayout_,
typename DsLayout_,
typename ELayout_,
typename ADataType_,
typename BDataType_,
typename AccDataType_,
typename CShuffleDataType_,
typename DsDataType_,
typename EDataType_,
typename AElementwiseOperation_,
typename BElementwiseOperation_,
typename CDEElementwiseOperation_,
ck::tensor_operation::device::ConvolutionForwardSpecialization ConvForwardSpecialization,
ck::tensor_operation::device::GemmSpecialization GemmSpec,
ck::index_t NumGemmKPrefetchStage,
ck::index_t BlockSize,
ck::index_t MPerBlock,
ck::index_t NPerBlock,
ck::index_t KPerBlock,
ck::index_t AK1,
ck::index_t BK1,
ck::index_t MPerXDL,
ck::index_t NPerXDL,
ck::index_t MXdlPerWave,
ck::index_t NXdlPerWave,
typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder_,
ck::index_t ABlockTransferSrcVectorDim,
ck::index_t ABlockTransferSrcScalarPerVector,
ck::index_t ABlockTransferDstScalarPerVector_AK1,
ck::index_t ABlockLdsExtraM,
typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
typename BBlockTransferThreadClusterArrangeOrder,
typename BBlockTransferSrcAccessOrder_,
ck::index_t BBlockTransferSrcVectorDim,
ck::index_t BBlockTransferSrcScalarPerVector,
ck::index_t BBlockTransferDstScalarPerVector_BK1,
ck::index_t BBlockLdsExtraN,
ck::index_t CShuffleMXdlPerWavePerShuffle,
ck::index_t CShuffleNXdlPerWavePerShuffle,
typename CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
ck::index_t CDEBlockTransferScalarPerVector_NPerBlock,
typename AComputeDataType_,
typename BComputeDataType_,
ck::LoopScheduler LoopSched>
struct InstanceTraits<
ck::tensor_operation::device::DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor<
NDimSpatial,
ALayout_,
BLayout_,
DsLayout_,
ELayout_,
ADataType_,
BDataType_,
AccDataType_,
CShuffleDataType_,
DsDataType_,
EDataType_,
AElementwiseOperation_,
BElementwiseOperation_,
CDEElementwiseOperation_,
ConvForwardSpecialization,
GemmSpec,
NumGemmKPrefetchStage,
BlockSize,
MPerBlock,
NPerBlock,
KPerBlock,
AK1,
BK1,
MPerXDL,
NPerXDL,
MXdlPerWave,
NXdlPerWave,
ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder_,
ABlockTransferSrcVectorDim,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_AK1,
ABlockLdsExtraM,
BBlockTransferThreadClusterLengths_BK0_N_BK1,
BBlockTransferThreadClusterArrangeOrder,
BBlockTransferSrcAccessOrder_,
BBlockTransferSrcVectorDim,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_BK1,
BBlockLdsExtraN,
CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle,
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
CDEBlockTransferScalarPerVector_NPerBlock,
AComputeDataType_,
BComputeDataType_,
LoopSched>>
{
// Spatial dimension
static constexpr int kSpatialDim = NDimSpatial;
// Layout types
using ALayout = ALayout_;
using BLayout = BLayout_;
using DsLayout = DsLayout_;
using ELayout = ELayout_;
// Data types
using ADataType = ADataType_;
using BDataType = BDataType_;
using AccDataType = AccDataType_;
using CShuffleDataType = CShuffleDataType_;
using DsDataType = DsDataType_;
using EDataType = EDataType_;
// Element-wise operations
using AElementwiseOperation = AElementwiseOperation_;
using BElementwiseOperation = BElementwiseOperation_;
using CDEElementwiseOperation = CDEElementwiseOperation_;
// Specialization
static constexpr ck::tensor_operation::device::ConvolutionForwardSpecialization
kConvForwardSpecialization = ConvForwardSpecialization;
static constexpr ck::tensor_operation::device::GemmSpecialization kGemmSpecialization =
GemmSpec;
// Prefetch stage
static constexpr int kNumGemmKPrefetchStage = NumGemmKPrefetchStage;
// Block configuration
static constexpr int kBlockSize = BlockSize;
static constexpr int kMPerBlock = MPerBlock;
static constexpr int kNPerBlock = NPerBlock;
static constexpr int kKPerBlock = KPerBlock;
// Tuning parameters
static constexpr int kAK1 = AK1;
static constexpr int kBK1 = BK1;
static constexpr int kMPerXDL = MPerXDL;
static constexpr int kNPerXDL = NPerXDL;
static constexpr int kMXdlPerWave = MXdlPerWave;
static constexpr int kNXdlPerWave = NXdlPerWave;
// A block transfer thread cluster dimensions (converted to std::array)
static constexpr auto kAThreadClusterLengths =
detail::SequenceToArray<ABlockTransferThreadClusterLengths_AK0_M_AK1>::value;
static constexpr auto kAThreadClusterArrangeOrder =
detail::SequenceToArray<ABlockTransferThreadClusterArrangeOrder>::value;
static constexpr auto kABlockTransferSrcAccessOrder =
detail::SequenceToArray<ABlockTransferSrcAccessOrder_>::value;
static constexpr int kABlockTransferSrcVectorDim = ABlockTransferSrcVectorDim;
static constexpr int kABlockTransferSrcScalarPerVector = ABlockTransferSrcScalarPerVector;
static constexpr int kABlockTransferDstScalarPerVectorK1 = ABlockTransferDstScalarPerVector_AK1;
static constexpr int kABlockLdsExtraM = ABlockLdsExtraM;
// B block transfer thread cluster dimensions (converted to std::array)
static constexpr auto kBThreadClusterLengths =
detail::SequenceToArray<BBlockTransferThreadClusterLengths_BK0_N_BK1>::value;
static constexpr auto kBThreadClusterArrangeOrder =
detail::SequenceToArray<BBlockTransferThreadClusterArrangeOrder>::value;
static constexpr auto kBBlockTransferSrcAccessOrder =
detail::SequenceToArray<BBlockTransferSrcAccessOrder_>::value;
static constexpr int kBBlockTransferSrcVectorDim = BBlockTransferSrcVectorDim;
static constexpr int kBBlockTransferSrcScalarPerVector = BBlockTransferSrcScalarPerVector;
static constexpr int kBBlockTransferDstScalarPerVectorK1 = BBlockTransferDstScalarPerVector_BK1;
static constexpr int kBBlockLdsExtraN = BBlockLdsExtraN;
// C shuffle parameters (converted to std::array)
static constexpr int kCShuffleMXdlPerWavePerShuffle = CShuffleMXdlPerWavePerShuffle;
static constexpr int kCShuffleNXdlPerWavePerShuffle = CShuffleNXdlPerWavePerShuffle;
static constexpr auto kCThreadClusterLengths = detail::SequenceToArray<
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock>::value;
static constexpr int kCBlockTransferScalarPerVector = CDEBlockTransferScalarPerVector_NPerBlock;
// Compute data types
using AComputeDataType = AComputeDataType_;
using BComputeDataType = BComputeDataType_;
// Loop scheduler
static constexpr ck::LoopScheduler kLoopScheduler = LoopSched;
// Static member function to generate instance string
static std::string instance_string()
{
std::ostringstream oss;
// Kernel type name
oss << "DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor";
// Template parameters in exact order matching InstanceTraits member order
oss << "<" << kSpatialDim; // 1. NDimSpatial
oss << "," << detail::layout_name<ALayout>(); // 2. ALayout
oss << "," << detail::layout_name<BLayout>(); // 3. BLayout
oss << "," << detail::tuple_name<DsLayout>(); // 4. DsLayout
oss << "," << detail::layout_name<ELayout>(); // 5. ELayout
oss << "," << detail::type_name<ADataType>(); // 6. ADataType
oss << "," << detail::type_name<BDataType>(); // 7. BDataType
oss << "," << detail::type_name<AccDataType>(); // 8. AccDataType
oss << "," << detail::type_name<CShuffleDataType>(); // 9. CShuffleDataType
oss << "," << detail::tuple_name<DsDataType>(); // 10. DsDataType
oss << "," << detail::type_name<EDataType>(); // 11. EDataType
oss << ","
<< detail::elementwise_op_name<AElementwiseOperation>(); // 12. AElementwiseOperation
oss << ","
<< detail::elementwise_op_name<BElementwiseOperation>(); // 13. BElementwiseOperation
oss << ","
<< detail::elementwise_op_name<CDEElementwiseOperation>(); // 14.
// CDEElementwiseOperation
oss << ","
<< detail::conv_fwd_spec_name(
kConvForwardSpecialization); // 15. ConvForwardSpecialization
oss << "," << detail::gemm_spec_name(kGemmSpecialization); // 16. GemmSpec
oss << "," << kNumGemmKPrefetchStage; // 17. NumGemmKPrefetchStage
oss << "," << kBlockSize; // 18. BlockSize
oss << "," << kMPerBlock; // 19. MPerBlock
oss << "," << kNPerBlock; // 20. NPerBlock
oss << "," << kKPerBlock; // 21. KPerBlock
oss << "," << kAK1; // 22. AK1
oss << "," << kBK1; // 23. BK1
oss << "," << kMPerXDL; // 24. MPerXDL
oss << "," << kNPerXDL; // 25. NPerXDL
oss << "," << kMXdlPerWave; // 26. MXdlPerWave
oss << "," << kNXdlPerWave; // 27. NXdlPerWave
oss << ","
<< detail::array_to_string(
kAThreadClusterLengths); // 28. ABlockTransferThreadClusterLengths
oss << ","
<< detail::array_to_string(
kAThreadClusterArrangeOrder); // 29. ABlockTransferThreadClusterArrangeOrder
oss << ","
<< detail::array_to_string(
kABlockTransferSrcAccessOrder); // 30. ABlockTransferSrcAccessOrder
oss << "," << kABlockTransferSrcVectorDim; // 31. ABlockTransferSrcVectorDim
oss << "," << kABlockTransferSrcScalarPerVector; // 32. ABlockTransferSrcScalarPerVector
oss << ","
<< kABlockTransferDstScalarPerVectorK1; // 33. ABlockTransferDstScalarPerVector_AK1
oss << "," << kABlockLdsExtraM; // 34. ABlockLdsExtraM
oss << ","
<< detail::array_to_string(
kBThreadClusterLengths); // 35. BBlockTransferThreadClusterLengths
oss << ","
<< detail::array_to_string(
kBThreadClusterArrangeOrder); // 36. BBlockTransferThreadClusterArrangeOrder
oss << ","
<< detail::array_to_string(
kBBlockTransferSrcAccessOrder); // 37. BBlockTransferSrcAccessOrder
oss << "," << kBBlockTransferSrcVectorDim; // 38. BBlockTransferSrcVectorDim
oss << "," << kBBlockTransferSrcScalarPerVector; // 39. BBlockTransferSrcScalarPerVector
oss << ","
<< kBBlockTransferDstScalarPerVectorK1; // 40. BBlockTransferDstScalarPerVector_BK1
oss << "," << kBBlockLdsExtraN; // 41. BBlockLdsExtraN
oss << "," << kCShuffleMXdlPerWavePerShuffle; // 42. CShuffleMXdlPerWavePerShuffle
oss << "," << kCShuffleNXdlPerWavePerShuffle; // 43. CShuffleNXdlPerWavePerShuffle
oss << ","
<< detail::array_to_string(
kCThreadClusterLengths); // 44. CDEBlockTransferClusterLengths
oss << ","
<< kCBlockTransferScalarPerVector; // 45. CDEBlockTransferScalarPerVector_NPerBlock
oss << "," << detail::type_name<AComputeDataType>(); // 46. AComputeDataType
oss << "," << detail::type_name<BComputeDataType>(); // 47. BComputeDataType
oss << "," << detail::loop_scheduler_name(kLoopScheduler); // 48. LoopSched
oss << ">";
return oss.str();
}
};
} // namespace ck_tile::reflect

View File

@@ -15,6 +15,7 @@
#include <ck/utility/data_type.hpp>
#include <ck/utility/sequence.hpp>
#include <ck/utility/blkgemmpipe_scheduler.hpp>
#include <ck/utility/loop_scheduler.hpp>
#include <ck/tensor_operation/gpu/device/tensor_layout.hpp>
#include <ck_tile/ops/common/tensor_layout.hpp>
#include <ck/tensor_operation/gpu/element/element_wise_operation.hpp>
@@ -160,6 +161,17 @@ constexpr std::string_view pipeline_version_name(ck::BlockGemmPipelineVersion ve
}
}
// Convert LoopScheduler enum to string
constexpr std::string_view loop_scheduler_name(ck::LoopScheduler sched)
{
using enum ck::LoopScheduler;
switch(sched)
{
case Default: return "Default";
case Interwave: return "Interwave";
}
}
// Convert std::array to string
template <typename T, std::size_t N>
inline std::string array_to_string(const std::array<T, N>& arr)

View File

@@ -26,7 +26,9 @@ add_ck_builder_test(test_inline_diff test_inline_diff.cpp)
# Testing the virtual GetInstanceString methods requires kernel compilation.
add_ck_builder_test(test_get_instance_string
test_get_instance_string.cpp)
test_get_instance_string_fwd_grp_conv_v3.cpp
test_get_instance_string_fwd_grp_conv.cpp
test_get_instance_string_fwd_grp_conv_large_tensor.cpp)
# Testing the fwd convolution builder requires kernel compilation.
# To enable parallel compilation, the individual tests are split into separate files.

View File

@@ -3,19 +3,18 @@
#include <gtest/gtest.h>
#include <gmock/gmock.h>
#include <ck/ck.hpp>
#include <ck/utility/reduction_operator.hpp>
#include <ck_tile/builder/reflect/instance_traits.hpp>
#include <ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp>
#include <ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp>
#include <ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp>
namespace {
using ::testing::ElementsAre;
// Test fixture for InstanceTraits tests
class InstanceTraitsTest : public ::testing::Test
{
};
// Test InstanceTraits with DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
TEST_F(InstanceTraitsTest, ConvFwdInstanceTraitsExtraction)
TEST(InstanceTraitsTest, ConvFwdInstanceTraitsExtraction)
{
// Define a concrete instance type with specific template parameters
using DeviceInstance =
@@ -156,8 +155,7 @@ TEST_F(InstanceTraitsTest, ConvFwdInstanceTraitsExtraction)
ck::tensor_operation::element_wise::PassThrough>::value));
}
// Test instance_string function
TEST_F(InstanceTraitsTest, InstanceStringGeneration)
TEST(InstanceTraitsTest, V3InstanceStringGeneration)
{
// Define a concrete instance type with specific template parameters
using DeviceInstance =
@@ -215,10 +213,8 @@ TEST_F(InstanceTraitsTest, InstanceStringGeneration)
ck::half_t, // AComputeDataType
ck::half_t>; // BComputeDataType
// Generate instance string
std::string instance_str = ck_tile::reflect::instance_string<DeviceInstance>();
// Expected string with all template parameters in exact order
std::string expected_str = "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3"
"<2" // NDimSpatial
",GNHWC" // ALayout
@@ -269,6 +265,234 @@ TEST_F(InstanceTraitsTest, InstanceStringGeneration)
",fp16" // AComputeDataType
",fp16>"; // BComputeDataType
EXPECT_EQ(instance_str, expected_str);
}
TEST(InstanceTraitsTest, BaseInstanceStringGeneration)
{
using DeviceInstance =
ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<
2, // NDimSpatial
ck::tensor_layout::convolution::GNHWC, // ALayout
ck::tensor_layout::convolution::GKYXC, // BLayout
ck::Tuple<>, // DsLayout
ck::tensor_layout::convolution::GNHWK, // ELayout
ck::half_t, // ADataType
ck::half_t, // BDataType
float, // AccDataType
ck::half_t, // CShuffleDataType
ck::Tuple<>, // DsDataType
ck::half_t, // EDataType
ck::tensor_operation::element_wise::PassThrough, // AElementwiseOperation
ck::tensor_operation::element_wise::PassThrough, // BElementwiseOperation
ck::tensor_operation::element_wise::PassThrough, // CDEElementwiseOperation
ck::tensor_operation::device::ConvolutionForwardSpecialization::
Default, // ConvForwardSpecialization
ck::tensor_operation::device::GemmSpecialization::Default, // GemmSpec
1, // NumGemmKPrefetchStage
256, // BlockSize
128, // MPerBlock
128, // NPerBlock
16, // KPerBlock
8, // AK1
8, // BK1
32, // MPerXDL
32, // NPerXDL
4, // MXdlPerWave
4, // NXdlPerWave
ck::Sequence<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1
ck::Sequence<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder
ck::Sequence<1, 0, 2>, // ABlockTransferSrcAccessOrder
2, // ABlockTransferSrcVectorDim
8, // ABlockTransferSrcScalarPerVector
8, // ABlockTransferDstScalarPerVector_AK1
1, // ABlockLdsExtraM
ck::Sequence<4, 64, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1
ck::Sequence<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder
ck::Sequence<1, 0, 2>, // BBlockTransferSrcAccessOrder
2, // BBlockTransferSrcVectorDim
8, // BBlockTransferSrcScalarPerVector
8, // BBlockTransferDstScalarPerVector_BK1
1, // BBlockLdsExtraN
1, // CShuffleMXdlPerWavePerShuffle
1, // CShuffleNXdlPerWavePerShuffle
ck::Sequence<1,
32,
1,
8>, // CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8, // CDEBlockTransferScalarPerVector_NPerBlock
ck::half_t, // AComputeDataType
ck::half_t, // BComputeDataType
ck::LoopScheduler::Default, // LoopSched
1>; // NumGroupsToMerge
std::string instance_str = ck_tile::reflect::instance_string<DeviceInstance>();
std::string expected_str = "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle"
"<2" // NDimSpatial
",GNHWC" // ALayout
",GKYXC" // BLayout
",EmptyTuple" // DsLayout
",GNHWK" // ELayout
",fp16" // ADataType
",fp16" // BDataType
",fp32" // AccDataType
",fp16" // CShuffleDataType
",EmptyTuple" // DsDataType
",fp16" // EDataType
",PassThrough" // AElementwiseOperation
",PassThrough" // BElementwiseOperation
",PassThrough" // CDEElementwiseOperation
",Default" // ConvForwardSpecialization
",Default" // GemmSpec
",1" // NumGemmKPrefetchStage
",256" // BlockSize
",128" // MPerBlock
",128" // NPerBlock
",16" // KPerBlock
",8" // AK1
",8" // BK1
",32" // MPerXDL
",32" // NPerXDL
",4" // MXdlPerWave
",4" // NXdlPerWave
",Seq(4,64,1)" // ABlockTransferThreadClusterLengths
",Seq(1,0,2)" // ABlockTransferThreadClusterArrangeOrder
",Seq(1,0,2)" // ABlockTransferSrcAccessOrder
",2" // ABlockTransferSrcVectorDim
",8" // ABlockTransferSrcScalarPerVector
",8" // ABlockTransferDstScalarPerVector_AK1
",1" // ABlockLdsExtraM
",Seq(4,64,1)" // BBlockTransferThreadClusterLengths
",Seq(1,0,2)" // BBlockTransferThreadClusterArrangeOrder
",Seq(1,0,2)" // BBlockTransferSrcAccessOrder
",2" // BBlockTransferSrcVectorDim
",8" // BBlockTransferSrcScalarPerVector
",8" // BBlockTransferDstScalarPerVector_BK1
",1" // BBlockLdsExtraN
",1" // CShuffleMXdlPerWavePerShuffle
",1" // CShuffleNXdlPerWavePerShuffle
",Seq(1,32,1,8)" // CDEBlockTransferClusterLengths
",8" // CDEBlockTransferScalarPerVector_NPerBlock
",fp16" // AComputeDataType
",fp16" // BComputeDataType
",Default" // LoopSched
",1>"; // NumGroupsToMerge
EXPECT_EQ(instance_str, expected_str);
}
TEST(InstanceTraitsTest, LargeTensorInstanceStringGeneration)
{
using DeviceInstance =
ck::tensor_operation::device::DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor<
2, // NDimSpatial
ck::tensor_layout::convolution::GNHWC, // ALayout
ck::tensor_layout::convolution::GKYXC, // BLayout
ck::Tuple<>, // DsLayout
ck::tensor_layout::convolution::GNHWK, // ELayout
ck::half_t, // ADataType
ck::half_t, // BDataType
float, // AccDataType
ck::half_t, // CShuffleDataType
ck::Tuple<>, // DsDataType
ck::half_t, // EDataType
ck::tensor_operation::element_wise::PassThrough, // AElementwiseOperation
ck::tensor_operation::element_wise::PassThrough, // BElementwiseOperation
ck::tensor_operation::element_wise::PassThrough, // CDEElementwiseOperation
ck::tensor_operation::device::ConvolutionForwardSpecialization::
Default, // ConvForwardSpecialization
ck::tensor_operation::device::GemmSpecialization::Default, // GemmSpec
1, // NumGemmKPrefetchStage
256, // BlockSize
128, // MPerBlock
128, // NPerBlock
16, // KPerBlock
8, // AK1
8, // BK1
32, // MPerXDL
32, // NPerXDL
4, // MXdlPerWave
4, // NXdlPerWave
ck::Sequence<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1
ck::Sequence<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder
ck::Sequence<1, 0, 2>, // ABlockTransferSrcAccessOrder
2, // ABlockTransferSrcVectorDim
8, // ABlockTransferSrcScalarPerVector
8, // ABlockTransferDstScalarPerVector_AK1
1, // ABlockLdsExtraM
ck::Sequence<4, 64, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1
ck::Sequence<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder
ck::Sequence<1, 0, 2>, // BBlockTransferSrcAccessOrder
2, // BBlockTransferSrcVectorDim
8, // BBlockTransferSrcScalarPerVector
8, // BBlockTransferDstScalarPerVector_BK1
1, // BBlockLdsExtraN
1, // CShuffleMXdlPerWavePerShuffle
1, // CShuffleNXdlPerWavePerShuffle
ck::Sequence<1,
32,
1,
8>, // CDEBlockTransferClusterLengths
8, // CDEBlockTransferScalarPerVector_NPerBlock
ck::half_t, // AComputeDataType
ck::half_t, // BComputeDataType
ck::LoopScheduler::Default>; // LoopSched
// Generate instance string
std::string instance_str = ck_tile::reflect::instance_string<DeviceInstance>();
// Expected string with all 48 template parameters
std::string expected_str = "DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor"
"<2" // NDimSpatial
",GNHWC" // ALayout
",GKYXC" // BLayout
",EmptyTuple" // DsLayout
",GNHWK" // ELayout
",fp16" // ADataType
",fp16" // BDataType
",fp32" // AccDataType
",fp16" // CShuffleDataType
",EmptyTuple" // DsDataType
",fp16" // EDataType
",PassThrough" // AElementwiseOperation
",PassThrough" // BElementwiseOperation
",PassThrough" // CDEElementwiseOperation
",Default" // ConvForwardSpecialization
",Default" // GemmSpec
",1" // NumGemmKPrefetchStage
",256" // BlockSize
",128" // MPerBlock
",128" // NPerBlock
",16" // KPerBlock
",8" // AK1
",8" // BK1
",32" // MPerXDL
",32" // NPerXDL
",4" // MXdlPerWave
",4" // NXdlPerWave
",Seq(4,64,1)" // ABlockTransferThreadClusterLengths
",Seq(1,0,2)" // ABlockTransferThreadClusterArrangeOrder
",Seq(1,0,2)" // ABlockTransferSrcAccessOrder
",2" // ABlockTransferSrcVectorDim
",8" // ABlockTransferSrcScalarPerVector
",8" // ABlockTransferDstScalarPerVector_AK1
",1" // ABlockLdsExtraM
",Seq(4,64,1)" // BBlockTransferThreadClusterLengths
",Seq(1,0,2)" // BBlockTransferThreadClusterArrangeOrder
",Seq(1,0,2)" // BBlockTransferSrcAccessOrder
",2" // BBlockTransferSrcVectorDim
",8" // BBlockTransferSrcScalarPerVector
",8" // BBlockTransferDstScalarPerVector_BK1
",1" // BBlockLdsExtraN
",1" // CShuffleMXdlPerWavePerShuffle
",1" // CShuffleNXdlPerWavePerShuffle
",Seq(1,32,1,8)" // CDEBlockTransferClusterLengths
",8" // CDEBlockTransferScalarPerVector_NPerBlock
",fp16" // AComputeDataType
",fp16" // BComputeDataType
",Default>"; // LoopSched
// Verify the generated string matches exactly
EXPECT_EQ(instance_str, expected_str);
}

View File

@@ -0,0 +1,104 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include <gtest/gtest.h>
#include <ck_tile/builder/reflect/instance_traits.hpp>
#include <ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_abd.hpp>
#include <ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp>
// Test GetInstanceString through base class pointer for non-V3 variant
TEST(GetInstanceString, ReturnsStringForFwdGrpConvInstance)
{
// Use the template helper to get a working instance configuration
using InstanceTuple =
ck::tensor_operation::device::instance::device_grouped_conv_fwd_xdl_f16_instances<
2, // NDimSpatial
ck::tensor_operation::device::instance::GNHWC, // ALayout
ck::tensor_operation::device::instance::GKYXC, // BLayout
ck::tensor_operation::device::instance::Empty_Tuple, // DsLayout
ck::tensor_operation::device::instance::GNHWK, // ELayout
ck::tensor_operation::device::instance::ConvFwdDefault>; // ConvForwardSpecialization
// Get the first instance from the tuple
using DeviceInstance = typename std::tuple_element<0, InstanceTuple>::type;
// Define the base class type using DeviceGroupedConvFwdMultipleABD
using BaseClass = ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD<
2, // NDimSpatial
ck::tensor_operation::device::instance::GNHWC, // ALayout
ck::tensor_operation::device::instance::GKYXC, // BLayout
ck::tensor_operation::device::instance::Empty_Tuple, // DsLayout
ck::tensor_operation::device::instance::GNHWK, // ELayout
ck::half_t, // ADataType
ck::half_t, // BDataType
ck::Tuple<>, // DsDataType
ck::half_t, // EDataType
ck::tensor_operation::element_wise::PassThrough, // AElementwiseOperation
ck::tensor_operation::element_wise::PassThrough, // BElementwiseOperation
ck::tensor_operation::element_wise::PassThrough, // CDEElementwiseOperation
ck::half_t, // AComputeType
ck::half_t>; // BComputeType
// Create an instance of the derived class
DeviceInstance device_instance;
// Get a pointer to the base class
BaseClass* base_ptr = &device_instance;
// Call GetInstanceString through the base class pointer
std::string instance_str = base_ptr->GetInstanceString();
// Expected complete instance string based on the first instance from
// device_grouped_conv_fwd_xdl_f16_instances
std::string expected_str = "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle"
"<2" // NDimSpatial
",GNHWC" // ALayout
",GKYXC" // BLayout
",EmptyTuple" // DsLayout
",GNHWK" // ELayout
",fp16" // ADataType
",fp16" // BDataType
",fp32" // AccDataType
",fp16" // CShuffleDataType
",EmptyTuple" // DsDataType
",fp16" // EDataType
",PassThrough" // AElementwiseOperation
",PassThrough" // BElementwiseOperation
",PassThrough" // CDEElementwiseOperation
",Default" // ConvForwardSpecialization
",MNKPadding" // GemmSpec
",1" // NumGemmKPrefetchStage
",64" // BlockSize
",64" // MPerBlock
",64" // NPerBlock
",32" // KPerBlock
",8" // AK1
",8" // BK1
",32" // MPerXDL
",32" // NPerXDL
",2" // MXdlPerWave
",2" // NXdlPerWave
",Seq(4,16,1)" // ABlockTransferThreadClusterLengths
",Seq(1,0,2)" // ABlockTransferThreadClusterArrangeOrder
",Seq(1,0,2)" // ABlockTransferSrcAccessOrder
",2" // ABlockTransferSrcVectorDim
",1" // ABlockTransferSrcScalarPerVector
",8" // ABlockTransferDstScalarPerVector_AK1
",1" // ABlockLdsExtraM
",Seq(4,16,1)" // BBlockTransferThreadClusterLengths
",Seq(1,0,2)" // BBlockTransferThreadClusterArrangeOrder
",Seq(1,0,2)" // BBlockTransferSrcAccessOrder
",2" // BBlockTransferSrcVectorDim
",1" // BBlockTransferSrcScalarPerVector
",8" // BBlockTransferDstScalarPerVector_BK1
",1" // BBlockLdsExtraN
",1" // CShuffleMXdlPerWavePerShuffle
",1" // CShuffleNXdlPerWavePerShuffle
",Seq(1,16,1,4)" // CDEBlockTransferClusterLengths
",1" // CDEBlockTransferScalarPerVector_NPerBlock
",fp16" // AComputeDataType
",fp16" // BComputeDataType
",Default" // LoopScheduler
",1>"; // NumGroupsToMerge
EXPECT_EQ(instance_str, expected_str);
}

View File

@@ -0,0 +1,103 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include <gtest/gtest.h>
#include <ck_tile/builder/reflect/instance_traits.hpp>
#include <ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_abd.hpp>
#include <ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_large_tensor_instance.hpp>
// Test GetInstanceString through base class pointer for large tensor variant
TEST(GetInstanceString, ReturnsStringForFwdGrpConvLargeTensorInstance)
{
// Use the template helper to get a working instance configuration
using InstanceTuple = ck::tensor_operation::device::instance::
device_grouped_conv_fwd_xdl_large_tensor_f16_instances<
2, // NDimSpatial
ck::tensor_operation::device::instance::GNHWC, // ALayout
ck::tensor_operation::device::instance::GKYXC, // BLayout
ck::tensor_operation::device::instance::Empty_Tuple, // DsLayout
ck::tensor_operation::device::instance::GNHWK, // ELayout
ck::tensor_operation::device::instance::ConvFwdDefault>; // ConvForwardSpecialization
// Get the first instance from the tuple
using DeviceInstance = typename std::tuple_element<0, InstanceTuple>::type;
// Define the base class type using DeviceGroupedConvFwdMultipleABD
using BaseClass = ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD<
2, // NDimSpatial
ck::tensor_operation::device::instance::GNHWC, // ALayout
ck::tensor_operation::device::instance::GKYXC, // BLayout
ck::tensor_operation::device::instance::Empty_Tuple, // DsLayout
ck::tensor_operation::device::instance::GNHWK, // ELayout
ck::half_t, // ADataType
ck::half_t, // BDataType
ck::Tuple<>, // DsDataType
ck::half_t, // EDataType
ck::tensor_operation::element_wise::PassThrough, // AElementwiseOperation
ck::tensor_operation::element_wise::PassThrough, // BElementwiseOperation
ck::tensor_operation::element_wise::PassThrough, // CDEElementwiseOperation
ck::half_t, // AComputeType
ck::half_t>; // BComputeType
// Create an instance of the derived class
DeviceInstance device_instance;
// Get a pointer to the base class
BaseClass* base_ptr = &device_instance;
// Call GetInstanceString through the base class pointer
std::string instance_str = base_ptr->GetInstanceString();
// Expected complete instance string based on the first instance from
// device_grouped_conv_fwd_xdl_large_tensor_f16_instances
std::string expected_str = "DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor"
"<2" // NDimSpatial
",GNHWC" // ALayout
",GKYXC" // BLayout
",EmptyTuple" // DsLayout
",GNHWK" // ELayout
",fp16" // ADataType
",fp16" // BDataType
",fp32" // AccDataType
",fp16" // CShuffleDataType
",EmptyTuple" // DsDataType
",fp16" // EDataType
",PassThrough" // AElementwiseOperation
",PassThrough" // BElementwiseOperation
",PassThrough" // CDEElementwiseOperation
",Default" // ConvForwardSpecialization
",MNKPadding" // GemmSpec
",1" // NumGemmKPrefetchStage
",64" // BlockSize
",64" // MPerBlock
",64" // NPerBlock
",32" // KPerBlock
",8" // AK1
",8" // BK1
",32" // MPerXDL
",32" // NPerXDL
",2" // MXdlPerWave
",2" // NXdlPerWave
",Seq(4,16,1)" // ABlockTransferThreadClusterLengths
",Seq(1,0,2)" // ABlockTransferThreadClusterArrangeOrder
",Seq(1,0,2)" // ABlockTransferSrcAccessOrder
",2" // ABlockTransferSrcVectorDim
",1" // ABlockTransferSrcScalarPerVector
",8" // ABlockTransferDstScalarPerVector_AK1
",1" // ABlockLdsExtraM
",Seq(4,16,1)" // BBlockTransferThreadClusterLengths
",Seq(1,0,2)" // BBlockTransferThreadClusterArrangeOrder
",Seq(1,0,2)" // BBlockTransferSrcAccessOrder
",2" // BBlockTransferSrcVectorDim
",1" // BBlockTransferSrcScalarPerVector
",8" // BBlockTransferDstScalarPerVector_BK1
",1" // BBlockLdsExtraN
",1" // CShuffleMXdlPerWavePerShuffle
",1" // CShuffleNXdlPerWavePerShuffle
",Seq(1,16,1,4)" // CDEBlockTransferClusterLengths
",1" // CDEBlockTransferScalarPerVector_NPerBlock
",fp16" // AComputeDataType
",fp16" // BComputeDataType
",Default>"; // LoopScheduler
EXPECT_EQ(instance_str, expected_str);
}

View File

@@ -6,8 +6,8 @@
#include <ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_abd.hpp>
#include <ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_comp_instance.hpp>
// Test GetInstanceString through base class pointer
TEST(GetInstanceStringTest, GetInstanceStringThroughBaseClass)
// Test GetInstanceString through base class pointer for V3 variant
TEST(GetInstanceString, ReturnsStringForFwdGrpConvV3Instance)
{
// Use the template helper to get a working instance configuration
using InstanceTuple =

View File

@@ -199,6 +199,14 @@ TEST(InstanceTraitsUtil, PipelineVersionNameReturnsCorrectStrings)
ElementsAre("v1", "v2", "v3", "v4", "v5"));
}
TEST(InstanceTraitsUtil, LoopSchedulerNameReturnsCorrectStrings)
{
using enum ck::LoopScheduler;
EXPECT_THAT(std::vector<std::string_view> names = {loop_scheduler_name(Default),
loop_scheduler_name(Interwave)},
ElementsAre("Default", "Interwave"));
}
TEST(InstanceTraitsUtil, TupleNameReturnsEmptyTupleForEmptyTuple)
{
EXPECT_EQ(tuple_name<ck::Tuple<>>(), "EmptyTuple");

View File

@@ -28,6 +28,9 @@
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#include "ck/host_utility/io.hpp"
#ifdef CK_EXPERIMENTAL_BUILDER
#include "ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp"
#endif
namespace ck {
namespace tensor_operation {
@@ -2063,6 +2066,19 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
return str.str();
}
#ifdef CK_EXPERIMENTAL_BUILDER
std::string GetInstanceString() const override
{
static_assert(ck_tile::reflect::HasInstanceTraits<DeviceOp>,
"Specialization of instance_traits not found. Please check that a "
"specialization exists in file "
"ck_tile/builder/reflect/"
"instance_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp "
"for the given template parameters.");
return ck_tile::reflect::instance_string<DeviceOp>();
}
#endif
size_t GetWorkSpaceSize(const BaseArgument* p_arg) const override
{
auto arg = dynamic_cast<const Argument*>(p_arg);

View File

@@ -24,6 +24,9 @@
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#include "ck/host_utility/io.hpp"
#ifdef CK_EXPERIMENTAL_BUILDER
#include "ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp"
#endif
namespace ck {
namespace tensor_operation {
@@ -1220,6 +1223,20 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor
return str.str();
}
#ifdef CK_EXPERIMENTAL_BUILDER
std::string GetInstanceString() const override
{
static_assert(
ck_tile::reflect::HasInstanceTraits<DeviceOp>,
"Specialization of instance_traits not found. Please check that a "
"specialization exists in file "
"ck_tile/builder/reflect/"
"instance_traits_device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp "
"for the given template parameters.");
return ck_tile::reflect::instance_string<DeviceOp>();
}
#endif
};
} // namespace device

View File

@@ -25,7 +25,7 @@
* (3) number of iterations to cover the entire Y axis.
* The raked here represents how data is partitioned across different processing granularity.
* It represents howe we are going to access the data in thread, warp, or blocked in contiguous
* It represents how we are going to access the data in thread, warp, or blocked in contiguous
region.
* From below, the qualifier for 'raked' is the part of warp/thread hierarchy
* in the split of Y tile dimension where the iteration happens,
@@ -101,7 +101,7 @@ enum struct tile_distribution_pattern
* @brief Block raked pattern - aka linear.
*
*/
block_raked,
block_raked
};
struct tile_distribution_encoding_pattern
@@ -144,7 +144,6 @@ struct tile_distribution_encoding_pattern_2d<BlockSize,
NumWaveGroups>
: public tile_distribution_encoding_pattern
{
// TODO: make pattern where below condition does not need to hold - GGemmMultiDSplitk!
static_assert(XPerTile % VecSize == 0, "XPerTile must be a multiple of VecSize!");
static constexpr index_t warp_size = get_warp_size();

View File

@@ -19,7 +19,7 @@
namespace ck_tile {
/** @brief Maximum number of error values to display when checking errors */
constexpr int ERROR_DETAIL_LIMIT = 5;
constexpr int ERROR_DETAIL_LIMIT = 128;
/** @brief 8-bit floating point type */
using F8 = ck_tile::fp8_t;

View File

@@ -26,7 +26,8 @@ struct GroupedConvBwdWeightKernelArgs
GroupedConvTraitsType_::ConvSpecialization,
GroupedConvTraitsType_::VectorSizeA,
GroupedConvTraitsType_::VectorSizeB,
GroupedConvTraitsType_::VectorSizeC>;
GroupedConvTraitsType_::VectorSizeC,
GroupedConvTraitsType_::NumGroupsToMerge>;
static constexpr index_t NumDTensor = GroupedConvTraitsType_::NumDTensor;
template <
@@ -84,9 +85,11 @@ struct GroupedConvBwdWeightKernelArgs
b_grid_desc_k_n = grid_descs.at(number<1>{});
c_grid_desc_m_n = grid_descs.at(number<2>{});
group_stride_a = args.K_; // A: Out NWGK
group_stride_b = args.C_; // B: In NWGC
group_stride_c = args.K_ * args.C_ * // C: Wei GKXC
NumGroupsPerBatch = GroupedConvTraitsType_::NumGroupsToMerge;
group_stride_a = args.K_ * NumGroupsPerBatch; // A: Out NWGK
group_stride_b = args.C_ * NumGroupsPerBatch; // B: In NWGC
group_stride_c = args.K_ * args.C_ // C: Wei GKXC
* NumGroupsPerBatch *
std::accumulate(args.filter_spatial_lengths_.begin(),
args.filter_spatial_lengths_.end(),
1,
@@ -95,7 +98,14 @@ struct GroupedConvBwdWeightKernelArgs
GemmM = a_grid_desc_k_m.get_length(number<1>{});
GemmN = b_grid_desc_k_n.get_length(number<1>{});
GemmK = a_grid_desc_k_m.get_length(number<0>{});
GemmBatch = args.G_;
GemmBatch = integer_divide_ceil(args.G_, NumGroupsPerBatch);
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
{
std::cout << "GemmM: " << GemmM << ", GemmN: " << GemmN << ", GemmK: " << GemmK
<< ", GemmBatch: " << GemmBatch
<< ", NumGroupsPerBatch: " << NumGroupsPerBatch << std::endl;
}
}
template <
@@ -160,9 +170,11 @@ struct GroupedConvBwdWeightKernelArgs
b_grid_desc_k_n = grid_descs.at(number<1>{});
c_grid_desc_m_n = grid_descs.at(number<2>{});
group_stride_a = args.K_; // A: Out NHWGK
group_stride_b = args.C_; // B: In NHWGC
group_stride_c = args.K_ * args.C_ * // C: Wei GKYXC
NumGroupsPerBatch = GroupedConvTraitsType_::NumGroupsToMerge;
group_stride_a = args.K_ * NumGroupsPerBatch; // A: Out NHWGK
group_stride_b = args.C_ * NumGroupsPerBatch; // B: In NHWGC
group_stride_c = args.K_ * args.C_ // C: Wei GKYXC
* NumGroupsPerBatch *
std::accumulate(args.filter_spatial_lengths_.begin(),
args.filter_spatial_lengths_.end(),
1,
@@ -171,7 +183,14 @@ struct GroupedConvBwdWeightKernelArgs
GemmM = a_grid_desc_k_m.get_length(number<1>{});
GemmN = b_grid_desc_k_n.get_length(number<1>{});
GemmK = a_grid_desc_k_m.get_length(number<0>{});
GemmBatch = args.G_;
GemmBatch = integer_divide_ceil(args.G_, NumGroupsPerBatch);
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
{
std::cout << "GemmM: " << GemmM << ", GemmN: " << GemmN << ", GemmK: " << GemmK
<< ", GemmBatch: " << GemmBatch
<< ", NumGroupsPerBatch: " << NumGroupsPerBatch << std::endl;
}
}
template <
@@ -243,9 +262,11 @@ struct GroupedConvBwdWeightKernelArgs
b_grid_desc_k_n = grid_descs.at(number<1>{});
c_grid_desc_m_n = grid_descs.at(number<2>{});
group_stride_a = args.K_; // A: Out NDHWGK
group_stride_b = args.C_; // B: In NDHWGC
group_stride_c = args.K_ * args.C_ * // C: wEI GKZYXC
NumGroupsPerBatch = GroupedConvTraitsType_::NumGroupsToMerge;
group_stride_a = args.K_ * NumGroupsPerBatch; // A: Out NDHWGK
group_stride_b = args.C_ * NumGroupsPerBatch; // B: In NDHWGC
group_stride_c = args.K_ * args.C_ // C: Wei GKZYXC
* NumGroupsPerBatch *
std::accumulate(args.filter_spatial_lengths_.begin(),
args.filter_spatial_lengths_.end(),
1,
@@ -254,7 +275,14 @@ struct GroupedConvBwdWeightKernelArgs
GemmM = a_grid_desc_k_m.get_length(number<1>{});
GemmN = b_grid_desc_k_n.get_length(number<1>{});
GemmK = a_grid_desc_k_m.get_length(number<0>{});
GemmBatch = args.G_;
GemmBatch = integer_divide_ceil(args.G_, NumGroupsPerBatch);
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
{
std::cout << "GemmM: " << GemmM << ", GemmN: " << GemmN << ", GemmK: " << GemmK
<< ", GemmBatch: " << GemmBatch
<< ", NumGroupsPerBatch: " << NumGroupsPerBatch << std::endl;
}
}
using ABCGridDescs = remove_cvref_t<
@@ -279,6 +307,7 @@ struct GroupedConvBwdWeightKernelArgs
index_t GemmN;
index_t GemmK;
index_t GemmBatch;
index_t NumGroupsPerBatch;
const void* out_ptr;
const void* in_ptr;
@@ -317,10 +346,9 @@ struct GroupedConvBwdWeightKernelArgs
/// the policy is responsible for definition of all necessary data layouts and thread's
/// work distribution.
///
/// @tparam GroupedConvTraitsType_ The type of class providing traits for grouped convolution.
/// @tparam GroupedConvTraitsType_ The type of class providing traits for grouped convolution.
/// @tparam TilePartitioner_ The type of class providing mapping of workgroup index into
/// the
/// output data tile to be calculated. It determines the
/// the output data tile to be calculated. It determines the
/// workgroup to data relationship (or in other words - which
/// data would be processed and calculated by which workgroup).
/// @tparam GemmPipeline_ The type of class which provides the core part of matrix
@@ -382,8 +410,12 @@ struct GroupedConvolutionBackwardWeightKernel
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
{
constexpr auto NumGroupsToMerge = GroupedConvTraitsType_::NumGroupsToMerge;
// clang-format off
return concat('_', "grouped_convolution_backward_weight", gemm_prec_str<InDataType, WeiDataType>, GemmPipeline::GetName());
if (NumGroupsToMerge > 1)
return concat('_', "grouped_convolution_backward_weight", gemm_prec_str<InDataType, WeiDataType>, GemmPipeline::GetName(), "merge", NumGroupsToMerge);
else
return concat('_', "grouped_convolution_backward_weight", gemm_prec_str<InDataType, WeiDataType>, GemmPipeline::GetName());
// clang-format on
}
@@ -402,6 +434,12 @@ struct GroupedConvolutionBackwardWeightKernel
CK_TILE_HOST static constexpr GroupedConvBwdWeightKernelArgsSpecialized
MakeKernelArgs(const GroupedConvBwdWeightHostArgs& hostArgs)
{
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
{
std::cout << "MPerBlock: " << number<TilePartitioner::MPerBlock>{} << std::endl;
std::cout << "NPerBlock: " << number<TilePartitioner::NPerBlock>{} << std::endl;
std::cout << "KPerBlock: " << number<TilePartitioner::KPerBlock>{} << std::endl;
}
return GroupedConvBwdWeightKernelArgsSpecialized(hostArgs);
}
@@ -442,11 +480,14 @@ struct GroupedConvolutionBackwardWeightKernel
{
return [&]() {
if(kargs.k_batch > 1)
hipGetErrorString(hipMemsetAsync(kargs.wei_ptr,
0,
kargs.GemmBatch * kargs.GemmM * kargs.GemmN *
sizeof(WeiDataType),
s.stream_id_));
{
// Total number of convolution groups (ConvG) = GemmBatch * NumGroupsPerBatch
// since we require that ConvG % NumGroupsPerBatch == 0.
const auto wei_size =
kargs.GemmBatch * kargs.GemmM * kargs.GemmN * kargs.NumGroupsPerBatch;
hipGetErrorString(
hipMemsetAsync(kargs.wei_ptr, 0, wei_size * sizeof(WeiDataType), s.stream_id_));
}
};
}
@@ -527,7 +568,8 @@ struct GroupedConvolutionBackwardWeightKernel
// Check access per C
if(ConvC % GroupedConvTraitsType_::VectorSizeB != 0)
{
CK_TILE_ERROR("Conv C is not a multiple of vector load size for input image!");
CK_TILE_ERROR("Conv C is not a multiple of vector load size for "
"input image!");
return false;
}
}
@@ -559,7 +601,8 @@ struct GroupedConvolutionBackwardWeightKernel
{
if(ConvK % GroupedConvTraitsType_::VectorSizeA != 0)
{
CK_TILE_ERROR("Conv K is not a multiple of vector store size for output image!");
CK_TILE_ERROR("Conv K is not a multiple of vector store size "
"for output image!");
return false;
}
}
@@ -569,6 +612,18 @@ struct GroupedConvolutionBackwardWeightKernel
return false;
}
if constexpr(GroupedConvTraitsType_::NumGroupsToMerge > 1)
{
const index_t ConvG = kargs.wei_g_k_c_xs_lengths[number<0>{}];
if(ConvG % GroupedConvTraitsType_::NumGroupsToMerge != 0)
{
CK_TILE_ERROR("ConvG must be a multiple of NumGroupsToMerge!");
return false;
}
// TODO: Should we also check that GemmM <= MPerBlock and GemmN <= NPerBlock?
}
return true;
}
@@ -654,6 +709,16 @@ struct GroupedConvolutionBackwardWeightKernel
return make_tuple(a_pad_view, b_pad_view, ds_pad_view, c_pad_view);
}
/**
* @brief Create views to the data that each workgroup will process.
*
* @param views padded views of A, B, D and C tensors
* @param i_m block m-index
* @param i_n block n-index
* @param i_k block k-index
*
* @return tuple of tile windows for A, B, D and C tensors
*/
template <typename PadView>
CK_TILE_DEVICE static auto MakeGemmTileWindows(const PadView& views,
const index_t i_m,
@@ -818,7 +883,6 @@ struct GroupedConvolutionBackwardWeightKernel
const InDataType* b_ptr = static_cast<const InDataType*>(kargs.in_ptr) + group_offset_b;
WeiDataType* c_ptr = static_cast<WeiDataType*>(kargs.wei_ptr) + group_offset_c;
// allocate LDS
__shared__ char smem_ptr_0[GetSmemSize()];
if constexpr(GemmPipeline::DoubleSmemBuffer == true)

View File

@@ -29,6 +29,7 @@ struct GroupedConvFwdKernelArgs
GroupedConvTraitsType_::VectorSizeA,
GroupedConvTraitsType_::VectorSizeB,
GroupedConvTraitsType_::VectorSizeC,
GroupedConvTraitsType_::NumGroupsToMerge,
true>; // Split N enabled
using CDElementwise = typename GroupedConvTraitsType_::CDElementwise;
static constexpr index_t NumDTensor = GroupedConvTraitsType_::NumDTensor;

View File

@@ -59,10 +59,11 @@ template <index_t NDimSpatial_,
typename WeiLayout_,
typename DsLayout_,
typename OutLayout_,
index_t VectorSizeA_ = 1,
index_t VectorSizeB_ = 1,
index_t VectorSizeC_ = 1,
typename CDElementwise_ = PassThrough>
index_t VectorSizeA_ = 1,
index_t VectorSizeB_ = 1,
index_t VectorSizeC_ = 1,
index_t NumGroupsToMerge_ = 1,
typename CDElementwise_ = PassThrough>
struct GroupedConvTraits
{
private:
@@ -73,7 +74,7 @@ struct GroupedConvTraits
}
public:
static constexpr index_t NumGroupsToMerge = 1;
static constexpr index_t NumGroupsToMerge = NumGroupsToMerge_;
static constexpr index_t NDimSpatial = NDimSpatial_;
static constexpr ConvolutionSpecialization ConvSpecialization = ConvSpecialization_;
using InLayout = InLayout_;

View File

@@ -13,10 +13,10 @@ template <index_t NDimSpatial,
index_t VectorSizeA,
index_t VectorSizeB,
index_t VectorSizeC,
index_t NumGroupsToMerge = 1,
bool SplitN = false,
typename ADataType = float,
typename CDataType = float,
index_t NumGroupsToMerge = 1,
typename IndexType = index_t>
struct TransformConvBwdWeightToGemm
{
@@ -125,8 +125,7 @@ struct TransformConvBwdWeightToGemm
InLeftPadW_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.InLeftPadW_)},
InRightPadD_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.InRightPadD_)},
InRightPadH_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.InRightPadH_)},
InRightPadW_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.InRightPadW_)},
ZYX_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.ZYX_)}
InRightPadW_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.InRightPadW_)}
{
}
@@ -164,8 +163,7 @@ struct TransformConvBwdWeightToGemm
InLeftPadW_{input_left_pads[I0]},
InRightPadD_{I0},
InRightPadH_{I0},
InRightPadW_{input_right_pads[I0]},
ZYX_{X_}
InRightPadW_{input_right_pads[I0]}
{
static_assert(std::is_same_v<ConvSpatialDimsType, std::array<IndexType, NDimSpatial>> ||
std::is_same_v<ConvSpatialDimsType, ck_tile::array<IndexType, NDimSpatial>>);
@@ -219,8 +217,7 @@ struct TransformConvBwdWeightToGemm
InLeftPadW_{input_left_pads[I1]},
InRightPadD_{I0},
InRightPadH_{input_right_pads[I0]},
InRightPadW_{input_right_pads[I1]},
ZYX_{Y_ * X_}
InRightPadW_{input_right_pads[I1]}
{
static_assert(std::is_same_v<ConvSpatialDimsType, std::array<IndexType, NDimSpatial>> ||
std::is_same_v<ConvSpatialDimsType, ck_tile::array<IndexType, NDimSpatial>>);
@@ -274,8 +271,7 @@ struct TransformConvBwdWeightToGemm
InLeftPadW_{input_left_pads[I2]},
InRightPadD_{input_right_pads[I0]},
InRightPadH_{input_right_pads[I1]},
InRightPadW_{input_right_pads[I2]},
ZYX_{Z_ * Y_ * X_}
InRightPadW_{input_right_pads[I2]}
{
static_assert(std::is_same_v<ConvSpatialDimsType, std::array<IndexType, NDimSpatial>> ||
std::is_same_v<ConvSpatialDimsType, ck_tile::array<IndexType, NDimSpatial>>);
@@ -420,11 +416,21 @@ struct TransformConvBwdWeightToGemm
const index_t NDoHoWoStride = G_ * K_;
constexpr auto KStride = I1;
// TODO Add support for NumGroupsToMerge > 1
return make_naive_tensor_descriptor(make_tuple(K_, N_ * Wo_),
make_tuple(KStride, NDoHoWoStride),
number<VectorSizeA>{},
I1);
if constexpr(NumGroupsToMerge > 1)
{
const index_t BatchStride = K_;
return make_naive_tensor_descriptor(make_tuple(K_, NumGroupsToMerge, N_ * Wo_),
make_tuple(KStride, BatchStride, NDoHoWoStride),
number<VectorSizeA>{},
I1);
}
else
{
return make_naive_tensor_descriptor(make_tuple(K_, N_ * Wo_),
make_tuple(KStride, NDoHoWoStride),
number<VectorSizeA>{},
I1);
}
}
template <index_t NDim = NDimSpatial, typename std::enable_if<NDim == 1, bool>::type = false>
@@ -435,11 +441,22 @@ struct TransformConvBwdWeightToGemm
const index_t WiStride = G_ * C_;
constexpr auto CStride = I1;
// TODO Add support for NumGroupsToMerge > 1
return make_naive_tensor_descriptor(make_tuple(N_, Wi_, C_),
make_tuple(NStride, WiStride, CStride),
number<VectorSizeB>{},
I1);
if constexpr(NumGroupsToMerge > 1)
{
const auto BatchStride = C_;
return make_naive_tensor_descriptor(make_tuple(N_, Wi_, NumGroupsToMerge, C_),
make_tuple(NStride, WiStride, BatchStride, CStride),
number<VectorSizeB>{},
I1);
}
else
{
return make_naive_tensor_descriptor(make_tuple(N_, Wi_, C_),
make_tuple(NStride, WiStride, CStride),
number<VectorSizeB>{},
I1);
}
}
template <index_t NDim = NDimSpatial, typename std::enable_if<NDim == 1, bool>::type = false>
@@ -449,9 +466,56 @@ struct TransformConvBwdWeightToGemm
const index_t KStride = X_ * C_;
constexpr auto CXStride = I1;
// TODO Add support for NumGroupsToMerge > 1
return make_naive_tensor_descriptor(
make_tuple(K_, X_ * C_), make_tuple(KStride, CXStride), number<VectorSizeC>{}, I1);
if constexpr(NumGroupsToMerge > 1)
{
const index_t XStride = C_;
const index_t BatchStride = K_ * X_ * C_;
// Add NumGroupsToMerge for Batch+M dimension and, 1 as a placeholder
// for Batch+N dimension
const auto desc = make_naive_tensor_descriptor(
make_tuple(NumGroupsToMerge, K_, X_, 1, C_),
make_tuple(BatchStride, KStride, XStride, BatchStride, CXStride),
number<VectorSizeC>{},
I1);
// Pad 1 to NumGroupsToMerge
const auto padded_desc = transform_tensor_descriptor(
desc,
make_tuple(make_pass_through_transform(NumGroupsToMerge),
make_pass_through_transform(K_),
make_pass_through_transform(X_),
make_pad_transform(1, 0, NumGroupsToMerge - 1),
make_pass_through_transform(C_)),
make_tuple(
sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}, sequence<4>{}),
make_tuple(
sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}, sequence<4>{}));
// We need only matrices from diagonal. Xor returns 0 for the same
// values. So if matrices is not on diagonal then it will be stored in padding.
// To avoid use of modulo after xor we assume that NumBatch to merge is power of 2.
static_assert(NumGroupsToMerge == 1 || NumGroupsToMerge == 2 || NumGroupsToMerge == 4 ||
NumGroupsToMerge == 8 || NumGroupsToMerge == 16 ||
NumGroupsToMerge == 32 || NumGroupsToMerge == 64);
const auto unmerged_padded_desc = transform_tensor_descriptor(
padded_desc,
make_tuple(make_xor_transform(make_tuple(NumGroupsToMerge, NumGroupsToMerge)),
make_pass_through_transform(K_),
make_pass_through_transform(X_),
make_pass_through_transform(C_)),
make_tuple(sequence<0, 3>{}, sequence<1>{}, sequence<2>{}, sequence<4>{}),
make_tuple(sequence<0, 3>{}, sequence<1>{}, sequence<2>{}, sequence<4>{}));
// Merge To M, N
return transform_tensor_descriptor(
unmerged_padded_desc,
make_tuple(make_merge_transform(make_tuple(NumGroupsToMerge, K_)),
make_merge_transform(make_tuple(X_, NumGroupsToMerge, C_))),
make_tuple(sequence<0, 1>{}, sequence<2, 3, 4>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
}
else
{
return make_naive_tensor_descriptor(
make_tuple(K_, X_ * C_), make_tuple(KStride, CXStride), number<VectorSizeC>{}, I1);
}
}
template <index_t NDim = NDimSpatial, typename std::enable_if<NDim == 2, bool>::type = false>
@@ -461,11 +525,22 @@ struct TransformConvBwdWeightToGemm
const index_t NDoHoWoStride = G_ * K_;
constexpr auto KStride = I1;
// TODO Add support for NumGroupsToMerge > 1
return make_naive_tensor_descriptor(make_tuple(N_ * Ho_ * Wo_, K_), // K_M
make_tuple(NDoHoWoStride, KStride),
number<VectorSizeA>{},
I1);
if constexpr(NumGroupsToMerge > 1)
{
const index_t BatchStride = K_;
return make_naive_tensor_descriptor(
make_tuple(N_ * Ho_ * Wo_, NumGroupsToMerge, K_), // K_Gm_M
make_tuple(NDoHoWoStride, BatchStride, KStride),
number<VectorSizeA>{},
I1);
}
else
{
return make_naive_tensor_descriptor(make_tuple(N_ * Ho_ * Wo_, K_), // K_M
make_tuple(NDoHoWoStride, KStride),
number<VectorSizeA>{},
I1);
}
}
template <index_t NDim = NDimSpatial, typename std::enable_if<NDim == 2, bool>::type = false>
@@ -477,11 +552,22 @@ struct TransformConvBwdWeightToGemm
const index_t WiStride = G_ * C_;
constexpr auto CStride = I1;
// TODO Add support for NumGroupsToMerge > 1
return make_naive_tensor_descriptor(make_tuple(N_, Hi_, Wi_, C_), // K_N
make_tuple(NStride, HiStride, WiStride, CStride),
number<VectorSizeB>{},
I1);
if constexpr(NumGroupsToMerge > 1)
{
const auto BatchStride = C_;
return make_naive_tensor_descriptor(
make_tuple(N_, Hi_, Wi_, NumGroupsToMerge, C_), // K_Gm_N
make_tuple(NStride, HiStride, WiStride, BatchStride, CStride),
number<VectorSizeB>{},
I1);
}
else
{
return make_naive_tensor_descriptor(make_tuple(N_, Hi_, Wi_, C_), // K_N
make_tuple(NStride, HiStride, WiStride, CStride),
number<VectorSizeB>{},
I1);
}
}
template <index_t NDim = NDimSpatial, typename std::enable_if<NDim == 2, bool>::type = false>
@@ -491,9 +577,58 @@ struct TransformConvBwdWeightToGemm
const index_t KStride = Y_ * X_ * C_;
constexpr auto CStride = I1;
// TODO Add support for NumGroupsToMerge > 1
return make_naive_tensor_descriptor(
make_tuple(K_, Y_ * X_ * C_), make_tuple(KStride, CStride), number<VectorSizeC>{}, I1);
if constexpr(NumGroupsToMerge > 1)
{
const index_t YXStride = C_;
const index_t BatchStride = K_ * Y_ * X_ * C_;
// Add NumGroupsToMerge for Batch+M dimension and, 1 as a placeholder
// for Batch+N dimension
const auto desc = make_naive_tensor_descriptor(
make_tuple(NumGroupsToMerge, K_, Y_ * X_, 1, C_),
make_tuple(BatchStride, KStride, YXStride, BatchStride, CStride),
number<VectorSizeC>{},
I1);
// Pad 1 to NumGroupsToMerge
const auto padded_desc = transform_tensor_descriptor(
desc,
make_tuple(make_pass_through_transform(NumGroupsToMerge),
make_pass_through_transform(K_),
make_pass_through_transform(Y_ * X_),
make_pad_transform(1, 0, NumGroupsToMerge - 1),
make_pass_through_transform(C_)),
make_tuple(
sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}, sequence<4>{}),
make_tuple(
sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}, sequence<4>{}));
// We need only matrices from diagonal. Xor returns 0 for the same
// values. So if matrices is not on diagonal then it will be stored in padding.
// To avoid use of modulo after xor we assume that NumBatch to merge is power of 2.
static_assert(NumGroupsToMerge == 1 || NumGroupsToMerge == 2 || NumGroupsToMerge == 4 ||
NumGroupsToMerge == 8 || NumGroupsToMerge == 16 ||
NumGroupsToMerge == 32 || NumGroupsToMerge == 64);
const auto unmerged_padded_desc = transform_tensor_descriptor(
padded_desc,
make_tuple(make_xor_transform(make_tuple(NumGroupsToMerge, NumGroupsToMerge)),
make_pass_through_transform(K_),
make_pass_through_transform(Y_ * X_),
make_pass_through_transform(C_)),
make_tuple(sequence<0, 3>{}, sequence<1>{}, sequence<2>{}, sequence<4>{}),
make_tuple(sequence<0, 3>{}, sequence<1>{}, sequence<2>{}, sequence<4>{}));
// Merge To M, N
return transform_tensor_descriptor(
unmerged_padded_desc,
make_tuple(make_merge_transform(make_tuple(NumGroupsToMerge, K_)),
make_merge_transform(make_tuple(Y_ * X_, NumGroupsToMerge, C_))),
make_tuple(sequence<0, 1>{}, sequence<2, 3, 4>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
}
else
{
return make_naive_tensor_descriptor(make_tuple(K_, Y_ * X_ * C_),
make_tuple(KStride, CStride),
number<VectorSizeC>{},
I1);
}
}
template <index_t NDim = NDimSpatial, typename std::enable_if<NDim == 3, bool>::type = false>
@@ -503,11 +638,22 @@ struct TransformConvBwdWeightToGemm
const index_t NDoHoWoStride = G_ * K_;
constexpr auto KStride = I1;
// TODO Add support for NumGroupsToMerge > 1
return make_naive_tensor_descriptor(make_tuple(N_ * Do_ * Ho_ * Wo_, K_),
make_tuple(NDoHoWoStride, KStride),
number<VectorSizeA>{},
I1);
if constexpr(NumGroupsToMerge > 1)
{
const auto BatchStride = K_;
return make_naive_tensor_descriptor(
make_tuple(N_ * Do_ * Ho_ * Wo_, NumGroupsToMerge, K_),
make_tuple(NDoHoWoStride, BatchStride, KStride),
number<VectorSizeA>{},
I1);
}
else
{
return make_naive_tensor_descriptor(make_tuple(N_ * Do_ * Ho_ * Wo_, K_),
make_tuple(NDoHoWoStride, KStride),
number<VectorSizeA>{},
I1);
}
}
template <index_t NDim = NDimSpatial, typename std::enable_if<NDim == 3, bool>::type = false>
@@ -519,26 +665,84 @@ struct TransformConvBwdWeightToGemm
const index_t WiStride = G_ * C_;
constexpr auto CStride = I1;
// TODO Add support for NumGroupsToMerge > 1
return make_naive_tensor_descriptor(
make_tuple(N_, Di_, Hi_, Wi_, C_),
make_tuple(NStride, DiStride, HiStride, WiStride, CStride),
number<VectorSizeB>{},
I1);
if constexpr(NumGroupsToMerge > 1)
{
const index_t BatchStride = C_;
return make_naive_tensor_descriptor(
make_tuple(N_, Di_, Hi_, Wi_, NumGroupsToMerge, C_),
make_tuple(NStride, DiStride, HiStride, WiStride, BatchStride, CStride),
number<VectorSizeB>{},
I1);
}
else
{
return make_naive_tensor_descriptor(
make_tuple(N_, Di_, Hi_, Wi_, C_),
make_tuple(NStride, DiStride, HiStride, WiStride, CStride),
number<VectorSizeB>{},
I1);
}
}
template <index_t NDim = NDimSpatial, typename std::enable_if<NDim == 3, bool>::type = false>
CK_TILE_HOST auto make_wei_grid_desc() const
{
// KZYXC
// GKZYXC
const index_t KStride = Z_ * Y_ * X_ * C_;
constexpr auto CStride = I1;
// TODO Add support for NumGroupsToMerge > 1
return make_naive_tensor_descriptor(make_tuple(K_, Z_ * Y_ * X_ * C_),
make_tuple(KStride, CStride),
number<VectorSizeC>{},
I1);
if constexpr(NumGroupsToMerge > 1)
{
const index_t ZYXStride = C_;
const index_t BatchStride = K_ * Z_ * Y_ * X_ * C_;
// Add NumGroupsToMerge for Batch+M dimension and, 1 as a placeholder
// for Batch+N dimension
const auto desc = make_naive_tensor_descriptor(
make_tuple(NumGroupsToMerge, K_, Z_ * Y_ * X_, 1, C_),
make_tuple(BatchStride, KStride, ZYXStride, BatchStride, CStride),
number<VectorSizeC>{},
I1);
// Pad 1 to NumGroupsToMerge
const auto padded_desc = transform_tensor_descriptor(
desc,
make_tuple(make_pass_through_transform(NumGroupsToMerge),
make_pass_through_transform(K_),
make_pass_through_transform(Z_ * Y_ * X_),
make_pad_transform(1, 0, NumGroupsToMerge - 1),
make_pass_through_transform(C_)),
make_tuple(
sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}, sequence<4>{}),
make_tuple(
sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}, sequence<4>{}));
// We need only matrices from diagonal. Xor returns 0 for the same
// values. So if matrices is not on diagonal then it will be stored in padding.
// To avoid use of modulo after xor we assume that NumBatch to merge is power of 2.
static_assert(NumGroupsToMerge == 1 || NumGroupsToMerge == 2 || NumGroupsToMerge == 4 ||
NumGroupsToMerge == 8 || NumGroupsToMerge == 16 ||
NumGroupsToMerge == 32 || NumGroupsToMerge == 64);
const auto unmerged_padded_desc = transform_tensor_descriptor(
padded_desc,
make_tuple(make_xor_transform(make_tuple(NumGroupsToMerge, NumGroupsToMerge)),
make_pass_through_transform(K_),
make_pass_through_transform(Z_ * Y_ * X_),
make_pass_through_transform(C_)),
make_tuple(sequence<0, 3>{}, sequence<1>{}, sequence<2>{}, sequence<4>{}),
make_tuple(sequence<0, 3>{}, sequence<1>{}, sequence<2>{}, sequence<4>{}));
// Merge To M, N
return transform_tensor_descriptor(
unmerged_padded_desc,
make_tuple(make_merge_transform(make_tuple(NumGroupsToMerge, K_)),
make_merge_transform(make_tuple(Z_ * Y_ * X_, NumGroupsToMerge, C_))),
make_tuple(sequence<0, 1>{}, sequence<2, 3, 4>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
}
else
{
return make_naive_tensor_descriptor(make_tuple(K_, Z_ * Y_ * X_ * C_),
make_tuple(KStride, CStride),
number<VectorSizeC>{},
I1);
}
}
// TODO: implement ck_tile::tensor_layout::convolution that describe packed/strided dimemsion as
@@ -552,31 +756,84 @@ struct TransformConvBwdWeightToGemm
const auto wei_grid_desc = make_wei_grid_desc<NDimSpatial>();
// B: input tensor comes in K_N
const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor(
in_grid_desc,
make_tuple(make_pass_through_transform(N_),
make_pad_transform(Wi_, InLeftPadW_, InRightPadW_),
make_pass_through_transform(C_)),
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}),
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}));
if constexpr(NumGroupsToMerge > 1)
{
// Output tensor transformation
// [0, 1, 2] -> [0, 1]
// [(N*Wo), Gm, K] -> [(N*Wo), (Gm*K)]
const auto out_gemm_k_gemm_m_grid_desc = transform_tensor_descriptor(
out_grid_desc,
make_tuple(make_pass_through_transform(N_ * Wo_),
make_merge_transform(make_tuple(NumGroupsToMerge, K_))),
make_tuple(sequence<0>{}, sequence<1, 2>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor(
in_n_hip_wip_c_grid_desc,
make_tuple(
make_pass_through_transform(N_),
make_embed_transform(make_tuple(X_, Wo_), make_tuple(ConvDilationW_, ConvStrideW_)),
make_pass_through_transform(C_)),
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}),
make_tuple(sequence<0>{}, sequence<1, 2>{}, sequence<3>{}));
// Input tensor transformation, part 1.
// [N, Wi, Gm, C] -> [N, (Wi + InLeftPadW + InRightPadW), Gm, C] = [N, Wip, Gm, C]
const auto in_n_wip_gm_c_grid_desc = transform_tensor_descriptor(
in_grid_desc,
make_tuple(make_pass_through_transform(N_),
make_pad_transform(Wi_, InLeftPadW_, InRightPadW_),
make_pass_through_transform(NumGroupsToMerge),
make_pass_through_transform(C_)),
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}),
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}));
const auto in_gemmn_gemmktotal_grid_desc =
transform_tensor_descriptor(in_n_y_ho_x_wo_c_grid_desc,
make_tuple(make_merge_transform(make_tuple(X_, C_)),
make_merge_transform(make_tuple(N_, Wo_))),
make_tuple(sequence<1, 3>{}, sequence<0, 2>{}),
make_tuple(sequence<1>{}, sequence<0>{}));
// Input tensor transformation, part 2.
// [N, Wip, Gm, C] -> [N, X, Wo, Gm, C]
const auto in_n_x_wo_gm_c_grid_desc = transform_tensor_descriptor(
in_n_wip_gm_c_grid_desc,
make_tuple(make_pass_through_transform(N_),
make_embed_transform(make_tuple(X_, Wo_),
make_tuple(ConvDilationW_, ConvStrideW_)),
make_pass_through_transform(NumGroupsToMerge),
make_pass_through_transform(C_)),
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}),
make_tuple(sequence<0>{}, sequence<1, 2>{}, sequence<3>{}, sequence<4>{}));
return make_tuple(out_grid_desc, in_gemmn_gemmktotal_grid_desc, wei_grid_desc);
// Input tensor transformation, part 3.
// [0, 1, 2, 3, 4] -> [0, 1]
// [N, X, Wo, Gm, C] -> [(N*Wo), (Gm*X*C)]
const auto in_gemm_n_gemm_k_grid_desc = transform_tensor_descriptor(
in_n_x_wo_gm_c_grid_desc,
make_tuple(make_merge_transform(make_tuple(X_, NumGroupsToMerge, C_)),
make_merge_transform(make_tuple(N_, Wo_))),
make_tuple(sequence<1, 3, 4>{}, sequence<0, 2>{}),
make_tuple(sequence<1>{}, sequence<0>{}));
return make_tuple(
out_gemm_k_gemm_m_grid_desc, in_gemm_n_gemm_k_grid_desc, wei_grid_desc);
}
else
{
// [N, Wi, C] -> [N, (Wi + InLeftPadW + InRightPadW), C] = [N, Wip, C]
const auto in_n_wip_c_grid_desc = transform_tensor_descriptor(
in_grid_desc,
make_tuple(make_pass_through_transform(N_),
make_pad_transform(Wi_, InLeftPadW_, InRightPadW_),
make_pass_through_transform(C_)),
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}),
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}));
// [N, Wip, C] -> [N, X, Wo, C]
const auto in_n_x_wo_c_grid_desc = transform_tensor_descriptor(
in_n_wip_c_grid_desc,
make_tuple(make_pass_through_transform(N_),
make_embed_transform(make_tuple(X_, Wo_),
make_tuple(ConvDilationW_, ConvStrideW_)),
make_pass_through_transform(C_)),
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}),
make_tuple(sequence<0>{}, sequence<1, 2>{}, sequence<3>{}));
const auto in_gemmn_gemmktotal_grid_desc =
transform_tensor_descriptor(in_n_x_wo_c_grid_desc,
make_tuple(make_merge_transform(make_tuple(X_, C_)),
make_merge_transform(make_tuple(N_, Wo_))),
make_tuple(sequence<1, 3>{}, sequence<0, 2>{}),
make_tuple(sequence<1>{}, sequence<0>{}));
return make_tuple(out_grid_desc, in_gemmn_gemmktotal_grid_desc, wei_grid_desc);
}
}
template <index_t NDim = NDimSpatial, typename std::enable_if<NDim == 2, bool>::type = false>
@@ -587,33 +844,95 @@ struct TransformConvBwdWeightToGemm
const auto wei_grid_desc = make_wei_grid_desc<NDimSpatial>();
// B: input tensor comes in K_N
const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor(
in_grid_desc,
make_tuple(make_pass_through_transform(N_),
make_pad_transform(Hi_, InLeftPadH_, InRightPadH_),
make_pad_transform(Wi_, InLeftPadW_, InRightPadW_),
make_pass_through_transform(C_)),
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}),
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}));
if constexpr(NumGroupsToMerge > 1)
{
// Output tensor transformation
// [0, 1, 2] -> [0, 1]
// [(N*Ho*Wo), Gm, K] -> [(N*Ho*Wo), (K*Gm)]
const auto out_gemm_k_gemm_m_grid_desc = transform_tensor_descriptor(
out_grid_desc,
make_tuple(make_pass_through_transform(N_ * Ho_ * Wo_),
make_merge_transform(make_tuple(NumGroupsToMerge, K_))),
make_tuple(sequence<0>{}, sequence<1, 2>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor(
in_n_hip_wip_c_grid_desc,
make_tuple(
make_pass_through_transform(N_),
make_embed_transform(make_tuple(Y_, Ho_), make_tuple(ConvDilationH_, ConvStrideH_)),
make_embed_transform(make_tuple(X_, Wo_), make_tuple(ConvDilationW_, ConvStrideW_)),
make_pass_through_transform(C_)),
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}),
make_tuple(sequence<0>{}, sequence<1, 2>{}, sequence<3, 4>{}, sequence<5>{}));
// Input tensor transformation, part 1.
// [N, Hi, Wi, Gm, C] -> [N, Hip, Wip, Gm, C]
const auto in_n_hip_wip_gm_c_grid_desc = transform_tensor_descriptor(
in_grid_desc,
make_tuple(make_pass_through_transform(N_),
make_pad_transform(Hi_, InLeftPadH_, InRightPadH_),
make_pad_transform(Wi_, InLeftPadW_, InRightPadW_),
make_pass_through_transform(NumGroupsToMerge),
make_pass_through_transform(C_)),
make_tuple(
sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}, sequence<4>{}),
make_tuple(
sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}, sequence<4>{}));
const auto in_gemmn_gemmktotal_grid_desc =
transform_tensor_descriptor(in_n_y_ho_x_wo_c_grid_desc,
make_tuple(make_merge_transform(make_tuple(Y_, X_, C_)),
make_merge_transform(make_tuple(N_, Ho_, Wo_))),
make_tuple(sequence<1, 3, 5>{}, sequence<0, 2, 4>{}),
make_tuple(sequence<1>{}, sequence<0>{}));
// Input tensor transformation, part 2.
// [N, Hip, Wip, Gm, C] -> [N, (Y, Wo), (X, Wo), Gm, C]
const auto in_n_y_ho_x_wo_gm_c_grid_desc = transform_tensor_descriptor(
in_n_hip_wip_gm_c_grid_desc,
make_tuple(make_pass_through_transform(N_),
make_embed_transform(make_tuple(Y_, Ho_),
make_tuple(ConvDilationH_, ConvStrideH_)),
make_embed_transform(make_tuple(X_, Wo_),
make_tuple(ConvDilationW_, ConvStrideW_)),
make_pass_through_transform(NumGroupsToMerge),
make_pass_through_transform(C_)),
make_tuple(
sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}, sequence<4>{}),
make_tuple(sequence<0>{},
sequence<1, 2>{},
sequence<3, 4>{},
sequence<5>{},
sequence<6>{}));
return make_tuple(out_grid_desc, in_gemmn_gemmktotal_grid_desc, wei_grid_desc);
// Input tensor transformation, part 3.
// [0, 1, 2, 3, 4 5 6] -> [0, 1]
// [N, Y, Ho, X, Wo, Gm, C] -> [(N*Ho*Wo), (Gm*Y*X*C)]
const auto in_gemm_n_gemm_k_grid_desc = transform_tensor_descriptor(
in_n_y_ho_x_wo_gm_c_grid_desc,
make_tuple(make_merge_transform(make_tuple(Y_, X_, NumGroupsToMerge, C_)),
make_merge_transform(make_tuple(N_, Ho_, Wo_))),
make_tuple(sequence<1, 3, 5, 6>{}, sequence<0, 2, 4>{}),
make_tuple(sequence<1>{}, sequence<0>{}));
return make_tuple(
out_gemm_k_gemm_m_grid_desc, in_gemm_n_gemm_k_grid_desc, wei_grid_desc);
}
else
{
const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor(
in_grid_desc,
make_tuple(make_pass_through_transform(N_),
make_pad_transform(Hi_, InLeftPadH_, InRightPadH_),
make_pad_transform(Wi_, InLeftPadW_, InRightPadW_),
make_pass_through_transform(C_)),
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}),
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}));
const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor(
in_n_hip_wip_c_grid_desc,
make_tuple(make_pass_through_transform(N_),
make_embed_transform(make_tuple(Y_, Ho_),
make_tuple(ConvDilationH_, ConvStrideH_)),
make_embed_transform(make_tuple(X_, Wo_),
make_tuple(ConvDilationW_, ConvStrideW_)),
make_pass_through_transform(C_)),
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}),
make_tuple(sequence<0>{}, sequence<1, 2>{}, sequence<3, 4>{}, sequence<5>{}));
const auto in_gemmn_gemmktotal_grid_desc = transform_tensor_descriptor(
in_n_y_ho_x_wo_c_grid_desc,
make_tuple(make_merge_transform(make_tuple(Y_, X_, C_)),
make_merge_transform(make_tuple(N_, Ho_, Wo_))),
make_tuple(sequence<1, 3, 5>{}, sequence<0, 2, 4>{}),
make_tuple(sequence<1>{}, sequence<0>{}));
return make_tuple(out_grid_desc, in_gemmn_gemmktotal_grid_desc, wei_grid_desc);
}
}
template <index_t NDim = NDimSpatial, typename std::enable_if<NDim == 3, bool>::type = false>
@@ -624,39 +943,121 @@ struct TransformConvBwdWeightToGemm
const auto wei_grid_desc = make_wei_grid_desc<NDimSpatial>();
// B: input tensor comes in K_N
const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor(
in_grid_desc,
make_tuple(make_pass_through_transform(N_),
make_pad_transform(Di_, InLeftPadD_, InRightPadD_),
make_pad_transform(Hi_, InLeftPadH_, InRightPadH_),
make_pad_transform(Wi_, InLeftPadW_, InRightPadW_),
make_pass_through_transform(C_)),
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}, sequence<4>{}),
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}, sequence<4>{}));
if constexpr(NumGroupsToMerge > 1)
{
// Output tensor transformation
// [0, 1, 2] -> [0, 1]
// [(N*Do*Ho*Wo), Gm, K] -> [(N*Do*Ho*Wo), (K*Gm)]
const auto out_gemm_k_gemm_m_grid_desc = transform_tensor_descriptor(
out_grid_desc,
make_tuple(make_pass_through_transform(N_ * Do_ * Ho_ * Wo_),
make_merge_transform(make_tuple(NumGroupsToMerge, K_))),
make_tuple(sequence<0>{}, sequence<1, 2>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor(
in_n_hip_wip_c_grid_desc,
make_tuple(
make_pass_through_transform(N_),
make_embed_transform(make_tuple(Z_, Do_), make_tuple(ConvDilationD_, ConvStrideD_)),
make_embed_transform(make_tuple(Y_, Ho_), make_tuple(ConvDilationH_, ConvStrideH_)),
make_embed_transform(make_tuple(X_, Wo_), make_tuple(ConvDilationW_, ConvStrideW_)),
make_pass_through_transform(C_)),
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}, sequence<4>{}),
make_tuple(sequence<0>{},
sequence<1, 2>{},
sequence<3, 4>{},
sequence<5, 6>{},
sequence<7>{}));
// Input tensor transformation, part 1.
// [N, Di, Hi, Wi, Gm, C] -> [N, Dip, Hip, Wip, Gm, C]
const auto in_n_dip_hip_wip_gm_c_grid_desc = transform_tensor_descriptor(
in_grid_desc,
make_tuple(make_pass_through_transform(N_),
make_pad_transform(Di_, InLeftPadD_, InRightPadD_),
make_pad_transform(Hi_, InLeftPadH_, InRightPadH_),
make_pad_transform(Wi_, InLeftPadW_, InRightPadW_),
make_pass_through_transform(NumGroupsToMerge),
make_pass_through_transform(C_)),
make_tuple(sequence<0>{},
sequence<1>{},
sequence<2>{},
sequence<3>{},
sequence<4>{},
sequence<5>{}),
make_tuple(sequence<0>{},
sequence<1>{},
sequence<2>{},
sequence<3>{},
sequence<4>{},
sequence<5>{}));
const auto in_gemmn_gemmktotal_grid_desc = transform_tensor_descriptor(
in_n_y_ho_x_wo_c_grid_desc,
make_tuple(make_merge_transform(make_tuple(Z_, Y_, X_, C_)),
make_merge_transform(make_tuple(N_, Do_, Ho_, Wo_))),
make_tuple(sequence<1, 3, 5, 7>{}, sequence<0, 2, 4, 6>{}),
make_tuple(sequence<1>{}, sequence<0>{}));
// Input tensor transformation, part 2.
// [N, Zip, Hip, Wip, Gm, C] -> [N, (Z, Zo), (Y, Wo), (X, Wo), Gm, C]
const auto in_n_z_do_y_ho_x_wo_gm_c_grid_desc = transform_tensor_descriptor(
in_n_dip_hip_wip_gm_c_grid_desc,
make_tuple(make_pass_through_transform(N_),
make_embed_transform(make_tuple(Z_, Do_),
make_tuple(ConvDilationD_, ConvStrideD_)),
make_embed_transform(make_tuple(Y_, Ho_),
make_tuple(ConvDilationH_, ConvStrideH_)),
make_embed_transform(make_tuple(X_, Wo_),
make_tuple(ConvDilationW_, ConvStrideW_)),
make_pass_through_transform(NumGroupsToMerge),
make_pass_through_transform(C_)),
make_tuple(sequence<0>{},
sequence<1>{},
sequence<2>{},
sequence<3>{},
sequence<4>{},
sequence<5>{}),
make_tuple(sequence<0>{},
sequence<1, 2>{},
sequence<3, 4>{},
sequence<5, 6>{},
sequence<7>{},
sequence<8>{}));
return make_tuple(out_grid_desc, in_gemmn_gemmktotal_grid_desc, wei_grid_desc);
// Input tensor transformation, part 3.
// [0, 1, 2, 3, 4, 5, 6, 7, 8] -> [0, 1]
// [N, Z, Do, Y, Ho, X, Wo, Gm, C] -> [(N*Do*Ho*Wo), (Z*Y*X*Gm*C)]
const auto in_gemm_k_gemm_n_grid_desc = transform_tensor_descriptor(
in_n_z_do_y_ho_x_wo_gm_c_grid_desc,
make_tuple(make_merge_transform(make_tuple(Z_, Y_, X_, NumGroupsToMerge, C_)),
make_merge_transform(make_tuple(N_, Do_, Ho_, Wo_))),
make_tuple(sequence<1, 3, 5, 7, 8>{}, sequence<0, 2, 4, 6>{}),
make_tuple(sequence<1>{}, sequence<0>{}));
return make_tuple(
out_gemm_k_gemm_m_grid_desc, in_gemm_k_gemm_n_grid_desc, wei_grid_desc);
}
else
{
const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor(
in_grid_desc,
make_tuple(make_pass_through_transform(N_),
make_pad_transform(Di_, InLeftPadD_, InRightPadD_),
make_pad_transform(Hi_, InLeftPadH_, InRightPadH_),
make_pad_transform(Wi_, InLeftPadW_, InRightPadW_),
make_pass_through_transform(C_)),
make_tuple(
sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}, sequence<4>{}),
make_tuple(
sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}, sequence<4>{}));
const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor(
in_n_hip_wip_c_grid_desc,
make_tuple(make_pass_through_transform(N_),
make_embed_transform(make_tuple(Z_, Do_),
make_tuple(ConvDilationD_, ConvStrideD_)),
make_embed_transform(make_tuple(Y_, Ho_),
make_tuple(ConvDilationH_, ConvStrideH_)),
make_embed_transform(make_tuple(X_, Wo_),
make_tuple(ConvDilationW_, ConvStrideW_)),
make_pass_through_transform(C_)),
make_tuple(
sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}, sequence<4>{}),
make_tuple(sequence<0>{},
sequence<1, 2>{},
sequence<3, 4>{},
sequence<5, 6>{},
sequence<7>{}));
const auto in_gemmn_gemmktotal_grid_desc = transform_tensor_descriptor(
in_n_y_ho_x_wo_c_grid_desc,
make_tuple(make_merge_transform(make_tuple(Z_, Y_, X_, C_)),
make_merge_transform(make_tuple(N_, Do_, Ho_, Wo_))),
make_tuple(sequence<1, 3, 5, 7>{}, sequence<0, 2, 4, 6>{}),
make_tuple(sequence<1>{}, sequence<0>{}));
return make_tuple(out_grid_desc, in_gemmn_gemmktotal_grid_desc, wei_grid_desc);
}
}
IndexType G_, N_;
@@ -668,7 +1069,6 @@ struct TransformConvBwdWeightToGemm
IndexType ConvDilationD_, ConvDilationH_, ConvDilationW_;
IndexType InLeftPadD_, InLeftPadH_, InLeftPadW_;
IndexType InRightPadD_, InRightPadH_, InRightPadW_;
IndexType ZYX_;
};
} // namespace ck_tile

View File

@@ -13,10 +13,10 @@ template <index_t NDimSpatial,
index_t VectorSizeA,
index_t VectorSizeB,
index_t VectorSizeC,
index_t NumGroupsToMerge = 1,
bool SplitN = false,
typename ADataType = float,
typename CDataType = float,
index_t NumGroupsToMerge = 1,
typename IndexType = index_t>
struct TransformConvFwdToGemm
{