Merge branch 'develop' into moe_xcd_remap

This commit is contained in:
Tianxing Wu
2025-11-10 14:16:51 +02:00
committed by GitHub
47 changed files with 4543 additions and 827 deletions

6
.gitignore vendored
View File

@@ -66,6 +66,12 @@ docs/doxygen/xml
cmake-build*/
build*/
# LSP configuration
.clangd
# User-defined CMake presets
CMakeUserPresets.json
# Python virtualenv
.venv/

View File

@@ -181,7 +181,7 @@ constexpr ck::index_t ScaleBlockSize = 32; // scaling block
constexpr ck::index_t KPerBlock = 256 / DataPackedSize; // 256 f4 = 128 fp4x2
static constexpr ck::index_t Nswizzle = false;
static constexpr ck::index_t ActOP = 0; // 0: gelu_and_mul, 1: silu_and_mul
static constexpr ck::index_t MPerBlock = 128;
static constexpr ck::index_t MPerBlock = 32;
static constexpr bool MulRoutedWeight = true;
// clang-format off
@@ -190,10 +190,10 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemmMXBPreShuffl
A0DataType, A1DataType, B0DataType, B1DataType, DsDataType, EDataType, AccDataType, CShuffleDataType,
AElementOp, BElementOp, CDEElementOp, GemmSpec,
ScaleBlockSize, 256,
MPerBlock, 64, KPerBlock,
MPerBlock, 128, KPerBlock,
16, 16,
16, 16,
4, 2,
2, 2,
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1,
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1,
2, 2, S<1, 32, 1, 8>, S<8, 1, 1, 1>,
@@ -213,10 +213,10 @@ int main(int argc, char* argv[])
ck::index_t sorted_size = sorted_tile_num * MPerBlock;
ck::index_t valid_size = valid_tile_num * MPerBlock;
ck::index_t N = 6144;
ck::index_t K = 4096;
ck::index_t N = 7168;
ck::index_t K = 256;
ck::index_t experts = 8;
ck::index_t tokens = 832;
ck::index_t tokens = 208;
ck::index_t topk = 2;
if(argc == 1)

View File

@@ -14,19 +14,11 @@
struct ConvConfigBase
{
static constexpr bool kPadM = true;
static constexpr bool kPadN = true;
static constexpr bool kPadK = true;
static constexpr bool TransposeC = false;
static constexpr ck_tile::index_t VectorSizeA = 4;
static constexpr ck_tile::index_t VectorSizeB = 8;
static constexpr ck_tile::index_t VectorSizeC = 8;
static constexpr int kBlockPerCu = 1;
static constexpr ck_tile::index_t TileParitionerGroupNum = 8;
static constexpr ck_tile::index_t TileParitionerM01 = 4;
static constexpr int kBlockPerCu = 1;
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave;
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V3;
static constexpr ck_tile::index_t NumWaveGroups = 1;
@@ -210,9 +202,9 @@ struct ConvConfigComputeV5 : public ConvConfigBase
static constexpr ck_tile::index_t N_Warp_Tile = 32;
static constexpr ck_tile::index_t K_Warp_Tile = 16;
static constexpr bool DoubleSmemBuffer = false;
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V5;
static constexpr ck_tile::index_t NumWaNumWaveGroups = 2;
static constexpr bool DoubleSmemBuffer = false;
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V5;
static constexpr ck_tile::index_t NumWaveGroups = 2;
};
template <typename PrecType>

View File

@@ -22,8 +22,6 @@ struct GroupedConvolutionBackwardDataInvoker
static float grouped_conv_bwd_data(const ck_tile::GroupedConvBwdDataHostArgs& args,
const ck_tile::stream_config& s)
{
constexpr int kBlockPerCu = 1;
// Implicit GEMM Traits
using GemmShape = ck_tile::TileGemmShape<
ck_tile::sequence<ConvConfig::M_Tile, ConvConfig::N_Tile, ConvConfig::K_Tile>,
@@ -32,36 +30,33 @@ struct GroupedConvolutionBackwardDataInvoker
ConvConfig::N_Warp_Tile,
ConvConfig::K_Warp_Tile>>;
constexpr ck_tile::index_t VectorSizeA = 8;
constexpr ck_tile::index_t VectorSizeB = 8;
constexpr ck_tile::index_t VectorSizeC = 8;
constexpr auto ConvSpec = ck_tile::ConvolutionSpecialization::Default;
using TilePartitioner =
ck_tile::GemmSpatiallyLocalTilePartitioner<GemmShape,
ConvConfig::TileParitionerGroupNum,
ConvConfig::TileParitionerM01>;
constexpr auto ConvSpec = ck_tile::ConvolutionSpecialization::Default;
using GroupedConvTraitsType = ck_tile::GroupedConvTraits<NDimSpatial,
ConvSpec,
InLayout,
WeiLayout,
DsLayout,
OutLayout,
VectorSizeA,
VectorSizeB,
VectorSizeC>;
ConvConfig::VectorSizeA,
ConvConfig::VectorSizeB,
ConvConfig::VectorSizeC>;
using TilePartitioner = ck_tile::GemmSpatiallyLocalTilePartitioner<
GemmShape,
GroupedConvTraitsType::FixedGemmParams::TilePartitionerGroupNum,
GroupedConvTraitsType::FixedGemmParams::TilePartitionerM01>;
using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits<
ConvConfig::kPadM,
ConvConfig::kPadN,
ConvConfig::kPadK,
GroupedConvTraitsType::FixedGemmParams::kPadM,
GroupedConvTraitsType::FixedGemmParams::kPadN,
GroupedConvTraitsType::FixedGemmParams::kPadK,
ConvConfig::DoubleSmemBuffer,
typename GroupedConvTraitsType::GroupedConvImplicitGemmTraitsBwdData::AsLayout,
typename GroupedConvTraitsType::GroupedConvImplicitGemmTraitsBwdData::BsLayout,
typename GroupedConvTraitsType::GroupedConvImplicitGemmTraitsBwdData::CLayout,
ConvConfig::TransposeC,
false,
false, // Persistent,
typename GroupedConvTraitsType::AsLayoutBwdData,
typename GroupedConvTraitsType::BsLayoutBwdData,
typename GroupedConvTraitsType::CLayoutBwdData,
GroupedConvTraitsType::FixedGemmParams::TransposeC,
GroupedConvTraitsType::FixedGemmParams::UseStructuredSparsity,
GroupedConvTraitsType::FixedGemmParams::Persistent,
ConvConfig::NumWaveGroups>;
using GemmPipelineProblem = ck_tile::GemmPipelineProblem<
@@ -69,13 +64,14 @@ struct GroupedConvolutionBackwardDataInvoker
WeiDataType,
AccDataType,
GemmShape,
typename GroupedConvTraitsType::GroupedConvImplicitGemmTraitsBwdData,
typename GroupedConvTraitsType::template GroupedConvImplicitGemmTraitsBwdData<
ConvConfig::NumWaveGroups>,
ck_tile::element_wise::PassThrough,
ck_tile::element_wise::PassThrough,
InDataType,
true,
VectorSizeA,
VectorSizeB>;
GroupedConvTraitsType::FixedGemmParams::FixedVectorSize,
GroupedConvTraitsType::VectorSizeA,
GroupedConvTraitsType::VectorSizeB>;
using BaseGemmPipeline = typename PipelineTypeTraits<
ConvConfig::Pipeline>::template UniversalGemmPipeline<GemmPipelineProblem>;
@@ -93,95 +89,96 @@ struct GroupedConvolutionBackwardDataInvoker
const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop);
float ave_time{0};
const auto Run =
[&](const auto has_hot_loop_, const auto tail_number_, const auto memory_operation_) {
constexpr bool has_hot_loop_v = has_hot_loop_.value;
constexpr auto tail_number_v = tail_number_.value;
constexpr auto scheduler = ConvConfig::Scheduler;
constexpr auto memory_operation = memory_operation_.value;
const auto Run = [&](const auto has_hot_loop_,
const auto tail_number_,
const auto memory_operation_) {
constexpr bool has_hot_loop_v = has_hot_loop_.value;
constexpr auto tail_number_v = tail_number_.value;
constexpr auto scheduler = ConvConfig::Scheduler;
constexpr auto memory_operation = memory_operation_.value;
using UniversalGemmProblem =
ck_tile::UniversalGemmPipelineProblem<OutDataType,
WeiDataType,
AccDataType,
GemmShape,
GemmUniversalTraits,
scheduler,
has_hot_loop_v,
tail_number_v,
ck_tile::element_wise::PassThrough,
ck_tile::element_wise::PassThrough,
InDataType,
true,
VectorSizeA,
VectorSizeB>;
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<
OutDataType,
WeiDataType,
AccDataType,
GemmShape,
GemmUniversalTraits,
scheduler,
has_hot_loop_v,
tail_number_v,
ck_tile::element_wise::PassThrough,
ck_tile::element_wise::PassThrough,
InDataType,
GroupedConvTraitsType::FixedGemmParams::FixedVectorSize,
GroupedConvTraitsType::VectorSizeA,
GroupedConvTraitsType::VectorSizeB>;
using GemmPipeline = typename PipelineTypeTraits<
ConvConfig::Pipeline>::template GemmPipeline<UniversalGemmProblem>;
using GemmPipeline = typename PipelineTypeTraits<
ConvConfig::Pipeline>::template GemmPipeline<UniversalGemmProblem>;
using ConvEpilogue = ck_tile::CShuffleEpilogue<ck_tile::CShuffleEpilogueProblem<
OutDataType,
WeiDataType,
DsDataType,
AccDataType,
InDataType,
typename GroupedConvTraitsType::ImplicitGemmDsLayout,
ck_tile::tensor_layout::gemm::RowMajor,
CDEElementWise,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
ConvConfig::M_Warp,
ConvConfig::N_Warp,
ConvConfig::M_Warp_Tile,
ConvConfig::N_Warp_Tile,
ConvConfig::K_Warp_Tile,
ConvConfig::TransposeC,
memory_operation,
1,
true,
GroupedConvTraitsType::VectorSizeC>>;
using ConvEpilogue = ck_tile::CShuffleEpilogue<ck_tile::CShuffleEpilogueProblem<
OutDataType,
WeiDataType,
DsDataType,
AccDataType,
InDataType,
typename GroupedConvTraitsType::ImplicitGemmDsLayout,
typename GroupedConvTraitsType::FixedGemmParams::ELayout,
CDEElementWise,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
ConvConfig::M_Warp,
ConvConfig::N_Warp,
ConvConfig::M_Warp_Tile,
ConvConfig::N_Warp_Tile,
ConvConfig::K_Warp_Tile,
GroupedConvTraitsType::FixedGemmParams::TransposeC,
memory_operation,
ConvConfig::NumWaveGroups,
GroupedConvTraitsType::FixedGemmParams::FixedVectorSize,
GroupedConvTraitsType::VectorSizeC>>;
using Kernel = ck_tile::GroupedConvolutionBackwardDataKernel<GroupedConvTraitsType,
TilePartitioner,
GemmPipeline,
ConvEpilogue>;
auto kargs = Kernel::MakeKernelArgs(args);
using Kernel = ck_tile::GroupedConvolutionBackwardDataKernel<GroupedConvTraitsType,
TilePartitioner,
GemmPipeline,
ConvEpilogue>;
auto kargs = Kernel::MakeKernelArgs(args);
const dim3 grids = Kernel::GridSize(args);
const dim3 blocks = Kernel::BlockSize();
const dim3 grids = Kernel::GridSize(args);
const dim3 blocks = Kernel::BlockSize();
if(!Kernel::IsSupportedArgument(kargs))
{
throw std::runtime_error("Wrong! Arguments not supported! Skipping conv!\n");
}
if(!Kernel::IsSupportedArgument(kargs))
{
throw std::runtime_error("Wrong! Arguments not supported! Skipping conv!\n");
}
if(s.log_level_ > 0)
{
std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n'
<< "shape: " << GemmShape::GetName() << '\n'
<< "problem: " << UniversalGemmProblem::GetName() << '\n'
<< "pipeline: " << GemmPipeline::GetName() << '\n'
<< "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z
<< "}" << '\n'
<< "Vector size A: " << GemmPipeline::GetVectorSizeA()
<< ", Vector size B: " << GemmPipeline::GetVectorSizeB()
<< ", Vector size C: " << ConvEpilogue::GetVectorSizeC() << std::endl;
}
if(s.log_level_ > 0)
{
std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n'
<< "shape: " << GemmShape::GetName() << '\n'
<< "problem: " << UniversalGemmProblem::GetName() << '\n'
<< "pipeline: " << GemmPipeline::GetName() << '\n'
<< "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z
<< "}" << '\n'
<< "Vector size A: " << GemmPipeline::GetVectorSizeA()
<< ", Vector size B: " << GemmPipeline::GetVectorSizeB()
<< ", Vector size C: " << ConvEpilogue::GetVectorSizeC() << std::endl;
}
auto preprocess = [&]() {
ck_tile::hip_check_error(hipMemsetAsync(
kargs.in_ptr, 0, args.template GetInputByte<InDataType>(), s.stream_id_));
};
ave_time = ck_tile::launch_kernel_time_mask(
s,
preprocess,
ck_tile::make_kernel<kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
return ave_time;
auto preprocess = [&]() {
ck_tile::hip_check_error(hipMemsetAsync(
kargs.in_ptr, 0, args.template GetInputByte<InDataType>(), s.stream_id_));
};
ave_time = ck_tile::launch_kernel_time_mask(
s,
preprocess,
ck_tile::make_kernel<ConvConfig::kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
return ave_time;
};
const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) {
if(args.k_batch == 1)
{

View File

@@ -21,8 +21,6 @@ struct GroupedConvolutionBackwardWeightInvoker
static float grouped_conv_bwd_weight(const ck_tile::GroupedConvBwdWeightHostArgs& args,
const ck_tile::stream_config& s)
{
constexpr int kBlockPerCu = 1;
// Implicit GEMM Traits
using GemmShape = ck_tile::TileGemmShape<
ck_tile::sequence<ConvConfig::M_Tile, ConvConfig::N_Tile, ConvConfig::K_Tile>,
@@ -31,37 +29,34 @@ struct GroupedConvolutionBackwardWeightInvoker
ConvConfig::N_Warp_Tile,
ConvConfig::K_Warp_Tile>>;
constexpr ck_tile::index_t VectorSizeA = ConvConfig::VectorSizeA;
constexpr ck_tile::index_t VectorSizeB = ConvConfig::VectorSizeB;
constexpr ck_tile::index_t VectorSizeC = ConvConfig::VectorSizeC;
constexpr auto ConvSpec = ck_tile::ConvolutionSpecialization::Default;
using TilePartitioner =
ck_tile::GemmSpatiallyLocalTilePartitioner<GemmShape,
ConvConfig::TileParitionerGroupNum,
ConvConfig::TileParitionerM01>;
constexpr auto ConvSpec = ck_tile::ConvolutionSpecialization::Default;
using GroupedConvTraitsType = ck_tile::GroupedConvTraits<NDimSpatial,
ConvSpec,
InLayout,
WeiLayout,
DsLayout,
OutLayout,
VectorSizeA,
VectorSizeB,
VectorSizeC,
ConvConfig::VectorSizeA,
ConvConfig::VectorSizeB,
ConvConfig::VectorSizeC,
ConvConfig::NumGroupsToMerge>;
using TilePartitioner = ck_tile::GemmSpatiallyLocalTilePartitioner<
GemmShape,
GroupedConvTraitsType::FixedGemmParams::TilePartitionerGroupNum,
GroupedConvTraitsType::FixedGemmParams::TilePartitionerM01>;
using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits<
ConvConfig::kPadM,
ConvConfig::kPadN,
ConvConfig::kPadK,
GroupedConvTraitsType::FixedGemmParams::kPadM,
GroupedConvTraitsType::FixedGemmParams::kPadN,
GroupedConvTraitsType::FixedGemmParams::kPadK,
ConvConfig::DoubleSmemBuffer,
typename GroupedConvTraitsType::GroupedConvImplicitGemmTraitsBwdWeight::AsLayout,
typename GroupedConvTraitsType::GroupedConvImplicitGemmTraitsBwdWeight::BsLayout,
typename GroupedConvTraitsType::GroupedConvImplicitGemmTraitsBwdWeight::CLayout,
ConvConfig::TransposeC,
false,
false, // Persistent,
typename GroupedConvTraitsType::AsLayoutBwdWeight,
typename GroupedConvTraitsType::BsLayoutBwdWeight,
typename GroupedConvTraitsType::CLayoutBwdWeight,
GroupedConvTraitsType::FixedGemmParams::TransposeC,
GroupedConvTraitsType::FixedGemmParams::UseStructuredSparsity,
GroupedConvTraitsType::FixedGemmParams::Persistent,
ConvConfig::NumWaveGroups>;
using GemmPipelineProblem = ck_tile::GemmPipelineProblem<
@@ -69,13 +64,14 @@ struct GroupedConvolutionBackwardWeightInvoker
InDataType,
AccDataType,
GemmShape,
typename GroupedConvTraitsType::GroupedConvImplicitGemmTraitsBwdWeight,
typename GroupedConvTraitsType::template GroupedConvImplicitGemmTraitsBwdWeight<
ConvConfig::NumWaveGroups>,
ck_tile::element_wise::PassThrough,
ck_tile::element_wise::PassThrough,
WeiDataType,
true,
VectorSizeA,
VectorSizeB>;
GroupedConvTraitsType::FixedGemmParams::FixedVectorSize,
GroupedConvTraitsType::VectorSizeA,
GroupedConvTraitsType::VectorSizeB>;
using BaseGemmPipeline = typename PipelineTypeTraits<
ConvConfig::Pipeline>::template UniversalGemmPipeline<GemmPipelineProblem>;
@@ -101,21 +97,21 @@ struct GroupedConvolutionBackwardWeightInvoker
constexpr auto scheduler = ConvConfig::Scheduler;
constexpr auto memory_operation = memory_operation_.value;
using UniversalGemmProblem =
ck_tile::UniversalGemmPipelineProblem<OutDataType,
InDataType,
AccDataType,
GemmShape,
GemmUniversalTraits,
scheduler,
has_hot_loop_v,
tail_number_v,
ck_tile::element_wise::PassThrough,
ck_tile::element_wise::PassThrough,
WeiDataType,
true,
VectorSizeA,
VectorSizeB>;
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<
OutDataType,
InDataType,
AccDataType,
GemmShape,
GemmUniversalTraits,
scheduler,
has_hot_loop_v,
tail_number_v,
ck_tile::element_wise::PassThrough,
ck_tile::element_wise::PassThrough,
WeiDataType,
GroupedConvTraitsType::FixedGemmParams::FixedVectorSize,
GroupedConvTraitsType::VectorSizeA,
GroupedConvTraitsType::VectorSizeB>;
using GemmPipeline = typename PipelineTypeTraits<
ConvConfig::Pipeline>::template GemmPipeline<UniversalGemmProblem>;
@@ -127,7 +123,7 @@ struct GroupedConvolutionBackwardWeightInvoker
AccDataType,
WeiDataType,
typename GroupedConvTraitsType::ImplicitGemmDsLayout,
ck_tile::tensor_layout::gemm::RowMajor,
typename GroupedConvTraitsType::FixedGemmParams::ELayout,
CDEElementWise,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
@@ -136,10 +132,10 @@ struct GroupedConvolutionBackwardWeightInvoker
ConvConfig::M_Warp_Tile,
ConvConfig::N_Warp_Tile,
ConvConfig::K_Warp_Tile,
ConvConfig::TransposeC,
GroupedConvTraitsType::FixedGemmParams::TransposeC,
memory_operation,
1,
true,
ConvConfig::NumWaveGroups,
GroupedConvTraitsType::FixedGemmParams::FixedVectorSize,
GroupedConvTraitsType::VectorSizeC>>;
using Kernel = ck_tile::GroupedConvolutionBackwardWeightKernel<GroupedConvTraitsType,
@@ -184,7 +180,7 @@ struct GroupedConvolutionBackwardWeightInvoker
ave_time = ck_tile::launch_kernel_time_mask(
s,
preprocess,
ck_tile::make_kernel<kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
ck_tile::make_kernel<ConvConfig::kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
return ave_time;
};

View File

@@ -23,8 +23,6 @@ struct GroupedConvolutionBackwardWeightTwoStageInvoker
{
using WorkspaceDataType = float;
constexpr int kBlockPerCu = 1;
// Implicit GEMM Traits
using GemmShape = ck_tile::TileGemmShape<
ck_tile::sequence<ConvConfig::M_Tile, ConvConfig::N_Tile, ConvConfig::K_Tile>,
@@ -33,36 +31,34 @@ struct GroupedConvolutionBackwardWeightTwoStageInvoker
ConvConfig::N_Warp_Tile,
ConvConfig::K_Warp_Tile>>;
constexpr ck_tile::index_t VectorSizeA = 4;
constexpr ck_tile::index_t VectorSizeB = 8;
constexpr ck_tile::index_t VectorSizeC = 8;
constexpr auto ConvSpec = ck_tile::ConvolutionSpecialization::Default;
using TilePartitioner =
ck_tile::GemmSpatiallyLocalTilePartitioner<GemmShape,
ConvConfig::TileParitionerGroupNum,
ConvConfig::TileParitionerM01>;
constexpr auto ConvSpec = ck_tile::ConvolutionSpecialization::Default;
using GroupedConvTraitsType = ck_tile::GroupedConvTraits<NDimSpatial,
ConvSpec,
InLayout,
WeiLayout,
DsLayout,
OutLayout,
VectorSizeA,
VectorSizeB,
VectorSizeC>;
ConvConfig::VectorSizeA,
ConvConfig::VectorSizeB,
ConvConfig::VectorSizeC,
ConvConfig::NumGroupsToMerge>;
using TilePartitioner = ck_tile::GemmSpatiallyLocalTilePartitioner<
GemmShape,
GroupedConvTraitsType::FixedGemmParams::TilePartitionerGroupNum,
GroupedConvTraitsType::FixedGemmParams::TilePartitionerM01>;
using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits<
ConvConfig::kPadM,
ConvConfig::kPadN,
ConvConfig::kPadK,
GroupedConvTraitsType::FixedGemmParams::kPadM,
GroupedConvTraitsType::FixedGemmParams::kPadN,
GroupedConvTraitsType::FixedGemmParams::kPadK,
ConvConfig::DoubleSmemBuffer,
typename GroupedConvTraitsType::GroupedConvImplicitGemmTraitsBwdWeight::AsLayout,
typename GroupedConvTraitsType::GroupedConvImplicitGemmTraitsBwdWeight::BsLayout,
typename GroupedConvTraitsType::GroupedConvImplicitGemmTraitsBwdWeight::CLayout,
ConvConfig::TransposeC,
false,
false, // Persistent,
typename GroupedConvTraitsType::AsLayoutBwdWeight,
typename GroupedConvTraitsType::BsLayoutBwdWeight,
typename GroupedConvTraitsType::CLayoutBwdWeight,
GroupedConvTraitsType::FixedGemmParams::TransposeC,
GroupedConvTraitsType::FixedGemmParams::UseStructuredSparsity,
GroupedConvTraitsType::FixedGemmParams::Persistent,
ConvConfig::NumWaveGroups>;
using GemmPipelineProblem = ck_tile::GemmPipelineProblem<
@@ -70,13 +66,14 @@ struct GroupedConvolutionBackwardWeightTwoStageInvoker
InDataType,
AccDataType,
GemmShape,
typename GroupedConvTraitsType::GroupedConvImplicitGemmTraitsBwdWeight,
typename GroupedConvTraitsType::template GroupedConvImplicitGemmTraitsBwdWeight<
ConvConfig::NumWaveGroups>,
ck_tile::element_wise::PassThrough,
ck_tile::element_wise::PassThrough,
WeiDataType,
true,
VectorSizeA,
VectorSizeB>;
GroupedConvTraitsType::FixedGemmParams::FixedVectorSize,
GroupedConvTraitsType::VectorSizeA,
GroupedConvTraitsType::VectorSizeB>;
using BaseGemmPipeline = typename PipelineTypeTraits<
ConvConfig::Pipeline>::template UniversalGemmPipeline<GemmPipelineProblem>;
@@ -102,21 +99,21 @@ struct GroupedConvolutionBackwardWeightTwoStageInvoker
constexpr auto scheduler = ConvConfig::Scheduler;
constexpr auto memory_operation = memory_operation_.value;
using UniversalGemmProblem =
ck_tile::UniversalGemmPipelineProblem<OutDataType,
InDataType,
AccDataType,
GemmShape,
GemmUniversalTraits,
scheduler,
has_hot_loop_v,
tail_number_v,
ck_tile::element_wise::PassThrough,
ck_tile::element_wise::PassThrough,
WeiDataType,
true,
VectorSizeA,
VectorSizeB>;
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<
OutDataType,
InDataType,
AccDataType,
GemmShape,
GemmUniversalTraits,
scheduler,
has_hot_loop_v,
tail_number_v,
ck_tile::element_wise::PassThrough,
ck_tile::element_wise::PassThrough,
WeiDataType,
GroupedConvTraitsType::FixedGemmParams::FixedVectorSize,
GroupedConvTraitsType::VectorSizeA,
GroupedConvTraitsType::VectorSizeB>;
using GemmPipeline = typename PipelineTypeTraits<
ConvConfig::Pipeline>::template GemmPipeline<UniversalGemmProblem>;
@@ -128,7 +125,7 @@ struct GroupedConvolutionBackwardWeightTwoStageInvoker
AccDataType,
WorkspaceDataType, // C: Workspace normally Out
typename GroupedConvTraitsType::ImplicitGemmDsLayout,
ck_tile::tensor_layout::gemm::RowMajor,
typename GroupedConvTraitsType::FixedGemmParams::ELayout,
CDEElementWise,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
@@ -139,8 +136,8 @@ struct GroupedConvolutionBackwardWeightTwoStageInvoker
ConvConfig::K_Warp_Tile,
GemmPipelineProblem::TransposeC,
memory_operation,
1,
true,
ConvConfig::NumWaveGroups,
GroupedConvTraitsType::FixedGemmParams::FixedVectorSize,
GroupedConvTraitsType::VectorSizeC>>;
using Kernel = ck_tile::GroupedConvolutionBackwardWeightKernel<GroupedConvTraitsType,
@@ -235,16 +232,17 @@ struct GroupedConvolutionBackwardWeightTwoStageInvoker
ave_time = ck_tile::launch_kernel_time_mask(
s,
preprocess,
ck_tile::make_kernel<kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs),
ck_tile::make_kernel<kBlockPerCu>(ElementwiseKernel{},
kGridSize,
kBlockSize,
0,
input_size,
ck_tile::make_tuple(shape[1], 1), // Input Stride
ck_tile::make_tuple(shape[1], 1), // Output Stride
input_tensors,
static_cast<WeiDataType*>(c_ptr)));
ck_tile::make_kernel<ConvConfig::kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs),
ck_tile::make_kernel<ConvConfig::kBlockPerCu>(
ElementwiseKernel{},
kGridSize,
kBlockSize,
0,
input_size,
ck_tile::make_tuple(shape[1], 1), // Input Stride
ck_tile::make_tuple(shape[1], 1), // Output Stride
input_tensors,
static_cast<WeiDataType*>(c_ptr)));
return ave_time;
};

View File

@@ -32,8 +32,6 @@ struct GroupedConvolutionForwardInvoker
{
std::cout << "[INVOKER] grouped_conv_fwd called, NDimSpatial=" << NDimSpatial << "\n";
}
constexpr int kBlockPerCu = 1;
// Implicit GEMM Traits
using GemmShape = ck_tile::TileGemmShape<
ck_tile::sequence<ConvConfig::M_Tile, ConvConfig::N_Tile, ConvConfig::K_Tile>,
@@ -42,38 +40,34 @@ struct GroupedConvolutionForwardInvoker
ConvConfig::N_Warp_Tile,
ConvConfig::K_Warp_Tile>>;
constexpr ck_tile::index_t VectorSizeA = 8;
constexpr ck_tile::index_t VectorSizeB = 8;
constexpr ck_tile::index_t VectorSizeC = 8;
constexpr ck_tile::index_t NumGroupsToMerge = 1;
constexpr auto ConvSpec = ck_tile::ConvolutionSpecialization::Default;
using TilePartitioner =
ck_tile::GemmSpatiallyLocalTilePartitioner<GemmShape,
ConvConfig::TileParitionerGroupNum,
ConvConfig::TileParitionerM01>;
constexpr auto ConvSpec = ck_tile::ConvolutionSpecialization::Default;
using GroupedConvTraitsType = ck_tile::GroupedConvTraits<NDimSpatial,
ConvSpec,
InLayout,
WeiLayout,
DsLayout,
OutLayout,
VectorSizeA,
VectorSizeB,
VectorSizeC,
NumGroupsToMerge>;
ConvConfig::VectorSizeA,
ConvConfig::VectorSizeB,
ConvConfig::VectorSizeC,
ConvConfig::NumGroupsToMerge>;
using TilePartitioner = ck_tile::GemmSpatiallyLocalTilePartitioner<
GemmShape,
GroupedConvTraitsType::FixedGemmParams::TilePartitionerGroupNum,
GroupedConvTraitsType::FixedGemmParams::TilePartitionerM01>;
using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits<
ConvConfig::kPadM,
ConvConfig::kPadN,
ConvConfig::kPadK,
GroupedConvTraitsType::FixedGemmParams::kPadM,
GroupedConvTraitsType::FixedGemmParams::kPadN,
GroupedConvTraitsType::FixedGemmParams::kPadK,
ConvConfig::DoubleSmemBuffer,
typename GroupedConvTraitsType::GroupedConvImplicitGemmTraitsFwd::AsLayout,
typename GroupedConvTraitsType::GroupedConvImplicitGemmTraitsFwd::BsLayout,
typename GroupedConvTraitsType::GroupedConvImplicitGemmTraitsFwd::CLayout,
ConvConfig::TransposeC,
false,
false, // Persistent,
typename GroupedConvTraitsType::AsLayoutFwd,
typename GroupedConvTraitsType::BsLayoutFwd,
typename GroupedConvTraitsType::CLayoutFwd,
GroupedConvTraitsType::FixedGemmParams::TransposeC,
GroupedConvTraitsType::FixedGemmParams::UseStructuredSparsity,
GroupedConvTraitsType::FixedGemmParams::Persistent,
ConvConfig::NumWaveGroups>;
using GemmPipelineProblem = ck_tile::GemmPipelineProblem<
@@ -81,13 +75,14 @@ struct GroupedConvolutionForwardInvoker
WeiDataType,
AccDataType,
GemmShape,
typename GroupedConvTraitsType::GroupedConvImplicitGemmTraitsFwd,
typename GroupedConvTraitsType::template GroupedConvImplicitGemmTraitsFwd<
ConvConfig::NumWaveGroups>,
ck_tile::element_wise::PassThrough,
ck_tile::element_wise::PassThrough,
OutDataType,
true,
VectorSizeA,
VectorSizeB>;
GroupedConvTraitsType::FixedGemmParams::FixedVectorSize,
GroupedConvTraitsType::VectorSizeA,
GroupedConvTraitsType::VectorSizeB>;
using BaseGemmPipeline = typename PipelineTypeTraits<
ConvConfig::Pipeline>::template UniversalGemmPipeline<GemmPipelineProblem>;
@@ -116,21 +111,21 @@ struct GroupedConvolutionForwardInvoker
constexpr auto scheduler = ConvConfig::Scheduler;
constexpr auto memory_operation = memory_operation_.value;
using UniversalGemmProblem =
ck_tile::UniversalGemmPipelineProblem<InDataType,
WeiDataType,
AccDataType,
GemmShape,
GemmUniversalTraits,
scheduler,
has_hot_loop_v,
tail_number_v,
ck_tile::element_wise::PassThrough,
ck_tile::element_wise::PassThrough,
OutDataType,
true,
VectorSizeA,
VectorSizeB>;
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<
InDataType,
WeiDataType,
AccDataType,
GemmShape,
GemmUniversalTraits,
scheduler,
has_hot_loop_v,
tail_number_v,
ck_tile::element_wise::PassThrough,
ck_tile::element_wise::PassThrough,
OutDataType,
GroupedConvTraitsType::FixedGemmParams::FixedVectorSize,
GroupedConvTraitsType::VectorSizeA,
GroupedConvTraitsType::VectorSizeB>;
using GemmPipeline = typename PipelineTypeTraits<
ConvConfig::Pipeline>::template GemmPipeline<UniversalGemmProblem>;
@@ -142,7 +137,7 @@ struct GroupedConvolutionForwardInvoker
AccDataType,
OutDataType,
typename GroupedConvTraitsType::ImplicitGemmDsLayout,
ck_tile::tensor_layout::gemm::RowMajor,
typename GroupedConvTraitsType::FixedGemmParams::ELayout,
CDElementWise,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
@@ -151,10 +146,10 @@ struct GroupedConvolutionForwardInvoker
ConvConfig::M_Warp_Tile,
ConvConfig::N_Warp_Tile,
ConvConfig::K_Warp_Tile,
ConvConfig::TransposeC,
GroupedConvTraitsType::FixedGemmParams::TransposeC,
memory_operation,
1,
true,
ConvConfig::NumWaveGroups,
GroupedConvTraitsType::FixedGemmParams::FixedVectorSize,
GroupedConvTraitsType::VectorSizeC>>;
using Kernel = ck_tile::GroupedConvolutionForwardKernel<GroupedConvTraitsType,
@@ -185,8 +180,9 @@ struct GroupedConvolutionForwardInvoker
<< ", Vector size C: " << ConvEpilogue::GetVectorSizeC() << std::endl;
}
ave_time = ck_tile::launch_kernel(
s, ck_tile::make_kernel<kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
ave_time = ck_tile::launch_kernel(s,
ck_tile::make_kernel<ConvConfig::kBlockPerCu>(
Kernel{}, grids, blocks, 0, kargs));
return ave_time;
};

View File

@@ -25,7 +25,6 @@ struct GroupedConvolutionForwardInvoker
{
std::cout << "[INVOKER] grouped_conv_fwd called, NDimSpatial=" << NDimSpatial << "\n";
}
constexpr int kBlockPerCu = 1;
// Implicit GEMM Traits
using GemmShape = ck_tile::TileGemmShape<
@@ -35,27 +34,18 @@ struct GroupedConvolutionForwardInvoker
ConvConfig::N_Warp_Tile,
ConvConfig::K_Warp_Tile>>;
constexpr ck_tile::index_t VectorSizeA = 8;
constexpr ck_tile::index_t VectorSizeB = 8;
constexpr ck_tile::index_t VectorSizeC = 8;
constexpr auto ConvSpec = ck_tile::ConvolutionSpecialization::Default;
using TilePartitioner =
ck_tile::GemmSpatiallyLocalTilePartitioner<GemmShape,
ConvConfig::TileParitionerGroupNum,
ConvConfig::TileParitionerM01>;
using GroupedConvTraitsTypeDefault = ck_tile::GroupedConvTraits<NDimSpatial,
ConvSpec,
InLayout,
WeiLayout,
DsLayout,
OutLayout,
VectorSizeA,
VectorSizeB,
VectorSizeC,
1, /*NumGroupsToMerge*/
false /*EnableSplitImage*/>;
using GroupedConvTraitsTypeDefault =
ck_tile::GroupedConvTraits<NDimSpatial,
ConvSpec,
InLayout,
WeiLayout,
DsLayout,
OutLayout,
ConvConfig::VectorSizeA,
ConvConfig::VectorSizeB,
ConvConfig::VectorSizeC,
ConvConfig::NumGroupsToMerge>;
using GroupedConvTraitsTypeLargeTensor =
ck_tile::GroupedConvTraits<NDimSpatial,
@@ -64,23 +54,28 @@ struct GroupedConvolutionForwardInvoker
WeiLayout,
DsLayout,
OutLayout,
VectorSizeA,
VectorSizeB,
VectorSizeC,
1, /*NumGroupsToMerge*/
ConvConfig::VectorSizeA,
ConvConfig::VectorSizeB,
ConvConfig::VectorSizeC,
ConvConfig::NumGroupsToMerge,
true /*EnableSplitImage*/>;
using TilePartitioner = ck_tile::GemmSpatiallyLocalTilePartitioner<
GemmShape,
GroupedConvTraitsTypeDefault::FixedGemmParams::TilePartitionerGroupNum,
GroupedConvTraitsTypeDefault::FixedGemmParams::TilePartitionerM01>;
using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits<
ConvConfig::kPadM,
ConvConfig::kPadN,
ConvConfig::kPadK,
GroupedConvTraitsTypeDefault::FixedGemmParams::kPadM,
GroupedConvTraitsTypeDefault::FixedGemmParams::kPadN,
GroupedConvTraitsTypeDefault::FixedGemmParams::kPadK,
ConvConfig::DoubleSmemBuffer,
typename GroupedConvTraitsTypeDefault::GroupedConvImplicitGemmTraitsFwd::AsLayout,
typename GroupedConvTraitsTypeDefault::GroupedConvImplicitGemmTraitsFwd::BsLayout,
typename GroupedConvTraitsTypeDefault::GroupedConvImplicitGemmTraitsFwd::CLayout,
ConvConfig::TransposeC,
false,
false, // Persistent,
typename GroupedConvTraitsTypeDefault::AsLayoutFwd,
typename GroupedConvTraitsTypeDefault::BsLayoutFwd,
typename GroupedConvTraitsTypeDefault::CLayoutFwd,
GroupedConvTraitsTypeDefault::FixedGemmParams::TransposeC,
GroupedConvTraitsTypeDefault::FixedGemmParams::UseStructuredSparsity,
GroupedConvTraitsTypeDefault::FixedGemmParams::Persistent,
ConvConfig::NumWaveGroups>;
using GemmPipelineProblem = ck_tile::GemmPipelineProblem<
@@ -88,13 +83,14 @@ struct GroupedConvolutionForwardInvoker
WeiDataType,
AccDataType,
GemmShape,
typename GroupedConvTraitsTypeDefault::GroupedConvImplicitGemmTraitsFwd,
typename GroupedConvTraitsTypeDefault::template GroupedConvImplicitGemmTraitsFwd<
ConvConfig::NumWaveGroups>,
ck_tile::element_wise::PassThrough,
ck_tile::element_wise::PassThrough,
OutDataType,
true,
VectorSizeA,
VectorSizeB>;
GroupedConvTraitsTypeDefault::FixedGemmParams::FixedVectorSize,
GroupedConvTraitsTypeDefault::VectorSizeA,
GroupedConvTraitsTypeDefault::VectorSizeB>;
using BaseGemmPipeline = typename PipelineTypeTraits<
ConvConfig::Pipeline>::template UniversalGemmPipeline<GemmPipelineProblem>;
@@ -116,9 +112,9 @@ struct GroupedConvolutionForwardInvoker
using TransformType =
ck_tile::TransformConvFwdToGemm<NDimSpatial,
ck_tile::ConvolutionSpecialization::Default,
VectorSizeA,
VectorSizeB,
VectorSizeC,
GroupedConvTraitsTypeDefault::VectorSizeA,
GroupedConvTraitsTypeDefault::VectorSizeB,
GroupedConvTraitsTypeDefault::VectorSizeC,
1, // NumGroupsToMerge
false, // SplitN
InDataType,
@@ -264,21 +260,21 @@ struct GroupedConvolutionForwardInvoker
GroupedConvTraitsTypeLargeTensor,
GroupedConvTraitsTypeDefault>;
using UniversalGemmProblem =
ck_tile::UniversalGemmPipelineProblem<InDataType,
WeiDataType,
AccDataType,
GemmShape,
GemmUniversalTraits,
scheduler,
has_hot_loop_v,
tail_number_v,
ck_tile::element_wise::PassThrough,
ck_tile::element_wise::PassThrough,
OutDataType,
true,
VectorSizeA,
VectorSizeB>;
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<
InDataType,
WeiDataType,
AccDataType,
GemmShape,
GemmUniversalTraits,
scheduler,
has_hot_loop_v,
tail_number_v,
ck_tile::element_wise::PassThrough,
ck_tile::element_wise::PassThrough,
OutDataType,
GroupedConvTraitsType::FixedGemmParams::FixedVectorSize,
GroupedConvTraitsType::VectorSizeA,
GroupedConvTraitsType::VectorSizeB>;
using GemmPipeline = typename PipelineTypeTraits<
ConvConfig::Pipeline>::template GemmPipeline<UniversalGemmProblem>;
@@ -290,7 +286,7 @@ struct GroupedConvolutionForwardInvoker
AccDataType,
OutDataType,
typename GroupedConvTraitsType::ImplicitGemmDsLayout,
ck_tile::tensor_layout::gemm::RowMajor,
typename GroupedConvTraitsType::FixedGemmParams::ELayout,
CDEElementWise,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
@@ -299,10 +295,10 @@ struct GroupedConvolutionForwardInvoker
ConvConfig::M_Warp_Tile,
ConvConfig::N_Warp_Tile,
ConvConfig::K_Warp_Tile,
ConvConfig::TransposeC,
GroupedConvTraitsType::FixedGemmParams::TransposeC,
memory_operation,
1,
true,
ConvConfig::NumWaveGroups,
GroupedConvTraitsType::FixedGemmParams::FixedVectorSize,
GroupedConvTraitsType::VectorSizeC>>;
// Use split-image kernel if layout supports it, otherwise use regular kernel
@@ -368,7 +364,8 @@ struct GroupedConvolutionForwardInvoker
}
ave_time = ck_tile::launch_kernel(
s, ck_tile::make_kernel<kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
s,
ck_tile::make_kernel<ConvConfig::kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
return ave_time;
};

