From 5d2cbd111730f44949448afd46a1a1846d471be9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ville=20Pietil=C3=A4?= <188998872+vpietila-amd@users.noreply.github.com> Date: Fri, 13 Mar 2026 03:20:15 +0200 Subject: [PATCH] [CK_TILE, CK_BUILDER] Add two-stage bwd weight kernels to CK Tile profiler (#5237) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Motivation PR #4797 added CK Tile bwd weight kernels to the CK Profiler. The two-stage kernels were not supported in the initial PR. This PR adds the the missing bwd weight two-stage kernels to the CK Profiler. ## Technical Details Extended the CK Tile conv builder factory to build also the elementwise ops required for the two-stage kernels. Extended the CK Builder for CK Tile instance to accept the two-stage flag as part of the algorithm configuration. ## Test Plan Added units tests for CK Builder that verify the two-stage kernel construction. ## Test Result If CI passes, the added unit tests are passing. ## Submission Checklist - [x] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests. --------- Co-authored-by: Ville Pietilä <> --- .../builder/conv_algorithm_concepts.hpp | 2 + .../builder/factory/conv_tile_factory.hpp | 37 +++++- .../ck_tile/conv_tile_tuning_params.hpp | 4 +- .../ck_tile/builder/testing/conv/ck_tile.hpp | 118 ++++++++++++++++++ .../test_ckb_conv_bwd_data_2d_fp16_v3.cpp | 6 +- .../test_ckb_conv_bwd_weight_2d_fp16_v3.cpp | 103 ++++++++++----- .../ck_tile/test_ckb_conv_fwd_2d_fp16_v3.cpp | 6 +- .../conv/ck_tile/test_ckb_conv_fwd_e2e.cpp | 6 +- .../test/impl/conv_algorithm_types.hpp | 2 + .../generate_instances.py | 36 ++++-- .../instances/instance_includes.inc | 1 + .../instances/instance_run.inc | 22 +++- .../elementwise/kernel/elementwise_kernel.hpp | 14 +++ .../elementwise_pipeline_default_policy.hpp | 5 + .../pipeline/elementwise_pipeline_problem.hpp | 14 +++ .../pipeline/elementwise_shape.hpp | 10 ++ 16 files changed, 336 insertions(+), 50 deletions(-) diff --git a/experimental/builder/include/ck_tile/builder/conv_algorithm_concepts.hpp b/experimental/builder/include/ck_tile/builder/conv_algorithm_concepts.hpp index b045fb04fe..6b883ecfc9 100644 --- a/experimental/builder/include/ck_tile/builder/conv_algorithm_concepts.hpp +++ b/experimental/builder/include/ck_tile/builder/conv_algorithm_concepts.hpp @@ -155,6 +155,7 @@ concept TileOptimizationsDescriptor = requires(T t) { { t.num_groups_to_merge } -> std::convertible_to; { t.split_image } -> std::convertible_to; { t.explicit_gemm } -> std::convertible_to; + { t.two_stage } -> std::convertible_to; }; // Base requirement for all ConvAlgorithm concepts, i.e., all conv algorithm concepts must meet this @@ -295,6 +296,7 @@ concept SpecifiesTileOptimizations = requires { { T::optimizations.num_groups_to_merge } -> std::convertible_to; { T::optimizations.split_image } -> std::convertible_to; { T::optimizations.explicit_gemm } -> std::convertible_to; + { T::optimizations.two_stage } -> std::convertible_to; }; template diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_tile_factory.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_tile_factory.hpp index b1f9136eed..05f91f4e9f 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_tile_factory.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_tile_factory.hpp @@ -8,6 +8,8 @@ #include "ck_tile/ops/epilogue.hpp" #include "ck_tile/ops/gemm.hpp" #include "ck_tile/ops/grouped_convolution.hpp" +#include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp" +#include "ck_tile/builder/versions.hpp" #include "ck_tile/builder/conv_signature_concepts.hpp" #include "ck_tile/builder/conv_algorithm_concepts.hpp" #include "ck_tile/builder/conv_algorithm_limits.hpp" @@ -68,6 +70,10 @@ struct ConvTileFactory GroupedConvTraitsType::FixedGemmParams::TilePartitionerGroupNum, GroupedConvTraitsType::FixedGemmParams::TilePartitionerM01>; + using ConvOutDataType = std::conditional_t; + using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits< GroupedConvTraitsType::FixedGemmParams::kPadM, GroupedConvTraitsType::FixedGemmParams::kPadN, @@ -103,7 +109,7 @@ struct ConvTileFactory typename Types::BDataType, typename Types::DsDataTypes, typename Types::AccDataType, - typename Types::EDataType, + ConvOutDataType, typename GroupedConvTraitsType::ImplicitGemmDsLayout, typename GroupedConvTraitsType::FixedGemmParams::ELayout, typename Ops::CDEElementwiseOp, @@ -126,4 +132,33 @@ struct ConvTileFactory ConvEpilogue>::Instance; }; +template +struct ElementwiseOpTileFactory +{ + static constexpr auto BLOCK = internal::SetTileThreadBlockInfo(); + static constexpr auto BLOCK_GEMM = internal::SetTileBlockGemm(); + + using Types = internal::TileConvTensorTypes; + using XDataType = Types::AccDataType; + using WorkspaceDataType = Types::AccDataType; + using XElementwiseOperation = ck_tile::element_wise::UnaryConvert; + using YDataType = Types::EDataType; + using BlockTile = ck_tile::sequence; + using BlockWarps = ck_tile::sequence; + using WarpTile = ck_tile::sequence; + using ElementwiseShape = + ck_tile::ElementWiseShape; + + // Conversion from X -> Y. + using Problem = ck_tile::ElementWisePipelineProblem; + + using Instance = ck_tile::ElementWiseKernel; +}; + } // namespace ck_tile::builder::factory diff --git a/experimental/builder/include/ck_tile/builder/factory/helpers/ck_tile/conv_tile_tuning_params.hpp b/experimental/builder/include/ck_tile/builder/factory/helpers/ck_tile/conv_tile_tuning_params.hpp index 8bc7de633a..0f2fbed216 100644 --- a/experimental/builder/include/ck_tile/builder/factory/helpers/ck_tile/conv_tile_tuning_params.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/helpers/ck_tile/conv_tile_tuning_params.hpp @@ -34,6 +34,7 @@ struct TileOptimizations int num_groups_to_merge = 1; bool split_image = false; bool explicit_gemm = false; + bool two_stage = false; }; template @@ -181,7 +182,8 @@ consteval TileOptimizations SetTileOptimizations() return TileOptimizations{.num_groups_to_merge = OPT.num_groups_to_merge, .split_image = OPT.split_image, - .explicit_gemm = OPT.explicit_gemm}; + .explicit_gemm = OPT.explicit_gemm, + .two_stage = OPT.two_stage}; } } // namespace ck_tile::builder::factory::internal diff --git a/experimental/builder/include/ck_tile/builder/testing/conv/ck_tile.hpp b/experimental/builder/include/ck_tile/builder/testing/conv/ck_tile.hpp index da771dfd19..ac6241af85 100644 --- a/experimental/builder/include/ck_tile/builder/testing/conv/ck_tile.hpp +++ b/experimental/builder/include/ck_tile/builder/testing/conv/ck_tile.hpp @@ -91,6 +91,100 @@ template +[[nodiscard]] RunResult run(CkTileConvInstance auto& conv, + auto& elementwise_op, + const Args& args, + InDataType* input, + WeiDataType* weight, + OutDataType* output, + const ck_tile::stream_config s_conf) +{ + using Conv = std::remove_reference_t; + using ElementwiseOp = std::remove_reference_t; + using WorkspaceDataType = typename ElementwiseOp::ComputeDataType; + using CDataType = typename ElementwiseOp::YDataType; + using BlockShape = typename ElementwiseOp::Problem::BlockShape; + + const auto param = args.to_ck_tile_conv_param(); + + ck_tile::GroupedConvHostArgs + host_args(param, input, weight, {}, output, args.k_batch); + + // Set-up for elementwise op kernel. + const ck_tile::index_t spatial_lengths_accum = + std::accumulate(host_args.filter_spatial_lengths_.begin(), + host_args.filter_spatial_lengths_.end(), + 1, + std::multiplies()); + ck_tile::DeviceMem ws_m_n_dev_buf(host_args.G_ * host_args.K_ * host_args.C_ * + spatial_lengths_accum * sizeof(WorkspaceDataType)); + ck_tile::GroupedConvBwdWeightHostArgs ws_args = + ck_tile::GroupedConvBwdWeightHostArgs(host_args); + auto c_ptr = ws_args.wei_ptr; + ws_args.wei_ptr = ws_m_n_dev_buf.GetDeviceBuffer(); + + auto kargs = Conv::MakeKernelArgs(ws_args); + const dim3 grids = Conv::GridSize(kargs); + const dim3 blocks = Conv::BlockSize(); + + if(!Conv::IsSupportedArgument(kargs)) + return RunResult::not_supported("unsupported ck_tile arguments"); + + ck_tile::index_t total_elements = 1; + std::vector shape = { + static_cast(host_args.G_ * host_args.K_), + static_cast(host_args.C_ * spatial_lengths_accum)}; + + for(auto d : shape) + total_elements *= d; + + const ck_tile::index_t kBlockSize = ElementwiseOp::BlockSize(); + + constexpr ck_tile::index_t elements_per_block = BlockShape::kBlockM; + ck_tile::index_t kGridSize = (total_elements + elements_per_block - 1) / elements_per_block; + + auto input_tensors = ck_tile::make_tuple(static_cast(ws_args.wei_ptr)); + auto input_size = ck_tile::make_tuple(shape[0], shape[1]); + + // Check if the kernel configuration is supported + if(!ElementwiseOp::IsSupportedArgument(input_size)) + { + return RunResult::not_supported("unsupported ck_tile arguments for elementwise op"); + } + + auto preprocess = [&]() { + if constexpr(ConvDirectionIsBackwardWeight) + { + if(args.k_batch > 1) + { + ck_tile::hip_check_error( + hipMemsetAsync(ws_args.wei_ptr, + 0, + shape[0] * shape[1] * sizeof(WorkspaceDataType), + s_conf.stream_id_)); + } + } + }; + + constexpr index_t minimum_occupancy = + Conv::GemmPipeline::Scheduler == ck_tile::GemmPipelineScheduler::Intrawave ? 1 : 2; + + return RunResult::from_runtime(ck_tile::launch_kernel_time_mask( + s_conf, + preprocess, + ck_tile::make_kernel(conv, grids, blocks, 0, kargs), + ck_tile::make_kernel(elementwise_op, + kGridSize, + kBlockSize, + 0, + input_size, + ck_tile::make_tuple(shape[1], 1), // Input Stride + ck_tile::make_tuple(shape[1], 1), // Output Stride + input_tensors, + static_cast(c_ptr)))); +} + } // namespace detail /// @brief Concept for checking whether a convolution is invoked like CK Tile. @@ -149,4 +243,28 @@ template s_conf); } +/// @brief `run()` specialization for two-stage backwards weight convolution and CK Tile. +/// +/// @tparam SIGNATURE Backwards weight convolution signature. +/// @returns RunResult about how the operation completed (or not). +/// +/// @see run() +template + requires ConvDirectionIsBackwardWeight +[[nodiscard]] RunResult run(CkTileConvInstance auto& conv, + auto& elementwise_op, + const Args& args, + const Inputs& inputs, + const Outputs& outputs, + const ck_tile::stream_config s_conf = {}) +{ + return detail::run(conv, + elementwise_op, + args, + static_cast(inputs.input), + static_cast(outputs.weight), + static_cast(inputs.output), + s_conf); +} + } // namespace ck_tile::builder::test diff --git a/experimental/builder/test/conv/ck_tile/test_ckb_conv_bwd_data_2d_fp16_v3.cpp b/experimental/builder/test/conv/ck_tile/test_ckb_conv_bwd_data_2d_fp16_v3.cpp index 89baf9b51b..fef339e102 100644 --- a/experimental/builder/test/conv/ck_tile/test_ckb_conv_bwd_data_2d_fp16_v3.cpp +++ b/experimental/builder/test/conv/ck_tile/test_ckb_conv_bwd_data_2d_fp16_v3.cpp @@ -25,8 +25,10 @@ TEST(BwdDataConvInstances, Create_ConvAlgorithm_Tile_GroupedConvolutionKernel_2D .with_tile_thread_block(TileThreadBlock_64x64x64) .with_tile_block_gemm(TileBlockGemmDesc_16x16_v3_intrawave) .with_tile_transfer(TileTransfer_4x4x4) - .with_tile_optimizations(TileOptimizations{ - .num_groups_to_merge = 1, .split_image = false, .explicit_gemm = false}); + .with_tile_optimizations(TileOptimizations{.num_groups_to_merge = 1, + .split_image = false, + .explicit_gemm = false, + .two_stage = false}); using Builder = ConvBuilder; run_ck_tile_test({ diff --git a/experimental/builder/test/conv/ck_tile/test_ckb_conv_bwd_weight_2d_fp16_v3.cpp b/experimental/builder/test/conv/ck_tile/test_ckb_conv_bwd_weight_2d_fp16_v3.cpp index 60dc45545f..1600e4fbb7 100644 --- a/experimental/builder/test/conv/ck_tile/test_ckb_conv_bwd_weight_2d_fp16_v3.cpp +++ b/experimental/builder/test/conv/ck_tile/test_ckb_conv_bwd_weight_2d_fp16_v3.cpp @@ -12,6 +12,7 @@ namespace ckb = ck_tile::builder; namespace ckt = ck_tile::builder::test; namespace cku = ck_tile::builder::test_utils; +namespace ckf = ck_tile::builder::factory; using enum ck_tile::builder::TensorLayout; using ck_tile::test::MatchesReference; @@ -31,12 +32,49 @@ constexpr auto ALGORITHM = .with_tile_thread_block(cku::TileThreadBlock_64x64x64) .with_tile_block_gemm(cku::TileBlockGemmDesc_16x16_v3_intrawave) .with_tile_transfer(cku::TileTransfer_4x4x4) - .with_tile_optimizations(ckt::TileOptimizations{ - .num_groups_to_merge = 1, .split_image = false, .explicit_gemm = false}); + .with_tile_optimizations(ckt::TileOptimizations{.num_groups_to_merge = 1, + .split_image = false, + .explicit_gemm = false, + .two_stage = false}); + +constexpr auto TWO_STAGE_ALGORITHM = + cku::ConvAlgorithm_Tile_GroupedConvolutionKernel{} + .with_tile_specializations(ckb::TileConvSpecialization::DEFAULT) + .with_tile_thread_block(cku::TileThreadBlock_64x64x64) + .with_tile_block_gemm(cku::TileBlockGemmDesc_16x16_v3_intrawave) + .with_tile_transfer(cku::TileTransfer_4x4x4) + .with_tile_optimizations(ckt::TileOptimizations{.num_groups_to_merge = 1, + .split_image = false, + .explicit_gemm = false, + .two_stage = true}); + +constexpr ckt::Args Args = { + .lengths = + { + .batch_size = 2, + .groups = 4, + .input_channels = 32, + .output_channels = 48, + .image = {.width = 32, .height = 56}, + .filter = {.width = 3, .height = 3}, + }, + .filter_strides = {.width = 1, .height = 1}, + .filter_dilation = {.width = 1, .height = 1}, + .input_left_pad = {.width = 0, .height = 0}, + .input_right_pad = {.width = 0, .height = 0}, + .a_elementwise_op = {}, + .b_elementwise_op = {}, + .cde_elementwise_op = {}, +}; using Builder = ckb::ConvBuilder; using Instance = Builder::Instance; +using TwoStageBuilder = ckb::ConvBuilder; +using TwoStageInstance = TwoStageBuilder::Instance; +using ElementwiseOpBuilder = ckf::ElementwiseOpTileFactory; +using ElementwiseOpInstance = ElementwiseOpBuilder::Instance; + using Reference = ckb::ConvBuilder::Instance; TEST(BwdWeight_2D_FP16_NHWGC, Create) @@ -61,38 +99,47 @@ TEST(BwdWeight_2D_FP16_NHWGC, Create) }); } +TEST(ElementWiseOp, CreateBwdWeightTwoStageElementwiseOp) +{ + cku::run_ck_tile_test({"elementwise_kernel", + "4096_256_4_4_64_4_256", + "UnaryConvert", + "kPad_1", + "ElementWiseDefaultPolicy"}); +} + TEST(BwdWeight_2D_FP16_NHWGC, Execution) { - ckt::Args args = { - .lengths = - { - .batch_size = 2, - .groups = 4, - .input_channels = 32, - .output_channels = 48, - .image = {.width = 32, .height = 56}, - .filter = {.width = 3, .height = 3}, - }, - .filter_strides = {.width = 1, .height = 1}, - .filter_dilation = {.width = 1, .height = 1}, - .input_left_pad = {.width = 0, .height = 0}, - .input_right_pad = {.width = 0, .height = 0}, - .a_elementwise_op = {}, - .b_elementwise_op = {}, - .cde_elementwise_op = {}, - }; + auto inputs = ckt::alloc_inputs(Args); + auto outputs = ckt::alloc_outputs(Args); + auto reference = ckt::alloc_outputs(Args); - auto inputs = ckt::alloc_inputs(args); - auto outputs = ckt::alloc_outputs(args); - auto reference = ckt::alloc_outputs(args); - - ckt::init_inputs(args, inputs.get()); + ckt::init_inputs(Args, inputs.get()); auto conv = Instance{}; - EXPECT_THAT(ckt::run(conv, args, inputs.get(), outputs.get()), SuccessfulRun()); + EXPECT_THAT(ckt::run(conv, Args, inputs.get(), outputs.get()), SuccessfulRun()); auto ref_conv = Reference{}; - EXPECT_THAT(ckt::run(ref_conv, args, inputs.get(), reference.get()), SuccessfulRun()); + EXPECT_THAT(ckt::run(ref_conv, Args, inputs.get(), reference.get()), SuccessfulRun()); - EXPECT_THAT(outputs.get(), MatchesReference(args, reference.get())); + EXPECT_THAT(outputs.get(), MatchesReference(Args, reference.get())); +} + +TEST(BwdWeight_TwoStage_2D_FP16_NHWGC, Execution) +{ + auto inputs = ckt::alloc_inputs(Args); + auto outputs = ckt::alloc_outputs(Args); + auto reference = ckt::alloc_outputs(Args); + + ckt::init_inputs(Args, inputs.get()); + + auto conv = TwoStageInstance{}; + auto elementwise_op = ElementwiseOpInstance{}; + + EXPECT_THAT(ckt::run(conv, elementwise_op, Args, inputs.get(), outputs.get()), SuccessfulRun()); + + auto ref_conv = Reference{}; + EXPECT_THAT(ckt::run(ref_conv, Args, inputs.get(), reference.get()), SuccessfulRun()); + + EXPECT_THAT(outputs.get(), MatchesReference(Args, reference.get())); } diff --git a/experimental/builder/test/conv/ck_tile/test_ckb_conv_fwd_2d_fp16_v3.cpp b/experimental/builder/test/conv/ck_tile/test_ckb_conv_fwd_2d_fp16_v3.cpp index 2c35fb5076..f19a8533e3 100644 --- a/experimental/builder/test/conv/ck_tile/test_ckb_conv_fwd_2d_fp16_v3.cpp +++ b/experimental/builder/test/conv/ck_tile/test_ckb_conv_fwd_2d_fp16_v3.cpp @@ -24,8 +24,10 @@ TEST(FwdConvInstances, Create_ConvAlgorithm_Tile_GroupedConvolutionKernel_2D_FP1 .with_tile_thread_block(TileThreadBlock_64x64x64) .with_tile_block_gemm(TileBlockGemmDesc_16x16_v3_intrawave) .with_tile_transfer(TileTransfer_4x4x4) - .with_tile_optimizations(TileOptimizations{ - .num_groups_to_merge = 1, .split_image = false, .explicit_gemm = false}); + .with_tile_optimizations(TileOptimizations{.num_groups_to_merge = 1, + .split_image = false, + .explicit_gemm = false, + .two_stage = false}); using Builder = ConvBuilder; run_ck_tile_test({ diff --git a/experimental/builder/test/conv/ck_tile/test_ckb_conv_fwd_e2e.cpp b/experimental/builder/test/conv/ck_tile/test_ckb_conv_fwd_e2e.cpp index 650c217b71..70ebd32164 100644 --- a/experimental/builder/test/conv/ck_tile/test_ckb_conv_fwd_e2e.cpp +++ b/experimental/builder/test/conv/ck_tile/test_ckb_conv_fwd_e2e.cpp @@ -31,8 +31,10 @@ constexpr auto ALGORITHM = .with_tile_thread_block(cku::FwdTileThreadBlock_64x64x64) .with_tile_block_gemm(cku::TileBlockGemmDesc_16x16_v3_intrawave) .with_tile_transfer(cku::FwdTileTransfer_4x4x4) - .with_tile_optimizations(ckt::TileOptimizations{ - .num_groups_to_merge = 1, .split_image = false, .explicit_gemm = false}); + .with_tile_optimizations(ckt::TileOptimizations{.num_groups_to_merge = 1, + .split_image = false, + .explicit_gemm = false, + .two_stage = false}); using Builder = ckb::ConvBuilder; using Instance = Builder::Instance; diff --git a/experimental/builder/test/impl/conv_algorithm_types.hpp b/experimental/builder/test/impl/conv_algorithm_types.hpp index 4b99fd8100..2b9db31fa2 100644 --- a/experimental/builder/test/impl/conv_algorithm_types.hpp +++ b/experimental/builder/test/impl/conv_algorithm_types.hpp @@ -377,6 +377,8 @@ struct TileOptimizations bool split_image; // Explicit gemm for 1x1, stride=0, pad=0 cases bool explicit_gemm; + // Two-stage kernels + bool two_stage; }; static_assert(ckb::TileOptimizationsDescriptor); diff --git a/experimental/grouped_convolution_tile_instances/generate_instances.py b/experimental/grouped_convolution_tile_instances/generate_instances.py index 925c9fe700..b78c60c105 100755 --- a/experimental/grouped_convolution_tile_instances/generate_instances.py +++ b/experimental/grouped_convolution_tile_instances/generate_instances.py @@ -13,6 +13,7 @@ class ConvInstanceTemplateParams: warp_tile, double_smem_buffer, num_wave_groups, + is_two_stage_instance, pipeline_version, scheduler, scalar_per_vector, @@ -27,6 +28,7 @@ class ConvInstanceTemplateParams: self.warp_tile = warp_tile self.double_smem_buffer = double_smem_buffer self.num_wave_groups = num_wave_groups + self.is_two_stage_instance = is_two_stage_instance self.pipeline_version = pipeline_version self.scheduler = scheduler self.scalar_per_vector = scalar_per_vector @@ -39,7 +41,8 @@ class ConvInstanceTemplateParams: explicit_gemm = "true" if self.explicit_gemm else "false" split_image = "true" if self.split_image else "false" num_groups_to_merge = str(self.num_groups_to_merge) - return f"ckt::TileOptimizations{{.num_groups_to_merge = {num_groups_to_merge}, .split_image = {split_image}, .explicit_gemm = {explicit_gemm}}}" + two_stage_instance = "true" if self.is_two_stage_instance else "false" + return f"ckt::TileOptimizations{{.num_groups_to_merge = {num_groups_to_merge}, .split_image = {split_image}, .explicit_gemm = {explicit_gemm}, .two_stage = {two_stage_instance}}}" def get_specialization(self): namespace = "ckb::TileConvSpecialization::" @@ -270,6 +273,8 @@ def parse_fwd_instances(instances, problem_name): print(f"Skipping instance {instance_id} with ASYNC_V4 since it's not supported yet.") continue + is_two_stage = False + conv = ConvInstanceTemplateParams( spec, [m_per_block, n_per_block, k_per_block], @@ -277,6 +282,7 @@ def parse_fwd_instances(instances, problem_name): [m_per_xdl, n_per_xdl, k_per_xdl], double_smem_buffer, num_wave_groups, + is_two_stage, pipeline_version, scheduler, [a_scalar_per_vector, b_scalar_per_vector, c_scalar_per_vector], @@ -345,7 +351,7 @@ def parse_bwd_weight_instances(instances, problem_name): num_groups_to_merge = 1 # Block GEMM pipeline parameters - blk_gemm_pipeline_schduler = args[6] + block_gemm_pipeline_scheduler = args[6] blk_gemm_pipeline_version = args[7] else: spec = args[11] @@ -375,20 +381,29 @@ def parse_bwd_weight_instances(instances, problem_name): num_groups_to_merge = int(args[44]) # Block GEMM pipeline parameters - blk_gemm_pipeline_schduler = args[39] + block_gemm_pipeline_scheduler = args[39] blk_gemm_pipeline_version = args[40] elif is_two_stage_instance: - print(f"Skipping instance {instance_id} with device op {device_op_name} since it's not supported yet.") - continue + if len(args) != 46: + raise RuntimeError(f"Wrong number of parameters in the TwoStage instance string: {instance}\n" + + f"Expected 46 parameters for TwoStage instance. Found {len(args)} parameters.") + + num_groups_to_merge = args[41] + + # Block GEMM pipeline parameters + block_gemm_pipeline_scheduler = args[39] + blk_gemm_pipeline_version = args[40] + else: # Regular V1 XDL CShuffle instance if len(args) != 43: - raise RuntimeError(f"Wrong number of parameters in the XDL CShuffle instance string: {instance}") + raise RuntimeError(f"Wrong number of parameters in the XDL CShuffle instance string: {instance}\n" + + f"Expected 43 parameters for V1 instance. Found {len(args)} parameters.") num_groups_to_merge = 1 # Block GEMM pipeline parameters - blk_gemm_pipeline_schduler = "Intrawave" + block_gemm_pipeline_scheduler = "Intrawave" blk_gemm_pipeline_version = "v1" # Common part to all solvers. @@ -396,15 +411,15 @@ def parse_bwd_weight_instances(instances, problem_name): # Sanity check for Block GEMM pipeline parameters # Scheduler must be either Intrawave or Interwave. # Version must be from v1 to v5 - if blk_gemm_pipeline_schduler not in ["Intrawave", "Interwave"]: - raise RuntimeError(f"Invalid Block GEMM pipeline scheduler: {blk_gemm_pipeline_schduler} in instance: {instance}") + if block_gemm_pipeline_scheduler not in ["Intrawave", "Interwave"]: + raise RuntimeError(f"Invalid Block GEMM pipeline scheduler: {block_gemm_pipeline_scheduler} in instance: {instance}") if blk_gemm_pipeline_version not in ["v1", "v2", "v3", "v4", "v5"]: raise RuntimeError(f"Invalid Block GEMM pipeline version: {blk_gemm_pipeline_version} in instance: {instance}") split_image = instance.find("Large") != -1 double_smem_buffer = blk_gemm_pipeline_version == "v4" num_wave_groups = 1 - scheduler = blk_gemm_pipeline_schduler + scheduler = block_gemm_pipeline_scheduler pipeline_version = blk_gemm_pipeline_version.upper() # OLd CK pipeline version V5 maps to V6 for CK Tile @@ -441,6 +456,7 @@ def parse_bwd_weight_instances(instances, problem_name): [m_per_xdl, n_per_xdl, k_per_xdl], double_smem_buffer, num_wave_groups, + is_two_stage_instance, pipeline_version, scheduler, [a_scalar_per_vector, b_scalar_per_vector, c_scalar_per_vector], diff --git a/experimental/grouped_convolution_tile_instances/instances/instance_includes.inc b/experimental/grouped_convolution_tile_instances/instances/instance_includes.inc index 2b391ea12a..b5e0216bd6 100644 --- a/experimental/grouped_convolution_tile_instances/instances/instance_includes.inc +++ b/experimental/grouped_convolution_tile_instances/instances/instance_includes.inc @@ -6,6 +6,7 @@ namespace ckb = ck_tile::builder; namespace ckt = ck_tile::builder::test; namespace cku = ck_tile::builder::test_utils; +namespace ckf = ck_tile::builder::factory; namespace ck_tile::builder::profiling { diff --git a/experimental/grouped_convolution_tile_instances/instances/instance_run.inc b/experimental/grouped_convolution_tile_instances/instances/instance_run.inc index 016ef3e653..14311bbb83 100644 --- a/experimental/grouped_convolution_tile_instances/instances/instance_run.inc +++ b/experimental/grouped_convolution_tile_instances/instances/instance_run.inc @@ -1,7 +1,21 @@ -using Builder = ckb::ConvBuilder; -using Instance = Builder::Instance; +using Builder = ckb::ConvBuilder; +using ConvInstance = Builder::Instance; + +auto conv = ConvInstance{}; + +auto result = [&]() { + if constexpr(ConvDirectionIsBackwardWeight && Alg.optimizations.two_stage) + { + using ElementwiseOpBuilder = ckf::ElementwiseOpTileFactory; + using ElementwiseOpInstance = ElementwiseOpBuilder::Instance; + auto elementwise_op = ElementwiseOpInstance{}; + return ckt::run(conv, elementwise_op, args, inputs, outputs, s_conf); + } + else + { + return ckt::run(conv, args, inputs, outputs, s_conf); + } +}.template operator()(); -auto conv = Instance{}; -ckt::RunResult result = ckt::run(conv, args, inputs, outputs, s_conf); return std::make_tuple(result.is_supported(), result.runtime, conv.GetInstanceString()); diff --git a/include/ck_tile/ops/elementwise/kernel/elementwise_kernel.hpp b/include/ck_tile/ops/elementwise/kernel/elementwise_kernel.hpp index 2078a69546..a4dd791b83 100644 --- a/include/ck_tile/ops/elementwise/kernel/elementwise_kernel.hpp +++ b/include/ck_tile/ops/elementwise/kernel/elementwise_kernel.hpp @@ -4,6 +4,7 @@ #pragma once #include "ck_tile/core.hpp" +#include "ck_tile/host/concat.hpp" #include "ck_tile/ops/common.hpp" #include "ck_tile/ops/elementwise/pipeline/elementwise_pipeline_problem.hpp" #include "ck_tile/ops/elementwise/pipeline/elementwise_pipeline_default_policy.hpp" @@ -108,6 +109,19 @@ struct ElementWiseKernel ignore = input_sizes; return true; } + + [[nodiscard]] CK_TILE_HOST static const std::string GetName() + { + // clang-format off + return concat('_', "elementwise_kernel", + Problem::GetName(), + "policy", + Policy::GetName() + ); + // clang-format on + } + + [[nodiscard]] CK_TILE_HOST static const std::string GetTypeString() { return GetName(); } }; } // namespace ck_tile diff --git a/include/ck_tile/ops/elementwise/pipeline/elementwise_pipeline_default_policy.hpp b/include/ck_tile/ops/elementwise/pipeline/elementwise_pipeline_default_policy.hpp index f719fd8182..ae01e9fb51 100644 --- a/include/ck_tile/ops/elementwise/pipeline/elementwise_pipeline_default_policy.hpp +++ b/include/ck_tile/ops/elementwise/pipeline/elementwise_pipeline_default_policy.hpp @@ -24,6 +24,11 @@ struct ElementWiseDefaultPolicy sequence<0, 3>>{} // Yield ); } + + [[nodiscard]] CK_TILE_HOST static const std::string GetName() + { + return "ElementWiseDefaultPolicy"; + } }; } // namespace ck_tile diff --git a/include/ck_tile/ops/elementwise/pipeline/elementwise_pipeline_problem.hpp b/include/ck_tile/ops/elementwise/pipeline/elementwise_pipeline_problem.hpp index a4edd95970..4e43bd0e89 100644 --- a/include/ck_tile/ops/elementwise/pipeline/elementwise_pipeline_problem.hpp +++ b/include/ck_tile/ops/elementwise/pipeline/elementwise_pipeline_problem.hpp @@ -4,6 +4,7 @@ #pragma once #include "ck_tile/core/utility/type_traits.hpp" +#include "ck_tile/host/concat.hpp" namespace ck_tile { @@ -21,6 +22,19 @@ struct ElementWisePipelineProblem using BlockShape = remove_cvref_t; using ElementWiseOperation = remove_cvref_t; static constexpr bool kPad = kPad_; + + [[nodiscard]] CK_TILE_HOST static const std::string GetName() + { + // clang-format off + return concat('_', + BlockShape::GetName(), + "op", + ElementWiseOperation::name, + "kPad", + kPad + ); + // clang-format on + } }; } // namespace ck_tile diff --git a/include/ck_tile/ops/elementwise/pipeline/elementwise_shape.hpp b/include/ck_tile/ops/elementwise/pipeline/elementwise_shape.hpp index 82d68f1883..75e2f70afe 100644 --- a/include/ck_tile/ops/elementwise/pipeline/elementwise_shape.hpp +++ b/include/ck_tile/ops/elementwise/pipeline/elementwise_shape.hpp @@ -4,6 +4,7 @@ #pragma once #include "ck_tile/core/utility/type_traits.hpp" +#include "ck_tile/host/concat.hpp" namespace ck_tile { @@ -25,6 +26,15 @@ struct ElementWiseShape static constexpr index_t kBlockSize = ck_tile::get_warp_size() * reduce_on_sequence(BlockWarps{}, multiplies<>{}, number<1>{}); + + [[nodiscard]] CK_TILE_HOST static const std::string GetName() + { + // clang-format off + return concat('_', "shape", + kBlockM, kWarpM, kVectorM, kWarpPerBlockM, kThreadPerWarpM, kRepeatM, kBlockSize + ); + // clang-format on + } }; } // namespace ck_tile