From 083ea723a045fcb88fcdca0048c9cea5c02af637 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bart=C5=82omiej=20Kocot?= Date: Tue, 25 Nov 2025 14:57:43 +0100 Subject: [PATCH] [CK_BUILDER] Add grouped conv bwd ck tile traits (#3281) * [CK_BUILDER] Add grouped conv bwd ck tile traits * copilot fixes [ROCm/composable_kernel commit: 9ac2666d5b48efc3743ce073aab0a68833accf5c] --- ...tion_backward_weight_two_stage_invoker.hpp | 2 +- ...tile_grouped_convolution_backward_data.hpp | 140 ++++++++++++++++++ ...le_grouped_convolution_backward_weight.hpp | 140 ++++++++++++++++++ ...raits_tile_grouped_convolution_forward.hpp | 14 +- experimental/builder/test/CMakeLists.txt | 1 + .../test/test_bwd_data_instance_traits.cpp | 133 +++++++++++++++++ .../test/test_bwd_weight_instance_traits.cpp | 123 +++++++++++++++ .../builder/test/test_fwd_instance_traits.cpp | 9 +- ...ouped_convolution_backward_data_kernel.hpp | 17 +++ ...ped_convolution_backward_weight_kernel.hpp | 17 +++ 10 files changed, 583 insertions(+), 13 deletions(-) create mode 100644 experimental/builder/include/ck_tile/builder/reflect/instance_traits_tile_grouped_convolution_backward_data.hpp create mode 100644 experimental/builder/include/ck_tile/builder/reflect/instance_traits_tile_grouped_convolution_backward_weight.hpp create mode 100644 experimental/builder/test/test_bwd_data_instance_traits.cpp diff --git a/example/ck_tile/20_grouped_convolution/grouped_convolution_backward_weight_two_stage_invoker.hpp b/example/ck_tile/20_grouped_convolution/grouped_convolution_backward_weight_two_stage_invoker.hpp index 0bc481814a..5d78bc4739 100644 --- a/example/ck_tile/20_grouped_convolution/grouped_convolution_backward_weight_two_stage_invoker.hpp +++ b/example/ck_tile/20_grouped_convolution/grouped_convolution_backward_weight_two_stage_invoker.hpp @@ -133,7 +133,7 @@ struct GroupedConvolutionBackwardWeightTwoStageInvoker ConvConfig::M_Warp_Tile, ConvConfig::N_Warp_Tile, ConvConfig::K_Warp_Tile, - GemmPipelineProblem::TransposeC, + GroupedConvTraitsType::FixedGemmParams::TransposeC, memory_operation, ConvConfig::NumWaveGroups, GroupedConvTraitsType::FixedGemmParams::FixedVectorSize, diff --git a/experimental/builder/include/ck_tile/builder/reflect/instance_traits_tile_grouped_convolution_backward_data.hpp b/experimental/builder/include/ck_tile/builder/reflect/instance_traits_tile_grouped_convolution_backward_data.hpp new file mode 100644 index 0000000000..80283a0467 --- /dev/null +++ b/experimental/builder/include/ck_tile/builder/reflect/instance_traits_tile_grouped_convolution_backward_data.hpp @@ -0,0 +1,140 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT +// InstanceTraits specialization for GroupedConvolutionBackwardDataKernel +// +// CRITICAL MAINTENANCE NOTE: +// This InstanceTraits file MUST be kept strictly in sync with the device implementation header: +// ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_data_kernel.hpp +// "In sync" means that the template parameter order, names, and types in the declaration below +// MUST EXACTLY MATCH those in the device implementation. If these diverge, you may encounter +// compilation errors, subtle template instantiation mismatches, or silent runtime bugs that are +// difficult to diagnose. Always update both files together and review changes carefully. + +#pragma once + +#include "instance_traits.hpp" +#include "instance_traits_util.hpp" + +// Backward Data declaration to avoid circular dependency. +namespace ck_tile { + +template +struct GroupedConvolutionBackwardDataKernel; + +} // namespace ck_tile + +namespace ck_tile { +namespace reflect { + +// Specialization for GroupedConvolutionBackwardDataKernel +template +struct InstanceTraits> +{ + // CK Tile Conv Traits + // Spatial dimension + static constexpr int kSpatialDim = GroupedConvTraitsType_::NDimSpatial; + // Specialization + static constexpr ck_tile::ConvolutionSpecialization ConvSpecialization = + GroupedConvTraitsType_::ConvSpecialization; + // DataType types + using InLayout = typename GroupedConvTraitsType_::InLayout; + using WeiLayout = typename GroupedConvTraitsType_::WeiLayout; + using DsLayout = typename GroupedConvTraitsType_::DsLayout; + using OutLayout = typename GroupedConvTraitsType_::OutLayout; + // Vector size + static constexpr int kVectorSizeA = GroupedConvTraitsType_::VectorSizeA; + static constexpr int kVectorSizeB = GroupedConvTraitsType_::VectorSizeB; + static constexpr int kVectorSizeC = GroupedConvTraitsType_::VectorSizeC; + // Num Groups To Merge + static constexpr int kNumGroupsToMerge = GroupedConvTraitsType_::NumGroupsToMerge; + // Split image (large tensors) + static constexpr bool kEnableSplitImage = GroupedConvTraitsType_::EnableSplitImage; + + // TilePartitioner + // Block configuration + static constexpr int kMPerBlock = TilePartitioner_::MPerBlock; + static constexpr int kNPerBlock = TilePartitioner_::NPerBlock; + static constexpr int kKPerBlock = TilePartitioner_::KPerBlock; + + static constexpr int kMWarp = TilePartitioner_::BlockGemmShape::BlockWarps::at(number<0>{}); + static constexpr int kNWarp = TilePartitioner_::BlockGemmShape::BlockWarps::at(number<1>{}); + static constexpr int kKWarp = TilePartitioner_::BlockGemmShape::BlockWarps::at(number<2>{}); + + static constexpr int kMWarpTile = TilePartitioner_::BlockGemmShape::WarpTile::at(number<0>{}); + static constexpr int kNWarpTile = TilePartitioner_::BlockGemmShape::WarpTile::at(number<1>{}); + static constexpr int kKWarpTile = TilePartitioner_::BlockGemmShape::WarpTile::at(number<2>{}); + + // Data types + using ADataType = typename GemmPipeline_::ADataType; + using BDataType = typename GemmPipeline_::BDataType; + // Gemm Pipeline + using GemmPipeline = GemmPipeline_; + static constexpr ck_tile::GemmPipelineScheduler kPipelineScheduler = GemmPipeline_::Scheduler; + static constexpr bool kDoubleSmemBuffer = GemmPipeline_::DoubleSmemBuffer; + static constexpr int kNumWaveGroups = GemmPipeline_::NumWaveGroups; + + // Epilogue Pipeline + using AccDataType = typename EpiloguePipeline_::AccDataType; + using EDataType = typename EpiloguePipeline_::ODataType; + using DsDataType = typename EpiloguePipeline_::DsDataType; + using CDEElementwiseOperation = typename EpiloguePipeline_::CDElementwise; + + // Static member function to generate instance string + static std::string instance_string() + { + std::ostringstream oss; + + // Kernel type name + oss << "GroupedConvolutionBackwardDataKernel"; + + // Template parameters in exact order matching InstanceTraits member order + oss << "<" << kSpatialDim; // 1. NDimSpatial + oss << "," + << ck_tile::getConvSpecializationString(ConvSpecialization); // 2. ConvSpecialization + oss << "," << detail::layout_name(); // 3. InLayout + oss << "," << detail::layout_name(); // 4. WeiLayout + oss << "," << detail::tuple_name(); // 5. DsLayout + oss << "," << detail::layout_name(); // 6. OutLayout + oss << "," << kVectorSizeA; // 7. VectorSizeA + oss << "," << kVectorSizeB; // 8. VectorSizeB + oss << "," << kVectorSizeC; // 9. VectorSizeC + oss << "," << kNumGroupsToMerge; // 10. NumGroupsToMerge + oss << "," << kEnableSplitImage; // 11. EnableSplitImage + oss << "," << kMPerBlock; // 12. MPerBlock + oss << "," << kNPerBlock; // 13. NPerBlock + oss << "," << kKPerBlock; // 14. KPerBlock + oss << "," << kMWarp; // 15. MWarp + oss << "," << kNWarp; // 16. NWarp + oss << "," << kKWarp; // 17. KWarp + oss << "," << kMWarpTile; // 18. MWarpTile + oss << "," << kNWarpTile; // 19. NWarpTile + oss << "," << kKWarpTile; // 20. KWarpTile + oss << "," << detail::type_name(); // 21. ADataType + oss << "," << detail::type_name(); // 22. BDataType + oss << "," << GemmPipeline::GetPipelineName(); // 23. BlkGemmPipelineVer + oss << "," << detail::pipeline_scheduler_name(kPipelineScheduler); // 24. BlkGemmPipeSched + oss << "," << kDoubleSmemBuffer; // 25. DoubleSmemBuffer + oss << "," << kNumWaveGroups; // 26. NumWaveGroups + oss << "," << detail::type_name(); // 27. AccDataType + oss << "," << detail::type_name(); // 28. EDataType + oss << "," << detail::tuple_name(); // 29. DsDataType + oss << "," + << detail::elementwise_op_name(); // 30. + // CDEElementwiseOperation + oss << ">"; + + return oss.str(); + } +}; + +} // namespace reflect +} // namespace ck_tile diff --git a/experimental/builder/include/ck_tile/builder/reflect/instance_traits_tile_grouped_convolution_backward_weight.hpp b/experimental/builder/include/ck_tile/builder/reflect/instance_traits_tile_grouped_convolution_backward_weight.hpp new file mode 100644 index 0000000000..f856a48e59 --- /dev/null +++ b/experimental/builder/include/ck_tile/builder/reflect/instance_traits_tile_grouped_convolution_backward_weight.hpp @@ -0,0 +1,140 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT +// InstanceTraits specialization for GroupedConvolutionBackwardWeightKernel +// +// CRITICAL MAINTENANCE NOTE: +// This InstanceTraits file MUST be kept strictly in sync with the device implementation header: +// ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_weight_kernel.hpp +// "In sync" means that the template parameter order, names, and types in the declaration below +// MUST EXACTLY MATCH those in the device implementation. If these diverge, you may encounter +// compilation errors, subtle template instantiation mismatches, or silent runtime bugs that are +// difficult to diagnose. Always update both files together and review changes carefully. + +#pragma once + +#include "instance_traits.hpp" +#include "instance_traits_util.hpp" + +// Backward Weight declaration to avoid circular dependency. +namespace ck_tile { + +template +struct GroupedConvolutionBackwardWeightKernel; + +} // namespace ck_tile + +namespace ck_tile { +namespace reflect { + +// Specialization for GroupedConvolutionBackwardWeightKernel +template +struct InstanceTraits> +{ + // CK Tile Conv Traits + // Spatial dimension + static constexpr int kSpatialDim = GroupedConvTraitsType_::NDimSpatial; + // Specialization + static constexpr ck_tile::ConvolutionSpecialization ConvSpecialization = + GroupedConvTraitsType_::ConvSpecialization; + // DataType types + using InLayout = typename GroupedConvTraitsType_::InLayout; + using WeiLayout = typename GroupedConvTraitsType_::WeiLayout; + using DsLayout = typename GroupedConvTraitsType_::DsLayout; + using OutLayout = typename GroupedConvTraitsType_::OutLayout; + // Vector size + static constexpr int kVectorSizeA = GroupedConvTraitsType_::VectorSizeA; + static constexpr int kVectorSizeB = GroupedConvTraitsType_::VectorSizeB; + static constexpr int kVectorSizeC = GroupedConvTraitsType_::VectorSizeC; + // Num Groups To Merge + static constexpr int kNumGroupsToMerge = GroupedConvTraitsType_::NumGroupsToMerge; + // Split image (large tensors) + static constexpr bool kEnableSplitImage = GroupedConvTraitsType_::EnableSplitImage; + + // TilePartitioner + // Block configuration + static constexpr int kMPerBlock = TilePartitioner_::MPerBlock; + static constexpr int kNPerBlock = TilePartitioner_::NPerBlock; + static constexpr int kKPerBlock = TilePartitioner_::KPerBlock; + + static constexpr int kMWarp = TilePartitioner_::BlockGemmShape::BlockWarps::at(number<0>{}); + static constexpr int kNWarp = TilePartitioner_::BlockGemmShape::BlockWarps::at(number<1>{}); + static constexpr int kKWarp = TilePartitioner_::BlockGemmShape::BlockWarps::at(number<2>{}); + + static constexpr int kMWarpTile = TilePartitioner_::BlockGemmShape::WarpTile::at(number<0>{}); + static constexpr int kNWarpTile = TilePartitioner_::BlockGemmShape::WarpTile::at(number<1>{}); + static constexpr int kKWarpTile = TilePartitioner_::BlockGemmShape::WarpTile::at(number<2>{}); + + // Data types + using ADataType = typename GemmPipeline_::ADataType; + using BDataType = typename GemmPipeline_::BDataType; + // Gemm Pipeline + using GemmPipeline = GemmPipeline_; + static constexpr ck_tile::GemmPipelineScheduler kPipelineScheduler = GemmPipeline_::Scheduler; + static constexpr bool kDoubleSmemBuffer = GemmPipeline_::DoubleSmemBuffer; + static constexpr int kNumWaveGroups = GemmPipeline_::NumWaveGroups; + + // Epilogue Pipeline + using AccDataType = typename EpiloguePipeline_::AccDataType; + using EDataType = typename EpiloguePipeline_::ODataType; + using DsDataType = typename EpiloguePipeline_::DsDataType; + using CDEElementwiseOperation = typename EpiloguePipeline_::CDElementwise; + + // Static member function to generate instance string + static std::string instance_string() + { + std::ostringstream oss; + + // Kernel type name + oss << "GroupedConvolutionBackwardWeightKernel"; + + // Template parameters in exact order matching InstanceTraits member order + oss << "<" << kSpatialDim; // 1. NDimSpatial + oss << "," + << ck_tile::getConvSpecializationString(ConvSpecialization); // 2. ConvSpecialization + oss << "," << detail::layout_name(); // 3. InLayout + oss << "," << detail::layout_name(); // 4. WeiLayout + oss << "," << detail::tuple_name(); // 5. DsLayout + oss << "," << detail::layout_name(); // 6. OutLayout + oss << "," << kVectorSizeA; // 7. VectorSizeA + oss << "," << kVectorSizeB; // 8. VectorSizeB + oss << "," << kVectorSizeC; // 9. VectorSizeC + oss << "," << kNumGroupsToMerge; // 10. NumGroupsToMerge + oss << "," << kEnableSplitImage; // 11. EnableSplitImage + oss << "," << kMPerBlock; // 12. MPerBlock + oss << "," << kNPerBlock; // 13. NPerBlock + oss << "," << kKPerBlock; // 14. KPerBlock + oss << "," << kMWarp; // 15. MWarp + oss << "," << kNWarp; // 16. NWarp + oss << "," << kKWarp; // 17. KWarp + oss << "," << kMWarpTile; // 18. MWarpTile + oss << "," << kNWarpTile; // 19. NWarpTile + oss << "," << kKWarpTile; // 20. KWarpTile + oss << "," << detail::type_name(); // 21. ADataType + oss << "," << detail::type_name(); // 22. BDataType + oss << "," << GemmPipeline::GetPipelineName(); // 23. BlkGemmPipelineVer + oss << "," << detail::pipeline_scheduler_name(kPipelineScheduler); // 24. BlkGemmPipeSched + oss << "," << kDoubleSmemBuffer; // 25. DoubleSmemBuffer + oss << "," << kNumWaveGroups; // 26. NumWaveGroups + oss << "," << detail::type_name(); // 27. AccDataType + oss << "," << detail::type_name(); // 28. EDataType + oss << "," << detail::tuple_name(); // 29. DsDataType + oss << "," + << detail::elementwise_op_name(); // 30. + // CDEElementwiseOperation + oss << ">"; + + return oss.str(); + } +}; + +} // namespace reflect +} // namespace ck_tile diff --git a/experimental/builder/include/ck_tile/builder/reflect/instance_traits_tile_grouped_convolution_forward.hpp b/experimental/builder/include/ck_tile/builder/reflect/instance_traits_tile_grouped_convolution_forward.hpp index e488b714dd..c42a4f44dd 100644 --- a/experimental/builder/include/ck_tile/builder/reflect/instance_traits_tile_grouped_convolution_forward.hpp +++ b/experimental/builder/include/ck_tile/builder/reflect/instance_traits_tile_grouped_convolution_forward.hpp @@ -16,7 +16,7 @@ #include "instance_traits_util.hpp" // Forward declaration to avoid circular dependency. -namespace ck_tile::device { +namespace ck_tile { template struct GroupedConvolutionForwardKernel; -} // namespace ck_tile::device +} // namespace ck_tile namespace ck_tile { namespace reflect { @@ -34,10 +34,10 @@ template -struct InstanceTraits> +struct InstanceTraits> { // CK Tile Conv Traits // Spatial dimension @@ -122,7 +122,7 @@ struct InstanceTraits(); // 22. BDataType oss << "," << GemmPipeline::GetPipelineName(); // 23. BlkGemmPipelineVer oss << "," << detail::pipeline_scheduler_name(kPipelineScheduler); // 24. BlkGemmPipeSched - oss << "," << kDoubleSmemBuffer; // 25. NumWaveGroups + oss << "," << kDoubleSmemBuffer; // 25. DoubleSmemBuffer oss << "," << kNumWaveGroups; // 26. NumWaveGroups oss << "," << detail::type_name(); // 27. AccDataType oss << "," << detail::type_name(); // 28. EDataType diff --git a/experimental/builder/test/CMakeLists.txt b/experimental/builder/test/CMakeLists.txt index 6ea06e4575..1089befe51 100644 --- a/experimental/builder/test/CMakeLists.txt +++ b/experimental/builder/test/CMakeLists.txt @@ -21,6 +21,7 @@ add_ck_builder_test(test_ckb_conv_builder test_conv_builder.cpp test_fwd_instance_traits.cpp test_bwd_weight_instance_traits.cpp + test_bwd_data_instance_traits.cpp test_instance_traits_util.cpp) add_ck_builder_test(test_ckb_inline_diff test_inline_diff.cpp) diff --git a/experimental/builder/test/test_bwd_data_instance_traits.cpp b/experimental/builder/test/test_bwd_data_instance_traits.cpp new file mode 100644 index 0000000000..d6d4749db7 --- /dev/null +++ b/experimental/builder/test/test_bwd_data_instance_traits.cpp @@ -0,0 +1,133 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include +#include +#include +#include + +namespace { + +TEST(InstanceTraits, TileInstanceStringReturnsCorrectFormat) +{ + using GroupedConvTraitsType = + ck_tile::GroupedConvTraits<2 /*NDimSpatial*/, + ck_tile::ConvolutionSpecialization::Default /*ConvSpec*/, + ck_tile::tensor_layout::convolution::NHWGC /*InLayout*/, + ck_tile::tensor_layout::convolution::GKYXC /*WeiLayout*/, + ck_tile::tuple<> /*DsLayout*/, + ck_tile::tensor_layout::convolution::NHWGK /*OutLayout*/, + 4 /*VectorSizeA*/, + 4 /*VectorSizeB*/, + 4 /*VectorSizeC*/, + 1 /*NumGroupsToMerge*/, + false /*EnableSplitImage*/>; + + using GemmShape = ck_tile::TileGemmShape< + ck_tile::sequence<128 /*M_Tile*/, 128 /*N_Tile*/, 32 /*K_Tile*/>, + ck_tile::sequence<4 /*M_Warp*/, 1 /*N_Warp*/, 1 /*K_Warp*/>, + ck_tile::sequence<16 /*M_Warp_Tile*/, 16 /*N_Warp_Tile*/, 16 /*K_Warp_Tile*/>>; + + using TilePartitioner = ck_tile::GemmSpatiallyLocalTilePartitioner< + GemmShape, + GroupedConvTraitsType::FixedGemmParams::TilePartitionerGroupNum, + GroupedConvTraitsType::FixedGemmParams::TilePartitionerM01>; + + using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits< + GroupedConvTraitsType::FixedGemmParams::kPadM, + GroupedConvTraitsType::FixedGemmParams::kPadN, + GroupedConvTraitsType::FixedGemmParams::kPadK, + false /*DoubleSmemBuffer*/, + typename GroupedConvTraitsType::AsLayoutBwdData, + typename GroupedConvTraitsType::BsLayoutBwdData, + typename GroupedConvTraitsType::CLayoutBwdData, + GroupedConvTraitsType::FixedGemmParams::TransposeC, + GroupedConvTraitsType::FixedGemmParams::UseStructuredSparsity, + GroupedConvTraitsType::FixedGemmParams::Persistent, + 1 /*NumWaveGroups*/>; + + using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem< + ck_tile::bf16_t /*OutDataType*/, + ck_tile::bf16_t /*WeiDataType*/, + float /*AccDataType*/, + GemmShape, + GemmUniversalTraits, + ck_tile::GemmPipelineScheduler::Intrawave /*scheduler*/, + true /*has_hot_loop_v*/, + ck_tile::TailNumber::Full /*tail_number_v*/, + ck_tile::element_wise::PassThrough /*AElementwiseOperation*/, + ck_tile::element_wise::PassThrough /*BElementwiseOperation*/, + ck_tile::bf16_t /*InDataType*/, + GroupedConvTraitsType::FixedGemmParams::FixedVectorSize, + GroupedConvTraitsType::VectorSizeA, + GroupedConvTraitsType::VectorSizeB>; + + using GemmPipeline = typename ck_tile::GemmPipelineAgBgCrCompV3; + + using ConvEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem /*DsDataType*/, + float /*AccDataType*/, + ck_tile::bf16_t /*InDataType*/, + typename GroupedConvTraitsType::ImplicitGemmDsLayout, + typename GroupedConvTraitsType::FixedGemmParams::ELayout, + ck_tile::element_wise::PassThrough /*CDElementWise*/, + 128 /*MPerBlock*/, + 128 /*NPerBlock*/, + 4 /*M_Warp*/, + 1 /*N_Warp*/, + 16 /*M_Warp_Tile*/, + 16 /*N_Warp_Tile*/, + 16 /*K_Warp_Tile*/, + GroupedConvTraitsType::FixedGemmParams::TransposeC, + ck_tile::memory_operation_enum::set /*memory_operation*/, + 1 /*kNumWaveGroups*/, + GroupedConvTraitsType::FixedGemmParams::FixedVectorSize, + GroupedConvTraitsType::VectorSizeC>>; + + using GroupedConvBwdDataKernel = + ck_tile::GroupedConvolutionBackwardDataKernel; + + std::string instance_str = ck_tile::reflect::instance_string(); + + std::string expected_str = "GroupedConvolutionBackwardDataKernel" + "<2" // NDimSpatial + ",Default" // ConvSpecialization + ",NHWGC" // InLayout + ",GKYXC" // WeiLayout + ",EmptyTuple" // DsLayout + ",NHWGK" // OutLayout + ",4" // VectorSizeA + ",4" // VectorSizeB + ",4" // VectorSizeC + ",1" // NumGroupsToMerge + ",0" // EnableSplitImage + ",128" // MPerBlock + ",128" // NPerBlock + ",32" // KPerBlock + ",4" // MWarp + ",1" // NWarp + ",1" // KWarp + ",16" // MWarpTile + ",16" // NWarpTile + ",16" // KWarpTile + ",bf16" // ADataType + ",bf16" // BDataType + ",COMPUTE_V3" // BlkGemmPipelineVer + ",Intrawave" // BlkGemmPipeSched + ",0" // DoubleSmemBuffer + ",1" // NumWaveGroups + ",fp32" // AccDataType + ",bf16" // EDataType + ",EmptyTuple" // DsDataType + ",PassThrough" // CDEElementwiseOperation + ">"; + + EXPECT_EQ(instance_str, expected_str); +} + +} // anonymous namespace diff --git a/experimental/builder/test/test_bwd_weight_instance_traits.cpp b/experimental/builder/test/test_bwd_weight_instance_traits.cpp index e1b89a6d49..a6aee7b210 100644 --- a/experimental/builder/test/test_bwd_weight_instance_traits.cpp +++ b/experimental/builder/test/test_bwd_weight_instance_traits.cpp @@ -5,6 +5,7 @@ #include #include #include +#include namespace { @@ -109,4 +110,126 @@ TEST(InstanceTraits, BwdWeightXdlCShuffleInstanceStringReturnsCorrectFormat) EXPECT_EQ(instance_str, expected_str); } +TEST(InstanceTraits, TileInstanceStringReturnsCorrectFormat) +{ + using GroupedConvTraitsType = + ck_tile::GroupedConvTraits<2 /*NDimSpatial*/, + ck_tile::ConvolutionSpecialization::Default /*ConvSpec*/, + ck_tile::tensor_layout::convolution::NHWGC /*InLayout*/, + ck_tile::tensor_layout::convolution::GKYXC /*WeiLayout*/, + ck_tile::tuple<> /*DsLayout*/, + ck_tile::tensor_layout::convolution::NHWGK /*OutLayout*/, + 4 /*VectorSizeA*/, + 4 /*VectorSizeB*/, + 4 /*VectorSizeC*/, + 1 /*NumGroupsToMerge*/, + false /*EnableSplitImage*/>; + + using GemmShape = ck_tile::TileGemmShape< + ck_tile::sequence<128 /*M_Tile*/, 128 /*N_Tile*/, 32 /*K_Tile*/>, + ck_tile::sequence<4 /*M_Warp*/, 1 /*N_Warp*/, 1 /*K_Warp*/>, + ck_tile::sequence<16 /*M_Warp_Tile*/, 16 /*N_Warp_Tile*/, 16 /*K_Warp_Tile*/>>; + + using TilePartitioner = ck_tile::GemmSpatiallyLocalTilePartitioner< + GemmShape, + GroupedConvTraitsType::FixedGemmParams::TilePartitionerGroupNum, + GroupedConvTraitsType::FixedGemmParams::TilePartitionerM01>; + + using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits< + GroupedConvTraitsType::FixedGemmParams::kPadM, + GroupedConvTraitsType::FixedGemmParams::kPadN, + GroupedConvTraitsType::FixedGemmParams::kPadK, + false /*DoubleSmemBuffer*/, + typename GroupedConvTraitsType::AsLayoutBwdWeight, + typename GroupedConvTraitsType::BsLayoutBwdWeight, + typename GroupedConvTraitsType::CLayoutBwdWeight, + GroupedConvTraitsType::FixedGemmParams::TransposeC, + GroupedConvTraitsType::FixedGemmParams::UseStructuredSparsity, + GroupedConvTraitsType::FixedGemmParams::Persistent, + 1 /*NumWaveGroups*/>; + + using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem< + ck_tile::bf16_t /*OutDataType*/, + ck_tile::bf16_t /*InDataType*/, + float /*AccDataType*/, + GemmShape, + GemmUniversalTraits, + ck_tile::GemmPipelineScheduler::Intrawave /*scheduler*/, + true /*has_hot_loop_v*/, + ck_tile::TailNumber::Full /*tail_number_v*/, + ck_tile::element_wise::PassThrough /*AElementwiseOperation*/, + ck_tile::element_wise::PassThrough /*BElementwiseOperation*/, + ck_tile::bf16_t /*WeiDataType*/, + GroupedConvTraitsType::FixedGemmParams::FixedVectorSize, + GroupedConvTraitsType::VectorSizeA, + GroupedConvTraitsType::VectorSizeB>; + + using GemmPipeline = typename ck_tile::GemmPipelineAgBgCrCompV3; + + using ConvEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem /*DsDataType*/, + float /*AccDataType*/, + ck_tile::bf16_t /*WeiDataType*/, + typename GroupedConvTraitsType::ImplicitGemmDsLayout, + typename GroupedConvTraitsType::FixedGemmParams::ELayout, + ck_tile::element_wise::PassThrough /*CDElementWise*/, + 128 /*MPerBlock*/, + 128 /*NPerBlock*/, + 4 /*M_Warp*/, + 1 /*N_Warp*/, + 16 /*M_Warp_Tile*/, + 16 /*N_Warp_Tile*/, + 16 /*K_Warp_Tile*/, + GroupedConvTraitsType::FixedGemmParams::TransposeC, + ck_tile::memory_operation_enum::set /*memory_operation*/, + 1 /*kNumWaveGroups*/, + GroupedConvTraitsType::FixedGemmParams::FixedVectorSize, + GroupedConvTraitsType::VectorSizeC>>; + + using GroupedConvBwdWeiKernel = + ck_tile::GroupedConvolutionBackwardWeightKernel; + + std::string instance_str = ck_tile::reflect::instance_string(); + + std::string expected_str = "GroupedConvolutionBackwardWeightKernel" + "<2" // NDimSpatial + ",Default" // ConvSpecialization + ",NHWGC" // InLayout + ",GKYXC" // WeiLayout + ",EmptyTuple" // DsLayout + ",NHWGK" // OutLayout + ",4" // VectorSizeA + ",4" // VectorSizeB + ",4" // VectorSizeC + ",1" // NumGroupsToMerge + ",0" // EnableSplitImage + ",128" // MPerBlock + ",128" // NPerBlock + ",32" // KPerBlock + ",4" // MWarp + ",1" // NWarp + ",1" // KWarp + ",16" // MWarpTile + ",16" // NWarpTile + ",16" // KWarpTile + ",bf16" // ADataType + ",bf16" // BDataType + ",COMPUTE_V3" // BlkGemmPipelineVer + ",Intrawave" // BlkGemmPipeSched + ",0" // DoubleSmemBuffer + ",1" // NumWaveGroups + ",fp32" // AccDataType + ",bf16" // EDataType + ",EmptyTuple" // DsDataType + ",PassThrough" // CDEElementwiseOperation + ">"; + + EXPECT_EQ(instance_str, expected_str); +} + } // anonymous namespace diff --git a/experimental/builder/test/test_fwd_instance_traits.cpp b/experimental/builder/test/test_fwd_instance_traits.cpp index c414da7458..1203686f6c 100644 --- a/experimental/builder/test/test_fwd_instance_traits.cpp +++ b/experimental/builder/test/test_fwd_instance_traits.cpp @@ -799,11 +799,10 @@ TEST(InstanceTraits, TileInstanceStringReturnsCorrectFormat) GroupedConvTraitsType::FixedGemmParams::FixedVectorSize, GroupedConvTraitsType::VectorSizeC>>; - using GroupedConvFwdKernel = - ck_tile::device::GroupedConvolutionForwardKernel; + using GroupedConvFwdKernel = ck_tile::GroupedConvolutionForwardKernel; std::string instance_str = ck_tile::reflect::instance_string(); diff --git a/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_data_kernel.hpp b/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_data_kernel.hpp index b1ed80b5ea..86f5684e73 100644 --- a/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_data_kernel.hpp +++ b/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_data_kernel.hpp @@ -14,6 +14,10 @@ #include "ck_tile/ops/grouped_convolution/utils/transform_conv_bwd_data_to_gemm.hpp" #include "ck_tile/ops/grouped_convolution/utils/grouped_convolution_utils.hpp" +#ifdef CK_EXPERIMENTAL_BUILDER +#include "ck_tile/builder/reflect/instance_traits_tile_grouped_convolution_backward_data.hpp" +#endif + namespace ck_tile { /// @brief The Grouped Convolution kernel device arguments. @@ -565,6 +569,19 @@ struct GroupedConvolutionBackwardDataKernel // clang-format on } +#ifdef CK_EXPERIMENTAL_BUILDER + CK_TILE_HOST std::string GetInstanceString() const + { + static_assert(ck_tile::reflect::HasInstanceTraits, + "Specialization of instance_traits not found. Please check that a " + "specialization exists in file " + "ck_tile/builder/reflect/" + "instance_traits_tile_grouped_convolution_backward_data.hpp " + "for the given template parameters."); + return ck_tile::reflect::instance_string(); + } +#endif + CK_TILE_HOST static auto GridSize(const GroupedConvBwdDataKernelArgsSpecialized& kargs) { // enable batched grouped gemm diff --git a/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_weight_kernel.hpp b/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_weight_kernel.hpp index 3407c67ad1..0143afae7a 100644 --- a/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_weight_kernel.hpp +++ b/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_weight_kernel.hpp @@ -14,6 +14,10 @@ #include "ck_tile/ops/grouped_convolution/utils/transform_conv_bwd_weight_to_gemm.hpp" #include "ck_tile/ops/grouped_convolution/utils/grouped_convolution_utils.hpp" +#ifdef CK_EXPERIMENTAL_BUILDER +#include "ck_tile/builder/reflect/instance_traits_tile_grouped_convolution_backward_weight.hpp" +#endif + namespace ck_tile { /// @brief The Grouped Convolution kernel device arguments. @@ -430,6 +434,19 @@ struct GroupedConvolutionBackwardWeightKernel // clang-format on } +#ifdef CK_EXPERIMENTAL_BUILDER + CK_TILE_HOST std::string GetInstanceString() const + { + static_assert(ck_tile::reflect::HasInstanceTraits, + "Specialization of instance_traits not found. Please check that a " + "specialization exists in file " + "ck_tile/builder/reflect/" + "instance_traits_tile_grouped_convolution_backward_weight.hpp " + "for the given template parameters."); + return ck_tile::reflect::instance_string(); + } +#endif + CK_TILE_HOST static constexpr auto GridSize(const GroupedConvBwdWeightKernelArgsSpecialized& kargs) {