View File

@@ -78,66 +78,4 @@ struct UnsupportedEnumValue
{
};
// Helper functions to convert enums to strings
constexpr std::string_view ConvDirectionToString(ConvDirection dir)
{
switch(dir)
{
case ConvDirection::FORWARD: return "Forward";
case ConvDirection::BACKWARD_DATA: return "Backward Data";
case ConvDirection::BACKWARD_WEIGHT: return "Backward Weight";
default: return "Unknown";
}
}
constexpr std::string_view DataTypeToString(DataType dt)
{
switch(dt)
{
case DataType::FP16: return "FP16";
case DataType::FP32: return "FP32";
case DataType::BF16: return "BF16";
case DataType::FP8: return "FP8";
case DataType::I8: return "I8";
case DataType::U8: return "U8";
default: return "Unknown";
}
}
constexpr std::string_view LayoutToString(GroupConvLayout1D layout)
{
switch(layout)
{
case GroupConvLayout1D::GNWC_GKXC_GNWK: return "GNWC_GKXC_GNWK";
case GroupConvLayout1D::NWGC_GKXC_NWGK: return "NWGC_GKXC_NWGK";
case GroupConvLayout1D::NGCW_GKXC_NGKW: return "NGCW_GKXC_NGKW";
case GroupConvLayout1D::NGCW_GKCX_NGKW: return "NGCW_GKCX_NGKW";
default: return "Unknown";
}
}
constexpr std::string_view LayoutToString(GroupConvLayout2D layout)
{
switch(layout)
{
case GroupConvLayout2D::GNHWC_GKYXC_GNHWK: return "GNHWC_GKYXC_GNHWK";
case GroupConvLayout2D::NHWGC_GKYXC_NHWGK: return "NHWGC_GKYXC_NHWGK";
case GroupConvLayout2D::NGCHW_GKYXC_NGKHW: return "NGCHW_GKYXC_NGKHW";
case GroupConvLayout2D::NGCHW_GKCYX_NGKHW: return "NGCHW_GKCYX_NGKHW";
default: return "Unknown";
}
}
constexpr std::string_view LayoutToString(GroupConvLayout3D layout)
{
switch(layout)
{
case GroupConvLayout3D::GNDHWC_GKZYXC_GNDHWK: return "GNDHWC_GKZYXC_GNDHWK";
case GroupConvLayout3D::NDHWGC_GKZYXC_NDHWGK: return "NDHWGC_GKZYXC_NDHWGK";
case GroupConvLayout3D::NGCDHW_GKZYXC_NGKDHW: return "NGCDHW_GKZYXC_NGKDHW";
case GroupConvLayout3D::NGCDHW_GKCZYX_NGKDHW: return "NGCDHW_GKCZYX_NGKDHW";
default: return "Unknown";
}
}
} // namespace ck_tile::builder

View File

@@ -38,8 +38,8 @@ concept GridwiseXdlGemmDescriptor = requires(T t) {
// Concept for parameter that describe block GEMM problem.
template <typename T>
concept BlockGemmDescriptor = requires(T t) {
{ t.pipeline_version } -> std::convertible_to<BlockGemmPipelineVersion>;
{ t.scheduler } -> std::convertible_to<BlockGemmPipelineScheduler>;
{ t.pipeline_version } -> std::convertible_to<PipelineVersion>;
{ t.scheduler } -> std::convertible_to<PipelineScheduler>;
};
// Concept for parameters that describe a gridwise WMMA GEMM problem.
@@ -50,7 +50,7 @@ concept GridwiseWmmaGemmDescriptor = requires(T t) {
{ t.n_per_wmma } -> std::convertible_to<size_t>;
{ t.m_wmma_per_wave } -> std::convertible_to<size_t>;
{ t.n_wmma_per_wave } -> std::convertible_to<size_t>;
{ t.pipeline_version } -> std::convertible_to<GridwiseGemmPipelineVersion>;
{ t.pipeline_version } -> std::convertible_to<PipelineVersion>;
};
// Concept for vectorized data transfer for convolution input tensors.
@@ -154,8 +154,8 @@ concept SpecifiesSourceAccessOrder = requires(T t) {
// Concept to check if struct specifies block GEMM.
template <typename T>
concept SpecifiesBlockGemm = requires {
{ T::block_gemm.pipeline_version } -> std::convertible_to<BlockGemmPipelineVersion>;
{ T::block_gemm.scheduler } -> std::convertible_to<BlockGemmPipelineScheduler>;
{ T::block_gemm.pipeline_version } -> std::convertible_to<PipelineVersion>;
{ T::block_gemm.scheduler } -> std::convertible_to<PipelineScheduler>;
};
template <typename T>
@@ -180,7 +180,90 @@ concept SpecifiesNumGroupsToMerge = requires {
template <typename T>
concept SpecifiesLoopScheduler = requires {
{ T::loop_scheduler } -> std::convertible_to<LoopScheduler>;
{ T::loop_scheduler } -> std::convertible_to<PipelineScheduler>;
};
/******************************************** */
/* DL-specific descriptors and requirements */
/******************************************** */
// Concept for DL thread configuration
template <typename T>
concept DlThreadConfigDescriptor = requires(T t) {
{ t.k0_per_block } -> std::convertible_to<size_t>;
{ t.k1 } -> std::convertible_to<size_t>;
{ t.m1_per_thread } -> std::convertible_to<size_t>;
{ t.n1_per_thread } -> std::convertible_to<size_t>;
{ t.k_per_thread } -> std::convertible_to<size_t>;
};
// Concept for DL thread cluster
template <typename T>
concept DlThreadClusterDescriptor = requires(T t) {
{ t.m1_xs } -> std::convertible_to<std::array<size_t, 2>>;
{ t.n1_xs } -> std::convertible_to<std::array<size_t, 2>>;
};
// Concept for DL block transfer K0_M0_M1_K1 format
template <typename T>
concept DlBlockTransferK0M0M1K1Descriptor = requires(T t) {
{ t.thread_slice_lengths } -> std::convertible_to<std::array<size_t, 4>>;
{ t.thread_cluster_lengths } -> std::convertible_to<std::array<size_t, 4>>;
{ t.thread_cluster_arrange_order } -> std::convertible_to<std::array<size_t, 4>>;
{ t.src_access_order } -> std::convertible_to<std::array<size_t, 4>>;
{ t.src_vector_tensor_lengths } -> std::convertible_to<std::array<size_t, 4>>;
{ t.src_vector_tensor_contiguous_dim_order } -> std::convertible_to<std::array<size_t, 4>>;
{ t.dst_vector_tensor_lengths } -> std::convertible_to<std::array<size_t, 4>>;
};
// Concept for DL block transfer K0_N0_N1_K1 format
template <typename T>
concept DlBlockTransferK0N0N1K1Descriptor = requires(T t) {
{ t.thread_slice_lengths } -> std::convertible_to<std::array<size_t, 4>>;
{ t.thread_cluster_lengths } -> std::convertible_to<std::array<size_t, 4>>;
{ t.thread_cluster_arrange_order } -> std::convertible_to<std::array<size_t, 4>>;
{ t.src_access_order } -> std::convertible_to<std::array<size_t, 4>>;
{ t.src_vector_tensor_lengths } -> std::convertible_to<std::array<size_t, 4>>;
{ t.src_vector_tensor_contiguous_dim_order } -> std::convertible_to<std::array<size_t, 4>>;
{ t.dst_vector_tensor_lengths } -> std::convertible_to<std::array<size_t, 4>>;
};
// Concept for DL C thread transfer
template <typename T>
concept DlCThreadTransferDescriptor = requires(T t) {
{ t.src_dst_access_order } -> std::convertible_to<std::array<size_t, 6>>;
{ t.src_dst_vector_dim } -> std::convertible_to<size_t>;
{ t.dst_scalar_per_vector } -> std::convertible_to<size_t>;
};
// Concept to check if algorithm specifies DL thread config
template <typename T>
concept SpecifiesDlThreadConfig = requires {
{ T::dl_thread_config } -> DlThreadConfigDescriptor;
};
// Concept to check if algorithm specifies DL thread cluster
template <typename T>
concept SpecifiesDlThreadCluster = requires {
{ T::dl_thread_cluster } -> DlThreadClusterDescriptor;
};
// Concept to check if algorithm specifies DL A block transfer
template <typename T>
concept SpecifiesDlBlockTransferA = requires {
{ T::dl_block_transfer_a } -> DlBlockTransferK0M0M1K1Descriptor;
};
// Concept to check if algorithm specifies DL B block transfer
template <typename T>
concept SpecifiesDlBlockTransferB = requires {
{ T::dl_block_transfer_b } -> DlBlockTransferK0N0N1K1Descriptor;
};
// Concept to check if algorithm specifies DL C thread transfer
template <typename T>
concept SpecifiesDlCThreadTransfer = requires {
{ T::dl_c_thread_transfer } -> DlCThreadTransferDescriptor;
};
} // namespace ck_tile::builder

View File

@@ -36,9 +36,21 @@
#pragma once
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp"
// WORKAROUND: Macro namespace collision in upstream CK device operation headers.
// device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp (line 41) and
// device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp (line 51) both define
// GridwiseGemmTemplateParameters macro without #undef, causing redefinition errors.
// Use pragma push/pop to isolate the Large_Tensor header's macro scope.
#pragma push_macro("GridwiseGemmTemplateParameters")
#ifdef GridwiseGemmTemplateParameters
#undef GridwiseGemmTemplateParameters
#endif
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp"
#pragma pop_macro("GridwiseGemmTemplateParameters")
#include "ck_tile/builder/conv_signature_concepts.hpp"
#include "ck_tile/builder/conv_algorithm_concepts.hpp"
#include "ck_tile/builder/conv_algorithm_limits.hpp"
@@ -297,42 +309,42 @@ constexpr BlockGemmSpec SetBlockGemm()
ck::BlockGemmPipelineScheduler scheduler;
ck::BlockGemmPipelineVersion version;
if constexpr(BG.scheduler == BlockGemmPipelineScheduler::INTRAWAVE)
if constexpr(BG.scheduler == PipelineScheduler::INTRAWAVE)
{
scheduler = ck::BlockGemmPipelineScheduler::Intrawave;
}
else if constexpr(BG.scheduler == BlockGemmPipelineScheduler::INTERWAVE)
else if constexpr(BG.scheduler == PipelineScheduler::INTERWAVE)
{
scheduler = ck::BlockGemmPipelineScheduler::Interwave;
}
else
{
static_assert(false, "Unknown BlockGemmPipelineScheduler");
static_assert(false, "Unknown PipelineScheduler");
}
if constexpr(BG.pipeline_version == BlockGemmPipelineVersion::V1)
if constexpr(BG.pipeline_version == PipelineVersion::V1)
{
version = ck::BlockGemmPipelineVersion::v1;
}
else if constexpr(BG.pipeline_version == BlockGemmPipelineVersion::V2)
else if constexpr(BG.pipeline_version == PipelineVersion::V2)
{
version = ck::BlockGemmPipelineVersion::v2;
}
else if constexpr(BG.pipeline_version == BlockGemmPipelineVersion::V3)
else if constexpr(BG.pipeline_version == PipelineVersion::V3)
{
version = ck::BlockGemmPipelineVersion::v3;
}
else if constexpr(BG.pipeline_version == BlockGemmPipelineVersion::V4)
else if constexpr(BG.pipeline_version == PipelineVersion::V4)
{
version = ck::BlockGemmPipelineVersion::v4;
}
else if constexpr(BG.pipeline_version == BlockGemmPipelineVersion::V5)
else if constexpr(BG.pipeline_version == PipelineVersion::V5)
{
version = ck::BlockGemmPipelineVersion::v5;
}
else
{
static_assert(false, "Unknown BlockGemmPipelineVersion");
static_assert(false, "Unknown PipelineVersion");
}
return BlockGemmSpec{.pipeline_version = version, .scheduler = scheduler};
@@ -442,17 +454,17 @@ consteval ck::LoopScheduler SetLoopScheduler()
{
constexpr auto loop_scheduler = ALGORITHM.loop_scheduler;
if constexpr(loop_scheduler == LoopScheduler::DEFAULT)
if constexpr(loop_scheduler == PipelineScheduler::DEFAULT)
{
return ck::LoopScheduler::Default;
}
else if constexpr(loop_scheduler == LoopScheduler::INTERWAVE)
else if constexpr(loop_scheduler == PipelineScheduler::INTERWAVE)
{
return ck::LoopScheduler::Interwave;
}
else
{
static_assert(false, "Unknown LoopScheduler");
static_assert(false, "Unknown PipelineScheduler");
}
}
@@ -460,29 +472,29 @@ template <ConvAlgorithmDescriptor auto ALGORITHM>
consteval ck::PipelineVersion SetGridwiseGemmPipelineVersion()
{
constexpr auto pipeline_version = ALGORITHM.gridwise_gemm.pipeline_version;
if constexpr(pipeline_version == GridwiseGemmPipelineVersion::V1)
if constexpr(pipeline_version == PipelineVersion::V1)
{
return ck::PipelineVersion::v1;
}
else if constexpr(pipeline_version == GridwiseGemmPipelineVersion::V2)
else if constexpr(pipeline_version == PipelineVersion::V2)
{
return ck::PipelineVersion::v2;
}
else if constexpr(pipeline_version == GridwiseGemmPipelineVersion::V3)
else if constexpr(pipeline_version == PipelineVersion::V3)
{
static_assert(false, "V3 is used only for stream-K.");
}
else if constexpr(pipeline_version == GridwiseGemmPipelineVersion::V4)
else if constexpr(pipeline_version == PipelineVersion::V4)
{
return ck::PipelineVersion::v4;
}
else if constexpr(pipeline_version == GridwiseGemmPipelineVersion::WEIGHT_ONLY)
else if constexpr(pipeline_version == PipelineVersion::WEIGHT_ONLY)
{
return ck::PipelineVersion::weight_only;
}
else
{
static_assert(false, "Unknown GridwiseGemmPipelineVersion");
static_assert(false, "Unknown PipelineVersion");
}
}
@@ -566,29 +578,29 @@ consteval ck::BlockGemmPipelineVersion SetBlockGemmPipelineVersion()
{
constexpr auto version = ALGORITHM.pipeline_version;
if constexpr(version == BlockGemmPipelineVersion::V1)
if constexpr(version == PipelineVersion::V1)
{
return ck::BlockGemmPipelineVersion::v1;
}
else if constexpr(version == BlockGemmPipelineVersion::V2)
else if constexpr(version == PipelineVersion::V2)
{
return ck::BlockGemmPipelineVersion::v2;
}
else if constexpr(version == BlockGemmPipelineVersion::V3)
else if constexpr(version == PipelineVersion::V3)
{
return ck::BlockGemmPipelineVersion::v3;
}
else if constexpr(version == BlockGemmPipelineVersion::V4)
else if constexpr(version == PipelineVersion::V4)
{
return ck::BlockGemmPipelineVersion::v4;
}
else if constexpr(version == BlockGemmPipelineVersion::V5)
else if constexpr(version == PipelineVersion::V5)
{
return ck::BlockGemmPipelineVersion::v5;
}
else
{
static_assert(false, "Unknown BlockGemmPipelineVersion");
static_assert(false, "Unknown PipelineVersion");
}
}
@@ -990,4 +1002,263 @@ struct ConvFactory<SIGNATURE, ALGORITHM, VERSION>
GRIDWISE_GEMM_PIPELINE_VERSION>;
};
// Factory specialization for DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK instance
// of a grouped forward convolution kernel using Direct Load (DL) approach.
template <ConvSignatureDescriptor auto SIGNATURE,
ConvAlgorithmDescriptor auto ALGORITHM,
StringLiteral VERSION>
requires ConvDirectionIsForward<SIGNATURE> &&
ConvDeviceOpIs_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK<SIGNATURE>
struct ConvFactory<SIGNATURE, ALGORITHM, VERSION>
{
static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim;
using Layouts = decltype(factory_internal::GetTensorLayout<SIGNATURE.layout,
SPATIAL_DIM,
ConvDirection::FORWARD>());
using Types = factory_internal::ConvTensorTypes<SIGNATURE.data_type>;
using Ops = factory_internal::ElementwiseOps<SIGNATURE.elementwise_operation>;
using AlgorithmType = decltype(ALGORITHM);
static_assert(SpecifiesThreadBlock<AlgorithmType>,
"The convolution algorithm descriptor must specify thread block info.");
static_assert(SpecifiesFwdConcSpecialization<AlgorithmType>,
"The convolution algorithm descriptor must specify forward convolution "
"specialization.");
static_assert(SpecifiesGemmSpecialization<AlgorithmType>,
"The convolution algorithm descriptor must specify gemm specialization.");
static_assert(SpecifiesDlThreadConfig<AlgorithmType>,
"DL algorithm must specify thread config.");
static_assert(SpecifiesDlThreadCluster<AlgorithmType>,
"DL algorithm must specify thread cluster.");
static_assert(SpecifiesDlBlockTransferA<AlgorithmType>,
"DL algorithm must specify A block transfer.");
static_assert(SpecifiesDlBlockTransferB<AlgorithmType>,
"DL algorithm must specify B block transfer.");
static_assert(SpecifiesDlCThreadTransfer<AlgorithmType>,
"DL algorithm must specify C thread transfer.");
static constexpr auto FWD_CONV_SPECIALIZATION =
factory_internal::SetFwdConvSpecialization<ALGORITHM>();
static constexpr auto GEMM_SPECIALIZATION =
factory_internal::SetGemmSpecialization<ALGORITHM>();
static constexpr auto BLOCK = factory_internal::SetThreadBlockInfo<ALGORITHM>();
// DL-specific parameters from algorithm descriptor
static constexpr auto DL_THREAD_CFG = ALGORITHM.dl_thread_config;
static constexpr ck::index_t K0PerBlock = DL_THREAD_CFG.k0_per_block;
static constexpr ck::index_t K1 = DL_THREAD_CFG.k1;
static constexpr ck::index_t M1PerThread = DL_THREAD_CFG.m1_per_thread;
static constexpr ck::index_t N1PerThread = DL_THREAD_CFG.n1_per_thread;
static constexpr ck::index_t KPerThread = DL_THREAD_CFG.k_per_thread;
// Thread cluster from descriptor
static constexpr auto DL_CLUSTER = ALGORITHM.dl_thread_cluster;
using M1N1ThreadClusterM1Xs = to_sequence_v<DL_CLUSTER.m1_xs>;
using M1N1ThreadClusterN1Xs = to_sequence_v<DL_CLUSTER.n1_xs>;
// A Block Transfer from descriptor - K0_M0_M1_K1 tensor format
static constexpr auto DL_A_TRANSFER = ALGORITHM.dl_block_transfer_a;
using ABlockTransferThreadSliceLengths_K0_M0_M1_K1 =
to_sequence_v<DL_A_TRANSFER.thread_slice_lengths>;
using ABlockTransferThreadClusterLengths_K0_M0_M1_K1 =
to_sequence_v<DL_A_TRANSFER.thread_cluster_lengths>;
using ABlockTransferThreadClusterArrangeOrder =
to_sequence_v<DL_A_TRANSFER.thread_cluster_arrange_order>;
using ABlockTransferSrcAccessOrder = to_sequence_v<DL_A_TRANSFER.src_access_order>;
using ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1 =
to_sequence_v<DL_A_TRANSFER.src_vector_tensor_lengths>;
using ABlockTransferSrcVectorTensorContiguousDimOrder =
to_sequence_v<DL_A_TRANSFER.src_vector_tensor_contiguous_dim_order>;
using ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1 =
to_sequence_v<DL_A_TRANSFER.dst_vector_tensor_lengths>;
// B Block Transfer from descriptor - K0_N0_N1_K1 tensor format
static constexpr auto DL_B_TRANSFER = ALGORITHM.dl_block_transfer_b;
using BBlockTransferThreadSliceLengths_K0_N0_N1_K1 =
to_sequence_v<DL_B_TRANSFER.thread_slice_lengths>;
using BBlockTransferThreadClusterLengths_K0_N0_N1_K1 =
to_sequence_v<DL_B_TRANSFER.thread_cluster_lengths>;
using BBlockTransferThreadClusterArrangeOrder =
to_sequence_v<DL_B_TRANSFER.thread_cluster_arrange_order>;
using BBlockTransferSrcAccessOrder = to_sequence_v<DL_B_TRANSFER.src_access_order>;
using BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1 =
to_sequence_v<DL_B_TRANSFER.src_vector_tensor_lengths>;
using BBlockTransferSrcVectorTensorContiguousDimOrder =
to_sequence_v<DL_B_TRANSFER.src_vector_tensor_contiguous_dim_order>;
using BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1 =
to_sequence_v<DL_B_TRANSFER.dst_vector_tensor_lengths>;
// C Thread Transfer from descriptor
static constexpr auto DL_C_TRANSFER = ALGORITHM.dl_c_thread_transfer;
using CThreadTransferSrcDstAccessOrder = to_sequence_v<DL_C_TRANSFER.src_dst_access_order>;
static constexpr ck::index_t CThreadTransferSrcDstVectorDim = DL_C_TRANSFER.src_dst_vector_dim;
static constexpr ck::index_t CThreadTransferDstScalarPerVector =
DL_C_TRANSFER.dst_scalar_per_vector;
// The DL forward convolution kernel class instance
using Instance = ck::tensor_operation::device::DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK<
SPATIAL_DIM,
typename Types::ADataType,
typename Types::BDataType,
typename Types::DsDataTypes,
typename Types::EDataType,
typename Types::AccDataType,
typename Layouts::ALayout,
typename Layouts::BLayout,
typename Layouts::DsLayout,
typename Layouts::ELayout,
typename Ops::AElementwiseOp,
typename Ops::BElementwiseOp,
typename Ops::CDEElementwiseOp,
FWD_CONV_SPECIALIZATION,
GEMM_SPECIALIZATION,
BLOCK.block_size,
BLOCK.per_block.m,
BLOCK.per_block.n,
K0PerBlock,
K1,
M1PerThread,
N1PerThread,
KPerThread,
M1N1ThreadClusterM1Xs,
M1N1ThreadClusterN1Xs,
ABlockTransferThreadSliceLengths_K0_M0_M1_K1,
ABlockTransferThreadClusterLengths_K0_M0_M1_K1,
ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder,
ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1,
ABlockTransferSrcVectorTensorContiguousDimOrder,
ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1,
BBlockTransferThreadSliceLengths_K0_N0_N1_K1,
BBlockTransferThreadClusterLengths_K0_N0_N1_K1,
BBlockTransferThreadClusterArrangeOrder,
BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1,
BBlockTransferSrcVectorTensorContiguousDimOrder,
BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1,
CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim,
CThreadTransferDstScalarPerVector>;
};
// Factory specialization for DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor instance
// of a grouped forward convolution kernel with large tensor support (N-splitting).
template <ConvSignatureDescriptor auto SIGNATURE,
ConvAlgorithmDescriptor auto ALGORITHM,
StringLiteral VERSION>
requires ConvDirectionIsForward<SIGNATURE> &&
ConvDeviceOpIs_DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor<SIGNATURE>
struct ConvFactory<SIGNATURE, ALGORITHM, VERSION>
{
static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim;
using Layouts = decltype(factory_internal::GetTensorLayout<SIGNATURE.layout,
SPATIAL_DIM,
ConvDirection::FORWARD>());
using Types = factory_internal::ConvTensorTypes<SIGNATURE.data_type>;
using Ops = factory_internal::ElementwiseOps<SIGNATURE.elementwise_operation>;
using AlgorithmType = decltype(ALGORITHM);
static_assert(SpecifiesThreadBlock<AlgorithmType>,
"The convolution algorithm descriptor must specify thread block info.");
static_assert(SpecifiesGridwiseXdlGemm<AlgorithmType>,
"The convolution algorithm descriptor must specify gridwise GEMM info.");
static_assert(SpecifiesBlockTransfer<AlgorithmType>,
"The convolution algorithm descriptor must specify block transfer info.");
static_assert(SpecifiesLdsTransfer<AlgorithmType>,
"The convolution algorithm descriptor must specify LDS transfer info.");
static_assert(
SpecifiesThreadClusterAccessOrder<AlgorithmType>,
"The convolution algorithm descriptor must specify thread cluster access order info.");
static_assert(SpecifiesSourceAccessOrder<AlgorithmType>,
"The convolution algorithm descriptor must specify source access order info.");
static_assert(SpecifiesFwdConcSpecialization<AlgorithmType>,
"The convolution algorithm descriptor must specify forward convolution "
"specialization.");
static_assert(SpecifiesGemmSpecialization<AlgorithmType>,
"The convolution algorithm descriptor must specify gemm specialization.");
static_assert(SpecifiesNumPrefetchStages<AlgorithmType>,
"The convolution algorithm descriptor must specify number of prefetch stages.");
static_assert(SpecifiesLoopScheduler<AlgorithmType>,
"The convolution algorithm descriptor must specify loop scheduler.");
static constexpr auto FWD_CONV_SPECIALIZATION =
factory_internal::SetFwdConvSpecialization<ALGORITHM>();
static constexpr auto GEMM_SPECIALIZATION =
factory_internal::SetGemmSpecialization<ALGORITHM>();
static constexpr factory_internal::ConvSpec SPECIALIZATION{.conv_spec = FWD_CONV_SPECIALIZATION,
.gemm_spec = GEMM_SPECIALIZATION};
static constexpr auto LOOP_SCHEDULER = factory_internal::SetLoopScheduler<ALGORITHM>();
static constexpr auto BLOCK = factory_internal::SetThreadBlockInfo<ALGORITHM>();
static constexpr auto GRIDWISE_GEMM = ALGORITHM.gridwise_gemm;
static constexpr auto A_BLOCK_TRANSFER =
factory_internal::SetFwdConvABlockTransfer<ALGORITHM>();
static constexpr auto B_BLOCK_TRANSFER =
factory_internal::SetFwdConvBBlockTransfer<ALGORITHM>();
static constexpr auto C_BLOCK_TRANSFER =
factory_internal::SetCBlockTransfer<SIGNATURE, ALGORITHM>();
// Check limits for the algorithm parameters.
static_assert(InputVectorTransferLimits<A_BLOCK_TRANSFER>);
static_assert(InputVectorTransferLimits<B_BLOCK_TRANSFER>);
static_assert(OutputVectorTransferLimits<C_BLOCK_TRANSFER>);
static_assert(AccessOrderLimits<A_BLOCK_TRANSFER.thread_cluster_order>);
static_assert(AccessOrderLimits<B_BLOCK_TRANSFER.thread_cluster_order>);
static_assert(AccessOrderLimits<A_BLOCK_TRANSFER.src_access_order>);
static_assert(AccessOrderLimits<B_BLOCK_TRANSFER.src_access_order>);
// The forward convolution kernel class instance with large tensor support.
using Instance =
ck::tensor_operation::device::DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor<
SPATIAL_DIM,
typename Layouts::ALayout,
typename Layouts::BLayout,
typename Layouts::DsLayout,
typename Layouts::ELayout,
typename Types::ADataType,
typename Types::BDataType,
typename Types::AccDataType,
typename Types::CShuffleDataType,
typename Types::DsDataTypes,
typename Types::EDataType,
typename Ops::AElementwiseOp,
typename Ops::BElementwiseOp,
typename Ops::CDEElementwiseOp,
SPECIALIZATION.conv_spec,
SPECIALIZATION.gemm_spec,
ALGORITHM.num_gemm_k_prefetch_stages,
BLOCK.block_size,
BLOCK.per_block.m,
BLOCK.per_block.n,
BLOCK.per_block.k,
GRIDWISE_GEMM.ak1,
GRIDWISE_GEMM.bk1,
GRIDWISE_GEMM.m_per_xdl,
GRIDWISE_GEMM.n_per_xdl,
GRIDWISE_GEMM.m_xdl_per_wave,
GRIDWISE_GEMM.n_xdl_per_wave,
to_sequence_v<A_BLOCK_TRANSFER.thread_cluster_dims>,
to_sequence_v<A_BLOCK_TRANSFER.thread_cluster_order>,
to_sequence_v<A_BLOCK_TRANSFER.src_access_order>,
A_BLOCK_TRANSFER.src_vector_dim,
A_BLOCK_TRANSFER.src_scalar_per_vector,
A_BLOCK_TRANSFER.lds_dst_scalar_per_vector,
A_BLOCK_TRANSFER.lds_padding,
to_sequence_v<B_BLOCK_TRANSFER.thread_cluster_dims>,
to_sequence_v<B_BLOCK_TRANSFER.thread_cluster_order>,
to_sequence_v<B_BLOCK_TRANSFER.src_access_order>,
B_BLOCK_TRANSFER.src_vector_dim,
B_BLOCK_TRANSFER.src_scalar_per_vector,
B_BLOCK_TRANSFER.lds_dst_scalar_per_vector,
B_BLOCK_TRANSFER.lds_padding,
C_BLOCK_TRANSFER.m_per_wave_per_shuffle,
C_BLOCK_TRANSFER.n_per_wave_per_shuffle,
to_sequence_v<C_BLOCK_TRANSFER.thread_cluster_dims>,
C_BLOCK_TRANSFER.scalar_per_vector,
typename Types::AComputeType,
typename Types::BComputeType,
LOOP_SCHEDULER>;
};
} // namespace ck_tile::builder

