diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_algorithms.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_algorithms.hpp index 84f8b688ad..a4cbe55eeb 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_algorithms.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_algorithms.hpp @@ -76,6 +76,13 @@ concept FwdXdlV3Algorithm = SpecifiesGridwiseFwdXdlGemm && SpecifiesFwdConvSpecialization && SpecifiesGemmSpecialization && SpecifiesBlockGemm && SpecifiesNumGroupsToMerge; +// FWD WMMA V3 algorithm concept +template +concept FwdWmmaV3Algorithm = + ConvAlgorithmDescriptor && SpecifiesThreadBlock && SpecifiesTileTransferParameters3D && + SpecifiesGridwiseWmmaGemm && SpecifiesFwdConvSpecialization && + SpecifiesGemmSpecialization && SpecifiesBlockGemm && SpecifiesNumGroupsToMerge; + // FWD WMMA algorithm concepts template concept FwdWmmaAlgorithm = diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_dispatcher.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_dispatcher.hpp index 857bc4b7c2..a5d9844419 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_dispatcher.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_dispatcher.hpp @@ -64,6 +64,7 @@ #include "ck_tile/builder/factory/conv_fwd_v3_factory.hpp" #include "ck_tile/builder/factory/conv_fwd_xdl_factory.hpp" #include "ck_tile/builder/factory/conv_fwd_wmma_factory.hpp" +#include "ck_tile/builder/factory/conv_fwd_wmma_v3_factory.hpp" #include "ck_tile/builder/factory/conv_fwd_dl_factory.hpp" #include "ck_tile/builder/factory/conv_fwd_large_tensor_factory.hpp" #include "ck_tile/builder/factory/reference_factory.hpp" @@ -130,6 +131,10 @@ constexpr auto make_conv_instance() { return typename ConvFwdXdlFactory::Instance{}; } + else if constexpr(FwdWmmaV3Algorithm) + { + return typename ConvFwdWmmaV3Factory::Instance{}; + } else if constexpr(FwdWmmaAlgorithm) { return typename ConvFwdWmmaFactory::Instance{}; diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_fwd_wmma_v3_factory.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_fwd_wmma_v3_factory.hpp new file mode 100644 index 0000000000..c5597f3c7d --- /dev/null +++ b/experimental/builder/include/ck_tile/builder/factory/conv_fwd_wmma_v3_factory.hpp @@ -0,0 +1,159 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_wmma_cshuffle_v3.hpp" +#include "ck_tile/builder/conv_signature_concepts.hpp" +#include "ck_tile/builder/conv_algorithm_concepts.hpp" +#include "ck_tile/builder/conv_algorithm_limits.hpp" +#include "ck_tile/builder/builder_utils.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_tensor_layout.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_tensor_type.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_elementwise_op.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_tuning_params.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_block_transfer.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_thread_block.hpp" + +namespace ck_tile::builder::factory { + +// Factory for DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3 instance +// of a grouped forward convolution kernel. +template + requires ConvDirectionIsForward +struct ConvFwdWmmaV3Factory +{ + static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim; + using Layouts = internal::ConvTensorLayouts; + using Types = internal::ConvTensorDataTypes; + using Ops = internal::ConvElementwiseOps; + using AlgorithmType = decltype(ALGORITHM); + + static_assert(ALGORITHM.transfer.a.lds_transfer.is_direct_load == + ALGORITHM.transfer.b.lds_transfer.is_direct_load, + "A and B block transfers must both be direct load or not."); + + static constexpr auto FWD_CONV_SPECIALIZATION = internal::SetFwdConvSpecialization(); + static constexpr auto GEMM_SPECIALIZATION = internal::SetGemmSpecialization(); + static constexpr internal::ConvSpec SPECIALIZATION{.conv_spec = FWD_CONV_SPECIALIZATION, + .gemm_spec = GEMM_SPECIALIZATION}; + + static constexpr auto BLOCK = internal::SetThreadBlockInfo(); + static constexpr auto GRIDWISE_GEMM = ALGORITHM.gridwise_gemm; + static constexpr auto A_BLOCK_TRANSFER = + internal::SetFwdConvBlockTransfer(); + static constexpr auto B_BLOCK_TRANSFER = + internal::SetFwdConvBlockTransfer(); + static constexpr auto C_BLOCK_TRANSFER = internal::SetCBlockTransfer(); + static constexpr auto BLOCK_GEMM = internal::SetBlockGemm(); + + // Check limits for the algorithm parameters. + static_assert(ValidABlockTransfer); + static_assert(ValidBBlockTransfer); + static_assert(ValidCBlockTransfer); + + // Layout validations + using enum TensorLayout; + static_assert(IsValidLayout && + A_BLOCK_TRANSFER.src_vector_dim == 2); + + static_assert(IsValidLayout && + B_BLOCK_TRANSFER.src_vector_dim == 2); + + static_assert(IsValidLayout); + + // The forward convolution kernel class instance. + using Instance = ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3< + SPATIAL_DIM, + typename Layouts::InLayout, + typename Layouts::WeiLayout, + typename Layouts::DsLayout, + typename Layouts::OutLayout, + typename Types::InDataType, + typename Types::WeiDataType, + typename Types::AccDataType, + typename Types::OutComputeType, + typename Types::DsDataType, + typename Types::OutDataType, + typename Ops::InElementwiseOp, + typename Ops::WeiElementwiseOp, + typename Ops::OutElementwiseOp, + SPECIALIZATION.conv_spec, + SPECIALIZATION.gemm_spec, + BLOCK.block_size, + BLOCK.per_block.m, + BLOCK.per_block.n, + BLOCK.per_block.k, + GRIDWISE_GEMM.ak1, + GRIDWISE_GEMM.bk1, + GRIDWISE_GEMM.m_per_wmma, + GRIDWISE_GEMM.n_per_wmma, + GRIDWISE_GEMM.m_wmma_per_wave, + GRIDWISE_GEMM.n_wmma_per_wave, + to_sequence_v, + to_sequence_v, + to_sequence_v, + A_BLOCK_TRANSFER.src_vector_dim, + A_BLOCK_TRANSFER.src_scalar_per_vector, + A_BLOCK_TRANSFER.lds_dst_scalar_per_vector, + static_cast(A_BLOCK_TRANSFER.lds_padding), + to_sequence_v, + to_sequence_v, + to_sequence_v, + B_BLOCK_TRANSFER.src_vector_dim, + B_BLOCK_TRANSFER.src_scalar_per_vector, + B_BLOCK_TRANSFER.lds_dst_scalar_per_vector, + static_cast(B_BLOCK_TRANSFER.lds_padding), + C_BLOCK_TRANSFER.m_xdl_per_wave_per_shuffle, + C_BLOCK_TRANSFER.n_xdl_per_wave_per_shuffle, + to_sequence_v, + C_BLOCK_TRANSFER.scalar_per_vector, + BLOCK_GEMM.scheduler, + BLOCK_GEMM.pipeline_version, + true, // UseThreadTileTransfer + typename Types::InComputeType, + typename Types::WeiComputeType, + ALGORITHM.num_conv_groups_to_merge>; +}; + +} // namespace ck_tile::builder::factory diff --git a/experimental/builder/include/ck_tile/builder/reflect/conv_traits_device_grouped_conv_fwd_multiple_abd_wmma_cshuffle_v3.hpp b/experimental/builder/include/ck_tile/builder/reflect/conv_traits_device_grouped_conv_fwd_multiple_abd_wmma_cshuffle_v3.hpp new file mode 100644 index 0000000000..865686666b --- /dev/null +++ b/experimental/builder/include/ck_tile/builder/reflect/conv_traits_device_grouped_conv_fwd_multiple_abd_wmma_cshuffle_v3.hpp @@ -0,0 +1,48 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include + +#include "ck_tile/builder/reflect/conv_traits.hpp" +#include "ck_tile/builder/reflect/conv_traits_helpers.hpp" +#include "ck_tile/builder/reflect/instance_traits.hpp" +#include "ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_abd_wmma_cshuffle_v3.hpp" + +namespace ck_tile::reflect::conv { + +/// @brief Tag dispatch implementation for DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3 +template + requires HasInstanceTraits && + std::same_as::device_kernel_tag, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3_Tag> +constexpr ConvTraits instance_to_conv_traits() +{ + using InstTraits = InstanceTraits; + + return ConvTraits{ + .spatial_dim = InstTraits::kSpatialDim, + .direction = conv_direction(), + .layout = fwd_conv_layout(), + .data_type = conv_data_type(), + .input_element_op = elementwise_op(), + .weight_element_op = elementwise_op(), + .output_element_op = elementwise_op(), + .gemm_padding = gemm_spec(), + .conv_specialization = conv_spec(), + .thread_block_size = InstTraits::kBlockSize, + .tile_dims = conv_traits_data_tile(), + .a_tile_transfer = conv_traits_a_transfer_params(InstTraits::kAK1), + .b_tile_transfer = conv_traits_b_transfer_params(InstTraits::kBK1), + .warp_gemm = conv_traits_wmma_warp_gemm_params(), + .c_tile_transfer = conv_traits_wmma_c_tile_transfer(), + // TODO: Add compute types (AComputeDataType, BComputeDataType) when ConvTraits supports + // them + // TODO: Add NumGroupsToMerge when ConvTraits supports it + .pipeline_version = get_pipeline_version(), + .pipeline_scheduler = get_pipeline_scheduler(), + }; +} + +} // namespace ck_tile::reflect::conv diff --git a/experimental/builder/include/ck_tile/builder/reflect/instance_to_conv_traits.hpp b/experimental/builder/include/ck_tile/builder/reflect/instance_to_conv_traits.hpp index cb4b3b2175..271c826d5c 100644 --- a/experimental/builder/include/ck_tile/builder/reflect/instance_to_conv_traits.hpp +++ b/experimental/builder/include/ck_tile/builder/reflect/instance_to_conv_traits.hpp @@ -8,6 +8,7 @@ #include "ck_tile/builder/reflect/conv_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp" #include "ck_tile/builder/reflect/conv_traits_device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp" #include "ck_tile/builder/reflect/conv_traits_device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp" +#include "ck_tile/builder/reflect/conv_traits_device_grouped_conv_fwd_multiple_abd_wmma_cshuffle_v3.hpp" // Bwd weight instances #include "ck_tile/builder/reflect/conv_traits_device_grouped_conv_bwd_weight_xdl_cshuffle.hpp" diff --git a/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_abd_wmma_cshuffle_v3.hpp b/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_abd_wmma_cshuffle_v3.hpp new file mode 100644 index 0000000000..993e8db8b0 --- /dev/null +++ b/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_abd_wmma_cshuffle_v3.hpp @@ -0,0 +1,16 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "instance_traits.hpp" +#include "instance_traits_util.hpp" + +namespace ck_tile::reflect { + +/// @brief Tag type for DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3 device kernel +struct DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3_Tag +{ +}; + +} // namespace ck_tile::reflect diff --git a/experimental/builder/include/ck_tile/builder/reflect/reflect_device_grouped_conv_fwd_multiple_abd_wmma_cshuffle_v3.inc b/experimental/builder/include/ck_tile/builder/reflect/reflect_device_grouped_conv_fwd_multiple_abd_wmma_cshuffle_v3.inc new file mode 100644 index 0000000000..09f8811f3e --- /dev/null +++ b/experimental/builder/include/ck_tile/builder/reflect/reflect_device_grouped_conv_fwd_multiple_abd_wmma_cshuffle_v3.inc @@ -0,0 +1,302 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +// InstanceTraits specialization for DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3 +// +// This .inc file is #included at the bottom of the device op header +// (device_grouped_conv_fwd_multiple_abd_wmma_cshuffle_v3.hpp) under +// #ifdef CK_EXPERIMENTAL_BUILDER, AFTER the struct is fully defined. +// This eliminates the need for forward declarations. +// +// CRITICAL MAINTENANCE NOTE: +// This file MUST be kept strictly in sync with the device implementation header: +// ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_wmma_cshuffle_v3.hpp +// The template parameter order, names, and types MUST EXACTLY MATCH those in the device +// implementation. + +#include "ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_abd_wmma_cshuffle_v3.hpp" + +namespace ck_tile::reflect { + +// Specialization for DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3 +template +struct InstanceTraits< + ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3< + NDimSpatial, + ALayout_, + BLayout_, + DsLayout_, + ELayout_, + ADataType_, + BDataType_, + AccDataType_, + CShuffleDataType_, + DsDataType_, + EDataType_, + AElementwiseOperation_, + BElementwiseOperation_, + CDEElementwiseOperation_, + ConvForwardSpecialization, + GemmSpec, + BlockSize, + MPerBlock, + NPerBlock, + KPerBlock, + AK1, + BK1, + MPerWmma, + NPerWmma, + MRepeat, + NRepeat, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder_, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder_, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + BBlockLdsExtraN, + CShuffleMRepeatPerShuffle, + CShuffleNRepeatPerShuffle, + CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CDEBlockTransferScalarPerVector_NPerBlock, + BlkGemmPipeSched, + BlkGemmPipelineVer, + UseThreadTileTransfer, + AComputeDataType_, + BComputeDataType_, + NumGroupsToMerge>> +{ + /// @brief Tag type identifying this device kernel variant + using device_kernel_tag = DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3_Tag; + + // Spatial dimension + static constexpr int kSpatialDim = NDimSpatial; + + // Layout types + using ALayout = ALayout_; + using BLayout = BLayout_; + using DsLayout = DsLayout_; + using ELayout = ELayout_; + + // Data types + using ADataType = ADataType_; + using BDataType = BDataType_; + using AccDataType = AccDataType_; + using CShuffleDataType = CShuffleDataType_; + using DsDataType = DsDataType_; + using EDataType = EDataType_; + + // Element-wise operations + using AElementwiseOperation = AElementwiseOperation_; + using BElementwiseOperation = BElementwiseOperation_; + using CDEElementwiseOperation = CDEElementwiseOperation_; + + // Specialization + static constexpr ck::tensor_operation::device::ConvolutionForwardSpecialization + kConvForwardSpecialization = ConvForwardSpecialization; + static constexpr ck::tensor_operation::device::GemmSpecialization kGemmSpecialization = + GemmSpec; + + // Block configuration + static constexpr int kBlockSize = BlockSize; + static constexpr int kMPerBlock = MPerBlock; + static constexpr int kNPerBlock = NPerBlock; + static constexpr int kKPerBlock = KPerBlock; + + // Tuning parameters + static constexpr int kAK1 = AK1; + static constexpr int kBK1 = BK1; + static constexpr int kMPerWmma = MPerWmma; + static constexpr int kNPerWmma = NPerWmma; + static constexpr int kMRepeat = MRepeat; + static constexpr int kNRepeat = NRepeat; + + // A block transfer thread cluster dimensions (converted to std::array) + static constexpr auto kAThreadClusterLengths = + detail::SequenceToArray::value; + static constexpr auto kAThreadClusterArrangeOrder = + detail::SequenceToArray::value; + static constexpr auto kABlockTransferSrcAccessOrder = + detail::SequenceToArray::value; + static constexpr int kABlockTransferSrcVectorDim = ABlockTransferSrcVectorDim; + static constexpr int kABlockTransferSrcScalarPerVector = ABlockTransferSrcScalarPerVector; + static constexpr int kABlockTransferDstScalarPerVectorK1 = ABlockTransferDstScalarPerVector_AK1; + static constexpr int kABlockLdsExtraM = ABlockLdsExtraM; + + // B block transfer thread cluster dimensions (converted to std::array) + static constexpr auto kBThreadClusterLengths = + detail::SequenceToArray::value; + static constexpr auto kBThreadClusterArrangeOrder = + detail::SequenceToArray::value; + static constexpr auto kBBlockTransferSrcAccessOrder = + detail::SequenceToArray::value; + static constexpr int kBBlockTransferSrcVectorDim = BBlockTransferSrcVectorDim; + static constexpr int kBBlockTransferSrcScalarPerVector = BBlockTransferSrcScalarPerVector; + static constexpr int kBBlockTransferDstScalarPerVectorK1 = BBlockTransferDstScalarPerVector_BK1; + static constexpr int kBBlockLdsExtraN = BBlockLdsExtraN; + + // C shuffle parameters (converted to std::array) + static constexpr int kCShuffleMRepeatPerShuffle = CShuffleMRepeatPerShuffle; + static constexpr int kCShuffleNRepeatPerShuffle = CShuffleNRepeatPerShuffle; + static constexpr auto kCDEThreadClusterLengths = detail::SequenceToArray< + CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock>::value; + static constexpr int kCDEBlockTransferScalarPerVector = + CDEBlockTransferScalarPerVector_NPerBlock; + + // Pipeline configuration + static constexpr ck::BlockGemmPipelineScheduler kPipelineScheduler = BlkGemmPipeSched; + static constexpr ck::BlockGemmPipelineVersion kPipelineVersion = BlkGemmPipelineVer; + + static constexpr bool kUseThreadTileTransfer = UseThreadTileTransfer; + + // Compute data types + using AComputeDataType = AComputeDataType_; + using BComputeDataType = BComputeDataType_; + + static constexpr int kNumGroupsToMerge = NumGroupsToMerge; + + // Static member function to generate instance string + static std::string instance_string() + { + std::ostringstream oss; + + // Kernel type name + oss << "DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3"; + + // Template parameters in exact order + oss << "<" << kSpatialDim; // 1. NDimSpatial + oss << "," << detail::layout_name(); // 2. ALayout + oss << "," << detail::layout_name(); // 3. BLayout + oss << "," << detail::tuple_name(); // 4. DsLayout + oss << "," << detail::layout_name(); // 5. ELayout + oss << "," << detail::type_or_type_tuple_name(); // 6. ADataType + oss << "," << detail::type_or_type_tuple_name(); // 7. BDataType + oss << "," << detail::type_name(); // 8. AccDataType + oss << "," << detail::type_name(); // 9. CShuffleDataType + oss << "," << detail::tuple_name(); // 10. DsDataType + oss << "," << detail::type_name(); // 11. EDataType + oss << "," + << detail::elementwise_op_name(); // 12. AElementwiseOperation + oss << "," + << detail::elementwise_op_name(); // 13. BElementwiseOperation + oss << "," + << detail::elementwise_op_name(); // 14. + // CDEElementwiseOperation + oss << "," + << detail::conv_fwd_spec_name( + kConvForwardSpecialization); // 15. ConvForwardSpecialization + oss << "," << detail::gemm_spec_name(kGemmSpecialization); // 16. GemmSpec + oss << "," << kBlockSize; // 17. BlockSize + oss << "," << kMPerBlock; // 18. MPerBlock + oss << "," << kNPerBlock; // 19. NPerBlock + oss << "," << kKPerBlock; // 20. KPerBlock + oss << "," << kAK1; // 21. AK1 + oss << "," << kBK1; // 22. BK1 + oss << "," << kMPerWmma; // 23. MPerWmma + oss << "," << kNPerWmma; // 24. NPerWmma + oss << "," << kMRepeat; // 25. MRepeat + oss << "," << kNRepeat; // 26. NRepeat + oss << "," + << detail::array_to_string( + kAThreadClusterLengths); // 27. ABlockTransferThreadClusterLengths + oss << "," + << detail::array_to_string( + kAThreadClusterArrangeOrder); // 28. ABlockTransferThreadClusterArrangeOrder + oss << "," + << detail::array_to_string( + kABlockTransferSrcAccessOrder); // 29. ABlockTransferSrcAccessOrder + oss << "," << kABlockTransferSrcVectorDim; // 30. ABlockTransferSrcVectorDim + oss << "," << kABlockTransferSrcScalarPerVector; // 31. ABlockTransferSrcScalarPerVector + oss << "," + << kABlockTransferDstScalarPerVectorK1; // 32. ABlockTransferDstScalarPerVector_AK1 + oss << "," << (kABlockLdsExtraM ? "true" : "false"); // 33. ABlockLdsExtraM + oss << "," + << detail::array_to_string( + kBThreadClusterLengths); // 34. BBlockTransferThreadClusterLengths + oss << "," + << detail::array_to_string( + kBThreadClusterArrangeOrder); // 35. BBlockTransferThreadClusterArrangeOrder + oss << "," + << detail::array_to_string( + kBBlockTransferSrcAccessOrder); // 36. BBlockTransferSrcAccessOrder + oss << "," << kBBlockTransferSrcVectorDim; // 37. BBlockTransferSrcVectorDim + oss << "," << kBBlockTransferSrcScalarPerVector; // 38. BBlockTransferSrcScalarPerVector + oss << "," + << kBBlockTransferDstScalarPerVectorK1; // 39. BBlockTransferDstScalarPerVector_BK1 + oss << "," << (kBBlockLdsExtraN ? "true" : "false"); // 40. BBlockLdsExtraN + oss << "," << kCShuffleMRepeatPerShuffle; // 41. CShuffleMRepeatPerShuffle + oss << "," << kCShuffleNRepeatPerShuffle; // 42. CShuffleNRepeatPerShuffle + oss << "," + << detail::array_to_string( + kCDEThreadClusterLengths); // 43. CDEBlockTransferClusterLengths + oss << "," + << kCDEBlockTransferScalarPerVector; // 44. CDEBlockTransferScalarPerVector_NPerBlock + oss << "," << detail::pipeline_scheduler_name(kPipelineScheduler); // 45. BlkGemmPipeSched + oss << "," << detail::pipeline_version_name(kPipelineVersion); // 46. BlkGemmPipelineVer + oss << "," << (kUseThreadTileTransfer ? "true" : "false"); // 47. UseThreadTileTransfer + oss << "," << detail::type_name(); // 48. AComputeDataType + oss << "," << detail::type_name(); // 49. BComputeDataType + oss << "," << kNumGroupsToMerge; // 50. NumGroupsToMerge + oss << ">"; + + return oss.str(); + } +}; + +} // namespace ck_tile::reflect diff --git a/experimental/builder/test/CMakeLists.txt b/experimental/builder/test/CMakeLists.txt index 57fc4cc779..c12375eeb3 100644 --- a/experimental/builder/test/CMakeLists.txt +++ b/experimental/builder/test/CMakeLists.txt @@ -146,6 +146,7 @@ set(INSTANCE_STRING_TESTS if (CK_USE_WMMA) list(APPEND INSTANCE_STRING_TESTS + test_instance_string_fwd_grp_conv_wmma_v3.cpp test_instance_string_bwd_weight_grp_conv_wmma_v3.cpp test_instance_string_bwd_weight_grp_conv_multiple_d_wmma_v3.cpp test_instance_string_bwd_weight_grp_conv_two_stage_wmma_v3.cpp @@ -172,6 +173,13 @@ add_ck_builder_test(test_ckb_build_fwd_instances conv/ck/test_ckb_conv_fwd_3d_fp32.cpp conv/ck_tile/test_ckb_conv_fwd_2d_fp16_v3.cpp ) + +if (CK_USE_WMMA) + target_sources(test_ckb_build_fwd_instances PRIVATE + conv/ck/test_ckb_conv_fwd_2d_wmma_v3_fp16.cpp + ) +endif() + target_link_libraries(test_ckb_build_fwd_instances PRIVATE utility) set(BWD_WEIGHT_TESTS diff --git a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_wmma_v3_fp16.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_wmma_v3_fp16.cpp new file mode 100644 index 0000000000..ca84b89b3d --- /dev/null +++ b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_wmma_v3_fp16.cpp @@ -0,0 +1,105 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "utils/ckb_conv_test_configs.hpp" +#include "utils/ckb_conv_test_utils.hpp" +#include "utils/conv_algorithm_type_utils.hpp" +#include "ck_tile/builder/testing/conv/fwd.hpp" +#include "ck_tile/builder/testing/conv/fwd_ck.hpp" +#include "ck_tile/builder/testing/conv/reference.hpp" +#include "ck_tile/host/device_prop.hpp" +#include "testing_utils.hpp" + +namespace ckb = ck_tile::builder; +namespace ckt = ck_tile::builder::test; +namespace cku = ck_tile::builder::test_utils; + +using ck_tile::test::MatchesReference; +using ck_tile::test::SuccessfulRun; + +constexpr auto SIGNATURE = + ckt::ConvSignature{.spatial_dim = 2, + .direction = ckb::ConvDirection::FORWARD, + .data_type = ckb::DataType::FP16, + .accumulation_data_type = ckb::DataType::FP32, + .input = {.config = {.layout = ckb::TensorLayout::GNHWC}}, + .weight = {.config = {.layout = ckb::TensorLayout::GKYXC}}, + .output = {.config = {.layout = ckb::TensorLayout::GNHWK}}}; + +constexpr auto ALGORITHM = cku::ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3{} + .with_thread_block(cku::ThreadBlock_64_64x64x32) + .with_gemm_config(cku::GemmParamsABK1_Wmma_16x16_4x2_per_wave) + .with_transfer(cku::Transfer_4x16x1) + .with_fwd_specializations(ckb::ConvSpecialization::DEFAULT, + ckb::GemmSpecialization::MNKPadding) + .with_block_gemm(cku::BlockGemmDesc_v1_intrawave) + .with_num_conv_groups_to_merge(1); + +using Builder = ckb::ConvBuilder; +using Instance = Builder::Instance; + +using Reference = ckb::ConvBuilder::Instance; + +TEST(Fwd2DFp16_WmmaV3_GNHWC, Create) +{ + const auto expected_transfer_parameters = to_string(ALGORITHM); + cku::run_test({"DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3", + expected_transfer_parameters, + "Default", + "Intrawave", + "v1", + "GNHWC,GKYXC,EmptyTuple,GNHWK", + "PassThrough,PassThrough,PassThrough", + "MNKPadding"}); +} + +TEST(Fwd2DFp16_WmmaV3_GNHWC, Execution) +{ + if(!ck_tile::get_device_name().starts_with("gfx11") && + !ck_tile::get_device_name().starts_with("gfx12")) + { + // Note: WMMA kernel requires gfx11 or gfx12 + GTEST_SKIP() << "unsupported architecture"; + } + + ckt::Args args = { + .lengths = + { + .batch_size = 16, + .groups = 1, + .input_channels = 32, + .output_channels = 48, + .image = + { + .width = 56, + .height = 64, + }, + .filter = + { + .width = 3, + .height = 5, + }, + }, + .filter_strides = {.width = 1, .height = 1}, + .filter_dilation = {.width = 1, .height = 1}, + .input_left_pad = {.width = 0, .height = 0}, + .input_right_pad = {.width = 0, .height = 0}, + .a_elementwise_op = {}, + .b_elementwise_op = {}, + .cde_elementwise_op = {}, + }; + + auto inputs = ckt::alloc_inputs(args); + auto outputs = ckt::alloc_outputs(args); + auto reference = ckt::alloc_outputs(args); + + ckt::init_inputs(args, inputs.get()); + + auto conv = Instance{}; + EXPECT_THAT(ckt::run(conv, args, inputs.get(), outputs.get()), SuccessfulRun()); + + auto ref_conv = Reference{}; + EXPECT_THAT(ckt::run(ref_conv, args, inputs.get(), reference.get()), SuccessfulRun()); + + EXPECT_THAT(outputs.get(), MatchesReference(args, reference.get())); +} diff --git a/experimental/builder/test/impl/conv_algorithm_types.hpp b/experimental/builder/test/impl/conv_algorithm_types.hpp index 59d29b1280..4b99fd8100 100644 --- a/experimental/builder/test/impl/conv_algorithm_types.hpp +++ b/experimental/builder/test/impl/conv_algorithm_types.hpp @@ -632,6 +632,14 @@ using ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 = BlockGemm_, GemmBatchOptions_>; +using ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3 = + ConvAlgorithmTemplate, + ConvSpecializationFwd_, + BlockGemm_, + GemmBatchOptions_>; + using ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle = ConvAlgorithmTemplate, // DsLayout + ck::tensor_layout::convolution::GNHWK, // ELayout + ck::half_t, // ADataType + ck::half_t, // BDataType + float, // AccDataType + ck::half_t, // CShuffleDataType + ck::Tuple<>, // DsDataType + ck::half_t, // EDataType + ck::tensor_operation::element_wise::PassThrough, // AElementwiseOperation + ck::tensor_operation::element_wise::PassThrough, // BElementwiseOperation + ck::tensor_operation::element_wise::PassThrough, // CDEElementwiseOperation + ck::tensor_operation::device::ConvolutionForwardSpecialization:: + Default, // ConvForwardSpec + ck::tensor_operation::device::GemmSpecialization::MNKPadding, // GemmSpec + 64, // BlockSize + 64, // MPerBlock + 64, // NPerBlock + 32, // KPerBlock + 8, // AK1 + 8, // BK1 + 16, // MPerWmma + 16, // NPerWmma + 4, // MRepeat + 2, // NRepeat + ck::Sequence<4, 16, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1 + ck::Sequence<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder + ck::Sequence<1, 0, 2>, // ABlockTransferSrcAccessOrder + 2, // ABlockTransferSrcVectorDim + 1, // ABlockTransferSrcScalarPerVector + 8, // ABlockTransferDstScalarPerVector_AK1 + 1, // ABlockLdsExtraM + ck::Sequence<4, 16, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1 + ck::Sequence<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder + ck::Sequence<1, 0, 2>, // BBlockTransferSrcAccessOrder + 2, // BBlockTransferSrcVectorDim + 1, // BBlockTransferSrcScalarPerVector + 8, // BBlockTransferDstScalarPerVector_BK1 + 1, // BBlockLdsExtraN + 1, // CShuffleMRepeatPerShuffle + 1, // CShuffleNRepeatPerShuffle + ck::Sequence<1, 16, 1, 4>, // CDEBlockTransferClusterLengths + 1, // CDEBlockTransferScalarPerVector_NPerBlock + ck::BlockGemmPipelineScheduler::Intrawave, // BlkGemmPipeSched + ck::BlockGemmPipelineVersion::v1>; // BlkGemmPipelineVer + + // Generate instance string + std::string instance_str = ck_tile::reflect::instance_string(); + + // Expected string with all template parameters + std::string expected_str = "DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3" + "<2" // NDimSpatial + ",GNHWC" // ALayout + ",GKYXC" // BLayout + ",EmptyTuple" // DsLayout + ",GNHWK" // ELayout + ",fp16" // ADataType + ",fp16" // BDataType + ",fp32" // AccDataType + ",fp16" // CShuffleDataType + ",EmptyTuple" // DsDataType + ",fp16" // EDataType + ",PassThrough" // AElementwiseOperation + ",PassThrough" // BElementwiseOperation + ",PassThrough" // CDEElementwiseOperation + ",Default" // ConvForwardSpecialization + ",MNKPadding" // GemmSpec + ",64" // BlockSize + ",64" // MPerBlock + ",64" // NPerBlock + ",32" // KPerBlock + ",8" // AK1 + ",8" // BK1 + ",16" // MPerWmma + ",16" // NPerWmma + ",4" // MRepeat + ",2" // NRepeat + ",Seq(4,16,1)" // ABlockTransferThreadClusterLengths + ",Seq(1,0,2)" // ABlockTransferThreadClusterArrangeOrder + ",Seq(1,0,2)" // ABlockTransferSrcAccessOrder + ",2" // ABlockTransferSrcVectorDim + ",1" // ABlockTransferSrcScalarPerVector + ",8" // ABlockTransferDstScalarPerVector_AK1 + ",true" // ABlockLdsExtraM + ",Seq(4,16,1)" // BBlockTransferThreadClusterLengths + ",Seq(1,0,2)" // BBlockTransferThreadClusterArrangeOrder + ",Seq(1,0,2)" // BBlockTransferSrcAccessOrder + ",2" // BBlockTransferSrcVectorDim + ",1" // BBlockTransferSrcScalarPerVector + ",8" // BBlockTransferDstScalarPerVector_BK1 + ",true" // BBlockLdsExtraN + ",1" // CShuffleMRepeatPerShuffle + ",1" // CShuffleNRepeatPerShuffle + ",Seq(1,16,1,4)" // CDEBlockTransferClusterLengths + ",1" // CDEBlockTransferScalarPerVector_NPerBlock + ",Intrawave" // BlkGemmPipeSched + ",v1" // BlkGemmPipelineVer + ",true" // UseThreadTileTransfer + ",fp16" // AComputeDataType + ",fp16" // BComputeDataType + ",1>"; // NumGroupsToMerge + + // Verify the generated string matches exactly + EXPECT_EQ(instance_str, expected_str); +} + TEST(InstanceTraits, DlInstanceStringReturnsCorrectFormat) { using DeviceInstance = diff --git a/experimental/builder/test/test_instance_string_fwd_grp_conv_wmma_v3.cpp b/experimental/builder/test/test_instance_string_fwd_grp_conv_wmma_v3.cpp new file mode 100644 index 0000000000..894908a1e9 --- /dev/null +++ b/experimental/builder/test/test_instance_string_fwd_grp_conv_wmma_v3.cpp @@ -0,0 +1,98 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include +#include +#include +#include +#include + +namespace { + +namespace ckr = ck_tile::reflect; + +// Use the template helper to get a working instance configuration +using InstanceTuple = ck::tensor_operation::device::instance:: + device_grouped_conv_fwd_wmma_cshufflev3_f16_generic_instances< + 2, // NDimSpatial + ck::tensor_operation::device::instance::GNHWC, // ALayout + ck::tensor_operation::device::instance::GKYXC, // BLayout + ck::tensor_operation::device::instance::Empty_Tuple, // DsLayout + ck::tensor_operation::device::instance::GNHWK, // ELayout + ck::tensor_operation::device::instance::ConvFwdDefault>; // ConvForwardSpecialization + +// Get the first instance from the tuple +using DeviceInstance = typename std::tuple_element<0, InstanceTuple>::type; + +// Expected complete instance string based on the first instance from +// device_grouped_conv_fwd_wmma_cshufflev3_f16_instances +std::string expected_str = "DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3" + "<2" // NDimSpatial + ",GNHWC" // ALayout + ",GKYXC" // BLayout + ",EmptyTuple" // DsLayout + ",GNHWK" // ELayout + ",fp16" // ADataType + ",fp16" // BDataType + ",fp32" // AccDataType + ",fp16" // CShuffleDataType + ",EmptyTuple" // DsDataType + ",fp16" // EDataType + ",PassThrough" // AElementwiseOperation + ",PassThrough" // BElementwiseOperation + ",PassThrough" // CDEElementwiseOperation + ",Default" // ConvForwardSpecialization + ",MNKPadding" // GemmSpec + ",64" // BlockSize + ",64" // MPerBlock + ",64" // NPerBlock + ",32" // KPerBlock + ",8" // AK1 + ",8" // BK1 + ",16" // MPerWmma + ",16" // NPerWmma + ",4" // MRepeat + ",2" // NRepeat + ",Seq(4,16,1)" // ABlockTransferThreadClusterLengths + ",Seq(1,0,2)" // ABlockTransferThreadClusterArrangeOrder + ",Seq(1,0,2)" // ABlockTransferSrcAccessOrder + ",2" // ABlockTransferSrcVectorDim + ",1" // ABlockTransferSrcScalarPerVector + ",8" // ABlockTransferDstScalarPerVector_AK1 + ",true" // ABlockLdsExtraM + ",Seq(4,16,1)" // BBlockTransferThreadClusterLengths + ",Seq(1,0,2)" // BBlockTransferThreadClusterArrangeOrder + ",Seq(1,0,2)" // BBlockTransferSrcAccessOrder + ",2" // BBlockTransferSrcVectorDim + ",1" // BBlockTransferSrcScalarPerVector + ",8" // BBlockTransferDstScalarPerVector_BK1 + ",true" // BBlockLdsExtraN + ",1" // CShuffleMRepeatPerShuffle + ",1" // CShuffleNRepeatPerShuffle + ",Seq(1,16,1,4)" // CDEBlockTransferClusterLengths + ",1" // CDEBlockTransferScalarPerVector_NPerBlock + ",Intrawave" // BlkGemmPipeSched + ",v1" // BlkGemmPipelineVer + ",true" // UseThreadTileTransfer + ",fp16" // AComputeDataType + ",fp16" // BComputeDataType + ",1>"; // NumGroupsToMerge + +// Test describe() through base class pointer for WMMA V3 variant +TEST(InstanceString, DescribeReturnsCorrectValueForFwdGrpConvWmmaV3) +{ + using BaseClass = ck::tensor_operation::device::BaseOperator; + DeviceInstance device_instance; + BaseClass* base_ptr = &device_instance; + + auto desc = base_ptr->describe(); + ASSERT_NE(desc, nullptr); + EXPECT_EQ(desc->instance_string(), expected_str); +} + +TEST(InstanceString, DescriptionReturnsCorrectValueForFwdGrpConvWmmaV3) +{ + EXPECT_EQ(ckr::describe().instance_string(), expected_str); +} + +} // namespace diff --git a/experimental/builder/test/utils/ckb_conv_test_configs.hpp b/experimental/builder/test/utils/ckb_conv_test_configs.hpp index 641787f7df..c8609f3e64 100644 --- a/experimental/builder/test/utils/ckb_conv_test_configs.hpp +++ b/experimental/builder/test/utils/ckb_conv_test_configs.hpp @@ -344,6 +344,16 @@ constexpr GridwiseWmmaGemmABK1 GemmParamsABK1_Wmma_16x16_2x1_per_wave{.ak1 .m_wmma_per_wave = 2, .n_wmma_per_wave = 1}; +constexpr GridwiseWmmaGemmABK1 GemmParamsABK1_Wmma_16x16_4x2_per_wave{.ak1 = 8, + .bk1 = 8, + .m_per_wmma = 16, + .n_per_wmma = 16, + .m_wmma_per_wave = 4, + .n_wmma_per_wave = 2}; + +constexpr ThreadBlock ThreadBlock_64_64x64x32{.block_size = 64, + .tile_size = {.m = 64, .n = 64, .k = 32}}; + constexpr ThreadBlock ThreadBlock_256_256x256x32{.block_size = 256, .tile_size = {.m = 256, .n = 256, .k = 32}}; diff --git a/experimental/builder/test/utils/conv_algorithm_type_utils.hpp b/experimental/builder/test/utils/conv_algorithm_type_utils.hpp index ccf1b8da2f..3be5d16f91 100644 --- a/experimental/builder/test/utils/conv_algorithm_type_utils.hpp +++ b/experimental/builder/test/utils/conv_algorithm_type_utils.hpp @@ -409,6 +409,17 @@ inline std::string to_string +inline std::string to_string( + ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3 t) +{ + std::ostringstream oss; + oss << to_string(static_cast(t)) << "," + << to_string(static_cast(t)) << "," + << to_string(static_cast>(t)); + return oss.str(); +} + template <> inline std::string to_string( ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle t) diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_wmma_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_wmma_cshuffle_v3.hpp index ee05c7c6a4..df252da8b4 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_wmma_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_wmma_cshuffle_v3.hpp @@ -29,6 +29,11 @@ #include "ck/host_utility/flush_cache.hpp" #include "ck/host_utility/io.hpp" +#ifdef CK_EXPERIMENTAL_BUILDER +#include "ck_tile/builder/reflect/description.hpp" +#include "ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_abd_wmma_cshuffle_v3.hpp" +#endif + namespace ck { namespace tensor_operation { namespace device { @@ -2341,8 +2346,28 @@ struct DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3 "The argument pointer is not an object of " "DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle::Argument structure!"); } + +#ifdef CK_EXPERIMENTAL_BUILDER + std::string GetInstanceString() const override + { + static_assert( + ck_tile::reflect::HasInstanceTraits, + "InstanceTraits specialization is required. Include the .inc file for this device op."); + return ck_tile::reflect::instance_string(); + } + + std::unique_ptr describe() const override + { + return std::make_unique( + ck_tile::reflect::instance_string()); + } +#endif }; } // namespace device } // namespace tensor_operation } // namespace ck + +#ifdef CK_EXPERIMENTAL_BUILDER +#include "ck_tile/builder/reflect/reflect_device_grouped_conv_fwd_multiple_abd_wmma_cshuffle_v3.inc" +#endif