From 2234ff830b2f4ce8026c50b2d81f95f38f7117e5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bart=C5=82omiej=20Kocot?= Date: Thu, 6 Nov 2025 11:26:30 +0100 Subject: [PATCH 1/6] [CK TILE] Convolution remove magic values (#3160) * [CK TILE] Refactor Conv configs and Conv Elementwise * fix * [CK TILE] Convolution remove magix values * fix partitioner --- .../20_grouped_convolution/conv_configs.hpp | 16 +- ...uped_convolution_backward_data_invoker.hpp | 209 +++++++++--------- ...ed_convolution_backward_weight_invoker.hpp | 90 ++++---- ...tion_backward_weight_two_stage_invoker.hpp | 108 +++++---- .../grouped_convolution_forward_invoker.hpp | 96 ++++---- ...nvolution_forward_large_tensor_invoker.hpp | 119 +++++----- .../utils/grouped_convolution_utils.hpp | 69 ++++-- 7 files changed, 355 insertions(+), 352 deletions(-) diff --git a/example/ck_tile/20_grouped_convolution/conv_configs.hpp b/example/ck_tile/20_grouped_convolution/conv_configs.hpp index b9edf247cc..8a2a60a197 100644 --- a/example/ck_tile/20_grouped_convolution/conv_configs.hpp +++ b/example/ck_tile/20_grouped_convolution/conv_configs.hpp @@ -14,19 +14,11 @@ struct ConvConfigBase { - static constexpr bool kPadM = true; - static constexpr bool kPadN = true; - static constexpr bool kPadK = true; - - static constexpr bool TransposeC = false; - static constexpr ck_tile::index_t VectorSizeA = 4; static constexpr ck_tile::index_t VectorSizeB = 8; static constexpr ck_tile::index_t VectorSizeC = 8; - static constexpr int kBlockPerCu = 1; - static constexpr ck_tile::index_t TileParitionerGroupNum = 8; - static constexpr ck_tile::index_t TileParitionerM01 = 4; + static constexpr int kBlockPerCu = 1; static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave; static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V3; static constexpr ck_tile::index_t NumWaveGroups = 1; @@ -210,9 +202,9 @@ struct ConvConfigComputeV5 : public ConvConfigBase static constexpr ck_tile::index_t N_Warp_Tile = 32; static constexpr ck_tile::index_t K_Warp_Tile = 16; - static constexpr bool DoubleSmemBuffer = false; - static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V5; - static constexpr ck_tile::index_t NumWaNumWaveGroups = 2; + static constexpr bool DoubleSmemBuffer = false; + static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V5; + static constexpr ck_tile::index_t NumWaveGroups = 2; }; template diff --git a/example/ck_tile/20_grouped_convolution/grouped_convolution_backward_data_invoker.hpp b/example/ck_tile/20_grouped_convolution/grouped_convolution_backward_data_invoker.hpp index 14a533ffc9..d19d3ac8ec 100644 --- a/example/ck_tile/20_grouped_convolution/grouped_convolution_backward_data_invoker.hpp +++ b/example/ck_tile/20_grouped_convolution/grouped_convolution_backward_data_invoker.hpp @@ -22,8 +22,6 @@ struct GroupedConvolutionBackwardDataInvoker static float grouped_conv_bwd_data(const ck_tile::GroupedConvBwdDataHostArgs& args, const ck_tile::stream_config& s) { - constexpr int kBlockPerCu = 1; - // Implicit GEMM Traits using GemmShape = ck_tile::TileGemmShape< ck_tile::sequence, @@ -32,36 +30,33 @@ struct GroupedConvolutionBackwardDataInvoker ConvConfig::N_Warp_Tile, ConvConfig::K_Warp_Tile>>; - constexpr ck_tile::index_t VectorSizeA = 8; - constexpr ck_tile::index_t VectorSizeB = 8; - constexpr ck_tile::index_t VectorSizeC = 8; - - constexpr auto ConvSpec = ck_tile::ConvolutionSpecialization::Default; - using TilePartitioner = - ck_tile::GemmSpatiallyLocalTilePartitioner; + constexpr auto ConvSpec = ck_tile::ConvolutionSpecialization::Default; using GroupedConvTraitsType = ck_tile::GroupedConvTraits; + ConvConfig::VectorSizeA, + ConvConfig::VectorSizeB, + ConvConfig::VectorSizeC>; + + using TilePartitioner = ck_tile::GemmSpatiallyLocalTilePartitioner< + GemmShape, + GroupedConvTraitsType::FixedGemmParams::TilePartitionerGroupNum, + GroupedConvTraitsType::FixedGemmParams::TilePartitionerM01>; using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits< - ConvConfig::kPadM, - ConvConfig::kPadN, - ConvConfig::kPadK, + GroupedConvTraitsType::FixedGemmParams::kPadM, + GroupedConvTraitsType::FixedGemmParams::kPadN, + GroupedConvTraitsType::FixedGemmParams::kPadK, ConvConfig::DoubleSmemBuffer, - typename GroupedConvTraitsType::GroupedConvImplicitGemmTraitsBwdData::AsLayout, - typename GroupedConvTraitsType::GroupedConvImplicitGemmTraitsBwdData::BsLayout, - typename GroupedConvTraitsType::GroupedConvImplicitGemmTraitsBwdData::CLayout, - ConvConfig::TransposeC, - false, - false, // Persistent, + typename GroupedConvTraitsType::AsLayoutBwdData, + typename GroupedConvTraitsType::BsLayoutBwdData, + typename GroupedConvTraitsType::CLayoutBwdData, + GroupedConvTraitsType::FixedGemmParams::TransposeC, + GroupedConvTraitsType::FixedGemmParams::UseStructuredSparsity, + GroupedConvTraitsType::FixedGemmParams::Persistent, ConvConfig::NumWaveGroups>; using GemmPipelineProblem = ck_tile::GemmPipelineProblem< @@ -69,13 +64,14 @@ struct GroupedConvolutionBackwardDataInvoker WeiDataType, AccDataType, GemmShape, - typename GroupedConvTraitsType::GroupedConvImplicitGemmTraitsBwdData, + typename GroupedConvTraitsType::template GroupedConvImplicitGemmTraitsBwdData< + ConvConfig::NumWaveGroups>, ck_tile::element_wise::PassThrough, ck_tile::element_wise::PassThrough, InDataType, - true, - VectorSizeA, - VectorSizeB>; + GroupedConvTraitsType::FixedGemmParams::FixedVectorSize, + GroupedConvTraitsType::VectorSizeA, + GroupedConvTraitsType::VectorSizeB>; using BaseGemmPipeline = typename PipelineTypeTraits< ConvConfig::Pipeline>::template UniversalGemmPipeline; @@ -93,95 +89,96 @@ struct GroupedConvolutionBackwardDataInvoker const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop); float ave_time{0}; - const auto Run = - [&](const auto has_hot_loop_, const auto tail_number_, const auto memory_operation_) { - constexpr bool has_hot_loop_v = has_hot_loop_.value; - constexpr auto tail_number_v = tail_number_.value; - constexpr auto scheduler = ConvConfig::Scheduler; - constexpr auto memory_operation = memory_operation_.value; + const auto Run = [&](const auto has_hot_loop_, + const auto tail_number_, + const auto memory_operation_) { + constexpr bool has_hot_loop_v = has_hot_loop_.value; + constexpr auto tail_number_v = tail_number_.value; + constexpr auto scheduler = ConvConfig::Scheduler; + constexpr auto memory_operation = memory_operation_.value; - using UniversalGemmProblem = - ck_tile::UniversalGemmPipelineProblem; + using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem< + OutDataType, + WeiDataType, + AccDataType, + GemmShape, + GemmUniversalTraits, + scheduler, + has_hot_loop_v, + tail_number_v, + ck_tile::element_wise::PassThrough, + ck_tile::element_wise::PassThrough, + InDataType, + GroupedConvTraitsType::FixedGemmParams::FixedVectorSize, + GroupedConvTraitsType::VectorSizeA, + GroupedConvTraitsType::VectorSizeB>; - using GemmPipeline = typename PipelineTypeTraits< - ConvConfig::Pipeline>::template GemmPipeline; + using GemmPipeline = typename PipelineTypeTraits< + ConvConfig::Pipeline>::template GemmPipeline; - using ConvEpilogue = ck_tile::CShuffleEpilogue>; + using ConvEpilogue = ck_tile::CShuffleEpilogue>; - using Kernel = ck_tile::GroupedConvolutionBackwardDataKernel; - auto kargs = Kernel::MakeKernelArgs(args); + using Kernel = ck_tile::GroupedConvolutionBackwardDataKernel; + auto kargs = Kernel::MakeKernelArgs(args); - const dim3 grids = Kernel::GridSize(args); - const dim3 blocks = Kernel::BlockSize(); + const dim3 grids = Kernel::GridSize(args); + const dim3 blocks = Kernel::BlockSize(); - if(!Kernel::IsSupportedArgument(kargs)) - { - throw std::runtime_error("Wrong! Arguments not supported! Skipping conv!\n"); - } + if(!Kernel::IsSupportedArgument(kargs)) + { + throw std::runtime_error("Wrong! Arguments not supported! Skipping conv!\n"); + } - if(s.log_level_ > 0) - { - std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n' - << "shape: " << GemmShape::GetName() << '\n' - << "problem: " << UniversalGemmProblem::GetName() << '\n' - << "pipeline: " << GemmPipeline::GetName() << '\n' - << "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" - << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z - << "}" << '\n' - << "Vector size A: " << GemmPipeline::GetVectorSizeA() - << ", Vector size B: " << GemmPipeline::GetVectorSizeB() - << ", Vector size C: " << ConvEpilogue::GetVectorSizeC() << std::endl; - } + if(s.log_level_ > 0) + { + std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n' + << "shape: " << GemmShape::GetName() << '\n' + << "problem: " << UniversalGemmProblem::GetName() << '\n' + << "pipeline: " << GemmPipeline::GetName() << '\n' + << "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" + << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z + << "}" << '\n' + << "Vector size A: " << GemmPipeline::GetVectorSizeA() + << ", Vector size B: " << GemmPipeline::GetVectorSizeB() + << ", Vector size C: " << ConvEpilogue::GetVectorSizeC() << std::endl; + } - auto preprocess = [&]() { - ck_tile::hip_check_error(hipMemsetAsync( - kargs.in_ptr, 0, args.template GetInputByte(), s.stream_id_)); - }; - - ave_time = ck_tile::launch_kernel_time_mask( - s, - preprocess, - ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); - - return ave_time; + auto preprocess = [&]() { + ck_tile::hip_check_error(hipMemsetAsync( + kargs.in_ptr, 0, args.template GetInputByte(), s.stream_id_)); }; + ave_time = ck_tile::launch_kernel_time_mask( + s, + preprocess, + ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); + + return ave_time; + }; + const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) { if(args.k_batch == 1) { diff --git a/example/ck_tile/20_grouped_convolution/grouped_convolution_backward_weight_invoker.hpp b/example/ck_tile/20_grouped_convolution/grouped_convolution_backward_weight_invoker.hpp index 0e777c5f8a..81b9d402ce 100644 --- a/example/ck_tile/20_grouped_convolution/grouped_convolution_backward_weight_invoker.hpp +++ b/example/ck_tile/20_grouped_convolution/grouped_convolution_backward_weight_invoker.hpp @@ -21,8 +21,6 @@ struct GroupedConvolutionBackwardWeightInvoker static float grouped_conv_bwd_weight(const ck_tile::GroupedConvBwdWeightHostArgs& args, const ck_tile::stream_config& s) { - constexpr int kBlockPerCu = 1; - // Implicit GEMM Traits using GemmShape = ck_tile::TileGemmShape< ck_tile::sequence, @@ -31,37 +29,34 @@ struct GroupedConvolutionBackwardWeightInvoker ConvConfig::N_Warp_Tile, ConvConfig::K_Warp_Tile>>; - constexpr ck_tile::index_t VectorSizeA = ConvConfig::VectorSizeA; - constexpr ck_tile::index_t VectorSizeB = ConvConfig::VectorSizeB; - constexpr ck_tile::index_t VectorSizeC = ConvConfig::VectorSizeC; - - constexpr auto ConvSpec = ck_tile::ConvolutionSpecialization::Default; - using TilePartitioner = - ck_tile::GemmSpatiallyLocalTilePartitioner; + constexpr auto ConvSpec = ck_tile::ConvolutionSpecialization::Default; using GroupedConvTraitsType = ck_tile::GroupedConvTraits; + using TilePartitioner = ck_tile::GemmSpatiallyLocalTilePartitioner< + GemmShape, + GroupedConvTraitsType::FixedGemmParams::TilePartitionerGroupNum, + GroupedConvTraitsType::FixedGemmParams::TilePartitionerM01>; + using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits< - ConvConfig::kPadM, - ConvConfig::kPadN, - ConvConfig::kPadK, + GroupedConvTraitsType::FixedGemmParams::kPadM, + GroupedConvTraitsType::FixedGemmParams::kPadN, + GroupedConvTraitsType::FixedGemmParams::kPadK, ConvConfig::DoubleSmemBuffer, - typename GroupedConvTraitsType::GroupedConvImplicitGemmTraitsBwdWeight::AsLayout, - typename GroupedConvTraitsType::GroupedConvImplicitGemmTraitsBwdWeight::BsLayout, - typename GroupedConvTraitsType::GroupedConvImplicitGemmTraitsBwdWeight::CLayout, - ConvConfig::TransposeC, - false, - false, // Persistent, + typename GroupedConvTraitsType::AsLayoutBwdWeight, + typename GroupedConvTraitsType::BsLayoutBwdWeight, + typename GroupedConvTraitsType::CLayoutBwdWeight, + GroupedConvTraitsType::FixedGemmParams::TransposeC, + GroupedConvTraitsType::FixedGemmParams::UseStructuredSparsity, + GroupedConvTraitsType::FixedGemmParams::Persistent, ConvConfig::NumWaveGroups>; using GemmPipelineProblem = ck_tile::GemmPipelineProblem< @@ -69,13 +64,14 @@ struct GroupedConvolutionBackwardWeightInvoker InDataType, AccDataType, GemmShape, - typename GroupedConvTraitsType::GroupedConvImplicitGemmTraitsBwdWeight, + typename GroupedConvTraitsType::template GroupedConvImplicitGemmTraitsBwdWeight< + ConvConfig::NumWaveGroups>, ck_tile::element_wise::PassThrough, ck_tile::element_wise::PassThrough, WeiDataType, - true, - VectorSizeA, - VectorSizeB>; + GroupedConvTraitsType::FixedGemmParams::FixedVectorSize, + GroupedConvTraitsType::VectorSizeA, + GroupedConvTraitsType::VectorSizeB>; using BaseGemmPipeline = typename PipelineTypeTraits< ConvConfig::Pipeline>::template UniversalGemmPipeline; @@ -101,21 +97,21 @@ struct GroupedConvolutionBackwardWeightInvoker constexpr auto scheduler = ConvConfig::Scheduler; constexpr auto memory_operation = memory_operation_.value; - using UniversalGemmProblem = - ck_tile::UniversalGemmPipelineProblem; + using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem< + OutDataType, + InDataType, + AccDataType, + GemmShape, + GemmUniversalTraits, + scheduler, + has_hot_loop_v, + tail_number_v, + ck_tile::element_wise::PassThrough, + ck_tile::element_wise::PassThrough, + WeiDataType, + GroupedConvTraitsType::FixedGemmParams::FixedVectorSize, + GroupedConvTraitsType::VectorSizeA, + GroupedConvTraitsType::VectorSizeB>; using GemmPipeline = typename PipelineTypeTraits< ConvConfig::Pipeline>::template GemmPipeline; @@ -127,7 +123,7 @@ struct GroupedConvolutionBackwardWeightInvoker AccDataType, WeiDataType, typename GroupedConvTraitsType::ImplicitGemmDsLayout, - ck_tile::tensor_layout::gemm::RowMajor, + typename GroupedConvTraitsType::FixedGemmParams::ELayout, CDEElementWise, TilePartitioner::MPerBlock, TilePartitioner::NPerBlock, @@ -136,10 +132,10 @@ struct GroupedConvolutionBackwardWeightInvoker ConvConfig::M_Warp_Tile, ConvConfig::N_Warp_Tile, ConvConfig::K_Warp_Tile, - ConvConfig::TransposeC, + GroupedConvTraitsType::FixedGemmParams::TransposeC, memory_operation, - 1, - true, + ConvConfig::NumWaveGroups, + GroupedConvTraitsType::FixedGemmParams::FixedVectorSize, GroupedConvTraitsType::VectorSizeC>>; using Kernel = ck_tile::GroupedConvolutionBackwardWeightKernel(Kernel{}, grids, blocks, 0, kargs)); + ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); return ave_time; }; diff --git a/example/ck_tile/20_grouped_convolution/grouped_convolution_backward_weight_two_stage_invoker.hpp b/example/ck_tile/20_grouped_convolution/grouped_convolution_backward_weight_two_stage_invoker.hpp index a8e41438c8..8cef2bde65 100644 --- a/example/ck_tile/20_grouped_convolution/grouped_convolution_backward_weight_two_stage_invoker.hpp +++ b/example/ck_tile/20_grouped_convolution/grouped_convolution_backward_weight_two_stage_invoker.hpp @@ -23,8 +23,6 @@ struct GroupedConvolutionBackwardWeightTwoStageInvoker { using WorkspaceDataType = float; - constexpr int kBlockPerCu = 1; - // Implicit GEMM Traits using GemmShape = ck_tile::TileGemmShape< ck_tile::sequence, @@ -33,36 +31,34 @@ struct GroupedConvolutionBackwardWeightTwoStageInvoker ConvConfig::N_Warp_Tile, ConvConfig::K_Warp_Tile>>; - constexpr ck_tile::index_t VectorSizeA = 4; - constexpr ck_tile::index_t VectorSizeB = 8; - constexpr ck_tile::index_t VectorSizeC = 8; - - constexpr auto ConvSpec = ck_tile::ConvolutionSpecialization::Default; - using TilePartitioner = - ck_tile::GemmSpatiallyLocalTilePartitioner; + constexpr auto ConvSpec = ck_tile::ConvolutionSpecialization::Default; using GroupedConvTraitsType = ck_tile::GroupedConvTraits; + ConvConfig::VectorSizeA, + ConvConfig::VectorSizeB, + ConvConfig::VectorSizeC, + ConvConfig::NumGroupsToMerge>; + + using TilePartitioner = ck_tile::GemmSpatiallyLocalTilePartitioner< + GemmShape, + GroupedConvTraitsType::FixedGemmParams::TilePartitionerGroupNum, + GroupedConvTraitsType::FixedGemmParams::TilePartitionerM01>; using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits< - ConvConfig::kPadM, - ConvConfig::kPadN, - ConvConfig::kPadK, + GroupedConvTraitsType::FixedGemmParams::kPadM, + GroupedConvTraitsType::FixedGemmParams::kPadN, + GroupedConvTraitsType::FixedGemmParams::kPadK, ConvConfig::DoubleSmemBuffer, - typename GroupedConvTraitsType::GroupedConvImplicitGemmTraitsBwdWeight::AsLayout, - typename GroupedConvTraitsType::GroupedConvImplicitGemmTraitsBwdWeight::BsLayout, - typename GroupedConvTraitsType::GroupedConvImplicitGemmTraitsBwdWeight::CLayout, - ConvConfig::TransposeC, - false, - false, // Persistent, + typename GroupedConvTraitsType::AsLayoutBwdWeight, + typename GroupedConvTraitsType::BsLayoutBwdWeight, + typename GroupedConvTraitsType::CLayoutBwdWeight, + GroupedConvTraitsType::FixedGemmParams::TransposeC, + GroupedConvTraitsType::FixedGemmParams::UseStructuredSparsity, + GroupedConvTraitsType::FixedGemmParams::Persistent, ConvConfig::NumWaveGroups>; using GemmPipelineProblem = ck_tile::GemmPipelineProblem< @@ -70,13 +66,14 @@ struct GroupedConvolutionBackwardWeightTwoStageInvoker InDataType, AccDataType, GemmShape, - typename GroupedConvTraitsType::GroupedConvImplicitGemmTraitsBwdWeight, + typename GroupedConvTraitsType::template GroupedConvImplicitGemmTraitsBwdWeight< + ConvConfig::NumWaveGroups>, ck_tile::element_wise::PassThrough, ck_tile::element_wise::PassThrough, WeiDataType, - true, - VectorSizeA, - VectorSizeB>; + GroupedConvTraitsType::FixedGemmParams::FixedVectorSize, + GroupedConvTraitsType::VectorSizeA, + GroupedConvTraitsType::VectorSizeB>; using BaseGemmPipeline = typename PipelineTypeTraits< ConvConfig::Pipeline>::template UniversalGemmPipeline; @@ -102,21 +99,21 @@ struct GroupedConvolutionBackwardWeightTwoStageInvoker constexpr auto scheduler = ConvConfig::Scheduler; constexpr auto memory_operation = memory_operation_.value; - using UniversalGemmProblem = - ck_tile::UniversalGemmPipelineProblem; + using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem< + OutDataType, + InDataType, + AccDataType, + GemmShape, + GemmUniversalTraits, + scheduler, + has_hot_loop_v, + tail_number_v, + ck_tile::element_wise::PassThrough, + ck_tile::element_wise::PassThrough, + WeiDataType, + GroupedConvTraitsType::FixedGemmParams::FixedVectorSize, + GroupedConvTraitsType::VectorSizeA, + GroupedConvTraitsType::VectorSizeB>; using GemmPipeline = typename PipelineTypeTraits< ConvConfig::Pipeline>::template GemmPipeline; @@ -128,7 +125,7 @@ struct GroupedConvolutionBackwardWeightTwoStageInvoker AccDataType, WorkspaceDataType, // C: Workspace normally Out typename GroupedConvTraitsType::ImplicitGemmDsLayout, - ck_tile::tensor_layout::gemm::RowMajor, + typename GroupedConvTraitsType::FixedGemmParams::ELayout, CDEElementWise, TilePartitioner::MPerBlock, TilePartitioner::NPerBlock, @@ -139,8 +136,8 @@ struct GroupedConvolutionBackwardWeightTwoStageInvoker ConvConfig::K_Warp_Tile, GemmPipelineProblem::TransposeC, memory_operation, - 1, - true, + ConvConfig::NumWaveGroups, + GroupedConvTraitsType::FixedGemmParams::FixedVectorSize, GroupedConvTraitsType::VectorSizeC>>; using Kernel = ck_tile::GroupedConvolutionBackwardWeightKernel(Kernel{}, grids, blocks, 0, kargs), - ck_tile::make_kernel(ElementwiseKernel{}, - kGridSize, - kBlockSize, - 0, - input_size, - ck_tile::make_tuple(shape[1], 1), // Input Stride - ck_tile::make_tuple(shape[1], 1), // Output Stride - input_tensors, - static_cast(c_ptr))); + ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs), + ck_tile::make_kernel( + ElementwiseKernel{}, + kGridSize, + kBlockSize, + 0, + input_size, + ck_tile::make_tuple(shape[1], 1), // Input Stride + ck_tile::make_tuple(shape[1], 1), // Output Stride + input_tensors, + static_cast(c_ptr))); return ave_time; }; diff --git a/example/ck_tile/20_grouped_convolution/grouped_convolution_forward_invoker.hpp b/example/ck_tile/20_grouped_convolution/grouped_convolution_forward_invoker.hpp index 2290f60d1f..7c8269d13c 100644 --- a/example/ck_tile/20_grouped_convolution/grouped_convolution_forward_invoker.hpp +++ b/example/ck_tile/20_grouped_convolution/grouped_convolution_forward_invoker.hpp @@ -32,8 +32,6 @@ struct GroupedConvolutionForwardInvoker { std::cout << "[INVOKER] grouped_conv_fwd called, NDimSpatial=" << NDimSpatial << "\n"; } - constexpr int kBlockPerCu = 1; - // Implicit GEMM Traits using GemmShape = ck_tile::TileGemmShape< ck_tile::sequence, @@ -42,38 +40,34 @@ struct GroupedConvolutionForwardInvoker ConvConfig::N_Warp_Tile, ConvConfig::K_Warp_Tile>>; - constexpr ck_tile::index_t VectorSizeA = 8; - constexpr ck_tile::index_t VectorSizeB = 8; - constexpr ck_tile::index_t VectorSizeC = 8; - constexpr ck_tile::index_t NumGroupsToMerge = 1; - - constexpr auto ConvSpec = ck_tile::ConvolutionSpecialization::Default; - using TilePartitioner = - ck_tile::GemmSpatiallyLocalTilePartitioner; + constexpr auto ConvSpec = ck_tile::ConvolutionSpecialization::Default; using GroupedConvTraitsType = ck_tile::GroupedConvTraits; + ConvConfig::VectorSizeA, + ConvConfig::VectorSizeB, + ConvConfig::VectorSizeC, + ConvConfig::NumGroupsToMerge>; + + using TilePartitioner = ck_tile::GemmSpatiallyLocalTilePartitioner< + GemmShape, + GroupedConvTraitsType::FixedGemmParams::TilePartitionerGroupNum, + GroupedConvTraitsType::FixedGemmParams::TilePartitionerM01>; using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits< - ConvConfig::kPadM, - ConvConfig::kPadN, - ConvConfig::kPadK, + GroupedConvTraitsType::FixedGemmParams::kPadM, + GroupedConvTraitsType::FixedGemmParams::kPadN, + GroupedConvTraitsType::FixedGemmParams::kPadK, ConvConfig::DoubleSmemBuffer, - typename GroupedConvTraitsType::GroupedConvImplicitGemmTraitsFwd::AsLayout, - typename GroupedConvTraitsType::GroupedConvImplicitGemmTraitsFwd::BsLayout, - typename GroupedConvTraitsType::GroupedConvImplicitGemmTraitsFwd::CLayout, - ConvConfig::TransposeC, - false, - false, // Persistent, + typename GroupedConvTraitsType::AsLayoutFwd, + typename GroupedConvTraitsType::BsLayoutFwd, + typename GroupedConvTraitsType::CLayoutFwd, + GroupedConvTraitsType::FixedGemmParams::TransposeC, + GroupedConvTraitsType::FixedGemmParams::UseStructuredSparsity, + GroupedConvTraitsType::FixedGemmParams::Persistent, ConvConfig::NumWaveGroups>; using GemmPipelineProblem = ck_tile::GemmPipelineProblem< @@ -81,13 +75,14 @@ struct GroupedConvolutionForwardInvoker WeiDataType, AccDataType, GemmShape, - typename GroupedConvTraitsType::GroupedConvImplicitGemmTraitsFwd, + typename GroupedConvTraitsType::template GroupedConvImplicitGemmTraitsFwd< + ConvConfig::NumWaveGroups>, ck_tile::element_wise::PassThrough, ck_tile::element_wise::PassThrough, OutDataType, - true, - VectorSizeA, - VectorSizeB>; + GroupedConvTraitsType::FixedGemmParams::FixedVectorSize, + GroupedConvTraitsType::VectorSizeA, + GroupedConvTraitsType::VectorSizeB>; using BaseGemmPipeline = typename PipelineTypeTraits< ConvConfig::Pipeline>::template UniversalGemmPipeline; @@ -116,21 +111,21 @@ struct GroupedConvolutionForwardInvoker constexpr auto scheduler = ConvConfig::Scheduler; constexpr auto memory_operation = memory_operation_.value; - using UniversalGemmProblem = - ck_tile::UniversalGemmPipelineProblem; + using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem< + InDataType, + WeiDataType, + AccDataType, + GemmShape, + GemmUniversalTraits, + scheduler, + has_hot_loop_v, + tail_number_v, + ck_tile::element_wise::PassThrough, + ck_tile::element_wise::PassThrough, + OutDataType, + GroupedConvTraitsType::FixedGemmParams::FixedVectorSize, + GroupedConvTraitsType::VectorSizeA, + GroupedConvTraitsType::VectorSizeB>; using GemmPipeline = typename PipelineTypeTraits< ConvConfig::Pipeline>::template GemmPipeline; @@ -142,7 +137,7 @@ struct GroupedConvolutionForwardInvoker AccDataType, OutDataType, typename GroupedConvTraitsType::ImplicitGemmDsLayout, - ck_tile::tensor_layout::gemm::RowMajor, + typename GroupedConvTraitsType::FixedGemmParams::ELayout, CDElementWise, TilePartitioner::MPerBlock, TilePartitioner::NPerBlock, @@ -151,10 +146,10 @@ struct GroupedConvolutionForwardInvoker ConvConfig::M_Warp_Tile, ConvConfig::N_Warp_Tile, ConvConfig::K_Warp_Tile, - ConvConfig::TransposeC, + GroupedConvTraitsType::FixedGemmParams::TransposeC, memory_operation, - 1, - true, + ConvConfig::NumWaveGroups, + GroupedConvTraitsType::FixedGemmParams::FixedVectorSize, GroupedConvTraitsType::VectorSizeC>>; using Kernel = ck_tile::GroupedConvolutionForwardKernel(Kernel{}, grids, blocks, 0, kargs)); + ave_time = ck_tile::launch_kernel(s, + ck_tile::make_kernel( + Kernel{}, grids, blocks, 0, kargs)); return ave_time; }; diff --git a/example/ck_tile/20_grouped_convolution/grouped_convolution_forward_large_tensor_invoker.hpp b/example/ck_tile/20_grouped_convolution/grouped_convolution_forward_large_tensor_invoker.hpp index 4d983baac5..9d2752727c 100644 --- a/example/ck_tile/20_grouped_convolution/grouped_convolution_forward_large_tensor_invoker.hpp +++ b/example/ck_tile/20_grouped_convolution/grouped_convolution_forward_large_tensor_invoker.hpp @@ -25,7 +25,6 @@ struct GroupedConvolutionForwardInvoker { std::cout << "[INVOKER] grouped_conv_fwd called, NDimSpatial=" << NDimSpatial << "\n"; } - constexpr int kBlockPerCu = 1; // Implicit GEMM Traits using GemmShape = ck_tile::TileGemmShape< @@ -35,27 +34,18 @@ struct GroupedConvolutionForwardInvoker ConvConfig::N_Warp_Tile, ConvConfig::K_Warp_Tile>>; - constexpr ck_tile::index_t VectorSizeA = 8; - constexpr ck_tile::index_t VectorSizeB = 8; - constexpr ck_tile::index_t VectorSizeC = 8; - constexpr auto ConvSpec = ck_tile::ConvolutionSpecialization::Default; - using TilePartitioner = - ck_tile::GemmSpatiallyLocalTilePartitioner; - - using GroupedConvTraitsTypeDefault = ck_tile::GroupedConvTraits; + using GroupedConvTraitsTypeDefault = + ck_tile::GroupedConvTraits; using GroupedConvTraitsTypeLargeTensor = ck_tile::GroupedConvTraits; + using TilePartitioner = ck_tile::GemmSpatiallyLocalTilePartitioner< + GemmShape, + GroupedConvTraitsTypeDefault::FixedGemmParams::TilePartitionerGroupNum, + GroupedConvTraitsTypeDefault::FixedGemmParams::TilePartitionerM01>; + using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits< - ConvConfig::kPadM, - ConvConfig::kPadN, - ConvConfig::kPadK, + GroupedConvTraitsTypeDefault::FixedGemmParams::kPadM, + GroupedConvTraitsTypeDefault::FixedGemmParams::kPadN, + GroupedConvTraitsTypeDefault::FixedGemmParams::kPadK, ConvConfig::DoubleSmemBuffer, - typename GroupedConvTraitsTypeDefault::GroupedConvImplicitGemmTraitsFwd::AsLayout, - typename GroupedConvTraitsTypeDefault::GroupedConvImplicitGemmTraitsFwd::BsLayout, - typename GroupedConvTraitsTypeDefault::GroupedConvImplicitGemmTraitsFwd::CLayout, - ConvConfig::TransposeC, - false, - false, // Persistent, + typename GroupedConvTraitsTypeDefault::AsLayoutFwd, + typename GroupedConvTraitsTypeDefault::BsLayoutFwd, + typename GroupedConvTraitsTypeDefault::CLayoutFwd, + GroupedConvTraitsTypeDefault::FixedGemmParams::TransposeC, + GroupedConvTraitsTypeDefault::FixedGemmParams::UseStructuredSparsity, + GroupedConvTraitsTypeDefault::FixedGemmParams::Persistent, ConvConfig::NumWaveGroups>; using GemmPipelineProblem = ck_tile::GemmPipelineProblem< @@ -88,13 +83,14 @@ struct GroupedConvolutionForwardInvoker WeiDataType, AccDataType, GemmShape, - typename GroupedConvTraitsTypeDefault::GroupedConvImplicitGemmTraitsFwd, + typename GroupedConvTraitsTypeDefault::template GroupedConvImplicitGemmTraitsFwd< + ConvConfig::NumWaveGroups>, ck_tile::element_wise::PassThrough, ck_tile::element_wise::PassThrough, OutDataType, - true, - VectorSizeA, - VectorSizeB>; + GroupedConvTraitsTypeDefault::FixedGemmParams::FixedVectorSize, + GroupedConvTraitsTypeDefault::VectorSizeA, + GroupedConvTraitsTypeDefault::VectorSizeB>; using BaseGemmPipeline = typename PipelineTypeTraits< ConvConfig::Pipeline>::template UniversalGemmPipeline; @@ -116,9 +112,9 @@ struct GroupedConvolutionForwardInvoker using TransformType = ck_tile::TransformConvFwdToGemm; - using UniversalGemmProblem = - ck_tile::UniversalGemmPipelineProblem; + using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem< + InDataType, + WeiDataType, + AccDataType, + GemmShape, + GemmUniversalTraits, + scheduler, + has_hot_loop_v, + tail_number_v, + ck_tile::element_wise::PassThrough, + ck_tile::element_wise::PassThrough, + OutDataType, + GroupedConvTraitsType::FixedGemmParams::FixedVectorSize, + GroupedConvTraitsType::VectorSizeA, + GroupedConvTraitsType::VectorSizeB>; using GemmPipeline = typename PipelineTypeTraits< ConvConfig::Pipeline>::template GemmPipeline; @@ -290,7 +286,7 @@ struct GroupedConvolutionForwardInvoker AccDataType, OutDataType, typename GroupedConvTraitsType::ImplicitGemmDsLayout, - ck_tile::tensor_layout::gemm::RowMajor, + typename GroupedConvTraitsType::FixedGemmParams::ELayout, CDEElementWise, TilePartitioner::MPerBlock, TilePartitioner::NPerBlock, @@ -299,10 +295,10 @@ struct GroupedConvolutionForwardInvoker ConvConfig::M_Warp_Tile, ConvConfig::N_Warp_Tile, ConvConfig::K_Warp_Tile, - ConvConfig::TransposeC, + GroupedConvTraitsType::FixedGemmParams::TransposeC, memory_operation, - 1, - true, + ConvConfig::NumWaveGroups, + GroupedConvTraitsType::FixedGemmParams::FixedVectorSize, GroupedConvTraitsType::VectorSizeC>>; // Use split-image kernel if layout supports it, otherwise use regular kernel @@ -368,7 +364,8 @@ struct GroupedConvolutionForwardInvoker } ave_time = ck_tile::launch_kernel( - s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); + s, + ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); return ave_time; }; diff --git a/include/ck_tile/ops/grouped_convolution/utils/grouped_convolution_utils.hpp b/include/ck_tile/ops/grouped_convolution/utils/grouped_convolution_utils.hpp index 9b5a60ee1f..8ea6cffa7d 100644 --- a/include/ck_tile/ops/grouped_convolution/utils/grouped_convolution_utils.hpp +++ b/include/ck_tile/ops/grouped_convolution/utils/grouped_convolution_utils.hpp @@ -74,6 +74,21 @@ struct GroupedConvTraits } public: + // Fixed values for Implicit GEMM + struct FixedGemmParams + { + static constexpr ck_tile::index_t TilePartitionerGroupNum = 8; + static constexpr ck_tile::index_t TilePartitionerM01 = 4; + static constexpr bool kPadM = true; + static constexpr bool kPadN = true; + static constexpr bool kPadK = true; + static constexpr bool TransposeC = false; + static constexpr bool FixedVectorSize = true; + static constexpr bool UseStructuredSparsity = false; + static constexpr bool Persistent = false; + using ELayout = ck_tile::tensor_layout::gemm::RowMajor; + }; + // Compile time parameters static constexpr bool EnableSplitImage = EnableSplitImage_; static constexpr index_t NumGroupsToMerge = NumGroupsToMerge_; static constexpr index_t NDimSpatial = NDimSpatial_; @@ -82,31 +97,43 @@ struct GroupedConvTraits using WeiLayout = WeiLayout_; using DsLayout = DsLayout_; using OutLayout = OutLayout_; + + // Forward Gemm Layouts + using AsLayoutFwd = ck_tile::tensor_layout::gemm::RowMajor; + using BsLayoutFwd = ck_tile::tensor_layout::gemm::ColumnMajor; + using CLayoutFwd = ck_tile::tensor_layout::gemm::RowMajor; + // Backward Data Gemm Layouts + using AsLayoutBwdData = ck_tile::tensor_layout::gemm::RowMajor; + using BsLayoutBwdData = ck_tile::tensor_layout::gemm::RowMajor; + using CLayoutBwdData = ck_tile::tensor_layout::gemm::RowMajor; + // Backward Weight Gemm Layouts + using AsLayoutBwdWeight = ck_tile::tensor_layout::gemm::ColumnMajor; + using BsLayoutBwdWeight = ck_tile::tensor_layout::gemm::RowMajor; + using CLayoutBwdWeight = ck_tile::tensor_layout::gemm::RowMajor; + + template using GroupedConvImplicitGemmTraitsFwd = - TileGemmTraits; - using GroupedConvImplicitGemmTraitsBwdData = - TileGemmTraits; - using GroupedConvImplicitGemmTraitsBwdWeight = - TileGemmTraits; + TileGemmTraits; + template + using GroupedConvImplicitGemmTraitsBwdData = TileGemmTraits; + template + using GroupedConvImplicitGemmTraitsBwdWeight = TileGemmTraits; static constexpr ck_tile::index_t VectorSizeA = VectorSizeA_; static constexpr ck_tile::index_t VectorSizeB = VectorSizeB_; static constexpr ck_tile::index_t VectorSizeC = VectorSizeC_; - static constexpr index_t NumDTensor = DsLayout::size(); + static constexpr ck_tile::index_t NumDTensor = DsLayout::size(); using ImplicitGemmDsLayout = decltype(generate_implicit_gemm_layout()); }; From 18e083003fa25a661015542c39b1979200f361cf Mon Sep 17 00:00:00 2001 From: Adam Osewski <19374865+aosewski@users.noreply.github.com> Date: Thu, 6 Nov 2025 15:46:26 +0100 Subject: [PATCH 2/6] [CK_BUILDER] Convolution description (#3163) * Add DirectLoad tparam & clean up headers. * Add convolution traits. * Update inline documentation. * Add more convolution specialization and gemm padding types. * Add additional helper functions & more tests to conv traits. * Fix tests cmake file. * Add case insensitive string comparison * Fix function name overlapping with variable name. * Unify pipeline version and scheduler enums. * Fix includes. * Update test conv traits with unified enums. * Update concepts etc with update unified enum * Fix ckb conv fwd test - unified enum usage. * Dump changes. * Add ostream overloads for all enum classes. * Update detailed() function in ConvDescription * Fix handling union based conv direction. * Add test & update conv description. * Refine tree view. * Update copyrights * Fix merge artifacts * Update detailed tree conv description * Fix clang-format --- .../include/ck_tile/builder/builder_utils.hpp | 62 ---- .../builder/conv_signature_predicates.hpp | 16 + .../builder/reflect/conv_description.hpp | 268 +++++++++++++++++ .../ck_tile/builder/reflect/conv_traits.hpp | 2 +- .../builder/reflect/instance_traits_util.hpp | 5 +- .../builder/reflect/tree_formatter.hpp | 106 +++++++ .../builder/include/ck_tile/builder/types.hpp | 275 ++++++++++++++++++ experimental/builder/test/CMakeLists.txt | 3 + .../builder/test/test_conv_description.cpp | 169 +++++++++++ 9 files changed, 842 insertions(+), 64 deletions(-) create mode 100644 experimental/builder/include/ck_tile/builder/reflect/conv_description.hpp create mode 100644 experimental/builder/include/ck_tile/builder/reflect/tree_formatter.hpp create mode 100644 experimental/builder/test/test_conv_description.cpp diff --git a/experimental/builder/include/ck_tile/builder/builder_utils.hpp b/experimental/builder/include/ck_tile/builder/builder_utils.hpp index 5b4981c630..f16d96bec6 100644 --- a/experimental/builder/include/ck_tile/builder/builder_utils.hpp +++ b/experimental/builder/include/ck_tile/builder/builder_utils.hpp @@ -78,66 +78,4 @@ struct UnsupportedEnumValue { }; -// Helper functions to convert enums to strings -constexpr std::string_view ConvDirectionToString(ConvDirection dir) -{ - switch(dir) - { - case ConvDirection::FORWARD: return "Forward"; - case ConvDirection::BACKWARD_DATA: return "Backward Data"; - case ConvDirection::BACKWARD_WEIGHT: return "Backward Weight"; - default: return "Unknown"; - } -} - -constexpr std::string_view DataTypeToString(DataType dt) -{ - switch(dt) - { - case DataType::FP16: return "FP16"; - case DataType::FP32: return "FP32"; - case DataType::BF16: return "BF16"; - case DataType::FP8: return "FP8"; - case DataType::I8: return "I8"; - case DataType::U8: return "U8"; - default: return "Unknown"; - } -} - -constexpr std::string_view LayoutToString(GroupConvLayout1D layout) -{ - switch(layout) - { - case GroupConvLayout1D::GNWC_GKXC_GNWK: return "GNWC_GKXC_GNWK"; - case GroupConvLayout1D::NWGC_GKXC_NWGK: return "NWGC_GKXC_NWGK"; - case GroupConvLayout1D::NGCW_GKXC_NGKW: return "NGCW_GKXC_NGKW"; - case GroupConvLayout1D::NGCW_GKCX_NGKW: return "NGCW_GKCX_NGKW"; - default: return "Unknown"; - } -} - -constexpr std::string_view LayoutToString(GroupConvLayout2D layout) -{ - switch(layout) - { - case GroupConvLayout2D::GNHWC_GKYXC_GNHWK: return "GNHWC_GKYXC_GNHWK"; - case GroupConvLayout2D::NHWGC_GKYXC_NHWGK: return "NHWGC_GKYXC_NHWGK"; - case GroupConvLayout2D::NGCHW_GKYXC_NGKHW: return "NGCHW_GKYXC_NGKHW"; - case GroupConvLayout2D::NGCHW_GKCYX_NGKHW: return "NGCHW_GKCYX_NGKHW"; - default: return "Unknown"; - } -} - -constexpr std::string_view LayoutToString(GroupConvLayout3D layout) -{ - switch(layout) - { - case GroupConvLayout3D::GNDHWC_GKZYXC_GNDHWK: return "GNDHWC_GKZYXC_GNDHWK"; - case GroupConvLayout3D::NDHWGC_GKZYXC_NDHWGK: return "NDHWGC_GKZYXC_NDHWGK"; - case GroupConvLayout3D::NGCDHW_GKZYXC_NGKDHW: return "NGCDHW_GKZYXC_NGKDHW"; - case GroupConvLayout3D::NGCDHW_GKCZYX_NGKDHW: return "NGCDHW_GKCZYX_NGKDHW"; - default: return "Unknown"; - } -} - } // namespace ck_tile::builder diff --git a/experimental/builder/include/ck_tile/builder/conv_signature_predicates.hpp b/experimental/builder/include/ck_tile/builder/conv_signature_predicates.hpp index f016a342d3..3869c7b538 100644 --- a/experimental/builder/include/ck_tile/builder/conv_signature_predicates.hpp +++ b/experimental/builder/include/ck_tile/builder/conv_signature_predicates.hpp @@ -33,30 +33,35 @@ concept ConvDirectionIsBackwardWeight = (Sig.direction == ConvDirection::BACKWAR // Predicate for DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 operation. template concept ConvDeviceOpIs_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 = + ConvDirectionIsForward && (Sig.device_operation._fwd == FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3); // Predicate for DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK operation. template concept ConvDeviceOpIs_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK = + ConvDirectionIsForward && (Sig.device_operation._fwd == FwdGroupConvDeviceOperation::DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK); // Predicate for DeviceGroupedConvFwdMultipleD_Wmma_CShuffle operation. template concept ConvDeviceOpIs_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle = + ConvDirectionIsForward && (Sig.device_operation._fwd == FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleD_Wmma_CShuffle); // Predicate for DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle operation. template concept ConvDeviceOpIs_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle = + ConvDirectionIsForward && (Sig.device_operation._fwd == FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle); // Predicate for DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor operation. template concept ConvDeviceOpIs_DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor = + ConvDirectionIsForward && (Sig.device_operation._fwd == FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor); @@ -76,48 +81,56 @@ concept ConvDeviceOpIsForward = // Predicate for DeviceGroupedConvBwdWeight operation. template concept ConvDeviceOpIs_DeviceGroupedConvBwdWeight = + ConvDirectionIsBackwardWeight && (Sig.device_operation._bwd_weight == BwdWeightGroupConvDeviceOperation::DeviceGroupedConvBwdWeight); // Predicate for DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle operation. template concept ConvDeviceOpIs_DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle = + ConvDirectionIsBackwardWeight && (Sig.device_operation._bwd_weight == BwdWeightGroupConvDeviceOperation::DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle); // Predicate for DeviceGroupedConvBwdWeight_Xdl_CShuffle operation. template concept ConvDeviceOpIs_DeviceGroupedConvBwdWeight_Xdl_CShuffle = + ConvDirectionIsBackwardWeight && (Sig.device_operation._bwd_weight == BwdWeightGroupConvDeviceOperation::DeviceGroupedConvBwdWeight_Xdl_CShuffle); // Predicate for DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle operation. template concept ConvDeviceOpIs_DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle = + ConvDirectionIsBackwardWeight && (Sig.device_operation._bwd_weight == BwdWeightGroupConvDeviceOperation::DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle); // Predicate for DeviceGroupedConvBwdWeight_Wmma_CShuffle operation. template concept ConvDeviceOpIs_DeviceGroupedConvBwdWeight_Wmma_CShuffle = + ConvDirectionIsBackwardWeight && (Sig.device_operation._bwd_weight == BwdWeightGroupConvDeviceOperation::DeviceGroupedConvBwdWeight_Wmma_CShuffle); // Predicate for DeviceGroupedConvBwdWeight_Xdl_CShuffleV3 operation. template concept ConvDeviceOpIs_DeviceGroupedConvBwdWeight_Xdl_CShuffleV3 = + ConvDirectionIsBackwardWeight && (Sig.device_operation._bwd_weight == BwdWeightGroupConvDeviceOperation::DeviceGroupedConvBwdWeight_Xdl_CShuffleV3); // Predicate for DeviceGroupedConvBwdWeightMultipleD operation. template concept ConvDeviceOpIs_DeviceGroupedConvBwdWeightMultipleD = + ConvDirectionIsBackwardWeight && (Sig.device_operation._bwd_weight == BwdWeightGroupConvDeviceOperation::DeviceGroupedConvBwdWeightMultipleD); // Predicate for DeviceGroupedConvBwdWeight_Dl operation. template concept ConvDeviceOpIs_DeviceGroupedConvBwdWeight_Dl = + ConvDirectionIsBackwardWeight && (Sig.device_operation._bwd_weight == BwdWeightGroupConvDeviceOperation::DeviceGroupedConvBwdWeight_Dl); @@ -140,18 +153,21 @@ concept ConvDeviceOpIsBackwardWeight = // Predicate for DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 operation. template concept ConvDeviceOpIs_DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 = + ConvDirectionIsBackwardData && (Sig.device_operation._bwd_data == BwdDataGroupConvDeviceOperation::DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1); // Predicate for DeviceGroupedConvBwdDataMultipleD operation. template concept ConvDeviceOpIs_DeviceGroupedConvBwdDataMultipleD = + ConvDirectionIsBackwardData && (Sig.device_operation._bwd_data == BwdDataGroupConvDeviceOperation::DeviceGroupedConvBwdDataMultipleD); // Predicate for DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle operation. template concept ConvDeviceOpIs_DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle = + ConvDirectionIsBackwardData && (Sig.device_operation._bwd_data == BwdDataGroupConvDeviceOperation::DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle); diff --git a/experimental/builder/include/ck_tile/builder/reflect/conv_description.hpp b/experimental/builder/include/ck_tile/builder/reflect/conv_description.hpp new file mode 100644 index 0000000000..0b58f5a3b7 --- /dev/null +++ b/experimental/builder/include/ck_tile/builder/reflect/conv_description.hpp @@ -0,0 +1,268 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include +#include +#include + +#include +#include +#include + +/// @file conv_description.hpp +/// @brief Provides human-readable descriptions of ConvBuilder configurations + +namespace ck_tile::reflect::conv { + +struct ConvSignatureInfo +{ + int spatial_dim; + builder::ConvDirection direction; + std::variant + layout; + builder::DataType data_type; + builder::ElementwiseOperation input_element_op; + builder::ElementwiseOperation weight_element_op; + builder::ElementwiseOperation output_element_op; +}; + +// Algorithm information - groups all algorithm-related configuration +struct GemmAlgorithmInfo +{ + int thread_block_size; + DataTileInfo tile_dims; + WarpGemmParams warp_gemm; + InputTileTransferInfo a_tile_transfer; + InputTileTransferInfo b_tile_transfer; + OutputTileTransferInfo c_tile_transfer; + builder::PipelineVersion pipeline_version; + builder::PipelineScheduler pipeline_scheduler; + std::variant + conv_specialization; + builder::GemmPadding padding; +}; + +// Provides human-readable descriptions of ConvBuilder configurations. +struct ConvDescription +{ + ConvSignatureInfo signature; + GemmAlgorithmInfo algorithm; + + // Brief one-line summary + std::string brief() const + { + std::ostringstream oss; + oss << signature.spatial_dim << "D " << signature.direction << " convolution"; + return oss.str(); + } + + // Detailed hierarchical description + std::string detailed() const + { + TreeFormatter f; + f.writeLine(0, signature.spatial_dim, "D ", signature.direction, " Convolution Kernel"); + f.writeLine(1, "Signature"); + f.writeLine(2, "Tensor Type: ", signature.data_type); + f.writeLine(2, "Memory Layout: ", signature.layout); + f.writeLine(2, "Input elementwise operation: ", signature.input_element_op); + f.writeLine(2, "Weights elementwise operation: ", signature.weight_element_op); + f.writeLast(2, "Output elementwise operation: ", signature.output_element_op); + + f.writeLine(1, "Algorithm"); + // Compute Block section + f.writeLine(2, "Thread block size: ", algorithm.thread_block_size); + f.writeLine(2, + "Data tile size: ", + algorithm.tile_dims.m, + "×", + algorithm.tile_dims.n, + "×", + algorithm.tile_dims.k); + f.writeLine(2, "Gemm padding: ", algorithm.padding); + f.writeLine(2, "Convolution specialization: ", algorithm.conv_specialization); + // Pipeline section + f.writeLine(2, "Pipeline version: ", algorithm.pipeline_version); + f.writeLine(2, "Pipeline scheduler: ", algorithm.pipeline_scheduler); + f.writeLine(2, "Warp Gemm parameters: "); + f.writeLine( + 3, "subtile size: ", algorithm.warp_gemm.gemm_m, "×", algorithm.warp_gemm.gemm_n); + f.writeLast(3, + "Number of warp gemm iterations: ", + algorithm.warp_gemm.m_iter, + "×", + algorithm.warp_gemm.n_iter); + + // Memory Access section + f.writeLine(2, "Memory access:"); + + f.writeLine(3, "A Tile transfer: "); + f.writeLine(4, + "Tile dimensions: ", + algorithm.a_tile_transfer.tile_dimensions.k0, + "×", + algorithm.a_tile_transfer.tile_dimensions.m_or_n, + "×", + algorithm.a_tile_transfer.tile_dimensions.k1, + "×"); + f.writeLine( + 4, "The innermost K subdimension size: ", algorithm.a_tile_transfer.transfer_params.k1); + f.writeLine(4, + "Spatial thread distribution over the data tile: ", + algorithm.a_tile_transfer.transfer_params.thread_cluster_order[0], + "×", + algorithm.a_tile_transfer.transfer_params.thread_cluster_order[1], + "×", + algorithm.a_tile_transfer.transfer_params.thread_cluster_order[2]); + f.writeLine(4, + "The order of accessing data tile axes: ", + algorithm.a_tile_transfer.transfer_params.src_access_order[0], + "×", + algorithm.a_tile_transfer.transfer_params.src_access_order[1], + "×", + algorithm.a_tile_transfer.transfer_params.src_access_order[2]); + f.writeLine(4, + "Vectorized memory access axis index (with contiguous memory): ", + algorithm.a_tile_transfer.transfer_params.src_vector_dim); + f.writeLine(4, + "Vector access (GMEM read) instruction size: ", + algorithm.a_tile_transfer.transfer_params.src_scalar_per_vector); + f.writeLine(4, + "Vector access (LDS write) instruction size: ", + algorithm.a_tile_transfer.transfer_params.dst_scalar_per_vector_k1); + f.writeLast(4, + "LDS data layout padding (to prevent bank conflicts): ", + algorithm.a_tile_transfer.transfer_params.dst_scalar_per_vector_k1); + + f.writeLine(3, "B Tile transfer: "); + f.writeLine(4, + "Tile dimensions: ", + algorithm.b_tile_transfer.tile_dimensions.k0, + "×", + algorithm.b_tile_transfer.tile_dimensions.m_or_n, + "×", + algorithm.b_tile_transfer.tile_dimensions.k1, + "×"); + f.writeLine( + 4, "The innermost K subdimension size: ", algorithm.b_tile_transfer.transfer_params.k1); + f.writeLine(4, + "Spatial thread distribution over the data tile: ", + algorithm.b_tile_transfer.transfer_params.thread_cluster_order[0], + "×", + algorithm.b_tile_transfer.transfer_params.thread_cluster_order[1], + "×", + algorithm.b_tile_transfer.transfer_params.thread_cluster_order[2]); + f.writeLine(4, + "The order of accessing data tile axes: ", + algorithm.b_tile_transfer.transfer_params.src_access_order[0], + "×", + algorithm.b_tile_transfer.transfer_params.src_access_order[1], + "×", + algorithm.b_tile_transfer.transfer_params.src_access_order[2]); + f.writeLine(4, + "Vectorized memory access axis index (with contiguous memory): ", + algorithm.b_tile_transfer.transfer_params.src_vector_dim); + f.writeLine(4, + "Vector access (GMEM read) instruction size: ", + algorithm.b_tile_transfer.transfer_params.src_scalar_per_vector); + f.writeLine(4, + "Vector access (LDS write) instruction size: ", + algorithm.b_tile_transfer.transfer_params.dst_scalar_per_vector_k1); + f.writeLast(4, + "LDS data layout padding (to prevent bank conflicts): ", + algorithm.b_tile_transfer.transfer_params.dst_scalar_per_vector_k1); + + f.writeLast(3, "C Tile transfer: "); + f.writeLine(4, + "Data shuffle (number of gemm instructions per iteration): ", + algorithm.c_tile_transfer.shuffle_params.m_gemms_per_shuffle, + "×", + algorithm.c_tile_transfer.shuffle_params.n_gemms_per_shuffle); + f.writeLine(4, + "Spatial thread distribution used to store data: ", + algorithm.c_tile_transfer.thread_cluster_dims[0], + "×", + algorithm.c_tile_transfer.thread_cluster_dims[1], + "×", + algorithm.c_tile_transfer.thread_cluster_dims[2], + "×", + algorithm.c_tile_transfer.thread_cluster_dims[3]); + f.writeLast(4, + "Vector access (GMEM write) instruction size: ", + algorithm.c_tile_transfer.scalar_per_vector); + f.writeLast(2); + f.writeLast(1); + return f.getString(); + } + + // Educational explanation of optimization choices + std::string explain() const + { + std::ostringstream oss; + // Placeholder for future implementation + return oss.str(); + } + + // Performance characteristics and use case guidance + std::string suggest() const + { + std::ostringstream oss; + // Placeholder for future implementation + return oss.str(); + } +}; + +// Helper concept to detect if a type has InstanceTraits specialization +template +concept HasInstanceTraits = requires { typename InstanceTraits; }; + +// Helper concept to detect ConvBuilder types +template +concept IsConvBuilder = requires { + typename T::Factory; + typename T::Instance; +}; + +// Primary factory function: Create ConvDescription from Instance type directly +template + requires HasInstanceTraits +ConvDescription Describe() +{ + using Traits = ConvTraits; + + return ConvDescription{ + .signature = ConvSignatureInfo{.spatial_dim = Traits::spatial_dim, + .direction = Traits::direction, + .layout = Traits::layout, + .data_type = Traits::data_type, + .input_element_op = Traits::input_element_op, + .weight_element_op = Traits::weight_element_op, + .output_element_op = Traits::output_element_op}, + .algorithm = GemmAlgorithmInfo{.thread_block_size = Traits::thread_block_size, + .tile_dims = Traits::tile_dims, + .warp_gemm = Traits::warp_gemm, + .a_tile_transfer = Traits::a_tile_transfer, + .b_tile_transfer = Traits::b_tile_transfer, + .c_tile_transfer = Traits::c_tile_transfer, + .pipeline_version = Traits::pipeline_version, + .pipeline_scheduler = Traits::pipeline_scheduler, + .conv_specialization = Traits::conv_specialization, + .padding = Traits::gemm_padding}}; +} + +// Backward compatibility: Create ConvDescription from Builder type +template + requires IsConvBuilder && (!HasInstanceTraits) +ConvDescription Describe() +{ + // Delegate to Instance-based version + using Instance = typename Builder::Instance; + return Describe(); +} + +} // namespace ck_tile::reflect::conv diff --git a/experimental/builder/include/ck_tile/builder/reflect/conv_traits.hpp b/experimental/builder/include/ck_tile/builder/reflect/conv_traits.hpp index a74d77d155..86cf11f647 100644 --- a/experimental/builder/include/ck_tile/builder/reflect/conv_traits.hpp +++ b/experimental/builder/include/ck_tile/builder/reflect/conv_traits.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/experimental/builder/include/ck_tile/builder/reflect/instance_traits_util.hpp b/experimental/builder/include/ck_tile/builder/reflect/instance_traits_util.hpp index c863d2306c..e4d154ae10 100644 --- a/experimental/builder/include/ck_tile/builder/reflect/instance_traits_util.hpp +++ b/experimental/builder/include/ck_tile/builder/reflect/instance_traits_util.hpp @@ -13,7 +13,10 @@ #include #include #include -#include +#include +#include +#include +#include #include #include #include diff --git a/experimental/builder/include/ck_tile/builder/reflect/tree_formatter.hpp b/experimental/builder/include/ck_tile/builder/reflect/tree_formatter.hpp new file mode 100644 index 0000000000..6a80a994ee --- /dev/null +++ b/experimental/builder/include/ck_tile/builder/reflect/tree_formatter.hpp @@ -0,0 +1,106 @@ +#pragma once + +#include +#include +#include +#include + +namespace ck_tile::reflect { + +// Helper class for formatting hierarchical tree structures with proper indentation +// and tree-drawing characters (├─, └─, │, etc.) +// +// Example Usage: +// +// TreeFormatter f; +// f.writeLine(0, "Root"); +// f.writeLine(1, "Branch 1"); +// f.writeLine(2, "Item 1a"); +// f.writeLast(2, "Item 1b"); +// f.writeLast(1, "Branch 2"); +// f.writeLast(2, "Item 2a"); +// std::cout << f.getString() << "\n"; +// +// Generated Output: +// +// Root +// ├─ Branch 1 +// │ ├─ Item 1a +// │ └─ Item 1b +// └─ Branch 2 +// └─ Item 2a +class TreeFormatter +{ + public: + TreeFormatter() = default; + + // Write a line at the specified indentation level (branch continues after this) + template + void writeLine(int indent_level, Args&&... args) + { + writeLineImpl(indent_level, false, std::forward(args)...); + } + + // Write the last line at the specified indentation level (branch ends) + template + void writeLast(int indent_level, Args&&... args) + { + writeLineImpl(indent_level, true, std::forward(args)...); + } + + // Get the formatted string (removes trailing newline if present) + std::string getString() const + { + std::string result = oss_.str(); + if(!result.empty() && result.back() == '\n') + { + result.pop_back(); + } + return result; + } + + private: + std::ostringstream oss_; + std::vector is_last_at_level_; // Tracks which levels have ended + + // Implementation of line writing with tree symbols + template + void writeLineImpl(int indent_level, bool is_last, Args&&... args) + { + // Ensure we have enough tracking space + if(static_cast(indent_level) >= is_last_at_level_.size()) + { + is_last_at_level_.resize(indent_level + 1, false); + // Level 0 (root) should always be treated as "last" since it has no tree symbols + if(is_last_at_level_.size() > 0) + { + is_last_at_level_[0] = true; + } + } + + // Draw the tree structure + // Start from level 1 (skip level 0 which is the root with no symbols) + for(int i = 1; i < indent_level; ++i) + { + // For all parent levels, draw vertical line or space based on whether they ended + oss_ << (is_last_at_level_[i] ? " " : "│ "); + } + + // Draw the branch symbol for the current level + if(indent_level > 0) + { + oss_ << (is_last ? "└─ " : "├─ "); + } + + // Write the content using fold expression with direct stream insertion + ((oss_ << std::forward(args)), ...); + + oss_ << '\n'; + + // Update tracking for this level AFTER writing the line + // This ensures future lines at deeper levels know if this level ended + is_last_at_level_[indent_level] = is_last; + } +}; + +} // namespace ck_tile::reflect diff --git a/experimental/builder/include/ck_tile/builder/types.hpp b/experimental/builder/include/ck_tile/builder/types.hpp index 2af10346e5..a58c994288 100644 --- a/experimental/builder/include/ck_tile/builder/types.hpp +++ b/experimental/builder/include/ck_tile/builder/types.hpp @@ -3,6 +3,10 @@ #pragma once +#include +#include +#include + namespace ck_tile::builder { enum class DataType @@ -215,4 +219,275 @@ enum class PipelineScheduler INTERWAVE }; +// ostream operator overloads for enum classes +inline std::ostream& operator<<(std::ostream& os, DataType dt) +{ + using enum DataType; + switch(dt) + { + case FP16: return os << "FP16"; + case FP32: return os << "FP32"; + case BF16: return os << "BF16"; + case FP8: return os << "FP8"; + case I8: return os << "I8"; + case U8: return os << "U8"; + default: return os << "Unknown"; + } +} + +inline std::ostream& operator<<(std::ostream& os, ConvDirection dir) +{ + using enum ConvDirection; + switch(dir) + { + case FORWARD: return os << "Forward"; + case BACKWARD_DATA: return os << "Backward Data"; + case BACKWARD_WEIGHT: return os << "Backward Weight"; + default: return os << "Unknown"; + } +} + +inline std::ostream& operator<<(std::ostream& os, GroupConvLayout1D layout) +{ + using enum GroupConvLayout1D; + switch(layout) + { + case GNWC_GKXC_GNWK: return os << "GNWC_GKXC_GNWK"; + case NWGC_GKXC_NWGK: return os << "NWGC_GKXC_NWGK"; + case NGCW_GKXC_NGKW: return os << "NGCW_GKXC_NGKW"; + case NGCW_GKCX_NGKW: return os << "NGCW_GKCX_NGKW"; + default: return os << "Unknown"; + } +} + +inline std::ostream& operator<<(std::ostream& os, GroupConvLayout2D layout) +{ + using enum GroupConvLayout2D; + switch(layout) + { + case GNHWC_GKYXC_GNHWK: return os << "GNHWC_GKYXC_GNHWK"; + case NHWGC_GKYXC_NHWGK: return os << "NHWGC_GKYXC_NHWGK"; + case NGCHW_GKYXC_NGKHW: return os << "NGCHW_GKYXC_NGKHW"; + case NGCHW_GKCYX_NGKHW: return os << "NGCHW_GKCYX_NGKHW"; + default: return os << "Unknown"; + } +} + +inline std::ostream& operator<<(std::ostream& os, GroupConvLayout3D layout) +{ + using enum GroupConvLayout3D; + switch(layout) + { + case GNDHWC_GKZYXC_GNDHWK: return os << "GNDHWC_GKZYXC_GNDHWK"; + case NDHWGC_GKZYXC_NDHWGK: return os << "NDHWGC_GKZYXC_NDHWGK"; + case NGCDHW_GKZYXC_NGKDHW: return os << "NGCDHW_GKZYXC_NGKDHW"; + case NGCDHW_GKCZYX_NGKDHW: return os << "NGCDHW_GKCZYX_NGKDHW"; + default: return os << "Unknown"; + } +} + +inline std::ostream& operator<<(std::ostream& os, FwdGroupConvDeviceOperation op) +{ + using enum FwdGroupConvDeviceOperation; + switch(op) + { + case DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK: + return os << "DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK"; + case DeviceGroupedConvFwdMultipleD_Wmma_CShuffle: + return os << "DeviceGroupedConvFwdMultipleD_Wmma_CShuffle"; + case DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle: + return os << "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle"; + case DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3: + return os << "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3"; + case DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor: + return os << "DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor"; + default: return os << "Unknown"; + } +} + +inline std::ostream& operator<<(std::ostream& os, BwdDataGroupConvDeviceOperation op) +{ + using enum BwdDataGroupConvDeviceOperation; + switch(op) + { + case DeviceGroupedConvBwdDataMultipleD: return os << "DeviceGroupedConvBwdDataMultipleD"; + case DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle: + return os << "DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle"; + case DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1: + return os << "DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1"; + default: return os << "Unknown"; + } +} + +inline std::ostream& operator<<(std::ostream& os, BwdWeightGroupConvDeviceOperation op) +{ + using enum BwdWeightGroupConvDeviceOperation; + switch(op) + { + case DeviceGroupedConvBwdWeight: return os << "DeviceGroupedConvBwdWeight"; + case DeviceGroupedConvBwdWeight_Dl: return os << "DeviceGroupedConvBwdWeight_Dl"; + case DeviceGroupedConvBwdWeight_Xdl_CShuffle: + return os << "DeviceGroupedConvBwdWeight_Xdl_CShuffle"; + case DeviceGroupedConvBwdWeight_Xdl_CShuffleV3: + return os << "DeviceGroupedConvBwdWeight_Xdl_CShuffleV3"; + case DeviceGroupedConvBwdWeight_Wmma_CShuffle: + return os << "DeviceGroupedConvBwdWeight_Wmma_CShuffle"; + case DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle: + return os << "DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle"; + case DeviceGroupedConvBwdWeightMultipleD: return os << "DeviceGroupedConvBwdWeightMultipleD"; + case DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle: + return os << "DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle"; + default: return os << "Unknown"; + } +} + +inline std::ostream& operator<<(std::ostream& os, ElementwiseOperation op) +{ + using enum ElementwiseOperation; + switch(op) + { + case BIAS: return os << "BIAS"; + case BIAS_CLAMP: return os << "BIAS_CLAMP"; + case BIAS_BNORM_CLAMP: return os << "BIAS_BNORM_CLAMP"; + case BILINEAR: return os << "BILINEAR"; + case CLAMP: return os << "CLAMP"; + case SCALE: return os << "SCALE"; + case PASS_THROUGH: return os << "PASS_THROUGH"; + default: return os << "Unknown"; + } +} + +inline std::ostream& operator<<(std::ostream& os, PipelineVersion ver) +{ + using enum PipelineVersion; + switch(ver) + { + case V1: return os << "V1"; + case V2: return os << "V2"; + case V3: return os << "V3"; + case V4: return os << "V4"; + case V5: return os << "V5"; + case WEIGHT_ONLY: return os << "WEIGHT_ONLY"; + default: return os << "Unknown"; + } +} + +inline std::ostream& operator<<(std::ostream& os, GemmSpecialization spec) +{ + using enum GemmSpecialization; + switch(spec) + { + case Default: return os << "Default"; + case MPadding: return os << "MPadding"; + case NPadding: return os << "NPadding"; + case KPadding: return os << "KPadding"; + case MNPadding: return os << "MNPadding"; + case MKPadding: return os << "MKPadding"; + case NKPadding: return os << "NKPadding"; + case MNKPadding: return os << "MNKPadding"; + case OPadding: return os << "OPadding"; + case MOPadding: return os << "MOPadding"; + case NOPadding: return os << "NOPadding"; + case KOPadding: return os << "KOPadding"; + case MNOPadding: return os << "MNOPadding"; + case MKOPadding: return os << "MKOPadding"; + case NKOPadding: return os << "NKOPadding"; + case MNKOPadding: return os << "MNKOPadding"; + default: return os << "Unknown"; + } +} + +inline std::ostream& operator<<(std::ostream& os, ConvFwdSpecialization spec) +{ + using enum ConvFwdSpecialization; + switch(spec) + { + case DEFAULT: return os << "DEFAULT"; + case FILTER_1X1_PAD0: return os << "FILTER_1X1_PAD0"; + case FILTER_1X1_STRIDE1_PAD0: return os << "FILTER_1X1_STRIDE1_PAD0"; + case FILTER_3x3: return os << "FILTER_3x3"; + default: return os << "Unknown"; + } +} + +inline std::ostream& operator<<(std::ostream& os, ConvBwdDataSpecialization spec) +{ + using enum ConvBwdDataSpecialization; + switch(spec) + { + case DEFAULT: return os << "DEFAULT"; + case FILTER_1X1_STRIDE1_PAD0: return os << "FILTER_1X1_STRIDE1_PAD0"; + default: return os << "Unknown"; + } +} + +inline std::ostream& operator<<(std::ostream& os, ConvBwdWeightSpecialization spec) +{ + using enum ConvBwdWeightSpecialization; + switch(spec) + { + case DEFAULT: return os << "DEFAULT"; + case FILTER_1X1_STRIDE1_PAD0: return os << "FILTER_1X1_STRIDE1_PAD0"; + case FILTER_1X1_PAD0: return os << "FILTER_1X1_PAD0"; + case ODD_C: return os << "ODD_C"; + default: return os << "Unknown"; + } +} + +inline std::ostream& operator<<(std::ostream& os, GemmPadding padding) +{ + using enum GemmPadding; + switch(padding) + { + case DEFAULT: return os << "DEFAULT"; + case M_PADDING: return os << "M_PADDING"; + case N_PADDING: return os << "N_PADDING"; + case K_PADDING: return os << "K_PADDING"; + case MN_PADDING: return os << "MN_PADDING"; + case MK_PADDING: return os << "MK_PADDING"; + case NK_PADDING: return os << "NK_PADDING"; + case MNK_PADDING: return os << "MNK_PADDING"; + case O_PADDING: return os << "O_PADDING"; + case MO_PADDING: return os << "MO_PADDING"; + case NO_PADDING: return os << "NO_PADDING"; + case KO_PADDING: return os << "KO_PADDING"; + case MNO_PADDING: return os << "MNO_PADDING"; + case MKO_PADDING: return os << "MKO_PADDING"; + case NKO_PADDING: return os << "NKO_PADDING"; + case MNKO_PADDING: return os << "MNKO_PADDING"; + default: return os << "Unknown"; + } +} + +inline std::ostream& operator<<(std::ostream& os, PipelineScheduler sched) +{ + using enum PipelineScheduler; + switch(sched) + { + case DEFAULT: return os << "DEFAULT"; + case INTRAWAVE: return os << "INTRAWAVE"; + case INTERWAVE: return os << "INTERWAVE"; + default: return os << "Unknown"; + } +} + +// ostream operator overload for std::variant of layout types +inline std::ostream& +operator<<(std::ostream& os, + const std::variant& layout) +{ + std::visit([&os](const auto& l) { os << l; }, layout); + return os; +} + +// ostream operator overload for std::variant of convolution specializations +inline std::ostream& operator<<(std::ostream& os, + const std::variant& spec) +{ + std::visit([&os](const auto& s) { os << s; }, spec); + return os; +} + } // namespace ck_tile::builder diff --git a/experimental/builder/test/CMakeLists.txt b/experimental/builder/test/CMakeLists.txt index 0cb3237f8c..b776edbcde 100644 --- a/experimental/builder/test/CMakeLists.txt +++ b/experimental/builder/test/CMakeLists.txt @@ -67,6 +67,9 @@ add_ck_factory_test(test_ckb_factory_grouped_convolution_forward_dynamic_op test add_ck_builder_test(test_conv_traits conv/test_conv_traits.cpp) +add_ck_builder_test(test_conv_description + test_conv_description.cpp) + # Function to add all test_ckb targets to a list function(collect_test_ckb_targets result_var) # Get all targets in current directory diff --git a/experimental/builder/test/test_conv_description.cpp b/experimental/builder/test/test_conv_description.cpp new file mode 100644 index 0000000000..97af4af795 --- /dev/null +++ b/experimental/builder/test/test_conv_description.cpp @@ -0,0 +1,169 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) Advanced Micro Devices, Inc. All rights reserved. + +#include +#include + +#include +#include +#include "testing_utils.hpp" +#include "impl/conv_signature_types.hpp" +#include "impl/conv_algorithm_types.hpp" + +namespace { + +namespace ckb = ck_tile::builder; +namespace ckr = ck_tile::reflect::conv; +namespace ckt = ck_tile::test; + +// Defines the signature of the convolution operation to be tested. +// This includes dimensionality, direction, data layout, and data type. +struct ConvSignature +{ + int spatial_dim = 2; + ckb::ConvDirection direction = ckb::ConvDirection::FORWARD; + ckb::GroupConvLayout layout = ckb::GroupConvLayout2D::GNHWC_GKYXC_GNHWK; + ckb::DataType data_type = ckb::DataType::FP16; + ckb::ElementwiseOperation elementwise_operation = ckb::ElementwiseOperation::PASS_THROUGH; + ckb::GroupConvDeviceOp device_operation = + ckb::FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3; +}; +static_assert(ckb::ConvSignatureDescriptor); + +struct DefaultAlgorithm +{ + ckb::test::ThreadBlock thread_block{.block_size = 256, + .tile_size = {.m = 256, .n = 256, .k = 32}}; + + ckb::test::GridwiseXdlGemm gridwise_gemm{.ak1 = 8, + .bk1 = 8, + .m_per_xdl = 16, + .n_per_xdl = 16, + .m_xdl_per_wave = 4, + .n_xdl_per_wave = 4}; + + ckb::test::BlockTransferABC block_transfer{ + .block_transfer_a = {.k0 = 4, .m_n = 256, .k1 = 8}, + .block_transfer_b = {.k0 = 4, .m_n = 256, .k1 = 8}, + .thread_cluster_dims_c = {.m_block = 1, + .m_wave_per_xdl = 32, + .n_block = 1, + .n_wave_per_xdl = 8}, + .lds_transfer_a = {.src_vector_dim = 2, + .src_scalar_per_vector = 8, + .lds_dst_scalar_per_vector = 8, + .is_direct_load = true, + .lds_padding = false}, + .lds_transfer_b = {.src_vector_dim = 2, + .src_scalar_per_vector = 8, + .lds_dst_scalar_per_vector = 8, + .is_direct_load = true, + .lds_padding = false}, + .epilogue_c = {.m_per_wave_per_shuffle = 1, + .n_per_wave_per_shuffle = 1, + .scalar_per_vector = 8}, + .block_transfer_access_order_a = {.order = {0, 1, 2}}, + .block_transfer_access_order_b = {.order = {0, 1, 2}}, + .src_access_order_a = {.order = {0, 1, 2}}, + .src_access_order_b = {.order = {0, 1, 2}}}; + + ckb::ConvFwdSpecialization fwd_specialization = ckb::ConvFwdSpecialization::DEFAULT; + ckb::GemmSpecialization gemm_specialization = ckb::GemmSpecialization::Default; + ckb::test::BlockGemm block_gemm{.pipeline_version = ckb::PipelineVersion::V4, + .scheduler = ckb::PipelineScheduler::INTRAWAVE}; +}; +static_assert(ckb::ConvAlgorithmDescriptor); + +TEST(ConvDescriptionTest, DefaultInstanceHasBriefDescription) +{ + static constexpr const ConvSignature SIGNATURE; + static constexpr const DefaultAlgorithm ALGORITHM; + using Builder = ckb::ConvBuilder; + EXPECT_THAT(ckr::Describe().brief(), ckt::StringEqWithDiff("2D Forward convolution")); +} + +TEST(ConvDescriptionTest, DefaultInstanceHasDetailedDescription) +{ + static constexpr const ConvSignature SIGNATURE; + static constexpr const DefaultAlgorithm ALGORITHM; + using Builder = ckb::ConvBuilder; + EXPECT_THAT(ckr::Describe().detailed(), + ckt::StringEqWithDiff( // + "2D Forward Convolution Kernel\n" + "├─ Signature\n" + "│ ├─ Tensor Type: FP16\n" + "│ ├─ Memory Layout: GNHWC_GKYXC_GNHWK\n" + "│ ├─ Input elementwise operation: PASS_THROUGH\n" + "│ ├─ Weights elementwise operation: PASS_THROUGH\n" + "│ └─ Output elementwise operation: PASS_THROUGH\n" + "├─ Algorithm\n" + "│ ├─ Thread block size: 256\n" + "│ ├─ Data tile size: 256×256×32\n" + "│ ├─ Gemm padding: DEFAULT\n" + "│ ├─ Convolution specialization: DEFAULT\n" + "│ ├─ Pipeline version: V4\n" + "│ ├─ Pipeline scheduler: INTRAWAVE\n" + "│ ├─ Warp Gemm parameters: \n" + "│ │ ├─ subtile size: 16×16\n" + "│ │ └─ Number of warp gemm iterations: 4×4\n" + "│ ├─ Memory access:\n" + "│ │ ├─ A Tile transfer: \n" + "│ │ │ ├─ Tile dimensions: 4×256×8×\n" + "│ │ │ ├─ The innermost K subdimension size: 8\n" + "│ │ │ ├─ Spatial thread distribution over the data tile: 0×1×2\n" + "│ │ │ ├─ The order of accessing data tile axes: 0×1×2\n" + "│ │ │ ├─ Vectorized memory access axis index (with contiguous memory): 2\n" + "│ │ │ ├─ Vector access (GMEM read) instruction size: 8\n" + "│ │ │ ├─ Vector access (LDS write) instruction size: 8\n" + "│ │ │ └─ LDS data layout padding (to prevent bank conflicts): 8\n" + "│ │ ├─ B Tile transfer: \n" + "│ │ │ ├─ Tile dimensions: 4×256×8×\n" + "│ │ │ ├─ The innermost K subdimension size: 8\n" + "│ │ │ ├─ Spatial thread distribution over the data tile: 0×1×2\n" + "│ │ │ ├─ The order of accessing data tile axes: 0×1×2\n" + "│ │ │ ├─ Vectorized memory access axis index (with contiguous memory): 2\n" + "│ │ │ ├─ Vector access (GMEM read) instruction size: 8\n" + "│ │ │ ├─ Vector access (LDS write) instruction size: 8\n" + "│ │ │ └─ LDS data layout padding (to prevent bank conflicts): 8\n" + "│ │ └─ C Tile transfer: \n" + "│ │ ├─ Data shuffle (number of gemm instructions per iteration): 1×1\n" + "│ │ ├─ Spatial thread distribution used to store data: 1×32×1×8\n" + "│ │ └─ Vector access (GMEM write) instruction size: 8\n" + "│ └─ \n" + "└─ ")); +} + +// NOTE: BackwardDataInstanceHasDetailedDescription test is disabled because ConvFactory +// does not have a specialization for backward data convolutions. The test fails with: +// "implicit instantiation of undefined template 'ck_tile::builder::ConvFactory<...>'" +// +// To enable this test, a ConvFactory specialization for backward data operations must be +// implemented first. +// +// TEST(ConvDescriptionTest, BackwardDataInstanceHasDetailedDescription) +// { +// struct BackwardDataSignature +// { +// int spatial_dim = 2; +// ckb::ConvDirection direction = ckb::ConvDirection::BACKWARD_DATA; +// ckb::GroupConvLayout layout = ckb::GroupConvLayout2D::GNHWC_GKYXC_GNHWK; +// ckb::DataType data_type = ckb::DataType::FP16; +// ckb::ElementwiseOperation elementwise_operation = +// ckb::ElementwiseOperation::PASS_THROUGH; ckb::GroupConvDeviceOp device_operation = +// ckb::BwdDataGroupConvDeviceOperation::DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1; +// }; +// static_assert(ckb::ConvSignatureDescriptor); +// +// static constexpr const BackwardDataSignature SIGNATURE; +// static constexpr const DefaultAlgorithm ALGORITHM; +// using Builder = ckb::ConvBuilder; +// +// // Verify Brief works +// EXPECT_THAT(ckr::Describe().brief(), +// ckt::StringEqWithDiff("2D Backward Data convolution")); +// +// // Verify detailed works - to be updated once ConvFactory is implemented +// EXPECT_THAT(ckr::Describe().detailed(), +// ckt::StringEqWithDiff("PLACEHOLDER")); +// } +} // namespace From 76c4c12f5959adcd56d1627a1d1ce885deb9d096 Mon Sep 17 00:00:00 2001 From: Johannes Graner Date: Fri, 7 Nov 2025 00:07:39 +0100 Subject: [PATCH 3/6] Add .clangd and CMakeUserPresets.json to .gitignore (#3171) --- .gitignore | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/.gitignore b/.gitignore index 6641e5bc58..2641a661d8 100644 --- a/.gitignore +++ b/.gitignore @@ -66,6 +66,12 @@ docs/doxygen/xml cmake-build*/ build*/ +# LSP configuration +.clangd + +# User-defined CMake presets +CMakeUserPresets.json + # Python virtualenv .venv/ From 5f3cae3e28a042e411afcd2e54b16cc6909c5bbb Mon Sep 17 00:00:00 2001 From: JH-Leon-KIM-AMD Date: Fri, 7 Nov 2025 02:29:48 +0200 Subject: [PATCH 4/6] [CK_BUILDER]ckb add remining fwd conv device ops (#3155) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Add device operation to conv signature. Use unions to hold conv layouts and device operations. * Add predicates for all device op instances. * Use the device op signature for validation. * Fix ckb CMakeLists.txt file for tests. * Fix building CK Builder instance traits after the introduction of direct load template parameter in CK. * Fix clang-formatting. * add device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk * Add full DL configurability with Option A implementation - Added 5 DL descriptor structs (39 configurable parameters) - Added 10 C++20 concepts for type-safe validation - Updated factory to read all parameters from descriptors - Updated test helper to populate all descriptors - All tests passing (13/13 including 3 new DL tests) * Add factory and test support for DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor - Add factory specialization for Large_Tensor device operation (conv_factory.hpp lines 1145-1265) - Add macro collision workaround using pragma push/pop (conv_factory.hpp lines 43-51) - Add test helper function run_test_DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor - Add builder test file test_ckb_conv_fwd_2d_large_tensor_fp16.cpp with 2 test cases - Update CMakeLists.txt to include new test file - Reuse existing ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle descriptor - Map all 42 template parameters identical to regular XDL CShuffle - All 15 builder tests passing including 2 new Large_Tensor tests Completes Task 350: All 4 forward convolution device operations now supported in CK Builder. * Update copyright headers to new format - Change copyright format to: Copyright (C) Advanced Micro Devices, Inc., or its affiliates. - Reorder headers: Copyright first, then SPDX-License-Identifier - Updated files: * experimental/builder/test/conv/test_ckb_conv_fwd_2d_dl_fp16.cpp * experimental/builder/test/conv/test_ckb_conv_fwd_2d_large_tensor_fp16.cpp * experimental/builder/include/ck_tile/builder/device_op_types.hpp * fix c++ 18 format * Fix clang-format-18 error in device_op_types.hpp --------- Co-authored-by: Ville Pietilä Co-authored-by: Ville Pietilä <188998872+vpietila-amd@users.noreply.github.com> --- .../builder/conv_algorithm_concepts.hpp | 83 ++++++ .../include/ck_tile/builder/conv_factory.hpp | 271 ++++++++++++++++++ .../ck_tile/builder/device_op_types.hpp | 22 ++ experimental/builder/test/CMakeLists.txt | 2 + .../conv/test_ckb_conv_fwd_2d_dl_fp16.cpp | 69 +++++ ...test_ckb_conv_fwd_2d_large_tensor_fp16.cpp | 53 ++++ .../test/impl/conv_algorithm_types.hpp | 80 ++++++ .../test/utils/ckb_conv_test_common.hpp | 145 ++++++++++ 8 files changed, 725 insertions(+) create mode 100644 experimental/builder/include/ck_tile/builder/device_op_types.hpp create mode 100644 experimental/builder/test/conv/test_ckb_conv_fwd_2d_dl_fp16.cpp create mode 100644 experimental/builder/test/conv/test_ckb_conv_fwd_2d_large_tensor_fp16.cpp diff --git a/experimental/builder/include/ck_tile/builder/conv_algorithm_concepts.hpp b/experimental/builder/include/ck_tile/builder/conv_algorithm_concepts.hpp index e43f910a73..6006efe4f8 100644 --- a/experimental/builder/include/ck_tile/builder/conv_algorithm_concepts.hpp +++ b/experimental/builder/include/ck_tile/builder/conv_algorithm_concepts.hpp @@ -183,4 +183,87 @@ concept SpecifiesLoopScheduler = requires { { T::loop_scheduler } -> std::convertible_to; }; +/******************************************** */ +/* DL-specific descriptors and requirements */ +/******************************************** */ + +// Concept for DL thread configuration +template +concept DlThreadConfigDescriptor = requires(T t) { + { t.k0_per_block } -> std::convertible_to; + { t.k1 } -> std::convertible_to; + { t.m1_per_thread } -> std::convertible_to; + { t.n1_per_thread } -> std::convertible_to; + { t.k_per_thread } -> std::convertible_to; +}; + +// Concept for DL thread cluster +template +concept DlThreadClusterDescriptor = requires(T t) { + { t.m1_xs } -> std::convertible_to>; + { t.n1_xs } -> std::convertible_to>; +}; + +// Concept for DL block transfer K0_M0_M1_K1 format +template +concept DlBlockTransferK0M0M1K1Descriptor = requires(T t) { + { t.thread_slice_lengths } -> std::convertible_to>; + { t.thread_cluster_lengths } -> std::convertible_to>; + { t.thread_cluster_arrange_order } -> std::convertible_to>; + { t.src_access_order } -> std::convertible_to>; + { t.src_vector_tensor_lengths } -> std::convertible_to>; + { t.src_vector_tensor_contiguous_dim_order } -> std::convertible_to>; + { t.dst_vector_tensor_lengths } -> std::convertible_to>; +}; + +// Concept for DL block transfer K0_N0_N1_K1 format +template +concept DlBlockTransferK0N0N1K1Descriptor = requires(T t) { + { t.thread_slice_lengths } -> std::convertible_to>; + { t.thread_cluster_lengths } -> std::convertible_to>; + { t.thread_cluster_arrange_order } -> std::convertible_to>; + { t.src_access_order } -> std::convertible_to>; + { t.src_vector_tensor_lengths } -> std::convertible_to>; + { t.src_vector_tensor_contiguous_dim_order } -> std::convertible_to>; + { t.dst_vector_tensor_lengths } -> std::convertible_to>; +}; + +// Concept for DL C thread transfer +template +concept DlCThreadTransferDescriptor = requires(T t) { + { t.src_dst_access_order } -> std::convertible_to>; + { t.src_dst_vector_dim } -> std::convertible_to; + { t.dst_scalar_per_vector } -> std::convertible_to; +}; + +// Concept to check if algorithm specifies DL thread config +template +concept SpecifiesDlThreadConfig = requires { + { T::dl_thread_config } -> DlThreadConfigDescriptor; +}; + +// Concept to check if algorithm specifies DL thread cluster +template +concept SpecifiesDlThreadCluster = requires { + { T::dl_thread_cluster } -> DlThreadClusterDescriptor; +}; + +// Concept to check if algorithm specifies DL A block transfer +template +concept SpecifiesDlBlockTransferA = requires { + { T::dl_block_transfer_a } -> DlBlockTransferK0M0M1K1Descriptor; +}; + +// Concept to check if algorithm specifies DL B block transfer +template +concept SpecifiesDlBlockTransferB = requires { + { T::dl_block_transfer_b } -> DlBlockTransferK0N0N1K1Descriptor; +}; + +// Concept to check if algorithm specifies DL C thread transfer +template +concept SpecifiesDlCThreadTransfer = requires { + { T::dl_c_thread_transfer } -> DlCThreadTransferDescriptor; +}; + } // namespace ck_tile::builder diff --git a/experimental/builder/include/ck_tile/builder/conv_factory.hpp b/experimental/builder/include/ck_tile/builder/conv_factory.hpp index 1ccc190ba2..e40199987d 100644 --- a/experimental/builder/include/ck_tile/builder/conv_factory.hpp +++ b/experimental/builder/include/ck_tile/builder/conv_factory.hpp @@ -36,9 +36,21 @@ #pragma once +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp" #include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp" #include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp" #include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp" +// WORKAROUND: Macro namespace collision in upstream CK device operation headers. +// device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp (line 41) and +// device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp (line 51) both define +// GridwiseGemmTemplateParameters macro without #undef, causing redefinition errors. +// Use pragma push/pop to isolate the Large_Tensor header's macro scope. +#pragma push_macro("GridwiseGemmTemplateParameters") +#ifdef GridwiseGemmTemplateParameters +#undef GridwiseGemmTemplateParameters +#endif +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp" +#pragma pop_macro("GridwiseGemmTemplateParameters") #include "ck_tile/builder/conv_signature_concepts.hpp" #include "ck_tile/builder/conv_algorithm_concepts.hpp" #include "ck_tile/builder/conv_algorithm_limits.hpp" @@ -990,4 +1002,263 @@ struct ConvFactory GRIDWISE_GEMM_PIPELINE_VERSION>; }; +// Factory specialization for DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK instance +// of a grouped forward convolution kernel using Direct Load (DL) approach. +template + requires ConvDirectionIsForward && + ConvDeviceOpIs_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK +struct ConvFactory +{ + static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim; + using Layouts = decltype(factory_internal::GetTensorLayout()); + using Types = factory_internal::ConvTensorTypes; + using Ops = factory_internal::ElementwiseOps; + using AlgorithmType = decltype(ALGORITHM); + + static_assert(SpecifiesThreadBlock, + "The convolution algorithm descriptor must specify thread block info."); + static_assert(SpecifiesFwdConcSpecialization, + "The convolution algorithm descriptor must specify forward convolution " + "specialization."); + static_assert(SpecifiesGemmSpecialization, + "The convolution algorithm descriptor must specify gemm specialization."); + static_assert(SpecifiesDlThreadConfig, + "DL algorithm must specify thread config."); + static_assert(SpecifiesDlThreadCluster, + "DL algorithm must specify thread cluster."); + static_assert(SpecifiesDlBlockTransferA, + "DL algorithm must specify A block transfer."); + static_assert(SpecifiesDlBlockTransferB, + "DL algorithm must specify B block transfer."); + static_assert(SpecifiesDlCThreadTransfer, + "DL algorithm must specify C thread transfer."); + + static constexpr auto FWD_CONV_SPECIALIZATION = + factory_internal::SetFwdConvSpecialization(); + static constexpr auto GEMM_SPECIALIZATION = + factory_internal::SetGemmSpecialization(); + + static constexpr auto BLOCK = factory_internal::SetThreadBlockInfo(); + + // DL-specific parameters from algorithm descriptor + static constexpr auto DL_THREAD_CFG = ALGORITHM.dl_thread_config; + static constexpr ck::index_t K0PerBlock = DL_THREAD_CFG.k0_per_block; + static constexpr ck::index_t K1 = DL_THREAD_CFG.k1; + static constexpr ck::index_t M1PerThread = DL_THREAD_CFG.m1_per_thread; + static constexpr ck::index_t N1PerThread = DL_THREAD_CFG.n1_per_thread; + static constexpr ck::index_t KPerThread = DL_THREAD_CFG.k_per_thread; + + // Thread cluster from descriptor + static constexpr auto DL_CLUSTER = ALGORITHM.dl_thread_cluster; + using M1N1ThreadClusterM1Xs = to_sequence_v; + using M1N1ThreadClusterN1Xs = to_sequence_v; + + // A Block Transfer from descriptor - K0_M0_M1_K1 tensor format + static constexpr auto DL_A_TRANSFER = ALGORITHM.dl_block_transfer_a; + using ABlockTransferThreadSliceLengths_K0_M0_M1_K1 = + to_sequence_v; + using ABlockTransferThreadClusterLengths_K0_M0_M1_K1 = + to_sequence_v; + using ABlockTransferThreadClusterArrangeOrder = + to_sequence_v; + using ABlockTransferSrcAccessOrder = to_sequence_v; + using ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1 = + to_sequence_v; + using ABlockTransferSrcVectorTensorContiguousDimOrder = + to_sequence_v; + using ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1 = + to_sequence_v; + + // B Block Transfer from descriptor - K0_N0_N1_K1 tensor format + static constexpr auto DL_B_TRANSFER = ALGORITHM.dl_block_transfer_b; + using BBlockTransferThreadSliceLengths_K0_N0_N1_K1 = + to_sequence_v; + using BBlockTransferThreadClusterLengths_K0_N0_N1_K1 = + to_sequence_v; + using BBlockTransferThreadClusterArrangeOrder = + to_sequence_v; + using BBlockTransferSrcAccessOrder = to_sequence_v; + using BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1 = + to_sequence_v; + using BBlockTransferSrcVectorTensorContiguousDimOrder = + to_sequence_v; + using BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1 = + to_sequence_v; + + // C Thread Transfer from descriptor + static constexpr auto DL_C_TRANSFER = ALGORITHM.dl_c_thread_transfer; + using CThreadTransferSrcDstAccessOrder = to_sequence_v; + static constexpr ck::index_t CThreadTransferSrcDstVectorDim = DL_C_TRANSFER.src_dst_vector_dim; + static constexpr ck::index_t CThreadTransferDstScalarPerVector = + DL_C_TRANSFER.dst_scalar_per_vector; + + // The DL forward convolution kernel class instance + using Instance = ck::tensor_operation::device::DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK< + SPATIAL_DIM, + typename Types::ADataType, + typename Types::BDataType, + typename Types::DsDataTypes, + typename Types::EDataType, + typename Types::AccDataType, + typename Layouts::ALayout, + typename Layouts::BLayout, + typename Layouts::DsLayout, + typename Layouts::ELayout, + typename Ops::AElementwiseOp, + typename Ops::BElementwiseOp, + typename Ops::CDEElementwiseOp, + FWD_CONV_SPECIALIZATION, + GEMM_SPECIALIZATION, + BLOCK.block_size, + BLOCK.per_block.m, + BLOCK.per_block.n, + K0PerBlock, + K1, + M1PerThread, + N1PerThread, + KPerThread, + M1N1ThreadClusterM1Xs, + M1N1ThreadClusterN1Xs, + ABlockTransferThreadSliceLengths_K0_M0_M1_K1, + ABlockTransferThreadClusterLengths_K0_M0_M1_K1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1, + ABlockTransferSrcVectorTensorContiguousDimOrder, + ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1, + BBlockTransferThreadSliceLengths_K0_N0_N1_K1, + BBlockTransferThreadClusterLengths_K0_N0_N1_K1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1, + BBlockTransferSrcVectorTensorContiguousDimOrder, + BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1, + CThreadTransferSrcDstAccessOrder, + CThreadTransferSrcDstVectorDim, + CThreadTransferDstScalarPerVector>; +}; + +// Factory specialization for DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor instance +// of a grouped forward convolution kernel with large tensor support (N-splitting). +template + requires ConvDirectionIsForward && + ConvDeviceOpIs_DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor +struct ConvFactory +{ + static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim; + using Layouts = decltype(factory_internal::GetTensorLayout()); + using Types = factory_internal::ConvTensorTypes; + using Ops = factory_internal::ElementwiseOps; + using AlgorithmType = decltype(ALGORITHM); + + static_assert(SpecifiesThreadBlock, + "The convolution algorithm descriptor must specify thread block info."); + static_assert(SpecifiesGridwiseXdlGemm, + "The convolution algorithm descriptor must specify gridwise GEMM info."); + static_assert(SpecifiesBlockTransfer, + "The convolution algorithm descriptor must specify block transfer info."); + static_assert(SpecifiesLdsTransfer, + "The convolution algorithm descriptor must specify LDS transfer info."); + static_assert( + SpecifiesThreadClusterAccessOrder, + "The convolution algorithm descriptor must specify thread cluster access order info."); + static_assert(SpecifiesSourceAccessOrder, + "The convolution algorithm descriptor must specify source access order info."); + static_assert(SpecifiesFwdConcSpecialization, + "The convolution algorithm descriptor must specify forward convolution " + "specialization."); + static_assert(SpecifiesGemmSpecialization, + "The convolution algorithm descriptor must specify gemm specialization."); + static_assert(SpecifiesNumPrefetchStages, + "The convolution algorithm descriptor must specify number of prefetch stages."); + static_assert(SpecifiesLoopScheduler, + "The convolution algorithm descriptor must specify loop scheduler."); + + static constexpr auto FWD_CONV_SPECIALIZATION = + factory_internal::SetFwdConvSpecialization(); + static constexpr auto GEMM_SPECIALIZATION = + factory_internal::SetGemmSpecialization(); + static constexpr factory_internal::ConvSpec SPECIALIZATION{.conv_spec = FWD_CONV_SPECIALIZATION, + .gemm_spec = GEMM_SPECIALIZATION}; + + static constexpr auto LOOP_SCHEDULER = factory_internal::SetLoopScheduler(); + static constexpr auto BLOCK = factory_internal::SetThreadBlockInfo(); + static constexpr auto GRIDWISE_GEMM = ALGORITHM.gridwise_gemm; + static constexpr auto A_BLOCK_TRANSFER = + factory_internal::SetFwdConvABlockTransfer(); + static constexpr auto B_BLOCK_TRANSFER = + factory_internal::SetFwdConvBBlockTransfer(); + static constexpr auto C_BLOCK_TRANSFER = + factory_internal::SetCBlockTransfer(); + + // Check limits for the algorithm parameters. + static_assert(InputVectorTransferLimits); + static_assert(InputVectorTransferLimits); + static_assert(OutputVectorTransferLimits); + static_assert(AccessOrderLimits); + static_assert(AccessOrderLimits); + static_assert(AccessOrderLimits); + static_assert(AccessOrderLimits); + + // The forward convolution kernel class instance with large tensor support. + using Instance = + ck::tensor_operation::device::DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor< + SPATIAL_DIM, + typename Layouts::ALayout, + typename Layouts::BLayout, + typename Layouts::DsLayout, + typename Layouts::ELayout, + typename Types::ADataType, + typename Types::BDataType, + typename Types::AccDataType, + typename Types::CShuffleDataType, + typename Types::DsDataTypes, + typename Types::EDataType, + typename Ops::AElementwiseOp, + typename Ops::BElementwiseOp, + typename Ops::CDEElementwiseOp, + SPECIALIZATION.conv_spec, + SPECIALIZATION.gemm_spec, + ALGORITHM.num_gemm_k_prefetch_stages, + BLOCK.block_size, + BLOCK.per_block.m, + BLOCK.per_block.n, + BLOCK.per_block.k, + GRIDWISE_GEMM.ak1, + GRIDWISE_GEMM.bk1, + GRIDWISE_GEMM.m_per_xdl, + GRIDWISE_GEMM.n_per_xdl, + GRIDWISE_GEMM.m_xdl_per_wave, + GRIDWISE_GEMM.n_xdl_per_wave, + to_sequence_v, + to_sequence_v, + to_sequence_v, + A_BLOCK_TRANSFER.src_vector_dim, + A_BLOCK_TRANSFER.src_scalar_per_vector, + A_BLOCK_TRANSFER.lds_dst_scalar_per_vector, + A_BLOCK_TRANSFER.lds_padding, + to_sequence_v, + to_sequence_v, + to_sequence_v, + B_BLOCK_TRANSFER.src_vector_dim, + B_BLOCK_TRANSFER.src_scalar_per_vector, + B_BLOCK_TRANSFER.lds_dst_scalar_per_vector, + B_BLOCK_TRANSFER.lds_padding, + C_BLOCK_TRANSFER.m_per_wave_per_shuffle, + C_BLOCK_TRANSFER.n_per_wave_per_shuffle, + to_sequence_v, + C_BLOCK_TRANSFER.scalar_per_vector, + typename Types::AComputeType, + typename Types::BComputeType, + LOOP_SCHEDULER>; +}; + } // namespace ck_tile::builder diff --git a/experimental/builder/include/ck_tile/builder/device_op_types.hpp b/experimental/builder/include/ck_tile/builder/device_op_types.hpp new file mode 100644 index 0000000000..0e779fdf4e --- /dev/null +++ b/experimental/builder/include/ck_tile/builder/device_op_types.hpp @@ -0,0 +1,22 @@ +// Copyright (C) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +namespace ck_tile::builder { + +// Enumeration for CK Device Operation types. +// This allows the builder to select which device operation template to instantiate +// based on the user's requirements. +enum class DeviceOpType +{ + // Forward Convolution - Non-grouped + CONV_FWD, // Maps to: DeviceConvFwd (TODO: No implementation with tuning params exists yet) + + // Forward Convolution - Grouped + GROUPED_CONV_FWD_MULTIPLE_ABD, // Maps to: DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle + GROUPED_CONV_FWD_MULTIPLE_ABD_XDL_CSHUFFLE_V3, // Maps to: + // DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 +}; + +} // namespace ck_tile::builder diff --git a/experimental/builder/test/CMakeLists.txt b/experimental/builder/test/CMakeLists.txt index b776edbcde..43c4fd4857 100644 --- a/experimental/builder/test/CMakeLists.txt +++ b/experimental/builder/test/CMakeLists.txt @@ -43,6 +43,8 @@ add_ck_builder_test(test_ckb_build_fwd_instances conv/test_ckb_conv_fwd_2d_bf16.cpp conv/test_ckb_conv_fwd_2d_fp16.cpp conv/test_ckb_conv_fwd_2d_fp32.cpp + conv/test_ckb_conv_fwd_2d_dl_fp16.cpp + conv/test_ckb_conv_fwd_2d_large_tensor_fp16.cpp conv/test_ckb_conv_fwd_3d_bf16.cpp conv/test_ckb_conv_fwd_3d_fp16.cpp conv/test_ckb_conv_fwd_3d_fp32.cpp) diff --git a/experimental/builder/test/conv/test_ckb_conv_fwd_2d_dl_fp16.cpp b/experimental/builder/test/conv/test_ckb_conv_fwd_2d_dl_fp16.cpp new file mode 100644 index 0000000000..12730bab19 --- /dev/null +++ b/experimental/builder/test/conv/test_ckb_conv_fwd_2d_dl_fp16.cpp @@ -0,0 +1,69 @@ +// Copyright (C) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "utils/ckb_conv_test_common.hpp" + +using namespace ck_tile::builder::test_utils; + +namespace ck_tile::builder::testing { + +TEST(FwdConvInstances, Create_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK_Instance_2D_FP16_GNHWC) +{ + constexpr ConvSignature FwdConvSignature{ + .spatial_dim = 2, + .direction = ConvDirection::FORWARD, + .layout = GroupConvLayout2D::GNHWC_GKYXC_GNHWK, + .data_type = DataType::FP16, + .elementwise_operation = ElementwiseOperation::PASS_THROUGH, + .device_operation = + FwdGroupConvDeviceOperation::DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK}; + + constexpr ThreadBlock FwdThreadBlock{.block_size = 256, + .tile_size = {.m = 128, .n = 128, .k = 16}}; + + run_test_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK(); +} + +TEST(FwdConvInstances, Create_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK_Instance_2D_FP16_NHWGC) +{ + constexpr ConvSignature FwdConvSignature{ + .spatial_dim = 2, + .direction = ConvDirection::FORWARD, + .layout = GroupConvLayout2D::NHWGC_GKYXC_NHWGK, + .data_type = DataType::FP16, + .elementwise_operation = ElementwiseOperation::PASS_THROUGH, + .device_operation = + FwdGroupConvDeviceOperation::DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK}; + + constexpr ThreadBlock FwdThreadBlock{.block_size = 256, + .tile_size = {.m = 128, .n = 128, .k = 16}}; + + run_test_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK(); +} + +TEST(FwdConvInstances, + Create_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK_Instance_2D_FP16_FILTER_1X1_PAD0) +{ + constexpr ConvSignature FwdConvSignature{ + .spatial_dim = 2, + .direction = ConvDirection::FORWARD, + .layout = GroupConvLayout2D::GNHWC_GKYXC_GNHWK, + .data_type = DataType::FP16, + .elementwise_operation = ElementwiseOperation::PASS_THROUGH, + .device_operation = + FwdGroupConvDeviceOperation::DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK}; + + constexpr ThreadBlock FwdThreadBlock{.block_size = 256, + .tile_size = {.m = 128, .n = 128, .k = 16}}; + + run_test_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK< + FwdConvSignature, + FwdThreadBlock, + ConvFwdSpecialization::FILTER_1X1_PAD0>(); +} + +} // namespace ck_tile::builder::testing diff --git a/experimental/builder/test/conv/test_ckb_conv_fwd_2d_large_tensor_fp16.cpp b/experimental/builder/test/conv/test_ckb_conv_fwd_2d_large_tensor_fp16.cpp new file mode 100644 index 0000000000..0216c5907d --- /dev/null +++ b/experimental/builder/test/conv/test_ckb_conv_fwd_2d_large_tensor_fp16.cpp @@ -0,0 +1,53 @@ +// Copyright (C) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "utils/ckb_conv_test_common.hpp" + +using namespace ck_tile::builder::test_utils; + +namespace ck_tile::builder::testing { + +TEST(FwdConvInstances, + Create_DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor_Instance_2D_FP16_GNHWC) +{ + constexpr ConvSignature FwdConvSignature{ + .spatial_dim = 2, + .direction = ConvDirection::FORWARD, + .layout = GroupConvLayout2D::GNHWC_GKYXC_GNHWK, + .data_type = DataType::FP16, + .elementwise_operation = ElementwiseOperation::PASS_THROUGH, + .device_operation = + FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor}; + + constexpr ThreadBlock FwdThreadBlock{.block_size = 256, + .tile_size = {.m = 256, .n = 128, .k = 32}}; + + run_test_DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor< + FwdConvSignature, + FwdThreadBlock, + ConvFwdSpecialization::DEFAULT>(); +} + +TEST( + FwdConvInstances, + Create_DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor_Instance_2D_FP16_GNHWC_Filter1x1Pad0) +{ + constexpr ConvSignature FwdConvSignature{ + .spatial_dim = 2, + .direction = ConvDirection::FORWARD, + .layout = GroupConvLayout2D::GNHWC_GKYXC_GNHWK, + .data_type = DataType::FP16, + .elementwise_operation = ElementwiseOperation::PASS_THROUGH, + .device_operation = + FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor}; + + constexpr ThreadBlock FwdThreadBlock{.block_size = 128, + .tile_size = {.m = 128, .n = 128, .k = 32}}; + + run_test_DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor< + FwdConvSignature, + FwdThreadBlock, + ConvFwdSpecialization::FILTER_1X1_PAD0>(); +} + +} // namespace ck_tile::builder::testing diff --git a/experimental/builder/test/impl/conv_algorithm_types.hpp b/experimental/builder/test/impl/conv_algorithm_types.hpp index accc4048dc..88c5b5787a 100644 --- a/experimental/builder/test/impl/conv_algorithm_types.hpp +++ b/experimental/builder/test/impl/conv_algorithm_types.hpp @@ -214,4 +214,84 @@ static_assert( static_assert( ckb::SpecifiesLoopScheduler); +// DL-specific descriptors +struct DlThreadConfig +{ + size_t k0_per_block; + size_t k1; + size_t m1_per_thread; + size_t n1_per_thread; + size_t k_per_thread; +}; +static_assert(ckb::DlThreadConfigDescriptor); + +struct DlThreadCluster +{ + std::array m1_xs; // e.g., {8, 2} + std::array n1_xs; // e.g., {8, 2} +}; +static_assert(ckb::DlThreadClusterDescriptor); + +struct DlBlockTransferK0M0M1K1 +{ + std::array thread_slice_lengths; + std::array thread_cluster_lengths; + std::array thread_cluster_arrange_order; + std::array src_access_order; + std::array src_vector_tensor_lengths; + std::array src_vector_tensor_contiguous_dim_order; + std::array dst_vector_tensor_lengths; +}; +static_assert(ckb::DlBlockTransferK0M0M1K1Descriptor); + +struct DlBlockTransferK0N0N1K1 +{ + std::array thread_slice_lengths; + std::array thread_cluster_lengths; + std::array thread_cluster_arrange_order; + std::array src_access_order; + std::array src_vector_tensor_lengths; + std::array src_vector_tensor_contiguous_dim_order; + std::array dst_vector_tensor_lengths; +}; +static_assert(ckb::DlBlockTransferK0N0N1K1Descriptor); + +struct DlCThreadTransfer +{ + std::array src_dst_access_order; + size_t src_dst_vector_dim; + size_t dst_scalar_per_vector; +}; +static_assert(ckb::DlCThreadTransferDescriptor); + +struct ConvAlgorithm_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK +{ + ThreadBlock thread_block; + ConvFwdSpecialization fwd_specialization; + GemmSpecialization gemm_specialization; + DlThreadConfig dl_thread_config; + DlThreadCluster dl_thread_cluster; + DlBlockTransferK0M0M1K1 dl_block_transfer_a; + DlBlockTransferK0N0N1K1 dl_block_transfer_b; + DlCThreadTransfer dl_c_thread_transfer; +}; +static_assert( + ckb::ConvAlgorithmDescriptor); +static_assert( + ckb::SpecifiesThreadBlock); +static_assert(ckb::SpecifiesFwdConcSpecialization< + ConvAlgorithm_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK>); +static_assert( + ckb::SpecifiesGemmSpecialization); +static_assert( + ckb::SpecifiesDlThreadConfig); +static_assert( + ckb::SpecifiesDlThreadCluster); +static_assert( + ckb::SpecifiesDlBlockTransferA); +static_assert( + ckb::SpecifiesDlBlockTransferB); +static_assert( + ckb::SpecifiesDlCThreadTransfer); + } // namespace ck_tile::builder::test diff --git a/experimental/builder/test/utils/ckb_conv_test_common.hpp b/experimental/builder/test/utils/ckb_conv_test_common.hpp index 7fd02a56f7..14fae566f6 100644 --- a/experimental/builder/test/utils/ckb_conv_test_common.hpp +++ b/experimental/builder/test/utils/ckb_conv_test_common.hpp @@ -235,4 +235,149 @@ constexpr void run_test_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle() EXPECT_NE(invoker_ptr, nullptr); } +template +constexpr void run_test_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK() +{ + // DL thread configuration + constexpr DlThreadConfig DlThreadCfg{ + .k0_per_block = 16, .k1 = 2, .m1_per_thread = 4, .n1_per_thread = 4, .k_per_thread = 1}; + + // DL thread cluster + constexpr DlThreadCluster DlCluster{.m1_xs = {8, 2}, .n1_xs = {8, 2}}; + + // DL A block transfer - K0_M0_M1_K1 format + constexpr DlBlockTransferK0M0M1K1 DlBlockTransferA{ + .thread_slice_lengths = {8, 1, 1, 2}, + .thread_cluster_lengths = {2, 1, 128, 1}, + .thread_cluster_arrange_order = {1, 2, 0, 3}, + .src_access_order = {1, 2, 0, 3}, + .src_vector_tensor_lengths = {4, 1, 1, 2}, + .src_vector_tensor_contiguous_dim_order = {1, 2, 0, 3}, + .dst_vector_tensor_lengths = {1, 1, 1, 2}}; + + // DL B block transfer - K0_N0_N1_K1 format + constexpr DlBlockTransferK0N0N1K1 DlBlockTransferB{ + .thread_slice_lengths = {8, 1, 1, 2}, + .thread_cluster_lengths = {2, 1, 128, 1}, + .thread_cluster_arrange_order = {1, 2, 0, 3}, + .src_access_order = {1, 2, 0, 3}, + .src_vector_tensor_lengths = {4, 1, 1, 2}, + .src_vector_tensor_contiguous_dim_order = {1, 2, 0, 3}, + .dst_vector_tensor_lengths = {1, 1, 1, 2}}; + + // DL C thread transfer + constexpr DlCThreadTransfer DlCTransfer{.src_dst_access_order = {0, 1, 2, 3, 4, 5}, + .src_dst_vector_dim = 5, + .dst_scalar_per_vector = 4}; + + constexpr ConvAlgorithm_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK FwdConvAlgorithm{ + .thread_block = FwdThreadBlock, + .fwd_specialization = FwdConvSpecialization, + .gemm_specialization = GemmSpecialization::MNKPadding, + .dl_thread_config = DlThreadCfg, + .dl_thread_cluster = DlCluster, + .dl_block_transfer_a = DlBlockTransferA, + .dl_block_transfer_b = DlBlockTransferB, + .dl_c_thread_transfer = DlCTransfer}; + + using Builder = ConvBuilder; + + auto instance = typename Builder::Instance{}; + + const auto kernel_string = instance.GetTypeString(); + std::cout << "Generated kernel: " << kernel_string << std::endl; + EXPECT_GT(kernel_string.size(), 0); + + EXPECT_TRUE(kernel_string.starts_with("DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK")); + + // Verify specialization is correct + if(FwdConvSpecialization == ConvFwdSpecialization::DEFAULT) + EXPECT_TRUE(kernel_string.find("Default") != std::string::npos); + else if(FwdConvSpecialization == ConvFwdSpecialization::FILTER_1X1_PAD0) + EXPECT_TRUE(kernel_string.find("Filter1x1Pad0") != std::string::npos); + else if(FwdConvSpecialization == ConvFwdSpecialization::FILTER_1X1_STRIDE1_PAD0) + EXPECT_TRUE(kernel_string.find("Filter1x1Stride1Pad0") != std::string::npos); + else if(FwdConvSpecialization == ConvFwdSpecialization::FILTER_3x3) + EXPECT_TRUE(kernel_string.find("Filter3x3") != std::string::npos); + + const auto invoker_ptr = instance.MakeInvokerPointer(); + EXPECT_NE(invoker_ptr, nullptr); +} + +// Test helper for DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor +// Note: Large_Tensor has identical parameters to regular XDL CShuffle +template +constexpr void run_test_DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor() +{ + constexpr GridwiseXdlGemm FwdGemmParams{.ak1 = 8, + .bk1 = 8, + .m_per_xdl = 32, + .n_per_xdl = 32, + .m_xdl_per_wave = 2, + .n_xdl_per_wave = 1}; + + constexpr BlockTransferABC FwdBlockTransfer{.block_transfer_a = {.k0 = 4, .m_n = 16, .k1 = 1}, + .block_transfer_b = {.k0 = 4, .m_n = 16, .k1 = 1}, + .thread_cluster_dims_c = {.m_block = 1, + .m_wave_per_xdl = 16, + .n_block = 1, + .n_wave_per_xdl = 4}, + .lds_transfer_a = {.src_vector_dim = 2, + .src_scalar_per_vector = 8, + .lds_dst_scalar_per_vector = 8, + .is_direct_load = false, + .lds_padding = true}, + .lds_transfer_b = {.src_vector_dim = 2, + .src_scalar_per_vector = 8, + .lds_dst_scalar_per_vector = 8, + .is_direct_load = false, + .lds_padding = true}, + .epilogue_c = {.m_per_wave_per_shuffle = 1, + .n_per_wave_per_shuffle = 1, + .scalar_per_vector = 8}, + .block_transfer_access_order_a = {1, 0, 2}, + .block_transfer_access_order_b = {1, 0, 2}, + .src_access_order_a = {1, 0, 2}, + .src_access_order_b = {1, 0, 2}}; + + // Large_Tensor uses the same descriptor as regular XDL CShuffle + constexpr ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle FwdConvAlgorithm{ + .thread_block = FwdThreadBlock, + .gridwise_gemm = FwdGemmParams, + .block_transfer = FwdBlockTransfer, + .fwd_specialization = FwdConvSpecialization, + .gemm_specialization = GemmSpecialization::MNKPadding, + .num_gemm_k_prefetch_stages = 1, + .num_groups_to_merge = 1, + .loop_scheduler = LoopScheduler::DEFAULT}; + + using Builder = ConvBuilder; + + auto instance = typename Builder::Instance{}; + + const auto kernel_string = instance.GetTypeString(); + std::cout << "Generated kernel: " << kernel_string << std::endl; + EXPECT_GT(kernel_string.size(), 0); + + EXPECT_TRUE( + kernel_string.starts_with("DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor")); + + // Verify specialization is correct + if(FwdConvSpecialization == ConvFwdSpecialization::DEFAULT) + EXPECT_TRUE(kernel_string.find("Default") != std::string::npos); + else if(FwdConvSpecialization == ConvFwdSpecialization::FILTER_1X1_PAD0) + EXPECT_TRUE(kernel_string.find("Filter1x1Pad0") != std::string::npos); + else if(FwdConvSpecialization == ConvFwdSpecialization::FILTER_1X1_STRIDE1_PAD0) + EXPECT_TRUE(kernel_string.find("Filter1x1Stride1Pad0") != std::string::npos); + else if(FwdConvSpecialization == ConvFwdSpecialization::FILTER_3x3) + EXPECT_TRUE(kernel_string.find("Filter3x3") != std::string::npos); + + const auto invoker_ptr = instance.MakeInvokerPointer(); + EXPECT_NE(invoker_ptr, nullptr); +} + } // namespace ck_tile::builder::test_utils From d04eba4ae37c8c2d40855f02aa861e1ac1ec7b3f Mon Sep 17 00:00:00 2001 From: Xudong Yuan Date: Fri, 7 Nov 2025 08:45:41 +0800 Subject: [PATCH 5/6] Ck moe mxfp4 blockm32 (#3098) * block_m = 32 * ck block_m = 32 * aiter/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_mx_moe_v3.hpp format * mxfp4_moe v1 pipe * update format --------- Co-authored-by: zhimding Co-authored-by: lalala-sh Co-authored-by: felix --- .../moe_gemm1_xdl_mx_fp4_bpreshuffle.cpp | 12 +- ...xdlops_b_preshuffle_mx_moe_gufusion_v3.hpp | 3 +- ...ne_xdlops_b_preshuffle_mx_moe_selector.hpp | 24 +- ...pipeline_xdlops_b_preshuffle_mx_moe_v1.hpp | 891 ++++++++++++++++++ ...pipeline_xdlops_b_preshuffle_mx_moe_v3.hpp | 234 +++-- .../impl/device_moe_mx_gemm_bpreshuffle.hpp | 2 +- .../grid/gridwise_moe_mx_gemm_bpreshuffle.hpp | 468 +++++---- 7 files changed, 1357 insertions(+), 277 deletions(-) create mode 100644 include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_mx_moe_v1.hpp diff --git a/example/67_gemm_microscaling/moe_gemm1_xdl_mx_fp4_bpreshuffle.cpp b/example/67_gemm_microscaling/moe_gemm1_xdl_mx_fp4_bpreshuffle.cpp index 1adf039b70..ebb73ca7e0 100644 --- a/example/67_gemm_microscaling/moe_gemm1_xdl_mx_fp4_bpreshuffle.cpp +++ b/example/67_gemm_microscaling/moe_gemm1_xdl_mx_fp4_bpreshuffle.cpp @@ -181,7 +181,7 @@ constexpr ck::index_t ScaleBlockSize = 32; // scaling block constexpr ck::index_t KPerBlock = 256 / DataPackedSize; // 256 f4 = 128 fp4x2 static constexpr ck::index_t Nswizzle = false; static constexpr ck::index_t ActOP = 0; // 0: gelu_and_mul, 1: silu_and_mul -static constexpr ck::index_t MPerBlock = 128; +static constexpr ck::index_t MPerBlock = 32; static constexpr bool MulRoutedWeight = true; // clang-format off @@ -190,10 +190,10 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemmMXBPreShuffl A0DataType, A1DataType, B0DataType, B1DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, ScaleBlockSize, 256, - MPerBlock, 64, KPerBlock, + MPerBlock, 128, KPerBlock, 16, 16, 16, 16, - 4, 2, + 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 2, 2, S<1, 32, 1, 8>, S<8, 1, 1, 1>, @@ -213,10 +213,10 @@ int main(int argc, char* argv[]) ck::index_t sorted_size = sorted_tile_num * MPerBlock; ck::index_t valid_size = valid_tile_num * MPerBlock; - ck::index_t N = 6144; - ck::index_t K = 4096; + ck::index_t N = 7168; + ck::index_t K = 256; ck::index_t experts = 8; - ck::index_t tokens = 832; + ck::index_t tokens = 208; ck::index_t topk = 2; if(argc == 1) diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_mx_moe_gufusion_v3.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_mx_moe_gufusion_v3.hpp index b3b3d312c7..b621c3a93d 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_mx_moe_gufusion_v3.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_mx_moe_gufusion_v3.hpp @@ -727,7 +727,8 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_mx_moe_gufusion_v3< }); }); - HotLoopScheduler(); + if constexpr(MPerBlock >= 64) + HotLoopScheduler(); __builtin_amdgcn_sched_barrier(0); }; diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_mx_moe_selector.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_mx_moe_selector.hpp index 6789d26a45..5223993671 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_mx_moe_selector.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_mx_moe_selector.hpp @@ -4,6 +4,7 @@ #pragma once #include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_mx_moe_v3.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_mx_moe_v1.hpp" #include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_mx_moe_gufusion_v3.hpp" namespace ck { @@ -45,7 +46,28 @@ constexpr auto BlockGemmMXBPreshufflePipeline_Selector() } else { - return nullptr; + return BlockwiseGemmXdlops_pipeline_bpreshuffle_mx_moe_v1< + BlkGemmPipeSche, + ThreadBlockSize, + ScaleBlockSize, + ADataType, + AScaleDataType, + BDataType, + BScaleDataType, + ATileDesc, + BTileDesc, + AMmaTileDesc, + BMmaTileDesc, + ABlockTransferSrcScalarPerVector, + BBlockTransferSrcScalarPerVector, + MPerBlock, + NPerBlock, + KPerBlock, + MPerXDL, + NPerXDL, + MRepeat, + NRepeat, + KPack>{}; } } else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_mx_moe_v1.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_mx_moe_v1.hpp new file mode 100644 index 0000000000..fc5cb60c37 --- /dev/null +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_mx_moe_v1.hpp @@ -0,0 +1,891 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/tensor_operation/gpu/block/blockwise_gemm_mx_pipeline_xdlops_base.hpp" + +namespace ck { + +// Naive pipeline with lowest resource request per WGP +// GlobalPrefetchStages: 2 +// LocalPreFillStages: 1 +// LocalPreFetchStages: 1 +// LocalSharedMemoryBuffer: 1 + +template +struct BlockwiseGemmXdlops_pipeline_bpreshuffle_mx_moe_v1 +{ +}; + +template +struct BlockwiseGemmXdlops_pipeline_bpreshuffle_mx_moe_v1 + : BlockwiseGemmXdlops_mx_pipeline_base + +{ + + using Base = BlockwiseGemmXdlops_mx_pipeline_base; + using Base::A_K1; + using Base::I0; + using Base::I1; + using Base::KRepeat; + using Base::MWaves; + using Base::NWaves; + using Base::WaveSize; + using Base::xdlops_gemm; + using typename Base::HotLoopInstList; + + using Base::CalculateCThreadOriginDataIndex; + using Base::GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4; + using Base::GetCThreadBuffer; + using Base::GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4; + using Base::GetWaveIdx; + using Base::MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + + using Base::a_block_desc_m0_m1_m2_m3_k; + using Base::b_block_desc_n0_n1_n2_n3_k; + + using Base::AMmaKStride; + using Base::APackedSize; + using Base::BMmaKStride; + using Base::BPackedSize; + using Base::KThreadChunk; + + using Base::KXdlPack; + using Base::MXdlPack; + using Base::NXdlPack; + + using AccType = typename Base::AccType; + using Tuple5 = typename Base::Tuple5; + using ComputeTypeA = typename Base::ComputeTypeA; + using ComputeTypeB = typename Base::ComputeTypeB; + + static constexpr index_t PrefetchStages = 2; + static constexpr index_t PrefillStages = 1; + static constexpr index_t GlobalBufferNum = 2; + static constexpr index_t HotloopLocalBufSwitch = MRepeat % 2 == 0 ? 0 : 1; + + static constexpr auto num_buffer_load_a_scale = MRepeat / MXdlPack * KRepeat / KXdlPack; + static constexpr auto num_buffer_load_b_scale = NRepeat / NXdlPack * KRepeat / KXdlPack; + static constexpr auto async_vmcnt = + num_buffer_load_a_scale + num_buffer_load_b_scale + HotLoopInstList::B_Buffer_Load_Inst_Num; + static constexpr auto async_vmcnt_encoding = 3952 + async_vmcnt % 16 + async_vmcnt / 16 * 16384; + + static constexpr auto ScalesPerKBlockSize = + KPerBlock / ScaleBlockSize; // How many mx-vectors per K block + + //> How many mx-vectors in each row/col is processed in one call to xdlops_gemm.Run() + static constexpr auto ScalesPerXdlopsRun = + (APackedSize * KPack * xdlops_gemm.K0PerXdlops) / ScaleBlockSize; + + //> How many scales a thread must read to accommodate one call to xdlops_gemm.Run() + static constexpr auto ScalesPerXdlopsRunPerThread = + ScalesPerXdlopsRun / xdlops_gemm.mfma_instr.num_input_blks; + + using mx_scale_t = e8m0_bexp_t; + static constexpr auto scale_pack_size_a = sizeof(AScaleDataType) / sizeof(mx_scale_t); + static constexpr auto scale_pack_size_b = sizeof(BScaleDataType) / sizeof(mx_scale_t); + static_assert(KXdlPack * MXdlPack % scale_pack_size_a == 0, + "A scale pack data type too large!"); + static_assert(KXdlPack * NXdlPack % scale_pack_size_b == 0, + "B scale pack data type too large!"); + static constexpr auto a_scale_thread_vec_size = KXdlPack * MXdlPack / scale_pack_size_a; + static constexpr auto b_scale_thread_vec_size = KXdlPack * NXdlPack / scale_pack_size_b; + + __host__ static constexpr bool BlockHasHotloop(index_t num_loop) + { + return num_loop > PrefetchStages; + } + + __host__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop) + { + return num_loop % 2 == 0 ? TailNumber::Even : TailNumber::Odd; + } + + __device__ static constexpr auto HotLoopScheduler() + { + constexpr auto num_ds_read_inst_a = HotLoopInstList::A_LDS_Read_Inst_Num; + constexpr auto num_buffer_load_inst_a = HotLoopInstList::A_Buffer_Load_Inst_Num; + constexpr auto num_buffer_load_inst_b = HotLoopInstList::B_Buffer_Load_Inst_Num * MWaves + + num_buffer_load_a_scale + num_buffer_load_b_scale; + constexpr auto mfma_interleave = MPerXDL == 32 ? 1 : 2; + // B global + static_for<0, num_buffer_load_inst_b, 1>{}([&](auto i) { + ignore = i; + if constexpr(MPerBlock >= 128 && NPerBlock >= 128) + { + __builtin_amdgcn_sched_group_barrier(0x008, 2 * mfma_interleave, 0); + } + else + { + __builtin_amdgcn_sched_group_barrier(0x008, mfma_interleave, 0); + } + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + }); + + // A global + static_for<0, num_buffer_load_inst_a, 1>{}([&](auto i) { + ignore = i; + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + }); + + // A local + static_for<0, MPerXDL == 32 ? num_ds_read_inst_a / 2 : num_ds_read_inst_a, 1>{}( + [&](auto i) { + ignore = i; + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x100, MPerXDL == 32 ? 2 : 1, 0); // DS read + }); + } + + template + __device__ void Run( + // ABlockCopy + const AGridDesc& a_grid_desc, + const ABlockDesc& a_block_desc, + ABlockTransfer& a_blockwise_copy, + const AGridBuffer& a_grid_buf, + ABlockBuffer& a_block_buf, + const ABlockTransferStep& a_block_copy_step, + // BBlockCopy + const BGridDesc& b_grid_desc, + const BBlockDesc& b_block_desc, + BBlockTransfer& b_blockwise_copy, + const BGridBuffer& b_grid_buf, + BBlockBuffer& b_block_bufs, + const BBlockTransferStep& b_block_copy_step, + // CThread + CThreadBuffer& c_thread_buf, + // A and B scales + const AScaleGridDesc& a_scale_grid_desc, + AScaleThreadTransfer& a_scale_thread_copy, + const AScaleGridBuffer& a_scale_grid_buf, + const BScaleGridDesc& b_scale_grid_desc, + BScaleThreadTransfer& b_scale_thread_copy, + const BScaleGridBuffer& b_scale_grid_buf, + index_t num_loop) const + { + ignore = b_block_bufs; + __builtin_amdgcn_sched_barrier(0); + auto a_thread_buf = make_static_buffer( + a_thread_desc_.GetElementSpaceSize()); + auto b_thread_buf = make_static_buffer( + b_thread_desc_.GetElementSpaceSize()); + + StaticallyIndexedArray{}> b_thread_bufs; + constexpr auto b_block_origin_idx = make_tuple(I0, I0, I0, I0, I0); + + auto a_scale_thread_buf = make_static_buffer( + a_scale_thread_desc.GetElementSpaceSize()); + auto b_scale_thread_buf = make_static_buffer( + b_scale_thread_desc.GetElementSpaceSize()); + + StaticallyIndexedArray{}> a_scale_thread_bufs; + StaticallyIndexedArray{}> b_scale_thread_bufs; + + // Global prefetch 1 + a_blockwise_copy.Run(a_grid_desc, a_grid_buf, a_block_desc, a_block_buf); + b_blockwise_copy.Run( + b_grid_desc, b_grid_buf, b_block_desc, b_block_origin_idx, b_thread_bufs(I0)); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + __builtin_amdgcn_sched_barrier(0); + + // Prefetch a_scales + static_for<0, MRepeat / MXdlPack, 1>{}([&](auto m0) { + static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) { + a_scale_thread_copy.Run(a_scale_grid_desc, + a_scale_grid_buf, + a_scale_thread_desc, + make_tuple(m0, k0, I0), + a_scale_thread_bufs(I0)); + + a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc, + make_multi_index(0, I1, 0)); + }); + a_scale_thread_copy.MoveSrcSliceWindow( + a_scale_grid_desc, make_multi_index(MWaves, -KRepeat / KXdlPack, 0)); + }); + + // restore row id and advance to the next set of scales + a_scale_thread_copy.MoveSrcSliceWindow( + a_scale_grid_desc, + make_multi_index(-MWaves * MRepeat / MXdlPack, KRepeat / KXdlPack, 0)); + + // Prefetch b_scales + static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) { + static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) { + b_scale_thread_copy.Run(b_scale_grid_desc, + b_scale_grid_buf, + b_scale_thread_desc, + make_tuple(n0, k0, I0), + b_scale_thread_bufs(I0)); + + b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, + make_multi_index(0, I1, 0)); + }); + b_scale_thread_copy.MoveSrcSliceWindow( + b_scale_grid_desc, make_multi_index(NWaves, -KRepeat / KXdlPack, 0)); + }); + + // restore col id and advance to the next set of scales + // NWaves * NPerXDL * NRepeat == NPerBlock + b_scale_thread_copy.MoveSrcSliceWindow( + b_scale_grid_desc, + make_multi_index(-NWaves * NRepeat / NXdlPack, KRepeat / KXdlPack, 0)); + + // Local prefetch 1, sync the async load + __builtin_amdgcn_s_waitcnt(async_vmcnt_encoding); + block_sync_lds(); + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, KRepeat, 1>{}([&](auto k) { + constexpr auto k_step = k * xdlops_gemm.KPerXdlops / APackedSize * + (APackedSize * KPack / xdlops_gemm.K1PerXdlops); + static_for<0, xdlops_gemm.K1PerXdlops / (APackedSize * KThreadChunk), 1>{}( + [&](auto chunk) { + constexpr auto a_k_step_chunk = + k_step + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; + a_thread_copy_.Run(a_block_desc_m0_m1_m2_m3_k, + make_tuple(Number{}, + I0, + Number{}, + I0, + Number{}), + a_block_buf, + a_thread_desc_, + make_tuple(Number{}, + I0, + Number{}, + k, + Number{}), + a_thread_buf); + }); + }); + }); + + // Initialize C + c_thread_buf.Clear(); + __builtin_amdgcn_sched_barrier(0); + // main body + if constexpr(HasMainLoop) + { + // loop over k with the step KPerBlock + index_t i = 0; + do + { + auto LoopFunc = [&](auto scale_comp_buf, auto scale_mem_buf) { + b_blockwise_copy.Run(b_grid_desc, + b_grid_buf, + b_block_desc, + b_block_origin_idx, + b_thread_bufs(scale_mem_buf)); + + block_sync_lds(); + a_blockwise_copy.Run(a_grid_desc, a_grid_buf, a_block_desc, a_block_buf); + // Prefetch a_scales + static_for<0, MRepeat / MXdlPack, 1>{}([&](auto m0) { + static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) { + a_scale_thread_copy.Run(a_scale_grid_desc, + a_scale_grid_buf, + a_scale_thread_desc, + make_tuple(m0, k0, I0), + a_scale_thread_bufs(scale_mem_buf)); + + a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc, + make_multi_index(0, I1, 0)); + }); + a_scale_thread_copy.MoveSrcSliceWindow( + a_scale_grid_desc, make_multi_index(MWaves, -KRepeat / KXdlPack, 0)); + }); + + // restore row id and advance to the next set of scales + a_scale_thread_copy.MoveSrcSliceWindow( + a_scale_grid_desc, + make_multi_index(-MWaves * MRepeat / MXdlPack, KRepeat / KXdlPack, 0)); + + // Prefetch b_scales + static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) { + static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) { + b_scale_thread_copy.Run(b_scale_grid_desc, + b_scale_grid_buf, + b_scale_thread_desc, + make_tuple(n0, k0, I0), + b_scale_thread_bufs(scale_mem_buf)); + + b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, + make_multi_index(0, I1, 0)); + }); + b_scale_thread_copy.MoveSrcSliceWindow( + b_scale_grid_desc, make_multi_index(NWaves, -KRepeat / KXdlPack, 0)); + }); + + // restore col id and advance to the next set of scales + // NWaves * NPerXDL * NRepeat == NPerBlock + b_scale_thread_copy.MoveSrcSliceWindow( + b_scale_grid_desc, + make_multi_index(-NWaves * NRepeat / NXdlPack, KRepeat / KXdlPack, 0)); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + constexpr auto im_major = m0 / MXdlPack; + constexpr auto im_minor = m0 % MXdlPack; + static_for<0, KRepeat, 1>{}([&](auto k0) { + constexpr auto ik_major = k0 / KXdlPack; + constexpr auto ik_minor = k0 % KXdlPack; + static_for<0, NRepeat, 1>{}([&](auto n0) { + constexpr auto in_major = n0 / NXdlPack; + constexpr auto in_minor = n0 % NXdlPack; + + constexpr index_t a_scale_offset = + a_scale_thread_desc.CalculateOffset( + make_tuple(im_major, ik_major, I0)); + constexpr index_t b_scale_offset = + b_scale_thread_desc.CalculateOffset( + make_tuple(in_major, ik_major, I0)); + + static_assert(0 < ScalesPerXdlopsRunPerThread, + "Must have at least one scale per Xdlops " + "per Thread."); + + vector_type + a_scale_thread_vec; + vector_type + b_scale_thread_vec; + + // Pack scale_thread_buf into scale_thread_vec + static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) { + a_scale_thread_vec.template AsType()(s) = + a_scale_thread_bufs( + scale_comp_buf)[Number{}]; + }); + + static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { + b_scale_thread_vec.template AsType()(s) = + b_scale_thread_bufs( + scale_comp_buf)[Number{}]; + }); + + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = b_thread_bufs + [scale_comp_buf][Number{}]; + }); + + using mfma_input_type_a = + typename vector_type::type; + + using mfma_input_type_b = + typename vector_type::type; + + using mfma_scale_input_type_a = + typename vector_type::type; + using mfma_scale_input_type_b = + typename vector_type::type; + + constexpr index_t c_offset = c_thread_desc_.CalculateOffset( + make_tuple(im_major, in_major, im_minor, in_minor, 0)); + + // MFMA accumulation + xdlops_gemm.template Run( + a_thread_vec.template AsType(), + a_scale_thread_vec.template AsType(), + b_thread_vec.template AsType(), + b_scale_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + }); + }); + }); + + block_sync_lds(); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, KRepeat, 1>{}([&](auto k) { + constexpr auto k_step = k * xdlops_gemm.KPerXdlops / APackedSize * + (APackedSize * KPack / xdlops_gemm.K1PerXdlops); + static_for<0, + xdlops_gemm.K1PerXdlops / (APackedSize * KThreadChunk), + 1>{}([&](auto chunk) { + constexpr auto a_k_step_chunk = + k_step + + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; + a_thread_copy_.Run(a_block_desc_m0_m1_m2_m3_k, + make_tuple(Number{}, + I0, + Number{}, + I0, + Number{}), + a_block_buf, + a_thread_desc_, + make_tuple(Number{}, + I0, + Number{}, + k, + Number{}), + a_thread_buf); + }); + }); + }); + HotLoopScheduler(); + __builtin_amdgcn_sched_barrier(0); + }; + + LoopFunc(I0, I1); + LoopFunc(I1, I0); + + i += 2; + } while(i < (num_loop - 2)); + } + + // tail + if constexpr(TailNum == TailNumber::Even) + { + b_blockwise_copy.Run( + b_grid_desc, b_grid_buf, b_block_desc, b_block_origin_idx, b_thread_bufs(I1)); + + block_sync_lds(); + a_blockwise_copy.Run(a_grid_desc, a_grid_buf, a_block_desc, a_block_buf); + // Prefetch a_scales + static_for<0, MRepeat / MXdlPack, 1>{}([&](auto m0) { + static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) { + a_scale_thread_copy.Run(a_scale_grid_desc, + a_scale_grid_buf, + a_scale_thread_desc, + make_tuple(m0, k0, I0), + a_scale_thread_bufs(I1)); + + a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc, + make_multi_index(0, I1, 0)); + }); + a_scale_thread_copy.MoveSrcSliceWindow( + a_scale_grid_desc, make_multi_index(MWaves, -KRepeat / KXdlPack, 0)); + }); + + // Prefetch b_scales + static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) { + static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) { + b_scale_thread_copy.Run(b_scale_grid_desc, + b_scale_grid_buf, + b_scale_thread_desc, + make_tuple(n0, k0, I0), + b_scale_thread_bufs(I1)); + + b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, + make_multi_index(0, I1, 0)); + }); + b_scale_thread_copy.MoveSrcSliceWindow( + b_scale_grid_desc, make_multi_index(NWaves, -KRepeat / KXdlPack, 0)); + }); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + constexpr auto im_major = m0 / MXdlPack; + constexpr auto im_minor = m0 % MXdlPack; + static_for<0, KRepeat, 1>{}([&](auto k0) { + constexpr auto ik_major = k0 / KXdlPack; + constexpr auto ik_minor = k0 % KXdlPack; + static_for<0, NRepeat, 1>{}([&](auto n0) { + constexpr auto in_major = n0 / NXdlPack; + constexpr auto in_minor = n0 % NXdlPack; + + constexpr index_t a_scale_offset = + a_scale_thread_desc.CalculateOffset(make_tuple(im_major, ik_major, I0)); + constexpr index_t b_scale_offset = + b_scale_thread_desc.CalculateOffset(make_tuple(in_major, ik_major, I0)); + + static_assert(0 < ScalesPerXdlopsRunPerThread, + "Must have at least one scale per Xdlops " + "per Thread."); + + vector_type a_scale_thread_vec; + vector_type b_scale_thread_vec; + + // Pack scale_thread_buf into scale_thread_vec + static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) { + a_scale_thread_vec.template AsType()(s) = + a_scale_thread_bufs(I0)[Number{}]; + }); + + static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { + b_scale_thread_vec.template AsType()(s) = + b_scale_thread_bufs(I0)[Number{}]; + }); + + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[I0][Number{}]; + }); + + using mfma_input_type_a = + typename vector_type::type; + + using mfma_input_type_b = + typename vector_type::type; + + using mfma_scale_input_type_a = + typename vector_type::type; + using mfma_scale_input_type_b = + typename vector_type::type; + + constexpr index_t c_offset = c_thread_desc_.CalculateOffset( + make_tuple(im_major, in_major, im_minor, in_minor, 0)); + + // MFMA accumulation + xdlops_gemm.template Run( + a_thread_vec.template AsType(), + a_scale_thread_vec.template AsType(), + b_thread_vec.template AsType(), + b_scale_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + }); + }); + + // constexpr auto lds_buf = m0.value >= SwitchM ? I1 : I0; + }); + __builtin_amdgcn_s_waitcnt(async_vmcnt_encoding); + block_sync_lds(); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, KRepeat, 1>{}([&](auto k) { + constexpr auto k_step = k * xdlops_gemm.KPerXdlops / APackedSize * + (APackedSize * KPack / xdlops_gemm.K1PerXdlops); + static_for<0, xdlops_gemm.K1PerXdlops / (APackedSize * KThreadChunk), 1>{}( + [&](auto chunk) { + constexpr auto a_k_step_chunk = + k_step + + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; + a_thread_copy_.Run(a_block_desc_m0_m1_m2_m3_k, + make_tuple(Number{}, + I0, + Number{}, + I0, + Number{}), + a_block_buf, + a_thread_desc_, + make_tuple(Number{}, + I0, + Number{}, + k, + Number{}), + a_thread_buf); + }); + }); + }); + __builtin_amdgcn_sched_barrier(0); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + constexpr auto im_major = m0 / MXdlPack; + constexpr auto im_minor = m0 % MXdlPack; + static_for<0, KRepeat, 1>{}([&](auto k0) { + constexpr auto ik_major = k0 / KXdlPack; + constexpr auto ik_minor = k0 % KXdlPack; + static_for<0, NRepeat, 1>{}([&](auto n0) { + constexpr auto in_major = n0 / NXdlPack; + constexpr auto in_minor = n0 % NXdlPack; + + constexpr index_t a_scale_offset = + a_scale_thread_desc.CalculateOffset(make_tuple(im_major, ik_major, I0)); + constexpr index_t b_scale_offset = + b_scale_thread_desc.CalculateOffset(make_tuple(in_major, ik_major, I0)); + + static_assert(0 < ScalesPerXdlopsRunPerThread, + "Must have at least one scale per Xdlops " + "per Thread."); + + vector_type a_scale_thread_vec; + vector_type b_scale_thread_vec; + + // Pack scale_thread_buf into scale_thread_vec + static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) { + a_scale_thread_vec.template AsType()(s) = + a_scale_thread_bufs(I1)[Number{}]; + }); + + static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { + b_scale_thread_vec.template AsType()(s) = + b_scale_thread_bufs(I1)[Number{}]; + }); + + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[I1][Number{}]; + }); + + using mfma_input_type_a = + typename vector_type::type; + + using mfma_input_type_b = + typename vector_type::type; + + using mfma_scale_input_type_a = + typename vector_type::type; + using mfma_scale_input_type_b = + typename vector_type::type; + + constexpr index_t c_offset = c_thread_desc_.CalculateOffset( + make_tuple(im_major, in_major, im_minor, in_minor, 0)); + + // MFMA accumulation + xdlops_gemm.template Run( + a_thread_vec.template AsType(), + a_scale_thread_vec.template AsType(), + b_thread_vec.template AsType(), + b_scale_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + }); + }); + }); + } + else if constexpr(TailNum == TailNumber::Odd) + { + static_for<0, MRepeat, 1>{}([&](auto m0) { + constexpr auto im_major = m0 / MXdlPack; + constexpr auto im_minor = m0 % MXdlPack; + static_for<0, KRepeat, 1>{}([&](auto k0) { + constexpr auto ik_major = k0 / KXdlPack; + constexpr auto ik_minor = k0 % KXdlPack; + static_for<0, NRepeat, 1>{}([&](auto n0) { + constexpr auto in_major = n0 / NXdlPack; + constexpr auto in_minor = n0 % NXdlPack; + + constexpr index_t a_scale_offset = + a_scale_thread_desc.CalculateOffset(make_tuple(im_major, ik_major, I0)); + constexpr index_t b_scale_offset = + b_scale_thread_desc.CalculateOffset(make_tuple(in_major, ik_major, I0)); + + static_assert(0 < ScalesPerXdlopsRunPerThread, + "Must have at least one scale per Xdlops " + "per Thread."); + + vector_type a_scale_thread_vec; + vector_type b_scale_thread_vec; + + // Pack scale_thread_buf into scale_thread_vec + static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) { + a_scale_thread_vec.template AsType()(s) = + a_scale_thread_bufs(I0)[Number{}]; + }); + + static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { + b_scale_thread_vec.template AsType()(s) = + b_scale_thread_bufs(I0)[Number{}]; + }); + + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[I0][Number{}]; + }); + + using mfma_input_type_a = + typename vector_type::type; + + using mfma_input_type_b = + typename vector_type::type; + + using mfma_scale_input_type_a = + typename vector_type::type; + using mfma_scale_input_type_b = + typename vector_type::type; + + constexpr index_t c_offset = c_thread_desc_.CalculateOffset( + make_tuple(im_major, in_major, im_minor, in_minor, 0)); + + // MFMA accumulation + xdlops_gemm.template Run( + a_thread_vec.template AsType(), + a_scale_thread_vec.template AsType(), + b_thread_vec.template AsType(), + b_scale_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + }); + }); + }); + } + } + + // TODO: make this field protected when a_scale_thread_copy_ is moved + // here + static constexpr auto a_scale_thread_desc = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, + Number{}, + Number{})); + + // TODO: make this field protected when b_scale_thread_copy_ is moved + // here + static constexpr auto b_scale_thread_desc = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, + Number{}, + Number{})); + + protected: + using Base::a_thread_copy_; + using Base::a_thread_desc_; + using Base::b_thread_copy_; + using Base::b_thread_desc_; + using Base::c_thread_desc_; +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_mx_moe_v3.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_mx_moe_v3.hpp index 2b936c8d25..7473d2f2e7 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_mx_moe_v3.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_mx_moe_v3.hpp @@ -226,85 +226,197 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_mx_moe_v3 2) + { - // Group num_mfma_perstage num_ds_read_a_perstage - // since we want to reuse a local register buffer - constexpr auto num_mfma_perstage = num_mfma_inst / num_total_stages; - constexpr auto num_ds_read_a_perstage = num_ds_read_inst_a / num_total_stages; + // Group num_mfma_perstage num_ds_read_a_perstage + // since we want to reuse a local register buffer + constexpr auto num_mfma_perstage = num_mfma_inst / num_total_stages; + constexpr auto num_ds_read_a_perstage = num_ds_read_inst_a / num_total_stages; - constexpr auto num_ds_read_a_mfma_perstage = - math::integer_divide_ceil(num_ds_read_a_perstage, ds_read_a_mfma_rate); + constexpr auto num_ds_read_a_mfma_perstage = + math::integer_divide_ceil(num_ds_read_a_perstage, ds_read_a_mfma_rate); - constexpr auto num_ds_read_a_prefetch_stages = 2; + constexpr auto num_ds_read_a_prefetch_stages = 2; - constexpr auto buffer_load_perstage_more = - math::integer_divide_ceil((num_buffer_load_stage1), (num_total_stages - 2)); - constexpr auto buffer_load_perstage_less = - math::integer_divide_floor((num_buffer_load_stage1), (num_total_stages - 2)); - constexpr auto buffer_load_perstage_stage2 = - math::integer_divide_floor((num_buffer_load_stage2), 2); + constexpr auto buffer_load_perstage_more = + math::integer_divide_ceil((num_buffer_load_stage1), (num_total_stages - 2)); + constexpr auto buffer_load_perstage_less = + math::integer_divide_floor((num_buffer_load_stage1), (num_total_stages - 2)); + constexpr auto buffer_load_perstage_stage2 = + math::integer_divide_floor((num_buffer_load_stage2), 2); - constexpr auto buffer_load_stages_more = - num_buffer_load_stage1 - - math::integer_divide_floor(num_buffer_load_stage1, (num_total_stages - 2)) * - ((num_total_stages - 2)); + constexpr auto buffer_load_stages_more = + num_buffer_load_stage1 - + math::integer_divide_floor(num_buffer_load_stage1, (num_total_stages - 2)) * + ((num_total_stages - 2)); - constexpr auto buffer_load_issue_point_interval_more = - num_mfma_perstage / buffer_load_perstage_more; - constexpr auto buffer_load_issue_point_interval_less = - num_mfma_perstage / buffer_load_perstage_less; - constexpr auto buffer_load_issue_point_interval_stage2 = - num_mfma_perstage / buffer_load_perstage_stage2; + constexpr auto buffer_load_issue_point_interval_more = + num_mfma_perstage / buffer_load_perstage_more; + constexpr auto buffer_load_issue_point_interval_less = + num_mfma_perstage / buffer_load_perstage_less; + constexpr auto buffer_load_issue_point_interval_stage2 = + num_mfma_perstage / buffer_load_perstage_stage2; - // Stage 1 - // global read more - static_for<0, buffer_load_stages_more, 1>{}([&](auto /*i*/) { - static_for<0, num_mfma_perstage, 1>{}([&](auto imfma) { - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + // Stage 1 + // global read more + static_for<0, buffer_load_stages_more, 1>{}([&](auto /*i*/) { + static_for<0, num_mfma_perstage, 1>{}([&](auto imfma) { + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - if constexpr(imfma % buffer_load_issue_point_interval_more == 0) + if constexpr(imfma % buffer_load_issue_point_interval_more == 0) + { + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + } + + if constexpr(imfma >= (num_mfma_perstage - num_ds_read_a_mfma_perstage)) + { + __builtin_amdgcn_sched_group_barrier( + 0x100, ds_read_a_mfma_rate, 0); // DS read + } + }); + }); + + // global read less + static_for<0, (num_total_stages - 2 - buffer_load_stages_more), 1>{}([&](auto /*i*/) { + static_for<0, num_mfma_perstage, 1>{}([&](auto imfma) { + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + if constexpr(imfma % buffer_load_issue_point_interval_less == 0) + { + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + } + if constexpr(imfma >= (num_mfma_perstage - num_ds_read_a_mfma_perstage)) + { + __builtin_amdgcn_sched_group_barrier( + 0x100, ds_read_a_mfma_rate, 0); // DS read + } + }); + }); + + // Stage 2, Sync + // lds synchronization, prefetch next loop local A + static_for<0, num_ds_read_a_prefetch_stages, 1>{}([&](auto /*i*/) { + static_for<0, num_mfma_perstage, 1>{}([&](auto imfma) { + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + if constexpr(imfma % buffer_load_issue_point_interval_stage2 == 0) + { + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + } + if constexpr(imfma >= (num_mfma_perstage - num_ds_read_a_mfma_perstage)) + { + __builtin_amdgcn_sched_group_barrier( + 0x100, ds_read_a_mfma_rate, 0); // DS read + } + }); + }); + } + else + { + constexpr auto num_buffer_load_total = num_buffer_load_inst_a + num_buffer_load_inst_b + + num_buffer_load_a_scale + + num_buffer_load_b_scale; + constexpr auto num_dsread_a_mfma = math::integer_divide_ceil( + num_ds_read_inst_a, ds_read_a_mfma_rate); // how many mfma per dsread_a + + // stage 1 + constexpr auto num_mfma_stage1 = num_mfma_inst - num_dsread_a_mfma; + + constexpr auto mfma_perstage_more = + math::integer_divide_ceil(num_mfma_stage1, num_buffer_load_total); + constexpr auto mfma_perstage_less = + math::integer_divide_floor(num_mfma_stage1, num_buffer_load_total); + + constexpr auto mfma_stages_more = + num_mfma_stage1 - mfma_perstage_less * num_buffer_load_total; + + static_for<0, num_buffer_load_inst_a, 1>{}([&](auto i) { + if constexpr(i < mfma_stages_more) { + static_for<0, mfma_perstage_more, 1>{}([&](auto) { + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + }); __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read } + else + { + static_for<0, mfma_perstage_less, 1>{}([&](auto) { + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + }); + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + } + }); - if constexpr(imfma >= (num_mfma_perstage - num_ds_read_a_mfma_perstage)) + static_for<0, num_buffer_load_inst_b, 1>{}([&](auto i) { + if constexpr((i + num_buffer_load_inst_a) < mfma_stages_more) + { + static_for<0, mfma_perstage_more, 1>{}([&](auto /*imfma*/) { + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + }); + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + } + else + { + static_for<0, mfma_perstage_less, 1>{}([&](auto /*imfma*/) { + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + }); + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + } + }); + + static_for<0, num_buffer_load_a_scale, 1>{}([&](auto i) { + if constexpr((i + num_buffer_load_inst_a + num_buffer_load_inst_b) < + mfma_stages_more) + { + static_for<0, mfma_perstage_more, 1>{}([&](auto /*imfma*/) { + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + }); + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + } + else + { + static_for<0, mfma_perstage_less, 1>{}([&](auto /*imfma*/) { + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + }); + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + } + }); + + static_for<0, num_buffer_load_b_scale, 1>{}([&](auto i) { + if constexpr((i + num_buffer_load_inst_a + num_buffer_load_inst_b + + num_buffer_load_a_scale) < mfma_stages_more) + { + static_for<0, mfma_perstage_more, 1>{}([&](auto /*imfma*/) { + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + }); + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + } + else + { + static_for<0, mfma_perstage_less, 1>{}([&](auto /*imfma*/) { + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + }); + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + } + }); + + // stage 2 + static_for<0, num_dsread_a_mfma, 1>{}([&](auto i) { + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + if constexpr((num_ds_read_inst_a - (i + 1) * ds_read_a_mfma_rate) >= + ds_read_a_mfma_rate) { __builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); // DS read } - }); - }); - - // global read less - static_for<0, (num_total_stages - 2 - buffer_load_stages_more), 1>{}([&](auto /*i*/) { - static_for<0, num_mfma_perstage, 1>{}([&](auto imfma) { - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - if constexpr(imfma % buffer_load_issue_point_interval_less == 0) + else { - __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read - } - if constexpr(imfma >= (num_mfma_perstage - num_ds_read_a_mfma_perstage)) - { - __builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); // DS read + __builtin_amdgcn_sched_group_barrier( + 0x100, + num_ds_read_inst_a - (num_dsread_a_mfma - 1) * ds_read_a_mfma_rate, + 0); // DS read } }); - }); - - // Stage 2, Sync - // lds synchronization, prefetch next loop local A - static_for<0, num_ds_read_a_prefetch_stages, 1>{}([&](auto /*i*/) { - static_for<0, num_mfma_perstage, 1>{}([&](auto imfma) { - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - if constexpr(imfma % buffer_load_issue_point_interval_stage2 == 0) - { - __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read - } - if constexpr(imfma >= (num_mfma_perstage - num_ds_read_a_mfma_perstage)) - { - __builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); // DS read - } - }); - }); + } } template ()) - { - __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z); + auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z); - GridwiseGemm::template Run( - karg.p_sorted_token_ids, - karg.p_sorted_expert_ids, - karg.p_max_token_id, - karg.p_a_grid + splitk_batch_offset.a_k_split_offset, - karg.p_a_scale_grid + splitk_batch_offset.a_k_split_offset, - karg.p_b_grid + splitk_batch_offset.b_k_split_offset, - karg.p_b_scale_grid + splitk_batch_offset.b_k_split_offset, - karg.p_ds_grid, - karg.p_c_grid, - p_shared, - karg, - karg.a_element_op, - karg.b_element_op, - karg.c_element_op); - } + GridwiseGemm::template Run( + karg.p_sorted_token_ids, + karg.p_sorted_expert_ids, + karg.p_max_token_id, + karg.p_a_grid + splitk_batch_offset.a_k_split_offset, + karg.p_a_scale_grid + splitk_batch_offset.a_scale_k_split_offset, + karg.p_b_grid + splitk_batch_offset.b_k_split_offset, + karg.p_b_scale_grid + splitk_batch_offset.b_scale_k_split_offset, + karg.p_ds_grid, + karg.p_c_grid, + p_shared, + karg, + karg.a_element_op, + karg.b_element_op, + karg.c_element_op); #else ignore = karg; #endif // end of if (defined(__gfx9__)) @@ -1249,7 +1246,6 @@ struct GridwiseMoeGemmMX_BPreshuffle __host__ static constexpr bool CalculateHasMainKBlockLoop(index_t K) { const index_t num_loop = K / KPerBlock; - return BlockwiseGemmPipe::BlockHasHotloop(num_loop); } @@ -1279,7 +1275,6 @@ struct GridwiseMoeGemmMX_BPreshuffle // using Block2CTileMapDefault = BlockToCTileMap_Grouped_M00_N0_M01Adapt<8, MPerBlock, // NPerBlock>; -#if 0 template @@ -1298,9 +1293,10 @@ struct GridwiseMoeGemmMX_BPreshuffle BElementwiseOperation b_element_op, CElementwiseOperation c_element_op) { + ignore = a_element_op; ignore = b_element_op; - index_t BN0Shuffled = CalculateBN0Shuffled(problem.N); - index_t BK0Shuffled = CalculateBK0Shuffled(problem.K); + index_t BN0Shuffled = CalculateBN0Shuffled(problem.N); + index_t BK0Shuffled = CalculateBK0Shuffled(problem.K); const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1( IsInputGemm ? problem.NumTokens : problem.NumTokens * problem.TopK, problem.MPadded, @@ -1317,29 +1313,41 @@ struct GridwiseMoeGemmMX_BPreshuffle problem.NPadded, problem.StrideC); - const auto a_scale_grid_desc_am_ak = make_naive_tensor_descriptor_packed( - make_tuple((IsInputGemm ? problem.NumTokens : problem.M) / (MXdlPack * MPerBlock), + // We pad the M unconditionaly for Scale + const auto Padded_Scale_M = + math::integer_divide_ceil(problem.M, ScaleBlockSize) * ScaleBlockSize; + const auto a_scale_grid_desc_am_ak = make_naive_tensor_descriptor( + make_tuple(Padded_Scale_M / (MXdlPack * MPerXdl), math::integer_divide_ceil(problem.K, (ScaleBlockSize / APackedSize)) / (KXdlPack * 64 / MPerXdl), - 64 * KXdlPack * MXdlPack / scale_pack_size_a)); + 64 * KXdlPack * MXdlPack / scale_pack_size_a), + make_tuple(math::integer_divide_ceil(problem.K * problem.KBatch, + (ScaleBlockSize / APackedSize)) * + MPerXdl * MXdlPack / scale_pack_size_a, + 64 * KXdlPack * MXdlPack / scale_pack_size_a, + 1)); - const auto b_scale_grid_desc_bn_ak = make_naive_tensor_descriptor_packed( + const auto b_scale_grid_desc_bn_ak = make_naive_tensor_descriptor( make_tuple(problem.N / (NXdlPack * NPerXdl), math::integer_divide_ceil(problem.K, (ScaleBlockSize / BPackedSize)) / (KXdlPack * 64 / NPerXdl), - 64 * KXdlPack * NXdlPack / scale_pack_size_b)); + 64 * KXdlPack * NXdlPack / scale_pack_size_b), + make_tuple(math::integer_divide_ceil(problem.K * problem.KBatch, + (ScaleBlockSize / BPackedSize)) * + NPerXdl * NXdlPack / scale_pack_size_b, + 64 * KXdlPack * NXdlPack / scale_pack_size_b, + 1)); const auto c_grid_desc_mblock_mperblock_nblock_nperblock = MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( c_grid_desc_m_n, problem.MBlock, problem.NBlock); - const index_t max_token_id = __builtin_amdgcn_readfirstlane(p_max_token_id[0]); - // static_assert(NSwizzle == false, "to do fix: need another pr in sorting merged"); + + const index_t max_token_id = __builtin_amdgcn_readfirstlane(p_max_token_id[0]); const index_t expert_block_id = NSwizzle ? blockIdx.x / problem.NBlock : blockIdx.y; if(expert_block_id * MPerBlock >= max_token_id) return; const index_t expert_id = __builtin_amdgcn_readfirstlane(p_sorted_expert_ids[expert_block_id]); - const auto block_mn = [&]() -> std::pair { if constexpr(NSwizzle) { @@ -1372,86 +1380,78 @@ struct GridwiseMoeGemmMX_BPreshuffle constexpr auto AK1Threads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I2); constexpr auto AKThreads = AK0Threads * AK1Threads; constexpr auto AMRepeats = MPerBlock / AMThreads; - const index_t token_pos = block_m_id * MPerBlock + threadIdx.x / AKThreads * AMRepeats; + const index_t token_pos = block_m_id * MPerBlock + threadIdx.x / AKThreads; if(token_pos >= max_token_id || token0 >= problem.NumTokens) return; StaticallyIndexedArray gather_offsets; static_for<0, AMRepeats, 1>{}([&](auto m0) { - const index_t fused_token = p_sorted_token_ids[token_pos + m0]; + const index_t fused_token = p_sorted_token_ids[token_pos + m0 * AMThreads]; index_t token_offset = fused_token & 0xffffff; if constexpr(!IsInputGemm) { token_offset = token_offset * problem.TopK + (fused_token >> 24); } - gather_offsets(m0) = static_cast(token_offset) * problem.K / APackedSize; + gather_offsets(m0) = static_cast(token_offset) * problem.K; }); + const index_t expert_stride = __builtin_amdgcn_readfirstlane(problem.N * problem.K * (IsInputGemm ? 2 : 1)); - const index_t expert_scale_stride = - __builtin_amdgcn_readfirstlane(problem.N * (IsInputGemm ? 2 : 1) * - math::integer_divide_ceil(problem.K, ScaleBlockSize)); + const index_t expert_scale_stride = __builtin_amdgcn_readfirstlane( + problem.N * (IsInputGemm ? 2 : 1) * + math::integer_divide_ceil(problem.K, ScaleBlockSize / BPackedSize)); // N0, K0, Blocksize*KPack const index_t n_block_data_idx_on_grid = - __builtin_amdgcn_readfirstlane(block_n_id * NXdlPerWave); + __builtin_amdgcn_readfirstlane(block_n_id * NXdlPerWave / NXdlPack); + // Gride buffer creation const auto a_grid_buf = make_dynamic_buffer( p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize()); const auto b_grid_buf = make_dynamic_buffer( - p_b_grid + expert_id * expert_stride / BPackedSize, - b_grid_desc_bpreshuffled.GetElementSpaceSize()); + p_b_grid + expert_id * expert_stride, b_grid_desc_bpreshuffled.GetElementSpaceSize()); // A, B scale buffer const auto a_scale_grid_buf = make_dynamic_buffer( p_a_scale_grid, a_scale_grid_desc_am_ak.GetElementSpaceSize()); const auto b_scale_grid_buf = make_dynamic_buffer( - p_b_scale_grid + expert_id * expert_scale_stride, + p_b_scale_grid + (expert_id * expert_scale_stride) / sizeof(BScaleDataType), b_scale_grid_desc_bn_ak.GetElementSpaceSize()); // A matrix in LDS memory, dst of blockwise copy constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); // B matrix in LDS memory, dst of blockwise copy - // dummy constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); - // A matrix blockwise copy - auto a_blockwise_copy = ThreadGroupTensorSliceTransfer_v4r1_gather< + + // A matrix blockwise direct to LDS copy + auto a_blockwise_copy = ThreadGroupTensorSliceTransfer_Gather_DirectLoad< ThisThreadBlock, - AElementwiseOperation, - ck::tensor_operation::element_wise::PassThrough, - InMemoryDataOperationEnum::Set, Sequence, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ADataType, - LDSTypeA, + ADataType, decltype(a_grid_desc_ak0_m_ak1), decltype(a_block_desc_ak0_m_ak1), ABlockTransferSrcAccessOrder, - Sequence<0, 1, 2>, ABlockTransferSrcVectorDim, 2, ABlockTransferSrcScalarPerVector, - ABlockTransferDstScalarPerVector_AK1, - 1, - 1, - AThreadTransferSrcResetCoordinateAfterRun, - true, IndexType, - 1, - BlockwiseGemmPipe::GlobalBufferNum>(a_grid_desc_ak0_m_ak1, - make_multi_index(0, 0, 0), - a_element_op, - a_block_desc_ak0_m_ak1, - make_multi_index(0, 0, 0), - ck::tensor_operation::element_wise::PassThrough{}, - gather_offsets); + 1>(a_grid_desc_ak0_m_ak1, + make_multi_index(0, 0, 0), + a_block_desc_ak0_m_ak1, + make_multi_index(0, 0, 0), + gather_offsets); // Thread-wise copy // K0 -> N0/NWave -> NWave -> KLane -> NLane -> KPack - auto b_block_buf = make_static_buffer( + auto b_block_buf_ping = make_static_buffer( b_block_desc_bk0_n_bk1.GetElementSpaceSize()); + auto b_block_buf_pong = make_static_buffer( + b_block_desc_bk0_n_bk1.GetElementSpaceSize()); + auto b_block_bufs = make_tuple(b_block_buf_ping, b_block_buf_pong); auto b_blockwise_copy = ThreadwiseTensorSliceTransfer_v2{}, Number{}, Number{}>, - Sequence<1, 2, 0, 3>, + Sequence<0, 1, 2, 3, 4>, 4, BBlockTransferSrcScalarPerVector, BThreadTransferSrcResetCoordinateAfterRun, @@ -1472,16 +1472,16 @@ struct GridwiseMoeGemmMX_BPreshuffle make_multi_index(n_block_data_idx_on_grid, get_warp_local_1d_id() % NWave, 0, - KPack / KGroup * (get_thread_local_1d_id() % WarpSize))); + 0, + KPack * (get_thread_local_1d_id() % WarpSize))); // LDS allocation for A and B: be careful of alignment // Cast after lds auto a_block_buf = make_dynamic_buffer( - static_cast(p_shared), - a_block_desc_ak0_m_ak1.GetElementSpaceSize() / APackedSize); + static_cast(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0); - constexpr auto b_block_slice_copy_step = make_multi_index(0, 0, KRepeat, 0); + constexpr auto b_block_slice_copy_step = make_multi_index(0, 0, 0, KRepeat, 0); // Blockwise GEMM pipeline static_assert(std::is_default_constructible_v); @@ -1505,13 +1505,16 @@ struct GridwiseMoeGemmMX_BPreshuffle const auto waveId_m = wave_idx[I0]; const auto waveId_n = wave_idx[I1]; - static constexpr auto mfma = BlockwiseGemmPipe::xdlops_gemm.mfma; - auto thread_offset_shuffled = get_thread_local_1d_id() % BlockwiseGemmPipe::WaveSize * KXdlPack * MXdlPack; auto a_thread_offset_m = waveId_m; + // get each thread's offset int the scale tensor + const index_t token_scale_pos = block_m_id * MPerBlock; + if(token_scale_pos >= max_token_id || token0 >= problem.NumTokens) + return; + auto a_scale_thread_copy = ThreadwiseTensorSliceTransfer_v2< AScaleDataType, AScaleDataType, @@ -1538,7 +1541,7 @@ struct GridwiseMoeGemmMX_BPreshuffle Sequence<1, 1, KXdlPack * NXdlPack / scale_pack_size_b>, // SliceLengths Sequence<0, 1, 2>, // DimAccessOrder 2, // SrcVectorDim - KXdlPack * MXdlPack / scale_pack_size_b, // SrcScalarPerVector + KXdlPack * NXdlPack / scale_pack_size_b, // SrcScalarPerVector 1, // SrcScalarStrideInVector true>(b_scale_grid_desc_bn_ak, make_multi_index(block_n_id * NPerBlock / NPerXdl / NXdlPack + b_thread_offset_n, @@ -1547,29 +1550,37 @@ struct GridwiseMoeGemmMX_BPreshuffle if constexpr(IsInputGemm) { - const BDataType* p_b_grid_up = p_b_grid + expert_stride / 2 / BPackedSize; + const BDataType* p_b_grid_up = p_b_grid + expert_stride / 2; const auto b_grid_buf_up = make_dynamic_buffer( - p_b_grid_up + expert_id * expert_stride / BPackedSize, + p_b_grid_up + expert_id * expert_stride, b_grid_desc_bpreshuffled.GetElementSpaceSize()); - auto b_blockwise_copy_up = ThreadwiseTensorSliceTransfer_v2< - BDataType, - BDataType, - decltype(b_grid_desc_bpreshuffled), - decltype(b_block_desc_bk0_n_bk1), - Sequence{}, I1, Number{}, Number{}>, - Sequence<1, 2, 0, 3>, - 3, - BBlockTransferSrcScalarPerVector, - BThreadTransferSrcResetCoordinateAfterRun, - true>(b_grid_desc_bpreshuffled, - make_multi_index(n_block_data_idx_on_grid, - get_warp_local_1d_id() % NWave, - 0, - KPack / KGroup * (get_thread_local_1d_id() % WarpSize))); - const BScaleDataType* p_b_scale_grid_up = p_b_scale_grid + expert_scale_stride / 2; - const auto b_scale_grid_buf_up = make_dynamic_buffer( - p_b_scale_grid_up + expert_id * expert_scale_stride, + auto b_blockwise_copy_up = + ThreadwiseTensorSliceTransfer_v2{}, + I1, + Number{}, + Number{}, + Number{}>, + Sequence<0, 1, 2, 3, 4>, + 4, + BBlockTransferSrcScalarPerVector, + BThreadTransferSrcResetCoordinateAfterRun, + true>( + b_grid_desc_bpreshuffled, + make_multi_index(n_block_data_idx_on_grid, + get_warp_local_1d_id() % NWave, + 0, + 0, + KPack * (get_thread_local_1d_id() % WarpSize))); + const BScaleDataType* p_b_scale_grid_up = + p_b_scale_grid + expert_scale_stride / 2 / sizeof(BScaleDataType); + const auto b_scale_grid_buf_up = make_dynamic_buffer( + p_b_scale_grid_up + expert_id * expert_scale_stride / sizeof(BScaleDataType), b_scale_grid_desc_bn_ak.GetElementSpaceSize()); + auto b_scale_thread_copy_up = ThreadwiseTensorSliceTransfer_v2< BScaleDataType, BScaleDataType, @@ -1587,25 +1598,30 @@ struct GridwiseMoeGemmMX_BPreshuffle thread_offset_shuffled / scale_pack_size_b)); blockwise_gemm_pipeline.template Run( + // A a_grid_desc_ak0_m_ak1, a_block_desc_ak0_m_ak1, a_blockwise_copy, a_grid_buf, a_block_buf, a_block_slice_copy_step, + // Gate and Up b_grid_desc_bpreshuffled, b_block_desc_bk0_n_bk1, b_blockwise_copy, b_blockwise_copy_up, b_grid_buf, b_grid_buf_up, - b_block_buf, + b_block_bufs, b_block_slice_copy_step, + // C c_thread_buf, c_thread_buf_up, + // A scale a_scale_grid_desc_am_ak, a_scale_thread_copy, a_scale_grid_buf, + // B scale b_scale_grid_desc_bn_ak, b_scale_thread_copy, b_scale_thread_copy_up, @@ -1616,23 +1632,23 @@ struct GridwiseMoeGemmMX_BPreshuffle else { blockwise_gemm_pipeline.template Run( - a_grid_desc_ak0_m_ak1, + a_grid_desc_ak0_m_ak1, // A a_block_desc_ak0_m_ak1, a_blockwise_copy, a_grid_buf, a_block_buf, a_block_slice_copy_step, - b_grid_desc_bpreshuffled, + b_grid_desc_bpreshuffled, // B b_block_desc_bk0_n_bk1, b_blockwise_copy, b_grid_buf, - b_block_buf, + b_block_bufs, b_block_slice_copy_step, - c_thread_buf, - a_scale_grid_desc_am_ak, + c_thread_buf, // C + a_scale_grid_desc_am_ak, // A scale a_scale_thread_copy, a_scale_grid_buf, - b_scale_grid_desc_bn_ak, + b_scale_grid_desc_bn_ak, // B scale b_scale_thread_copy, b_scale_grid_buf, num_k_block_main_loop); @@ -1643,84 +1659,101 @@ struct GridwiseMoeGemmMX_BPreshuffle static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 && NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0, "wrong!"); + static_assert(CShuffleMXdlPerWavePerShuffle % MXdlPack == 0 && + CShuffleNXdlPerWavePerShuffle % NXdlPack == 0, + "wrong!"); + + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); // TODO: hacky, fix it! constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 = - blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3(); // TODO: hacky, fix it! // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp = - blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3(); constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0); constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1); constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2); constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3); constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4); - constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5); - constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6); - constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7); + constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5); + constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6); + constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7); + constexpr auto M5 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I8); + constexpr auto N3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I9); // mul scales - static_assert(M0 * M1 * M2 * M3 * M4 == MPerBlock); - static_assert(M4 == 4); + + static_assert(M0 * M1 * M2 * M3 * M4 * M5 == MPerBlock); + static_assert(M5 == 4); const index_t m1 = get_warp_local_1d_id() / NWave; - const index_t m3 = threadIdx.x % get_warp_size() / MPerXdl; + const index_t m4 = threadIdx.x % get_warp_size() / MPerXdl; vector_type topk_weights; // for gemm2 only - static_for<0, NXdlPerWave, 1>{}([&](auto n0) { - static_for<0, MXdlPerWave, 1>{}([&](auto m0) { // MXDLPerWave - static_for<0, M2, 1>{}([&](auto m2) { // m_inst_num_groups_per_blk - const index_t m_pos = block_m_id * MPerBlock + m0 * M1 * M2 * M3 * M4 + - m1 * M2 * M3 * M4 + m2 * M3 * M4 + m3 * M4; - if constexpr(MulRoutedWeight) - { - topk_weights = *c_style_pointer_cast*>( - p_ds_grid[I2] + m_pos); - } - static_for<0, M4, 1>{}([&](auto m4) { // m_inst_group_size - constexpr index_t c_offset = - blockwise_gemm_pipeline.GetCThreadDesc().CalculateOffset( - make_tuple(m0, n0, m2 * M4 + m4)); - constexpr auto cidx = Number{}; - - if constexpr(IsInputGemm) // gu fusion - { - if constexpr(ActivationOperation == Activation::silu_and_mul) - { - float gate = c_thread_buf[cidx]; - float up = c_thread_buf_up[cidx]; - if constexpr(MulRoutedWeight) - { - gate = gate * topk_weights.AsType()[m4]; - up = up * topk_weights.AsType()[m4]; - } - tensor_operation::element_wise::Silu{}(gate, gate); - c_thread_buf_fp32(cidx) = gate * up; - } - else if(ActivationOperation == Activation::gelu_and_mul) - { - float gate = c_thread_buf[cidx]; - float up = c_thread_buf_up[cidx]; - if constexpr(MulRoutedWeight) - { - gate = gate * topk_weights.AsType()[m4]; - up = up * topk_weights.AsType()[m4]; - } - tensor_operation::element_wise::Gelu{}(gate, gate); - c_thread_buf_fp32(cidx) = gate * up; - } - } - else - { - c_thread_buf_fp32(cidx) = c_thread_buf[cidx]; + static_for<0, NXdlPerWave / NXdlPack, 1>{}([&](auto n0) { + static_for<0, NXdlPack, 1>{}([&](auto inxdl) { // NXdlPack + static_for<0, MXdlPerWave / MXdlPack, 1>{}([&](auto m0) { // MXDLPerWave + static_for<0, MXdlPack, 1>{}([&](auto imxdl) { // MXdlPack + static_for<0, M3, 1>{}([&](auto m3) { // m_inst_num_groups_per_blk + const index_t m_pos = block_m_id * MPerBlock + + m0 * M2 * M1 * M3 * M4 * M5 + + m1 * M2 * M3 * M4 * M5 + + imxdl * M3 * M4 * M5 + m3 * M4 * M5 + m4 * M5; if constexpr(MulRoutedWeight) { - c_thread_buf_fp32(cidx) = - topk_weights.AsType()[m4] * c_thread_buf_fp32[cidx]; + topk_weights = + *c_style_pointer_cast*>( + p_ds_grid[I2] + m_pos); } - } + static_for<0, M5, 1>{}([&](auto m5) { // m_inst_group_size + constexpr index_t c_offset = + blockwise_gemm_pipeline.GetCThreadDesc().CalculateOffset( + make_tuple(m0, n0, imxdl, inxdl, m3 * M5 + m5)); + constexpr auto cidx = Number{}; + + if constexpr(IsInputGemm) // gu fusion + { + if constexpr(ActivationOperation == + Activation::silu_and_mul) + { + float gate = c_thread_buf[cidx]; + float up = c_thread_buf_up[cidx]; + if constexpr(MulRoutedWeight) + { + gate = gate * topk_weights.AsType()[m5]; + up = up * topk_weights.AsType()[m5]; + } + tensor_operation::element_wise::Silu{}(gate, gate); + c_thread_buf_fp32(cidx) = gate * up; + } + else if(ActivationOperation == Activation::gelu_and_mul) + { + float gate = c_thread_buf[cidx]; + float up = c_thread_buf_up[cidx]; + if constexpr(MulRoutedWeight) + { + gate = gate * topk_weights.AsType()[m5]; + up = up * topk_weights.AsType()[m5]; + } + tensor_operation::element_wise::Gelu{}(gate, gate); + c_thread_buf_fp32(cidx) = gate * up; + } + } + else + { + c_thread_buf_fp32(cidx) = c_thread_buf[cidx]; + if constexpr(MulRoutedWeight) + { + c_thread_buf_fp32(cidx) = + topk_weights.AsType()[m5] * + c_thread_buf_fp32[cidx]; + } + } + }); + }); }); }); }); @@ -1738,19 +1771,25 @@ struct GridwiseMoeGemmMX_BPreshuffle make_tuple( make_freeze_transform(I0), make_unmerge_transform(make_tuple( - Number{}, // M0 (MXdlPerWave) per shuffle - M1, // M1 = MWave - M2, // M2 * M3 * M4 = MPerXdl + Number{}, // M0 (MXdlPerWave) per + // shuffle + M1, // M1 = MWave + M2, // M2 * M3 * M4 = MPerXdl M3, - M4)), + M4, + M5)), make_freeze_transform(I0), make_unmerge_transform(make_tuple( - Number{}, // N0 (NXdlPerWave) per shuffle - N1, // N1 = NWave - N2))), // N2 = NPerXdl + Number{}, // N0 (NXdlPerWave) + // per shuffle + N1, // N1 = NWave + N2, // N2 = NXdlPack + N3))), // N3 = NPerXdl make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple( - Sequence<>{}, Sequence<0, 2, 4, 5, 6>{}, Sequence<>{}, Sequence<1, 3, 7>{})); + make_tuple(Sequence<>{}, + Sequence<0, 2, 4, 6, 7, 8>{}, + Sequence<>{}, + Sequence<1, 3, 5, 9>{})); // calculate origin of thread output tensor on global memory // blockwise GEMM c matrix starting index @@ -1762,8 +1801,8 @@ struct GridwiseMoeGemmMX_BPreshuffle const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor = make_single_stage_tensor_adaptor( - make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))), - make_tuple(Sequence<0, 1, 2, 3, 4>{}), + make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4, M5))), + make_tuple(Sequence<0, 1, 2, 3, 4, 5>{}), make_tuple(Sequence<0>{})); const auto m_thread_data_on_block_idx = @@ -1772,8 +1811,8 @@ struct GridwiseMoeGemmMX_BPreshuffle const auto n_thread_data_on_block_to_n0_n1_n2_adaptor = make_single_stage_tensor_adaptor( - make_tuple(make_merge_transform(make_tuple(N0, N1, N2))), - make_tuple(Sequence<0, 1, 2>{}), + make_tuple(make_merge_transform(make_tuple(N0, N1, N2, N3))), + make_tuple(Sequence<0, 1, 2, 3>{}), make_tuple(Sequence<0>{})); const auto n_thread_data_on_block_idx = @@ -1781,36 +1820,39 @@ struct GridwiseMoeGemmMX_BPreshuffle make_multi_index(n_thread_data_on_block)); // shuffle: threadwise copy C from VGPR to LDS - auto c_thread_copy_vgpr_to_lds = - ThreadwiseTensorSliceTransfer_v1r3, - Sequence<0, 1, 2, 3, 4, 5, 6, 7>, - 7, - 1, - InMemoryDataOperationEnum::Set, - 1, - true>{ - c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, - make_multi_index(0, - 0, - m_thread_data_on_block_idx[I1], - n_thread_data_on_block_idx[I1], - m_thread_data_on_block_idx[I2], - m_thread_data_on_block_idx[I3], - m_thread_data_on_block_idx[I4], - n_thread_data_on_block_idx[I2]), - ck::tensor_operation::element_wise::PassThrough{}}; + auto c_thread_copy_vgpr_to_lds = ThreadwiseTensorSliceTransfer_v1r3< + AccDataType, + CShuffleDataType, + decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2), + decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2), + ck::tensor_operation::element_wise::PassThrough, + Sequence, + Sequence<0, 1, 2, 3, 4, 5, 6, 7, 8, 9>, + 9, + 1, + InMemoryDataOperationEnum::Set, + 1, + true>{c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + make_multi_index(0, + 0, + m_thread_data_on_block_idx[I1], + n_thread_data_on_block_idx[I1], + m_thread_data_on_block_idx[I2], + n_thread_data_on_block_idx[I2], + m_thread_data_on_block_idx[I3], + m_thread_data_on_block_idx[I4], + m_thread_data_on_block_idx[I5], + n_thread_data_on_block_idx[I3]), + ck::tensor_operation::element_wise::PassThrough{}}; using EDataType = CDataType; @@ -1859,7 +1901,7 @@ struct GridwiseMoeGemmMX_BPreshuffle using CDEBlockTransferCluster = CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock; const auto EGlobalMemoryDataOperation = CGlobalMemoryDataOperation; - constexpr index_t scatter_weight_idx = 1; // hack fix felix + constexpr index_t scatter_weight_idx = 3; // hack fix felix auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7r3_scatter< ThisThreadBlock, decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})), @@ -1867,8 +1909,9 @@ struct GridwiseMoeGemmMX_BPreshuffle decltype(c_ds_desc_refs), decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)), CElementwiseOperation, - Sequence(EGlobalMemoryDataOperation)>, // FIXME: make Sequence - // support arbitray type + Sequence(EGlobalMemoryDataOperation)>, // FIXME: make + // Sequence support + // arbitray type Sequence<1, CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, 1, @@ -1898,13 +1941,25 @@ struct GridwiseMoeGemmMX_BPreshuffle auto c_grid_buf = make_dynamic_buffer( p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + constexpr auto sfc_c_vgpr = - SpaceFillingCurve, - Sequence<0, 1, 2, 3, 4, 5, 6, 7>, - Sequence, + Sequence<0, 1, 2, 3, 4, 5, 6, 7, 8, 9>, + Sequence Date: Fri, 7 Nov 2025 11:42:39 +0800 Subject: [PATCH 6/6] fix MX bpreshuffle gemm B grid descriptor dimension error. (#3170) --- .../gpu/grid/gridwise_gemm_xdl_cshuffle_v3_mx_bpreshuffle.hpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_mx_bpreshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_mx_bpreshuffle.hpp index 3d2ef9b6c4..7c5bd606b2 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_mx_bpreshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_mx_bpreshuffle.hpp @@ -429,8 +429,8 @@ struct GridwiseGemmMX_xdl_cshuffle_v3_bpreshuffle constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); constexpr index_t WaveSize = BlockSize / (MWave * NWave); constexpr index_t NkSwizzleNumber = Number{}; - return make_naive_tensor_descriptor_packed( - make_tuple(N0 / NWave / NXdlPack, NWave, NXdlPack, K0, NkSwizzleNumber)); + return make_naive_tensor_descriptor_packed(make_tuple( + math::integer_divide_ceil(N0, NWave * NXdlPack), NWave, NXdlPack, K0, NkSwizzleNumber)); } __host__ __device__ static auto MakeBGridDescriptor_BK0_N_BK1(