View File

@@ -33,30 +33,35 @@ concept ConvDirectionIsBackwardWeight = (Sig.direction == ConvDirection::BACKWAR
// Predicate for DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 operation.
template <auto Sig>
concept ConvDeviceOpIs_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 =
ConvDirectionIsForward<Sig> &&
(Sig.device_operation._fwd ==
FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3);
// Predicate for DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK operation.
template <auto Sig>
concept ConvDeviceOpIs_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK =
ConvDirectionIsForward<Sig> &&
(Sig.device_operation._fwd ==
FwdGroupConvDeviceOperation::DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK);
// Predicate for DeviceGroupedConvFwdMultipleD_Wmma_CShuffle operation.
template <auto Sig>
concept ConvDeviceOpIs_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle =
ConvDirectionIsForward<Sig> &&
(Sig.device_operation._fwd ==
FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleD_Wmma_CShuffle);
// Predicate for DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle operation.
template <auto Sig>
concept ConvDeviceOpIs_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle =
ConvDirectionIsForward<Sig> &&
(Sig.device_operation._fwd ==
FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle);
// Predicate for DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor operation.
template <auto Sig>
concept ConvDeviceOpIs_DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor =
ConvDirectionIsForward<Sig> &&
(Sig.device_operation._fwd ==
FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor);
@@ -76,48 +81,56 @@ concept ConvDeviceOpIsForward =
// Predicate for DeviceGroupedConvBwdWeight operation.
template <auto Sig>
concept ConvDeviceOpIs_DeviceGroupedConvBwdWeight =
ConvDirectionIsBackwardWeight<Sig> &&
(Sig.device_operation._bwd_weight ==
BwdWeightGroupConvDeviceOperation::DeviceGroupedConvBwdWeight);
// Predicate for DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle operation.
template <auto Sig>
concept ConvDeviceOpIs_DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle =
ConvDirectionIsBackwardWeight<Sig> &&
(Sig.device_operation._bwd_weight ==
BwdWeightGroupConvDeviceOperation::DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle);
// Predicate for DeviceGroupedConvBwdWeight_Xdl_CShuffle operation.
template <auto Sig>
concept ConvDeviceOpIs_DeviceGroupedConvBwdWeight_Xdl_CShuffle =
ConvDirectionIsBackwardWeight<Sig> &&
(Sig.device_operation._bwd_weight ==
BwdWeightGroupConvDeviceOperation::DeviceGroupedConvBwdWeight_Xdl_CShuffle);
// Predicate for DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle operation.
template <auto Sig>
concept ConvDeviceOpIs_DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle =
ConvDirectionIsBackwardWeight<Sig> &&
(Sig.device_operation._bwd_weight ==
BwdWeightGroupConvDeviceOperation::DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle);
// Predicate for DeviceGroupedConvBwdWeight_Wmma_CShuffle operation.
template <auto Sig>
concept ConvDeviceOpIs_DeviceGroupedConvBwdWeight_Wmma_CShuffle =
ConvDirectionIsBackwardWeight<Sig> &&
(Sig.device_operation._bwd_weight ==
BwdWeightGroupConvDeviceOperation::DeviceGroupedConvBwdWeight_Wmma_CShuffle);
// Predicate for DeviceGroupedConvBwdWeight_Xdl_CShuffleV3 operation.
template <auto Sig>
concept ConvDeviceOpIs_DeviceGroupedConvBwdWeight_Xdl_CShuffleV3 =
ConvDirectionIsBackwardWeight<Sig> &&
(Sig.device_operation._bwd_weight ==
BwdWeightGroupConvDeviceOperation::DeviceGroupedConvBwdWeight_Xdl_CShuffleV3);
// Predicate for DeviceGroupedConvBwdWeightMultipleD operation.
template <auto Sig>
concept ConvDeviceOpIs_DeviceGroupedConvBwdWeightMultipleD =
ConvDirectionIsBackwardWeight<Sig> &&
(Sig.device_operation._bwd_weight ==
BwdWeightGroupConvDeviceOperation::DeviceGroupedConvBwdWeightMultipleD);
// Predicate for DeviceGroupedConvBwdWeight_Dl operation.
template <auto Sig>
concept ConvDeviceOpIs_DeviceGroupedConvBwdWeight_Dl =
ConvDirectionIsBackwardWeight<Sig> &&
(Sig.device_operation._bwd_weight ==
BwdWeightGroupConvDeviceOperation::DeviceGroupedConvBwdWeight_Dl);
@@ -140,18 +153,21 @@ concept ConvDeviceOpIsBackwardWeight =
// Predicate for DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 operation.
template <auto Sig>
concept ConvDeviceOpIs_DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 =
ConvDirectionIsBackwardData<Sig> &&
(Sig.device_operation._bwd_data ==
BwdDataGroupConvDeviceOperation::DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1);
// Predicate for DeviceGroupedConvBwdDataMultipleD operation.
template <auto Sig>
concept ConvDeviceOpIs_DeviceGroupedConvBwdDataMultipleD =
ConvDirectionIsBackwardData<Sig> &&
(Sig.device_operation._bwd_data ==
BwdDataGroupConvDeviceOperation::DeviceGroupedConvBwdDataMultipleD);
// Predicate for DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle operation.
template <auto Sig>
concept ConvDeviceOpIs_DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle =
ConvDirectionIsBackwardData<Sig> &&
(Sig.device_operation._bwd_data ==
BwdDataGroupConvDeviceOperation::DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle);

View File

@@ -0,0 +1,22 @@
// Copyright (C) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
namespace ck_tile::builder {
// Enumeration for CK Device Operation types.
// This allows the builder to select which device operation template to instantiate
// based on the user's requirements.
enum class DeviceOpType
{
// Forward Convolution - Non-grouped
CONV_FWD, // Maps to: DeviceConvFwd (TODO: No implementation with tuning params exists yet)
// Forward Convolution - Grouped
GROUPED_CONV_FWD_MULTIPLE_ABD, // Maps to: DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
GROUPED_CONV_FWD_MULTIPLE_ABD_XDL_CSHUFFLE_V3, // Maps to:
// DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
};
} // namespace ck_tile::builder

View File

@@ -0,0 +1,268 @@
// SPDX-License-Identifier: MIT
// Copyright (c) Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <concepts>
#include <string_view>
#include <sstream>
#include <type_traits>
#include <variant>
#include <ck_tile/builder/conv_signature_concepts.hpp>
#include <ck_tile/builder/reflect/conv_traits.hpp>
#include <ck_tile/builder/reflect/tree_formatter.hpp>
/// @file conv_description.hpp
/// @brief Provides human-readable descriptions of ConvBuilder configurations
namespace ck_tile::reflect::conv {
struct ConvSignatureInfo
{
int spatial_dim;
builder::ConvDirection direction;
std::variant<builder::GroupConvLayout1D, builder::GroupConvLayout2D, builder::GroupConvLayout3D>
layout;
builder::DataType data_type;
builder::ElementwiseOperation input_element_op;
builder::ElementwiseOperation weight_element_op;
builder::ElementwiseOperation output_element_op;
};
// Algorithm information - groups all algorithm-related configuration
struct GemmAlgorithmInfo
{
int thread_block_size;
DataTileInfo tile_dims;
WarpGemmParams warp_gemm;
InputTileTransferInfo a_tile_transfer;
InputTileTransferInfo b_tile_transfer;
OutputTileTransferInfo c_tile_transfer;
builder::PipelineVersion pipeline_version;
builder::PipelineScheduler pipeline_scheduler;
std::variant<builder::ConvFwdSpecialization,
builder::ConvBwdDataSpecialization,
builder::ConvBwdWeightSpecialization>
conv_specialization;
builder::GemmPadding padding;
};
// Provides human-readable descriptions of ConvBuilder configurations.
struct ConvDescription
{
ConvSignatureInfo signature;
GemmAlgorithmInfo algorithm;
// Brief one-line summary
std::string brief() const
{
std::ostringstream oss;
oss << signature.spatial_dim << "D " << signature.direction << " convolution";
return oss.str();
}
// Detailed hierarchical description
std::string detailed() const
{
TreeFormatter f;
f.writeLine(0, signature.spatial_dim, "D ", signature.direction, " Convolution Kernel");
f.writeLine(1, "Signature");
f.writeLine(2, "Tensor Type: ", signature.data_type);
f.writeLine(2, "Memory Layout: ", signature.layout);
f.writeLine(2, "Input elementwise operation: ", signature.input_element_op);
f.writeLine(2, "Weights elementwise operation: ", signature.weight_element_op);
f.writeLast(2, "Output elementwise operation: ", signature.output_element_op);
f.writeLine(1, "Algorithm");
// Compute Block section
f.writeLine(2, "Thread block size: ", algorithm.thread_block_size);
f.writeLine(2,
"Data tile size: ",
algorithm.tile_dims.m,
"×",
algorithm.tile_dims.n,
"×",
algorithm.tile_dims.k);
f.writeLine(2, "Gemm padding: ", algorithm.padding);
f.writeLine(2, "Convolution specialization: ", algorithm.conv_specialization);
// Pipeline section
f.writeLine(2, "Pipeline version: ", algorithm.pipeline_version);
f.writeLine(2, "Pipeline scheduler: ", algorithm.pipeline_scheduler);
f.writeLine(2, "Warp Gemm parameters: ");
f.writeLine(
3, "subtile size: ", algorithm.warp_gemm.gemm_m, "×", algorithm.warp_gemm.gemm_n);
f.writeLast(3,
"Number of warp gemm iterations: ",
algorithm.warp_gemm.m_iter,
"×",
algorithm.warp_gemm.n_iter);
// Memory Access section
f.writeLine(2, "Memory access:");
f.writeLine(3, "A Tile transfer: ");
f.writeLine(4,
"Tile dimensions: ",
algorithm.a_tile_transfer.tile_dimensions.k0,
"×",
algorithm.a_tile_transfer.tile_dimensions.m_or_n,
"×",
algorithm.a_tile_transfer.tile_dimensions.k1,
"×");
f.writeLine(
4, "The innermost K subdimension size: ", algorithm.a_tile_transfer.transfer_params.k1);
f.writeLine(4,
"Spatial thread distribution over the data tile: ",
algorithm.a_tile_transfer.transfer_params.thread_cluster_order[0],
"×",
algorithm.a_tile_transfer.transfer_params.thread_cluster_order[1],
"×",
algorithm.a_tile_transfer.transfer_params.thread_cluster_order[2]);
f.writeLine(4,
"The order of accessing data tile axes: ",
algorithm.a_tile_transfer.transfer_params.src_access_order[0],
"×",
algorithm.a_tile_transfer.transfer_params.src_access_order[1],
"×",
algorithm.a_tile_transfer.transfer_params.src_access_order[2]);
f.writeLine(4,
"Vectorized memory access axis index (with contiguous memory): ",
algorithm.a_tile_transfer.transfer_params.src_vector_dim);
f.writeLine(4,
"Vector access (GMEM read) instruction size: ",
algorithm.a_tile_transfer.transfer_params.src_scalar_per_vector);
f.writeLine(4,
"Vector access (LDS write) instruction size: ",
algorithm.a_tile_transfer.transfer_params.dst_scalar_per_vector_k1);
f.writeLast(4,
"LDS data layout padding (to prevent bank conflicts): ",
algorithm.a_tile_transfer.transfer_params.dst_scalar_per_vector_k1);
f.writeLine(3, "B Tile transfer: ");
f.writeLine(4,
"Tile dimensions: ",
algorithm.b_tile_transfer.tile_dimensions.k0,
"×",
algorithm.b_tile_transfer.tile_dimensions.m_or_n,
"×",
algorithm.b_tile_transfer.tile_dimensions.k1,
"×");
f.writeLine(
4, "The innermost K subdimension size: ", algorithm.b_tile_transfer.transfer_params.k1);
f.writeLine(4,
"Spatial thread distribution over the data tile: ",
algorithm.b_tile_transfer.transfer_params.thread_cluster_order[0],
"×",
algorithm.b_tile_transfer.transfer_params.thread_cluster_order[1],
"×",
algorithm.b_tile_transfer.transfer_params.thread_cluster_order[2]);
f.writeLine(4,
"The order of accessing data tile axes: ",
algorithm.b_tile_transfer.transfer_params.src_access_order[0],
"×",
algorithm.b_tile_transfer.transfer_params.src_access_order[1],
"×",
algorithm.b_tile_transfer.transfer_params.src_access_order[2]);
f.writeLine(4,
"Vectorized memory access axis index (with contiguous memory): ",
algorithm.b_tile_transfer.transfer_params.src_vector_dim);
f.writeLine(4,
"Vector access (GMEM read) instruction size: ",
algorithm.b_tile_transfer.transfer_params.src_scalar_per_vector);
f.writeLine(4,
"Vector access (LDS write) instruction size: ",
algorithm.b_tile_transfer.transfer_params.dst_scalar_per_vector_k1);
f.writeLast(4,
"LDS data layout padding (to prevent bank conflicts): ",
algorithm.b_tile_transfer.transfer_params.dst_scalar_per_vector_k1);
f.writeLast(3, "C Tile transfer: ");
f.writeLine(4,
"Data shuffle (number of gemm instructions per iteration): ",
algorithm.c_tile_transfer.shuffle_params.m_gemms_per_shuffle,
"×",
algorithm.c_tile_transfer.shuffle_params.n_gemms_per_shuffle);
f.writeLine(4,
"Spatial thread distribution used to store data: ",
algorithm.c_tile_transfer.thread_cluster_dims[0],
"×",
algorithm.c_tile_transfer.thread_cluster_dims[1],
"×",
algorithm.c_tile_transfer.thread_cluster_dims[2],
"×",
algorithm.c_tile_transfer.thread_cluster_dims[3]);
f.writeLast(4,
"Vector access (GMEM write) instruction size: ",
algorithm.c_tile_transfer.scalar_per_vector);
f.writeLast(2);
f.writeLast(1);
return f.getString();
}
// Educational explanation of optimization choices
std::string explain() const
{
std::ostringstream oss;
// Placeholder for future implementation
return oss.str();
}
// Performance characteristics and use case guidance
std::string suggest() const
{
std::ostringstream oss;
// Placeholder for future implementation
return oss.str();
}
};
// Helper concept to detect if a type has InstanceTraits specialization
template <typename T>
concept HasInstanceTraits = requires { typename InstanceTraits<T>; };
// Helper concept to detect ConvBuilder types
template <typename T>
concept IsConvBuilder = requires {
typename T::Factory;
typename T::Instance;
};
// Primary factory function: Create ConvDescription from Instance type directly
template <typename Instance>
requires HasInstanceTraits<Instance>
ConvDescription Describe()
{
using Traits = ConvTraits<Instance>;
return ConvDescription{
.signature = ConvSignatureInfo{.spatial_dim = Traits::spatial_dim,
.direction = Traits::direction,
.layout = Traits::layout,
.data_type = Traits::data_type,
.input_element_op = Traits::input_element_op,
.weight_element_op = Traits::weight_element_op,
.output_element_op = Traits::output_element_op},
.algorithm = GemmAlgorithmInfo{.thread_block_size = Traits::thread_block_size,
.tile_dims = Traits::tile_dims,
.warp_gemm = Traits::warp_gemm,
.a_tile_transfer = Traits::a_tile_transfer,
.b_tile_transfer = Traits::b_tile_transfer,
.c_tile_transfer = Traits::c_tile_transfer,
.pipeline_version = Traits::pipeline_version,
.pipeline_scheduler = Traits::pipeline_scheduler,
.conv_specialization = Traits::conv_specialization,
.padding = Traits::gemm_padding}};
}
// Backward compatibility: Create ConvDescription from Builder type
template <typename Builder>
requires IsConvBuilder<Builder> && (!HasInstanceTraits<Builder>)
ConvDescription Describe()
{
// Delegate to Instance-based version
using Instance = typename Builder::Instance;
return Describe<Instance>();
}
} // namespace ck_tile::reflect::conv

View File

@@ -0,0 +1,719 @@
// SPDX-License-Identifier: MIT
// Copyright (c) Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <concepts>
#include <ck_tile/builder/conv_builder.hpp>
#include <ck_tile/builder/conv_factory.hpp>
#include <ck_tile/builder/conv_signature_concepts.hpp>
#include <ck_tile/builder/reflect/instance_traits.hpp>
#include <ck_tile/builder/types.hpp>
#include <ck/tensor_operation/gpu/device/tensor_layout.hpp>
#include <ck/tensor_operation/gpu/device/convolution_backward_data_specialization.hpp>
#include <ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp>
#include <ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp>
#include <ck/utility/loop_scheduler.hpp>
#include <ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp>
namespace ck_tile::reflect::conv {
// Helper metafunctions to convert from ck enums to builder enums
/// @brief Converts a CK BlockGemmPipelineVersion enum to a builder PipelineVersion enum.
/// @tparam ck_ver The CK BlockGemmPipelineVersion enum value to convert.
/// @return The corresponding builder::PipelineVersion enum value (V1, V2, V3, V4, or V5).
/// @details This function maps CK's block GEMM pipeline version identifiers to the
/// builder framework's standardized pipeline version enum. The pipeline version
/// determines the strategy used for data movement and computation overlap in the
/// GEMM kernel's main loop.
template <ck::BlockGemmPipelineVersion ck_ver>
constexpr auto convert_pipeline_version()
{
using enum ck::BlockGemmPipelineVersion;
using enum builder::PipelineVersion;
if constexpr(ck_ver == v1)
return V1;
else if constexpr(ck_ver == v2)
return V2;
else if constexpr(ck_ver == v3)
return V3;
else if constexpr(ck_ver == v4)
return V4;
else if constexpr(ck_ver == v5)
return V5;
}
/// @brief Converts a CK PipelineVersion enum to a builder PipelineVersion enum.
/// @tparam ck_ver The CK PipelineVersion enum value to convert.
/// @return The corresponding builder::PipelineVersion enum value (V1, V2, V4, or WEIGHT_ONLY).
/// @details This function maps CK's general pipeline version identifiers to the
/// builder framework's standardized pipeline version enum. Note that this overload
/// handles a different set of pipeline versions compared to the BlockGemmPipelineVersion
/// variant, including support for specialized weight-only pipelines.
template <ck::PipelineVersion ck_ver>
constexpr auto convert_pipeline_version()
{
using enum ck::PipelineVersion;
using enum builder::PipelineVersion;
if constexpr(ck_ver == v1)
return V1;
else if constexpr(ck_ver == v2)
return V2;
else if constexpr(ck_ver == v4)
return V4;
else if constexpr(ck_ver == weight_only)
return WEIGHT_ONLY;
}
/// @brief Converts a CK BlockGemmPipelineScheduler enum to a builder PipelineScheduler enum.
/// @tparam ck_sched The CK BlockGemmPipelineScheduler enum value to convert.
/// @return The corresponding builder::PipelineScheduler enum value (INTRAWAVE or INTERWAVE).
/// @details This function maps CK's block GEMM pipeline scheduler identifiers to the
/// builder framework's standardized scheduler enum. The scheduler determines how work
/// is distributed and synchronized within and across wavefronts during pipeline execution.
/// INTRAWAVE scheduling operates within a single wavefront, while INTERWAVE coordinates
/// across multiple wavefronts.
template <ck::BlockGemmPipelineScheduler ck_sched>
constexpr auto convert_pipeline_scheduler()
{
using enum ck::BlockGemmPipelineScheduler;
using enum builder::PipelineScheduler;
if constexpr(ck_sched == Intrawave)
return INTRAWAVE;
else if constexpr(ck_sched == Interwave)
return INTERWAVE;
}
/// @brief Converts a CK LoopScheduler enum to a builder PipelineScheduler enum.
/// @tparam ck_sched The CK LoopScheduler enum value to convert.
/// @return The corresponding builder::PipelineScheduler enum value (DEFAULT or INTERWAVE).
/// @details This function maps CK's loop scheduler identifiers to the builder framework's
/// standardized pipeline scheduler enum. The loop scheduler controls how iterations of
/// the main computational loop are scheduled across threads. DEFAULT uses the standard
/// scheduling strategy, while INTERWAVE enables cross-wavefront coordination for improved
/// performance in certain scenarios.
template <ck::LoopScheduler ck_sched>
constexpr auto convert_pipeline_scheduler()
{
using enum ck::LoopScheduler;
using enum builder::PipelineScheduler;
if constexpr(ck_sched == Default)
return DEFAULT;
else if constexpr(ck_sched == Interwave)
return INTERWAVE;
}
/// @brief Helper structures for organizing trait data with domain-specific naming
/// @brief Data tile dimensions processed by a workgroup.
/// @details This struct defines the M, N, and K dimensions of the data tile
/// that a single workgroup (thread block) is responsible for processing in the
/// underlying GEMM computation.
struct DataTileInfo
{
int m; ///< M dimension of the tile processed by the workgroup (MPerBlock).
int n; ///< N dimension of the tile processed by the workgroup (NPerBlock).
int k; ///< K dimension of the tile processed by the workgroup (KPerBlock).
};
/// @brief Dimensions for an input data tile transfer.
/// @details Defines the shape of the input tile (A or B matrix) as it is
/// transferred from global memory to LDS. The tile is conceptually divided
/// into k0 and k1 dimensions.
struct InputTileTransferDimensions
{
int k0; ///< The outer dimension of K, where K = k0 * k1.
int m_or_n; ///< The M dimension for the A matrix transfer, or the N dimension for the B matrix.
int k1; ///< The inner dimension of K, often corresponding to the vector load size from global
///< memory.
};
/// @brief Parameters governing the transfer of an input tile.
/// @details This struct holds configuration details for how an input tile is
/// loaded from global memory into LDS, including thread clustering, memory
/// access patterns, and vectorization settings.
struct InputTileTransferParams
{
int k1; ///< The inner K dimension size, often matching the vectorization width.
std::array<int, 3>
thread_cluster_dims; ///< Spatial thread distribution over the input data tile; defines how
///< many threads are arranged on each axis.
std::array<int, 3> thread_cluster_order; ///< The order of thread spatial distribution over the
///< input tensor dimensions.
std::array<int, 3> src_access_order; ///< The order of accessing input tensor axes (e.g., which
///< dimension to read first).
int src_vector_dim; ///< The index of the axis on which vectorized memory access is performed
///< (the contiguous dimension).
int src_scalar_per_vector; ///< The size of the vector access instruction; the number of
///< elements accessed per thread per instruction.
int dst_scalar_per_vector_k1; ///< The size of the vectorized store into LDS memory along the K1
///< dimension.
bool lds_padding; ///< Flag indicating if padding is used for the LDS tensor to prevent bank
///< conflicts.
};
/// @brief Complete information for an input tile transfer.
/// @details Combines the dimensional information and transfer parameters for
/// a full description of an input tile's journey from global memory to LDS.
struct InputTileTransferInfo
{
InputTileTransferDimensions tile_dimensions; ///< The shape and layout of the tile.
InputTileTransferParams transfer_params; ///< The parameters for the memory transfer operation.
};
/// @brief Parameters for the warp-level GEMM computation.
/// @details Defines the configuration of the GEMM operation performed by each
/// warp using hardware MFMA (Matrix Fused Multiply-Add) instructions.
struct WarpGemmParams
{
int gemm_m; ///< The M dimension of a single MFMA instruction (MPerXdl).
int gemm_n; ///< The N dimension of a single MFMA instruction (NPerXdl).
int m_iter; ///< The number of MFMA iterations along the M dimension of the output tile per
///< wavefront (MXdlPerWave).
int n_iter; ///< The number of MFMA iterations along the N dimension of the output tile per
///< wavefront (NXdlPerWave).
};
/// @brief Parameters for shuffling data between warps (CShuffle optimization).
/// @details Configures how many MFMA instruction results are processed per
/// wave in each iteration of the CShuffle routine.
struct WarpShuffleParams
{
int m_gemms_per_shuffle; ///< Number of MFMA results along the M dimension to process per wave
///< per shuffle iteration.
int n_gemms_per_shuffle; ///< Number of MFMA results along the N dimension to process per wave
///< per shuffle iteration.
};
/// @brief Information for the output tile transfer (CShuffle).
/// @details Describes how the final computed tile (C matrix) is written out from
/// LDS to global memory, including shuffling, thread clustering, and vectorization.
struct OutputTileTransferInfo
{
WarpShuffleParams shuffle_params; ///< Configuration for cross-warp data shuffling.
// m_block, m_wave_per_xdl, n_block, n_wave_per_xdl
std::array<int, 4> thread_cluster_dims; ///< The spatial thread distribution used for storing
///< data into the output tensor.
int scalar_per_vector; ///< The size of the vectorized memory access when storing data to the
///< output tensor.
};
// Helper metafunctions to derive signature information from Instance types
/// @brief Derives the convolution direction from a device kernel `Instance` type.
/// @tparam Instance The device kernel instance type.
/// @return A `builder::ConvDirection` enum value (FORWARD, BACKWARD_DATA, or BACKWARD_WEIGHT).
template <typename Instance>
constexpr builder::ConvDirection conv_direction()
{
using InstTraits = InstanceTraits<Instance>;
if constexpr(requires { &InstTraits::kConvForwardSpecialization; })
{
return builder::ConvDirection::FORWARD;
}
else if constexpr(requires { &InstTraits::kConvBwdDataSpecialization; })
{
return builder::ConvDirection::BACKWARD_DATA;
}
else if constexpr(requires { &InstTraits::kConvBwdWeightSpecialization; })
{
return builder::ConvDirection::BACKWARD_WEIGHT;
}
else
{
return builder::ConvDirection::FORWARD; // Default fallback
}
}
/// @brief Derives the convolution-specific specialization from a device kernel `Instance` type.
/// @tparam Instance The device kernel instance type.
/// @return A `builder::ConvFwdSpecialization`, `builder::ConvBwdDataSpecialization`, or
/// `builder::ConvBwdWeightSpecialization` enum value.
template <typename Instance>
constexpr auto conv_spec()
{
using InstTraits = InstanceTraits<Instance>;
if constexpr(requires { InstTraits::kConvForwardSpecialization; })
{
using enum ck::tensor_operation::device::ConvolutionForwardSpecialization;
if constexpr(InstTraits::kConvForwardSpecialization == Default)
{
return builder::ConvFwdSpecialization::DEFAULT;
}
else if constexpr(InstTraits::kConvForwardSpecialization == Filter1x1Pad0)
{
return builder::ConvFwdSpecialization::FILTER_1X1_PAD0;
}
else if constexpr(InstTraits::kConvForwardSpecialization == Filter1x1Stride1Pad0)
{
return builder::ConvFwdSpecialization::FILTER_1X1_STRIDE1_PAD0;
}
else if constexpr(InstTraits::kConvForwardSpecialization == Filter3x3)
{
return builder::ConvFwdSpecialization::FILTER_3x3;
}
}
else if constexpr(requires { InstTraits::kConvBwdDataSpecialization; })
{
using enum ck::tensor_operation::device::ConvolutionBackwardDataSpecialization;
if constexpr(InstTraits::kConvBwdDataSpecialization == Default)
{
return builder::ConvBwdDataSpecialization::DEFAULT;
}
else if constexpr(InstTraits::kConvBwdDataSpecialization == Filter1x1Stride1Pad0)
{
return builder::ConvBwdDataSpecialization::FILTER_1X1_STRIDE1_PAD0;
}
}
else if constexpr(requires { InstTraits::kConvBwdWeightSpecialization; })
{
using enum ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization;
if constexpr(InstTraits::kConvBwdWeightSpecialization == Default)
{
return builder::ConvBwdWeightSpecialization::DEFAULT;
}
else if constexpr(InstTraits::kConvBwdWeightSpecialization == Filter1x1Stride1Pad0)
{
return builder::ConvBwdWeightSpecialization::FILTER_1X1_STRIDE1_PAD0;
}
else if constexpr(InstTraits::kConvBwdWeightSpecialization == Filter1x1Pad0)
{
return builder::ConvBwdWeightSpecialization::FILTER_1X1_PAD0;
}
else if constexpr(InstTraits::kConvBwdWeightSpecialization == OddC)
{
return builder::ConvBwdWeightSpecialization::ODD_C;
}
}
}
/// @brief Derives the grouped convolution layout from a device kernel `Instance` type.
/// @tparam Instance The device kernel instance type.
/// @return A `builder::GroupConvLayout{1D|2D|3D}` enum value corresponding to the tensor layouts.
template <typename Instance>
constexpr auto conv_layout()
{
using InstTraits = InstanceTraits<Instance>;
using ALayout = typename InstTraits::ALayout;
using BLayout = typename InstTraits::BLayout;
using ELayout = typename InstTraits::ELayout;
namespace ctc = ck::tensor_layout::convolution;
if constexpr(InstTraits::kSpatialDim == 1)
{
if constexpr(std::is_same_v<ALayout, ctc::GNWC> && std::is_same_v<BLayout, ctc::GKXC> &&
std::is_same_v<ELayout, ctc::GNWK>)
{
return builder::GroupConvLayout1D::GNWC_GKXC_GNWK;
}
else if constexpr(std::is_same_v<ALayout, ctc::NWGC> &&
std::is_same_v<BLayout, ctc::GKXC> && std::is_same_v<ELayout, ctc::NWGK>)
{
return builder::GroupConvLayout1D::NWGC_GKXC_NWGK;
}
else if constexpr(std::is_same_v<ALayout, ctc::NGCW> &&
std::is_same_v<BLayout, ctc::GKXC> && std::is_same_v<ELayout, ctc::NGKW>)
{
return builder::GroupConvLayout1D::NGCW_GKXC_NGKW;
}
else if constexpr(std::is_same_v<ALayout, ctc::NGCW> &&
std::is_same_v<BLayout, ctc::GKCX> && std::is_same_v<ELayout, ctc::NGKW>)
{
return builder::GroupConvLayout1D::NGCW_GKCX_NGKW;
}
}
else if constexpr(InstTraits::kSpatialDim == 2)
{
if constexpr(std::is_same_v<ALayout, ctc::GNHWC> && std::is_same_v<BLayout, ctc::GKYXC> &&
std::is_same_v<ELayout, ctc::GNHWK>)
{
return builder::GroupConvLayout2D::GNHWC_GKYXC_GNHWK;
}
else if constexpr(std::is_same_v<ALayout, ctc::NHWGC> &&
std::is_same_v<BLayout, ctc::GKYXC> &&
std::is_same_v<ELayout, ctc::NHWGK>)
{
return builder::GroupConvLayout2D::NHWGC_GKYXC_NHWGK;
}
else if constexpr(std::is_same_v<ALayout, ctc::NGCHW> &&
std::is_same_v<BLayout, ctc::GKYXC> &&
std::is_same_v<ELayout, ctc::NGKHW>)
{
return builder::GroupConvLayout2D::NGCHW_GKYXC_NGKHW;
}
else if constexpr(std::is_same_v<ALayout, ctc::NGCHW> &&
std::is_same_v<BLayout, ctc::GKCYX> &&
std::is_same_v<ELayout, ctc::NGKHW>)
{
return builder::GroupConvLayout2D::NGCHW_GKCYX_NGKHW;
}
}
else if constexpr(InstTraits::kSpatialDim == 3)
{
if constexpr(std::is_same_v<ALayout, ctc::GNDHWC> && std::is_same_v<BLayout, ctc::GKZYXC> &&
std::is_same_v<ELayout, ctc::GNDHWK>)
{
return builder::GroupConvLayout3D::GNDHWC_GKZYXC_GNDHWK;
}
else if constexpr(std::is_same_v<ALayout, ctc::NDHWGC> &&
std::is_same_v<BLayout, ctc::GKZYXC> &&
std::is_same_v<ELayout, ctc::NDHWGK>)
{
return builder::GroupConvLayout3D::NDHWGC_GKZYXC_NDHWGK;
}
else if constexpr(std::is_same_v<ALayout, ctc::NGCDHW> &&
std::is_same_v<BLayout, ctc::GKZYXC> &&
std::is_same_v<ELayout, ctc::NGKDHW>)
{
return builder::GroupConvLayout3D::NGCDHW_GKZYXC_NGKDHW;
}
else if constexpr(std::is_same_v<ALayout, ctc::NGCDHW> &&
std::is_same_v<BLayout, ctc::GKCZYX> &&
std::is_same_v<ELayout, ctc::NGKDHW>)
{
return builder::GroupConvLayout3D::NGCDHW_GKCZYX_NGKDHW;
}
}
}
/// @brief Derives the data type from a device kernel `Instance` type.
/// @tparam Instance The device kernel instance type.
/// @return A `builder::DataType` enum value (e.g., FP16, BF16, FP32).
template <typename Instance>
constexpr builder::DataType conv_data_type()
{
using InstTraits = InstanceTraits<Instance>;
using ADataType = typename InstTraits::ADataType;
if constexpr(std::is_same_v<ADataType, ck::half_t>)
{
return builder::DataType::FP16;
}
else if constexpr(std::is_same_v<ADataType, ck::bhalf_t>)
{
return builder::DataType::BF16;
}
else if constexpr(std::is_same_v<ADataType, float>)
{
return builder::DataType::FP32;
}
else if constexpr(std::is_same_v<ADataType, ck::f8_t>)
{
return builder::DataType::FP8;
}
else if constexpr(std::is_same_v<ADataType, int8_t>)
{
return builder::DataType::I8;
}
else if constexpr(std::is_same_v<ADataType, uint8_t>)
{
return builder::DataType::U8;
}
else
{
// Default fallback
return builder::DataType::FP32;
}
}
/// @brief Derives the elementwise operation from op type.
/// @tparam ElementwiseOp Elementwise operation functor type.
/// @return A `builder::ElementwiseOperation` enum value corresponding to elementwise operation.
template <typename ElementwiseOp>
constexpr builder::ElementwiseOperation elementwise_op()
{
constexpr std::string_view name = detail::elementwise_op_name<ElementwiseOp>();
if constexpr(detail::case_insensitive_equal(name, "Bias"))
{
return builder::ElementwiseOperation::BIAS;
}
else if constexpr(detail::case_insensitive_equal(name, "BiasClamp"))
{
return builder::ElementwiseOperation::BIAS_CLAMP;
}
else if constexpr(detail::case_insensitive_equal(name, "BiasBnormClamp"))
{
return builder::ElementwiseOperation::BIAS_BNORM_CLAMP;
}
else if constexpr(detail::case_insensitive_equal(name, "Bilinear"))
{
return builder::ElementwiseOperation::BILINEAR;
}
else if constexpr(detail::case_insensitive_equal(name, "Clamp"))
{
return builder::ElementwiseOperation::CLAMP;
}
else if constexpr(detail::case_insensitive_equal(name, "Scale"))
{
return builder::ElementwiseOperation::SCALE;
}
else if constexpr(detail::case_insensitive_equal(name, "PassThrough"))
{
return builder::ElementwiseOperation::PASS_THROUGH;
}
}
/// @brief Derives a gemm padding from a kernel instance type.
/// @tparam Instance - A Device Kernel object type.
/// @return A `builder::GemmPadding` enum value corresponding to kernel padding.
template <typename Instance>
constexpr builder::GemmPadding gemm_spec()
{
using InstTraits = InstanceTraits<Instance>;
using enum builder::GemmPadding;
using enum ck::tensor_operation::device::GemmSpecialization;
constexpr auto gemm_spec = InstTraits::kGemmSpecialization;
if constexpr(gemm_spec == Default)
{
return DEFAULT;
}
else if constexpr(gemm_spec == MPadding)
{
return M_PADDING;
}
else if constexpr(gemm_spec == NPadding)
{
return N_PADDING;
}
else if constexpr(gemm_spec == KPadding)
{
return K_PADDING;
}
else if constexpr(gemm_spec == MNPadding)
{
return MN_PADDING;
}
else if constexpr(gemm_spec == MKPadding)
{
return MK_PADDING;
}
else if constexpr(gemm_spec == NKPadding)
{
return NK_PADDING;
}
else if constexpr(gemm_spec == MNKPadding)
{
return MNK_PADDING;
}
else if constexpr(gemm_spec == OPadding)
{
return O_PADDING;
}
else if constexpr(gemm_spec == MOPadding)
{
return MO_PADDING;
}
else if constexpr(gemm_spec == NOPadding)
{
return NO_PADDING;
}
else if constexpr(gemm_spec == KOPadding)
{
return KO_PADDING;
}
else if constexpr(gemm_spec == MNOPadding)
{
return MNO_PADDING;
}
else if constexpr(gemm_spec == MKOPadding)
{
return MKO_PADDING;
}
else if constexpr(gemm_spec == NKOPadding)
{
return NKO_PADDING;
}
else if constexpr(gemm_spec == MNKOPadding)
{
return MNKO_PADDING;
}
}
/// @brief Primary template for extracting convolution traits.
/// @details This struct is the main entry point for reflecting on a convolution
/// kernel's properties. It is specialized to handle different kinds of input types.
template <typename T>
struct ConvTraits;
/// @brief Specialization of `ConvTraits` for a direct device kernel `Instance`.
/// @details This is the primary specialization used to extract a comprehensive
/// set of traits directly from a fully-formed device kernel `Instance` type.
/// It uses `InstanceTraits` to access the kernel's template parameters.
template <typename Instance>
requires requires { typename InstanceTraits<Instance>; }
struct ConvTraits<Instance>
{
using InstTraits = InstanceTraits<Instance>;
// --- Signature Information ---
/// @brief The number of spatial dimensions in the convolution (1, 2, or 3).
static constexpr int spatial_dim = InstTraits::kSpatialDim;
/// @brief The direction of the convolution (Forward, Backward Data, or Backward Weight).
static constexpr builder::ConvDirection direction = conv_direction<Instance>();
/// @brief The memory layout of the convolution tensors (e.g., GNHWC_GKYXC_GNHWK).
static constexpr auto layout = conv_layout<Instance>();
/// @brief The primary data type used in the computation (e.g., FP16, FP32).
static constexpr builder::DataType data_type = conv_data_type<Instance>();
static constexpr builder::ElementwiseOperation input_element_op =
elementwise_op<typename InstTraits::AElementwiseOperation>();
static constexpr builder::ElementwiseOperation weight_element_op =
elementwise_op<typename InstTraits::BElementwiseOperation>();
static constexpr builder::ElementwiseOperation output_element_op =
elementwise_op<typename InstTraits::CDEElementwiseOperation>();
/// @brief The GEMM specialization used by the kernel - padding
static constexpr auto gemm_padding = gemm_spec<Instance>();
/// @brief The convolution-specific specialization (e.g., Default, 1x1).
static constexpr auto conv_specialization = conv_spec<Instance>();
// --- Algorithm Information ---
/// @brief The total number of threads in a thread block (workgroup).
static constexpr int thread_block_size = InstTraits::kBlockSize;
/// @brief The dimensions of the data tile processed by the thread block.
static constexpr DataTileInfo tile_dims = {
.m = InstTraits::kMPerBlock, .n = InstTraits::kNPerBlock, .k = InstTraits::kKPerBlock};
/// @brief Configuration for the A-matrix (input) tile transfer.
static constexpr InputTileTransferInfo a_tile_transfer = {
.tile_dimensions = {.k0 = InstTraits::kKPerBlock / InstTraits::kAK1,
.m_or_n = InstTraits::kMPerBlock,
.k1 = InstTraits::kAK1},
.transfer_params = {.k1 = InstTraits::kAK1,
.thread_cluster_dims = InstTraits::kAThreadClusterLengths,
.thread_cluster_order = InstTraits::kAThreadClusterArrangeOrder,
.src_access_order = InstTraits::kABlockTransferSrcAccessOrder,
.src_vector_dim = InstTraits::kABlockTransferSrcVectorDim,
.src_scalar_per_vector = InstTraits::kABlockTransferSrcScalarPerVector,
.dst_scalar_per_vector_k1 =
InstTraits::kABlockTransferDstScalarPerVectorK1,
.lds_padding = static_cast<bool>(InstTraits::kABlockLdsExtraM)}};
/// @brief Configuration for the B-matrix (weights) tile transfer.
static constexpr InputTileTransferInfo b_tile_transfer = {
.tile_dimensions = {.k0 = InstTraits::kKPerBlock / InstTraits::kBK1,
.m_or_n = InstTraits::kNPerBlock,
.k1 = InstTraits::kBK1},
.transfer_params = {.k1 = InstTraits::kBK1,
.thread_cluster_dims = InstTraits::kBThreadClusterLengths,
.thread_cluster_order = InstTraits::kBThreadClusterArrangeOrder,
.src_access_order = InstTraits::kBBlockTransferSrcAccessOrder,
.src_vector_dim = InstTraits::kBBlockTransferSrcVectorDim,
.src_scalar_per_vector = InstTraits::kBBlockTransferSrcScalarPerVector,
.dst_scalar_per_vector_k1 =
InstTraits::kBBlockTransferDstScalarPerVectorK1,
.lds_padding = static_cast<bool>(InstTraits::kBBlockLdsExtraN)}};
/// @brief Parameters for the warp-level GEMM computation.
static constexpr WarpGemmParams warp_gemm = {.gemm_m = InstTraits::kMPerXDL,
.gemm_n = InstTraits::kNPerXDL,
.m_iter = InstTraits::kMXdlPerWave,
.n_iter = InstTraits::kNXdlPerWave};
/// @brief Configuration for the C-matrix (output) tile transfer.
static constexpr OutputTileTransferInfo c_tile_transfer = {
.shuffle_params = {.m_gemms_per_shuffle = InstTraits::kCShuffleMXdlPerWavePerShuffle,
.n_gemms_per_shuffle = InstTraits::kCShuffleNXdlPerWavePerShuffle},
.thread_cluster_dims = {InstTraits::kCThreadClusterLengths[0],
InstTraits::kCThreadClusterLengths[1],
InstTraits::kCThreadClusterLengths[2],
InstTraits::kCThreadClusterLengths[3]},
.scalar_per_vector = InstTraits::kCBlockTransferScalarPerVector};
/// @brief Helper to safely get the pipeline version.
/// @details This is only available for some convolutions (e.g., forward).
/// If not present in `InstanceTraits`, it returns a default value.
template <typename T = InstTraits>
static constexpr auto get_pipeline_version()
{
if constexpr(requires { T::kPipelineVersion; })
{
return convert_pipeline_version<T::kPipelineVersion>();
}
else
{
// Return a default or indicate not available
return builder::PipelineVersion::V1;
}
}
/// @brief The block GEMM pipeline version used by the kernel.
static constexpr auto pipeline_version = get_pipeline_version();
/// @brief Helper to safely get the pipeline scheduler.
/// @details This is only available for some convolutions. If not present
/// in `InstanceTraits`, it returns a default value.
template <typename T = InstTraits>
static constexpr auto get_pipeline_scheduler()
{
if constexpr(requires { T::kPipelineScheduler; })
{
return convert_pipeline_scheduler<T::kPipelineScheduler>();
}
else if constexpr(requires { T::kLoopScheduler; })
{
return convert_pipeline_scheduler<T::kLoopScheduler>();
}
else
{
// Return a default or indicate not available
return builder::PipelineScheduler::DEFAULT;
}
}
/// @brief The pipeline scheduler used by the kernel.
static constexpr auto pipeline_scheduler = get_pipeline_scheduler();
};
/// @brief Specialization of `ConvTraits` for a `ConvBuilder` type.
/// @details This specialization provides backward compatibility for reflecting
/// on kernels defined via the `ConvBuilder` interface. It works by first
/// creating the `Instance` via the builder's factory, and then delegating
/// all trait extraction to the `ConvTraits<Instance>` specialization.
template <builder::ConvSignatureDescriptor auto SIGNATURE,
builder::ConvAlgorithmDescriptor auto ALGORITHM,
builder::StringLiteral VERSION>
struct ConvTraits<builder::ConvBuilder<SIGNATURE, ALGORITHM, VERSION>>
{
using Factory = builder::ConvFactory<SIGNATURE, ALGORITHM, VERSION>;
using Instance = typename Factory::Instance;
// Delegate to Instance-based ConvTraits
using InstanceConvTraits = ConvTraits<Instance>;
// Forward all members from Instance-based traits
static constexpr int spatial_dim = InstanceConvTraits::spatial_dim;
static constexpr builder::ConvDirection direction = InstanceConvTraits::direction;
static constexpr auto layout = InstanceConvTraits::layout;
static constexpr builder::DataType data_type = InstanceConvTraits::data_type;
static constexpr builder::ElementwiseOperation input_element_op =
InstanceConvTraits::input_element_op;
static constexpr builder::ElementwiseOperation weight_element_op =
InstanceConvTraits::weight_element_op;
static constexpr builder::ElementwiseOperation output_element_op =
InstanceConvTraits::output_element_op;
static constexpr auto gemm_padding = InstanceConvTraits::gemm_padding;
static constexpr auto conv_specialization = InstanceConvTraits::conv_specialization;
static constexpr int thread_block_size = InstanceConvTraits::thread_block_size;
static constexpr DataTileInfo tile_dims = InstanceConvTraits::tile_dims;
static constexpr InputTileTransferInfo a_tile_transfer = InstanceConvTraits::a_tile_transfer;
static constexpr InputTileTransferInfo b_tile_transfer = InstanceConvTraits::b_tile_transfer;
static constexpr WarpGemmParams warp_gemm = InstanceConvTraits::warp_gemm;
static constexpr OutputTileTransferInfo c_tile_transfer = InstanceConvTraits::c_tile_transfer;
static constexpr auto pipeline_version = InstanceConvTraits::pipeline_version;
static constexpr auto pipeline_scheduler = InstanceConvTraits::pipeline_scheduler;
};
} // namespace ck_tile::reflect::conv

View File

@@ -14,18 +14,9 @@
#pragma once
#include <array>
#include <string>
#include <sstream>
#include <type_traits>
#include <ck/utility/data_type.hpp>
#include <ck/utility/sequence.hpp>
#include <ck/utility/blkgemmpipe_scheduler.hpp>
#include <ck/tensor_operation/gpu/device/tensor_layout.hpp>
#include <ck/tensor_operation/gpu/element/element_wise_operation.hpp>
#include <ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp>
#include <ck/tensor_operation/gpu/device/gemm_specialization.hpp>
#include "instance_traits_util.hpp"
#include <concepts>
namespace ck_tile::reflect {

View File

@@ -15,6 +15,7 @@
#pragma once
#include "instance_traits.hpp"
#include "instance_traits_util.hpp"
// Forward declaration to avoid circular dependency.
// This file will be included by the device implementation header, so we cannot include

View File

@@ -14,6 +14,7 @@
#pragma once
#include "instance_traits.hpp"
#include "instance_traits_util.hpp"
// Forward declaration to avoid circular dependency.
// This file will be included by the device implementation header, so we cannot include

View File

@@ -9,9 +9,14 @@
#include <array>
#include <string>
#include <concepts>
#include <string_view>
#include <sstream>
#include <type_traits>
#include <limits.h>
#include <cmath>
#include <ostream>
#include <iostream>
#include <ck/utility/data_type.hpp>
#include <ck/utility/sequence.hpp>
#include <ck/utility/blkgemmpipe_scheduler.hpp>
@@ -371,4 +376,30 @@ constexpr std::string type_or_type_tuple_name()
}
}
/// @brief Makes a case insensitive comparison of two string views.
/// @param a First string view
/// @param b Second string view
/// @return Whether two string views a equal case insensitive
constexpr bool case_insensitive_equal(std::string_view a, std::string_view b)
{
if(a.size() != b.size())
return false;
for(size_t i = 0; i < a.size(); ++i)
{
char c1 = a[i];
char c2 = b[i];
// Convert to lowercase for comparison
if(c1 >= 'A' && c1 <= 'Z')
c1 += 32;
if(c2 >= 'A' && c2 <= 'Z')
c2 += 32;
if(c1 != c2)
return false;
}
return true;
}
} // namespace ck_tile::reflect::detail

