mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 18:17:44 +00:00
Merge commit 'cafaeb6b7bac4e18b0a5341cd14f54224292a0c9' into develop
This commit is contained in:
22
Jenkinsfile
vendored
22
Jenkinsfile
vendored
@@ -20,6 +20,28 @@ def failurePatterns = [
|
||||
[pattern: /cat: .* No such file or directory/, description: "GPU not found"],
|
||||
]
|
||||
|
||||
// Given a pattern, check if the log contains the pattern and return the context.
|
||||
def checkForPattern(pattern, log) {
|
||||
def lines = log.split('\n')
|
||||
for (int i = 0; i < lines.size(); i++) {
|
||||
if (lines[i] =~ pattern) {
|
||||
echo "Found pattern match in log for ${pattern}"
|
||||
|
||||
// Get the two lines before and after failure.
|
||||
def contextStart = Math.max(0, i - 2)
|
||||
def contextEnd = Math.min(lines.size() - 1, i + 2)
|
||||
def contextLines = []
|
||||
for (int j = contextStart; j <= contextEnd; j++) {
|
||||
contextLines.add(lines[j])
|
||||
}
|
||||
|
||||
return [found: true, matchedLine: lines[i], context: contextLines.join('\n')]
|
||||
}
|
||||
}
|
||||
echo "No pattern match found in log for ${pattern}"
|
||||
return [found: false, matchedLine: "", context: ""]
|
||||
}
|
||||
|
||||
class Version {
|
||||
int major, minor, patch
|
||||
@Override
|
||||
|
||||
@@ -17,7 +17,7 @@
|
||||
#define CK_TILE_PIPELINE_COMPUTE_V4 3
|
||||
#define CK_TILE_PIPELINE_COMPUTE_V5 4
|
||||
|
||||
struct GemmConfigBase
|
||||
struct ConvConfigBase
|
||||
{
|
||||
static constexpr bool kPadM = true;
|
||||
static constexpr bool kPadN = true;
|
||||
@@ -29,6 +29,10 @@ struct GemmConfigBase
|
||||
static constexpr bool TransposeC = false;
|
||||
static constexpr bool UseStructuredSparsity = false;
|
||||
|
||||
static constexpr ck_tile::index_t VectorSizeA = 4;
|
||||
static constexpr ck_tile::index_t VectorSizeB = 8;
|
||||
static constexpr ck_tile::index_t VectorSizeC = 8;
|
||||
|
||||
static constexpr int kBlockPerCu = 1;
|
||||
static constexpr ck_tile::index_t TileParitionerGroupNum = 8;
|
||||
static constexpr ck_tile::index_t TileParitionerM01 = 4;
|
||||
@@ -37,10 +41,12 @@ struct GemmConfigBase
|
||||
static constexpr ck_tile::index_t NumWaveGroups = 1;
|
||||
static constexpr bool Preshuffle = false;
|
||||
static constexpr bool TiledMMAPermuteN = false;
|
||||
|
||||
static constexpr ck_tile::index_t NumGroupsToMerge = 1;
|
||||
};
|
||||
|
||||
template <typename PrecType>
|
||||
struct GemmConfigMemoryInterwave : public GemmConfigBase
|
||||
struct ConvConfigMemoryInterwave : public ConvConfigBase
|
||||
{
|
||||
// Memory friendly for Interwave scheduler
|
||||
static constexpr ck_tile::index_t M_Tile = 128;
|
||||
@@ -61,7 +67,7 @@ struct GemmConfigMemoryInterwave : public GemmConfigBase
|
||||
};
|
||||
|
||||
template <typename PrecType>
|
||||
struct GemmConfigMemoryIntrawave : public GemmConfigBase
|
||||
struct ConvConfigMemoryIntrawave : public ConvConfigBase
|
||||
{
|
||||
static constexpr ck_tile::index_t M_Tile = 128;
|
||||
static constexpr ck_tile::index_t N_Tile = 32;
|
||||
@@ -80,7 +86,7 @@ struct GemmConfigMemoryIntrawave : public GemmConfigBase
|
||||
};
|
||||
|
||||
template <typename PrecType>
|
||||
struct GemmConfigComputeV3 : public GemmConfigBase
|
||||
struct ConvConfigComputeV3 : public ConvConfigBase
|
||||
{
|
||||
// Compute V3 only support Intrawave scheduler
|
||||
static constexpr ck_tile::index_t M_Tile = 16;
|
||||
@@ -100,7 +106,7 @@ struct GemmConfigComputeV3 : public GemmConfigBase
|
||||
};
|
||||
|
||||
template <typename PrecType>
|
||||
struct GemmConfigComputeV3_1 : public GemmConfigBase
|
||||
struct ConvConfigComputeV3_1 : public ConvConfigBase
|
||||
{
|
||||
static constexpr ck_tile::index_t M_Tile = 256;
|
||||
static constexpr ck_tile::index_t N_Tile = 256;
|
||||
@@ -119,7 +125,7 @@ struct GemmConfigComputeV3_1 : public GemmConfigBase
|
||||
};
|
||||
|
||||
template <typename PrecType>
|
||||
struct GemmConfigComputeV3_2 : public GemmConfigBase
|
||||
struct ConvConfigComputeV3_2 : public ConvConfigBase
|
||||
{
|
||||
static constexpr ck_tile::index_t M_Tile = 128;
|
||||
static constexpr ck_tile::index_t N_Tile = 128;
|
||||
@@ -140,7 +146,7 @@ struct GemmConfigComputeV3_2 : public GemmConfigBase
|
||||
};
|
||||
|
||||
template <typename PrecType>
|
||||
struct GemmConfigComputeV3_WMMA : public GemmConfigBase
|
||||
struct ConvConfigComputeV3_WMMA : public ConvConfigBase
|
||||
{
|
||||
static constexpr ck_tile::index_t M_Tile = 128;
|
||||
static constexpr ck_tile::index_t N_Tile = 128;
|
||||
@@ -161,7 +167,7 @@ struct GemmConfigComputeV3_WMMA : public GemmConfigBase
|
||||
};
|
||||
|
||||
template <typename PrecType>
|
||||
struct GemmConfigComputeV4 : public GemmConfigBase
|
||||
struct ConvConfigComputeV4 : public ConvConfigBase
|
||||
{
|
||||
// Compute V4 only support Intrawave scheduler
|
||||
// Using the ping pong reader in the lds level
|
||||
@@ -182,7 +188,7 @@ struct GemmConfigComputeV4 : public GemmConfigBase
|
||||
};
|
||||
|
||||
template <typename PrecType>
|
||||
struct GemmConfigComputeV4_1 : public GemmConfigBase
|
||||
struct ConvConfigComputeV4_1 : public ConvConfigBase
|
||||
{
|
||||
static constexpr ck_tile::index_t M_Tile = 256;
|
||||
static constexpr ck_tile::index_t N_Tile = 256;
|
||||
@@ -201,7 +207,7 @@ struct GemmConfigComputeV4_1 : public GemmConfigBase
|
||||
};
|
||||
|
||||
template <typename PrecType>
|
||||
struct GemmConfigComputeV5 : public GemmConfigBase
|
||||
struct ConvConfigComputeV5 : public ConvConfigBase
|
||||
{
|
||||
static constexpr ck_tile::index_t M_Tile = 128;
|
||||
static constexpr ck_tile::index_t N_Tile = 128;
|
||||
@@ -220,6 +226,31 @@ struct GemmConfigComputeV5 : public GemmConfigBase
|
||||
static constexpr ck_tile::index_t NumWaNumWaveGroups = 2;
|
||||
};
|
||||
|
||||
template <typename PrecType>
|
||||
struct ConvConfigComputeV3_merged_groups : public ConvConfigBase
|
||||
{
|
||||
static constexpr ck_tile::index_t VectorSizeA = 4;
|
||||
static constexpr ck_tile::index_t VectorSizeB = 8;
|
||||
static constexpr ck_tile::index_t VectorSizeC = 8;
|
||||
|
||||
static constexpr ck_tile::index_t M_Tile = 16;
|
||||
static constexpr ck_tile::index_t N_Tile = 32;
|
||||
static constexpr ck_tile::index_t K_Tile = 32;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp = 1;
|
||||
static constexpr ck_tile::index_t N_Warp = 2;
|
||||
static constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = 32;
|
||||
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3;
|
||||
|
||||
static constexpr ck_tile::index_t NumGroupsToMerge = 2;
|
||||
};
|
||||
|
||||
template <typename InDataType, typename WeiDataType = InDataType, typename OutDataType = InDataType>
|
||||
struct ConvTypeConfig;
|
||||
|
||||
@@ -51,8 +51,8 @@ int run_grouped_conv_bwd_data_example(int argc, char* argv[])
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
#if CK_TILE_USE_WMMA
|
||||
return !run_grouped_conv_bwd_data_example<GemmConfigComputeV3_WMMA>(argc, argv);
|
||||
return !run_grouped_conv_bwd_data_example<ConvConfigComputeV3_WMMA>(argc, argv);
|
||||
#else
|
||||
return !run_grouped_conv_bwd_data_example<GemmConfigComputeV3>(argc, argv);
|
||||
return !run_grouped_conv_bwd_data_example<ConvConfigComputeV3>(argc, argv);
|
||||
#endif
|
||||
}
|
||||
|
||||
@@ -14,7 +14,7 @@
|
||||
#include "grouped_convolution_backward_weight_invoker.hpp"
|
||||
#include "run_grouped_convolution_bwd_weight_example.inc"
|
||||
|
||||
template <template <typename PrecType> typename GemmConfig>
|
||||
template <template <typename PrecType> typename ConvConfig>
|
||||
int run_grouped_conv_bwd_weight_example(ck_tile::ArgParser& arg_parser)
|
||||
{
|
||||
using Invoker = GroupedConvolutionBackwardWeightInvoker;
|
||||
@@ -27,14 +27,14 @@ int run_grouped_conv_bwd_weight_example(ck_tile::ArgParser& arg_parser)
|
||||
if(data_type == "fp16")
|
||||
{
|
||||
return run_grouped_conv_bwd_weight_example_prec_type<Invoker,
|
||||
GemmConfig<ck_tile::half_t>,
|
||||
ConvConfig<ck_tile::half_t>,
|
||||
ck_tile::half_t>(
|
||||
in_layout, wei_layout, out_layout, arg_parser);
|
||||
}
|
||||
else if(data_type == "bf16")
|
||||
{
|
||||
return run_grouped_conv_bwd_weight_example_prec_type<Invoker,
|
||||
GemmConfig<ck_tile::bf16_t>,
|
||||
ConvConfig<ck_tile::bf16_t>,
|
||||
ck_tile::bf16_t>(
|
||||
in_layout, wei_layout, out_layout, arg_parser);
|
||||
}
|
||||
@@ -54,9 +54,9 @@ int main(int argc, char* argv[])
|
||||
try
|
||||
{
|
||||
#if CK_TILE_USE_WMMA
|
||||
return !run_grouped_conv_bwd_weight_example<GemmConfigComputeV3_WMMA>(arg_parser);
|
||||
return !run_grouped_conv_bwd_weight_example<ConvConfigComputeV3_WMMA>(arg_parser);
|
||||
#else
|
||||
return !run_grouped_conv_bwd_weight_example<GemmConfigComputeV3>(arg_parser);
|
||||
return !run_grouped_conv_bwd_weight_example<ConvConfigComputeV3>(arg_parser);
|
||||
#endif
|
||||
}
|
||||
catch(const std::runtime_error& e)
|
||||
|
||||
@@ -7,7 +7,7 @@
|
||||
struct GroupedConvolutionBackwardWeightInvoker
|
||||
{
|
||||
template <ck_tile::index_t NDimSpatial,
|
||||
typename GemmConfig,
|
||||
typename ConvConfig,
|
||||
typename InDataType,
|
||||
typename WeiDataType,
|
||||
typename AccDataType,
|
||||
@@ -25,22 +25,22 @@ struct GroupedConvolutionBackwardWeightInvoker
|
||||
|
||||
// Implicit GEMM Traits
|
||||
using GemmShape = ck_tile::TileGemmShape<
|
||||
ck_tile::sequence<GemmConfig::M_Tile, GemmConfig::N_Tile, GemmConfig::K_Tile>,
|
||||
ck_tile::sequence<GemmConfig::M_Warp, GemmConfig::N_Warp, GemmConfig::K_Warp>,
|
||||
ck_tile::sequence<ConvConfig::M_Tile, ConvConfig::N_Tile, ConvConfig::K_Tile>,
|
||||
ck_tile::sequence<ConvConfig::M_Warp, ConvConfig::N_Warp, ConvConfig::K_Warp>,
|
||||
ck_tile::
|
||||
sequence<GemmConfig::M_Warp_Tile, GemmConfig::N_Warp_Tile, GemmConfig::K_Warp_Tile>,
|
||||
GemmConfig::PermuteA,
|
||||
GemmConfig::PermuteB>;
|
||||
sequence<ConvConfig::M_Warp_Tile, ConvConfig::N_Warp_Tile, ConvConfig::K_Warp_Tile>,
|
||||
ConvConfig::PermuteA,
|
||||
ConvConfig::PermuteB>;
|
||||
|
||||
constexpr ck_tile::index_t VectorSizeA = 4;
|
||||
constexpr ck_tile::index_t VectorSizeB = 8;
|
||||
constexpr ck_tile::index_t VectorSizeC = 8;
|
||||
constexpr ck_tile::index_t VectorSizeA = ConvConfig::VectorSizeA;
|
||||
constexpr ck_tile::index_t VectorSizeB = ConvConfig::VectorSizeB;
|
||||
constexpr ck_tile::index_t VectorSizeC = ConvConfig::VectorSizeC;
|
||||
|
||||
constexpr auto ConvSpec = ck_tile::ConvolutionSpecialization::Default;
|
||||
using TilePartitioner =
|
||||
ck_tile::GemmSpatiallyLocalTilePartitioner<GemmShape,
|
||||
GemmConfig::TileParitionerGroupNum,
|
||||
GemmConfig::TileParitionerM01>;
|
||||
ConvConfig::TileParitionerGroupNum,
|
||||
ConvConfig::TileParitionerM01>;
|
||||
using GroupedConvTraitsType = ck_tile::GroupedConvTraits<NDimSpatial,
|
||||
ConvSpec,
|
||||
InLayout,
|
||||
@@ -49,20 +49,21 @@ struct GroupedConvolutionBackwardWeightInvoker
|
||||
OutLayout,
|
||||
VectorSizeA,
|
||||
VectorSizeB,
|
||||
VectorSizeC>;
|
||||
VectorSizeC,
|
||||
ConvConfig::NumGroupsToMerge>;
|
||||
|
||||
using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits<
|
||||
GemmConfig::kPadM,
|
||||
GemmConfig::kPadN,
|
||||
GemmConfig::kPadK,
|
||||
GemmConfig::DoubleSmemBuffer,
|
||||
ConvConfig::kPadM,
|
||||
ConvConfig::kPadN,
|
||||
ConvConfig::kPadK,
|
||||
ConvConfig::DoubleSmemBuffer,
|
||||
typename GroupedConvTraitsType::GroupedConvImplicitGemmTraitsBwdWeight::AsLayout,
|
||||
typename GroupedConvTraitsType::GroupedConvImplicitGemmTraitsBwdWeight::BsLayout,
|
||||
typename GroupedConvTraitsType::GroupedConvImplicitGemmTraitsBwdWeight::CLayout,
|
||||
GemmConfig::TransposeC,
|
||||
GemmConfig::UseStructuredSparsity,
|
||||
ConvConfig::TransposeC,
|
||||
ConvConfig::UseStructuredSparsity,
|
||||
false, // Persistent,
|
||||
GemmConfig::NumWaveGroups>;
|
||||
ConvConfig::NumWaveGroups>;
|
||||
|
||||
using GemmPipelineProblem = ck_tile::GemmPipelineProblem<
|
||||
OutDataType,
|
||||
@@ -78,7 +79,7 @@ struct GroupedConvolutionBackwardWeightInvoker
|
||||
VectorSizeB>;
|
||||
|
||||
using BaseGemmPipeline = typename PipelineTypeTraits<
|
||||
GemmConfig::Pipeline>::template UniversalGemmPipeline<GemmPipelineProblem>;
|
||||
ConvConfig::Pipeline>::template UniversalGemmPipeline<GemmPipelineProblem>;
|
||||
|
||||
const ck_tile::index_t gemm_k =
|
||||
args.N_ * std::accumulate(args.output_spatial_lengths_.begin(),
|
||||
@@ -86,8 +87,8 @@ struct GroupedConvolutionBackwardWeightInvoker
|
||||
1,
|
||||
std::multiplies<ck_tile::index_t>());
|
||||
|
||||
const ck_tile::index_t k_grain = args.k_batch * GemmConfig::K_Tile;
|
||||
const ck_tile::index_t K_split = (gemm_k + k_grain - 1) / k_grain * GemmConfig::K_Tile;
|
||||
const ck_tile::index_t k_grain = args.k_batch * ConvConfig::K_Tile;
|
||||
const ck_tile::index_t K_split = (gemm_k + k_grain - 1) / k_grain * ConvConfig::K_Tile;
|
||||
const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split);
|
||||
const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop);
|
||||
const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop);
|
||||
@@ -98,7 +99,7 @@ struct GroupedConvolutionBackwardWeightInvoker
|
||||
const auto memory_operation_) {
|
||||
constexpr bool has_hot_loop_v = has_hot_loop_.value;
|
||||
constexpr auto tail_number_v = tail_number_.value;
|
||||
constexpr auto scheduler = GemmConfig::Scheduler;
|
||||
constexpr auto scheduler = ConvConfig::Scheduler;
|
||||
constexpr auto memory_operation = memory_operation_.value;
|
||||
|
||||
using UniversalGemmProblem =
|
||||
@@ -118,7 +119,7 @@ struct GroupedConvolutionBackwardWeightInvoker
|
||||
VectorSizeB>;
|
||||
|
||||
using GemmPipeline = typename PipelineTypeTraits<
|
||||
GemmConfig::Pipeline>::template GemmPipeline<UniversalGemmProblem>;
|
||||
ConvConfig::Pipeline>::template GemmPipeline<UniversalGemmProblem>;
|
||||
|
||||
using ConvEpilogue = ck_tile::CShuffleEpilogue<ck_tile::CShuffleEpilogueProblem<
|
||||
OutDataType,
|
||||
@@ -131,12 +132,12 @@ struct GroupedConvolutionBackwardWeightInvoker
|
||||
CDEElementWise,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
GemmConfig::M_Warp,
|
||||
GemmConfig::N_Warp,
|
||||
GemmConfig::M_Warp_Tile,
|
||||
GemmConfig::N_Warp_Tile,
|
||||
GemmConfig::K_Warp_Tile,
|
||||
GemmConfig::TransposeC,
|
||||
ConvConfig::M_Warp,
|
||||
ConvConfig::N_Warp,
|
||||
ConvConfig::M_Warp_Tile,
|
||||
ConvConfig::N_Warp_Tile,
|
||||
ConvConfig::K_Warp_Tile,
|
||||
ConvConfig::TransposeC,
|
||||
memory_operation,
|
||||
1,
|
||||
true,
|
||||
|
||||
@@ -13,9 +13,9 @@
|
||||
#include "grouped_convolution_utils.hpp"
|
||||
#include "grouped_convolution_backward_weight_two_stage_invoker.hpp"
|
||||
#include "run_grouped_convolution_bwd_weight_example.inc"
|
||||
#include "gemm_configs.hpp"
|
||||
#include "conv_configs.hpp"
|
||||
|
||||
template <template <typename PrecType> typename GemmConfig>
|
||||
template <template <typename PrecType> typename ConvConfig>
|
||||
int run_grouped_conv_bwd_weight_example(ck_tile::ArgParser& arg_parser)
|
||||
{
|
||||
using Invoker = GroupedConvolutionBackwardWeightTwoStageInvoker;
|
||||
@@ -28,14 +28,14 @@ int run_grouped_conv_bwd_weight_example(ck_tile::ArgParser& arg_parser)
|
||||
if(data_type == "fp16")
|
||||
{
|
||||
return run_grouped_conv_bwd_weight_example_prec_type<Invoker,
|
||||
GemmConfig<ck_tile::half_t>,
|
||||
ConvConfig<ck_tile::half_t>,
|
||||
ck_tile::half_t>(
|
||||
in_layout, wei_layout, out_layout, arg_parser);
|
||||
}
|
||||
else if(data_type == "bf16")
|
||||
{
|
||||
return run_grouped_conv_bwd_weight_example_prec_type<Invoker,
|
||||
GemmConfig<ck_tile::bf16_t>,
|
||||
ConvConfig<ck_tile::bf16_t>,
|
||||
ck_tile::bf16_t>(
|
||||
in_layout, wei_layout, out_layout, arg_parser);
|
||||
}
|
||||
@@ -55,9 +55,9 @@ int main(int argc, char* argv[])
|
||||
try
|
||||
{
|
||||
#if CK_TILE_USE_WMMA
|
||||
return !run_grouped_conv_bwd_weight_example<GemmConfigComputeV3_WMMA>(arg_parser);
|
||||
return !run_grouped_conv_bwd_weight_example<ConvConfigComputeV3_WMMA>(arg_parser);
|
||||
#else
|
||||
return !run_grouped_conv_bwd_weight_example<GemmConfigComputeV3>(arg_parser);
|
||||
return !run_grouped_conv_bwd_weight_example<ConvConfigComputeV3>(arg_parser);
|
||||
#endif
|
||||
}
|
||||
catch(const std::runtime_error& e)
|
||||
|
||||
@@ -7,7 +7,7 @@
|
||||
struct GroupedConvolutionBackwardWeightTwoStageInvoker
|
||||
{
|
||||
template <ck_tile::index_t NDimSpatial,
|
||||
typename GemmConfig,
|
||||
typename ConvConfig,
|
||||
typename InDataType,
|
||||
typename WeiDataType,
|
||||
typename AccDataType,
|
||||
@@ -27,12 +27,12 @@ struct GroupedConvolutionBackwardWeightTwoStageInvoker
|
||||
|
||||
// Implicit GEMM Traits
|
||||
using GemmShape = ck_tile::TileGemmShape<
|
||||
ck_tile::sequence<GemmConfig::M_Tile, GemmConfig::N_Tile, GemmConfig::K_Tile>,
|
||||
ck_tile::sequence<GemmConfig::M_Warp, GemmConfig::N_Warp, GemmConfig::K_Warp>,
|
||||
ck_tile::sequence<ConvConfig::M_Tile, ConvConfig::N_Tile, ConvConfig::K_Tile>,
|
||||
ck_tile::sequence<ConvConfig::M_Warp, ConvConfig::N_Warp, ConvConfig::K_Warp>,
|
||||
ck_tile::
|
||||
sequence<GemmConfig::M_Warp_Tile, GemmConfig::N_Warp_Tile, GemmConfig::K_Warp_Tile>,
|
||||
GemmConfig::PermuteA,
|
||||
GemmConfig::PermuteB>;
|
||||
sequence<ConvConfig::M_Warp_Tile, ConvConfig::N_Warp_Tile, ConvConfig::K_Warp_Tile>,
|
||||
ConvConfig::PermuteA,
|
||||
ConvConfig::PermuteB>;
|
||||
|
||||
constexpr ck_tile::index_t VectorSizeA = 4;
|
||||
constexpr ck_tile::index_t VectorSizeB = 8;
|
||||
@@ -41,8 +41,8 @@ struct GroupedConvolutionBackwardWeightTwoStageInvoker
|
||||
constexpr auto ConvSpec = ck_tile::ConvolutionSpecialization::Default;
|
||||
using TilePartitioner =
|
||||
ck_tile::GemmSpatiallyLocalTilePartitioner<GemmShape,
|
||||
GemmConfig::TileParitionerGroupNum,
|
||||
GemmConfig::TileParitionerM01>;
|
||||
ConvConfig::TileParitionerGroupNum,
|
||||
ConvConfig::TileParitionerM01>;
|
||||
using GroupedConvTraitsType = ck_tile::GroupedConvTraits<NDimSpatial,
|
||||
ConvSpec,
|
||||
InLayout,
|
||||
@@ -54,17 +54,17 @@ struct GroupedConvolutionBackwardWeightTwoStageInvoker
|
||||
VectorSizeC>;
|
||||
|
||||
using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits<
|
||||
GemmConfig::kPadM,
|
||||
GemmConfig::kPadN,
|
||||
GemmConfig::kPadK,
|
||||
GemmConfig::DoubleSmemBuffer,
|
||||
ConvConfig::kPadM,
|
||||
ConvConfig::kPadN,
|
||||
ConvConfig::kPadK,
|
||||
ConvConfig::DoubleSmemBuffer,
|
||||
typename GroupedConvTraitsType::GroupedConvImplicitGemmTraitsBwdWeight::AsLayout,
|
||||
typename GroupedConvTraitsType::GroupedConvImplicitGemmTraitsBwdWeight::BsLayout,
|
||||
typename GroupedConvTraitsType::GroupedConvImplicitGemmTraitsBwdWeight::CLayout,
|
||||
GemmConfig::TransposeC,
|
||||
GemmConfig::UseStructuredSparsity,
|
||||
ConvConfig::TransposeC,
|
||||
ConvConfig::UseStructuredSparsity,
|
||||
false, // Persistent,
|
||||
GemmConfig::NumWaveGroups>;
|
||||
ConvConfig::NumWaveGroups>;
|
||||
|
||||
using GemmPipelineProblem = ck_tile::GemmPipelineProblem<
|
||||
OutDataType,
|
||||
@@ -80,7 +80,7 @@ struct GroupedConvolutionBackwardWeightTwoStageInvoker
|
||||
VectorSizeB>;
|
||||
|
||||
using BaseGemmPipeline = typename PipelineTypeTraits<
|
||||
GemmConfig::Pipeline>::template UniversalGemmPipeline<GemmPipelineProblem>;
|
||||
ConvConfig::Pipeline>::template UniversalGemmPipeline<GemmPipelineProblem>;
|
||||
|
||||
const ck_tile::index_t gemm_k =
|
||||
args.N_ * std::accumulate(args.output_spatial_lengths_.begin(),
|
||||
@@ -88,8 +88,8 @@ struct GroupedConvolutionBackwardWeightTwoStageInvoker
|
||||
1,
|
||||
std::multiplies<ck_tile::index_t>());
|
||||
|
||||
const ck_tile::index_t k_grain = args.k_batch * GemmConfig::K_Tile;
|
||||
const ck_tile::index_t K_split = (gemm_k + k_grain - 1) / k_grain * GemmConfig::K_Tile;
|
||||
const ck_tile::index_t k_grain = args.k_batch * ConvConfig::K_Tile;
|
||||
const ck_tile::index_t K_split = (gemm_k + k_grain - 1) / k_grain * ConvConfig::K_Tile;
|
||||
const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split);
|
||||
const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop);
|
||||
const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop);
|
||||
@@ -100,7 +100,7 @@ struct GroupedConvolutionBackwardWeightTwoStageInvoker
|
||||
const auto memory_operation_) {
|
||||
constexpr bool has_hot_loop_v = has_hot_loop_.value;
|
||||
constexpr auto tail_number_v = tail_number_.value;
|
||||
constexpr auto scheduler = GemmConfig::Scheduler;
|
||||
constexpr auto scheduler = ConvConfig::Scheduler;
|
||||
constexpr auto memory_operation = memory_operation_.value;
|
||||
|
||||
using UniversalGemmProblem =
|
||||
@@ -120,7 +120,7 @@ struct GroupedConvolutionBackwardWeightTwoStageInvoker
|
||||
VectorSizeB>;
|
||||
|
||||
using GemmPipeline = typename PipelineTypeTraits<
|
||||
GemmConfig::Pipeline>::template GemmPipeline<UniversalGemmProblem>;
|
||||
ConvConfig::Pipeline>::template GemmPipeline<UniversalGemmProblem>;
|
||||
|
||||
using ConvEpilogue = ck_tile::CShuffleEpilogue<ck_tile::CShuffleEpilogueProblem<
|
||||
OutDataType, // A: Out
|
||||
@@ -133,11 +133,11 @@ struct GroupedConvolutionBackwardWeightTwoStageInvoker
|
||||
CDEElementWise,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
GemmConfig::M_Warp,
|
||||
GemmConfig::N_Warp,
|
||||
GemmConfig::M_Warp_Tile,
|
||||
GemmConfig::N_Warp_Tile,
|
||||
GemmConfig::K_Warp_Tile,
|
||||
ConvConfig::M_Warp,
|
||||
ConvConfig::N_Warp,
|
||||
ConvConfig::M_Warp_Tile,
|
||||
ConvConfig::N_Warp_Tile,
|
||||
ConvConfig::K_Warp_Tile,
|
||||
GemmPipelineProblem::TransposeC,
|
||||
memory_operation,
|
||||
1,
|
||||
|
||||
@@ -51,8 +51,8 @@ int run_grouped_conv_fwd_example(int argc, char* argv[])
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
#if CK_TILE_USE_WMMA
|
||||
return !run_grouped_conv_fwd_example<GemmConfigComputeV3_WMMA>(argc, argv);
|
||||
return !run_grouped_conv_fwd_example<ConvConfigComputeV3_WMMA>(argc, argv);
|
||||
#else
|
||||
return !run_grouped_conv_fwd_example<GemmConfigComputeV3>(argc, argv);
|
||||
return !run_grouped_conv_fwd_example<ConvConfigComputeV3>(argc, argv);
|
||||
#endif
|
||||
}
|
||||
|
||||
@@ -32,9 +32,10 @@ struct GroupedConvolutionForwardInvoker
|
||||
GemmConfig::PermuteA,
|
||||
GemmConfig::PermuteB>;
|
||||
|
||||
constexpr ck_tile::index_t VectorSizeA = 8;
|
||||
constexpr ck_tile::index_t VectorSizeB = 8;
|
||||
constexpr ck_tile::index_t VectorSizeC = 8;
|
||||
constexpr ck_tile::index_t VectorSizeA = 8;
|
||||
constexpr ck_tile::index_t VectorSizeB = 8;
|
||||
constexpr ck_tile::index_t VectorSizeC = 8;
|
||||
constexpr ck_tile::index_t NumGroupsToMerge = 1;
|
||||
|
||||
constexpr auto ConvSpec = ck_tile::ConvolutionSpecialization::Default;
|
||||
using TilePartitioner =
|
||||
@@ -50,6 +51,7 @@ struct GroupedConvolutionForwardInvoker
|
||||
VectorSizeA,
|
||||
VectorSizeB,
|
||||
VectorSizeC,
|
||||
NumGroupsToMerge,
|
||||
CDElementWise>;
|
||||
|
||||
using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits<
|
||||
|
||||
@@ -11,24 +11,11 @@
|
||||
#include "ck_tile/ops/epilogue.hpp"
|
||||
#include "ck_tile/ops/gemm.hpp"
|
||||
#include "ck_tile/ops/grouped_convolution.hpp"
|
||||
#include "gemm_configs.hpp"
|
||||
#include "conv_configs.hpp"
|
||||
using MemoryOpSet =
|
||||
std::integral_constant<ck_tile::memory_operation_enum, ck_tile::memory_operation_enum::set>;
|
||||
using MemoryOpAtomicAdd = std::integral_constant<ck_tile::memory_operation_enum,
|
||||
ck_tile::memory_operation_enum::atomic_add>;
|
||||
struct GemmWarpConfig_Mfma
|
||||
{
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = 16;
|
||||
};
|
||||
|
||||
struct GemmWarpConfig_Wmma
|
||||
{
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = 16;
|
||||
};
|
||||
|
||||
template <typename InDataType, typename WeiDataType, typename AccDataType, typename OutDataType>
|
||||
auto calculate_rtol_atol(const ck_tile::index_t GemmK,
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
#pragma once
|
||||
|
||||
template <ck_tile::index_t NDimSpatial,
|
||||
typename GemmConfig,
|
||||
typename ConvConfig,
|
||||
typename Invoker,
|
||||
typename InDataType,
|
||||
typename WeiDataType,
|
||||
@@ -17,7 +17,7 @@ float invoke_grouped_conv_bwd_weight(ck_tile::GroupedConvBwdWeightHostArgs& args
|
||||
int n_repeat)
|
||||
{
|
||||
float ave_time = Invoker::template grouped_conv_bwd_weight<NDimSpatial,
|
||||
GemmConfig,
|
||||
ConvConfig,
|
||||
InDataType,
|
||||
WeiDataType,
|
||||
AccDataType,
|
||||
@@ -31,7 +31,7 @@ float invoke_grouped_conv_bwd_weight(ck_tile::GroupedConvBwdWeightHostArgs& args
|
||||
}
|
||||
|
||||
template <ck_tile::index_t NDimSpatial,
|
||||
typename GemmConfig,
|
||||
typename ConvConfig,
|
||||
typename Invoker,
|
||||
typename InDataType,
|
||||
typename WeiDataType = InDataType,
|
||||
@@ -131,7 +131,7 @@ int run_grouped_conv_bwd_weight_example_with_layouts(ck_tile::ArgParser& arg_par
|
||||
std::cout << "output: " << output.mDesc << std::endl;
|
||||
|
||||
float ave_time = invoke_grouped_conv_bwd_weight<NDimSpatial,
|
||||
GemmConfig,
|
||||
ConvConfig,
|
||||
Invoker,
|
||||
InDataType,
|
||||
WeiDataType,
|
||||
@@ -193,7 +193,7 @@ int run_grouped_conv_bwd_weight_example_with_layouts(ck_tile::ArgParser& arg_par
|
||||
}
|
||||
|
||||
template <typename Invoker,
|
||||
typename GemmConfig,
|
||||
typename ConvConfig,
|
||||
typename InPrecType,
|
||||
typename WeiPrecType = InPrecType,
|
||||
typename OutPrecType = InPrecType>
|
||||
@@ -217,7 +217,7 @@ int run_grouped_conv_bwd_weight_example_prec_type(std::string in_layout,
|
||||
if(in_layout == "NWGC" && wei_layout == "GKXC" && out_layout == "NWGK")
|
||||
{
|
||||
return run_grouped_conv_bwd_weight_example_with_layouts<ck_tile::number<1>{},
|
||||
GemmConfig,
|
||||
ConvConfig,
|
||||
Invoker,
|
||||
InPrecType,
|
||||
WeiPrecType,
|
||||
@@ -227,7 +227,7 @@ int run_grouped_conv_bwd_weight_example_prec_type(std::string in_layout,
|
||||
else if(in_layout == "NHWGC" && wei_layout == "GKYXC" && out_layout == "NHWGK")
|
||||
{
|
||||
return run_grouped_conv_bwd_weight_example_with_layouts<ck_tile::number<2>{},
|
||||
GemmConfig,
|
||||
ConvConfig,
|
||||
Invoker,
|
||||
InPrecType,
|
||||
WeiPrecType,
|
||||
@@ -237,7 +237,7 @@ int run_grouped_conv_bwd_weight_example_prec_type(std::string in_layout,
|
||||
else if(in_layout == "NDHWGC" && wei_layout == "GKZYXC" && out_layout == "NDHWGK")
|
||||
{
|
||||
return run_grouped_conv_bwd_weight_example_with_layouts<ck_tile::number<3>{},
|
||||
GemmConfig,
|
||||
ConvConfig,
|
||||
Invoker,
|
||||
InPrecType,
|
||||
WeiPrecType,
|
||||
|
||||
@@ -0,0 +1,350 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
// InstanceTraits specialization for DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
//
|
||||
// CRITICAL MAINTENANCE NOTE:
|
||||
// This InstanceTraits file MUST be kept strictly in sync with the device implementation header:
|
||||
// ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp
|
||||
// "In sync" means that the template parameter order, names, and types in the declaration below
|
||||
// MUST EXACTLY MATCH those in the device implementation. If these diverge, you may encounter
|
||||
// compilation errors, subtle template instantiation mismatches, or silent runtime bugs that are
|
||||
// difficult to diagnose. Always update both files together and review changes carefully.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "instance_traits.hpp"
|
||||
|
||||
// Forward declaration to avoid circular dependency.
|
||||
namespace ck::tensor_operation::device {
|
||||
|
||||
template <ck::index_t NDimSpatial,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename DsLayout,
|
||||
typename ELayout,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename AccDataType,
|
||||
typename CShuffleDataType,
|
||||
typename DsDataType,
|
||||
typename EDataType,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CDEElementwiseOperation,
|
||||
ConvolutionForwardSpecialization ConvForwardSpecialization,
|
||||
GemmSpecialization GemmSpec,
|
||||
ck::index_t NumGemmKPrefetchStage,
|
||||
ck::index_t BlockSize,
|
||||
ck::index_t MPerBlock,
|
||||
ck::index_t NPerBlock,
|
||||
ck::index_t KPerBlock,
|
||||
ck::index_t AK1,
|
||||
ck::index_t BK1,
|
||||
ck::index_t MPerXDL,
|
||||
ck::index_t NPerXDL,
|
||||
ck::index_t MXdlPerWave,
|
||||
ck::index_t NXdlPerWave,
|
||||
typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
|
||||
typename ABlockTransferThreadClusterArrangeOrder,
|
||||
typename ABlockTransferSrcAccessOrder,
|
||||
ck::index_t ABlockTransferSrcVectorDim,
|
||||
ck::index_t ABlockTransferSrcScalarPerVector,
|
||||
ck::index_t ABlockTransferDstScalarPerVector_AK1,
|
||||
ck::index_t ABlockLdsExtraM,
|
||||
typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
|
||||
typename BBlockTransferThreadClusterArrangeOrder,
|
||||
typename BBlockTransferSrcAccessOrder,
|
||||
ck::index_t BBlockTransferSrcVectorDim,
|
||||
ck::index_t BBlockTransferSrcScalarPerVector,
|
||||
ck::index_t BBlockTransferDstScalarPerVector_BK1,
|
||||
ck::index_t BBlockLdsExtraN,
|
||||
ck::index_t CShuffleMXdlPerWavePerShuffle,
|
||||
ck::index_t CShuffleNXdlPerWavePerShuffle,
|
||||
typename CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
ck::index_t CDEBlockTransferScalarPerVector_NPerBlock,
|
||||
typename AComputeDataType,
|
||||
typename BComputeDataType,
|
||||
LoopScheduler LoopSched,
|
||||
ck::index_t NumGroupsToMerge>
|
||||
struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle;
|
||||
|
||||
} // namespace ck::tensor_operation::device
|
||||
|
||||
namespace ck_tile::reflect {
|
||||
|
||||
// Specialization for DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
template <ck::index_t NDimSpatial,
|
||||
typename ALayout_,
|
||||
typename BLayout_,
|
||||
typename DsLayout_,
|
||||
typename ELayout_,
|
||||
typename ADataType_,
|
||||
typename BDataType_,
|
||||
typename AccDataType_,
|
||||
typename CShuffleDataType_,
|
||||
typename DsDataType_,
|
||||
typename EDataType_,
|
||||
typename AElementwiseOperation_,
|
||||
typename BElementwiseOperation_,
|
||||
typename CDEElementwiseOperation_,
|
||||
ck::tensor_operation::device::ConvolutionForwardSpecialization ConvForwardSpecialization,
|
||||
ck::tensor_operation::device::GemmSpecialization GemmSpec,
|
||||
ck::index_t NumGemmKPrefetchStage,
|
||||
ck::index_t BlockSize,
|
||||
ck::index_t MPerBlock,
|
||||
ck::index_t NPerBlock,
|
||||
ck::index_t KPerBlock,
|
||||
ck::index_t AK1,
|
||||
ck::index_t BK1,
|
||||
ck::index_t MPerXDL,
|
||||
ck::index_t NPerXDL,
|
||||
ck::index_t MXdlPerWave,
|
||||
ck::index_t NXdlPerWave,
|
||||
typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
|
||||
typename ABlockTransferThreadClusterArrangeOrder,
|
||||
typename ABlockTransferSrcAccessOrder_,
|
||||
ck::index_t ABlockTransferSrcVectorDim,
|
||||
ck::index_t ABlockTransferSrcScalarPerVector,
|
||||
ck::index_t ABlockTransferDstScalarPerVector_AK1,
|
||||
ck::index_t ABlockLdsExtraM,
|
||||
typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
|
||||
typename BBlockTransferThreadClusterArrangeOrder,
|
||||
typename BBlockTransferSrcAccessOrder_,
|
||||
ck::index_t BBlockTransferSrcVectorDim,
|
||||
ck::index_t BBlockTransferSrcScalarPerVector,
|
||||
ck::index_t BBlockTransferDstScalarPerVector_BK1,
|
||||
ck::index_t BBlockLdsExtraN,
|
||||
ck::index_t CShuffleMXdlPerWavePerShuffle,
|
||||
ck::index_t CShuffleNXdlPerWavePerShuffle,
|
||||
typename CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
ck::index_t CDEBlockTransferScalarPerVector_NPerBlock,
|
||||
typename AComputeDataType_,
|
||||
typename BComputeDataType_,
|
||||
ck::LoopScheduler LoopSched,
|
||||
ck::index_t NumGroupsToMerge>
|
||||
struct InstanceTraits<ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<
|
||||
NDimSpatial,
|
||||
ALayout_,
|
||||
BLayout_,
|
||||
DsLayout_,
|
||||
ELayout_,
|
||||
ADataType_,
|
||||
BDataType_,
|
||||
AccDataType_,
|
||||
CShuffleDataType_,
|
||||
DsDataType_,
|
||||
EDataType_,
|
||||
AElementwiseOperation_,
|
||||
BElementwiseOperation_,
|
||||
CDEElementwiseOperation_,
|
||||
ConvForwardSpecialization,
|
||||
GemmSpec,
|
||||
NumGemmKPrefetchStage,
|
||||
BlockSize,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
AK1,
|
||||
BK1,
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
MXdlPerWave,
|
||||
NXdlPerWave,
|
||||
ABlockTransferThreadClusterLengths_AK0_M_AK1,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
ABlockTransferSrcAccessOrder_,
|
||||
ABlockTransferSrcVectorDim,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
ABlockTransferDstScalarPerVector_AK1,
|
||||
ABlockLdsExtraM,
|
||||
BBlockTransferThreadClusterLengths_BK0_N_BK1,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
BBlockTransferSrcAccessOrder_,
|
||||
BBlockTransferSrcVectorDim,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
BBlockTransferDstScalarPerVector_BK1,
|
||||
BBlockLdsExtraN,
|
||||
CShuffleMXdlPerWavePerShuffle,
|
||||
CShuffleNXdlPerWavePerShuffle,
|
||||
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
CDEBlockTransferScalarPerVector_NPerBlock,
|
||||
AComputeDataType_,
|
||||
BComputeDataType_,
|
||||
LoopSched,
|
||||
NumGroupsToMerge>>
|
||||
{
|
||||
// Spatial dimension
|
||||
static constexpr int kSpatialDim = NDimSpatial;
|
||||
|
||||
// Layout types
|
||||
using ALayout = ALayout_;
|
||||
using BLayout = BLayout_;
|
||||
using DsLayout = DsLayout_;
|
||||
using ELayout = ELayout_;
|
||||
|
||||
// Data types
|
||||
using ADataType = ADataType_;
|
||||
using BDataType = BDataType_;
|
||||
using AccDataType = AccDataType_;
|
||||
using CShuffleDataType = CShuffleDataType_;
|
||||
using DsDataType = DsDataType_;
|
||||
using EDataType = EDataType_;
|
||||
|
||||
// Element-wise operations
|
||||
using AElementwiseOperation = AElementwiseOperation_;
|
||||
using BElementwiseOperation = BElementwiseOperation_;
|
||||
using CDEElementwiseOperation = CDEElementwiseOperation_;
|
||||
|
||||
// Specialization
|
||||
static constexpr ck::tensor_operation::device::ConvolutionForwardSpecialization
|
||||
kConvForwardSpecialization = ConvForwardSpecialization;
|
||||
static constexpr ck::tensor_operation::device::GemmSpecialization kGemmSpecialization =
|
||||
GemmSpec;
|
||||
|
||||
// Prefetch stage
|
||||
static constexpr int kNumGemmKPrefetchStage = NumGemmKPrefetchStage;
|
||||
|
||||
// Block configuration
|
||||
static constexpr int kBlockSize = BlockSize;
|
||||
static constexpr int kMPerBlock = MPerBlock;
|
||||
static constexpr int kNPerBlock = NPerBlock;
|
||||
static constexpr int kKPerBlock = KPerBlock;
|
||||
|
||||
// Tuning parameters
|
||||
static constexpr int kAK1 = AK1;
|
||||
static constexpr int kBK1 = BK1;
|
||||
static constexpr int kMPerXDL = MPerXDL;
|
||||
static constexpr int kNPerXDL = NPerXDL;
|
||||
static constexpr int kMXdlPerWave = MXdlPerWave;
|
||||
static constexpr int kNXdlPerWave = NXdlPerWave;
|
||||
|
||||
// A block transfer thread cluster dimensions (converted to std::array)
|
||||
static constexpr auto kAThreadClusterLengths =
|
||||
detail::SequenceToArray<ABlockTransferThreadClusterLengths_AK0_M_AK1>::value;
|
||||
static constexpr auto kAThreadClusterArrangeOrder =
|
||||
detail::SequenceToArray<ABlockTransferThreadClusterArrangeOrder>::value;
|
||||
static constexpr auto kABlockTransferSrcAccessOrder =
|
||||
detail::SequenceToArray<ABlockTransferSrcAccessOrder_>::value;
|
||||
static constexpr int kABlockTransferSrcVectorDim = ABlockTransferSrcVectorDim;
|
||||
static constexpr int kABlockTransferSrcScalarPerVector = ABlockTransferSrcScalarPerVector;
|
||||
static constexpr int kABlockTransferDstScalarPerVectorK1 = ABlockTransferDstScalarPerVector_AK1;
|
||||
static constexpr int kABlockLdsExtraM = ABlockLdsExtraM;
|
||||
|
||||
// B block transfer thread cluster dimensions (converted to std::array)
|
||||
static constexpr auto kBThreadClusterLengths =
|
||||
detail::SequenceToArray<BBlockTransferThreadClusterLengths_BK0_N_BK1>::value;
|
||||
static constexpr auto kBThreadClusterArrangeOrder =
|
||||
detail::SequenceToArray<BBlockTransferThreadClusterArrangeOrder>::value;
|
||||
static constexpr auto kBBlockTransferSrcAccessOrder =
|
||||
detail::SequenceToArray<BBlockTransferSrcAccessOrder_>::value;
|
||||
static constexpr int kBBlockTransferSrcVectorDim = BBlockTransferSrcVectorDim;
|
||||
static constexpr int kBBlockTransferSrcScalarPerVector = BBlockTransferSrcScalarPerVector;
|
||||
static constexpr int kBBlockTransferDstScalarPerVectorK1 = BBlockTransferDstScalarPerVector_BK1;
|
||||
static constexpr int kBBlockLdsExtraN = BBlockLdsExtraN;
|
||||
|
||||
// C shuffle parameters (converted to std::array)
|
||||
static constexpr int kCShuffleMXdlPerWavePerShuffle = CShuffleMXdlPerWavePerShuffle;
|
||||
static constexpr int kCShuffleNXdlPerWavePerShuffle = CShuffleNXdlPerWavePerShuffle;
|
||||
static constexpr auto kCThreadClusterLengths = detail::SequenceToArray<
|
||||
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock>::value;
|
||||
static constexpr int kCBlockTransferScalarPerVector = CDEBlockTransferScalarPerVector_NPerBlock;
|
||||
|
||||
// Compute data types
|
||||
using AComputeDataType = AComputeDataType_;
|
||||
using BComputeDataType = BComputeDataType_;
|
||||
|
||||
// Loop scheduler
|
||||
static constexpr ck::LoopScheduler kLoopScheduler = LoopSched;
|
||||
|
||||
// Groups to merge
|
||||
static constexpr int kNumGroupsToMerge = NumGroupsToMerge;
|
||||
|
||||
// Static member function to generate instance string
|
||||
static std::string instance_string()
|
||||
{
|
||||
std::ostringstream oss;
|
||||
|
||||
// Kernel type name
|
||||
oss << "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle";
|
||||
|
||||
// Template parameters in exact order matching InstanceTraits member order
|
||||
oss << "<" << kSpatialDim; // 1. NDimSpatial
|
||||
oss << "," << detail::layout_name<ALayout>(); // 2. ALayout
|
||||
oss << "," << detail::layout_name<BLayout>(); // 3. BLayout
|
||||
oss << "," << detail::tuple_name<DsLayout>(); // 4. DsLayout
|
||||
oss << "," << detail::layout_name<ELayout>(); // 5. ELayout
|
||||
oss << "," << detail::type_name<ADataType>(); // 6. ADataType
|
||||
oss << "," << detail::type_name<BDataType>(); // 7. BDataType
|
||||
oss << "," << detail::type_name<AccDataType>(); // 8. AccDataType
|
||||
oss << "," << detail::type_name<CShuffleDataType>(); // 9. CShuffleDataType
|
||||
oss << "," << detail::tuple_name<DsDataType>(); // 10. DsDataType
|
||||
oss << "," << detail::type_name<EDataType>(); // 11. EDataType
|
||||
oss << ","
|
||||
<< detail::elementwise_op_name<AElementwiseOperation>(); // 12. AElementwiseOperation
|
||||
oss << ","
|
||||
<< detail::elementwise_op_name<BElementwiseOperation>(); // 13. BElementwiseOperation
|
||||
oss << ","
|
||||
<< detail::elementwise_op_name<CDEElementwiseOperation>(); // 14.
|
||||
// CDEElementwiseOperation
|
||||
oss << ","
|
||||
<< detail::conv_fwd_spec_name(
|
||||
kConvForwardSpecialization); // 15. ConvForwardSpecialization
|
||||
oss << "," << detail::gemm_spec_name(kGemmSpecialization); // 16. GemmSpec
|
||||
oss << "," << kNumGemmKPrefetchStage; // 17. NumGemmKPrefetchStage
|
||||
oss << "," << kBlockSize; // 18. BlockSize
|
||||
oss << "," << kMPerBlock; // 19. MPerBlock
|
||||
oss << "," << kNPerBlock; // 20. NPerBlock
|
||||
oss << "," << kKPerBlock; // 21. KPerBlock
|
||||
oss << "," << kAK1; // 22. AK1
|
||||
oss << "," << kBK1; // 23. BK1
|
||||
oss << "," << kMPerXDL; // 24. MPerXDL
|
||||
oss << "," << kNPerXDL; // 25. NPerXDL
|
||||
oss << "," << kMXdlPerWave; // 26. MXdlPerWave
|
||||
oss << "," << kNXdlPerWave; // 27. NXdlPerWave
|
||||
oss << ","
|
||||
<< detail::array_to_string(
|
||||
kAThreadClusterLengths); // 28. ABlockTransferThreadClusterLengths
|
||||
oss << ","
|
||||
<< detail::array_to_string(
|
||||
kAThreadClusterArrangeOrder); // 29. ABlockTransferThreadClusterArrangeOrder
|
||||
oss << ","
|
||||
<< detail::array_to_string(
|
||||
kABlockTransferSrcAccessOrder); // 30. ABlockTransferSrcAccessOrder
|
||||
oss << "," << kABlockTransferSrcVectorDim; // 31. ABlockTransferSrcVectorDim
|
||||
oss << "," << kABlockTransferSrcScalarPerVector; // 32. ABlockTransferSrcScalarPerVector
|
||||
oss << ","
|
||||
<< kABlockTransferDstScalarPerVectorK1; // 33. ABlockTransferDstScalarPerVector_AK1
|
||||
oss << "," << kABlockLdsExtraM; // 34. ABlockLdsExtraM
|
||||
oss << ","
|
||||
<< detail::array_to_string(
|
||||
kBThreadClusterLengths); // 35. BBlockTransferThreadClusterLengths
|
||||
oss << ","
|
||||
<< detail::array_to_string(
|
||||
kBThreadClusterArrangeOrder); // 36. BBlockTransferThreadClusterArrangeOrder
|
||||
oss << ","
|
||||
<< detail::array_to_string(
|
||||
kBBlockTransferSrcAccessOrder); // 37. BBlockTransferSrcAccessOrder
|
||||
oss << "," << kBBlockTransferSrcVectorDim; // 38. BBlockTransferSrcVectorDim
|
||||
oss << "," << kBBlockTransferSrcScalarPerVector; // 39. BBlockTransferSrcScalarPerVector
|
||||
oss << ","
|
||||
<< kBBlockTransferDstScalarPerVectorK1; // 40. BBlockTransferDstScalarPerVector_BK1
|
||||
oss << "," << kBBlockLdsExtraN; // 41. BBlockLdsExtraN
|
||||
oss << "," << kCShuffleMXdlPerWavePerShuffle; // 42. CShuffleMXdlPerWavePerShuffle
|
||||
oss << "," << kCShuffleNXdlPerWavePerShuffle; // 43. CShuffleNXdlPerWavePerShuffle
|
||||
oss << ","
|
||||
<< detail::array_to_string(
|
||||
kCThreadClusterLengths); // 44. CDEBlockTransferClusterLengths
|
||||
oss << ","
|
||||
<< kCBlockTransferScalarPerVector; // 45. CDEBlockTransferScalarPerVector_NPerBlock
|
||||
oss << "," << detail::type_name<AComputeDataType>(); // 46. AComputeDataType
|
||||
oss << "," << detail::type_name<BComputeDataType>(); // 47. BComputeDataType
|
||||
oss << "," << detail::loop_scheduler_name(kLoopScheduler); // 48. LoopSched
|
||||
oss << "," << kNumGroupsToMerge; // 49. NumGroupsToMerge
|
||||
oss << ">";
|
||||
|
||||
return oss.str();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile::reflect
|
||||
@@ -0,0 +1,344 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
// InstanceTraits specialization for DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor
|
||||
//
|
||||
// CRITICAL MAINTENANCE NOTE:
|
||||
// This InstanceTraits file MUST be kept strictly in sync with the device implementation header:
|
||||
// ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp
|
||||
// "In sync" means that the template parameter order, names, and types in the declaration below
|
||||
// MUST EXACTLY MATCH those in the device implementation. If these diverge, you may encounter
|
||||
// compilation errors, subtle template instantiation mismatches, or silent runtime bugs that are
|
||||
// difficult to diagnose. Always update both files together and review changes carefully.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "instance_traits.hpp"
|
||||
|
||||
// Forward declaration to avoid circular dependency.
|
||||
namespace ck::tensor_operation::device {
|
||||
|
||||
template <ck::index_t NDimSpatial,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename DsLayout,
|
||||
typename ELayout,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename AccDataType,
|
||||
typename CShuffleDataType,
|
||||
typename DsDataType,
|
||||
typename EDataType,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CDEElementwiseOperation,
|
||||
ConvolutionForwardSpecialization ConvForwardSpecialization,
|
||||
GemmSpecialization GemmSpec,
|
||||
ck::index_t NumGemmKPrefetchStage,
|
||||
ck::index_t BlockSize,
|
||||
ck::index_t MPerBlock,
|
||||
ck::index_t NPerBlock,
|
||||
ck::index_t KPerBlock,
|
||||
ck::index_t AK1,
|
||||
ck::index_t BK1,
|
||||
ck::index_t MPerXDL,
|
||||
ck::index_t NPerXDL,
|
||||
ck::index_t MXdlPerWave,
|
||||
ck::index_t NXdlPerWave,
|
||||
typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
|
||||
typename ABlockTransferThreadClusterArrangeOrder,
|
||||
typename ABlockTransferSrcAccessOrder,
|
||||
ck::index_t ABlockTransferSrcVectorDim,
|
||||
ck::index_t ABlockTransferSrcScalarPerVector,
|
||||
ck::index_t ABlockTransferDstScalarPerVector_AK1,
|
||||
ck::index_t ABlockLdsExtraM,
|
||||
typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
|
||||
typename BBlockTransferThreadClusterArrangeOrder,
|
||||
typename BBlockTransferSrcAccessOrder,
|
||||
ck::index_t BBlockTransferSrcVectorDim,
|
||||
ck::index_t BBlockTransferSrcScalarPerVector,
|
||||
ck::index_t BBlockTransferDstScalarPerVector_BK1,
|
||||
ck::index_t BBlockLdsExtraN,
|
||||
ck::index_t CShuffleMXdlPerWavePerShuffle,
|
||||
ck::index_t CShuffleNXdlPerWavePerShuffle,
|
||||
typename CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
ck::index_t CDEBlockTransferScalarPerVector_NPerBlock,
|
||||
typename AComputeDataType,
|
||||
typename BComputeDataType,
|
||||
LoopScheduler LoopSched>
|
||||
struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor;
|
||||
|
||||
} // namespace ck::tensor_operation::device
|
||||
|
||||
namespace ck_tile::reflect {
|
||||
|
||||
// Specialization for DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor
|
||||
template <ck::index_t NDimSpatial,
|
||||
typename ALayout_,
|
||||
typename BLayout_,
|
||||
typename DsLayout_,
|
||||
typename ELayout_,
|
||||
typename ADataType_,
|
||||
typename BDataType_,
|
||||
typename AccDataType_,
|
||||
typename CShuffleDataType_,
|
||||
typename DsDataType_,
|
||||
typename EDataType_,
|
||||
typename AElementwiseOperation_,
|
||||
typename BElementwiseOperation_,
|
||||
typename CDEElementwiseOperation_,
|
||||
ck::tensor_operation::device::ConvolutionForwardSpecialization ConvForwardSpecialization,
|
||||
ck::tensor_operation::device::GemmSpecialization GemmSpec,
|
||||
ck::index_t NumGemmKPrefetchStage,
|
||||
ck::index_t BlockSize,
|
||||
ck::index_t MPerBlock,
|
||||
ck::index_t NPerBlock,
|
||||
ck::index_t KPerBlock,
|
||||
ck::index_t AK1,
|
||||
ck::index_t BK1,
|
||||
ck::index_t MPerXDL,
|
||||
ck::index_t NPerXDL,
|
||||
ck::index_t MXdlPerWave,
|
||||
ck::index_t NXdlPerWave,
|
||||
typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
|
||||
typename ABlockTransferThreadClusterArrangeOrder,
|
||||
typename ABlockTransferSrcAccessOrder_,
|
||||
ck::index_t ABlockTransferSrcVectorDim,
|
||||
ck::index_t ABlockTransferSrcScalarPerVector,
|
||||
ck::index_t ABlockTransferDstScalarPerVector_AK1,
|
||||
ck::index_t ABlockLdsExtraM,
|
||||
typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
|
||||
typename BBlockTransferThreadClusterArrangeOrder,
|
||||
typename BBlockTransferSrcAccessOrder_,
|
||||
ck::index_t BBlockTransferSrcVectorDim,
|
||||
ck::index_t BBlockTransferSrcScalarPerVector,
|
||||
ck::index_t BBlockTransferDstScalarPerVector_BK1,
|
||||
ck::index_t BBlockLdsExtraN,
|
||||
ck::index_t CShuffleMXdlPerWavePerShuffle,
|
||||
ck::index_t CShuffleNXdlPerWavePerShuffle,
|
||||
typename CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
ck::index_t CDEBlockTransferScalarPerVector_NPerBlock,
|
||||
typename AComputeDataType_,
|
||||
typename BComputeDataType_,
|
||||
ck::LoopScheduler LoopSched>
|
||||
struct InstanceTraits<
|
||||
ck::tensor_operation::device::DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor<
|
||||
NDimSpatial,
|
||||
ALayout_,
|
||||
BLayout_,
|
||||
DsLayout_,
|
||||
ELayout_,
|
||||
ADataType_,
|
||||
BDataType_,
|
||||
AccDataType_,
|
||||
CShuffleDataType_,
|
||||
DsDataType_,
|
||||
EDataType_,
|
||||
AElementwiseOperation_,
|
||||
BElementwiseOperation_,
|
||||
CDEElementwiseOperation_,
|
||||
ConvForwardSpecialization,
|
||||
GemmSpec,
|
||||
NumGemmKPrefetchStage,
|
||||
BlockSize,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
AK1,
|
||||
BK1,
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
MXdlPerWave,
|
||||
NXdlPerWave,
|
||||
ABlockTransferThreadClusterLengths_AK0_M_AK1,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
ABlockTransferSrcAccessOrder_,
|
||||
ABlockTransferSrcVectorDim,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
ABlockTransferDstScalarPerVector_AK1,
|
||||
ABlockLdsExtraM,
|
||||
BBlockTransferThreadClusterLengths_BK0_N_BK1,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
BBlockTransferSrcAccessOrder_,
|
||||
BBlockTransferSrcVectorDim,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
BBlockTransferDstScalarPerVector_BK1,
|
||||
BBlockLdsExtraN,
|
||||
CShuffleMXdlPerWavePerShuffle,
|
||||
CShuffleNXdlPerWavePerShuffle,
|
||||
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
CDEBlockTransferScalarPerVector_NPerBlock,
|
||||
AComputeDataType_,
|
||||
BComputeDataType_,
|
||||
LoopSched>>
|
||||
{
|
||||
// Spatial dimension
|
||||
static constexpr int kSpatialDim = NDimSpatial;
|
||||
|
||||
// Layout types
|
||||
using ALayout = ALayout_;
|
||||
using BLayout = BLayout_;
|
||||
using DsLayout = DsLayout_;
|
||||
using ELayout = ELayout_;
|
||||
|
||||
// Data types
|
||||
using ADataType = ADataType_;
|
||||
using BDataType = BDataType_;
|
||||
using AccDataType = AccDataType_;
|
||||
using CShuffleDataType = CShuffleDataType_;
|
||||
using DsDataType = DsDataType_;
|
||||
using EDataType = EDataType_;
|
||||
|
||||
// Element-wise operations
|
||||
using AElementwiseOperation = AElementwiseOperation_;
|
||||
using BElementwiseOperation = BElementwiseOperation_;
|
||||
using CDEElementwiseOperation = CDEElementwiseOperation_;
|
||||
|
||||
// Specialization
|
||||
static constexpr ck::tensor_operation::device::ConvolutionForwardSpecialization
|
||||
kConvForwardSpecialization = ConvForwardSpecialization;
|
||||
static constexpr ck::tensor_operation::device::GemmSpecialization kGemmSpecialization =
|
||||
GemmSpec;
|
||||
|
||||
// Prefetch stage
|
||||
static constexpr int kNumGemmKPrefetchStage = NumGemmKPrefetchStage;
|
||||
|
||||
// Block configuration
|
||||
static constexpr int kBlockSize = BlockSize;
|
||||
static constexpr int kMPerBlock = MPerBlock;
|
||||
static constexpr int kNPerBlock = NPerBlock;
|
||||
static constexpr int kKPerBlock = KPerBlock;
|
||||
|
||||
// Tuning parameters
|
||||
static constexpr int kAK1 = AK1;
|
||||
static constexpr int kBK1 = BK1;
|
||||
static constexpr int kMPerXDL = MPerXDL;
|
||||
static constexpr int kNPerXDL = NPerXDL;
|
||||
static constexpr int kMXdlPerWave = MXdlPerWave;
|
||||
static constexpr int kNXdlPerWave = NXdlPerWave;
|
||||
|
||||
// A block transfer thread cluster dimensions (converted to std::array)
|
||||
static constexpr auto kAThreadClusterLengths =
|
||||
detail::SequenceToArray<ABlockTransferThreadClusterLengths_AK0_M_AK1>::value;
|
||||
static constexpr auto kAThreadClusterArrangeOrder =
|
||||
detail::SequenceToArray<ABlockTransferThreadClusterArrangeOrder>::value;
|
||||
static constexpr auto kABlockTransferSrcAccessOrder =
|
||||
detail::SequenceToArray<ABlockTransferSrcAccessOrder_>::value;
|
||||
static constexpr int kABlockTransferSrcVectorDim = ABlockTransferSrcVectorDim;
|
||||
static constexpr int kABlockTransferSrcScalarPerVector = ABlockTransferSrcScalarPerVector;
|
||||
static constexpr int kABlockTransferDstScalarPerVectorK1 = ABlockTransferDstScalarPerVector_AK1;
|
||||
static constexpr int kABlockLdsExtraM = ABlockLdsExtraM;
|
||||
|
||||
// B block transfer thread cluster dimensions (converted to std::array)
|
||||
static constexpr auto kBThreadClusterLengths =
|
||||
detail::SequenceToArray<BBlockTransferThreadClusterLengths_BK0_N_BK1>::value;
|
||||
static constexpr auto kBThreadClusterArrangeOrder =
|
||||
detail::SequenceToArray<BBlockTransferThreadClusterArrangeOrder>::value;
|
||||
static constexpr auto kBBlockTransferSrcAccessOrder =
|
||||
detail::SequenceToArray<BBlockTransferSrcAccessOrder_>::value;
|
||||
static constexpr int kBBlockTransferSrcVectorDim = BBlockTransferSrcVectorDim;
|
||||
static constexpr int kBBlockTransferSrcScalarPerVector = BBlockTransferSrcScalarPerVector;
|
||||
static constexpr int kBBlockTransferDstScalarPerVectorK1 = BBlockTransferDstScalarPerVector_BK1;
|
||||
static constexpr int kBBlockLdsExtraN = BBlockLdsExtraN;
|
||||
|
||||
// C shuffle parameters (converted to std::array)
|
||||
static constexpr int kCShuffleMXdlPerWavePerShuffle = CShuffleMXdlPerWavePerShuffle;
|
||||
static constexpr int kCShuffleNXdlPerWavePerShuffle = CShuffleNXdlPerWavePerShuffle;
|
||||
static constexpr auto kCThreadClusterLengths = detail::SequenceToArray<
|
||||
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock>::value;
|
||||
static constexpr int kCBlockTransferScalarPerVector = CDEBlockTransferScalarPerVector_NPerBlock;
|
||||
|
||||
// Compute data types
|
||||
using AComputeDataType = AComputeDataType_;
|
||||
using BComputeDataType = BComputeDataType_;
|
||||
|
||||
// Loop scheduler
|
||||
static constexpr ck::LoopScheduler kLoopScheduler = LoopSched;
|
||||
|
||||
// Static member function to generate instance string
|
||||
static std::string instance_string()
|
||||
{
|
||||
std::ostringstream oss;
|
||||
|
||||
// Kernel type name
|
||||
oss << "DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor";
|
||||
|
||||
// Template parameters in exact order matching InstanceTraits member order
|
||||
oss << "<" << kSpatialDim; // 1. NDimSpatial
|
||||
oss << "," << detail::layout_name<ALayout>(); // 2. ALayout
|
||||
oss << "," << detail::layout_name<BLayout>(); // 3. BLayout
|
||||
oss << "," << detail::tuple_name<DsLayout>(); // 4. DsLayout
|
||||
oss << "," << detail::layout_name<ELayout>(); // 5. ELayout
|
||||
oss << "," << detail::type_name<ADataType>(); // 6. ADataType
|
||||
oss << "," << detail::type_name<BDataType>(); // 7. BDataType
|
||||
oss << "," << detail::type_name<AccDataType>(); // 8. AccDataType
|
||||
oss << "," << detail::type_name<CShuffleDataType>(); // 9. CShuffleDataType
|
||||
oss << "," << detail::tuple_name<DsDataType>(); // 10. DsDataType
|
||||
oss << "," << detail::type_name<EDataType>(); // 11. EDataType
|
||||
oss << ","
|
||||
<< detail::elementwise_op_name<AElementwiseOperation>(); // 12. AElementwiseOperation
|
||||
oss << ","
|
||||
<< detail::elementwise_op_name<BElementwiseOperation>(); // 13. BElementwiseOperation
|
||||
oss << ","
|
||||
<< detail::elementwise_op_name<CDEElementwiseOperation>(); // 14.
|
||||
// CDEElementwiseOperation
|
||||
oss << ","
|
||||
<< detail::conv_fwd_spec_name(
|
||||
kConvForwardSpecialization); // 15. ConvForwardSpecialization
|
||||
oss << "," << detail::gemm_spec_name(kGemmSpecialization); // 16. GemmSpec
|
||||
oss << "," << kNumGemmKPrefetchStage; // 17. NumGemmKPrefetchStage
|
||||
oss << "," << kBlockSize; // 18. BlockSize
|
||||
oss << "," << kMPerBlock; // 19. MPerBlock
|
||||
oss << "," << kNPerBlock; // 20. NPerBlock
|
||||
oss << "," << kKPerBlock; // 21. KPerBlock
|
||||
oss << "," << kAK1; // 22. AK1
|
||||
oss << "," << kBK1; // 23. BK1
|
||||
oss << "," << kMPerXDL; // 24. MPerXDL
|
||||
oss << "," << kNPerXDL; // 25. NPerXDL
|
||||
oss << "," << kMXdlPerWave; // 26. MXdlPerWave
|
||||
oss << "," << kNXdlPerWave; // 27. NXdlPerWave
|
||||
oss << ","
|
||||
<< detail::array_to_string(
|
||||
kAThreadClusterLengths); // 28. ABlockTransferThreadClusterLengths
|
||||
oss << ","
|
||||
<< detail::array_to_string(
|
||||
kAThreadClusterArrangeOrder); // 29. ABlockTransferThreadClusterArrangeOrder
|
||||
oss << ","
|
||||
<< detail::array_to_string(
|
||||
kABlockTransferSrcAccessOrder); // 30. ABlockTransferSrcAccessOrder
|
||||
oss << "," << kABlockTransferSrcVectorDim; // 31. ABlockTransferSrcVectorDim
|
||||
oss << "," << kABlockTransferSrcScalarPerVector; // 32. ABlockTransferSrcScalarPerVector
|
||||
oss << ","
|
||||
<< kABlockTransferDstScalarPerVectorK1; // 33. ABlockTransferDstScalarPerVector_AK1
|
||||
oss << "," << kABlockLdsExtraM; // 34. ABlockLdsExtraM
|
||||
oss << ","
|
||||
<< detail::array_to_string(
|
||||
kBThreadClusterLengths); // 35. BBlockTransferThreadClusterLengths
|
||||
oss << ","
|
||||
<< detail::array_to_string(
|
||||
kBThreadClusterArrangeOrder); // 36. BBlockTransferThreadClusterArrangeOrder
|
||||
oss << ","
|
||||
<< detail::array_to_string(
|
||||
kBBlockTransferSrcAccessOrder); // 37. BBlockTransferSrcAccessOrder
|
||||
oss << "," << kBBlockTransferSrcVectorDim; // 38. BBlockTransferSrcVectorDim
|
||||
oss << "," << kBBlockTransferSrcScalarPerVector; // 39. BBlockTransferSrcScalarPerVector
|
||||
oss << ","
|
||||
<< kBBlockTransferDstScalarPerVectorK1; // 40. BBlockTransferDstScalarPerVector_BK1
|
||||
oss << "," << kBBlockLdsExtraN; // 41. BBlockLdsExtraN
|
||||
oss << "," << kCShuffleMXdlPerWavePerShuffle; // 42. CShuffleMXdlPerWavePerShuffle
|
||||
oss << "," << kCShuffleNXdlPerWavePerShuffle; // 43. CShuffleNXdlPerWavePerShuffle
|
||||
oss << ","
|
||||
<< detail::array_to_string(
|
||||
kCThreadClusterLengths); // 44. CDEBlockTransferClusterLengths
|
||||
oss << ","
|
||||
<< kCBlockTransferScalarPerVector; // 45. CDEBlockTransferScalarPerVector_NPerBlock
|
||||
oss << "," << detail::type_name<AComputeDataType>(); // 46. AComputeDataType
|
||||
oss << "," << detail::type_name<BComputeDataType>(); // 47. BComputeDataType
|
||||
oss << "," << detail::loop_scheduler_name(kLoopScheduler); // 48. LoopSched
|
||||
oss << ">";
|
||||
|
||||
return oss.str();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile::reflect
|
||||
@@ -15,6 +15,7 @@
|
||||
#include <ck/utility/data_type.hpp>
|
||||
#include <ck/utility/sequence.hpp>
|
||||
#include <ck/utility/blkgemmpipe_scheduler.hpp>
|
||||
#include <ck/utility/loop_scheduler.hpp>
|
||||
#include <ck/tensor_operation/gpu/device/tensor_layout.hpp>
|
||||
#include <ck_tile/ops/common/tensor_layout.hpp>
|
||||
#include <ck/tensor_operation/gpu/element/element_wise_operation.hpp>
|
||||
@@ -160,6 +161,17 @@ constexpr std::string_view pipeline_version_name(ck::BlockGemmPipelineVersion ve
|
||||
}
|
||||
}
|
||||
|
||||
// Convert LoopScheduler enum to string
|
||||
constexpr std::string_view loop_scheduler_name(ck::LoopScheduler sched)
|
||||
{
|
||||
using enum ck::LoopScheduler;
|
||||
switch(sched)
|
||||
{
|
||||
case Default: return "Default";
|
||||
case Interwave: return "Interwave";
|
||||
}
|
||||
}
|
||||
|
||||
// Convert std::array to string
|
||||
template <typename T, std::size_t N>
|
||||
inline std::string array_to_string(const std::array<T, N>& arr)
|
||||
|
||||
@@ -26,7 +26,9 @@ add_ck_builder_test(test_inline_diff test_inline_diff.cpp)
|
||||
|
||||
# Testing the virtual GetInstanceString methods requires kernel compilation.
|
||||
add_ck_builder_test(test_get_instance_string
|
||||
test_get_instance_string.cpp)
|
||||
test_get_instance_string_fwd_grp_conv_v3.cpp
|
||||
test_get_instance_string_fwd_grp_conv.cpp
|
||||
test_get_instance_string_fwd_grp_conv_large_tensor.cpp)
|
||||
|
||||
# Testing the fwd convolution builder requires kernel compilation.
|
||||
# To enable parallel compilation, the individual tests are split into separate files.
|
||||
|
||||
@@ -3,19 +3,18 @@
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
#include <gmock/gmock.h>
|
||||
#include <ck/ck.hpp>
|
||||
#include <ck/utility/reduction_operator.hpp>
|
||||
#include <ck_tile/builder/reflect/instance_traits.hpp>
|
||||
#include <ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp>
|
||||
#include <ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp>
|
||||
#include <ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp>
|
||||
|
||||
namespace {
|
||||
|
||||
using ::testing::ElementsAre;
|
||||
// Test fixture for InstanceTraits tests
|
||||
class InstanceTraitsTest : public ::testing::Test
|
||||
{
|
||||
};
|
||||
|
||||
// Test InstanceTraits with DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
|
||||
TEST_F(InstanceTraitsTest, ConvFwdInstanceTraitsExtraction)
|
||||
TEST(InstanceTraitsTest, ConvFwdInstanceTraitsExtraction)
|
||||
{
|
||||
// Define a concrete instance type with specific template parameters
|
||||
using DeviceInstance =
|
||||
@@ -156,8 +155,7 @@ TEST_F(InstanceTraitsTest, ConvFwdInstanceTraitsExtraction)
|
||||
ck::tensor_operation::element_wise::PassThrough>::value));
|
||||
}
|
||||
|
||||
// Test instance_string function
|
||||
TEST_F(InstanceTraitsTest, InstanceStringGeneration)
|
||||
TEST(InstanceTraitsTest, V3InstanceStringGeneration)
|
||||
{
|
||||
// Define a concrete instance type with specific template parameters
|
||||
using DeviceInstance =
|
||||
@@ -215,10 +213,8 @@ TEST_F(InstanceTraitsTest, InstanceStringGeneration)
|
||||
ck::half_t, // AComputeDataType
|
||||
ck::half_t>; // BComputeDataType
|
||||
|
||||
// Generate instance string
|
||||
std::string instance_str = ck_tile::reflect::instance_string<DeviceInstance>();
|
||||
|
||||
// Expected string with all template parameters in exact order
|
||||
std::string expected_str = "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3"
|
||||
"<2" // NDimSpatial
|
||||
",GNHWC" // ALayout
|
||||
@@ -269,6 +265,234 @@ TEST_F(InstanceTraitsTest, InstanceStringGeneration)
|
||||
",fp16" // AComputeDataType
|
||||
",fp16>"; // BComputeDataType
|
||||
|
||||
EXPECT_EQ(instance_str, expected_str);
|
||||
}
|
||||
|
||||
TEST(InstanceTraitsTest, BaseInstanceStringGeneration)
|
||||
{
|
||||
using DeviceInstance =
|
||||
ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<
|
||||
2, // NDimSpatial
|
||||
ck::tensor_layout::convolution::GNHWC, // ALayout
|
||||
ck::tensor_layout::convolution::GKYXC, // BLayout
|
||||
ck::Tuple<>, // DsLayout
|
||||
ck::tensor_layout::convolution::GNHWK, // ELayout
|
||||
ck::half_t, // ADataType
|
||||
ck::half_t, // BDataType
|
||||
float, // AccDataType
|
||||
ck::half_t, // CShuffleDataType
|
||||
ck::Tuple<>, // DsDataType
|
||||
ck::half_t, // EDataType
|
||||
ck::tensor_operation::element_wise::PassThrough, // AElementwiseOperation
|
||||
ck::tensor_operation::element_wise::PassThrough, // BElementwiseOperation
|
||||
ck::tensor_operation::element_wise::PassThrough, // CDEElementwiseOperation
|
||||
ck::tensor_operation::device::ConvolutionForwardSpecialization::
|
||||
Default, // ConvForwardSpecialization
|
||||
ck::tensor_operation::device::GemmSpecialization::Default, // GemmSpec
|
||||
1, // NumGemmKPrefetchStage
|
||||
256, // BlockSize
|
||||
128, // MPerBlock
|
||||
128, // NPerBlock
|
||||
16, // KPerBlock
|
||||
8, // AK1
|
||||
8, // BK1
|
||||
32, // MPerXDL
|
||||
32, // NPerXDL
|
||||
4, // MXdlPerWave
|
||||
4, // NXdlPerWave
|
||||
ck::Sequence<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1
|
||||
ck::Sequence<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder
|
||||
ck::Sequence<1, 0, 2>, // ABlockTransferSrcAccessOrder
|
||||
2, // ABlockTransferSrcVectorDim
|
||||
8, // ABlockTransferSrcScalarPerVector
|
||||
8, // ABlockTransferDstScalarPerVector_AK1
|
||||
1, // ABlockLdsExtraM
|
||||
ck::Sequence<4, 64, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1
|
||||
ck::Sequence<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder
|
||||
ck::Sequence<1, 0, 2>, // BBlockTransferSrcAccessOrder
|
||||
2, // BBlockTransferSrcVectorDim
|
||||
8, // BBlockTransferSrcScalarPerVector
|
||||
8, // BBlockTransferDstScalarPerVector_BK1
|
||||
1, // BBlockLdsExtraN
|
||||
1, // CShuffleMXdlPerWavePerShuffle
|
||||
1, // CShuffleNXdlPerWavePerShuffle
|
||||
ck::Sequence<1,
|
||||
32,
|
||||
1,
|
||||
8>, // CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
8, // CDEBlockTransferScalarPerVector_NPerBlock
|
||||
ck::half_t, // AComputeDataType
|
||||
ck::half_t, // BComputeDataType
|
||||
ck::LoopScheduler::Default, // LoopSched
|
||||
1>; // NumGroupsToMerge
|
||||
|
||||
std::string instance_str = ck_tile::reflect::instance_string<DeviceInstance>();
|
||||
|
||||
std::string expected_str = "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle"
|
||||
"<2" // NDimSpatial
|
||||
",GNHWC" // ALayout
|
||||
",GKYXC" // BLayout
|
||||
",EmptyTuple" // DsLayout
|
||||
",GNHWK" // ELayout
|
||||
",fp16" // ADataType
|
||||
",fp16" // BDataType
|
||||
",fp32" // AccDataType
|
||||
",fp16" // CShuffleDataType
|
||||
",EmptyTuple" // DsDataType
|
||||
",fp16" // EDataType
|
||||
",PassThrough" // AElementwiseOperation
|
||||
",PassThrough" // BElementwiseOperation
|
||||
",PassThrough" // CDEElementwiseOperation
|
||||
",Default" // ConvForwardSpecialization
|
||||
",Default" // GemmSpec
|
||||
",1" // NumGemmKPrefetchStage
|
||||
",256" // BlockSize
|
||||
",128" // MPerBlock
|
||||
",128" // NPerBlock
|
||||
",16" // KPerBlock
|
||||
",8" // AK1
|
||||
",8" // BK1
|
||||
",32" // MPerXDL
|
||||
",32" // NPerXDL
|
||||
",4" // MXdlPerWave
|
||||
",4" // NXdlPerWave
|
||||
",Seq(4,64,1)" // ABlockTransferThreadClusterLengths
|
||||
",Seq(1,0,2)" // ABlockTransferThreadClusterArrangeOrder
|
||||
",Seq(1,0,2)" // ABlockTransferSrcAccessOrder
|
||||
",2" // ABlockTransferSrcVectorDim
|
||||
",8" // ABlockTransferSrcScalarPerVector
|
||||
",8" // ABlockTransferDstScalarPerVector_AK1
|
||||
",1" // ABlockLdsExtraM
|
||||
",Seq(4,64,1)" // BBlockTransferThreadClusterLengths
|
||||
",Seq(1,0,2)" // BBlockTransferThreadClusterArrangeOrder
|
||||
",Seq(1,0,2)" // BBlockTransferSrcAccessOrder
|
||||
",2" // BBlockTransferSrcVectorDim
|
||||
",8" // BBlockTransferSrcScalarPerVector
|
||||
",8" // BBlockTransferDstScalarPerVector_BK1
|
||||
",1" // BBlockLdsExtraN
|
||||
",1" // CShuffleMXdlPerWavePerShuffle
|
||||
",1" // CShuffleNXdlPerWavePerShuffle
|
||||
",Seq(1,32,1,8)" // CDEBlockTransferClusterLengths
|
||||
",8" // CDEBlockTransferScalarPerVector_NPerBlock
|
||||
",fp16" // AComputeDataType
|
||||
",fp16" // BComputeDataType
|
||||
",Default" // LoopSched
|
||||
",1>"; // NumGroupsToMerge
|
||||
|
||||
EXPECT_EQ(instance_str, expected_str);
|
||||
}
|
||||
|
||||
TEST(InstanceTraitsTest, LargeTensorInstanceStringGeneration)
|
||||
{
|
||||
using DeviceInstance =
|
||||
ck::tensor_operation::device::DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor<
|
||||
2, // NDimSpatial
|
||||
ck::tensor_layout::convolution::GNHWC, // ALayout
|
||||
ck::tensor_layout::convolution::GKYXC, // BLayout
|
||||
ck::Tuple<>, // DsLayout
|
||||
ck::tensor_layout::convolution::GNHWK, // ELayout
|
||||
ck::half_t, // ADataType
|
||||
ck::half_t, // BDataType
|
||||
float, // AccDataType
|
||||
ck::half_t, // CShuffleDataType
|
||||
ck::Tuple<>, // DsDataType
|
||||
ck::half_t, // EDataType
|
||||
ck::tensor_operation::element_wise::PassThrough, // AElementwiseOperation
|
||||
ck::tensor_operation::element_wise::PassThrough, // BElementwiseOperation
|
||||
ck::tensor_operation::element_wise::PassThrough, // CDEElementwiseOperation
|
||||
ck::tensor_operation::device::ConvolutionForwardSpecialization::
|
||||
Default, // ConvForwardSpecialization
|
||||
ck::tensor_operation::device::GemmSpecialization::Default, // GemmSpec
|
||||
1, // NumGemmKPrefetchStage
|
||||
256, // BlockSize
|
||||
128, // MPerBlock
|
||||
128, // NPerBlock
|
||||
16, // KPerBlock
|
||||
8, // AK1
|
||||
8, // BK1
|
||||
32, // MPerXDL
|
||||
32, // NPerXDL
|
||||
4, // MXdlPerWave
|
||||
4, // NXdlPerWave
|
||||
ck::Sequence<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1
|
||||
ck::Sequence<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder
|
||||
ck::Sequence<1, 0, 2>, // ABlockTransferSrcAccessOrder
|
||||
2, // ABlockTransferSrcVectorDim
|
||||
8, // ABlockTransferSrcScalarPerVector
|
||||
8, // ABlockTransferDstScalarPerVector_AK1
|
||||
1, // ABlockLdsExtraM
|
||||
ck::Sequence<4, 64, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1
|
||||
ck::Sequence<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder
|
||||
ck::Sequence<1, 0, 2>, // BBlockTransferSrcAccessOrder
|
||||
2, // BBlockTransferSrcVectorDim
|
||||
8, // BBlockTransferSrcScalarPerVector
|
||||
8, // BBlockTransferDstScalarPerVector_BK1
|
||||
1, // BBlockLdsExtraN
|
||||
1, // CShuffleMXdlPerWavePerShuffle
|
||||
1, // CShuffleNXdlPerWavePerShuffle
|
||||
ck::Sequence<1,
|
||||
32,
|
||||
1,
|
||||
8>, // CDEBlockTransferClusterLengths
|
||||
8, // CDEBlockTransferScalarPerVector_NPerBlock
|
||||
ck::half_t, // AComputeDataType
|
||||
ck::half_t, // BComputeDataType
|
||||
ck::LoopScheduler::Default>; // LoopSched
|
||||
|
||||
// Generate instance string
|
||||
std::string instance_str = ck_tile::reflect::instance_string<DeviceInstance>();
|
||||
|
||||
// Expected string with all 48 template parameters
|
||||
std::string expected_str = "DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor"
|
||||
"<2" // NDimSpatial
|
||||
",GNHWC" // ALayout
|
||||
",GKYXC" // BLayout
|
||||
",EmptyTuple" // DsLayout
|
||||
",GNHWK" // ELayout
|
||||
",fp16" // ADataType
|
||||
",fp16" // BDataType
|
||||
",fp32" // AccDataType
|
||||
",fp16" // CShuffleDataType
|
||||
",EmptyTuple" // DsDataType
|
||||
",fp16" // EDataType
|
||||
",PassThrough" // AElementwiseOperation
|
||||
",PassThrough" // BElementwiseOperation
|
||||
",PassThrough" // CDEElementwiseOperation
|
||||
",Default" // ConvForwardSpecialization
|
||||
",Default" // GemmSpec
|
||||
",1" // NumGemmKPrefetchStage
|
||||
",256" // BlockSize
|
||||
",128" // MPerBlock
|
||||
",128" // NPerBlock
|
||||
",16" // KPerBlock
|
||||
",8" // AK1
|
||||
",8" // BK1
|
||||
",32" // MPerXDL
|
||||
",32" // NPerXDL
|
||||
",4" // MXdlPerWave
|
||||
",4" // NXdlPerWave
|
||||
",Seq(4,64,1)" // ABlockTransferThreadClusterLengths
|
||||
",Seq(1,0,2)" // ABlockTransferThreadClusterArrangeOrder
|
||||
",Seq(1,0,2)" // ABlockTransferSrcAccessOrder
|
||||
",2" // ABlockTransferSrcVectorDim
|
||||
",8" // ABlockTransferSrcScalarPerVector
|
||||
",8" // ABlockTransferDstScalarPerVector_AK1
|
||||
",1" // ABlockLdsExtraM
|
||||
",Seq(4,64,1)" // BBlockTransferThreadClusterLengths
|
||||
",Seq(1,0,2)" // BBlockTransferThreadClusterArrangeOrder
|
||||
",Seq(1,0,2)" // BBlockTransferSrcAccessOrder
|
||||
",2" // BBlockTransferSrcVectorDim
|
||||
",8" // BBlockTransferSrcScalarPerVector
|
||||
",8" // BBlockTransferDstScalarPerVector_BK1
|
||||
",1" // BBlockLdsExtraN
|
||||
",1" // CShuffleMXdlPerWavePerShuffle
|
||||
",1" // CShuffleNXdlPerWavePerShuffle
|
||||
",Seq(1,32,1,8)" // CDEBlockTransferClusterLengths
|
||||
",8" // CDEBlockTransferScalarPerVector_NPerBlock
|
||||
",fp16" // AComputeDataType
|
||||
",fp16" // BComputeDataType
|
||||
",Default>"; // LoopSched
|
||||
|
||||
// Verify the generated string matches exactly
|
||||
EXPECT_EQ(instance_str, expected_str);
|
||||
}
|
||||
@@ -0,0 +1,104 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
#include <ck_tile/builder/reflect/instance_traits.hpp>
|
||||
#include <ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_abd.hpp>
|
||||
#include <ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp>
|
||||
|
||||
// Test GetInstanceString through base class pointer for non-V3 variant
|
||||
TEST(GetInstanceString, ReturnsStringForFwdGrpConvInstance)
|
||||
{
|
||||
// Use the template helper to get a working instance configuration
|
||||
using InstanceTuple =
|
||||
ck::tensor_operation::device::instance::device_grouped_conv_fwd_xdl_f16_instances<
|
||||
2, // NDimSpatial
|
||||
ck::tensor_operation::device::instance::GNHWC, // ALayout
|
||||
ck::tensor_operation::device::instance::GKYXC, // BLayout
|
||||
ck::tensor_operation::device::instance::Empty_Tuple, // DsLayout
|
||||
ck::tensor_operation::device::instance::GNHWK, // ELayout
|
||||
ck::tensor_operation::device::instance::ConvFwdDefault>; // ConvForwardSpecialization
|
||||
|
||||
// Get the first instance from the tuple
|
||||
using DeviceInstance = typename std::tuple_element<0, InstanceTuple>::type;
|
||||
|
||||
// Define the base class type using DeviceGroupedConvFwdMultipleABD
|
||||
using BaseClass = ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD<
|
||||
2, // NDimSpatial
|
||||
ck::tensor_operation::device::instance::GNHWC, // ALayout
|
||||
ck::tensor_operation::device::instance::GKYXC, // BLayout
|
||||
ck::tensor_operation::device::instance::Empty_Tuple, // DsLayout
|
||||
ck::tensor_operation::device::instance::GNHWK, // ELayout
|
||||
ck::half_t, // ADataType
|
||||
ck::half_t, // BDataType
|
||||
ck::Tuple<>, // DsDataType
|
||||
ck::half_t, // EDataType
|
||||
ck::tensor_operation::element_wise::PassThrough, // AElementwiseOperation
|
||||
ck::tensor_operation::element_wise::PassThrough, // BElementwiseOperation
|
||||
ck::tensor_operation::element_wise::PassThrough, // CDEElementwiseOperation
|
||||
ck::half_t, // AComputeType
|
||||
ck::half_t>; // BComputeType
|
||||
|
||||
// Create an instance of the derived class
|
||||
DeviceInstance device_instance;
|
||||
|
||||
// Get a pointer to the base class
|
||||
BaseClass* base_ptr = &device_instance;
|
||||
|
||||
// Call GetInstanceString through the base class pointer
|
||||
std::string instance_str = base_ptr->GetInstanceString();
|
||||
|
||||
// Expected complete instance string based on the first instance from
|
||||
// device_grouped_conv_fwd_xdl_f16_instances
|
||||
std::string expected_str = "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle"
|
||||
"<2" // NDimSpatial
|
||||
",GNHWC" // ALayout
|
||||
",GKYXC" // BLayout
|
||||
",EmptyTuple" // DsLayout
|
||||
",GNHWK" // ELayout
|
||||
",fp16" // ADataType
|
||||
",fp16" // BDataType
|
||||
",fp32" // AccDataType
|
||||
",fp16" // CShuffleDataType
|
||||
",EmptyTuple" // DsDataType
|
||||
",fp16" // EDataType
|
||||
",PassThrough" // AElementwiseOperation
|
||||
",PassThrough" // BElementwiseOperation
|
||||
",PassThrough" // CDEElementwiseOperation
|
||||
",Default" // ConvForwardSpecialization
|
||||
",MNKPadding" // GemmSpec
|
||||
",1" // NumGemmKPrefetchStage
|
||||
",64" // BlockSize
|
||||
",64" // MPerBlock
|
||||
",64" // NPerBlock
|
||||
",32" // KPerBlock
|
||||
",8" // AK1
|
||||
",8" // BK1
|
||||
",32" // MPerXDL
|
||||
",32" // NPerXDL
|
||||
",2" // MXdlPerWave
|
||||
",2" // NXdlPerWave
|
||||
",Seq(4,16,1)" // ABlockTransferThreadClusterLengths
|
||||
",Seq(1,0,2)" // ABlockTransferThreadClusterArrangeOrder
|
||||
",Seq(1,0,2)" // ABlockTransferSrcAccessOrder
|
||||
",2" // ABlockTransferSrcVectorDim
|
||||
",1" // ABlockTransferSrcScalarPerVector
|
||||
",8" // ABlockTransferDstScalarPerVector_AK1
|
||||
",1" // ABlockLdsExtraM
|
||||
",Seq(4,16,1)" // BBlockTransferThreadClusterLengths
|
||||
",Seq(1,0,2)" // BBlockTransferThreadClusterArrangeOrder
|
||||
",Seq(1,0,2)" // BBlockTransferSrcAccessOrder
|
||||
",2" // BBlockTransferSrcVectorDim
|
||||
",1" // BBlockTransferSrcScalarPerVector
|
||||
",8" // BBlockTransferDstScalarPerVector_BK1
|
||||
",1" // BBlockLdsExtraN
|
||||
",1" // CShuffleMXdlPerWavePerShuffle
|
||||
",1" // CShuffleNXdlPerWavePerShuffle
|
||||
",Seq(1,16,1,4)" // CDEBlockTransferClusterLengths
|
||||
",1" // CDEBlockTransferScalarPerVector_NPerBlock
|
||||
",fp16" // AComputeDataType
|
||||
",fp16" // BComputeDataType
|
||||
",Default" // LoopScheduler
|
||||
",1>"; // NumGroupsToMerge
|
||||
EXPECT_EQ(instance_str, expected_str);
|
||||
}
|
||||
@@ -0,0 +1,103 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
#include <ck_tile/builder/reflect/instance_traits.hpp>
|
||||
#include <ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_abd.hpp>
|
||||
#include <ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_large_tensor_instance.hpp>
|
||||
|
||||
// Test GetInstanceString through base class pointer for large tensor variant
|
||||
TEST(GetInstanceString, ReturnsStringForFwdGrpConvLargeTensorInstance)
|
||||
{
|
||||
// Use the template helper to get a working instance configuration
|
||||
using InstanceTuple = ck::tensor_operation::device::instance::
|
||||
device_grouped_conv_fwd_xdl_large_tensor_f16_instances<
|
||||
2, // NDimSpatial
|
||||
ck::tensor_operation::device::instance::GNHWC, // ALayout
|
||||
ck::tensor_operation::device::instance::GKYXC, // BLayout
|
||||
ck::tensor_operation::device::instance::Empty_Tuple, // DsLayout
|
||||
ck::tensor_operation::device::instance::GNHWK, // ELayout
|
||||
ck::tensor_operation::device::instance::ConvFwdDefault>; // ConvForwardSpecialization
|
||||
|
||||
// Get the first instance from the tuple
|
||||
using DeviceInstance = typename std::tuple_element<0, InstanceTuple>::type;
|
||||
|
||||
// Define the base class type using DeviceGroupedConvFwdMultipleABD
|
||||
using BaseClass = ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD<
|
||||
2, // NDimSpatial
|
||||
ck::tensor_operation::device::instance::GNHWC, // ALayout
|
||||
ck::tensor_operation::device::instance::GKYXC, // BLayout
|
||||
ck::tensor_operation::device::instance::Empty_Tuple, // DsLayout
|
||||
ck::tensor_operation::device::instance::GNHWK, // ELayout
|
||||
ck::half_t, // ADataType
|
||||
ck::half_t, // BDataType
|
||||
ck::Tuple<>, // DsDataType
|
||||
ck::half_t, // EDataType
|
||||
ck::tensor_operation::element_wise::PassThrough, // AElementwiseOperation
|
||||
ck::tensor_operation::element_wise::PassThrough, // BElementwiseOperation
|
||||
ck::tensor_operation::element_wise::PassThrough, // CDEElementwiseOperation
|
||||
ck::half_t, // AComputeType
|
||||
ck::half_t>; // BComputeType
|
||||
|
||||
// Create an instance of the derived class
|
||||
DeviceInstance device_instance;
|
||||
|
||||
// Get a pointer to the base class
|
||||
BaseClass* base_ptr = &device_instance;
|
||||
|
||||
// Call GetInstanceString through the base class pointer
|
||||
std::string instance_str = base_ptr->GetInstanceString();
|
||||
|
||||
// Expected complete instance string based on the first instance from
|
||||
// device_grouped_conv_fwd_xdl_large_tensor_f16_instances
|
||||
std::string expected_str = "DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor"
|
||||
"<2" // NDimSpatial
|
||||
",GNHWC" // ALayout
|
||||
",GKYXC" // BLayout
|
||||
",EmptyTuple" // DsLayout
|
||||
",GNHWK" // ELayout
|
||||
",fp16" // ADataType
|
||||
",fp16" // BDataType
|
||||
",fp32" // AccDataType
|
||||
",fp16" // CShuffleDataType
|
||||
",EmptyTuple" // DsDataType
|
||||
",fp16" // EDataType
|
||||
",PassThrough" // AElementwiseOperation
|
||||
",PassThrough" // BElementwiseOperation
|
||||
",PassThrough" // CDEElementwiseOperation
|
||||
",Default" // ConvForwardSpecialization
|
||||
",MNKPadding" // GemmSpec
|
||||
",1" // NumGemmKPrefetchStage
|
||||
",64" // BlockSize
|
||||
",64" // MPerBlock
|
||||
",64" // NPerBlock
|
||||
",32" // KPerBlock
|
||||
",8" // AK1
|
||||
",8" // BK1
|
||||
",32" // MPerXDL
|
||||
",32" // NPerXDL
|
||||
",2" // MXdlPerWave
|
||||
",2" // NXdlPerWave
|
||||
",Seq(4,16,1)" // ABlockTransferThreadClusterLengths
|
||||
",Seq(1,0,2)" // ABlockTransferThreadClusterArrangeOrder
|
||||
",Seq(1,0,2)" // ABlockTransferSrcAccessOrder
|
||||
",2" // ABlockTransferSrcVectorDim
|
||||
",1" // ABlockTransferSrcScalarPerVector
|
||||
",8" // ABlockTransferDstScalarPerVector_AK1
|
||||
",1" // ABlockLdsExtraM
|
||||
",Seq(4,16,1)" // BBlockTransferThreadClusterLengths
|
||||
",Seq(1,0,2)" // BBlockTransferThreadClusterArrangeOrder
|
||||
",Seq(1,0,2)" // BBlockTransferSrcAccessOrder
|
||||
",2" // BBlockTransferSrcVectorDim
|
||||
",1" // BBlockTransferSrcScalarPerVector
|
||||
",8" // BBlockTransferDstScalarPerVector_BK1
|
||||
",1" // BBlockLdsExtraN
|
||||
",1" // CShuffleMXdlPerWavePerShuffle
|
||||
",1" // CShuffleNXdlPerWavePerShuffle
|
||||
",Seq(1,16,1,4)" // CDEBlockTransferClusterLengths
|
||||
",1" // CDEBlockTransferScalarPerVector_NPerBlock
|
||||
",fp16" // AComputeDataType
|
||||
",fp16" // BComputeDataType
|
||||
",Default>"; // LoopScheduler
|
||||
EXPECT_EQ(instance_str, expected_str);
|
||||
}
|
||||
@@ -6,8 +6,8 @@
|
||||
#include <ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_abd.hpp>
|
||||
#include <ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_comp_instance.hpp>
|
||||
|
||||
// Test GetInstanceString through base class pointer
|
||||
TEST(GetInstanceStringTest, GetInstanceStringThroughBaseClass)
|
||||
// Test GetInstanceString through base class pointer for V3 variant
|
||||
TEST(GetInstanceString, ReturnsStringForFwdGrpConvV3Instance)
|
||||
{
|
||||
// Use the template helper to get a working instance configuration
|
||||
using InstanceTuple =
|
||||
@@ -199,6 +199,14 @@ TEST(InstanceTraitsUtil, PipelineVersionNameReturnsCorrectStrings)
|
||||
ElementsAre("v1", "v2", "v3", "v4", "v5"));
|
||||
}
|
||||
|
||||
TEST(InstanceTraitsUtil, LoopSchedulerNameReturnsCorrectStrings)
|
||||
{
|
||||
using enum ck::LoopScheduler;
|
||||
EXPECT_THAT(std::vector<std::string_view> names = {loop_scheduler_name(Default),
|
||||
loop_scheduler_name(Interwave)},
|
||||
ElementsAre("Default", "Interwave"));
|
||||
}
|
||||
|
||||
TEST(InstanceTraitsUtil, TupleNameReturnsEmptyTupleForEmptyTuple)
|
||||
{
|
||||
EXPECT_EQ(tuple_name<ck::Tuple<>>(), "EmptyTuple");
|
||||
|
||||
@@ -28,6 +28,9 @@
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
#include "ck/host_utility/kernel_launch.hpp"
|
||||
#include "ck/host_utility/io.hpp"
|
||||
#ifdef CK_EXPERIMENTAL_BUILDER
|
||||
#include "ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp"
|
||||
#endif
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
@@ -2063,6 +2066,19 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
return str.str();
|
||||
}
|
||||
|
||||
#ifdef CK_EXPERIMENTAL_BUILDER
|
||||
std::string GetInstanceString() const override
|
||||
{
|
||||
static_assert(ck_tile::reflect::HasInstanceTraits<DeviceOp>,
|
||||
"Specialization of instance_traits not found. Please check that a "
|
||||
"specialization exists in file "
|
||||
"ck_tile/builder/reflect/"
|
||||
"instance_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp "
|
||||
"for the given template parameters.");
|
||||
return ck_tile::reflect::instance_string<DeviceOp>();
|
||||
}
|
||||
#endif
|
||||
|
||||
size_t GetWorkSpaceSize(const BaseArgument* p_arg) const override
|
||||
{
|
||||
auto arg = dynamic_cast<const Argument*>(p_arg);
|
||||
|
||||
@@ -24,6 +24,9 @@
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
#include "ck/host_utility/kernel_launch.hpp"
|
||||
#include "ck/host_utility/io.hpp"
|
||||
#ifdef CK_EXPERIMENTAL_BUILDER
|
||||
#include "ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp"
|
||||
#endif
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
@@ -1220,6 +1223,20 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor
|
||||
|
||||
return str.str();
|
||||
}
|
||||
|
||||
#ifdef CK_EXPERIMENTAL_BUILDER
|
||||
std::string GetInstanceString() const override
|
||||
{
|
||||
static_assert(
|
||||
ck_tile::reflect::HasInstanceTraits<DeviceOp>,
|
||||
"Specialization of instance_traits not found. Please check that a "
|
||||
"specialization exists in file "
|
||||
"ck_tile/builder/reflect/"
|
||||
"instance_traits_device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp "
|
||||
"for the given template parameters.");
|
||||
return ck_tile::reflect::instance_string<DeviceOp>();
|
||||
}
|
||||
#endif
|
||||
};
|
||||
|
||||
} // namespace device
|
||||
|
||||
@@ -25,7 +25,7 @@
|
||||
* (3) number of iterations to cover the entire Y axis.
|
||||
|
||||
* The raked here represents how data is partitioned across different processing granularity.
|
||||
* It represents howe we are going to access the data in thread, warp, or blocked in contiguous
|
||||
* It represents how we are going to access the data in thread, warp, or blocked in contiguous
|
||||
region.
|
||||
* From below, the qualifier for 'raked' is the part of warp/thread hierarchy
|
||||
* in the split of Y tile dimension where the iteration happens,
|
||||
@@ -101,7 +101,7 @@ enum struct tile_distribution_pattern
|
||||
* @brief Block raked pattern - aka linear.
|
||||
*
|
||||
*/
|
||||
block_raked,
|
||||
block_raked
|
||||
};
|
||||
|
||||
struct tile_distribution_encoding_pattern
|
||||
@@ -144,7 +144,6 @@ struct tile_distribution_encoding_pattern_2d<BlockSize,
|
||||
NumWaveGroups>
|
||||
: public tile_distribution_encoding_pattern
|
||||
{
|
||||
|
||||
// TODO: make pattern where below condition does not need to hold - GGemmMultiDSplitk!
|
||||
static_assert(XPerTile % VecSize == 0, "XPerTile must be a multiple of VecSize!");
|
||||
static constexpr index_t warp_size = get_warp_size();
|
||||
|
||||
@@ -19,7 +19,7 @@
|
||||
namespace ck_tile {
|
||||
|
||||
/** @brief Maximum number of error values to display when checking errors */
|
||||
constexpr int ERROR_DETAIL_LIMIT = 5;
|
||||
constexpr int ERROR_DETAIL_LIMIT = 128;
|
||||
|
||||
/** @brief 8-bit floating point type */
|
||||
using F8 = ck_tile::fp8_t;
|
||||
|
||||
@@ -26,7 +26,8 @@ struct GroupedConvBwdWeightKernelArgs
|
||||
GroupedConvTraitsType_::ConvSpecialization,
|
||||
GroupedConvTraitsType_::VectorSizeA,
|
||||
GroupedConvTraitsType_::VectorSizeB,
|
||||
GroupedConvTraitsType_::VectorSizeC>;
|
||||
GroupedConvTraitsType_::VectorSizeC,
|
||||
GroupedConvTraitsType_::NumGroupsToMerge>;
|
||||
static constexpr index_t NumDTensor = GroupedConvTraitsType_::NumDTensor;
|
||||
|
||||
template <
|
||||
@@ -84,9 +85,11 @@ struct GroupedConvBwdWeightKernelArgs
|
||||
b_grid_desc_k_n = grid_descs.at(number<1>{});
|
||||
c_grid_desc_m_n = grid_descs.at(number<2>{});
|
||||
|
||||
group_stride_a = args.K_; // A: Out NWGK
|
||||
group_stride_b = args.C_; // B: In NWGC
|
||||
group_stride_c = args.K_ * args.C_ * // C: Wei GKXC
|
||||
NumGroupsPerBatch = GroupedConvTraitsType_::NumGroupsToMerge;
|
||||
group_stride_a = args.K_ * NumGroupsPerBatch; // A: Out NWGK
|
||||
group_stride_b = args.C_ * NumGroupsPerBatch; // B: In NWGC
|
||||
group_stride_c = args.K_ * args.C_ // C: Wei GKXC
|
||||
* NumGroupsPerBatch *
|
||||
std::accumulate(args.filter_spatial_lengths_.begin(),
|
||||
args.filter_spatial_lengths_.end(),
|
||||
1,
|
||||
@@ -95,7 +98,14 @@ struct GroupedConvBwdWeightKernelArgs
|
||||
GemmM = a_grid_desc_k_m.get_length(number<1>{});
|
||||
GemmN = b_grid_desc_k_n.get_length(number<1>{});
|
||||
GemmK = a_grid_desc_k_m.get_length(number<0>{});
|
||||
GemmBatch = args.G_;
|
||||
GemmBatch = integer_divide_ceil(args.G_, NumGroupsPerBatch);
|
||||
|
||||
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
|
||||
{
|
||||
std::cout << "GemmM: " << GemmM << ", GemmN: " << GemmN << ", GemmK: " << GemmK
|
||||
<< ", GemmBatch: " << GemmBatch
|
||||
<< ", NumGroupsPerBatch: " << NumGroupsPerBatch << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
template <
|
||||
@@ -160,9 +170,11 @@ struct GroupedConvBwdWeightKernelArgs
|
||||
b_grid_desc_k_n = grid_descs.at(number<1>{});
|
||||
c_grid_desc_m_n = grid_descs.at(number<2>{});
|
||||
|
||||
group_stride_a = args.K_; // A: Out NHWGK
|
||||
group_stride_b = args.C_; // B: In NHWGC
|
||||
group_stride_c = args.K_ * args.C_ * // C: Wei GKYXC
|
||||
NumGroupsPerBatch = GroupedConvTraitsType_::NumGroupsToMerge;
|
||||
group_stride_a = args.K_ * NumGroupsPerBatch; // A: Out NHWGK
|
||||
group_stride_b = args.C_ * NumGroupsPerBatch; // B: In NHWGC
|
||||
group_stride_c = args.K_ * args.C_ // C: Wei GKYXC
|
||||
* NumGroupsPerBatch *
|
||||
std::accumulate(args.filter_spatial_lengths_.begin(),
|
||||
args.filter_spatial_lengths_.end(),
|
||||
1,
|
||||
@@ -171,7 +183,14 @@ struct GroupedConvBwdWeightKernelArgs
|
||||
GemmM = a_grid_desc_k_m.get_length(number<1>{});
|
||||
GemmN = b_grid_desc_k_n.get_length(number<1>{});
|
||||
GemmK = a_grid_desc_k_m.get_length(number<0>{});
|
||||
GemmBatch = args.G_;
|
||||
GemmBatch = integer_divide_ceil(args.G_, NumGroupsPerBatch);
|
||||
|
||||
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
|
||||
{
|
||||
std::cout << "GemmM: " << GemmM << ", GemmN: " << GemmN << ", GemmK: " << GemmK
|
||||
<< ", GemmBatch: " << GemmBatch
|
||||
<< ", NumGroupsPerBatch: " << NumGroupsPerBatch << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
template <
|
||||
@@ -243,9 +262,11 @@ struct GroupedConvBwdWeightKernelArgs
|
||||
b_grid_desc_k_n = grid_descs.at(number<1>{});
|
||||
c_grid_desc_m_n = grid_descs.at(number<2>{});
|
||||
|
||||
group_stride_a = args.K_; // A: Out NDHWGK
|
||||
group_stride_b = args.C_; // B: In NDHWGC
|
||||
group_stride_c = args.K_ * args.C_ * // C: wEI GKZYXC
|
||||
NumGroupsPerBatch = GroupedConvTraitsType_::NumGroupsToMerge;
|
||||
group_stride_a = args.K_ * NumGroupsPerBatch; // A: Out NDHWGK
|
||||
group_stride_b = args.C_ * NumGroupsPerBatch; // B: In NDHWGC
|
||||
group_stride_c = args.K_ * args.C_ // C: Wei GKZYXC
|
||||
* NumGroupsPerBatch *
|
||||
std::accumulate(args.filter_spatial_lengths_.begin(),
|
||||
args.filter_spatial_lengths_.end(),
|
||||
1,
|
||||
@@ -254,7 +275,14 @@ struct GroupedConvBwdWeightKernelArgs
|
||||
GemmM = a_grid_desc_k_m.get_length(number<1>{});
|
||||
GemmN = b_grid_desc_k_n.get_length(number<1>{});
|
||||
GemmK = a_grid_desc_k_m.get_length(number<0>{});
|
||||
GemmBatch = args.G_;
|
||||
GemmBatch = integer_divide_ceil(args.G_, NumGroupsPerBatch);
|
||||
|
||||
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
|
||||
{
|
||||
std::cout << "GemmM: " << GemmM << ", GemmN: " << GemmN << ", GemmK: " << GemmK
|
||||
<< ", GemmBatch: " << GemmBatch
|
||||
<< ", NumGroupsPerBatch: " << NumGroupsPerBatch << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
using ABCGridDescs = remove_cvref_t<
|
||||
@@ -279,6 +307,7 @@ struct GroupedConvBwdWeightKernelArgs
|
||||
index_t GemmN;
|
||||
index_t GemmK;
|
||||
index_t GemmBatch;
|
||||
index_t NumGroupsPerBatch;
|
||||
|
||||
const void* out_ptr;
|
||||
const void* in_ptr;
|
||||
@@ -317,10 +346,9 @@ struct GroupedConvBwdWeightKernelArgs
|
||||
/// the policy is responsible for definition of all necessary data layouts and thread's
|
||||
/// work distribution.
|
||||
///
|
||||
/// @tparam GroupedConvTraitsType_ The type of class providing traits for grouped convolution.
|
||||
/// @tparam GroupedConvTraitsType_ The type of class providing traits for grouped convolution.
|
||||
/// @tparam TilePartitioner_ The type of class providing mapping of workgroup index into
|
||||
/// the
|
||||
/// output data tile to be calculated. It determines the
|
||||
/// the output data tile to be calculated. It determines the
|
||||
/// workgroup to data relationship (or in other words - which
|
||||
/// data would be processed and calculated by which workgroup).
|
||||
/// @tparam GemmPipeline_ The type of class which provides the core part of matrix
|
||||
@@ -382,8 +410,12 @@ struct GroupedConvolutionBackwardWeightKernel
|
||||
|
||||
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
|
||||
{
|
||||
constexpr auto NumGroupsToMerge = GroupedConvTraitsType_::NumGroupsToMerge;
|
||||
// clang-format off
|
||||
return concat('_', "grouped_convolution_backward_weight", gemm_prec_str<InDataType, WeiDataType>, GemmPipeline::GetName());
|
||||
if (NumGroupsToMerge > 1)
|
||||
return concat('_', "grouped_convolution_backward_weight", gemm_prec_str<InDataType, WeiDataType>, GemmPipeline::GetName(), "merge", NumGroupsToMerge);
|
||||
else
|
||||
return concat('_', "grouped_convolution_backward_weight", gemm_prec_str<InDataType, WeiDataType>, GemmPipeline::GetName());
|
||||
// clang-format on
|
||||
}
|
||||
|
||||
@@ -402,6 +434,12 @@ struct GroupedConvolutionBackwardWeightKernel
|
||||
CK_TILE_HOST static constexpr GroupedConvBwdWeightKernelArgsSpecialized
|
||||
MakeKernelArgs(const GroupedConvBwdWeightHostArgs& hostArgs)
|
||||
{
|
||||
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
|
||||
{
|
||||
std::cout << "MPerBlock: " << number<TilePartitioner::MPerBlock>{} << std::endl;
|
||||
std::cout << "NPerBlock: " << number<TilePartitioner::NPerBlock>{} << std::endl;
|
||||
std::cout << "KPerBlock: " << number<TilePartitioner::KPerBlock>{} << std::endl;
|
||||
}
|
||||
return GroupedConvBwdWeightKernelArgsSpecialized(hostArgs);
|
||||
}
|
||||
|
||||
@@ -442,11 +480,14 @@ struct GroupedConvolutionBackwardWeightKernel
|
||||
{
|
||||
return [&]() {
|
||||
if(kargs.k_batch > 1)
|
||||
hipGetErrorString(hipMemsetAsync(kargs.wei_ptr,
|
||||
0,
|
||||
kargs.GemmBatch * kargs.GemmM * kargs.GemmN *
|
||||
sizeof(WeiDataType),
|
||||
s.stream_id_));
|
||||
{
|
||||
// Total number of convolution groups (ConvG) = GemmBatch * NumGroupsPerBatch
|
||||
// since we require that ConvG % NumGroupsPerBatch == 0.
|
||||
const auto wei_size =
|
||||
kargs.GemmBatch * kargs.GemmM * kargs.GemmN * kargs.NumGroupsPerBatch;
|
||||
hipGetErrorString(
|
||||
hipMemsetAsync(kargs.wei_ptr, 0, wei_size * sizeof(WeiDataType), s.stream_id_));
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
@@ -527,7 +568,8 @@ struct GroupedConvolutionBackwardWeightKernel
|
||||
// Check access per C
|
||||
if(ConvC % GroupedConvTraitsType_::VectorSizeB != 0)
|
||||
{
|
||||
CK_TILE_ERROR("Conv C is not a multiple of vector load size for input image!");
|
||||
CK_TILE_ERROR("Conv C is not a multiple of vector load size for "
|
||||
"input image!");
|
||||
return false;
|
||||
}
|
||||
}
|
||||
@@ -559,7 +601,8 @@ struct GroupedConvolutionBackwardWeightKernel
|
||||
{
|
||||
if(ConvK % GroupedConvTraitsType_::VectorSizeA != 0)
|
||||
{
|
||||
CK_TILE_ERROR("Conv K is not a multiple of vector store size for output image!");
|
||||
CK_TILE_ERROR("Conv K is not a multiple of vector store size "
|
||||
"for output image!");
|
||||
return false;
|
||||
}
|
||||
}
|
||||
@@ -569,6 +612,18 @@ struct GroupedConvolutionBackwardWeightKernel
|
||||
return false;
|
||||
}
|
||||
|
||||
if constexpr(GroupedConvTraitsType_::NumGroupsToMerge > 1)
|
||||
{
|
||||
const index_t ConvG = kargs.wei_g_k_c_xs_lengths[number<0>{}];
|
||||
if(ConvG % GroupedConvTraitsType_::NumGroupsToMerge != 0)
|
||||
{
|
||||
CK_TILE_ERROR("ConvG must be a multiple of NumGroupsToMerge!");
|
||||
return false;
|
||||
}
|
||||
|
||||
// TODO: Should we also check that GemmM <= MPerBlock and GemmN <= NPerBlock?
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
@@ -654,6 +709,16 @@ struct GroupedConvolutionBackwardWeightKernel
|
||||
return make_tuple(a_pad_view, b_pad_view, ds_pad_view, c_pad_view);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Create views to the data that each workgroup will process.
|
||||
*
|
||||
* @param views padded views of A, B, D and C tensors
|
||||
* @param i_m block m-index
|
||||
* @param i_n block n-index
|
||||
* @param i_k block k-index
|
||||
*
|
||||
* @return tuple of tile windows for A, B, D and C tensors
|
||||
*/
|
||||
template <typename PadView>
|
||||
CK_TILE_DEVICE static auto MakeGemmTileWindows(const PadView& views,
|
||||
const index_t i_m,
|
||||
@@ -818,7 +883,6 @@ struct GroupedConvolutionBackwardWeightKernel
|
||||
const InDataType* b_ptr = static_cast<const InDataType*>(kargs.in_ptr) + group_offset_b;
|
||||
WeiDataType* c_ptr = static_cast<WeiDataType*>(kargs.wei_ptr) + group_offset_c;
|
||||
|
||||
// allocate LDS
|
||||
__shared__ char smem_ptr_0[GetSmemSize()];
|
||||
|
||||
if constexpr(GemmPipeline::DoubleSmemBuffer == true)
|
||||
|
||||
@@ -29,6 +29,7 @@ struct GroupedConvFwdKernelArgs
|
||||
GroupedConvTraitsType_::VectorSizeA,
|
||||
GroupedConvTraitsType_::VectorSizeB,
|
||||
GroupedConvTraitsType_::VectorSizeC,
|
||||
GroupedConvTraitsType_::NumGroupsToMerge,
|
||||
true>; // Split N enabled
|
||||
using CDElementwise = typename GroupedConvTraitsType_::CDElementwise;
|
||||
static constexpr index_t NumDTensor = GroupedConvTraitsType_::NumDTensor;
|
||||
|
||||
@@ -59,10 +59,11 @@ template <index_t NDimSpatial_,
|
||||
typename WeiLayout_,
|
||||
typename DsLayout_,
|
||||
typename OutLayout_,
|
||||
index_t VectorSizeA_ = 1,
|
||||
index_t VectorSizeB_ = 1,
|
||||
index_t VectorSizeC_ = 1,
|
||||
typename CDElementwise_ = PassThrough>
|
||||
index_t VectorSizeA_ = 1,
|
||||
index_t VectorSizeB_ = 1,
|
||||
index_t VectorSizeC_ = 1,
|
||||
index_t NumGroupsToMerge_ = 1,
|
||||
typename CDElementwise_ = PassThrough>
|
||||
struct GroupedConvTraits
|
||||
{
|
||||
private:
|
||||
@@ -73,7 +74,7 @@ struct GroupedConvTraits
|
||||
}
|
||||
|
||||
public:
|
||||
static constexpr index_t NumGroupsToMerge = 1;
|
||||
static constexpr index_t NumGroupsToMerge = NumGroupsToMerge_;
|
||||
static constexpr index_t NDimSpatial = NDimSpatial_;
|
||||
static constexpr ConvolutionSpecialization ConvSpecialization = ConvSpecialization_;
|
||||
using InLayout = InLayout_;
|
||||
|
||||
@@ -13,10 +13,10 @@ template <index_t NDimSpatial,
|
||||
index_t VectorSizeA,
|
||||
index_t VectorSizeB,
|
||||
index_t VectorSizeC,
|
||||
index_t NumGroupsToMerge = 1,
|
||||
bool SplitN = false,
|
||||
typename ADataType = float,
|
||||
typename CDataType = float,
|
||||
index_t NumGroupsToMerge = 1,
|
||||
typename IndexType = index_t>
|
||||
struct TransformConvBwdWeightToGemm
|
||||
{
|
||||
@@ -125,8 +125,7 @@ struct TransformConvBwdWeightToGemm
|
||||
InLeftPadW_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.InLeftPadW_)},
|
||||
InRightPadD_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.InRightPadD_)},
|
||||
InRightPadH_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.InRightPadH_)},
|
||||
InRightPadW_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.InRightPadW_)},
|
||||
ZYX_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.ZYX_)}
|
||||
InRightPadW_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.InRightPadW_)}
|
||||
{
|
||||
}
|
||||
|
||||
@@ -164,8 +163,7 @@ struct TransformConvBwdWeightToGemm
|
||||
InLeftPadW_{input_left_pads[I0]},
|
||||
InRightPadD_{I0},
|
||||
InRightPadH_{I0},
|
||||
InRightPadW_{input_right_pads[I0]},
|
||||
ZYX_{X_}
|
||||
InRightPadW_{input_right_pads[I0]}
|
||||
{
|
||||
static_assert(std::is_same_v<ConvSpatialDimsType, std::array<IndexType, NDimSpatial>> ||
|
||||
std::is_same_v<ConvSpatialDimsType, ck_tile::array<IndexType, NDimSpatial>>);
|
||||
@@ -219,8 +217,7 @@ struct TransformConvBwdWeightToGemm
|
||||
InLeftPadW_{input_left_pads[I1]},
|
||||
InRightPadD_{I0},
|
||||
InRightPadH_{input_right_pads[I0]},
|
||||
InRightPadW_{input_right_pads[I1]},
|
||||
ZYX_{Y_ * X_}
|
||||
InRightPadW_{input_right_pads[I1]}
|
||||
{
|
||||
static_assert(std::is_same_v<ConvSpatialDimsType, std::array<IndexType, NDimSpatial>> ||
|
||||
std::is_same_v<ConvSpatialDimsType, ck_tile::array<IndexType, NDimSpatial>>);
|
||||
@@ -274,8 +271,7 @@ struct TransformConvBwdWeightToGemm
|
||||
InLeftPadW_{input_left_pads[I2]},
|
||||
InRightPadD_{input_right_pads[I0]},
|
||||
InRightPadH_{input_right_pads[I1]},
|
||||
InRightPadW_{input_right_pads[I2]},
|
||||
ZYX_{Z_ * Y_ * X_}
|
||||
InRightPadW_{input_right_pads[I2]}
|
||||
{
|
||||
static_assert(std::is_same_v<ConvSpatialDimsType, std::array<IndexType, NDimSpatial>> ||
|
||||
std::is_same_v<ConvSpatialDimsType, ck_tile::array<IndexType, NDimSpatial>>);
|
||||
@@ -420,11 +416,21 @@ struct TransformConvBwdWeightToGemm
|
||||
const index_t NDoHoWoStride = G_ * K_;
|
||||
constexpr auto KStride = I1;
|
||||
|
||||
// TODO Add support for NumGroupsToMerge > 1
|
||||
return make_naive_tensor_descriptor(make_tuple(K_, N_ * Wo_),
|
||||
make_tuple(KStride, NDoHoWoStride),
|
||||
number<VectorSizeA>{},
|
||||
I1);
|
||||
if constexpr(NumGroupsToMerge > 1)
|
||||
{
|
||||
const index_t BatchStride = K_;
|
||||
return make_naive_tensor_descriptor(make_tuple(K_, NumGroupsToMerge, N_ * Wo_),
|
||||
make_tuple(KStride, BatchStride, NDoHoWoStride),
|
||||
number<VectorSizeA>{},
|
||||
I1);
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(K_, N_ * Wo_),
|
||||
make_tuple(KStride, NDoHoWoStride),
|
||||
number<VectorSizeA>{},
|
||||
I1);
|
||||
}
|
||||
}
|
||||
|
||||
template <index_t NDim = NDimSpatial, typename std::enable_if<NDim == 1, bool>::type = false>
|
||||
@@ -435,11 +441,22 @@ struct TransformConvBwdWeightToGemm
|
||||
const index_t WiStride = G_ * C_;
|
||||
constexpr auto CStride = I1;
|
||||
|
||||
// TODO Add support for NumGroupsToMerge > 1
|
||||
return make_naive_tensor_descriptor(make_tuple(N_, Wi_, C_),
|
||||
make_tuple(NStride, WiStride, CStride),
|
||||
number<VectorSizeB>{},
|
||||
I1);
|
||||
if constexpr(NumGroupsToMerge > 1)
|
||||
{
|
||||
const auto BatchStride = C_;
|
||||
return make_naive_tensor_descriptor(make_tuple(N_, Wi_, NumGroupsToMerge, C_),
|
||||
make_tuple(NStride, WiStride, BatchStride, CStride),
|
||||
number<VectorSizeB>{},
|
||||
I1);
|
||||
}
|
||||
else
|
||||
{
|
||||
|
||||
return make_naive_tensor_descriptor(make_tuple(N_, Wi_, C_),
|
||||
make_tuple(NStride, WiStride, CStride),
|
||||
number<VectorSizeB>{},
|
||||
I1);
|
||||
}
|
||||
}
|
||||
|
||||
template <index_t NDim = NDimSpatial, typename std::enable_if<NDim == 1, bool>::type = false>
|
||||
@@ -449,9 +466,56 @@ struct TransformConvBwdWeightToGemm
|
||||
const index_t KStride = X_ * C_;
|
||||
constexpr auto CXStride = I1;
|
||||
|
||||
// TODO Add support for NumGroupsToMerge > 1
|
||||
return make_naive_tensor_descriptor(
|
||||
make_tuple(K_, X_ * C_), make_tuple(KStride, CXStride), number<VectorSizeC>{}, I1);
|
||||
if constexpr(NumGroupsToMerge > 1)
|
||||
{
|
||||
const index_t XStride = C_;
|
||||
const index_t BatchStride = K_ * X_ * C_;
|
||||
// Add NumGroupsToMerge for Batch+M dimension and, 1 as a placeholder
|
||||
// for Batch+N dimension
|
||||
const auto desc = make_naive_tensor_descriptor(
|
||||
make_tuple(NumGroupsToMerge, K_, X_, 1, C_),
|
||||
make_tuple(BatchStride, KStride, XStride, BatchStride, CXStride),
|
||||
number<VectorSizeC>{},
|
||||
I1);
|
||||
// Pad 1 to NumGroupsToMerge
|
||||
const auto padded_desc = transform_tensor_descriptor(
|
||||
desc,
|
||||
make_tuple(make_pass_through_transform(NumGroupsToMerge),
|
||||
make_pass_through_transform(K_),
|
||||
make_pass_through_transform(X_),
|
||||
make_pad_transform(1, 0, NumGroupsToMerge - 1),
|
||||
make_pass_through_transform(C_)),
|
||||
make_tuple(
|
||||
sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}, sequence<4>{}),
|
||||
make_tuple(
|
||||
sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}, sequence<4>{}));
|
||||
// We need only matrices from diagonal. Xor returns 0 for the same
|
||||
// values. So if matrices is not on diagonal then it will be stored in padding.
|
||||
// To avoid use of modulo after xor we assume that NumBatch to merge is power of 2.
|
||||
static_assert(NumGroupsToMerge == 1 || NumGroupsToMerge == 2 || NumGroupsToMerge == 4 ||
|
||||
NumGroupsToMerge == 8 || NumGroupsToMerge == 16 ||
|
||||
NumGroupsToMerge == 32 || NumGroupsToMerge == 64);
|
||||
const auto unmerged_padded_desc = transform_tensor_descriptor(
|
||||
padded_desc,
|
||||
make_tuple(make_xor_transform(make_tuple(NumGroupsToMerge, NumGroupsToMerge)),
|
||||
make_pass_through_transform(K_),
|
||||
make_pass_through_transform(X_),
|
||||
make_pass_through_transform(C_)),
|
||||
make_tuple(sequence<0, 3>{}, sequence<1>{}, sequence<2>{}, sequence<4>{}),
|
||||
make_tuple(sequence<0, 3>{}, sequence<1>{}, sequence<2>{}, sequence<4>{}));
|
||||
// Merge To M, N
|
||||
return transform_tensor_descriptor(
|
||||
unmerged_padded_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(NumGroupsToMerge, K_)),
|
||||
make_merge_transform(make_tuple(X_, NumGroupsToMerge, C_))),
|
||||
make_tuple(sequence<0, 1>{}, sequence<2, 3, 4>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_naive_tensor_descriptor(
|
||||
make_tuple(K_, X_ * C_), make_tuple(KStride, CXStride), number<VectorSizeC>{}, I1);
|
||||
}
|
||||
}
|
||||
|
||||
template <index_t NDim = NDimSpatial, typename std::enable_if<NDim == 2, bool>::type = false>
|
||||
@@ -461,11 +525,22 @@ struct TransformConvBwdWeightToGemm
|
||||
const index_t NDoHoWoStride = G_ * K_;
|
||||
constexpr auto KStride = I1;
|
||||
|
||||
// TODO Add support for NumGroupsToMerge > 1
|
||||
return make_naive_tensor_descriptor(make_tuple(N_ * Ho_ * Wo_, K_), // K_M
|
||||
make_tuple(NDoHoWoStride, KStride),
|
||||
number<VectorSizeA>{},
|
||||
I1);
|
||||
if constexpr(NumGroupsToMerge > 1)
|
||||
{
|
||||
const index_t BatchStride = K_;
|
||||
return make_naive_tensor_descriptor(
|
||||
make_tuple(N_ * Ho_ * Wo_, NumGroupsToMerge, K_), // K_Gm_M
|
||||
make_tuple(NDoHoWoStride, BatchStride, KStride),
|
||||
number<VectorSizeA>{},
|
||||
I1);
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(N_ * Ho_ * Wo_, K_), // K_M
|
||||
make_tuple(NDoHoWoStride, KStride),
|
||||
number<VectorSizeA>{},
|
||||
I1);
|
||||
}
|
||||
}
|
||||
|
||||
template <index_t NDim = NDimSpatial, typename std::enable_if<NDim == 2, bool>::type = false>
|
||||
@@ -477,11 +552,22 @@ struct TransformConvBwdWeightToGemm
|
||||
const index_t WiStride = G_ * C_;
|
||||
constexpr auto CStride = I1;
|
||||
|
||||
// TODO Add support for NumGroupsToMerge > 1
|
||||
return make_naive_tensor_descriptor(make_tuple(N_, Hi_, Wi_, C_), // K_N
|
||||
make_tuple(NStride, HiStride, WiStride, CStride),
|
||||
number<VectorSizeB>{},
|
||||
I1);
|
||||
if constexpr(NumGroupsToMerge > 1)
|
||||
{
|
||||
const auto BatchStride = C_;
|
||||
return make_naive_tensor_descriptor(
|
||||
make_tuple(N_, Hi_, Wi_, NumGroupsToMerge, C_), // K_Gm_N
|
||||
make_tuple(NStride, HiStride, WiStride, BatchStride, CStride),
|
||||
number<VectorSizeB>{},
|
||||
I1);
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(N_, Hi_, Wi_, C_), // K_N
|
||||
make_tuple(NStride, HiStride, WiStride, CStride),
|
||||
number<VectorSizeB>{},
|
||||
I1);
|
||||
}
|
||||
}
|
||||
|
||||
template <index_t NDim = NDimSpatial, typename std::enable_if<NDim == 2, bool>::type = false>
|
||||
@@ -491,9 +577,58 @@ struct TransformConvBwdWeightToGemm
|
||||
const index_t KStride = Y_ * X_ * C_;
|
||||
constexpr auto CStride = I1;
|
||||
|
||||
// TODO Add support for NumGroupsToMerge > 1
|
||||
return make_naive_tensor_descriptor(
|
||||
make_tuple(K_, Y_ * X_ * C_), make_tuple(KStride, CStride), number<VectorSizeC>{}, I1);
|
||||
if constexpr(NumGroupsToMerge > 1)
|
||||
{
|
||||
const index_t YXStride = C_;
|
||||
const index_t BatchStride = K_ * Y_ * X_ * C_;
|
||||
// Add NumGroupsToMerge for Batch+M dimension and, 1 as a placeholder
|
||||
// for Batch+N dimension
|
||||
const auto desc = make_naive_tensor_descriptor(
|
||||
make_tuple(NumGroupsToMerge, K_, Y_ * X_, 1, C_),
|
||||
make_tuple(BatchStride, KStride, YXStride, BatchStride, CStride),
|
||||
number<VectorSizeC>{},
|
||||
I1);
|
||||
// Pad 1 to NumGroupsToMerge
|
||||
const auto padded_desc = transform_tensor_descriptor(
|
||||
desc,
|
||||
make_tuple(make_pass_through_transform(NumGroupsToMerge),
|
||||
make_pass_through_transform(K_),
|
||||
make_pass_through_transform(Y_ * X_),
|
||||
make_pad_transform(1, 0, NumGroupsToMerge - 1),
|
||||
make_pass_through_transform(C_)),
|
||||
make_tuple(
|
||||
sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}, sequence<4>{}),
|
||||
make_tuple(
|
||||
sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}, sequence<4>{}));
|
||||
// We need only matrices from diagonal. Xor returns 0 for the same
|
||||
// values. So if matrices is not on diagonal then it will be stored in padding.
|
||||
// To avoid use of modulo after xor we assume that NumBatch to merge is power of 2.
|
||||
static_assert(NumGroupsToMerge == 1 || NumGroupsToMerge == 2 || NumGroupsToMerge == 4 ||
|
||||
NumGroupsToMerge == 8 || NumGroupsToMerge == 16 ||
|
||||
NumGroupsToMerge == 32 || NumGroupsToMerge == 64);
|
||||
const auto unmerged_padded_desc = transform_tensor_descriptor(
|
||||
padded_desc,
|
||||
make_tuple(make_xor_transform(make_tuple(NumGroupsToMerge, NumGroupsToMerge)),
|
||||
make_pass_through_transform(K_),
|
||||
make_pass_through_transform(Y_ * X_),
|
||||
make_pass_through_transform(C_)),
|
||||
make_tuple(sequence<0, 3>{}, sequence<1>{}, sequence<2>{}, sequence<4>{}),
|
||||
make_tuple(sequence<0, 3>{}, sequence<1>{}, sequence<2>{}, sequence<4>{}));
|
||||
// Merge To M, N
|
||||
return transform_tensor_descriptor(
|
||||
unmerged_padded_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(NumGroupsToMerge, K_)),
|
||||
make_merge_transform(make_tuple(Y_ * X_, NumGroupsToMerge, C_))),
|
||||
make_tuple(sequence<0, 1>{}, sequence<2, 3, 4>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(K_, Y_ * X_ * C_),
|
||||
make_tuple(KStride, CStride),
|
||||
number<VectorSizeC>{},
|
||||
I1);
|
||||
}
|
||||
}
|
||||
|
||||
template <index_t NDim = NDimSpatial, typename std::enable_if<NDim == 3, bool>::type = false>
|
||||
@@ -503,11 +638,22 @@ struct TransformConvBwdWeightToGemm
|
||||
const index_t NDoHoWoStride = G_ * K_;
|
||||
constexpr auto KStride = I1;
|
||||
|
||||
// TODO Add support for NumGroupsToMerge > 1
|
||||
return make_naive_tensor_descriptor(make_tuple(N_ * Do_ * Ho_ * Wo_, K_),
|
||||
make_tuple(NDoHoWoStride, KStride),
|
||||
number<VectorSizeA>{},
|
||||
I1);
|
||||
if constexpr(NumGroupsToMerge > 1)
|
||||
{
|
||||
const auto BatchStride = K_;
|
||||
return make_naive_tensor_descriptor(
|
||||
make_tuple(N_ * Do_ * Ho_ * Wo_, NumGroupsToMerge, K_),
|
||||
make_tuple(NDoHoWoStride, BatchStride, KStride),
|
||||
number<VectorSizeA>{},
|
||||
I1);
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(N_ * Do_ * Ho_ * Wo_, K_),
|
||||
make_tuple(NDoHoWoStride, KStride),
|
||||
number<VectorSizeA>{},
|
||||
I1);
|
||||
}
|
||||
}
|
||||
|
||||
template <index_t NDim = NDimSpatial, typename std::enable_if<NDim == 3, bool>::type = false>
|
||||
@@ -519,26 +665,84 @@ struct TransformConvBwdWeightToGemm
|
||||
const index_t WiStride = G_ * C_;
|
||||
constexpr auto CStride = I1;
|
||||
|
||||
// TODO Add support for NumGroupsToMerge > 1
|
||||
return make_naive_tensor_descriptor(
|
||||
make_tuple(N_, Di_, Hi_, Wi_, C_),
|
||||
make_tuple(NStride, DiStride, HiStride, WiStride, CStride),
|
||||
number<VectorSizeB>{},
|
||||
I1);
|
||||
if constexpr(NumGroupsToMerge > 1)
|
||||
{
|
||||
const index_t BatchStride = C_;
|
||||
return make_naive_tensor_descriptor(
|
||||
make_tuple(N_, Di_, Hi_, Wi_, NumGroupsToMerge, C_),
|
||||
make_tuple(NStride, DiStride, HiStride, WiStride, BatchStride, CStride),
|
||||
number<VectorSizeB>{},
|
||||
I1);
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_naive_tensor_descriptor(
|
||||
make_tuple(N_, Di_, Hi_, Wi_, C_),
|
||||
make_tuple(NStride, DiStride, HiStride, WiStride, CStride),
|
||||
number<VectorSizeB>{},
|
||||
I1);
|
||||
}
|
||||
}
|
||||
|
||||
template <index_t NDim = NDimSpatial, typename std::enable_if<NDim == 3, bool>::type = false>
|
||||
CK_TILE_HOST auto make_wei_grid_desc() const
|
||||
{
|
||||
// KZYXC
|
||||
// GKZYXC
|
||||
const index_t KStride = Z_ * Y_ * X_ * C_;
|
||||
constexpr auto CStride = I1;
|
||||
|
||||
// TODO Add support for NumGroupsToMerge > 1
|
||||
return make_naive_tensor_descriptor(make_tuple(K_, Z_ * Y_ * X_ * C_),
|
||||
make_tuple(KStride, CStride),
|
||||
number<VectorSizeC>{},
|
||||
I1);
|
||||
if constexpr(NumGroupsToMerge > 1)
|
||||
{
|
||||
const index_t ZYXStride = C_;
|
||||
const index_t BatchStride = K_ * Z_ * Y_ * X_ * C_;
|
||||
// Add NumGroupsToMerge for Batch+M dimension and, 1 as a placeholder
|
||||
// for Batch+N dimension
|
||||
const auto desc = make_naive_tensor_descriptor(
|
||||
make_tuple(NumGroupsToMerge, K_, Z_ * Y_ * X_, 1, C_),
|
||||
make_tuple(BatchStride, KStride, ZYXStride, BatchStride, CStride),
|
||||
number<VectorSizeC>{},
|
||||
I1);
|
||||
// Pad 1 to NumGroupsToMerge
|
||||
const auto padded_desc = transform_tensor_descriptor(
|
||||
desc,
|
||||
make_tuple(make_pass_through_transform(NumGroupsToMerge),
|
||||
make_pass_through_transform(K_),
|
||||
make_pass_through_transform(Z_ * Y_ * X_),
|
||||
make_pad_transform(1, 0, NumGroupsToMerge - 1),
|
||||
make_pass_through_transform(C_)),
|
||||
make_tuple(
|
||||
sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}, sequence<4>{}),
|
||||
make_tuple(
|
||||
sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}, sequence<4>{}));
|
||||
// We need only matrices from diagonal. Xor returns 0 for the same
|
||||
// values. So if matrices is not on diagonal then it will be stored in padding.
|
||||
// To avoid use of modulo after xor we assume that NumBatch to merge is power of 2.
|
||||
static_assert(NumGroupsToMerge == 1 || NumGroupsToMerge == 2 || NumGroupsToMerge == 4 ||
|
||||
NumGroupsToMerge == 8 || NumGroupsToMerge == 16 ||
|
||||
NumGroupsToMerge == 32 || NumGroupsToMerge == 64);
|
||||
const auto unmerged_padded_desc = transform_tensor_descriptor(
|
||||
padded_desc,
|
||||
make_tuple(make_xor_transform(make_tuple(NumGroupsToMerge, NumGroupsToMerge)),
|
||||
make_pass_through_transform(K_),
|
||||
make_pass_through_transform(Z_ * Y_ * X_),
|
||||
make_pass_through_transform(C_)),
|
||||
make_tuple(sequence<0, 3>{}, sequence<1>{}, sequence<2>{}, sequence<4>{}),
|
||||
make_tuple(sequence<0, 3>{}, sequence<1>{}, sequence<2>{}, sequence<4>{}));
|
||||
// Merge To M, N
|
||||
return transform_tensor_descriptor(
|
||||
unmerged_padded_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(NumGroupsToMerge, K_)),
|
||||
make_merge_transform(make_tuple(Z_ * Y_ * X_, NumGroupsToMerge, C_))),
|
||||
make_tuple(sequence<0, 1>{}, sequence<2, 3, 4>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(K_, Z_ * Y_ * X_ * C_),
|
||||
make_tuple(KStride, CStride),
|
||||
number<VectorSizeC>{},
|
||||
I1);
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: implement ck_tile::tensor_layout::convolution that describe packed/strided dimemsion as
|
||||
@@ -552,31 +756,84 @@ struct TransformConvBwdWeightToGemm
|
||||
const auto wei_grid_desc = make_wei_grid_desc<NDimSpatial>();
|
||||
|
||||
// B: input tensor comes in K_N
|
||||
const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor(
|
||||
in_grid_desc,
|
||||
make_tuple(make_pass_through_transform(N_),
|
||||
make_pad_transform(Wi_, InLeftPadW_, InRightPadW_),
|
||||
make_pass_through_transform(C_)),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}));
|
||||
if constexpr(NumGroupsToMerge > 1)
|
||||
{
|
||||
// Output tensor transformation
|
||||
// [0, 1, 2] -> [0, 1]
|
||||
// [(N*Wo), Gm, K] -> [(N*Wo), (Gm*K)]
|
||||
const auto out_gemm_k_gemm_m_grid_desc = transform_tensor_descriptor(
|
||||
out_grid_desc,
|
||||
make_tuple(make_pass_through_transform(N_ * Wo_),
|
||||
make_merge_transform(make_tuple(NumGroupsToMerge, K_))),
|
||||
make_tuple(sequence<0>{}, sequence<1, 2>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
|
||||
const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor(
|
||||
in_n_hip_wip_c_grid_desc,
|
||||
make_tuple(
|
||||
make_pass_through_transform(N_),
|
||||
make_embed_transform(make_tuple(X_, Wo_), make_tuple(ConvDilationW_, ConvStrideW_)),
|
||||
make_pass_through_transform(C_)),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1, 2>{}, sequence<3>{}));
|
||||
// Input tensor transformation, part 1.
|
||||
// [N, Wi, Gm, C] -> [N, (Wi + InLeftPadW + InRightPadW), Gm, C] = [N, Wip, Gm, C]
|
||||
const auto in_n_wip_gm_c_grid_desc = transform_tensor_descriptor(
|
||||
in_grid_desc,
|
||||
make_tuple(make_pass_through_transform(N_),
|
||||
make_pad_transform(Wi_, InLeftPadW_, InRightPadW_),
|
||||
make_pass_through_transform(NumGroupsToMerge),
|
||||
make_pass_through_transform(C_)),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}));
|
||||
|
||||
const auto in_gemmn_gemmktotal_grid_desc =
|
||||
transform_tensor_descriptor(in_n_y_ho_x_wo_c_grid_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(X_, C_)),
|
||||
make_merge_transform(make_tuple(N_, Wo_))),
|
||||
make_tuple(sequence<1, 3>{}, sequence<0, 2>{}),
|
||||
make_tuple(sequence<1>{}, sequence<0>{}));
|
||||
// Input tensor transformation, part 2.
|
||||
// [N, Wip, Gm, C] -> [N, X, Wo, Gm, C]
|
||||
const auto in_n_x_wo_gm_c_grid_desc = transform_tensor_descriptor(
|
||||
in_n_wip_gm_c_grid_desc,
|
||||
make_tuple(make_pass_through_transform(N_),
|
||||
make_embed_transform(make_tuple(X_, Wo_),
|
||||
make_tuple(ConvDilationW_, ConvStrideW_)),
|
||||
make_pass_through_transform(NumGroupsToMerge),
|
||||
make_pass_through_transform(C_)),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1, 2>{}, sequence<3>{}, sequence<4>{}));
|
||||
|
||||
return make_tuple(out_grid_desc, in_gemmn_gemmktotal_grid_desc, wei_grid_desc);
|
||||
// Input tensor transformation, part 3.
|
||||
// [0, 1, 2, 3, 4] -> [0, 1]
|
||||
// [N, X, Wo, Gm, C] -> [(N*Wo), (Gm*X*C)]
|
||||
const auto in_gemm_n_gemm_k_grid_desc = transform_tensor_descriptor(
|
||||
in_n_x_wo_gm_c_grid_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(X_, NumGroupsToMerge, C_)),
|
||||
make_merge_transform(make_tuple(N_, Wo_))),
|
||||
make_tuple(sequence<1, 3, 4>{}, sequence<0, 2>{}),
|
||||
make_tuple(sequence<1>{}, sequence<0>{}));
|
||||
|
||||
return make_tuple(
|
||||
out_gemm_k_gemm_m_grid_desc, in_gemm_n_gemm_k_grid_desc, wei_grid_desc);
|
||||
}
|
||||
else
|
||||
{
|
||||
// [N, Wi, C] -> [N, (Wi + InLeftPadW + InRightPadW), C] = [N, Wip, C]
|
||||
const auto in_n_wip_c_grid_desc = transform_tensor_descriptor(
|
||||
in_grid_desc,
|
||||
make_tuple(make_pass_through_transform(N_),
|
||||
make_pad_transform(Wi_, InLeftPadW_, InRightPadW_),
|
||||
make_pass_through_transform(C_)),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}));
|
||||
|
||||
// [N, Wip, C] -> [N, X, Wo, C]
|
||||
const auto in_n_x_wo_c_grid_desc = transform_tensor_descriptor(
|
||||
in_n_wip_c_grid_desc,
|
||||
make_tuple(make_pass_through_transform(N_),
|
||||
make_embed_transform(make_tuple(X_, Wo_),
|
||||
make_tuple(ConvDilationW_, ConvStrideW_)),
|
||||
make_pass_through_transform(C_)),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1, 2>{}, sequence<3>{}));
|
||||
|
||||
const auto in_gemmn_gemmktotal_grid_desc =
|
||||
transform_tensor_descriptor(in_n_x_wo_c_grid_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(X_, C_)),
|
||||
make_merge_transform(make_tuple(N_, Wo_))),
|
||||
make_tuple(sequence<1, 3>{}, sequence<0, 2>{}),
|
||||
make_tuple(sequence<1>{}, sequence<0>{}));
|
||||
|
||||
return make_tuple(out_grid_desc, in_gemmn_gemmktotal_grid_desc, wei_grid_desc);
|
||||
}
|
||||
}
|
||||
|
||||
template <index_t NDim = NDimSpatial, typename std::enable_if<NDim == 2, bool>::type = false>
|
||||
@@ -587,33 +844,95 @@ struct TransformConvBwdWeightToGemm
|
||||
const auto wei_grid_desc = make_wei_grid_desc<NDimSpatial>();
|
||||
|
||||
// B: input tensor comes in K_N
|
||||
const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor(
|
||||
in_grid_desc,
|
||||
make_tuple(make_pass_through_transform(N_),
|
||||
make_pad_transform(Hi_, InLeftPadH_, InRightPadH_),
|
||||
make_pad_transform(Wi_, InLeftPadW_, InRightPadW_),
|
||||
make_pass_through_transform(C_)),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}));
|
||||
if constexpr(NumGroupsToMerge > 1)
|
||||
{
|
||||
// Output tensor transformation
|
||||
// [0, 1, 2] -> [0, 1]
|
||||
// [(N*Ho*Wo), Gm, K] -> [(N*Ho*Wo), (K*Gm)]
|
||||
const auto out_gemm_k_gemm_m_grid_desc = transform_tensor_descriptor(
|
||||
out_grid_desc,
|
||||
make_tuple(make_pass_through_transform(N_ * Ho_ * Wo_),
|
||||
make_merge_transform(make_tuple(NumGroupsToMerge, K_))),
|
||||
make_tuple(sequence<0>{}, sequence<1, 2>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
|
||||
const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor(
|
||||
in_n_hip_wip_c_grid_desc,
|
||||
make_tuple(
|
||||
make_pass_through_transform(N_),
|
||||
make_embed_transform(make_tuple(Y_, Ho_), make_tuple(ConvDilationH_, ConvStrideH_)),
|
||||
make_embed_transform(make_tuple(X_, Wo_), make_tuple(ConvDilationW_, ConvStrideW_)),
|
||||
make_pass_through_transform(C_)),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1, 2>{}, sequence<3, 4>{}, sequence<5>{}));
|
||||
// Input tensor transformation, part 1.
|
||||
// [N, Hi, Wi, Gm, C] -> [N, Hip, Wip, Gm, C]
|
||||
const auto in_n_hip_wip_gm_c_grid_desc = transform_tensor_descriptor(
|
||||
in_grid_desc,
|
||||
make_tuple(make_pass_through_transform(N_),
|
||||
make_pad_transform(Hi_, InLeftPadH_, InRightPadH_),
|
||||
make_pad_transform(Wi_, InLeftPadW_, InRightPadW_),
|
||||
make_pass_through_transform(NumGroupsToMerge),
|
||||
make_pass_through_transform(C_)),
|
||||
make_tuple(
|
||||
sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}, sequence<4>{}),
|
||||
make_tuple(
|
||||
sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}, sequence<4>{}));
|
||||
|
||||
const auto in_gemmn_gemmktotal_grid_desc =
|
||||
transform_tensor_descriptor(in_n_y_ho_x_wo_c_grid_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(Y_, X_, C_)),
|
||||
make_merge_transform(make_tuple(N_, Ho_, Wo_))),
|
||||
make_tuple(sequence<1, 3, 5>{}, sequence<0, 2, 4>{}),
|
||||
make_tuple(sequence<1>{}, sequence<0>{}));
|
||||
// Input tensor transformation, part 2.
|
||||
// [N, Hip, Wip, Gm, C] -> [N, (Y, Wo), (X, Wo), Gm, C]
|
||||
const auto in_n_y_ho_x_wo_gm_c_grid_desc = transform_tensor_descriptor(
|
||||
in_n_hip_wip_gm_c_grid_desc,
|
||||
make_tuple(make_pass_through_transform(N_),
|
||||
make_embed_transform(make_tuple(Y_, Ho_),
|
||||
make_tuple(ConvDilationH_, ConvStrideH_)),
|
||||
make_embed_transform(make_tuple(X_, Wo_),
|
||||
make_tuple(ConvDilationW_, ConvStrideW_)),
|
||||
make_pass_through_transform(NumGroupsToMerge),
|
||||
make_pass_through_transform(C_)),
|
||||
make_tuple(
|
||||
sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}, sequence<4>{}),
|
||||
make_tuple(sequence<0>{},
|
||||
sequence<1, 2>{},
|
||||
sequence<3, 4>{},
|
||||
sequence<5>{},
|
||||
sequence<6>{}));
|
||||
|
||||
return make_tuple(out_grid_desc, in_gemmn_gemmktotal_grid_desc, wei_grid_desc);
|
||||
// Input tensor transformation, part 3.
|
||||
// [0, 1, 2, 3, 4 5 6] -> [0, 1]
|
||||
// [N, Y, Ho, X, Wo, Gm, C] -> [(N*Ho*Wo), (Gm*Y*X*C)]
|
||||
const auto in_gemm_n_gemm_k_grid_desc = transform_tensor_descriptor(
|
||||
in_n_y_ho_x_wo_gm_c_grid_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(Y_, X_, NumGroupsToMerge, C_)),
|
||||
make_merge_transform(make_tuple(N_, Ho_, Wo_))),
|
||||
make_tuple(sequence<1, 3, 5, 6>{}, sequence<0, 2, 4>{}),
|
||||
make_tuple(sequence<1>{}, sequence<0>{}));
|
||||
|
||||
return make_tuple(
|
||||
out_gemm_k_gemm_m_grid_desc, in_gemm_n_gemm_k_grid_desc, wei_grid_desc);
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor(
|
||||
in_grid_desc,
|
||||
make_tuple(make_pass_through_transform(N_),
|
||||
make_pad_transform(Hi_, InLeftPadH_, InRightPadH_),
|
||||
make_pad_transform(Wi_, InLeftPadW_, InRightPadW_),
|
||||
make_pass_through_transform(C_)),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}));
|
||||
|
||||
const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor(
|
||||
in_n_hip_wip_c_grid_desc,
|
||||
make_tuple(make_pass_through_transform(N_),
|
||||
make_embed_transform(make_tuple(Y_, Ho_),
|
||||
make_tuple(ConvDilationH_, ConvStrideH_)),
|
||||
make_embed_transform(make_tuple(X_, Wo_),
|
||||
make_tuple(ConvDilationW_, ConvStrideW_)),
|
||||
make_pass_through_transform(C_)),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1, 2>{}, sequence<3, 4>{}, sequence<5>{}));
|
||||
|
||||
const auto in_gemmn_gemmktotal_grid_desc = transform_tensor_descriptor(
|
||||
in_n_y_ho_x_wo_c_grid_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(Y_, X_, C_)),
|
||||
make_merge_transform(make_tuple(N_, Ho_, Wo_))),
|
||||
make_tuple(sequence<1, 3, 5>{}, sequence<0, 2, 4>{}),
|
||||
make_tuple(sequence<1>{}, sequence<0>{}));
|
||||
|
||||
return make_tuple(out_grid_desc, in_gemmn_gemmktotal_grid_desc, wei_grid_desc);
|
||||
}
|
||||
}
|
||||
|
||||
template <index_t NDim = NDimSpatial, typename std::enable_if<NDim == 3, bool>::type = false>
|
||||
@@ -624,39 +943,121 @@ struct TransformConvBwdWeightToGemm
|
||||
const auto wei_grid_desc = make_wei_grid_desc<NDimSpatial>();
|
||||
|
||||
// B: input tensor comes in K_N
|
||||
const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor(
|
||||
in_grid_desc,
|
||||
make_tuple(make_pass_through_transform(N_),
|
||||
make_pad_transform(Di_, InLeftPadD_, InRightPadD_),
|
||||
make_pad_transform(Hi_, InLeftPadH_, InRightPadH_),
|
||||
make_pad_transform(Wi_, InLeftPadW_, InRightPadW_),
|
||||
make_pass_through_transform(C_)),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}, sequence<4>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}, sequence<4>{}));
|
||||
if constexpr(NumGroupsToMerge > 1)
|
||||
{
|
||||
// Output tensor transformation
|
||||
// [0, 1, 2] -> [0, 1]
|
||||
// [(N*Do*Ho*Wo), Gm, K] -> [(N*Do*Ho*Wo), (K*Gm)]
|
||||
const auto out_gemm_k_gemm_m_grid_desc = transform_tensor_descriptor(
|
||||
out_grid_desc,
|
||||
make_tuple(make_pass_through_transform(N_ * Do_ * Ho_ * Wo_),
|
||||
make_merge_transform(make_tuple(NumGroupsToMerge, K_))),
|
||||
make_tuple(sequence<0>{}, sequence<1, 2>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
|
||||
const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor(
|
||||
in_n_hip_wip_c_grid_desc,
|
||||
make_tuple(
|
||||
make_pass_through_transform(N_),
|
||||
make_embed_transform(make_tuple(Z_, Do_), make_tuple(ConvDilationD_, ConvStrideD_)),
|
||||
make_embed_transform(make_tuple(Y_, Ho_), make_tuple(ConvDilationH_, ConvStrideH_)),
|
||||
make_embed_transform(make_tuple(X_, Wo_), make_tuple(ConvDilationW_, ConvStrideW_)),
|
||||
make_pass_through_transform(C_)),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}, sequence<4>{}),
|
||||
make_tuple(sequence<0>{},
|
||||
sequence<1, 2>{},
|
||||
sequence<3, 4>{},
|
||||
sequence<5, 6>{},
|
||||
sequence<7>{}));
|
||||
// Input tensor transformation, part 1.
|
||||
// [N, Di, Hi, Wi, Gm, C] -> [N, Dip, Hip, Wip, Gm, C]
|
||||
const auto in_n_dip_hip_wip_gm_c_grid_desc = transform_tensor_descriptor(
|
||||
in_grid_desc,
|
||||
make_tuple(make_pass_through_transform(N_),
|
||||
make_pad_transform(Di_, InLeftPadD_, InRightPadD_),
|
||||
make_pad_transform(Hi_, InLeftPadH_, InRightPadH_),
|
||||
make_pad_transform(Wi_, InLeftPadW_, InRightPadW_),
|
||||
make_pass_through_transform(NumGroupsToMerge),
|
||||
make_pass_through_transform(C_)),
|
||||
make_tuple(sequence<0>{},
|
||||
sequence<1>{},
|
||||
sequence<2>{},
|
||||
sequence<3>{},
|
||||
sequence<4>{},
|
||||
sequence<5>{}),
|
||||
make_tuple(sequence<0>{},
|
||||
sequence<1>{},
|
||||
sequence<2>{},
|
||||
sequence<3>{},
|
||||
sequence<4>{},
|
||||
sequence<5>{}));
|
||||
|
||||
const auto in_gemmn_gemmktotal_grid_desc = transform_tensor_descriptor(
|
||||
in_n_y_ho_x_wo_c_grid_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(Z_, Y_, X_, C_)),
|
||||
make_merge_transform(make_tuple(N_, Do_, Ho_, Wo_))),
|
||||
make_tuple(sequence<1, 3, 5, 7>{}, sequence<0, 2, 4, 6>{}),
|
||||
make_tuple(sequence<1>{}, sequence<0>{}));
|
||||
// Input tensor transformation, part 2.
|
||||
// [N, Zip, Hip, Wip, Gm, C] -> [N, (Z, Zo), (Y, Wo), (X, Wo), Gm, C]
|
||||
const auto in_n_z_do_y_ho_x_wo_gm_c_grid_desc = transform_tensor_descriptor(
|
||||
in_n_dip_hip_wip_gm_c_grid_desc,
|
||||
make_tuple(make_pass_through_transform(N_),
|
||||
make_embed_transform(make_tuple(Z_, Do_),
|
||||
make_tuple(ConvDilationD_, ConvStrideD_)),
|
||||
make_embed_transform(make_tuple(Y_, Ho_),
|
||||
make_tuple(ConvDilationH_, ConvStrideH_)),
|
||||
make_embed_transform(make_tuple(X_, Wo_),
|
||||
make_tuple(ConvDilationW_, ConvStrideW_)),
|
||||
make_pass_through_transform(NumGroupsToMerge),
|
||||
make_pass_through_transform(C_)),
|
||||
make_tuple(sequence<0>{},
|
||||
sequence<1>{},
|
||||
sequence<2>{},
|
||||
sequence<3>{},
|
||||
sequence<4>{},
|
||||
sequence<5>{}),
|
||||
make_tuple(sequence<0>{},
|
||||
sequence<1, 2>{},
|
||||
sequence<3, 4>{},
|
||||
sequence<5, 6>{},
|
||||
sequence<7>{},
|
||||
sequence<8>{}));
|
||||
|
||||
return make_tuple(out_grid_desc, in_gemmn_gemmktotal_grid_desc, wei_grid_desc);
|
||||
// Input tensor transformation, part 3.
|
||||
// [0, 1, 2, 3, 4, 5, 6, 7, 8] -> [0, 1]
|
||||
// [N, Z, Do, Y, Ho, X, Wo, Gm, C] -> [(N*Do*Ho*Wo), (Z*Y*X*Gm*C)]
|
||||
const auto in_gemm_k_gemm_n_grid_desc = transform_tensor_descriptor(
|
||||
in_n_z_do_y_ho_x_wo_gm_c_grid_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(Z_, Y_, X_, NumGroupsToMerge, C_)),
|
||||
make_merge_transform(make_tuple(N_, Do_, Ho_, Wo_))),
|
||||
make_tuple(sequence<1, 3, 5, 7, 8>{}, sequence<0, 2, 4, 6>{}),
|
||||
make_tuple(sequence<1>{}, sequence<0>{}));
|
||||
|
||||
return make_tuple(
|
||||
out_gemm_k_gemm_m_grid_desc, in_gemm_k_gemm_n_grid_desc, wei_grid_desc);
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor(
|
||||
in_grid_desc,
|
||||
make_tuple(make_pass_through_transform(N_),
|
||||
make_pad_transform(Di_, InLeftPadD_, InRightPadD_),
|
||||
make_pad_transform(Hi_, InLeftPadH_, InRightPadH_),
|
||||
make_pad_transform(Wi_, InLeftPadW_, InRightPadW_),
|
||||
make_pass_through_transform(C_)),
|
||||
make_tuple(
|
||||
sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}, sequence<4>{}),
|
||||
make_tuple(
|
||||
sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}, sequence<4>{}));
|
||||
|
||||
const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor(
|
||||
in_n_hip_wip_c_grid_desc,
|
||||
make_tuple(make_pass_through_transform(N_),
|
||||
make_embed_transform(make_tuple(Z_, Do_),
|
||||
make_tuple(ConvDilationD_, ConvStrideD_)),
|
||||
make_embed_transform(make_tuple(Y_, Ho_),
|
||||
make_tuple(ConvDilationH_, ConvStrideH_)),
|
||||
make_embed_transform(make_tuple(X_, Wo_),
|
||||
make_tuple(ConvDilationW_, ConvStrideW_)),
|
||||
make_pass_through_transform(C_)),
|
||||
make_tuple(
|
||||
sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}, sequence<4>{}),
|
||||
make_tuple(sequence<0>{},
|
||||
sequence<1, 2>{},
|
||||
sequence<3, 4>{},
|
||||
sequence<5, 6>{},
|
||||
sequence<7>{}));
|
||||
|
||||
const auto in_gemmn_gemmktotal_grid_desc = transform_tensor_descriptor(
|
||||
in_n_y_ho_x_wo_c_grid_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(Z_, Y_, X_, C_)),
|
||||
make_merge_transform(make_tuple(N_, Do_, Ho_, Wo_))),
|
||||
make_tuple(sequence<1, 3, 5, 7>{}, sequence<0, 2, 4, 6>{}),
|
||||
make_tuple(sequence<1>{}, sequence<0>{}));
|
||||
|
||||
return make_tuple(out_grid_desc, in_gemmn_gemmktotal_grid_desc, wei_grid_desc);
|
||||
}
|
||||
}
|
||||
|
||||
IndexType G_, N_;
|
||||
@@ -668,7 +1069,6 @@ struct TransformConvBwdWeightToGemm
|
||||
IndexType ConvDilationD_, ConvDilationH_, ConvDilationW_;
|
||||
IndexType InLeftPadD_, InLeftPadH_, InLeftPadW_;
|
||||
IndexType InRightPadD_, InRightPadH_, InRightPadW_;
|
||||
IndexType ZYX_;
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -13,10 +13,10 @@ template <index_t NDimSpatial,
|
||||
index_t VectorSizeA,
|
||||
index_t VectorSizeB,
|
||||
index_t VectorSizeC,
|
||||
index_t NumGroupsToMerge = 1,
|
||||
bool SplitN = false,
|
||||
typename ADataType = float,
|
||||
typename CDataType = float,
|
||||
index_t NumGroupsToMerge = 1,
|
||||
typename IndexType = index_t>
|
||||
struct TransformConvFwdToGemm
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user