View File

@@ -0,0 +1,106 @@
#pragma once
#include <sstream>
#include <string>
#include <type_traits>
#include <vector>
namespace ck_tile::reflect {
// Helper class for formatting hierarchical tree structures with proper indentation
// and tree-drawing characters (├─, └─, │, etc.)
//
// Example Usage:
//
// TreeFormatter f;
// f.writeLine(0, "Root");
// f.writeLine(1, "Branch 1");
// f.writeLine(2, "Item 1a");
// f.writeLast(2, "Item 1b");
// f.writeLast(1, "Branch 2");
// f.writeLast(2, "Item 2a");
// std::cout << f.getString() << "\n";
//
// Generated Output:
//
// Root
// ├─ Branch 1
// │ ├─ Item 1a
// │ └─ Item 1b
// └─ Branch 2
// └─ Item 2a
class TreeFormatter
{
public:
TreeFormatter() = default;
// Write a line at the specified indentation level (branch continues after this)
template <typename... Args>
void writeLine(int indent_level, Args&&... args)
{
writeLineImpl(indent_level, false, std::forward<Args>(args)...);
}
// Write the last line at the specified indentation level (branch ends)
template <typename... Args>
void writeLast(int indent_level, Args&&... args)
{
writeLineImpl(indent_level, true, std::forward<Args>(args)...);
}
// Get the formatted string (removes trailing newline if present)
std::string getString() const
{
std::string result = oss_.str();
if(!result.empty() && result.back() == '\n')
{
result.pop_back();
}
return result;
}
private:
std::ostringstream oss_;
std::vector<bool> is_last_at_level_; // Tracks which levels have ended
// Implementation of line writing with tree symbols
template <typename... Args>
void writeLineImpl(int indent_level, bool is_last, Args&&... args)
{
// Ensure we have enough tracking space
if(static_cast<size_t>(indent_level) >= is_last_at_level_.size())
{
is_last_at_level_.resize(indent_level + 1, false);
// Level 0 (root) should always be treated as "last" since it has no tree symbols
if(is_last_at_level_.size() > 0)
{
is_last_at_level_[0] = true;
}
}
// Draw the tree structure
// Start from level 1 (skip level 0 which is the root with no symbols)
for(int i = 1; i < indent_level; ++i)
{
// For all parent levels, draw vertical line or space based on whether they ended
oss_ << (is_last_at_level_[i] ? " " : "");
}
// Draw the branch symbol for the current level
if(indent_level > 0)
{
oss_ << (is_last ? "└─ " : "├─ ");
}
// Write the content using fold expression with direct stream insertion
((oss_ << std::forward<Args>(args)), ...);
oss_ << '\n';
// Update tracking for this level AFTER writing the line
// This ensures future lines at deeper levels know if this level ended
is_last_at_level_[indent_level] = is_last;
}
};
} // namespace ck_tile::reflect

View File

@@ -3,6 +3,10 @@
#pragma once
#include <ostream>
#include <string_view>
#include <variant>
namespace ck_tile::builder {
enum class DataType
@@ -128,29 +132,14 @@ enum class ElementwiseOperation
PASS_THROUGH
};
// Enums for the current block GEMM pipeline versions.
enum class BlockGemmPipelineVersion
// Enums for pipeline versions & schedulers
enum class PipelineVersion
{
V1,
V2,
V3,
V4,
V5
};
enum struct BlockGemmPipelineScheduler
{
INTRAWAVE,
INTERWAVE,
};
// Enums for the gridwise GEMM pipeline versions.
enum class GridwiseGemmPipelineVersion
{
V1,
V2,
V3, // Only used in stream-K implementation
V4,
V5,
WEIGHT_ONLY
};
@@ -186,10 +175,319 @@ enum class ConvFwdSpecialization
FILTER_3x3
};
enum class LoopScheduler
// Enums for the backward data convolution specialization.
enum class ConvBwdDataSpecialization
{
DEFAULT,
FILTER_1X1_STRIDE1_PAD0,
};
// Enums for the backward weight convolution specialization.
enum class ConvBwdWeightSpecialization
{
DEFAULT,
FILTER_1X1_STRIDE1_PAD0,
FILTER_1X1_PAD0,
ODD_C,
};
// Enums for the Gemm padding.
enum class GemmPadding
{
DEFAULT,
M_PADDING,
N_PADDING,
K_PADDING,
MN_PADDING,
MK_PADDING,
NK_PADDING,
MNK_PADDING,
O_PADDING,
MO_PADDING,
NO_PADDING,
KO_PADDING,
MNO_PADDING,
MKO_PADDING,
NKO_PADDING,
MNKO_PADDING,
};
enum class PipelineScheduler
{
DEFAULT,
INTRAWAVE,
INTERWAVE
};
// ostream operator overloads for enum classes
inline std::ostream& operator<<(std::ostream& os, DataType dt)
{
using enum DataType;
switch(dt)
{
case FP16: return os << "FP16";
case FP32: return os << "FP32";
case BF16: return os << "BF16";
case FP8: return os << "FP8";
case I8: return os << "I8";
case U8: return os << "U8";
default: return os << "Unknown";
}
}
inline std::ostream& operator<<(std::ostream& os, ConvDirection dir)
{
using enum ConvDirection;
switch(dir)
{
case FORWARD: return os << "Forward";
case BACKWARD_DATA: return os << "Backward Data";
case BACKWARD_WEIGHT: return os << "Backward Weight";
default: return os << "Unknown";
}
}
inline std::ostream& operator<<(std::ostream& os, GroupConvLayout1D layout)
{
using enum GroupConvLayout1D;
switch(layout)
{
case GNWC_GKXC_GNWK: return os << "GNWC_GKXC_GNWK";
case NWGC_GKXC_NWGK: return os << "NWGC_GKXC_NWGK";
case NGCW_GKXC_NGKW: return os << "NGCW_GKXC_NGKW";
case NGCW_GKCX_NGKW: return os << "NGCW_GKCX_NGKW";
default: return os << "Unknown";
}
}
inline std::ostream& operator<<(std::ostream& os, GroupConvLayout2D layout)
{
using enum GroupConvLayout2D;
switch(layout)
{
case GNHWC_GKYXC_GNHWK: return os << "GNHWC_GKYXC_GNHWK";
case NHWGC_GKYXC_NHWGK: return os << "NHWGC_GKYXC_NHWGK";
case NGCHW_GKYXC_NGKHW: return os << "NGCHW_GKYXC_NGKHW";
case NGCHW_GKCYX_NGKHW: return os << "NGCHW_GKCYX_NGKHW";
default: return os << "Unknown";
}
}
inline std::ostream& operator<<(std::ostream& os, GroupConvLayout3D layout)
{
using enum GroupConvLayout3D;
switch(layout)
{
case GNDHWC_GKZYXC_GNDHWK: return os << "GNDHWC_GKZYXC_GNDHWK";
case NDHWGC_GKZYXC_NDHWGK: return os << "NDHWGC_GKZYXC_NDHWGK";
case NGCDHW_GKZYXC_NGKDHW: return os << "NGCDHW_GKZYXC_NGKDHW";
case NGCDHW_GKCZYX_NGKDHW: return os << "NGCDHW_GKCZYX_NGKDHW";
default: return os << "Unknown";
}
}
inline std::ostream& operator<<(std::ostream& os, FwdGroupConvDeviceOperation op)
{
using enum FwdGroupConvDeviceOperation;
switch(op)
{
case DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK:
return os << "DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK";
case DeviceGroupedConvFwdMultipleD_Wmma_CShuffle:
return os << "DeviceGroupedConvFwdMultipleD_Wmma_CShuffle";
case DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle:
return os << "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle";
case DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3:
return os << "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3";
case DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor:
return os << "DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor";
default: return os << "Unknown";
}
}
inline std::ostream& operator<<(std::ostream& os, BwdDataGroupConvDeviceOperation op)
{
using enum BwdDataGroupConvDeviceOperation;
switch(op)
{
case DeviceGroupedConvBwdDataMultipleD: return os << "DeviceGroupedConvBwdDataMultipleD";
case DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle:
return os << "DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle";
case DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1:
return os << "DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1";
default: return os << "Unknown";
}
}
inline std::ostream& operator<<(std::ostream& os, BwdWeightGroupConvDeviceOperation op)
{
using enum BwdWeightGroupConvDeviceOperation;
switch(op)
{
case DeviceGroupedConvBwdWeight: return os << "DeviceGroupedConvBwdWeight";
case DeviceGroupedConvBwdWeight_Dl: return os << "DeviceGroupedConvBwdWeight_Dl";
case DeviceGroupedConvBwdWeight_Xdl_CShuffle:
return os << "DeviceGroupedConvBwdWeight_Xdl_CShuffle";
case DeviceGroupedConvBwdWeight_Xdl_CShuffleV3:
return os << "DeviceGroupedConvBwdWeight_Xdl_CShuffleV3";
case DeviceGroupedConvBwdWeight_Wmma_CShuffle:
return os << "DeviceGroupedConvBwdWeight_Wmma_CShuffle";
case DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle:
return os << "DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle";
case DeviceGroupedConvBwdWeightMultipleD: return os << "DeviceGroupedConvBwdWeightMultipleD";
case DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle:
return os << "DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle";
default: return os << "Unknown";
}
}
inline std::ostream& operator<<(std::ostream& os, ElementwiseOperation op)
{
using enum ElementwiseOperation;
switch(op)
{
case BIAS: return os << "BIAS";
case BIAS_CLAMP: return os << "BIAS_CLAMP";
case BIAS_BNORM_CLAMP: return os << "BIAS_BNORM_CLAMP";
case BILINEAR: return os << "BILINEAR";
case CLAMP: return os << "CLAMP";
case SCALE: return os << "SCALE";
case PASS_THROUGH: return os << "PASS_THROUGH";
default: return os << "Unknown";
}
}
inline std::ostream& operator<<(std::ostream& os, PipelineVersion ver)
{
using enum PipelineVersion;
switch(ver)
{
case V1: return os << "V1";
case V2: return os << "V2";
case V3: return os << "V3";
case V4: return os << "V4";
case V5: return os << "V5";
case WEIGHT_ONLY: return os << "WEIGHT_ONLY";
default: return os << "Unknown";
}
}
inline std::ostream& operator<<(std::ostream& os, GemmSpecialization spec)
{
using enum GemmSpecialization;
switch(spec)
{
case Default: return os << "Default";
case MPadding: return os << "MPadding";
case NPadding: return os << "NPadding";
case KPadding: return os << "KPadding";
case MNPadding: return os << "MNPadding";
case MKPadding: return os << "MKPadding";
case NKPadding: return os << "NKPadding";
case MNKPadding: return os << "MNKPadding";
case OPadding: return os << "OPadding";
case MOPadding: return os << "MOPadding";
case NOPadding: return os << "NOPadding";
case KOPadding: return os << "KOPadding";
case MNOPadding: return os << "MNOPadding";
case MKOPadding: return os << "MKOPadding";
case NKOPadding: return os << "NKOPadding";
case MNKOPadding: return os << "MNKOPadding";
default: return os << "Unknown";
}
}
inline std::ostream& operator<<(std::ostream& os, ConvFwdSpecialization spec)
{
using enum ConvFwdSpecialization;
switch(spec)
{
case DEFAULT: return os << "DEFAULT";
case FILTER_1X1_PAD0: return os << "FILTER_1X1_PAD0";
case FILTER_1X1_STRIDE1_PAD0: return os << "FILTER_1X1_STRIDE1_PAD0";
case FILTER_3x3: return os << "FILTER_3x3";
default: return os << "Unknown";
}
}
inline std::ostream& operator<<(std::ostream& os, ConvBwdDataSpecialization spec)
{
using enum ConvBwdDataSpecialization;
switch(spec)
{
case DEFAULT: return os << "DEFAULT";
case FILTER_1X1_STRIDE1_PAD0: return os << "FILTER_1X1_STRIDE1_PAD0";
default: return os << "Unknown";
}
}
inline std::ostream& operator<<(std::ostream& os, ConvBwdWeightSpecialization spec)
{
using enum ConvBwdWeightSpecialization;
switch(spec)
{
case DEFAULT: return os << "DEFAULT";
case FILTER_1X1_STRIDE1_PAD0: return os << "FILTER_1X1_STRIDE1_PAD0";
case FILTER_1X1_PAD0: return os << "FILTER_1X1_PAD0";
case ODD_C: return os << "ODD_C";
default: return os << "Unknown";
}
}
inline std::ostream& operator<<(std::ostream& os, GemmPadding padding)
{
using enum GemmPadding;
switch(padding)
{
case DEFAULT: return os << "DEFAULT";
case M_PADDING: return os << "M_PADDING";
case N_PADDING: return os << "N_PADDING";
case K_PADDING: return os << "K_PADDING";
case MN_PADDING: return os << "MN_PADDING";
case MK_PADDING: return os << "MK_PADDING";
case NK_PADDING: return os << "NK_PADDING";
case MNK_PADDING: return os << "MNK_PADDING";
case O_PADDING: return os << "O_PADDING";
case MO_PADDING: return os << "MO_PADDING";
case NO_PADDING: return os << "NO_PADDING";
case KO_PADDING: return os << "KO_PADDING";
case MNO_PADDING: return os << "MNO_PADDING";
case MKO_PADDING: return os << "MKO_PADDING";
case NKO_PADDING: return os << "NKO_PADDING";
case MNKO_PADDING: return os << "MNKO_PADDING";
default: return os << "Unknown";
}
}
inline std::ostream& operator<<(std::ostream& os, PipelineScheduler sched)
{
using enum PipelineScheduler;
switch(sched)
{
case DEFAULT: return os << "DEFAULT";
case INTRAWAVE: return os << "INTRAWAVE";
case INTERWAVE: return os << "INTERWAVE";
default: return os << "Unknown";
}
}
// ostream operator overload for std::variant of layout types
inline std::ostream&
operator<<(std::ostream& os,
const std::variant<GroupConvLayout1D, GroupConvLayout2D, GroupConvLayout3D>& layout)
{
std::visit([&os](const auto& l) { os << l; }, layout);
return os;
}
// ostream operator overload for std::variant of convolution specializations
inline std::ostream& operator<<(std::ostream& os,
const std::variant<ConvFwdSpecialization,
ConvBwdDataSpecialization,
ConvBwdWeightSpecialization>& spec)
{
std::visit([&os](const auto& s) { os << s; }, spec);
return os;
}
} // namespace ck_tile::builder

View File

@@ -43,6 +43,8 @@ add_ck_builder_test(test_ckb_build_fwd_instances
conv/test_ckb_conv_fwd_2d_bf16.cpp
conv/test_ckb_conv_fwd_2d_fp16.cpp
conv/test_ckb_conv_fwd_2d_fp32.cpp
conv/test_ckb_conv_fwd_2d_dl_fp16.cpp
conv/test_ckb_conv_fwd_2d_large_tensor_fp16.cpp
conv/test_ckb_conv_fwd_3d_bf16.cpp
conv/test_ckb_conv_fwd_3d_fp16.cpp
conv/test_ckb_conv_fwd_3d_fp32.cpp)
@@ -64,6 +66,12 @@ add_ck_factory_test(test_ckb_factory_grouped_convolution_forward_bias_bnorm_clam
add_ck_factory_test(test_ckb_factory_grouped_convolution_forward_scaleadd_scaleadd_relu test_ck_factory_grouped_convolution_forward_scaleadd_scaleadd_relu.cpp)
add_ck_factory_test(test_ckb_factory_grouped_convolution_forward_dynamic_op test_ck_factory_grouped_convolution_forward_dynamic_op.cpp)
add_ck_builder_test(test_conv_traits
conv/test_conv_traits.cpp)
add_ck_builder_test(test_conv_description
test_conv_description.cpp)
# Function to add all test_ckb targets to a list
function(collect_test_ckb_targets result_var)
# Get all targets in current directory

View File

@@ -27,7 +27,7 @@ TEST(FwdConvInstances,
run_test_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<
FwdConvSignature,
FwdThreadBlock,
BlockGemmPipelineVersion::V2,
PipelineVersion::V2,
ConvFwdSpecialization::FILTER_1X1_STRIDE1_PAD0>();
}

View File

@@ -25,7 +25,7 @@ TEST(FwdConvInstances,
run_test_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<FwdConvSignature,
FwdThreadBlock,
BlockGemmPipelineVersion::V1,
PipelineVersion::V1,
ConvFwdSpecialization::DEFAULT>();
}
@@ -47,7 +47,7 @@ TEST(FwdConvInstances,
run_test_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<FwdConvSignature,
FwdThreadBlock,
BlockGemmPipelineVersion::V5,
PipelineVersion::V5,
ConvFwdSpecialization::FILTER_3x3>();
}

View File

@@ -0,0 +1,69 @@
// Copyright (C) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "utils/ckb_conv_test_common.hpp"
using namespace ck_tile::builder::test_utils;
namespace ck_tile::builder::testing {
TEST(FwdConvInstances, Create_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK_Instance_2D_FP16_GNHWC)
{
constexpr ConvSignature FwdConvSignature{
.spatial_dim = 2,
.direction = ConvDirection::FORWARD,
.layout = GroupConvLayout2D::GNHWC_GKYXC_GNHWK,
.data_type = DataType::FP16,
.elementwise_operation = ElementwiseOperation::PASS_THROUGH,
.device_operation =
FwdGroupConvDeviceOperation::DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK};
constexpr ThreadBlock FwdThreadBlock{.block_size = 256,
.tile_size = {.m = 128, .n = 128, .k = 16}};
run_test_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK<FwdConvSignature,
FwdThreadBlock,
ConvFwdSpecialization::DEFAULT>();
}
TEST(FwdConvInstances, Create_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK_Instance_2D_FP16_NHWGC)
{
constexpr ConvSignature FwdConvSignature{
.spatial_dim = 2,
.direction = ConvDirection::FORWARD,
.layout = GroupConvLayout2D::NHWGC_GKYXC_NHWGK,
.data_type = DataType::FP16,
.elementwise_operation = ElementwiseOperation::PASS_THROUGH,
.device_operation =
FwdGroupConvDeviceOperation::DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK};
constexpr ThreadBlock FwdThreadBlock{.block_size = 256,
.tile_size = {.m = 128, .n = 128, .k = 16}};
run_test_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK<FwdConvSignature,
FwdThreadBlock,
ConvFwdSpecialization::DEFAULT>();
}
TEST(FwdConvInstances,
Create_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK_Instance_2D_FP16_FILTER_1X1_PAD0)
{
constexpr ConvSignature FwdConvSignature{
.spatial_dim = 2,
.direction = ConvDirection::FORWARD,
.layout = GroupConvLayout2D::GNHWC_GKYXC_GNHWK,
.data_type = DataType::FP16,
.elementwise_operation = ElementwiseOperation::PASS_THROUGH,
.device_operation =
FwdGroupConvDeviceOperation::DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK};
constexpr ThreadBlock FwdThreadBlock{.block_size = 256,
.tile_size = {.m = 128, .n = 128, .k = 16}};
run_test_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK<
FwdConvSignature,
FwdThreadBlock,
ConvFwdSpecialization::FILTER_1X1_PAD0>();
}
} // namespace ck_tile::builder::testing

View File

@@ -25,7 +25,7 @@ TEST(FwdConvInstances,
run_test_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<
FwdConvSignature,
FwdThreadBlock,
BlockGemmPipelineVersion::V3,
PipelineVersion::V3,
ConvFwdSpecialization::FILTER_1X1_PAD0>();
}

View File

@@ -25,7 +25,7 @@ TEST(FwdConvInstances,
run_test_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<
FwdConvSignature,
FwdThreadBlock,
BlockGemmPipelineVersion::V4,
PipelineVersion::V4,
ConvFwdSpecialization::FILTER_1X1_STRIDE1_PAD0>();
}

View File

@@ -0,0 +1,53 @@
// Copyright (C) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "utils/ckb_conv_test_common.hpp"
using namespace ck_tile::builder::test_utils;
namespace ck_tile::builder::testing {
TEST(FwdConvInstances,
Create_DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor_Instance_2D_FP16_GNHWC)
{
constexpr ConvSignature FwdConvSignature{
.spatial_dim = 2,
.direction = ConvDirection::FORWARD,
.layout = GroupConvLayout2D::GNHWC_GKYXC_GNHWK,
.data_type = DataType::FP16,
.elementwise_operation = ElementwiseOperation::PASS_THROUGH,
.device_operation =
FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor};
constexpr ThreadBlock FwdThreadBlock{.block_size = 256,
.tile_size = {.m = 256, .n = 128, .k = 32}};
run_test_DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor<
FwdConvSignature,
FwdThreadBlock,
ConvFwdSpecialization::DEFAULT>();
}
TEST(
FwdConvInstances,
Create_DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor_Instance_2D_FP16_GNHWC_Filter1x1Pad0)
{
constexpr ConvSignature FwdConvSignature{
.spatial_dim = 2,
.direction = ConvDirection::FORWARD,
.layout = GroupConvLayout2D::GNHWC_GKYXC_GNHWK,
.data_type = DataType::FP16,
.elementwise_operation = ElementwiseOperation::PASS_THROUGH,
.device_operation =
FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor};
constexpr ThreadBlock FwdThreadBlock{.block_size = 128,
.tile_size = {.m = 128, .n = 128, .k = 32}};
run_test_DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor<
FwdConvSignature,
FwdThreadBlock,
ConvFwdSpecialization::FILTER_1X1_PAD0>();
}
} // namespace ck_tile::builder::testing

View File

@@ -25,7 +25,7 @@ TEST(FwdConvInstances,
run_test_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<FwdConvSignature,
FwdThreadBlock,
BlockGemmPipelineVersion::V3,
PipelineVersion::V3,
ConvFwdSpecialization::DEFAULT>();
}

View File

@@ -26,7 +26,7 @@ TEST(FwdConvInstances,
run_test_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<
FwdConvSignature,
FwdThreadBlock,
BlockGemmPipelineVersion::V4,
PipelineVersion::V4,
ConvFwdSpecialization::FILTER_1X1_PAD0>();
}

View File

@@ -26,7 +26,7 @@ TEST(FwdConvInstances,
run_test_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<
FwdConvSignature,
FwdThreadBlock,
BlockGemmPipelineVersion::V1,
PipelineVersion::V1,
ConvFwdSpecialization::FILTER_1X1_PAD0>();
}

View File

@@ -0,0 +1,316 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include <gtest/gtest.h>
#include <gmock/gmock.h>
#include <concepts>
#include <ck_tile/builder/reflect/conv_traits.hpp>
#include <ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp>
#include <ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp>
#include <ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp>
namespace {
using ::testing::ElementsAre;
// Test fixture for ConvTraits tests
class ConvTraitsTest : public ::testing::Test
{
};
// Test ConvTraits with DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
TEST_F(ConvTraitsTest, ConvFwdTraitsExtraction)
{
// Define a concrete instance type with specific template parameters
using DeviceInstance =
ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<
2, // NDimSpatial
ck::tensor_layout::convolution::GNHWC, // ALayout
ck::tensor_layout::convolution::GKYXC, // BLayout
ck::Tuple<>, // DsLayout
ck::tensor_layout::convolution::GNHWK, // ELayout
ck::half_t, // ADataType
ck::half_t, // BDataType
float, // AccDataType
ck::half_t, // CShuffleDataType
ck::Tuple<>, // DsDataType
ck::half_t, // EDataType
ck::tensor_operation::element_wise::PassThrough, // AElementwiseOperation
ck::tensor_operation::element_wise::PassThrough, // BElementwiseOperation
ck::tensor_operation::element_wise::PassThrough, // CDEElementwiseOperation
ck::tensor_operation::device::ConvolutionForwardSpecialization::
Default, // ConvForwardSpecialization
ck::tensor_operation::device::GemmSpecialization::Default, // GemmSpec
256, // BlockSize
128, // MPerBlock
128, // NPerBlock
16, // KPerBlock
8, // AK1
8, // BK1
32, // MPerXDL
32, // NPerXDL
4, // MXdlPerWave
4, // NXdlPerWave
ck::Sequence<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1
ck::Sequence<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder
ck::Sequence<1, 0, 2>, // ABlockTransferSrcAccessOrder
2, // ABlockTransferSrcVectorDim
8, // ABlockTransferSrcScalarPerVector
8, // ABlockTransferDstScalarPerVector_AK1
1, // ABlockLdsExtraM
ck::Sequence<4, 64, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1
ck::Sequence<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder
ck::Sequence<1, 0, 2>, // BBlockTransferSrcAccessOrder
2, // BBlockTransferSrcVectorDim
8, // BBlockTransferSrcScalarPerVector
8, // BBlockTransferDstScalarPerVector_BK1
1, // BBlockLdsExtraN
1, // CShuffleMXdlPerWavePerShuffle
1, // CShuffleNXdlPerWavePerShuffle
ck::Sequence<1,
32,
1,
8>, // CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8, // CDEBlockTransferScalarPerVector_NPerBlock
ck::BlockGemmPipelineScheduler::Intrawave, // BlkGemmPipeSched
ck::PipelineVersion::v1, // BlkGemmPipelineVer
ck::half_t, // AComputeDataType
ck::half_t, // BComputeDataType
false>; // DirectLoad
// Use ConvTraits to extract compile-time information
using Traits = ck_tile::reflect::conv::ConvTraits<DeviceInstance>;
// Verify signature information
EXPECT_EQ(Traits::spatial_dim, 2);
EXPECT_EQ(Traits::direction, ck_tile::builder::ConvDirection::FORWARD);
EXPECT_EQ(Traits::layout, ck_tile::builder::GroupConvLayout2D::GNHWC_GKYXC_GNHWK);
EXPECT_EQ(Traits::data_type, ck_tile::builder::DataType::FP16);
EXPECT_EQ(Traits::input_element_op, ck_tile::builder::ElementwiseOperation::PASS_THROUGH);
EXPECT_EQ(Traits::weight_element_op, ck_tile::builder::ElementwiseOperation::PASS_THROUGH);
EXPECT_EQ(Traits::output_element_op, ck_tile::builder::ElementwiseOperation::PASS_THROUGH);
// Verify specializations
EXPECT_EQ(Traits::gemm_padding, ck_tile::builder::GemmPadding::DEFAULT);
EXPECT_EQ(Traits::conv_specialization, ck_tile::builder::ConvFwdSpecialization::DEFAULT);
// Verify algorithm information
EXPECT_EQ(Traits::thread_block_size, 256);
// Verify tile dimensions
EXPECT_EQ(Traits::tile_dims.m, 128);
EXPECT_EQ(Traits::tile_dims.n, 128);
EXPECT_EQ(Traits::tile_dims.k, 16);
// Verify A tile transfer info
EXPECT_EQ(Traits::a_tile_transfer.tile_dimensions.k0, 2);
EXPECT_EQ(Traits::a_tile_transfer.tile_dimensions.m_or_n, 128);
EXPECT_EQ(Traits::a_tile_transfer.tile_dimensions.k1, 8);
EXPECT_EQ(Traits::a_tile_transfer.transfer_params.k1, 8);
EXPECT_THAT(Traits::a_tile_transfer.transfer_params.thread_cluster_dims, ElementsAre(4, 64, 1));
EXPECT_THAT(Traits::a_tile_transfer.transfer_params.thread_cluster_order, ElementsAre(1, 0, 2));
EXPECT_THAT(Traits::a_tile_transfer.transfer_params.src_access_order, ElementsAre(1, 0, 2));
EXPECT_EQ(Traits::a_tile_transfer.transfer_params.src_vector_dim, 2);
EXPECT_EQ(Traits::a_tile_transfer.transfer_params.src_scalar_per_vector, 8);
EXPECT_EQ(Traits::a_tile_transfer.transfer_params.dst_scalar_per_vector_k1, 8);
EXPECT_TRUE(Traits::a_tile_transfer.transfer_params.lds_padding);
// Verify B tile transfer info
EXPECT_EQ(Traits::b_tile_transfer.tile_dimensions.k0, 2);
EXPECT_EQ(Traits::b_tile_transfer.tile_dimensions.m_or_n, 128);
EXPECT_EQ(Traits::b_tile_transfer.tile_dimensions.k1, 8);
EXPECT_EQ(Traits::b_tile_transfer.transfer_params.k1, 8);
EXPECT_THAT(Traits::b_tile_transfer.transfer_params.thread_cluster_dims, ElementsAre(4, 64, 1));
EXPECT_THAT(Traits::b_tile_transfer.transfer_params.thread_cluster_order, ElementsAre(1, 0, 2));
EXPECT_THAT(Traits::b_tile_transfer.transfer_params.src_access_order, ElementsAre(1, 0, 2));
EXPECT_EQ(Traits::b_tile_transfer.transfer_params.src_vector_dim, 2);
EXPECT_EQ(Traits::b_tile_transfer.transfer_params.src_scalar_per_vector, 8);
EXPECT_EQ(Traits::b_tile_transfer.transfer_params.dst_scalar_per_vector_k1, 8);
EXPECT_TRUE(Traits::b_tile_transfer.transfer_params.lds_padding);
// Verify warp GEMM params
EXPECT_EQ(Traits::warp_gemm.gemm_m, 32);
EXPECT_EQ(Traits::warp_gemm.gemm_n, 32);
EXPECT_EQ(Traits::warp_gemm.m_iter, 4);
EXPECT_EQ(Traits::warp_gemm.n_iter, 4);
// Verify output tile transfer info
EXPECT_EQ(Traits::c_tile_transfer.shuffle_params.m_gemms_per_shuffle, 1);
EXPECT_EQ(Traits::c_tile_transfer.shuffle_params.n_gemms_per_shuffle, 1);
EXPECT_THAT(Traits::c_tile_transfer.thread_cluster_dims, ElementsAre(1, 32, 1, 8));
EXPECT_EQ(Traits::c_tile_transfer.scalar_per_vector, 8);
// Verify pipeline configuration
EXPECT_EQ(Traits::pipeline_scheduler, ck_tile::builder::PipelineScheduler::INTRAWAVE);
EXPECT_EQ(Traits::pipeline_version, ck_tile::builder::PipelineVersion::V1);
}
// Test ConvTraits with DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
TEST_F(ConvTraitsTest, ConvFwdBaseTraitsExtraction)
{
// Define a concrete instance type with specific template parameters
using DeviceInstance =
ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<
2, // NDimSpatial
ck::tensor_layout::convolution::GNHWC, // ALayout
ck::tensor_layout::convolution::GKYXC, // BLayout
ck::Tuple<>, // DsLayout
ck::tensor_layout::convolution::GNHWK, // ELayout
ck::half_t, // ADataType
ck::half_t, // BDataType
float, // AccDataType
ck::half_t, // CShuffleDataType
ck::Tuple<>, // DsDataType
ck::half_t, // EDataType
ck::tensor_operation::element_wise::PassThrough, // AElementwiseOperation
ck::tensor_operation::element_wise::PassThrough, // BElementwiseOperation
ck::tensor_operation::element_wise::PassThrough, // CDEElementwiseOperation
ck::tensor_operation::device::ConvolutionForwardSpecialization::
Default, // ConvForwardSpecialization
ck::tensor_operation::device::GemmSpecialization::Default, // GemmSpec
1, // NumGemmKPrefetchStage
256, // BlockSize
128, // MPerBlock
128, // NPerBlock
16, // KPerBlock
8, // AK1
8, // BK1
32, // MPerXDL
32, // NPerXDL
4, // MXdlPerWave
4, // NXdlPerWave
ck::Sequence<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1
ck::Sequence<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder
ck::Sequence<1, 0, 2>, // ABlockTransferSrcAccessOrder
2, // ABlockTransferSrcVectorDim
8, // ABlockTransferSrcScalarPerVector
8, // ABlockTransferDstScalarPerVector_AK1
1, // ABlockLdsExtraM
ck::Sequence<4, 64, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1
ck::Sequence<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder
ck::Sequence<1, 0, 2>, // BBlockTransferSrcAccessOrder
2, // BBlockTransferSrcVectorDim
8, // BBlockTransferSrcScalarPerVector
8, // BBlockTransferDstScalarPerVector_BK1
1, // BBlockLdsExtraN
1, // CShuffleMXdlPerWavePerShuffle
1, // CShuffleNXdlPerWavePerShuffle
ck::Sequence<1,
32,
1,
8>, // CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8, // CDEBlockTransferScalarPerVector_NPerBlock
ck::half_t, // AComputeDataType
ck::half_t, // BComputeDataType
ck::LoopScheduler::Default, // LoopSched
1>; // NumGroupsToMerge
// Use ConvTraits to extract compile-time information
using Traits = ck_tile::reflect::conv::ConvTraits<DeviceInstance>;
// Verify signature information
EXPECT_EQ(Traits::spatial_dim, 2);
EXPECT_EQ(Traits::direction, ck_tile::builder::ConvDirection::FORWARD);
EXPECT_EQ(Traits::layout, ck_tile::builder::GroupConvLayout2D::GNHWC_GKYXC_GNHWK);
EXPECT_EQ(Traits::data_type, ck_tile::builder::DataType::FP16);
EXPECT_EQ(Traits::input_element_op, ck_tile::builder::ElementwiseOperation::PASS_THROUGH);
EXPECT_EQ(Traits::weight_element_op, ck_tile::builder::ElementwiseOperation::PASS_THROUGH);
EXPECT_EQ(Traits::output_element_op, ck_tile::builder::ElementwiseOperation::PASS_THROUGH);
// Verify specializations
EXPECT_EQ(Traits::gemm_padding, ck_tile::builder::GemmPadding::DEFAULT);
EXPECT_EQ(Traits::conv_specialization, ck_tile::builder::ConvFwdSpecialization::DEFAULT);
// Verify algorithm information
EXPECT_EQ(Traits::thread_block_size, 256);
// Verify tile dimensions
EXPECT_EQ(Traits::tile_dims.m, 128);
EXPECT_EQ(Traits::tile_dims.n, 128);
EXPECT_EQ(Traits::tile_dims.k, 16);
}
// Test ConvTraits with DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor
TEST_F(ConvTraitsTest, ConvFwdLargeTensorTraitsExtraction)
{
// Define a concrete instance type with specific template parameters
using DeviceInstance =
ck::tensor_operation::device::DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor<
2, // NDimSpatial
ck::tensor_layout::convolution::GNHWC, // ALayout
ck::tensor_layout::convolution::GKYXC, // BLayout
ck::Tuple<>, // DsLayout
ck::tensor_layout::convolution::GNHWK, // ELayout
ck::half_t, // ADataType
ck::half_t, // BDataType
float, // AccDataType
ck::half_t, // CShuffleDataType
ck::Tuple<>, // DsDataType
ck::half_t, // EDataType
ck::tensor_operation::element_wise::PassThrough, // AElementwiseOperation
ck::tensor_operation::element_wise::PassThrough, // BElementwiseOperation
ck::tensor_operation::element_wise::PassThrough, // CDEElementwiseOperation
ck::tensor_operation::device::ConvolutionForwardSpecialization::
Default, // ConvForwardSpecialization
ck::tensor_operation::device::GemmSpecialization::Default, // GemmSpec
1, // NumGemmKPrefetchStage
256, // BlockSize
128, // MPerBlock
128, // NPerBlock
16, // KPerBlock
8, // AK1
8, // BK1
32, // MPerXDL
32, // NPerXDL
4, // MXdlPerWave
4, // NXdlPerWave
ck::Sequence<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1
ck::Sequence<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder
ck::Sequence<1, 0, 2>, // ABlockTransferSrcAccessOrder
2, // ABlockTransferSrcVectorDim
8, // ABlockTransferSrcScalarPerVector
8, // ABlockTransferDstScalarPerVector_AK1
1, // ABlockLdsExtraM
ck::Sequence<4, 64, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1
ck::Sequence<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder
ck::Sequence<1, 0, 2>, // BBlockTransferSrcAccessOrder
2, // BBlockTransferSrcVectorDim
8, // BBlockTransferSrcScalarPerVector
8, // BBlockTransferDstScalarPerVector_BK1
1, // BBlockLdsExtraN
1, // CShuffleMXdlPerWavePerShuffle
1, // CShuffleNXdlPerWavePerShuffle
ck::Sequence<1,
32,
1,
8>, // CDEBlockTransferClusterLengths
8, // CDEBlockTransferScalarPerVector_NPerBlock
ck::half_t, // AComputeDataType
ck::half_t, // BComputeDataType
ck::LoopScheduler::Default>; // LoopSched
// Use ConvTraits to extract compile-time information
using Traits = ck_tile::reflect::conv::ConvTraits<DeviceInstance>;
// Verify signature information
EXPECT_EQ(Traits::spatial_dim, 2);
EXPECT_EQ(Traits::direction, ck_tile::builder::ConvDirection::FORWARD);
EXPECT_EQ(Traits::layout, ck_tile::builder::GroupConvLayout2D::GNHWC_GKYXC_GNHWK);
EXPECT_EQ(Traits::data_type, ck_tile::builder::DataType::FP16);
EXPECT_EQ(Traits::input_element_op, ck_tile::builder::ElementwiseOperation::PASS_THROUGH);
EXPECT_EQ(Traits::weight_element_op, ck_tile::builder::ElementwiseOperation::PASS_THROUGH);
EXPECT_EQ(Traits::output_element_op, ck_tile::builder::ElementwiseOperation::PASS_THROUGH);
// Verify specializations
EXPECT_EQ(Traits::gemm_padding, ck_tile::builder::GemmPadding::DEFAULT);
EXPECT_EQ(Traits::conv_specialization, ck_tile::builder::ConvFwdSpecialization::DEFAULT);
// Verify algorithm information
EXPECT_EQ(Traits::thread_block_size, 256);
// Verify tile dimensions
EXPECT_EQ(Traits::tile_dims.m, 128);
EXPECT_EQ(Traits::tile_dims.n, 128);
EXPECT_EQ(Traits::tile_dims.k, 16);
}
} // anonymous namespace

View File

@@ -49,14 +49,14 @@ struct GridwiseWmmaGemm
size_t n_per_wmma = 0;
size_t m_wmma_per_wave = 0;
size_t n_wmma_per_wave = 0;
GridwiseGemmPipelineVersion pipeline_version;
PipelineVersion pipeline_version;
};
static_assert(ckb::GridwiseWmmaGemmDescriptor<GridwiseWmmaGemm>);
struct BlockGemm
{
BlockGemmPipelineVersion pipeline_version;
BlockGemmPipelineScheduler scheduler;
PipelineVersion pipeline_version;
PipelineScheduler scheduler;
};
static_assert(ckb::BlockGemmDescriptor<BlockGemm>);
@@ -156,7 +156,7 @@ struct ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
GemmSpecialization gemm_specialization;
size_t num_gemm_k_prefetch_stages;
size_t num_groups_to_merge;
LoopScheduler loop_scheduler;
PipelineScheduler loop_scheduler;
};
static_assert(
ckb::ConvAlgorithmDescriptor<ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle>);
@@ -191,7 +191,7 @@ struct ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
ConvFwdSpecialization fwd_specialization;
GemmSpecialization gemm_specialization;
size_t num_gemm_k_prefetch_stages;
LoopScheduler loop_scheduler;
PipelineScheduler loop_scheduler;
};
static_assert(
ckb::ConvAlgorithmDescriptor<ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle>);
@@ -214,4 +214,84 @@ static_assert(
static_assert(
ckb::SpecifiesLoopScheduler<ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle>);
// DL-specific descriptors
struct DlThreadConfig
{
size_t k0_per_block;
size_t k1;
size_t m1_per_thread;
size_t n1_per_thread;
size_t k_per_thread;
};
static_assert(ckb::DlThreadConfigDescriptor<DlThreadConfig>);
struct DlThreadCluster
{
std::array<size_t, 2> m1_xs; // e.g., {8, 2}
std::array<size_t, 2> n1_xs; // e.g., {8, 2}
};
static_assert(ckb::DlThreadClusterDescriptor<DlThreadCluster>);
struct DlBlockTransferK0M0M1K1
{
std::array<size_t, 4> thread_slice_lengths;
std::array<size_t, 4> thread_cluster_lengths;
std::array<size_t, 4> thread_cluster_arrange_order;
std::array<size_t, 4> src_access_order;
std::array<size_t, 4> src_vector_tensor_lengths;
std::array<size_t, 4> src_vector_tensor_contiguous_dim_order;
std::array<size_t, 4> dst_vector_tensor_lengths;
};
static_assert(ckb::DlBlockTransferK0M0M1K1Descriptor<DlBlockTransferK0M0M1K1>);
struct DlBlockTransferK0N0N1K1
{
std::array<size_t, 4> thread_slice_lengths;
std::array<size_t, 4> thread_cluster_lengths;
std::array<size_t, 4> thread_cluster_arrange_order;
std::array<size_t, 4> src_access_order;
std::array<size_t, 4> src_vector_tensor_lengths;
std::array<size_t, 4> src_vector_tensor_contiguous_dim_order;
std::array<size_t, 4> dst_vector_tensor_lengths;
};
static_assert(ckb::DlBlockTransferK0N0N1K1Descriptor<DlBlockTransferK0N0N1K1>);
struct DlCThreadTransfer
{
std::array<size_t, 6> src_dst_access_order;
size_t src_dst_vector_dim;
size_t dst_scalar_per_vector;
};
static_assert(ckb::DlCThreadTransferDescriptor<DlCThreadTransfer>);
struct ConvAlgorithm_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK
{
ThreadBlock thread_block;
ConvFwdSpecialization fwd_specialization;
GemmSpecialization gemm_specialization;
DlThreadConfig dl_thread_config;
DlThreadCluster dl_thread_cluster;
DlBlockTransferK0M0M1K1 dl_block_transfer_a;
DlBlockTransferK0N0N1K1 dl_block_transfer_b;
DlCThreadTransfer dl_c_thread_transfer;
};
static_assert(
ckb::ConvAlgorithmDescriptor<ConvAlgorithm_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK>);
static_assert(
ckb::SpecifiesThreadBlock<ConvAlgorithm_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK>);
static_assert(ckb::SpecifiesFwdConcSpecialization<
ConvAlgorithm_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK>);
static_assert(
ckb::SpecifiesGemmSpecialization<ConvAlgorithm_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK>);
static_assert(
ckb::SpecifiesDlThreadConfig<ConvAlgorithm_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK>);
static_assert(
ckb::SpecifiesDlThreadCluster<ConvAlgorithm_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK>);
static_assert(
ckb::SpecifiesDlBlockTransferA<ConvAlgorithm_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK>);
static_assert(
ckb::SpecifiesDlBlockTransferB<ConvAlgorithm_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK>);
static_assert(
ckb::SpecifiesDlCThreadTransfer<ConvAlgorithm_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK>);
} // namespace ck_tile::builder::test

View File

@@ -0,0 +1,169 @@
// SPDX-License-Identifier: MIT
// Copyright (c) Advanced Micro Devices, Inc. All rights reserved.
#include <gtest/gtest.h>
#include <gmock/gmock.h>
#include <ck_tile/builder/conv_builder.hpp>
#include <ck_tile/builder/reflect/conv_description.hpp>
#include "testing_utils.hpp"
#include "impl/conv_signature_types.hpp"
#include "impl/conv_algorithm_types.hpp"
namespace {
namespace ckb = ck_tile::builder;
namespace ckr = ck_tile::reflect::conv;
namespace ckt = ck_tile::test;
// Defines the signature of the convolution operation to be tested.
// This includes dimensionality, direction, data layout, and data type.
struct ConvSignature
{
int spatial_dim = 2;
ckb::ConvDirection direction = ckb::ConvDirection::FORWARD;
ckb::GroupConvLayout layout = ckb::GroupConvLayout2D::GNHWC_GKYXC_GNHWK;
ckb::DataType data_type = ckb::DataType::FP16;
ckb::ElementwiseOperation elementwise_operation = ckb::ElementwiseOperation::PASS_THROUGH;
ckb::GroupConvDeviceOp device_operation =
ckb::FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3;
};
static_assert(ckb::ConvSignatureDescriptor<ConvSignature>);
struct DefaultAlgorithm
{
ckb::test::ThreadBlock thread_block{.block_size = 256,
.tile_size = {.m = 256, .n = 256, .k = 32}};
ckb::test::GridwiseXdlGemm gridwise_gemm{.ak1 = 8,
.bk1 = 8,
.m_per_xdl = 16,
.n_per_xdl = 16,
.m_xdl_per_wave = 4,
.n_xdl_per_wave = 4};
ckb::test::BlockTransferABC block_transfer{
.block_transfer_a = {.k0 = 4, .m_n = 256, .k1 = 8},
.block_transfer_b = {.k0 = 4, .m_n = 256, .k1 = 8},
.thread_cluster_dims_c = {.m_block = 1,
.m_wave_per_xdl = 32,
.n_block = 1,
.n_wave_per_xdl = 8},
.lds_transfer_a = {.src_vector_dim = 2,
.src_scalar_per_vector = 8,
.lds_dst_scalar_per_vector = 8,
.is_direct_load = true,
.lds_padding = false},
.lds_transfer_b = {.src_vector_dim = 2,
.src_scalar_per_vector = 8,
.lds_dst_scalar_per_vector = 8,
.is_direct_load = true,
.lds_padding = false},
.epilogue_c = {.m_per_wave_per_shuffle = 1,
.n_per_wave_per_shuffle = 1,
.scalar_per_vector = 8},
.block_transfer_access_order_a = {.order = {0, 1, 2}},
.block_transfer_access_order_b = {.order = {0, 1, 2}},
.src_access_order_a = {.order = {0, 1, 2}},
.src_access_order_b = {.order = {0, 1, 2}}};
ckb::ConvFwdSpecialization fwd_specialization = ckb::ConvFwdSpecialization::DEFAULT;
ckb::GemmSpecialization gemm_specialization = ckb::GemmSpecialization::Default;
ckb::test::BlockGemm block_gemm{.pipeline_version = ckb::PipelineVersion::V4,
.scheduler = ckb::PipelineScheduler::INTRAWAVE};
};
static_assert(ckb::ConvAlgorithmDescriptor<DefaultAlgorithm>);
TEST(ConvDescriptionTest, DefaultInstanceHasBriefDescription)
{
static constexpr const ConvSignature SIGNATURE;
static constexpr const DefaultAlgorithm ALGORITHM;
using Builder = ckb::ConvBuilder<SIGNATURE, ALGORITHM>;
EXPECT_THAT(ckr::Describe<Builder>().brief(), ckt::StringEqWithDiff("2D Forward convolution"));
}
TEST(ConvDescriptionTest, DefaultInstanceHasDetailedDescription)
{
static constexpr const ConvSignature SIGNATURE;
static constexpr const DefaultAlgorithm ALGORITHM;
using Builder = ckb::ConvBuilder<SIGNATURE, ALGORITHM>;
EXPECT_THAT(ckr::Describe<Builder>().detailed(),
ckt::StringEqWithDiff( //
"2D Forward Convolution Kernel\n"
"├─ Signature\n"
"│ ├─ Tensor Type: FP16\n"
"│ ├─ Memory Layout: GNHWC_GKYXC_GNHWK\n"
"│ ├─ Input elementwise operation: PASS_THROUGH\n"
"│ ├─ Weights elementwise operation: PASS_THROUGH\n"
"│ └─ Output elementwise operation: PASS_THROUGH\n"
"├─ Algorithm\n"
"│ ├─ Thread block size: 256\n"
"│ ├─ Data tile size: 256×256×32\n"
"│ ├─ Gemm padding: DEFAULT\n"
"│ ├─ Convolution specialization: DEFAULT\n"
"│ ├─ Pipeline version: V4\n"
"│ ├─ Pipeline scheduler: INTRAWAVE\n"
"│ ├─ Warp Gemm parameters: \n"
"│ │ ├─ subtile size: 16×16\n"
"│ │ └─ Number of warp gemm iterations: 4×4\n"
"│ ├─ Memory access:\n"
"│ │ ├─ A Tile transfer: \n"
"│ │ │ ├─ Tile dimensions: 4×256×8×\n"
"│ │ │ ├─ The innermost K subdimension size: 8\n"
"│ │ │ ├─ Spatial thread distribution over the data tile: 0×1×2\n"
"│ │ │ ├─ The order of accessing data tile axes: 0×1×2\n"
"│ │ │ ├─ Vectorized memory access axis index (with contiguous memory): 2\n"
"│ │ │ ├─ Vector access (GMEM read) instruction size: 8\n"
"│ │ │ ├─ Vector access (LDS write) instruction size: 8\n"
"│ │ │ └─ LDS data layout padding (to prevent bank conflicts): 8\n"
"│ │ ├─ B Tile transfer: \n"
"│ │ │ ├─ Tile dimensions: 4×256×8×\n"
"│ │ │ ├─ The innermost K subdimension size: 8\n"
"│ │ │ ├─ Spatial thread distribution over the data tile: 0×1×2\n"
"│ │ │ ├─ The order of accessing data tile axes: 0×1×2\n"
"│ │ │ ├─ Vectorized memory access axis index (with contiguous memory): 2\n"
"│ │ │ ├─ Vector access (GMEM read) instruction size: 8\n"
"│ │ │ ├─ Vector access (LDS write) instruction size: 8\n"
"│ │ │ └─ LDS data layout padding (to prevent bank conflicts): 8\n"
"│ │ └─ C Tile transfer: \n"
"│ │ ├─ Data shuffle (number of gemm instructions per iteration): 1×1\n"
"│ │ ├─ Spatial thread distribution used to store data: 1×32×1×8\n"
"│ │ └─ Vector access (GMEM write) instruction size: 8\n"
"│ └─ \n"
"└─ "));
}
// NOTE: BackwardDataInstanceHasDetailedDescription test is disabled because ConvFactory
// does not have a specialization for backward data convolutions. The test fails with:
// "implicit instantiation of undefined template 'ck_tile::builder::ConvFactory<...>'"
//
// To enable this test, a ConvFactory specialization for backward data operations must be
// implemented first.
//
// TEST(ConvDescriptionTest, BackwardDataInstanceHasDetailedDescription)
// {
// struct BackwardDataSignature
// {
// int spatial_dim = 2;
// ckb::ConvDirection direction = ckb::ConvDirection::BACKWARD_DATA;
// ckb::GroupConvLayout layout = ckb::GroupConvLayout2D::GNHWC_GKYXC_GNHWK;
// ckb::DataType data_type = ckb::DataType::FP16;
// ckb::ElementwiseOperation elementwise_operation =
// ckb::ElementwiseOperation::PASS_THROUGH; ckb::GroupConvDeviceOp device_operation =
// ckb::BwdDataGroupConvDeviceOperation::DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1;
// };
// static_assert(ckb::ConvSignatureDescriptor<BackwardDataSignature>);
//
// static constexpr const BackwardDataSignature SIGNATURE;
// static constexpr const DefaultAlgorithm ALGORITHM;
// using Builder = ckb::ConvBuilder<SIGNATURE, ALGORITHM>;
//
// // Verify Brief works
// EXPECT_THAT(ckr::Describe<Builder>().brief(),
// ckt::StringEqWithDiff("2D Backward Data convolution"));
//
// // Verify detailed works - to be updated once ConvFactory is implemented
// EXPECT_THAT(ckr::Describe<Builder>().detailed(),
// ckt::StringEqWithDiff("PLACEHOLDER"));
// }
} // namespace

View File

@@ -16,7 +16,7 @@ using namespace test;
// Common test implementation
template <ConvSignature FwdConvSignature,
ThreadBlock FwdThreadBlock,
BlockGemmPipelineVersion FwdPipelineVersion,
PipelineVersion FwdPipelineVersion,
ConvFwdSpecialization FwdConvSpecialization>
constexpr void run_test_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3()
{
@@ -52,7 +52,7 @@ constexpr void run_test_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3()
.src_access_order_b = {1, 0, 2}};
constexpr BlockGemm BlockGemmDesc = {.pipeline_version = FwdPipelineVersion,
.scheduler = BlockGemmPipelineScheduler::INTRAWAVE};
.scheduler = PipelineScheduler::INTRAWAVE};
constexpr ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 FwdConvAlgorithm{
.thread_block = FwdThreadBlock,
@@ -73,13 +73,13 @@ constexpr void run_test_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3()
EXPECT_TRUE(kernel_string.starts_with("DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3"));
// Verify pipeline version is correct
if(FwdPipelineVersion == BlockGemmPipelineVersion::V1)
if(FwdPipelineVersion == PipelineVersion::V1)
EXPECT_TRUE(kernel_string.find("BlkGemmPipelineVersion: v1") != std::string::npos);
else if(FwdPipelineVersion == BlockGemmPipelineVersion::V3)
else if(FwdPipelineVersion == PipelineVersion::V3)
EXPECT_TRUE(kernel_string.find("BlkGemmPipelineVersion: v3") != std::string::npos);
else if(FwdPipelineVersion == BlockGemmPipelineVersion::V4)
else if(FwdPipelineVersion == PipelineVersion::V4)
EXPECT_TRUE(kernel_string.find("BlkGemmPipelineVersion: v4") != std::string::npos);
else if(FwdPipelineVersion == BlockGemmPipelineVersion::V5)
else if(FwdPipelineVersion == PipelineVersion::V5)
EXPECT_TRUE(kernel_string.find("BlkGemmPipelineVersion: v5") != std::string::npos);
// Verify specialization is correct
@@ -140,7 +140,7 @@ constexpr void run_test_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle()
.gemm_specialization = GemmSpecialization::MNKPadding,
.num_gemm_k_prefetch_stages = 1,
.num_groups_to_merge = 2,
.loop_scheduler = LoopScheduler::DEFAULT};
.loop_scheduler = PipelineScheduler::DEFAULT};
using Builder = ConvBuilder<FwdConvSignature, FwdConvAlgorithm>;
@@ -176,7 +176,7 @@ constexpr void run_test_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle()
.n_per_wmma = 32,
.m_wmma_per_wave = 2,
.n_wmma_per_wave = 1,
.pipeline_version = GridwiseGemmPipelineVersion::V1};
.pipeline_version = PipelineVersion::V1};
constexpr BlockTransferABC FwdBlockTransfer{.block_transfer_a = {.k0 = 4, .m_n = 32, .k1 = 1},
.block_transfer_b = {.k0 = 4, .m_n = 32, .k1 = 1},
@@ -209,7 +209,7 @@ constexpr void run_test_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle()
.fwd_specialization = FwdConvSpecialization,
.gemm_specialization = GemmSpecialization::MNKPadding,
.num_gemm_k_prefetch_stages = 1,
.loop_scheduler = LoopScheduler::DEFAULT};
.loop_scheduler = PipelineScheduler::DEFAULT};
using Builder = ConvBuilder<FwdConvSignature, FwdConvAlgorithm>;
@@ -235,4 +235,149 @@ constexpr void run_test_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle()
EXPECT_NE(invoker_ptr, nullptr);
}
template <ConvSignature FwdConvSignature,
ThreadBlock FwdThreadBlock,
ConvFwdSpecialization FwdConvSpecialization>
constexpr void run_test_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK()
{
// DL thread configuration
constexpr DlThreadConfig DlThreadCfg{
.k0_per_block = 16, .k1 = 2, .m1_per_thread = 4, .n1_per_thread = 4, .k_per_thread = 1};
// DL thread cluster
constexpr DlThreadCluster DlCluster{.m1_xs = {8, 2}, .n1_xs = {8, 2}};
// DL A block transfer - K0_M0_M1_K1 format
constexpr DlBlockTransferK0M0M1K1 DlBlockTransferA{
.thread_slice_lengths = {8, 1, 1, 2},
.thread_cluster_lengths = {2, 1, 128, 1},
.thread_cluster_arrange_order = {1, 2, 0, 3},
.src_access_order = {1, 2, 0, 3},
.src_vector_tensor_lengths = {4, 1, 1, 2},
.src_vector_tensor_contiguous_dim_order = {1, 2, 0, 3},
.dst_vector_tensor_lengths = {1, 1, 1, 2}};
// DL B block transfer - K0_N0_N1_K1 format
constexpr DlBlockTransferK0N0N1K1 DlBlockTransferB{
.thread_slice_lengths = {8, 1, 1, 2},
.thread_cluster_lengths = {2, 1, 128, 1},
.thread_cluster_arrange_order = {1, 2, 0, 3},
.src_access_order = {1, 2, 0, 3},
.src_vector_tensor_lengths = {4, 1, 1, 2},
.src_vector_tensor_contiguous_dim_order = {1, 2, 0, 3},
.dst_vector_tensor_lengths = {1, 1, 1, 2}};
// DL C thread transfer
constexpr DlCThreadTransfer DlCTransfer{.src_dst_access_order = {0, 1, 2, 3, 4, 5},
.src_dst_vector_dim = 5,
.dst_scalar_per_vector = 4};
constexpr ConvAlgorithm_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK FwdConvAlgorithm{
.thread_block = FwdThreadBlock,
.fwd_specialization = FwdConvSpecialization,
.gemm_specialization = GemmSpecialization::MNKPadding,
.dl_thread_config = DlThreadCfg,
.dl_thread_cluster = DlCluster,
.dl_block_transfer_a = DlBlockTransferA,
.dl_block_transfer_b = DlBlockTransferB,
.dl_c_thread_transfer = DlCTransfer};
using Builder = ConvBuilder<FwdConvSignature, FwdConvAlgorithm>;
auto instance = typename Builder::Instance{};
const auto kernel_string = instance.GetTypeString();
std::cout << "Generated kernel: " << kernel_string << std::endl;
EXPECT_GT(kernel_string.size(), 0);
EXPECT_TRUE(kernel_string.starts_with("DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK"));
// Verify specialization is correct
if(FwdConvSpecialization == ConvFwdSpecialization::DEFAULT)
EXPECT_TRUE(kernel_string.find("Default") != std::string::npos);
else if(FwdConvSpecialization == ConvFwdSpecialization::FILTER_1X1_PAD0)
EXPECT_TRUE(kernel_string.find("Filter1x1Pad0") != std::string::npos);
else if(FwdConvSpecialization == ConvFwdSpecialization::FILTER_1X1_STRIDE1_PAD0)
EXPECT_TRUE(kernel_string.find("Filter1x1Stride1Pad0") != std::string::npos);
else if(FwdConvSpecialization == ConvFwdSpecialization::FILTER_3x3)
EXPECT_TRUE(kernel_string.find("Filter3x3") != std::string::npos);
const auto invoker_ptr = instance.MakeInvokerPointer();
EXPECT_NE(invoker_ptr, nullptr);
}
// Test helper for DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor
// Note: Large_Tensor has identical parameters to regular XDL CShuffle
template <ConvSignature FwdConvSignature,
ThreadBlock FwdThreadBlock,
ConvFwdSpecialization FwdConvSpecialization>
constexpr void run_test_DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor()
{
constexpr GridwiseXdlGemm FwdGemmParams{.ak1 = 8,
.bk1 = 8,
.m_per_xdl = 32,
.n_per_xdl = 32,
.m_xdl_per_wave = 2,
.n_xdl_per_wave = 1};
constexpr BlockTransferABC FwdBlockTransfer{.block_transfer_a = {.k0 = 4, .m_n = 16, .k1 = 1},
.block_transfer_b = {.k0 = 4, .m_n = 16, .k1 = 1},
.thread_cluster_dims_c = {.m_block = 1,
.m_wave_per_xdl = 16,
.n_block = 1,
.n_wave_per_xdl = 4},
.lds_transfer_a = {.src_vector_dim = 2,
.src_scalar_per_vector = 8,
.lds_dst_scalar_per_vector = 8,
.is_direct_load = false,
.lds_padding = true},
.lds_transfer_b = {.src_vector_dim = 2,
.src_scalar_per_vector = 8,
.lds_dst_scalar_per_vector = 8,
.is_direct_load = false,
.lds_padding = true},
.epilogue_c = {.m_per_wave_per_shuffle = 1,
.n_per_wave_per_shuffle = 1,
.scalar_per_vector = 8},
.block_transfer_access_order_a = {1, 0, 2},
.block_transfer_access_order_b = {1, 0, 2},
.src_access_order_a = {1, 0, 2},
.src_access_order_b = {1, 0, 2}};
// Large_Tensor uses the same descriptor as regular XDL CShuffle
constexpr ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle FwdConvAlgorithm{
.thread_block = FwdThreadBlock,
.gridwise_gemm = FwdGemmParams,
.block_transfer = FwdBlockTransfer,
.fwd_specialization = FwdConvSpecialization,
.gemm_specialization = GemmSpecialization::MNKPadding,
.num_gemm_k_prefetch_stages = 1,
.num_groups_to_merge = 1,
.loop_scheduler = LoopScheduler::DEFAULT};
using Builder = ConvBuilder<FwdConvSignature, FwdConvAlgorithm>;
auto instance = typename Builder::Instance{};
const auto kernel_string = instance.GetTypeString();
std::cout << "Generated kernel: " << kernel_string << std::endl;
EXPECT_GT(kernel_string.size(), 0);
EXPECT_TRUE(
kernel_string.starts_with("DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor"));
// Verify specialization is correct
if(FwdConvSpecialization == ConvFwdSpecialization::DEFAULT)
EXPECT_TRUE(kernel_string.find("Default") != std::string::npos);
else if(FwdConvSpecialization == ConvFwdSpecialization::FILTER_1X1_PAD0)
EXPECT_TRUE(kernel_string.find("Filter1x1Pad0") != std::string::npos);
else if(FwdConvSpecialization == ConvFwdSpecialization::FILTER_1X1_STRIDE1_PAD0)
EXPECT_TRUE(kernel_string.find("Filter1x1Stride1Pad0") != std::string::npos);
else if(FwdConvSpecialization == ConvFwdSpecialization::FILTER_3x3)
EXPECT_TRUE(kernel_string.find("Filter3x3") != std::string::npos);
const auto invoker_ptr = instance.MakeInvokerPointer();
EXPECT_NE(invoker_ptr, nullptr);
}
} // namespace ck_tile::builder::test_utils

View File

@@ -727,7 +727,8 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_mx_moe_gufusion_v3<
});
});
HotLoopScheduler();
if constexpr(MPerBlock >= 64)
HotLoopScheduler();
__builtin_amdgcn_sched_barrier(0);
};

View File

@@ -4,6 +4,7 @@
#pragma once
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_mx_moe_v3.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_mx_moe_v1.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_mx_moe_gufusion_v3.hpp"
namespace ck {
@@ -45,7 +46,28 @@ constexpr auto BlockGemmMXBPreshufflePipeline_Selector()
}
else
{
return nullptr;
return BlockwiseGemmXdlops_pipeline_bpreshuffle_mx_moe_v1<
BlkGemmPipeSche,
ThreadBlockSize,
ScaleBlockSize,
ADataType,
AScaleDataType,
BDataType,
BScaleDataType,
ATileDesc,
BTileDesc,
AMmaTileDesc,
BMmaTileDesc,
ABlockTransferSrcScalarPerVector,
BBlockTransferSrcScalarPerVector,
MPerBlock,
NPerBlock,
KPerBlock,
MPerXDL,
NPerXDL,
MRepeat,
NRepeat,
KPack>{};
}
}
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)

View File

@@ -0,0 +1,891 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/tensor_operation/gpu/block/blockwise_gemm_mx_pipeline_xdlops_base.hpp"
namespace ck {
// Naive pipeline with lowest resource request per WGP
// GlobalPrefetchStages: 2
// LocalPreFillStages: 1
// LocalPreFetchStages: 1
// LocalSharedMemoryBuffer: 1
template <BlockGemmPipelineScheduler BlkGemmPipelineVer,
index_t ThreadBlockSize,
index_t ScaleBlockSize,
typename ADataType,
typename AScaleDataType,
typename BDataType,
typename BScaleDataType,
typename ATileDesc,
typename BTileDesc,
typename AMmaTileDesc,
typename BMmaTileDesc,
index_t ABlockTransferSrcScalarPerVector,
index_t BBlockTransferSrcScalarPerVector,
index_t MPerBlock,
index_t NPerBlock,
index_t KPerBlock,
index_t MPerXDL,
index_t NPerXDL,
index_t MRepeat, // MXdlPerWave
index_t NRepeat, // NXdlPerWave
index_t KPack>
struct BlockwiseGemmXdlops_pipeline_bpreshuffle_mx_moe_v1
{
};
template <index_t ThreadBlockSize,
index_t ScaleBlockSize,
typename ADataType,
typename AScaleDataType,
typename BDataType,
typename BScaleDataType,
typename ATileDesc,
typename BTileDesc,
typename AMmaTileDesc,
typename BMmaTileDesc,
index_t ABlockTransferSrcScalarPerVector,
index_t BBlockTransferSrcScalarPerVector,
index_t MPerBlock,
index_t NPerBlock,
index_t KPerBlock,
index_t MPerXDL,
index_t NPerXDL,
index_t MRepeat, // MXdlPerWave
index_t NRepeat, // NXdlPerWave
index_t KPack>
struct BlockwiseGemmXdlops_pipeline_bpreshuffle_mx_moe_v1<BlockGemmPipelineScheduler::Intrawave,
ThreadBlockSize,
ScaleBlockSize,
ADataType,
AScaleDataType,
BDataType,
BScaleDataType,
ATileDesc,
BTileDesc,
AMmaTileDesc,
BMmaTileDesc,
ABlockTransferSrcScalarPerVector,
BBlockTransferSrcScalarPerVector,
MPerBlock,
NPerBlock,
KPerBlock,
MPerXDL,
NPerXDL,
MRepeat,
NRepeat,
KPack>
: BlockwiseGemmXdlops_mx_pipeline_base<ThreadBlockSize,
ADataType,
BDataType,
ATileDesc,
BTileDesc,
AMmaTileDesc,
BMmaTileDesc,
ABlockTransferSrcScalarPerVector,
BBlockTransferSrcScalarPerVector,
MPerBlock,
NPerBlock,
KPerBlock,
MPerXDL,
NPerXDL,
MRepeat,
NRepeat,
KPack>
{
using Base = BlockwiseGemmXdlops_mx_pipeline_base<ThreadBlockSize,
ADataType,
BDataType,
ATileDesc,
BTileDesc,
AMmaTileDesc,
BMmaTileDesc,
ABlockTransferSrcScalarPerVector,
BBlockTransferSrcScalarPerVector,
MPerBlock,
NPerBlock,
KPerBlock,
MPerXDL,
NPerXDL,
MRepeat,
NRepeat,
KPack>;
using Base::A_K1;
using Base::I0;
using Base::I1;
using Base::KRepeat;
using Base::MWaves;
using Base::NWaves;
using Base::WaveSize;
using Base::xdlops_gemm;
using typename Base::HotLoopInstList;
using Base::CalculateCThreadOriginDataIndex;
using Base::GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4;
using Base::GetCThreadBuffer;
using Base::GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4;
using Base::GetWaveIdx;
using Base::MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
using Base::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
using Base::a_block_desc_m0_m1_m2_m3_k;
using Base::b_block_desc_n0_n1_n2_n3_k;
using Base::AMmaKStride;
using Base::APackedSize;
using Base::BMmaKStride;
using Base::BPackedSize;
using Base::KThreadChunk;
using Base::KXdlPack;
using Base::MXdlPack;
using Base::NXdlPack;
using AccType = typename Base::AccType;
using Tuple5 = typename Base::Tuple5;
using ComputeTypeA = typename Base::ComputeTypeA;
using ComputeTypeB = typename Base::ComputeTypeB;
static constexpr index_t PrefetchStages = 2;
static constexpr index_t PrefillStages = 1;
static constexpr index_t GlobalBufferNum = 2;
static constexpr index_t HotloopLocalBufSwitch = MRepeat % 2 == 0 ? 0 : 1;
static constexpr auto num_buffer_load_a_scale = MRepeat / MXdlPack * KRepeat / KXdlPack;
static constexpr auto num_buffer_load_b_scale = NRepeat / NXdlPack * KRepeat / KXdlPack;
static constexpr auto async_vmcnt =
num_buffer_load_a_scale + num_buffer_load_b_scale + HotLoopInstList::B_Buffer_Load_Inst_Num;
static constexpr auto async_vmcnt_encoding = 3952 + async_vmcnt % 16 + async_vmcnt / 16 * 16384;
static constexpr auto ScalesPerKBlockSize =
KPerBlock / ScaleBlockSize; // How many mx-vectors per K block
//> How many mx-vectors in each row/col is processed in one call to xdlops_gemm.Run()
static constexpr auto ScalesPerXdlopsRun =
(APackedSize * KPack * xdlops_gemm.K0PerXdlops) / ScaleBlockSize;
//> How many scales a thread must read to accommodate one call to xdlops_gemm.Run()
static constexpr auto ScalesPerXdlopsRunPerThread =
ScalesPerXdlopsRun / xdlops_gemm.mfma_instr.num_input_blks;
using mx_scale_t = e8m0_bexp_t;
static constexpr auto scale_pack_size_a = sizeof(AScaleDataType) / sizeof(mx_scale_t);
static constexpr auto scale_pack_size_b = sizeof(BScaleDataType) / sizeof(mx_scale_t);
static_assert(KXdlPack * MXdlPack % scale_pack_size_a == 0,
"A scale pack data type too large!");
static_assert(KXdlPack * NXdlPack % scale_pack_size_b == 0,
"B scale pack data type too large!");
static constexpr auto a_scale_thread_vec_size = KXdlPack * MXdlPack / scale_pack_size_a;
static constexpr auto b_scale_thread_vec_size = KXdlPack * NXdlPack / scale_pack_size_b;
__host__ static constexpr bool BlockHasHotloop(index_t num_loop)
{
return num_loop > PrefetchStages;
}
__host__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop)
{
return num_loop % 2 == 0 ? TailNumber::Even : TailNumber::Odd;
}
__device__ static constexpr auto HotLoopScheduler()
{
constexpr auto num_ds_read_inst_a = HotLoopInstList::A_LDS_Read_Inst_Num;
constexpr auto num_buffer_load_inst_a = HotLoopInstList::A_Buffer_Load_Inst_Num;
constexpr auto num_buffer_load_inst_b = HotLoopInstList::B_Buffer_Load_Inst_Num * MWaves +
num_buffer_load_a_scale + num_buffer_load_b_scale;
constexpr auto mfma_interleave = MPerXDL == 32 ? 1 : 2;
// B global
static_for<0, num_buffer_load_inst_b, 1>{}([&](auto i) {
ignore = i;
if constexpr(MPerBlock >= 128 && NPerBlock >= 128)
{
__builtin_amdgcn_sched_group_barrier(0x008, 2 * mfma_interleave, 0);
}
else
{
__builtin_amdgcn_sched_group_barrier(0x008, mfma_interleave, 0);
}
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
});
// A global
static_for<0, num_buffer_load_inst_a, 1>{}([&](auto i) {
ignore = i;
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
});
// A local
static_for<0, MPerXDL == 32 ? num_ds_read_inst_a / 2 : num_ds_read_inst_a, 1>{}(
[&](auto i) {
ignore = i;
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, MPerXDL == 32 ? 2 : 1, 0); // DS read
});
}
template <bool HasMainLoop,
TailNumber TailNum,
typename AGridDesc,
typename ABlockDesc,
typename ABlockTransfer,
typename AGridBuffer,
typename ABlockBuffer,
typename ABlockTransferStep,
typename BGridDesc,
typename BBlockDesc,
typename BBlockTransfer,
typename BGridBuffer,
typename BBlockBuffer,
typename BBlockTransferStep,
typename CThreadBuffer,
typename AScaleGridBuffer,
typename AScaleGridDesc,
typename AScaleThreadTransfer,
typename BScaleGridBuffer,
typename BScaleGridDesc,
typename BScaleThreadTransfer>
__device__ void Run(
// ABlockCopy
const AGridDesc& a_grid_desc,
const ABlockDesc& a_block_desc,
ABlockTransfer& a_blockwise_copy,
const AGridBuffer& a_grid_buf,
ABlockBuffer& a_block_buf,
const ABlockTransferStep& a_block_copy_step,
// BBlockCopy
const BGridDesc& b_grid_desc,
const BBlockDesc& b_block_desc,
BBlockTransfer& b_blockwise_copy,
const BGridBuffer& b_grid_buf,
BBlockBuffer& b_block_bufs,
const BBlockTransferStep& b_block_copy_step,
// CThread
CThreadBuffer& c_thread_buf,
// A and B scales
const AScaleGridDesc& a_scale_grid_desc,
AScaleThreadTransfer& a_scale_thread_copy,
const AScaleGridBuffer& a_scale_grid_buf,
const BScaleGridDesc& b_scale_grid_desc,
BScaleThreadTransfer& b_scale_thread_copy,
const BScaleGridBuffer& b_scale_grid_buf,
index_t num_loop) const
{
ignore = b_block_bufs;
__builtin_amdgcn_sched_barrier(0);
auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeA>(
a_thread_desc_.GetElementSpaceSize());
auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeB>(
b_thread_desc_.GetElementSpaceSize());
StaticallyIndexedArray<decltype(b_thread_buf), Number<2>{}> b_thread_bufs;
constexpr auto b_block_origin_idx = make_tuple(I0, I0, I0, I0, I0);
auto a_scale_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, AScaleDataType>(
a_scale_thread_desc.GetElementSpaceSize());
auto b_scale_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, BScaleDataType>(
b_scale_thread_desc.GetElementSpaceSize());
StaticallyIndexedArray<decltype(a_scale_thread_buf), Number<2>{}> a_scale_thread_bufs;
StaticallyIndexedArray<decltype(b_scale_thread_buf), Number<2>{}> b_scale_thread_bufs;
// Global prefetch 1
a_blockwise_copy.Run(a_grid_desc, a_grid_buf, a_block_desc, a_block_buf);
b_blockwise_copy.Run(
b_grid_desc, b_grid_buf, b_block_desc, b_block_origin_idx, b_thread_bufs(I0));
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
__builtin_amdgcn_sched_barrier(0);
// Prefetch a_scales
static_for<0, MRepeat / MXdlPack, 1>{}([&](auto m0) {
static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) {
a_scale_thread_copy.Run(a_scale_grid_desc,
a_scale_grid_buf,
a_scale_thread_desc,
make_tuple(m0, k0, I0),
a_scale_thread_bufs(I0));
a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc,
make_multi_index(0, I1, 0));
});
a_scale_thread_copy.MoveSrcSliceWindow(
a_scale_grid_desc, make_multi_index(MWaves, -KRepeat / KXdlPack, 0));
});
// restore row id and advance to the next set of scales
a_scale_thread_copy.MoveSrcSliceWindow(
a_scale_grid_desc,
make_multi_index(-MWaves * MRepeat / MXdlPack, KRepeat / KXdlPack, 0));
// Prefetch b_scales
static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) {
static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) {
b_scale_thread_copy.Run(b_scale_grid_desc,
b_scale_grid_buf,
b_scale_thread_desc,
make_tuple(n0, k0, I0),
b_scale_thread_bufs(I0));
b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc,
make_multi_index(0, I1, 0));
});
b_scale_thread_copy.MoveSrcSliceWindow(
b_scale_grid_desc, make_multi_index(NWaves, -KRepeat / KXdlPack, 0));
});
// restore col id and advance to the next set of scales
// NWaves * NPerXDL * NRepeat == NPerBlock
b_scale_thread_copy.MoveSrcSliceWindow(
b_scale_grid_desc,
make_multi_index(-NWaves * NRepeat / NXdlPack, KRepeat / KXdlPack, 0));
// Local prefetch 1, sync the async load
__builtin_amdgcn_s_waitcnt(async_vmcnt_encoding);
block_sync_lds();
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, KRepeat, 1>{}([&](auto k) {
constexpr auto k_step = k * xdlops_gemm.KPerXdlops / APackedSize *
(APackedSize * KPack / xdlops_gemm.K1PerXdlops);
static_for<0, xdlops_gemm.K1PerXdlops / (APackedSize * KThreadChunk), 1>{}(
[&](auto chunk) {
constexpr auto a_k_step_chunk =
k_step + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks;
a_thread_copy_.Run(a_block_desc_m0_m1_m2_m3_k,
make_tuple(Number<m0 / MXdlPack>{},
I0,
Number<m0 % MXdlPack>{},
I0,
Number<a_k_step_chunk>{}),
a_block_buf,
a_thread_desc_,
make_tuple(Number<m0 / MXdlPack>{},
I0,
Number<m0 % MXdlPack>{},
k,
Number<chunk * KThreadChunk>{}),
a_thread_buf);
});
});
});
// Initialize C
c_thread_buf.Clear();
__builtin_amdgcn_sched_barrier(0);
// main body
if constexpr(HasMainLoop)
{
// loop over k with the step KPerBlock
index_t i = 0;
do
{
auto LoopFunc = [&](auto scale_comp_buf, auto scale_mem_buf) {
b_blockwise_copy.Run(b_grid_desc,
b_grid_buf,
b_block_desc,
b_block_origin_idx,
b_thread_bufs(scale_mem_buf));
block_sync_lds();
a_blockwise_copy.Run(a_grid_desc, a_grid_buf, a_block_desc, a_block_buf);
// Prefetch a_scales
static_for<0, MRepeat / MXdlPack, 1>{}([&](auto m0) {
static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) {
a_scale_thread_copy.Run(a_scale_grid_desc,
a_scale_grid_buf,
a_scale_thread_desc,
make_tuple(m0, k0, I0),
a_scale_thread_bufs(scale_mem_buf));
a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc,
make_multi_index(0, I1, 0));
});
a_scale_thread_copy.MoveSrcSliceWindow(
a_scale_grid_desc, make_multi_index(MWaves, -KRepeat / KXdlPack, 0));
});
// restore row id and advance to the next set of scales
a_scale_thread_copy.MoveSrcSliceWindow(
a_scale_grid_desc,
make_multi_index(-MWaves * MRepeat / MXdlPack, KRepeat / KXdlPack, 0));
// Prefetch b_scales
static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) {
static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) {
b_scale_thread_copy.Run(b_scale_grid_desc,
b_scale_grid_buf,
b_scale_thread_desc,
make_tuple(n0, k0, I0),
b_scale_thread_bufs(scale_mem_buf));
b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc,
make_multi_index(0, I1, 0));
});
b_scale_thread_copy.MoveSrcSliceWindow(
b_scale_grid_desc, make_multi_index(NWaves, -KRepeat / KXdlPack, 0));
});
// restore col id and advance to the next set of scales
// NWaves * NPerXDL * NRepeat == NPerBlock
b_scale_thread_copy.MoveSrcSliceWindow(
b_scale_grid_desc,
make_multi_index(-NWaves * NRepeat / NXdlPack, KRepeat / KXdlPack, 0));
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
static_for<0, MRepeat, 1>{}([&](auto m0) {
constexpr auto im_major = m0 / MXdlPack;
constexpr auto im_minor = m0 % MXdlPack;
static_for<0, KRepeat, 1>{}([&](auto k0) {
constexpr auto ik_major = k0 / KXdlPack;
constexpr auto ik_minor = k0 % KXdlPack;
static_for<0, NRepeat, 1>{}([&](auto n0) {
constexpr auto in_major = n0 / NXdlPack;
constexpr auto in_minor = n0 % NXdlPack;
constexpr index_t a_scale_offset =
a_scale_thread_desc.CalculateOffset(
make_tuple(im_major, ik_major, I0));
constexpr index_t b_scale_offset =
b_scale_thread_desc.CalculateOffset(
make_tuple(in_major, ik_major, I0));
static_assert(0 < ScalesPerXdlopsRunPerThread,
"Must have at least one scale per Xdlops "
"per Thread.");
vector_type<AScaleDataType, a_scale_thread_vec_size>
a_scale_thread_vec;
vector_type<BScaleDataType, b_scale_thread_vec_size>
b_scale_thread_vec;
// Pack scale_thread_buf into scale_thread_vec
static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) {
a_scale_thread_vec.template AsType<AScaleDataType>()(s) =
a_scale_thread_bufs(
scale_comp_buf)[Number<a_scale_offset + s>{}];
});
static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) {
b_scale_thread_vec.template AsType<BScaleDataType>()(s) =
b_scale_thread_bufs(
scale_comp_buf)[Number<b_scale_offset + s>{}];
});
vector_type<ComputeTypeA, KPack> a_thread_vec;
vector_type<ComputeTypeB, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeTypeA>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(im_major, I0, im_minor, k0, ik))>{}];
b_thread_vec.template AsType<ComputeTypeB>()(ik) = b_thread_bufs
[scale_comp_buf][Number<b_thread_desc_.CalculateOffset(
make_tuple(in_major, I0, in_minor, k0, ik))>{}];
});
using mfma_input_type_a =
typename vector_type<ComputeTypeA,
xdlops_gemm.K1PerXdlops /
APackedSize>::type;
using mfma_input_type_b =
typename vector_type<ComputeTypeB,
xdlops_gemm.K1PerXdlops /
BPackedSize>::type;
using mfma_scale_input_type_a =
typename vector_type<AScaleDataType,
a_scale_thread_vec_size>::type;
using mfma_scale_input_type_b =
typename vector_type<BScaleDataType,
b_scale_thread_vec_size>::type;
constexpr index_t c_offset = c_thread_desc_.CalculateOffset(
make_tuple(im_major, in_major, im_minor, in_minor, 0));
// MFMA accumulation
xdlops_gemm.template Run<ik_minor * MXdlPack + im_minor,
ik_minor * NXdlPack + in_minor>(
a_thread_vec.template AsType<mfma_input_type_a>(),
a_scale_thread_vec.template AsType<mfma_scale_input_type_a>(),
b_thread_vec.template AsType<mfma_input_type_b>(),
b_scale_thread_vec.template AsType<mfma_scale_input_type_b>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
});
});
block_sync_lds();
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, KRepeat, 1>{}([&](auto k) {
constexpr auto k_step = k * xdlops_gemm.KPerXdlops / APackedSize *
(APackedSize * KPack / xdlops_gemm.K1PerXdlops);
static_for<0,
xdlops_gemm.K1PerXdlops / (APackedSize * KThreadChunk),
1>{}([&](auto chunk) {
constexpr auto a_k_step_chunk =
k_step +
chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks;
a_thread_copy_.Run(a_block_desc_m0_m1_m2_m3_k,
make_tuple(Number<m0 / MXdlPack>{},
I0,
Number<m0 % MXdlPack>{},
I0,
Number<a_k_step_chunk>{}),
a_block_buf,
a_thread_desc_,
make_tuple(Number<m0 / MXdlPack>{},
I0,
Number<m0 % MXdlPack>{},
k,
Number<chunk * KThreadChunk>{}),
a_thread_buf);
});
});
});
HotLoopScheduler();
__builtin_amdgcn_sched_barrier(0);
};
LoopFunc(I0, I1);
LoopFunc(I1, I0);
i += 2;
} while(i < (num_loop - 2));
}
// tail
if constexpr(TailNum == TailNumber::Even)
{
b_blockwise_copy.Run(
b_grid_desc, b_grid_buf, b_block_desc, b_block_origin_idx, b_thread_bufs(I1));
block_sync_lds();
a_blockwise_copy.Run(a_grid_desc, a_grid_buf, a_block_desc, a_block_buf);
// Prefetch a_scales
static_for<0, MRepeat / MXdlPack, 1>{}([&](auto m0) {
static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) {
a_scale_thread_copy.Run(a_scale_grid_desc,
a_scale_grid_buf,
a_scale_thread_desc,
make_tuple(m0, k0, I0),
a_scale_thread_bufs(I1));
a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc,
make_multi_index(0, I1, 0));
});
a_scale_thread_copy.MoveSrcSliceWindow(
a_scale_grid_desc, make_multi_index(MWaves, -KRepeat / KXdlPack, 0));
});
// Prefetch b_scales
static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) {
static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) {
b_scale_thread_copy.Run(b_scale_grid_desc,
b_scale_grid_buf,
b_scale_thread_desc,
make_tuple(n0, k0, I0),
b_scale_thread_bufs(I1));
b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc,
make_multi_index(0, I1, 0));
});
b_scale_thread_copy.MoveSrcSliceWindow(
b_scale_grid_desc, make_multi_index(NWaves, -KRepeat / KXdlPack, 0));
});
static_for<0, MRepeat, 1>{}([&](auto m0) {
constexpr auto im_major = m0 / MXdlPack;
constexpr auto im_minor = m0 % MXdlPack;
static_for<0, KRepeat, 1>{}([&](auto k0) {
constexpr auto ik_major = k0 / KXdlPack;
constexpr auto ik_minor = k0 % KXdlPack;
static_for<0, NRepeat, 1>{}([&](auto n0) {
constexpr auto in_major = n0 / NXdlPack;
constexpr auto in_minor = n0 % NXdlPack;
constexpr index_t a_scale_offset =
a_scale_thread_desc.CalculateOffset(make_tuple(im_major, ik_major, I0));
constexpr index_t b_scale_offset =
b_scale_thread_desc.CalculateOffset(make_tuple(in_major, ik_major, I0));
static_assert(0 < ScalesPerXdlopsRunPerThread,
"Must have at least one scale per Xdlops "
"per Thread.");
vector_type<AScaleDataType, a_scale_thread_vec_size> a_scale_thread_vec;
vector_type<BScaleDataType, b_scale_thread_vec_size> b_scale_thread_vec;
// Pack scale_thread_buf into scale_thread_vec
static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) {
a_scale_thread_vec.template AsType<AScaleDataType>()(s) =
a_scale_thread_bufs(I0)[Number<a_scale_offset + s>{}];
});
static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) {
b_scale_thread_vec.template AsType<BScaleDataType>()(s) =
b_scale_thread_bufs(I0)[Number<b_scale_offset + s>{}];
});
vector_type<ComputeTypeA, KPack> a_thread_vec;
vector_type<ComputeTypeB, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeTypeA>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(im_major, I0, im_minor, k0, ik))>{}];
b_thread_vec.template AsType<ComputeTypeB>()(ik) =
b_thread_bufs[I0][Number<b_thread_desc_.CalculateOffset(
make_tuple(in_major, I0, in_minor, k0, ik))>{}];
});
using mfma_input_type_a =
typename vector_type<ComputeTypeA,
xdlops_gemm.K1PerXdlops / APackedSize>::type;
using mfma_input_type_b =
typename vector_type<ComputeTypeB,
xdlops_gemm.K1PerXdlops / BPackedSize>::type;
using mfma_scale_input_type_a =
typename vector_type<AScaleDataType, a_scale_thread_vec_size>::type;
using mfma_scale_input_type_b =
typename vector_type<BScaleDataType, b_scale_thread_vec_size>::type;
constexpr index_t c_offset = c_thread_desc_.CalculateOffset(
make_tuple(im_major, in_major, im_minor, in_minor, 0));
// MFMA accumulation
xdlops_gemm.template Run<ik_minor * MXdlPack + im_minor,
ik_minor * NXdlPack + in_minor>(
a_thread_vec.template AsType<mfma_input_type_a>(),
a_scale_thread_vec.template AsType<mfma_scale_input_type_a>(),
b_thread_vec.template AsType<mfma_input_type_b>(),
b_scale_thread_vec.template AsType<mfma_scale_input_type_b>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
});
// constexpr auto lds_buf = m0.value >= SwitchM ? I1 : I0;
});
__builtin_amdgcn_s_waitcnt(async_vmcnt_encoding);
block_sync_lds();
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, KRepeat, 1>{}([&](auto k) {
constexpr auto k_step = k * xdlops_gemm.KPerXdlops / APackedSize *
(APackedSize * KPack / xdlops_gemm.K1PerXdlops);
static_for<0, xdlops_gemm.K1PerXdlops / (APackedSize * KThreadChunk), 1>{}(
[&](auto chunk) {
constexpr auto a_k_step_chunk =
k_step +
chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks;
a_thread_copy_.Run(a_block_desc_m0_m1_m2_m3_k,
make_tuple(Number<m0 / MXdlPack>{},
I0,
Number<m0 % MXdlPack>{},
I0,
Number<a_k_step_chunk>{}),
a_block_buf,
a_thread_desc_,
make_tuple(Number<m0 / MXdlPack>{},
I0,
Number<m0 % MXdlPack>{},
k,
Number<chunk * KThreadChunk>{}),
a_thread_buf);
});
});
});
__builtin_amdgcn_sched_barrier(0);
static_for<0, MRepeat, 1>{}([&](auto m0) {
constexpr auto im_major = m0 / MXdlPack;
constexpr auto im_minor = m0 % MXdlPack;
static_for<0, KRepeat, 1>{}([&](auto k0) {
constexpr auto ik_major = k0 / KXdlPack;
constexpr auto ik_minor = k0 % KXdlPack;
static_for<0, NRepeat, 1>{}([&](auto n0) {
constexpr auto in_major = n0 / NXdlPack;
constexpr auto in_minor = n0 % NXdlPack;
constexpr index_t a_scale_offset =
a_scale_thread_desc.CalculateOffset(make_tuple(im_major, ik_major, I0));
constexpr index_t b_scale_offset =
b_scale_thread_desc.CalculateOffset(make_tuple(in_major, ik_major, I0));
static_assert(0 < ScalesPerXdlopsRunPerThread,
"Must have at least one scale per Xdlops "
"per Thread.");
vector_type<AScaleDataType, a_scale_thread_vec_size> a_scale_thread_vec;
vector_type<BScaleDataType, b_scale_thread_vec_size> b_scale_thread_vec;
// Pack scale_thread_buf into scale_thread_vec
static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) {
a_scale_thread_vec.template AsType<AScaleDataType>()(s) =
a_scale_thread_bufs(I1)[Number<a_scale_offset + s>{}];
});
static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) {
b_scale_thread_vec.template AsType<BScaleDataType>()(s) =
b_scale_thread_bufs(I1)[Number<b_scale_offset + s>{}];
});
vector_type<ComputeTypeA, KPack> a_thread_vec;
vector_type<ComputeTypeB, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeTypeA>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(im_major, I0, im_minor, k0, ik))>{}];
b_thread_vec.template AsType<ComputeTypeB>()(ik) =
b_thread_bufs[I1][Number<b_thread_desc_.CalculateOffset(
make_tuple(in_major, I0, in_minor, k0, ik))>{}];
});
using mfma_input_type_a =
typename vector_type<ComputeTypeA,
xdlops_gemm.K1PerXdlops / APackedSize>::type;
using mfma_input_type_b =
typename vector_type<ComputeTypeB,
xdlops_gemm.K1PerXdlops / BPackedSize>::type;
using mfma_scale_input_type_a =
typename vector_type<AScaleDataType, a_scale_thread_vec_size>::type;
using mfma_scale_input_type_b =
typename vector_type<BScaleDataType, b_scale_thread_vec_size>::type;
constexpr index_t c_offset = c_thread_desc_.CalculateOffset(
make_tuple(im_major, in_major, im_minor, in_minor, 0));
// MFMA accumulation
xdlops_gemm.template Run<ik_minor * MXdlPack + im_minor,
ik_minor * NXdlPack + in_minor>(
a_thread_vec.template AsType<mfma_input_type_a>(),
a_scale_thread_vec.template AsType<mfma_scale_input_type_a>(),
b_thread_vec.template AsType<mfma_input_type_b>(),
b_scale_thread_vec.template AsType<mfma_scale_input_type_b>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
});
});
}
else if constexpr(TailNum == TailNumber::Odd)
{
static_for<0, MRepeat, 1>{}([&](auto m0) {
constexpr auto im_major = m0 / MXdlPack;
constexpr auto im_minor = m0 % MXdlPack;
static_for<0, KRepeat, 1>{}([&](auto k0) {
constexpr auto ik_major = k0 / KXdlPack;
constexpr auto ik_minor = k0 % KXdlPack;
static_for<0, NRepeat, 1>{}([&](auto n0) {
constexpr auto in_major = n0 / NXdlPack;
constexpr auto in_minor = n0 % NXdlPack;
constexpr index_t a_scale_offset =
a_scale_thread_desc.CalculateOffset(make_tuple(im_major, ik_major, I0));
constexpr index_t b_scale_offset =
b_scale_thread_desc.CalculateOffset(make_tuple(in_major, ik_major, I0));
static_assert(0 < ScalesPerXdlopsRunPerThread,
"Must have at least one scale per Xdlops "
"per Thread.");
vector_type<AScaleDataType, a_scale_thread_vec_size> a_scale_thread_vec;
vector_type<BScaleDataType, b_scale_thread_vec_size> b_scale_thread_vec;
// Pack scale_thread_buf into scale_thread_vec
static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) {
a_scale_thread_vec.template AsType<AScaleDataType>()(s) =
a_scale_thread_bufs(I0)[Number<a_scale_offset + s>{}];
});
static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) {
b_scale_thread_vec.template AsType<BScaleDataType>()(s) =
b_scale_thread_bufs(I0)[Number<b_scale_offset + s>{}];
});
vector_type<ComputeTypeA, KPack> a_thread_vec;
vector_type<ComputeTypeB, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeTypeA>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(im_major, I0, im_minor, k0, ik))>{}];
b_thread_vec.template AsType<ComputeTypeB>()(ik) =
b_thread_bufs[I0][Number<b_thread_desc_.CalculateOffset(
make_tuple(in_major, I0, in_minor, k0, ik))>{}];
});
using mfma_input_type_a =
typename vector_type<ComputeTypeA,
xdlops_gemm.K1PerXdlops / APackedSize>::type;
using mfma_input_type_b =
typename vector_type<ComputeTypeB,
xdlops_gemm.K1PerXdlops / BPackedSize>::type;
using mfma_scale_input_type_a =
typename vector_type<AScaleDataType, a_scale_thread_vec_size>::type;
using mfma_scale_input_type_b =
typename vector_type<BScaleDataType, b_scale_thread_vec_size>::type;
constexpr index_t c_offset = c_thread_desc_.CalculateOffset(
make_tuple(im_major, in_major, im_minor, in_minor, 0));
// MFMA accumulation
xdlops_gemm.template Run<ik_minor * MXdlPack + im_minor,
ik_minor * NXdlPack + in_minor>(
a_thread_vec.template AsType<mfma_input_type_a>(),
a_scale_thread_vec.template AsType<mfma_scale_input_type_a>(),
b_thread_vec.template AsType<mfma_input_type_b>(),
b_scale_thread_vec.template AsType<mfma_scale_input_type_b>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
});
});
}
}
// TODO: make this field protected when a_scale_thread_copy_ is moved
// here
static constexpr auto a_scale_thread_desc = make_naive_tensor_descriptor_packed(
make_tuple(Number<MRepeat / MXdlPack>{},
Number<KRepeat / KXdlPack>{},
Number<ScalesPerXdlopsRunPerThread * a_scale_thread_vec_size>{}));
// TODO: make this field protected when b_scale_thread_copy_ is moved
// here
static constexpr auto b_scale_thread_desc = make_naive_tensor_descriptor_packed(
make_tuple(Number<NRepeat / NXdlPack>{},
Number<KRepeat / KXdlPack>{},
Number<ScalesPerXdlopsRunPerThread * b_scale_thread_vec_size>{}));
protected:
using Base::a_thread_copy_;
using Base::a_thread_desc_;
using Base::b_thread_copy_;
using Base::b_thread_desc_;
using Base::c_thread_desc_;
};
} // namespace ck

View File

@@ -226,85 +226,197 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_mx_moe_v3<BlockGemmPipelineSched
// constexpr auto num_dsread_a_mfma =
// (num_ds_read_inst_a + ds_read_a_mfma_rate - 1) / ds_read_a_mfma_rate;
constexpr auto num_total_stages = MRepeat;
constexpr auto num_total_stages = std::max(2, MRepeat);
if constexpr(num_total_stages > 2)
{
// Group num_mfma_perstage num_ds_read_a_perstage
// since we want to reuse a local register buffer
constexpr auto num_mfma_perstage = num_mfma_inst / num_total_stages;
constexpr auto num_ds_read_a_perstage = num_ds_read_inst_a / num_total_stages;
// Group num_mfma_perstage num_ds_read_a_perstage
// since we want to reuse a local register buffer
constexpr auto num_mfma_perstage = num_mfma_inst / num_total_stages;
constexpr auto num_ds_read_a_perstage = num_ds_read_inst_a / num_total_stages;
constexpr auto num_ds_read_a_mfma_perstage =
math::integer_divide_ceil(num_ds_read_a_perstage, ds_read_a_mfma_rate);
constexpr auto num_ds_read_a_mfma_perstage =
math::integer_divide_ceil(num_ds_read_a_perstage, ds_read_a_mfma_rate);
constexpr auto num_ds_read_a_prefetch_stages = 2;
constexpr auto num_ds_read_a_prefetch_stages = 2;
constexpr auto buffer_load_perstage_more =
math::integer_divide_ceil((num_buffer_load_stage1), (num_total_stages - 2));
constexpr auto buffer_load_perstage_less =
math::integer_divide_floor((num_buffer_load_stage1), (num_total_stages - 2));
constexpr auto buffer_load_perstage_stage2 =
math::integer_divide_floor((num_buffer_load_stage2), 2);
constexpr auto buffer_load_perstage_more =
math::integer_divide_ceil((num_buffer_load_stage1), (num_total_stages - 2));
constexpr auto buffer_load_perstage_less =
math::integer_divide_floor((num_buffer_load_stage1), (num_total_stages - 2));
constexpr auto buffer_load_perstage_stage2 =
math::integer_divide_floor((num_buffer_load_stage2), 2);
constexpr auto buffer_load_stages_more =
num_buffer_load_stage1 -
math::integer_divide_floor(num_buffer_load_stage1, (num_total_stages - 2)) *
((num_total_stages - 2));
constexpr auto buffer_load_stages_more =
num_buffer_load_stage1 -
math::integer_divide_floor(num_buffer_load_stage1, (num_total_stages - 2)) *
((num_total_stages - 2));
constexpr auto buffer_load_issue_point_interval_more =
num_mfma_perstage / buffer_load_perstage_more;
constexpr auto buffer_load_issue_point_interval_less =
num_mfma_perstage / buffer_load_perstage_less;
constexpr auto buffer_load_issue_point_interval_stage2 =
num_mfma_perstage / buffer_load_perstage_stage2;
constexpr auto buffer_load_issue_point_interval_more =
num_mfma_perstage / buffer_load_perstage_more;
constexpr auto buffer_load_issue_point_interval_less =
num_mfma_perstage / buffer_load_perstage_less;
constexpr auto buffer_load_issue_point_interval_stage2 =
num_mfma_perstage / buffer_load_perstage_stage2;
// Stage 1
// global read more
static_for<0, buffer_load_stages_more, 1>{}([&](auto /*i*/) {
static_for<0, num_mfma_perstage, 1>{}([&](auto imfma) {
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
// Stage 1
// global read more
static_for<0, buffer_load_stages_more, 1>{}([&](auto /*i*/) {
static_for<0, num_mfma_perstage, 1>{}([&](auto imfma) {
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
if constexpr(imfma % buffer_load_issue_point_interval_more == 0)
if constexpr(imfma % buffer_load_issue_point_interval_more == 0)
{
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
}
if constexpr(imfma >= (num_mfma_perstage - num_ds_read_a_mfma_perstage))
{
__builtin_amdgcn_sched_group_barrier(
0x100, ds_read_a_mfma_rate, 0); // DS read
}
});
});
// global read less
static_for<0, (num_total_stages - 2 - buffer_load_stages_more), 1>{}([&](auto /*i*/) {
static_for<0, num_mfma_perstage, 1>{}([&](auto imfma) {
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
if constexpr(imfma % buffer_load_issue_point_interval_less == 0)
{
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
}
if constexpr(imfma >= (num_mfma_perstage - num_ds_read_a_mfma_perstage))
{
__builtin_amdgcn_sched_group_barrier(
0x100, ds_read_a_mfma_rate, 0); // DS read
}
});
});
// Stage 2, Sync
// lds synchronization, prefetch next loop local A
static_for<0, num_ds_read_a_prefetch_stages, 1>{}([&](auto /*i*/) {
static_for<0, num_mfma_perstage, 1>{}([&](auto imfma) {
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
if constexpr(imfma % buffer_load_issue_point_interval_stage2 == 0)
{
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
}
if constexpr(imfma >= (num_mfma_perstage - num_ds_read_a_mfma_perstage))
{
__builtin_amdgcn_sched_group_barrier(
0x100, ds_read_a_mfma_rate, 0); // DS read
}
});
});
}
else
{
constexpr auto num_buffer_load_total = num_buffer_load_inst_a + num_buffer_load_inst_b +
num_buffer_load_a_scale +
num_buffer_load_b_scale;
constexpr auto num_dsread_a_mfma = math::integer_divide_ceil(
num_ds_read_inst_a, ds_read_a_mfma_rate); // how many mfma per dsread_a
// stage 1
constexpr auto num_mfma_stage1 = num_mfma_inst - num_dsread_a_mfma;
constexpr auto mfma_perstage_more =
math::integer_divide_ceil(num_mfma_stage1, num_buffer_load_total);
constexpr auto mfma_perstage_less =
math::integer_divide_floor(num_mfma_stage1, num_buffer_load_total);
constexpr auto mfma_stages_more =
num_mfma_stage1 - mfma_perstage_less * num_buffer_load_total;
static_for<0, num_buffer_load_inst_a, 1>{}([&](auto i) {
if constexpr(i < mfma_stages_more)
{
static_for<0, mfma_perstage_more, 1>{}([&](auto) {
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
});
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
}
else
{
static_for<0, mfma_perstage_less, 1>{}([&](auto) {
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
});
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
}
});
if constexpr(imfma >= (num_mfma_perstage - num_ds_read_a_mfma_perstage))
static_for<0, num_buffer_load_inst_b, 1>{}([&](auto i) {
if constexpr((i + num_buffer_load_inst_a) < mfma_stages_more)
{
static_for<0, mfma_perstage_more, 1>{}([&](auto /*imfma*/) {
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
});
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
}
else
{
static_for<0, mfma_perstage_less, 1>{}([&](auto /*imfma*/) {
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
});
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
}
});
static_for<0, num_buffer_load_a_scale, 1>{}([&](auto i) {
if constexpr((i + num_buffer_load_inst_a + num_buffer_load_inst_b) <
mfma_stages_more)
{
static_for<0, mfma_perstage_more, 1>{}([&](auto /*imfma*/) {
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
});
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
}
else
{
static_for<0, mfma_perstage_less, 1>{}([&](auto /*imfma*/) {
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
});
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
}
});
static_for<0, num_buffer_load_b_scale, 1>{}([&](auto i) {
if constexpr((i + num_buffer_load_inst_a + num_buffer_load_inst_b +
num_buffer_load_a_scale) < mfma_stages_more)
{
static_for<0, mfma_perstage_more, 1>{}([&](auto /*imfma*/) {
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
});
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
}
else
{
static_for<0, mfma_perstage_less, 1>{}([&](auto /*imfma*/) {
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
});
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
}
});
// stage 2
static_for<0, num_dsread_a_mfma, 1>{}([&](auto i) {
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
if constexpr((num_ds_read_inst_a - (i + 1) * ds_read_a_mfma_rate) >=
ds_read_a_mfma_rate)
{
__builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); // DS read
}
});
});
// global read less
static_for<0, (num_total_stages - 2 - buffer_load_stages_more), 1>{}([&](auto /*i*/) {
static_for<0, num_mfma_perstage, 1>{}([&](auto imfma) {
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
if constexpr(imfma % buffer_load_issue_point_interval_less == 0)
else
{
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
}
if constexpr(imfma >= (num_mfma_perstage - num_ds_read_a_mfma_perstage))
{
__builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); // DS read
__builtin_amdgcn_sched_group_barrier(
0x100,
num_ds_read_inst_a - (num_dsread_a_mfma - 1) * ds_read_a_mfma_rate,
0); // DS read
}
});
});
// Stage 2, Sync
// lds synchronization, prefetch next loop local A
static_for<0, num_ds_read_a_prefetch_stages, 1>{}([&](auto /*i*/) {
static_for<0, num_mfma_perstage, 1>{}([&](auto imfma) {
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
if constexpr(imfma % buffer_load_issue_point_interval_stage2 == 0)
{
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
}
if constexpr(imfma >= (num_mfma_perstage - num_ds_read_a_mfma_perstage))
{
__builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); // DS read
}
});
});
}
}
template <bool HasMainLoop,

View File

@@ -3,6 +3,8 @@
#pragma once
#include <string>
namespace ck {
namespace tensor_operation {
namespace device {

View File

@@ -122,7 +122,7 @@ struct DeviceMoeGemmMXBPreShuffle : public DeviceMoEGemmMXBPreShuffle<ALayout,
MPerXDL,
NPerXDL,
MXdlPerWave,
NXdlPerWave_,
math::max(2, NXdlPerWave_),
ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder,

View File

@@ -429,8 +429,8 @@ struct GridwiseGemmMX_xdl_cshuffle_v3_bpreshuffle
constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
constexpr index_t WaveSize = BlockSize / (MWave * NWave);
constexpr index_t NkSwizzleNumber = Number<WaveSize * KPack>{};
return make_naive_tensor_descriptor_packed(
make_tuple(N0 / NWave / NXdlPack, NWave, NXdlPack, K0, NkSwizzleNumber));
return make_naive_tensor_descriptor_packed(make_tuple(
math::integer_divide_ceil(N0, NWave * NXdlPack), NWave, NXdlPack, K0, NkSwizzleNumber));
}
__host__ __device__ static auto MakeBGridDescriptor_BK0_N_BK1(

View File

@@ -48,28 +48,25 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
kernel_moe_mxgemm(typename GridwiseGemm::Argument karg)
{
#if defined(__gfx9__)
if constexpr(GridwiseGemm::template IsValidCompilationParameter<CGlobalMemoryDataOperation>())
{
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z);
auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z);
GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
karg.p_sorted_token_ids,
karg.p_sorted_expert_ids,
karg.p_max_token_id,
karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
karg.p_a_scale_grid + splitk_batch_offset.a_k_split_offset,
karg.p_b_grid + splitk_batch_offset.b_k_split_offset,
karg.p_b_scale_grid + splitk_batch_offset.b_k_split_offset,
karg.p_ds_grid,
karg.p_c_grid,
p_shared,
karg,
karg.a_element_op,
karg.b_element_op,
karg.c_element_op);
}
GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
karg.p_sorted_token_ids,
karg.p_sorted_expert_ids,
karg.p_max_token_id,
karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
karg.p_a_scale_grid + splitk_batch_offset.a_scale_k_split_offset,
karg.p_b_grid + splitk_batch_offset.b_k_split_offset,
karg.p_b_scale_grid + splitk_batch_offset.b_scale_k_split_offset,
karg.p_ds_grid,
karg.p_c_grid,
p_shared,
karg,
karg.a_element_op,
karg.b_element_op,
karg.c_element_op);
#else
ignore = karg;
#endif // end of if (defined(__gfx9__))
@@ -1249,7 +1246,6 @@ struct GridwiseMoeGemmMX_BPreshuffle
__host__ static constexpr bool CalculateHasMainKBlockLoop(index_t K)
{
const index_t num_loop = K / KPerBlock;
return BlockwiseGemmPipe::BlockHasHotloop(num_loop);
}
@@ -1279,7 +1275,6 @@ struct GridwiseMoeGemmMX_BPreshuffle
// using Block2CTileMapDefault = BlockToCTileMap_Grouped_M00_N0_M01Adapt<8, MPerBlock,
// NPerBlock>;
#if 0
template <bool HasMainKBlockLoop,
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
TailNumber TailNum = TailNumber::Odd>
@@ -1298,9 +1293,10 @@ struct GridwiseMoeGemmMX_BPreshuffle
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op)
{
ignore = a_element_op;
ignore = b_element_op;
index_t BN0Shuffled = CalculateBN0Shuffled(problem.N);
index_t BK0Shuffled = CalculateBK0Shuffled(problem.K);
index_t BN0Shuffled = CalculateBN0Shuffled(problem.N);
index_t BK0Shuffled = CalculateBK0Shuffled(problem.K);
const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1(
IsInputGemm ? problem.NumTokens : problem.NumTokens * problem.TopK,
problem.MPadded,
@@ -1317,29 +1313,41 @@ struct GridwiseMoeGemmMX_BPreshuffle
problem.NPadded,
problem.StrideC);
const auto a_scale_grid_desc_am_ak = make_naive_tensor_descriptor_packed(
make_tuple((IsInputGemm ? problem.NumTokens : problem.M) / (MXdlPack * MPerBlock),
// We pad the M unconditionaly for Scale
const auto Padded_Scale_M =
math::integer_divide_ceil(problem.M, ScaleBlockSize) * ScaleBlockSize;
const auto a_scale_grid_desc_am_ak = make_naive_tensor_descriptor(
make_tuple(Padded_Scale_M / (MXdlPack * MPerXdl),
math::integer_divide_ceil(problem.K, (ScaleBlockSize / APackedSize)) /
(KXdlPack * 64 / MPerXdl),
64 * KXdlPack * MXdlPack / scale_pack_size_a));
64 * KXdlPack * MXdlPack / scale_pack_size_a),
make_tuple(math::integer_divide_ceil(problem.K * problem.KBatch,
(ScaleBlockSize / APackedSize)) *
MPerXdl * MXdlPack / scale_pack_size_a,
64 * KXdlPack * MXdlPack / scale_pack_size_a,
1));
const auto b_scale_grid_desc_bn_ak = make_naive_tensor_descriptor_packed(
const auto b_scale_grid_desc_bn_ak = make_naive_tensor_descriptor(
make_tuple(problem.N / (NXdlPack * NPerXdl),
math::integer_divide_ceil(problem.K, (ScaleBlockSize / BPackedSize)) /
(KXdlPack * 64 / NPerXdl),
64 * KXdlPack * NXdlPack / scale_pack_size_b));
64 * KXdlPack * NXdlPack / scale_pack_size_b),
make_tuple(math::integer_divide_ceil(problem.K * problem.KBatch,
(ScaleBlockSize / BPackedSize)) *
NPerXdl * NXdlPack / scale_pack_size_b,
64 * KXdlPack * NXdlPack / scale_pack_size_b,
1));
const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
c_grid_desc_m_n, problem.MBlock, problem.NBlock);
const index_t max_token_id = __builtin_amdgcn_readfirstlane(p_max_token_id[0]);
// static_assert(NSwizzle == false, "to do fix: need another pr in sorting merged");
const index_t max_token_id = __builtin_amdgcn_readfirstlane(p_max_token_id[0]);
const index_t expert_block_id = NSwizzle ? blockIdx.x / problem.NBlock : blockIdx.y;
if(expert_block_id * MPerBlock >= max_token_id)
return;
const index_t expert_id =
__builtin_amdgcn_readfirstlane(p_sorted_expert_ids[expert_block_id]);
const auto block_mn = [&]() -> std::pair<int, int> {
if constexpr(NSwizzle)
{
@@ -1372,86 +1380,78 @@ struct GridwiseMoeGemmMX_BPreshuffle
constexpr auto AK1Threads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I2);
constexpr auto AKThreads = AK0Threads * AK1Threads;
constexpr auto AMRepeats = MPerBlock / AMThreads;
const index_t token_pos = block_m_id * MPerBlock + threadIdx.x / AKThreads * AMRepeats;
const index_t token_pos = block_m_id * MPerBlock + threadIdx.x / AKThreads;
if(token_pos >= max_token_id || token0 >= problem.NumTokens)
return;
StaticallyIndexedArray<IndexType, AMRepeats> gather_offsets;
static_for<0, AMRepeats, 1>{}([&](auto m0) {
const index_t fused_token = p_sorted_token_ids[token_pos + m0];
const index_t fused_token = p_sorted_token_ids[token_pos + m0 * AMThreads];
index_t token_offset = fused_token & 0xffffff;
if constexpr(!IsInputGemm)
{
token_offset = token_offset * problem.TopK + (fused_token >> 24);
}
gather_offsets(m0) = static_cast<IndexType>(token_offset) * problem.K / APackedSize;
gather_offsets(m0) = static_cast<IndexType>(token_offset) * problem.K;
});
const index_t expert_stride =
__builtin_amdgcn_readfirstlane(problem.N * problem.K * (IsInputGemm ? 2 : 1));
const index_t expert_scale_stride =
__builtin_amdgcn_readfirstlane(problem.N * (IsInputGemm ? 2 : 1) *
math::integer_divide_ceil(problem.K, ScaleBlockSize));
const index_t expert_scale_stride = __builtin_amdgcn_readfirstlane(
problem.N * (IsInputGemm ? 2 : 1) *
math::integer_divide_ceil(problem.K, ScaleBlockSize / BPackedSize));
// N0, K0, Blocksize*KPack
const index_t n_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(block_n_id * NXdlPerWave);
__builtin_amdgcn_readfirstlane(block_n_id * NXdlPerWave / NXdlPack);
// Gride buffer creation
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b_grid + expert_id * expert_stride / BPackedSize,
b_grid_desc_bpreshuffled.GetElementSpaceSize());
p_b_grid + expert_id * expert_stride, b_grid_desc_bpreshuffled.GetElementSpaceSize());
// A, B scale buffer
const auto a_scale_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_a_scale_grid, a_scale_grid_desc_am_ak.GetElementSpaceSize());
const auto b_scale_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b_scale_grid + expert_id * expert_scale_stride,
p_b_scale_grid + (expert_id * expert_scale_stride) / sizeof(BScaleDataType),
b_scale_grid_desc_bn_ak.GetElementSpaceSize());
// A matrix in LDS memory, dst of blockwise copy
constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
// B matrix in LDS memory, dst of blockwise copy
// dummy
constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
// A matrix blockwise copy
auto a_blockwise_copy = ThreadGroupTensorSliceTransfer_v4r1_gather<
// A matrix blockwise direct to LDS copy
auto a_blockwise_copy = ThreadGroupTensorSliceTransfer_Gather_DirectLoad<
ThisThreadBlock,
AElementwiseOperation,
ck::tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set,
Sequence<AK0Number, MPerBlock, AK1Number>,
ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder,
ADataType,
LDSTypeA,
ADataType,
decltype(a_grid_desc_ak0_m_ak1),
decltype(a_block_desc_ak0_m_ak1),
ABlockTransferSrcAccessOrder,
Sequence<0, 1, 2>,
ABlockTransferSrcVectorDim,
2,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_AK1,
1,
1,
AThreadTransferSrcResetCoordinateAfterRun,
true,
IndexType,
1,
BlockwiseGemmPipe::GlobalBufferNum>(a_grid_desc_ak0_m_ak1,
make_multi_index(0, 0, 0),
a_element_op,
a_block_desc_ak0_m_ak1,
make_multi_index(0, 0, 0),
ck::tensor_operation::element_wise::PassThrough{},
gather_offsets);
1>(a_grid_desc_ak0_m_ak1,
make_multi_index(0, 0, 0),
a_block_desc_ak0_m_ak1,
make_multi_index(0, 0, 0),
gather_offsets);
// Thread-wise copy
// K0 -> N0/NWave -> NWave -> KLane -> NLane -> KPack
auto b_block_buf = make_static_buffer<AddressSpaceEnum::Vgpr, BDataType>(
auto b_block_buf_ping = make_static_buffer<AddressSpaceEnum::Vgpr, BDataType>(
b_block_desc_bk0_n_bk1.GetElementSpaceSize());
auto b_block_buf_pong = make_static_buffer<AddressSpaceEnum::Vgpr, BDataType>(
b_block_desc_bk0_n_bk1.GetElementSpaceSize());
auto b_block_bufs = make_tuple(b_block_buf_ping, b_block_buf_pong);
auto b_blockwise_copy =
ThreadwiseTensorSliceTransfer_v2<BDataType,
@@ -1463,7 +1463,7 @@ struct GridwiseMoeGemmMX_BPreshuffle
Number<NXdlPack>{},
Number<KRepeat>{},
Number<BK1Value>{}>,
Sequence<1, 2, 0, 3>,
Sequence<0, 1, 2, 3, 4>,
4,
BBlockTransferSrcScalarPerVector,
BThreadTransferSrcResetCoordinateAfterRun,
@@ -1472,16 +1472,16 @@ struct GridwiseMoeGemmMX_BPreshuffle
make_multi_index(n_block_data_idx_on_grid,
get_warp_local_1d_id() % NWave,
0,
KPack / KGroup * (get_thread_local_1d_id() % WarpSize)));
0,
KPack * (get_thread_local_1d_id() % WarpSize)));
// LDS allocation for A and B: be careful of alignment
// Cast after lds
auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<LDSTypeA*>(p_shared),
a_block_desc_ak0_m_ak1.GetElementSpaceSize() / APackedSize);
static_cast<ADataType*>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0);
constexpr auto b_block_slice_copy_step = make_multi_index(0, 0, KRepeat, 0);
constexpr auto b_block_slice_copy_step = make_multi_index(0, 0, 0, KRepeat, 0);
// Blockwise GEMM pipeline
static_assert(std::is_default_constructible_v<BlockwiseGemmPipe>);
@@ -1505,13 +1505,16 @@ struct GridwiseMoeGemmMX_BPreshuffle
const auto waveId_m = wave_idx[I0];
const auto waveId_n = wave_idx[I1];
static constexpr auto mfma = BlockwiseGemmPipe::xdlops_gemm.mfma;
auto thread_offset_shuffled =
get_thread_local_1d_id() % BlockwiseGemmPipe::WaveSize * KXdlPack * MXdlPack;
auto a_thread_offset_m = waveId_m;
// get each thread's offset int the scale tensor
const index_t token_scale_pos = block_m_id * MPerBlock;
if(token_scale_pos >= max_token_id || token0 >= problem.NumTokens)
return;
auto a_scale_thread_copy = ThreadwiseTensorSliceTransfer_v2<
AScaleDataType,
AScaleDataType,
@@ -1538,7 +1541,7 @@ struct GridwiseMoeGemmMX_BPreshuffle
Sequence<1, 1, KXdlPack * NXdlPack / scale_pack_size_b>, // SliceLengths
Sequence<0, 1, 2>, // DimAccessOrder
2, // SrcVectorDim
KXdlPack * MXdlPack / scale_pack_size_b, // SrcScalarPerVector
KXdlPack * NXdlPack / scale_pack_size_b, // SrcScalarPerVector
1, // SrcScalarStrideInVector
true>(b_scale_grid_desc_bn_ak,
make_multi_index(block_n_id * NPerBlock / NPerXdl / NXdlPack + b_thread_offset_n,
@@ -1547,29 +1550,37 @@ struct GridwiseMoeGemmMX_BPreshuffle
if constexpr(IsInputGemm)
{
const BDataType* p_b_grid_up = p_b_grid + expert_stride / 2 / BPackedSize;
const BDataType* p_b_grid_up = p_b_grid + expert_stride / 2;
const auto b_grid_buf_up = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b_grid_up + expert_id * expert_stride / BPackedSize,
p_b_grid_up + expert_id * expert_stride,
b_grid_desc_bpreshuffled.GetElementSpaceSize());
auto b_blockwise_copy_up = ThreadwiseTensorSliceTransfer_v2<
BDataType,
BDataType,
decltype(b_grid_desc_bpreshuffled),
decltype(b_block_desc_bk0_n_bk1),
Sequence<Number<NXdlPerWave>{}, I1, Number<KRepeat>{}, Number<BK1Value>{}>,
Sequence<1, 2, 0, 3>,
3,
BBlockTransferSrcScalarPerVector,
BThreadTransferSrcResetCoordinateAfterRun,
true>(b_grid_desc_bpreshuffled,
make_multi_index(n_block_data_idx_on_grid,
get_warp_local_1d_id() % NWave,
0,
KPack / KGroup * (get_thread_local_1d_id() % WarpSize)));
const BScaleDataType* p_b_scale_grid_up = p_b_scale_grid + expert_scale_stride / 2;
const auto b_scale_grid_buf_up = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b_scale_grid_up + expert_id * expert_scale_stride,
auto b_blockwise_copy_up =
ThreadwiseTensorSliceTransfer_v2<BDataType,
BDataType,
decltype(b_grid_desc_bpreshuffled),
decltype(b_block_desc_bk0_n_bk1),
Sequence<Number<NXdlPerWave / NXdlPack>{},
I1,
Number<NXdlPack>{},
Number<KRepeat>{},
Number<BK1Value>{}>,
Sequence<0, 1, 2, 3, 4>,
4,
BBlockTransferSrcScalarPerVector,
BThreadTransferSrcResetCoordinateAfterRun,
true>(
b_grid_desc_bpreshuffled,
make_multi_index(n_block_data_idx_on_grid,
get_warp_local_1d_id() % NWave,
0,
0,
KPack * (get_thread_local_1d_id() % WarpSize)));
const BScaleDataType* p_b_scale_grid_up =
p_b_scale_grid + expert_scale_stride / 2 / sizeof(BScaleDataType);
const auto b_scale_grid_buf_up = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b_scale_grid_up + expert_id * expert_scale_stride / sizeof(BScaleDataType),
b_scale_grid_desc_bn_ak.GetElementSpaceSize());
auto b_scale_thread_copy_up = ThreadwiseTensorSliceTransfer_v2<
BScaleDataType,
BScaleDataType,
@@ -1587,25 +1598,30 @@ struct GridwiseMoeGemmMX_BPreshuffle
thread_offset_shuffled / scale_pack_size_b));
blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, TailNum>(
// A
a_grid_desc_ak0_m_ak1,
a_block_desc_ak0_m_ak1,
a_blockwise_copy,
a_grid_buf,
a_block_buf,
a_block_slice_copy_step,
// Gate and Up
b_grid_desc_bpreshuffled,
b_block_desc_bk0_n_bk1,
b_blockwise_copy,
b_blockwise_copy_up,
b_grid_buf,
b_grid_buf_up,
b_block_buf,
b_block_bufs,
b_block_slice_copy_step,
// C
c_thread_buf,
c_thread_buf_up,
// A scale
a_scale_grid_desc_am_ak,
a_scale_thread_copy,
a_scale_grid_buf,
// B scale
b_scale_grid_desc_bn_ak,
b_scale_thread_copy,
b_scale_thread_copy_up,
@@ -1616,23 +1632,23 @@ struct GridwiseMoeGemmMX_BPreshuffle
else
{
blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, TailNum>(
a_grid_desc_ak0_m_ak1,
a_grid_desc_ak0_m_ak1, // A
a_block_desc_ak0_m_ak1,
a_blockwise_copy,
a_grid_buf,
a_block_buf,
a_block_slice_copy_step,
b_grid_desc_bpreshuffled,
b_grid_desc_bpreshuffled, // B
b_block_desc_bk0_n_bk1,
b_blockwise_copy,
b_grid_buf,
b_block_buf,
b_block_bufs,
b_block_slice_copy_step,
c_thread_buf,
a_scale_grid_desc_am_ak,
c_thread_buf, // C
a_scale_grid_desc_am_ak, // A scale
a_scale_thread_copy,
a_scale_grid_buf,
b_scale_grid_desc_bn_ak,
b_scale_grid_desc_bn_ak, // B scale
b_scale_thread_copy,
b_scale_grid_buf,
num_k_block_main_loop);
@@ -1643,84 +1659,101 @@ struct GridwiseMoeGemmMX_BPreshuffle
static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
"wrong!");
static_assert(CShuffleMXdlPerWavePerShuffle % MXdlPack == 0 &&
CShuffleNXdlPerWavePerShuffle % NXdlPack == 0,
"wrong!");
constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
// TODO: hacky, fix it!
constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3();
// TODO: hacky, fix it!
// c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths
constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp =
blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3();
constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0);
constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1);
constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2);
constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3);
constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4);
constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5);
constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6);
constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7);
constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5);
constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6);
constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7);
constexpr auto M5 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I8);
constexpr auto N3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I9);
// mul scales
static_assert(M0 * M1 * M2 * M3 * M4 == MPerBlock);
static_assert(M4 == 4);
static_assert(M0 * M1 * M2 * M3 * M4 * M5 == MPerBlock);
static_assert(M5 == 4);
const index_t m1 = get_warp_local_1d_id() / NWave;
const index_t m3 = threadIdx.x % get_warp_size() / MPerXdl;
const index_t m4 = threadIdx.x % get_warp_size() / MPerXdl;
vector_type<float, 4> topk_weights; // for gemm2 only
static_for<0, NXdlPerWave, 1>{}([&](auto n0) {
static_for<0, MXdlPerWave, 1>{}([&](auto m0) { // MXDLPerWave
static_for<0, M2, 1>{}([&](auto m2) { // m_inst_num_groups_per_blk
const index_t m_pos = block_m_id * MPerBlock + m0 * M1 * M2 * M3 * M4 +
m1 * M2 * M3 * M4 + m2 * M3 * M4 + m3 * M4;
if constexpr(MulRoutedWeight)
{
topk_weights = *c_style_pointer_cast<const vector_type<float, M4>*>(
p_ds_grid[I2] + m_pos);
}
static_for<0, M4, 1>{}([&](auto m4) { // m_inst_group_size
constexpr index_t c_offset =
blockwise_gemm_pipeline.GetCThreadDesc().CalculateOffset(
make_tuple(m0, n0, m2 * M4 + m4));
constexpr auto cidx = Number<c_offset>{};
if constexpr(IsInputGemm) // gu fusion
{
if constexpr(ActivationOperation == Activation::silu_and_mul)
{
float gate = c_thread_buf[cidx];
float up = c_thread_buf_up[cidx];
if constexpr(MulRoutedWeight)
{
gate = gate * topk_weights.AsType<float>()[m4];
up = up * topk_weights.AsType<float>()[m4];
}
tensor_operation::element_wise::Silu{}(gate, gate);
c_thread_buf_fp32(cidx) = gate * up;
}
else if(ActivationOperation == Activation::gelu_and_mul)
{
float gate = c_thread_buf[cidx];
float up = c_thread_buf_up[cidx];
if constexpr(MulRoutedWeight)
{
gate = gate * topk_weights.AsType<float>()[m4];
up = up * topk_weights.AsType<float>()[m4];
}
tensor_operation::element_wise::Gelu{}(gate, gate);
c_thread_buf_fp32(cidx) = gate * up;
}
}
else
{
c_thread_buf_fp32(cidx) = c_thread_buf[cidx];
static_for<0, NXdlPerWave / NXdlPack, 1>{}([&](auto n0) {
static_for<0, NXdlPack, 1>{}([&](auto inxdl) { // NXdlPack
static_for<0, MXdlPerWave / MXdlPack, 1>{}([&](auto m0) { // MXDLPerWave
static_for<0, MXdlPack, 1>{}([&](auto imxdl) { // MXdlPack
static_for<0, M3, 1>{}([&](auto m3) { // m_inst_num_groups_per_blk
const index_t m_pos = block_m_id * MPerBlock +
m0 * M2 * M1 * M3 * M4 * M5 +
m1 * M2 * M3 * M4 * M5 +
imxdl * M3 * M4 * M5 + m3 * M4 * M5 + m4 * M5;
if constexpr(MulRoutedWeight)
{
c_thread_buf_fp32(cidx) =
topk_weights.AsType<float>()[m4] * c_thread_buf_fp32[cidx];
topk_weights =
*c_style_pointer_cast<const vector_type<float, M5>*>(
p_ds_grid[I2] + m_pos);
}
}
static_for<0, M5, 1>{}([&](auto m5) { // m_inst_group_size
constexpr index_t c_offset =
blockwise_gemm_pipeline.GetCThreadDesc().CalculateOffset(
make_tuple(m0, n0, imxdl, inxdl, m3 * M5 + m5));
constexpr auto cidx = Number<c_offset>{};
if constexpr(IsInputGemm) // gu fusion
{
if constexpr(ActivationOperation ==
Activation::silu_and_mul)
{
float gate = c_thread_buf[cidx];
float up = c_thread_buf_up[cidx];
if constexpr(MulRoutedWeight)
{
gate = gate * topk_weights.AsType<float>()[m5];
up = up * topk_weights.AsType<float>()[m5];
}
tensor_operation::element_wise::Silu{}(gate, gate);
c_thread_buf_fp32(cidx) = gate * up;
}
else if(ActivationOperation == Activation::gelu_and_mul)
{
float gate = c_thread_buf[cidx];
float up = c_thread_buf_up[cidx];
if constexpr(MulRoutedWeight)
{
gate = gate * topk_weights.AsType<float>()[m5];
up = up * topk_weights.AsType<float>()[m5];
}
tensor_operation::element_wise::Gelu{}(gate, gate);
c_thread_buf_fp32(cidx) = gate * up;
}
}
else
{
c_thread_buf_fp32(cidx) = c_thread_buf[cidx];
if constexpr(MulRoutedWeight)
{
c_thread_buf_fp32(cidx) =
topk_weights.AsType<float>()[m5] *
c_thread_buf_fp32[cidx];
}
}
});
});
});
});
});
@@ -1738,19 +1771,25 @@ struct GridwiseMoeGemmMX_BPreshuffle
make_tuple(
make_freeze_transform(I0),
make_unmerge_transform(make_tuple(
Number<CShuffleMXdlPerWavePerShuffle>{}, // M0 (MXdlPerWave) per shuffle
M1, // M1 = MWave
M2, // M2 * M3 * M4 = MPerXdl
Number<CShuffleMXdlPerWavePerShuffle / MXdlPack>{}, // M0 (MXdlPerWave) per
// shuffle
M1, // M1 = MWave
M2, // M2 * M3 * M4 = MPerXdl
M3,
M4)),
M4,
M5)),
make_freeze_transform(I0),
make_unmerge_transform(make_tuple(
Number<CShuffleNXdlPerWavePerShuffle>{}, // N0 (NXdlPerWave) per shuffle
N1, // N1 = NWave
N2))), // N2 = NPerXdl
Number<CShuffleNXdlPerWavePerShuffle / NXdlPack>{}, // N0 (NXdlPerWave)
// per shuffle
N1, // N1 = NWave
N2, // N2 = NXdlPack
N3))), // N3 = NPerXdl
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(
Sequence<>{}, Sequence<0, 2, 4, 5, 6>{}, Sequence<>{}, Sequence<1, 3, 7>{}));
make_tuple(Sequence<>{},
Sequence<0, 2, 4, 6, 7, 8>{},
Sequence<>{},
Sequence<1, 3, 5, 9>{}));
// calculate origin of thread output tensor on global memory
// blockwise GEMM c matrix starting index
@@ -1762,8 +1801,8 @@ struct GridwiseMoeGemmMX_BPreshuffle
const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor =
make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))),
make_tuple(Sequence<0, 1, 2, 3, 4>{}),
make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4, M5))),
make_tuple(Sequence<0, 1, 2, 3, 4, 5>{}),
make_tuple(Sequence<0>{}));
const auto m_thread_data_on_block_idx =
@@ -1772,8 +1811,8 @@ struct GridwiseMoeGemmMX_BPreshuffle
const auto n_thread_data_on_block_to_n0_n1_n2_adaptor =
make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(N0, N1, N2))),
make_tuple(Sequence<0, 1, 2>{}),
make_tuple(make_merge_transform(make_tuple(N0, N1, N2, N3))),
make_tuple(Sequence<0, 1, 2, 3>{}),
make_tuple(Sequence<0>{}));
const auto n_thread_data_on_block_idx =
@@ -1781,36 +1820,39 @@ struct GridwiseMoeGemmMX_BPreshuffle
make_multi_index(n_thread_data_on_block));
// shuffle: threadwise copy C from VGPR to LDS
auto c_thread_copy_vgpr_to_lds =
ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
CShuffleDataType,
decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
ck::tensor_operation::element_wise::PassThrough,
Sequence<CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle,
I1,
I1,
M2,
I1,
M4,
I1>,
Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
7,
1,
InMemoryDataOperationEnum::Set,
1,
true>{
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
make_multi_index(0,
0,
m_thread_data_on_block_idx[I1],
n_thread_data_on_block_idx[I1],
m_thread_data_on_block_idx[I2],
m_thread_data_on_block_idx[I3],
m_thread_data_on_block_idx[I4],
n_thread_data_on_block_idx[I2]),
ck::tensor_operation::element_wise::PassThrough{}};
auto c_thread_copy_vgpr_to_lds = ThreadwiseTensorSliceTransfer_v1r3<
AccDataType,
CShuffleDataType,
decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
ck::tensor_operation::element_wise::PassThrough,
Sequence<CShuffleMXdlPerWavePerShuffle / MXdlPack,
CShuffleNXdlPerWavePerShuffle / NXdlPack,
I1,
I1,
M2,
N2,
M3,
I1,
M5,
I1>,
Sequence<0, 1, 2, 3, 4, 5, 6, 7, 8, 9>,
9,
1,
InMemoryDataOperationEnum::Set,
1,
true>{c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
make_multi_index(0,
0,
m_thread_data_on_block_idx[I1],
n_thread_data_on_block_idx[I1],
m_thread_data_on_block_idx[I2],
n_thread_data_on_block_idx[I2],
m_thread_data_on_block_idx[I3],
m_thread_data_on_block_idx[I4],
m_thread_data_on_block_idx[I5],
n_thread_data_on_block_idx[I3]),
ck::tensor_operation::element_wise::PassThrough{}};
using EDataType = CDataType;
@@ -1859,7 +1901,7 @@ struct GridwiseMoeGemmMX_BPreshuffle
using CDEBlockTransferCluster =
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock;
const auto EGlobalMemoryDataOperation = CGlobalMemoryDataOperation;
constexpr index_t scatter_weight_idx = 1; // hack fix felix
constexpr index_t scatter_weight_idx = 3; // hack fix felix
auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7r3_scatter<
ThisThreadBlock,
decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})),
@@ -1867,8 +1909,9 @@ struct GridwiseMoeGemmMX_BPreshuffle
decltype(c_ds_desc_refs),
decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)),
CElementwiseOperation,
Sequence<static_cast<index_t>(EGlobalMemoryDataOperation)>, // FIXME: make Sequence
// support arbitray type
Sequence<static_cast<index_t>(EGlobalMemoryDataOperation)>, // FIXME: make
// Sequence support
// arbitray type
Sequence<1,
CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1,
@@ -1898,13 +1941,25 @@ struct GridwiseMoeGemmMX_BPreshuffle
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
constexpr auto sfc_c_vgpr =
SpaceFillingCurve<Sequence<MXdlPerWave, NXdlPerWave, 1, 1, M2, 1, M4, 1>,
Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
Sequence<CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle,
SpaceFillingCurve<Sequence<MXdlPerWave / MXdlPack,
NXdlPerWave / NXdlPack,
1,
1,
MXdlPack,
NXdlPack,
M2,
1,
M4,
1>,
Sequence<0, 1, 2, 3, 4, 5, 6, 7, 8, 9>,
Sequence<CShuffleMXdlPerWavePerShuffle / MXdlPack,
CShuffleNXdlPerWavePerShuffle / NXdlPack,
1,
1,
MXdlPack,
NXdlPack,
M2,
1,
M4,
@@ -1984,7 +2039,6 @@ struct GridwiseMoeGemmMX_BPreshuffle
});
}
}
#endif
template <bool HasMainKBlockLoop,
InMemoryDataOperationEnum CGlobalMemoryDataOperation,

129
include/ck_tile/core/arch/arch.hpp Normal file → Executable file
View File

@@ -136,66 +136,103 @@ CK_TILE_DEVICE void block_sync_load_raw(index_t cnt = 0)
#endif
}
// https://llvm.org/docs/AMDGPU/gfx9_waitcnt.html
struct WaitcntLayoutGfx12
{ // s_wait_loadcnt_dscnt: mem[13:8], ds[5:0]
CK_TILE_DEVICE static constexpr index_t VM_MASK = 0x3F; // mem
CK_TILE_DEVICE static constexpr index_t LGKM_MASK = 0x3F; // ds
CK_TILE_DEVICE static constexpr bool HAS_EXP = false;
CK_TILE_DEVICE static constexpr index_t pack_vm(index_t c) { return ((c & VM_MASK) << 8); }
CK_TILE_DEVICE static constexpr index_t pack_lgkm(index_t c) { return ((c & LGKM_MASK) << 0); }
CK_TILE_DEVICE static constexpr index_t pack_exp(index_t) { return 0; }
};
struct WaitcntLayoutGfx11
{ // vm[15:10] (6), lgkm[9:4] (6), exp unused
CK_TILE_DEVICE static constexpr index_t VM_MASK = 0x3F;
CK_TILE_DEVICE static constexpr index_t LGKM_MASK = 0x3F;
CK_TILE_DEVICE static constexpr bool HAS_EXP = false;
CK_TILE_DEVICE static constexpr index_t pack_vm(index_t c) { return ((c & VM_MASK) << 10); }
CK_TILE_DEVICE static constexpr index_t pack_lgkm(index_t c) { return ((c & LGKM_MASK) << 4); }
CK_TILE_DEVICE static constexpr index_t pack_exp(index_t) { return 0; }
};
struct WaitcntLayoutLegacy
{ // FE'DC'BA98'7'654'3210 => VV'UU'LLLL'U'EEE'VVVV
CK_TILE_DEVICE static constexpr index_t VM_MASK = 0x3F; // split: low4 + hi2
CK_TILE_DEVICE static constexpr index_t LGKM_MASK = 0x0F; // [11:8]
CK_TILE_DEVICE static constexpr index_t EXP_MASK = 0x07; // [6:4]
CK_TILE_DEVICE static constexpr bool HAS_EXP = true;
CK_TILE_DEVICE static constexpr index_t pack_vm(index_t c)
{
c &= VM_MASK;
return ((c & 0xF) << 0) | ((c & 0x30) << 10);
}
CK_TILE_DEVICE static constexpr index_t pack_lgkm(index_t c) { return ((c & LGKM_MASK) << 8); }
CK_TILE_DEVICE static constexpr index_t pack_exp(index_t c) { return ((c & EXP_MASK) << 4); }
};
// Select active layout
#if defined(__gfx12__)
using Waitcnt = WaitcntLayoutGfx12;
#elif defined(__gfx11__)
using Waitcnt = WaitcntLayoutGfx11;
#else
using Waitcnt = WaitcntLayoutLegacy;
#endif
//----------------------------------------------
// Public API: only from_* (constexpr templates)
//----------------------------------------------
struct waitcnt_arg
{
#if defined(__gfx12__)
// use s_wait_loadcnt_dscnt in this instruction; in this instruction, ds [5:0]; mem [13:8]
CK_TILE_DEVICE static constexpr index_t MAX = 0b00'111111'00'111111;
CK_TILE_DEVICE static constexpr index_t kMaxVmCnt = 0b111111;
CK_TILE_DEVICE static constexpr index_t kMaxExpCnt = 0b111;
CK_TILE_DEVICE static constexpr index_t kMaxLgkmCnt = 0b111111;
template <index_t cnt>
CK_TILE_DEVICE static constexpr index_t from_vmcnt()
{
static_assert(cnt >= 0 && !(cnt >> 6), "valid range is [0..63]");
return MAX & (cnt << 8);
}
template <index_t cnt>
CK_TILE_DEVICE static constexpr index_t from_expcnt()
{
return 0; // no export in MI series
}
template <index_t cnt>
CK_TILE_DEVICE static constexpr index_t from_lgkmcnt()
{
static_assert(cnt >= 0 && !(cnt >> 6), "valid range is [0..63]");
return MAX & cnt;
}
// kMax* exposed for callers; match field widths per-arch
#if defined(__gfx12__) || defined(__gfx11__)
CK_TILE_DEVICE static constexpr index_t kMaxVmCnt = 0x3F; // 6 bits
CK_TILE_DEVICE static constexpr index_t kMaxLgkmCnt = 0x3F; // 6 bits
CK_TILE_DEVICE static constexpr index_t kMaxExpCnt = 0x0; // none
#else
// bit numbers (hex) -------------------------> FE'DC'BA98'7'654'3210
// [V]M [E]XP [L]GKM counters and [U]NUSED ---> VV'UU'LLLL'U'EEE'VVVV
CK_TILE_DEVICE static constexpr index_t MAX = 0b11'00'1111'0'111'1111;
CK_TILE_DEVICE static constexpr index_t kMaxVmCnt = 0b111111;
CK_TILE_DEVICE static constexpr index_t kMaxExpCnt = 0b111;
CK_TILE_DEVICE static constexpr index_t kMaxLgkmCnt = 0b1111;
CK_TILE_DEVICE static constexpr index_t kMaxVmCnt = 0x3F; // 6 bits (split)
CK_TILE_DEVICE static constexpr index_t kMaxLgkmCnt = 0x0F; // 4 bits
CK_TILE_DEVICE static constexpr index_t kMaxExpCnt = 0x07; // 3 bits
#endif
template <index_t cnt>
CK_TILE_DEVICE static constexpr index_t from_vmcnt()
{
static_assert(cnt >= 0 && !(cnt >> 6), "valid range is [0..63]");
return MAX & ((cnt & 0b1111) | ((cnt & 0b110000) << 10));
}
template <index_t cnt>
CK_TILE_DEVICE static constexpr index_t from_expcnt()
{
static_assert(cnt >= 0 && !(cnt >> 3), "valid range is [0..7]");
return MAX & (cnt << 4);
static_assert((cnt & ~Waitcnt::VM_MASK) == 0, "vmcnt out of range");
return Waitcnt::pack_vm(cnt);
}
template <index_t cnt>
CK_TILE_DEVICE static constexpr index_t from_lgkmcnt()
{
static_assert(cnt >= 0 && !(cnt >> 4), "valid range is [0..15]");
return MAX & (cnt << 8);
static_assert((cnt & ~Waitcnt::LGKM_MASK) == 0, "lgkmcnt out of range");
return Waitcnt::pack_lgkm(cnt);
}
template <index_t cnt>
CK_TILE_DEVICE static constexpr index_t from_expcnt()
{
if constexpr(Waitcnt::HAS_EXP)
{
// EXP_MASK only exists on legacy
#if !defined(__gfx12__) && !defined(__gfx11__)
static_assert((cnt & ~Waitcnt::EXP_MASK) == 0, "expcnt out of range");
return Waitcnt::pack_exp(cnt);
#else
(void)cnt;
return 0;
#endif
}
else
{
static_assert(cnt == 0, "expcnt unsupported on this arch");
return 0;
}
}
};
template <index_t vmcnt = waitcnt_arg::kMaxVmCnt,

View File

@@ -102,11 +102,14 @@ struct static_counter_uniq_;
}
#define MAKE_SC() \
ck_tile::static_counter<ck_tile::impl::static_counter_uniq_<__COUNTER__>> {}
#define MAKE_SC_WITH(start_, step_) \
ck_tile::static_counter<ck_tile::impl::static_counter_uniq_<__COUNTER__>, start_, step_> {}
#define NEXT_SC(c_) c_.next<__COUNTER__>()
#define NEXT_SCI(c_, static_i_) c_.next<__COUNTER__ + static_i_>()
__extension__ ck_tile::static_counter<ck_tile::impl::static_counter_uniq_<__COUNTER__>> {}
#define MAKE_SC_WITH(start_, step_) \
__extension__ ck_tile:: \
static_counter<ck_tile::impl::static_counter_uniq_<__COUNTER__>, start_, step_> \
{ \
}
#define NEXT_SC(c_) __extension__ c_.next<__COUNTER__>()
#define NEXT_SCI(c_, static_i_) __extension__ c_.next<__COUNTER__ + static_i_>()
// Usage:
// constexpr auto c = MAKE_SC()

View File

@@ -74,6 +74,21 @@ struct GroupedConvTraits
}
public:
// Fixed values for Implicit GEMM
struct FixedGemmParams
{
static constexpr ck_tile::index_t TilePartitionerGroupNum = 8;
static constexpr ck_tile::index_t TilePartitionerM01 = 4;
static constexpr bool kPadM = true;
static constexpr bool kPadN = true;
static constexpr bool kPadK = true;
static constexpr bool TransposeC = false;
static constexpr bool FixedVectorSize = true;
static constexpr bool UseStructuredSparsity = false;
static constexpr bool Persistent = false;
using ELayout = ck_tile::tensor_layout::gemm::RowMajor;
};
// Compile time parameters
static constexpr bool EnableSplitImage = EnableSplitImage_;
static constexpr index_t NumGroupsToMerge = NumGroupsToMerge_;
static constexpr index_t NDimSpatial = NDimSpatial_;
@@ -82,31 +97,43 @@ struct GroupedConvTraits
using WeiLayout = WeiLayout_;
using DsLayout = DsLayout_;
using OutLayout = OutLayout_;
// Forward Gemm Layouts
using AsLayoutFwd = ck_tile::tensor_layout::gemm::RowMajor;
using BsLayoutFwd = ck_tile::tensor_layout::gemm::ColumnMajor;
using CLayoutFwd = ck_tile::tensor_layout::gemm::RowMajor;
// Backward Data Gemm Layouts
using AsLayoutBwdData = ck_tile::tensor_layout::gemm::RowMajor;
using BsLayoutBwdData = ck_tile::tensor_layout::gemm::RowMajor;
using CLayoutBwdData = ck_tile::tensor_layout::gemm::RowMajor;
// Backward Weight Gemm Layouts
using AsLayoutBwdWeight = ck_tile::tensor_layout::gemm::ColumnMajor;
using BsLayoutBwdWeight = ck_tile::tensor_layout::gemm::RowMajor;
using CLayoutBwdWeight = ck_tile::tensor_layout::gemm::RowMajor;
template <ck_tile::index_t NumWaveGroups = 1>
using GroupedConvImplicitGemmTraitsFwd =
TileGemmTraits<true,
true,
true,
ck_tile::tensor_layout::gemm::RowMajor,
ck_tile::tensor_layout::gemm::ColumnMajor,
ck_tile::tensor_layout::gemm::RowMajor>;
using GroupedConvImplicitGemmTraitsBwdData =
TileGemmTraits<true,
true,
true,
ck_tile::tensor_layout::gemm::RowMajor,
ck_tile::tensor_layout::gemm::RowMajor,
ck_tile::tensor_layout::gemm::RowMajor>;
using GroupedConvImplicitGemmTraitsBwdWeight =
TileGemmTraits<true,
true,
true,
ck_tile::tensor_layout::gemm::ColumnMajor,
ck_tile::tensor_layout::gemm::RowMajor,
ck_tile::tensor_layout::gemm::RowMajor>;
TileGemmTraits<true, true, true, AsLayoutFwd, BsLayoutFwd, CLayoutFwd, NumWaveGroups>;
template <ck_tile::index_t NumWaveGroups = 1>
using GroupedConvImplicitGemmTraitsBwdData = TileGemmTraits<true,
true,
true,
AsLayoutBwdData,
BsLayoutBwdData,
CLayoutBwdData,
NumWaveGroups>;
template <ck_tile::index_t NumWaveGroups = 1>
using GroupedConvImplicitGemmTraitsBwdWeight = TileGemmTraits<true,
true,
true,
AsLayoutBwdWeight,
BsLayoutBwdWeight,
CLayoutBwdWeight,
NumWaveGroups>;
static constexpr ck_tile::index_t VectorSizeA = VectorSizeA_;
static constexpr ck_tile::index_t VectorSizeB = VectorSizeB_;
static constexpr ck_tile::index_t VectorSizeC = VectorSizeC_;
static constexpr index_t NumDTensor = DsLayout::size();
static constexpr ck_tile::index_t NumDTensor = DsLayout::size();
using ImplicitGemmDsLayout = decltype(generate_implicit_gemm_layout());
};

View File

@@ -74,6 +74,6 @@ class ProfilerOperationRegistry final
#define PP_CONCAT(x, y) PP_CONCAT_IMPL(x, y)
#define PP_CONCAT_IMPL(x, y) x##y
#define REGISTER_PROFILER_OPERATION(name, description, operation) \
static const bool PP_CONCAT(operation_registration_result_, __COUNTER__) = \
#define REGISTER_PROFILER_OPERATION(name, description, operation) \
__extension__ static const bool PP_CONCAT(operation_registration_result_, __COUNTER__) = \
::ProfilerOperationRegistry::GetInstance().Add(name, description, operation)