From 2f0242c5ab380f1cf63dac8cc4cbab9ff141b409 Mon Sep 17 00:00:00 2001 From: John Shumway Date: Wed, 29 Oct 2025 08:04:13 -0700 Subject: [PATCH] Add instance traits for two more grouped forward convolutions (#3112) [ROCm/composable_kernel commit: cafaeb6b7bac4e18b0a5341cd14f54224292a0c9] --- ...ped_conv_fwd_multiple_abd_xdl_cshuffle.hpp | 350 ++++++++++++++++++ ...d_multiple_d_xdl_large_tensor_cshuffle.hpp | 344 +++++++++++++++++ .../builder/reflect/instance_traits_util.hpp | 12 + experimental/builder/test/CMakeLists.txt | 4 +- ...raits.cpp => test_fwd_instance_traits.cpp} | 244 +++++++++++- .../test_get_instance_string_fwd_grp_conv.cpp | 104 ++++++ ...tance_string_fwd_grp_conv_large_tensor.cpp | 103 ++++++ ...t_get_instance_string_fwd_grp_conv_v3.cpp} | 4 +- .../test/test_instance_traits_util.cpp | 8 + ...ped_conv_fwd_multiple_abd_xdl_cshuffle.hpp | 16 + ...d_multiple_d_xdl_large_tensor_cshuffle.hpp | 17 + 11 files changed, 1193 insertions(+), 13 deletions(-) create mode 100644 experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp create mode 100644 experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp rename experimental/builder/test/{test_instance_traits.cpp => test_fwd_instance_traits.cpp} (50%) create mode 100644 experimental/builder/test/test_get_instance_string_fwd_grp_conv.cpp create mode 100644 experimental/builder/test/test_get_instance_string_fwd_grp_conv_large_tensor.cpp rename experimental/builder/test/{test_get_instance_string.cpp => test_get_instance_string_fwd_grp_conv_v3.cpp} (98%) diff --git a/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp b/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp new file mode 100644 index 0000000000..462269884e --- /dev/null +++ b/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp @@ -0,0 +1,350 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +// InstanceTraits specialization for DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle +// +// CRITICAL MAINTENANCE NOTE: +// This InstanceTraits file MUST be kept strictly in sync with the device implementation header: +// ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp +// "In sync" means that the template parameter order, names, and types in the declaration below +// MUST EXACTLY MATCH those in the device implementation. If these diverge, you may encounter +// compilation errors, subtle template instantiation mismatches, or silent runtime bugs that are +// difficult to diagnose. Always update both files together and review changes carefully. + +#pragma once + +#include "instance_traits.hpp" + +// Forward declaration to avoid circular dependency. +namespace ck::tensor_operation::device { + +template +struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle; + +} // namespace ck::tensor_operation::device + +namespace ck_tile::reflect { + +// Specialization for DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle +template +struct InstanceTraits> +{ + // Spatial dimension + static constexpr int kSpatialDim = NDimSpatial; + + // Layout types + using ALayout = ALayout_; + using BLayout = BLayout_; + using DsLayout = DsLayout_; + using ELayout = ELayout_; + + // Data types + using ADataType = ADataType_; + using BDataType = BDataType_; + using AccDataType = AccDataType_; + using CShuffleDataType = CShuffleDataType_; + using DsDataType = DsDataType_; + using EDataType = EDataType_; + + // Element-wise operations + using AElementwiseOperation = AElementwiseOperation_; + using BElementwiseOperation = BElementwiseOperation_; + using CDEElementwiseOperation = CDEElementwiseOperation_; + + // Specialization + static constexpr ck::tensor_operation::device::ConvolutionForwardSpecialization + kConvForwardSpecialization = ConvForwardSpecialization; + static constexpr ck::tensor_operation::device::GemmSpecialization kGemmSpecialization = + GemmSpec; + + // Prefetch stage + static constexpr int kNumGemmKPrefetchStage = NumGemmKPrefetchStage; + + // Block configuration + static constexpr int kBlockSize = BlockSize; + static constexpr int kMPerBlock = MPerBlock; + static constexpr int kNPerBlock = NPerBlock; + static constexpr int kKPerBlock = KPerBlock; + + // Tuning parameters + static constexpr int kAK1 = AK1; + static constexpr int kBK1 = BK1; + static constexpr int kMPerXDL = MPerXDL; + static constexpr int kNPerXDL = NPerXDL; + static constexpr int kMXdlPerWave = MXdlPerWave; + static constexpr int kNXdlPerWave = NXdlPerWave; + + // A block transfer thread cluster dimensions (converted to std::array) + static constexpr auto kAThreadClusterLengths = + detail::SequenceToArray::value; + static constexpr auto kAThreadClusterArrangeOrder = + detail::SequenceToArray::value; + static constexpr auto kABlockTransferSrcAccessOrder = + detail::SequenceToArray::value; + static constexpr int kABlockTransferSrcVectorDim = ABlockTransferSrcVectorDim; + static constexpr int kABlockTransferSrcScalarPerVector = ABlockTransferSrcScalarPerVector; + static constexpr int kABlockTransferDstScalarPerVectorK1 = ABlockTransferDstScalarPerVector_AK1; + static constexpr int kABlockLdsExtraM = ABlockLdsExtraM; + + // B block transfer thread cluster dimensions (converted to std::array) + static constexpr auto kBThreadClusterLengths = + detail::SequenceToArray::value; + static constexpr auto kBThreadClusterArrangeOrder = + detail::SequenceToArray::value; + static constexpr auto kBBlockTransferSrcAccessOrder = + detail::SequenceToArray::value; + static constexpr int kBBlockTransferSrcVectorDim = BBlockTransferSrcVectorDim; + static constexpr int kBBlockTransferSrcScalarPerVector = BBlockTransferSrcScalarPerVector; + static constexpr int kBBlockTransferDstScalarPerVectorK1 = BBlockTransferDstScalarPerVector_BK1; + static constexpr int kBBlockLdsExtraN = BBlockLdsExtraN; + + // C shuffle parameters (converted to std::array) + static constexpr int kCShuffleMXdlPerWavePerShuffle = CShuffleMXdlPerWavePerShuffle; + static constexpr int kCShuffleNXdlPerWavePerShuffle = CShuffleNXdlPerWavePerShuffle; + static constexpr auto kCThreadClusterLengths = detail::SequenceToArray< + CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock>::value; + static constexpr int kCBlockTransferScalarPerVector = CDEBlockTransferScalarPerVector_NPerBlock; + + // Compute data types + using AComputeDataType = AComputeDataType_; + using BComputeDataType = BComputeDataType_; + + // Loop scheduler + static constexpr ck::LoopScheduler kLoopScheduler = LoopSched; + + // Groups to merge + static constexpr int kNumGroupsToMerge = NumGroupsToMerge; + + // Static member function to generate instance string + static std::string instance_string() + { + std::ostringstream oss; + + // Kernel type name + oss << "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle"; + + // Template parameters in exact order matching InstanceTraits member order + oss << "<" << kSpatialDim; // 1. NDimSpatial + oss << "," << detail::layout_name(); // 2. ALayout + oss << "," << detail::layout_name(); // 3. BLayout + oss << "," << detail::tuple_name(); // 4. DsLayout + oss << "," << detail::layout_name(); // 5. ELayout + oss << "," << detail::type_name(); // 6. ADataType + oss << "," << detail::type_name(); // 7. BDataType + oss << "," << detail::type_name(); // 8. AccDataType + oss << "," << detail::type_name(); // 9. CShuffleDataType + oss << "," << detail::tuple_name(); // 10. DsDataType + oss << "," << detail::type_name(); // 11. EDataType + oss << "," + << detail::elementwise_op_name(); // 12. AElementwiseOperation + oss << "," + << detail::elementwise_op_name(); // 13. BElementwiseOperation + oss << "," + << detail::elementwise_op_name(); // 14. + // CDEElementwiseOperation + oss << "," + << detail::conv_fwd_spec_name( + kConvForwardSpecialization); // 15. ConvForwardSpecialization + oss << "," << detail::gemm_spec_name(kGemmSpecialization); // 16. GemmSpec + oss << "," << kNumGemmKPrefetchStage; // 17. NumGemmKPrefetchStage + oss << "," << kBlockSize; // 18. BlockSize + oss << "," << kMPerBlock; // 19. MPerBlock + oss << "," << kNPerBlock; // 20. NPerBlock + oss << "," << kKPerBlock; // 21. KPerBlock + oss << "," << kAK1; // 22. AK1 + oss << "," << kBK1; // 23. BK1 + oss << "," << kMPerXDL; // 24. MPerXDL + oss << "," << kNPerXDL; // 25. NPerXDL + oss << "," << kMXdlPerWave; // 26. MXdlPerWave + oss << "," << kNXdlPerWave; // 27. NXdlPerWave + oss << "," + << detail::array_to_string( + kAThreadClusterLengths); // 28. ABlockTransferThreadClusterLengths + oss << "," + << detail::array_to_string( + kAThreadClusterArrangeOrder); // 29. ABlockTransferThreadClusterArrangeOrder + oss << "," + << detail::array_to_string( + kABlockTransferSrcAccessOrder); // 30. ABlockTransferSrcAccessOrder + oss << "," << kABlockTransferSrcVectorDim; // 31. ABlockTransferSrcVectorDim + oss << "," << kABlockTransferSrcScalarPerVector; // 32. ABlockTransferSrcScalarPerVector + oss << "," + << kABlockTransferDstScalarPerVectorK1; // 33. ABlockTransferDstScalarPerVector_AK1 + oss << "," << kABlockLdsExtraM; // 34. ABlockLdsExtraM + oss << "," + << detail::array_to_string( + kBThreadClusterLengths); // 35. BBlockTransferThreadClusterLengths + oss << "," + << detail::array_to_string( + kBThreadClusterArrangeOrder); // 36. BBlockTransferThreadClusterArrangeOrder + oss << "," + << detail::array_to_string( + kBBlockTransferSrcAccessOrder); // 37. BBlockTransferSrcAccessOrder + oss << "," << kBBlockTransferSrcVectorDim; // 38. BBlockTransferSrcVectorDim + oss << "," << kBBlockTransferSrcScalarPerVector; // 39. BBlockTransferSrcScalarPerVector + oss << "," + << kBBlockTransferDstScalarPerVectorK1; // 40. BBlockTransferDstScalarPerVector_BK1 + oss << "," << kBBlockLdsExtraN; // 41. BBlockLdsExtraN + oss << "," << kCShuffleMXdlPerWavePerShuffle; // 42. CShuffleMXdlPerWavePerShuffle + oss << "," << kCShuffleNXdlPerWavePerShuffle; // 43. CShuffleNXdlPerWavePerShuffle + oss << "," + << detail::array_to_string( + kCThreadClusterLengths); // 44. CDEBlockTransferClusterLengths + oss << "," + << kCBlockTransferScalarPerVector; // 45. CDEBlockTransferScalarPerVector_NPerBlock + oss << "," << detail::type_name(); // 46. AComputeDataType + oss << "," << detail::type_name(); // 47. BComputeDataType + oss << "," << detail::loop_scheduler_name(kLoopScheduler); // 48. LoopSched + oss << "," << kNumGroupsToMerge; // 49. NumGroupsToMerge + oss << ">"; + + return oss.str(); + } +}; + +} // namespace ck_tile::reflect diff --git a/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp b/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp new file mode 100644 index 0000000000..0896a9daf9 --- /dev/null +++ b/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp @@ -0,0 +1,344 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +// InstanceTraits specialization for DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor +// +// CRITICAL MAINTENANCE NOTE: +// This InstanceTraits file MUST be kept strictly in sync with the device implementation header: +// ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp +// "In sync" means that the template parameter order, names, and types in the declaration below +// MUST EXACTLY MATCH those in the device implementation. If these diverge, you may encounter +// compilation errors, subtle template instantiation mismatches, or silent runtime bugs that are +// difficult to diagnose. Always update both files together and review changes carefully. + +#pragma once + +#include "instance_traits.hpp" + +// Forward declaration to avoid circular dependency. +namespace ck::tensor_operation::device { + +template +struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor; + +} // namespace ck::tensor_operation::device + +namespace ck_tile::reflect { + +// Specialization for DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor +template +struct InstanceTraits< + ck::tensor_operation::device::DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor< + NDimSpatial, + ALayout_, + BLayout_, + DsLayout_, + ELayout_, + ADataType_, + BDataType_, + AccDataType_, + CShuffleDataType_, + DsDataType_, + EDataType_, + AElementwiseOperation_, + BElementwiseOperation_, + CDEElementwiseOperation_, + ConvForwardSpecialization, + GemmSpec, + NumGemmKPrefetchStage, + BlockSize, + MPerBlock, + NPerBlock, + KPerBlock, + AK1, + BK1, + MPerXDL, + NPerXDL, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder_, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder_, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + BBlockLdsExtraN, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CDEBlockTransferScalarPerVector_NPerBlock, + AComputeDataType_, + BComputeDataType_, + LoopSched>> +{ + // Spatial dimension + static constexpr int kSpatialDim = NDimSpatial; + + // Layout types + using ALayout = ALayout_; + using BLayout = BLayout_; + using DsLayout = DsLayout_; + using ELayout = ELayout_; + + // Data types + using ADataType = ADataType_; + using BDataType = BDataType_; + using AccDataType = AccDataType_; + using CShuffleDataType = CShuffleDataType_; + using DsDataType = DsDataType_; + using EDataType = EDataType_; + + // Element-wise operations + using AElementwiseOperation = AElementwiseOperation_; + using BElementwiseOperation = BElementwiseOperation_; + using CDEElementwiseOperation = CDEElementwiseOperation_; + + // Specialization + static constexpr ck::tensor_operation::device::ConvolutionForwardSpecialization + kConvForwardSpecialization = ConvForwardSpecialization; + static constexpr ck::tensor_operation::device::GemmSpecialization kGemmSpecialization = + GemmSpec; + + // Prefetch stage + static constexpr int kNumGemmKPrefetchStage = NumGemmKPrefetchStage; + + // Block configuration + static constexpr int kBlockSize = BlockSize; + static constexpr int kMPerBlock = MPerBlock; + static constexpr int kNPerBlock = NPerBlock; + static constexpr int kKPerBlock = KPerBlock; + + // Tuning parameters + static constexpr int kAK1 = AK1; + static constexpr int kBK1 = BK1; + static constexpr int kMPerXDL = MPerXDL; + static constexpr int kNPerXDL = NPerXDL; + static constexpr int kMXdlPerWave = MXdlPerWave; + static constexpr int kNXdlPerWave = NXdlPerWave; + + // A block transfer thread cluster dimensions (converted to std::array) + static constexpr auto kAThreadClusterLengths = + detail::SequenceToArray::value; + static constexpr auto kAThreadClusterArrangeOrder = + detail::SequenceToArray::value; + static constexpr auto kABlockTransferSrcAccessOrder = + detail::SequenceToArray::value; + static constexpr int kABlockTransferSrcVectorDim = ABlockTransferSrcVectorDim; + static constexpr int kABlockTransferSrcScalarPerVector = ABlockTransferSrcScalarPerVector; + static constexpr int kABlockTransferDstScalarPerVectorK1 = ABlockTransferDstScalarPerVector_AK1; + static constexpr int kABlockLdsExtraM = ABlockLdsExtraM; + + // B block transfer thread cluster dimensions (converted to std::array) + static constexpr auto kBThreadClusterLengths = + detail::SequenceToArray::value; + static constexpr auto kBThreadClusterArrangeOrder = + detail::SequenceToArray::value; + static constexpr auto kBBlockTransferSrcAccessOrder = + detail::SequenceToArray::value; + static constexpr int kBBlockTransferSrcVectorDim = BBlockTransferSrcVectorDim; + static constexpr int kBBlockTransferSrcScalarPerVector = BBlockTransferSrcScalarPerVector; + static constexpr int kBBlockTransferDstScalarPerVectorK1 = BBlockTransferDstScalarPerVector_BK1; + static constexpr int kBBlockLdsExtraN = BBlockLdsExtraN; + + // C shuffle parameters (converted to std::array) + static constexpr int kCShuffleMXdlPerWavePerShuffle = CShuffleMXdlPerWavePerShuffle; + static constexpr int kCShuffleNXdlPerWavePerShuffle = CShuffleNXdlPerWavePerShuffle; + static constexpr auto kCThreadClusterLengths = detail::SequenceToArray< + CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock>::value; + static constexpr int kCBlockTransferScalarPerVector = CDEBlockTransferScalarPerVector_NPerBlock; + + // Compute data types + using AComputeDataType = AComputeDataType_; + using BComputeDataType = BComputeDataType_; + + // Loop scheduler + static constexpr ck::LoopScheduler kLoopScheduler = LoopSched; + + // Static member function to generate instance string + static std::string instance_string() + { + std::ostringstream oss; + + // Kernel type name + oss << "DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor"; + + // Template parameters in exact order matching InstanceTraits member order + oss << "<" << kSpatialDim; // 1. NDimSpatial + oss << "," << detail::layout_name(); // 2. ALayout + oss << "," << detail::layout_name(); // 3. BLayout + oss << "," << detail::tuple_name(); // 4. DsLayout + oss << "," << detail::layout_name(); // 5. ELayout + oss << "," << detail::type_name(); // 6. ADataType + oss << "," << detail::type_name(); // 7. BDataType + oss << "," << detail::type_name(); // 8. AccDataType + oss << "," << detail::type_name(); // 9. CShuffleDataType + oss << "," << detail::tuple_name(); // 10. DsDataType + oss << "," << detail::type_name(); // 11. EDataType + oss << "," + << detail::elementwise_op_name(); // 12. AElementwiseOperation + oss << "," + << detail::elementwise_op_name(); // 13. BElementwiseOperation + oss << "," + << detail::elementwise_op_name(); // 14. + // CDEElementwiseOperation + oss << "," + << detail::conv_fwd_spec_name( + kConvForwardSpecialization); // 15. ConvForwardSpecialization + oss << "," << detail::gemm_spec_name(kGemmSpecialization); // 16. GemmSpec + oss << "," << kNumGemmKPrefetchStage; // 17. NumGemmKPrefetchStage + oss << "," << kBlockSize; // 18. BlockSize + oss << "," << kMPerBlock; // 19. MPerBlock + oss << "," << kNPerBlock; // 20. NPerBlock + oss << "," << kKPerBlock; // 21. KPerBlock + oss << "," << kAK1; // 22. AK1 + oss << "," << kBK1; // 23. BK1 + oss << "," << kMPerXDL; // 24. MPerXDL + oss << "," << kNPerXDL; // 25. NPerXDL + oss << "," << kMXdlPerWave; // 26. MXdlPerWave + oss << "," << kNXdlPerWave; // 27. NXdlPerWave + oss << "," + << detail::array_to_string( + kAThreadClusterLengths); // 28. ABlockTransferThreadClusterLengths + oss << "," + << detail::array_to_string( + kAThreadClusterArrangeOrder); // 29. ABlockTransferThreadClusterArrangeOrder + oss << "," + << detail::array_to_string( + kABlockTransferSrcAccessOrder); // 30. ABlockTransferSrcAccessOrder + oss << "," << kABlockTransferSrcVectorDim; // 31. ABlockTransferSrcVectorDim + oss << "," << kABlockTransferSrcScalarPerVector; // 32. ABlockTransferSrcScalarPerVector + oss << "," + << kABlockTransferDstScalarPerVectorK1; // 33. ABlockTransferDstScalarPerVector_AK1 + oss << "," << kABlockLdsExtraM; // 34. ABlockLdsExtraM + oss << "," + << detail::array_to_string( + kBThreadClusterLengths); // 35. BBlockTransferThreadClusterLengths + oss << "," + << detail::array_to_string( + kBThreadClusterArrangeOrder); // 36. BBlockTransferThreadClusterArrangeOrder + oss << "," + << detail::array_to_string( + kBBlockTransferSrcAccessOrder); // 37. BBlockTransferSrcAccessOrder + oss << "," << kBBlockTransferSrcVectorDim; // 38. BBlockTransferSrcVectorDim + oss << "," << kBBlockTransferSrcScalarPerVector; // 39. BBlockTransferSrcScalarPerVector + oss << "," + << kBBlockTransferDstScalarPerVectorK1; // 40. BBlockTransferDstScalarPerVector_BK1 + oss << "," << kBBlockLdsExtraN; // 41. BBlockLdsExtraN + oss << "," << kCShuffleMXdlPerWavePerShuffle; // 42. CShuffleMXdlPerWavePerShuffle + oss << "," << kCShuffleNXdlPerWavePerShuffle; // 43. CShuffleNXdlPerWavePerShuffle + oss << "," + << detail::array_to_string( + kCThreadClusterLengths); // 44. CDEBlockTransferClusterLengths + oss << "," + << kCBlockTransferScalarPerVector; // 45. CDEBlockTransferScalarPerVector_NPerBlock + oss << "," << detail::type_name(); // 46. AComputeDataType + oss << "," << detail::type_name(); // 47. BComputeDataType + oss << "," << detail::loop_scheduler_name(kLoopScheduler); // 48. LoopSched + oss << ">"; + + return oss.str(); + } +}; + +} // namespace ck_tile::reflect diff --git a/experimental/builder/include/ck_tile/builder/reflect/instance_traits_util.hpp b/experimental/builder/include/ck_tile/builder/reflect/instance_traits_util.hpp index 902c3b3579..545441fd90 100644 --- a/experimental/builder/include/ck_tile/builder/reflect/instance_traits_util.hpp +++ b/experimental/builder/include/ck_tile/builder/reflect/instance_traits_util.hpp @@ -15,6 +15,7 @@ #include #include #include +#include #include #include #include @@ -160,6 +161,17 @@ constexpr std::string_view pipeline_version_name(ck::BlockGemmPipelineVersion ve } } +// Convert LoopScheduler enum to string +constexpr std::string_view loop_scheduler_name(ck::LoopScheduler sched) +{ + using enum ck::LoopScheduler; + switch(sched) + { + case Default: return "Default"; + case Interwave: return "Interwave"; + } +} + // Convert std::array to string template inline std::string array_to_string(const std::array& arr) diff --git a/experimental/builder/test/CMakeLists.txt b/experimental/builder/test/CMakeLists.txt index 3a13f7239f..b7adbc116a 100644 --- a/experimental/builder/test/CMakeLists.txt +++ b/experimental/builder/test/CMakeLists.txt @@ -26,7 +26,9 @@ add_ck_builder_test(test_inline_diff test_inline_diff.cpp) # Testing the virtual GetInstanceString methods requires kernel compilation. add_ck_builder_test(test_get_instance_string - test_get_instance_string.cpp) + test_get_instance_string_fwd_grp_conv_v3.cpp + test_get_instance_string_fwd_grp_conv.cpp + test_get_instance_string_fwd_grp_conv_large_tensor.cpp) # Testing the fwd convolution builder requires kernel compilation. # To enable parallel compilation, the individual tests are split into separate files. diff --git a/experimental/builder/test/test_instance_traits.cpp b/experimental/builder/test/test_fwd_instance_traits.cpp similarity index 50% rename from experimental/builder/test/test_instance_traits.cpp rename to experimental/builder/test/test_fwd_instance_traits.cpp index f6a8fd28c2..181319bc18 100644 --- a/experimental/builder/test/test_instance_traits.cpp +++ b/experimental/builder/test/test_fwd_instance_traits.cpp @@ -3,19 +3,18 @@ #include #include +#include +#include #include #include +#include +#include namespace { using ::testing::ElementsAre; -// Test fixture for InstanceTraits tests -class InstanceTraitsTest : public ::testing::Test -{ -}; -// Test InstanceTraits with DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 -TEST_F(InstanceTraitsTest, ConvFwdInstanceTraitsExtraction) +TEST(InstanceTraitsTest, ConvFwdInstanceTraitsExtraction) { // Define a concrete instance type with specific template parameters using DeviceInstance = @@ -156,8 +155,7 @@ TEST_F(InstanceTraitsTest, ConvFwdInstanceTraitsExtraction) ck::tensor_operation::element_wise::PassThrough>::value)); } -// Test instance_string function -TEST_F(InstanceTraitsTest, InstanceStringGeneration) +TEST(InstanceTraitsTest, V3InstanceStringGeneration) { // Define a concrete instance type with specific template parameters using DeviceInstance = @@ -215,10 +213,8 @@ TEST_F(InstanceTraitsTest, InstanceStringGeneration) ck::half_t, // AComputeDataType ck::half_t>; // BComputeDataType - // Generate instance string std::string instance_str = ck_tile::reflect::instance_string(); - // Expected string with all template parameters in exact order std::string expected_str = "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3" "<2" // NDimSpatial ",GNHWC" // ALayout @@ -269,6 +265,234 @@ TEST_F(InstanceTraitsTest, InstanceStringGeneration) ",fp16" // AComputeDataType ",fp16>"; // BComputeDataType + EXPECT_EQ(instance_str, expected_str); +} + +TEST(InstanceTraitsTest, BaseInstanceStringGeneration) +{ + using DeviceInstance = + ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle< + 2, // NDimSpatial + ck::tensor_layout::convolution::GNHWC, // ALayout + ck::tensor_layout::convolution::GKYXC, // BLayout + ck::Tuple<>, // DsLayout + ck::tensor_layout::convolution::GNHWK, // ELayout + ck::half_t, // ADataType + ck::half_t, // BDataType + float, // AccDataType + ck::half_t, // CShuffleDataType + ck::Tuple<>, // DsDataType + ck::half_t, // EDataType + ck::tensor_operation::element_wise::PassThrough, // AElementwiseOperation + ck::tensor_operation::element_wise::PassThrough, // BElementwiseOperation + ck::tensor_operation::element_wise::PassThrough, // CDEElementwiseOperation + ck::tensor_operation::device::ConvolutionForwardSpecialization:: + Default, // ConvForwardSpecialization + ck::tensor_operation::device::GemmSpecialization::Default, // GemmSpec + 1, // NumGemmKPrefetchStage + 256, // BlockSize + 128, // MPerBlock + 128, // NPerBlock + 16, // KPerBlock + 8, // AK1 + 8, // BK1 + 32, // MPerXDL + 32, // NPerXDL + 4, // MXdlPerWave + 4, // NXdlPerWave + ck::Sequence<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1 + ck::Sequence<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder + ck::Sequence<1, 0, 2>, // ABlockTransferSrcAccessOrder + 2, // ABlockTransferSrcVectorDim + 8, // ABlockTransferSrcScalarPerVector + 8, // ABlockTransferDstScalarPerVector_AK1 + 1, // ABlockLdsExtraM + ck::Sequence<4, 64, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1 + ck::Sequence<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder + ck::Sequence<1, 0, 2>, // BBlockTransferSrcAccessOrder + 2, // BBlockTransferSrcVectorDim + 8, // BBlockTransferSrcScalarPerVector + 8, // BBlockTransferDstScalarPerVector_BK1 + 1, // BBlockLdsExtraN + 1, // CShuffleMXdlPerWavePerShuffle + 1, // CShuffleNXdlPerWavePerShuffle + ck::Sequence<1, + 32, + 1, + 8>, // CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock + 8, // CDEBlockTransferScalarPerVector_NPerBlock + ck::half_t, // AComputeDataType + ck::half_t, // BComputeDataType + ck::LoopScheduler::Default, // LoopSched + 1>; // NumGroupsToMerge + + std::string instance_str = ck_tile::reflect::instance_string(); + + std::string expected_str = "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle" + "<2" // NDimSpatial + ",GNHWC" // ALayout + ",GKYXC" // BLayout + ",EmptyTuple" // DsLayout + ",GNHWK" // ELayout + ",fp16" // ADataType + ",fp16" // BDataType + ",fp32" // AccDataType + ",fp16" // CShuffleDataType + ",EmptyTuple" // DsDataType + ",fp16" // EDataType + ",PassThrough" // AElementwiseOperation + ",PassThrough" // BElementwiseOperation + ",PassThrough" // CDEElementwiseOperation + ",Default" // ConvForwardSpecialization + ",Default" // GemmSpec + ",1" // NumGemmKPrefetchStage + ",256" // BlockSize + ",128" // MPerBlock + ",128" // NPerBlock + ",16" // KPerBlock + ",8" // AK1 + ",8" // BK1 + ",32" // MPerXDL + ",32" // NPerXDL + ",4" // MXdlPerWave + ",4" // NXdlPerWave + ",Seq(4,64,1)" // ABlockTransferThreadClusterLengths + ",Seq(1,0,2)" // ABlockTransferThreadClusterArrangeOrder + ",Seq(1,0,2)" // ABlockTransferSrcAccessOrder + ",2" // ABlockTransferSrcVectorDim + ",8" // ABlockTransferSrcScalarPerVector + ",8" // ABlockTransferDstScalarPerVector_AK1 + ",1" // ABlockLdsExtraM + ",Seq(4,64,1)" // BBlockTransferThreadClusterLengths + ",Seq(1,0,2)" // BBlockTransferThreadClusterArrangeOrder + ",Seq(1,0,2)" // BBlockTransferSrcAccessOrder + ",2" // BBlockTransferSrcVectorDim + ",8" // BBlockTransferSrcScalarPerVector + ",8" // BBlockTransferDstScalarPerVector_BK1 + ",1" // BBlockLdsExtraN + ",1" // CShuffleMXdlPerWavePerShuffle + ",1" // CShuffleNXdlPerWavePerShuffle + ",Seq(1,32,1,8)" // CDEBlockTransferClusterLengths + ",8" // CDEBlockTransferScalarPerVector_NPerBlock + ",fp16" // AComputeDataType + ",fp16" // BComputeDataType + ",Default" // LoopSched + ",1>"; // NumGroupsToMerge + + EXPECT_EQ(instance_str, expected_str); +} + +TEST(InstanceTraitsTest, LargeTensorInstanceStringGeneration) +{ + using DeviceInstance = + ck::tensor_operation::device::DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor< + 2, // NDimSpatial + ck::tensor_layout::convolution::GNHWC, // ALayout + ck::tensor_layout::convolution::GKYXC, // BLayout + ck::Tuple<>, // DsLayout + ck::tensor_layout::convolution::GNHWK, // ELayout + ck::half_t, // ADataType + ck::half_t, // BDataType + float, // AccDataType + ck::half_t, // CShuffleDataType + ck::Tuple<>, // DsDataType + ck::half_t, // EDataType + ck::tensor_operation::element_wise::PassThrough, // AElementwiseOperation + ck::tensor_operation::element_wise::PassThrough, // BElementwiseOperation + ck::tensor_operation::element_wise::PassThrough, // CDEElementwiseOperation + ck::tensor_operation::device::ConvolutionForwardSpecialization:: + Default, // ConvForwardSpecialization + ck::tensor_operation::device::GemmSpecialization::Default, // GemmSpec + 1, // NumGemmKPrefetchStage + 256, // BlockSize + 128, // MPerBlock + 128, // NPerBlock + 16, // KPerBlock + 8, // AK1 + 8, // BK1 + 32, // MPerXDL + 32, // NPerXDL + 4, // MXdlPerWave + 4, // NXdlPerWave + ck::Sequence<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1 + ck::Sequence<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder + ck::Sequence<1, 0, 2>, // ABlockTransferSrcAccessOrder + 2, // ABlockTransferSrcVectorDim + 8, // ABlockTransferSrcScalarPerVector + 8, // ABlockTransferDstScalarPerVector_AK1 + 1, // ABlockLdsExtraM + ck::Sequence<4, 64, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1 + ck::Sequence<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder + ck::Sequence<1, 0, 2>, // BBlockTransferSrcAccessOrder + 2, // BBlockTransferSrcVectorDim + 8, // BBlockTransferSrcScalarPerVector + 8, // BBlockTransferDstScalarPerVector_BK1 + 1, // BBlockLdsExtraN + 1, // CShuffleMXdlPerWavePerShuffle + 1, // CShuffleNXdlPerWavePerShuffle + ck::Sequence<1, + 32, + 1, + 8>, // CDEBlockTransferClusterLengths + 8, // CDEBlockTransferScalarPerVector_NPerBlock + ck::half_t, // AComputeDataType + ck::half_t, // BComputeDataType + ck::LoopScheduler::Default>; // LoopSched + + // Generate instance string + std::string instance_str = ck_tile::reflect::instance_string(); + + // Expected string with all 48 template parameters + std::string expected_str = "DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor" + "<2" // NDimSpatial + ",GNHWC" // ALayout + ",GKYXC" // BLayout + ",EmptyTuple" // DsLayout + ",GNHWK" // ELayout + ",fp16" // ADataType + ",fp16" // BDataType + ",fp32" // AccDataType + ",fp16" // CShuffleDataType + ",EmptyTuple" // DsDataType + ",fp16" // EDataType + ",PassThrough" // AElementwiseOperation + ",PassThrough" // BElementwiseOperation + ",PassThrough" // CDEElementwiseOperation + ",Default" // ConvForwardSpecialization + ",Default" // GemmSpec + ",1" // NumGemmKPrefetchStage + ",256" // BlockSize + ",128" // MPerBlock + ",128" // NPerBlock + ",16" // KPerBlock + ",8" // AK1 + ",8" // BK1 + ",32" // MPerXDL + ",32" // NPerXDL + ",4" // MXdlPerWave + ",4" // NXdlPerWave + ",Seq(4,64,1)" // ABlockTransferThreadClusterLengths + ",Seq(1,0,2)" // ABlockTransferThreadClusterArrangeOrder + ",Seq(1,0,2)" // ABlockTransferSrcAccessOrder + ",2" // ABlockTransferSrcVectorDim + ",8" // ABlockTransferSrcScalarPerVector + ",8" // ABlockTransferDstScalarPerVector_AK1 + ",1" // ABlockLdsExtraM + ",Seq(4,64,1)" // BBlockTransferThreadClusterLengths + ",Seq(1,0,2)" // BBlockTransferThreadClusterArrangeOrder + ",Seq(1,0,2)" // BBlockTransferSrcAccessOrder + ",2" // BBlockTransferSrcVectorDim + ",8" // BBlockTransferSrcScalarPerVector + ",8" // BBlockTransferDstScalarPerVector_BK1 + ",1" // BBlockLdsExtraN + ",1" // CShuffleMXdlPerWavePerShuffle + ",1" // CShuffleNXdlPerWavePerShuffle + ",Seq(1,32,1,8)" // CDEBlockTransferClusterLengths + ",8" // CDEBlockTransferScalarPerVector_NPerBlock + ",fp16" // AComputeDataType + ",fp16" // BComputeDataType + ",Default>"; // LoopSched + // Verify the generated string matches exactly EXPECT_EQ(instance_str, expected_str); } diff --git a/experimental/builder/test/test_get_instance_string_fwd_grp_conv.cpp b/experimental/builder/test/test_get_instance_string_fwd_grp_conv.cpp new file mode 100644 index 0000000000..b2b0fb2389 --- /dev/null +++ b/experimental/builder/test/test_get_instance_string_fwd_grp_conv.cpp @@ -0,0 +1,104 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +// Test GetInstanceString through base class pointer for non-V3 variant +TEST(GetInstanceString, ReturnsStringForFwdGrpConvInstance) +{ + // Use the template helper to get a working instance configuration + using InstanceTuple = + ck::tensor_operation::device::instance::device_grouped_conv_fwd_xdl_f16_instances< + 2, // NDimSpatial + ck::tensor_operation::device::instance::GNHWC, // ALayout + ck::tensor_operation::device::instance::GKYXC, // BLayout + ck::tensor_operation::device::instance::Empty_Tuple, // DsLayout + ck::tensor_operation::device::instance::GNHWK, // ELayout + ck::tensor_operation::device::instance::ConvFwdDefault>; // ConvForwardSpecialization + + // Get the first instance from the tuple + using DeviceInstance = typename std::tuple_element<0, InstanceTuple>::type; + + // Define the base class type using DeviceGroupedConvFwdMultipleABD + using BaseClass = ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD< + 2, // NDimSpatial + ck::tensor_operation::device::instance::GNHWC, // ALayout + ck::tensor_operation::device::instance::GKYXC, // BLayout + ck::tensor_operation::device::instance::Empty_Tuple, // DsLayout + ck::tensor_operation::device::instance::GNHWK, // ELayout + ck::half_t, // ADataType + ck::half_t, // BDataType + ck::Tuple<>, // DsDataType + ck::half_t, // EDataType + ck::tensor_operation::element_wise::PassThrough, // AElementwiseOperation + ck::tensor_operation::element_wise::PassThrough, // BElementwiseOperation + ck::tensor_operation::element_wise::PassThrough, // CDEElementwiseOperation + ck::half_t, // AComputeType + ck::half_t>; // BComputeType + + // Create an instance of the derived class + DeviceInstance device_instance; + + // Get a pointer to the base class + BaseClass* base_ptr = &device_instance; + + // Call GetInstanceString through the base class pointer + std::string instance_str = base_ptr->GetInstanceString(); + + // Expected complete instance string based on the first instance from + // device_grouped_conv_fwd_xdl_f16_instances + std::string expected_str = "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle" + "<2" // NDimSpatial + ",GNHWC" // ALayout + ",GKYXC" // BLayout + ",EmptyTuple" // DsLayout + ",GNHWK" // ELayout + ",fp16" // ADataType + ",fp16" // BDataType + ",fp32" // AccDataType + ",fp16" // CShuffleDataType + ",EmptyTuple" // DsDataType + ",fp16" // EDataType + ",PassThrough" // AElementwiseOperation + ",PassThrough" // BElementwiseOperation + ",PassThrough" // CDEElementwiseOperation + ",Default" // ConvForwardSpecialization + ",MNKPadding" // GemmSpec + ",1" // NumGemmKPrefetchStage + ",64" // BlockSize + ",64" // MPerBlock + ",64" // NPerBlock + ",32" // KPerBlock + ",8" // AK1 + ",8" // BK1 + ",32" // MPerXDL + ",32" // NPerXDL + ",2" // MXdlPerWave + ",2" // NXdlPerWave + ",Seq(4,16,1)" // ABlockTransferThreadClusterLengths + ",Seq(1,0,2)" // ABlockTransferThreadClusterArrangeOrder + ",Seq(1,0,2)" // ABlockTransferSrcAccessOrder + ",2" // ABlockTransferSrcVectorDim + ",1" // ABlockTransferSrcScalarPerVector + ",8" // ABlockTransferDstScalarPerVector_AK1 + ",1" // ABlockLdsExtraM + ",Seq(4,16,1)" // BBlockTransferThreadClusterLengths + ",Seq(1,0,2)" // BBlockTransferThreadClusterArrangeOrder + ",Seq(1,0,2)" // BBlockTransferSrcAccessOrder + ",2" // BBlockTransferSrcVectorDim + ",1" // BBlockTransferSrcScalarPerVector + ",8" // BBlockTransferDstScalarPerVector_BK1 + ",1" // BBlockLdsExtraN + ",1" // CShuffleMXdlPerWavePerShuffle + ",1" // CShuffleNXdlPerWavePerShuffle + ",Seq(1,16,1,4)" // CDEBlockTransferClusterLengths + ",1" // CDEBlockTransferScalarPerVector_NPerBlock + ",fp16" // AComputeDataType + ",fp16" // BComputeDataType + ",Default" // LoopScheduler + ",1>"; // NumGroupsToMerge + EXPECT_EQ(instance_str, expected_str); +} diff --git a/experimental/builder/test/test_get_instance_string_fwd_grp_conv_large_tensor.cpp b/experimental/builder/test/test_get_instance_string_fwd_grp_conv_large_tensor.cpp new file mode 100644 index 0000000000..4d50c34ea3 --- /dev/null +++ b/experimental/builder/test/test_get_instance_string_fwd_grp_conv_large_tensor.cpp @@ -0,0 +1,103 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +// Test GetInstanceString through base class pointer for large tensor variant +TEST(GetInstanceString, ReturnsStringForFwdGrpConvLargeTensorInstance) +{ + // Use the template helper to get a working instance configuration + using InstanceTuple = ck::tensor_operation::device::instance:: + device_grouped_conv_fwd_xdl_large_tensor_f16_instances< + 2, // NDimSpatial + ck::tensor_operation::device::instance::GNHWC, // ALayout + ck::tensor_operation::device::instance::GKYXC, // BLayout + ck::tensor_operation::device::instance::Empty_Tuple, // DsLayout + ck::tensor_operation::device::instance::GNHWK, // ELayout + ck::tensor_operation::device::instance::ConvFwdDefault>; // ConvForwardSpecialization + + // Get the first instance from the tuple + using DeviceInstance = typename std::tuple_element<0, InstanceTuple>::type; + + // Define the base class type using DeviceGroupedConvFwdMultipleABD + using BaseClass = ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD< + 2, // NDimSpatial + ck::tensor_operation::device::instance::GNHWC, // ALayout + ck::tensor_operation::device::instance::GKYXC, // BLayout + ck::tensor_operation::device::instance::Empty_Tuple, // DsLayout + ck::tensor_operation::device::instance::GNHWK, // ELayout + ck::half_t, // ADataType + ck::half_t, // BDataType + ck::Tuple<>, // DsDataType + ck::half_t, // EDataType + ck::tensor_operation::element_wise::PassThrough, // AElementwiseOperation + ck::tensor_operation::element_wise::PassThrough, // BElementwiseOperation + ck::tensor_operation::element_wise::PassThrough, // CDEElementwiseOperation + ck::half_t, // AComputeType + ck::half_t>; // BComputeType + + // Create an instance of the derived class + DeviceInstance device_instance; + + // Get a pointer to the base class + BaseClass* base_ptr = &device_instance; + + // Call GetInstanceString through the base class pointer + std::string instance_str = base_ptr->GetInstanceString(); + + // Expected complete instance string based on the first instance from + // device_grouped_conv_fwd_xdl_large_tensor_f16_instances + std::string expected_str = "DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor" + "<2" // NDimSpatial + ",GNHWC" // ALayout + ",GKYXC" // BLayout + ",EmptyTuple" // DsLayout + ",GNHWK" // ELayout + ",fp16" // ADataType + ",fp16" // BDataType + ",fp32" // AccDataType + ",fp16" // CShuffleDataType + ",EmptyTuple" // DsDataType + ",fp16" // EDataType + ",PassThrough" // AElementwiseOperation + ",PassThrough" // BElementwiseOperation + ",PassThrough" // CDEElementwiseOperation + ",Default" // ConvForwardSpecialization + ",MNKPadding" // GemmSpec + ",1" // NumGemmKPrefetchStage + ",64" // BlockSize + ",64" // MPerBlock + ",64" // NPerBlock + ",32" // KPerBlock + ",8" // AK1 + ",8" // BK1 + ",32" // MPerXDL + ",32" // NPerXDL + ",2" // MXdlPerWave + ",2" // NXdlPerWave + ",Seq(4,16,1)" // ABlockTransferThreadClusterLengths + ",Seq(1,0,2)" // ABlockTransferThreadClusterArrangeOrder + ",Seq(1,0,2)" // ABlockTransferSrcAccessOrder + ",2" // ABlockTransferSrcVectorDim + ",1" // ABlockTransferSrcScalarPerVector + ",8" // ABlockTransferDstScalarPerVector_AK1 + ",1" // ABlockLdsExtraM + ",Seq(4,16,1)" // BBlockTransferThreadClusterLengths + ",Seq(1,0,2)" // BBlockTransferThreadClusterArrangeOrder + ",Seq(1,0,2)" // BBlockTransferSrcAccessOrder + ",2" // BBlockTransferSrcVectorDim + ",1" // BBlockTransferSrcScalarPerVector + ",8" // BBlockTransferDstScalarPerVector_BK1 + ",1" // BBlockLdsExtraN + ",1" // CShuffleMXdlPerWavePerShuffle + ",1" // CShuffleNXdlPerWavePerShuffle + ",Seq(1,16,1,4)" // CDEBlockTransferClusterLengths + ",1" // CDEBlockTransferScalarPerVector_NPerBlock + ",fp16" // AComputeDataType + ",fp16" // BComputeDataType + ",Default>"; // LoopScheduler + EXPECT_EQ(instance_str, expected_str); +} diff --git a/experimental/builder/test/test_get_instance_string.cpp b/experimental/builder/test/test_get_instance_string_fwd_grp_conv_v3.cpp similarity index 98% rename from experimental/builder/test/test_get_instance_string.cpp rename to experimental/builder/test/test_get_instance_string_fwd_grp_conv_v3.cpp index 5ccd17a5f1..6870f6e5d0 100644 --- a/experimental/builder/test/test_get_instance_string.cpp +++ b/experimental/builder/test/test_get_instance_string_fwd_grp_conv_v3.cpp @@ -6,8 +6,8 @@ #include #include -// Test GetInstanceString through base class pointer -TEST(GetInstanceStringTest, GetInstanceStringThroughBaseClass) +// Test GetInstanceString through base class pointer for V3 variant +TEST(GetInstanceString, ReturnsStringForFwdGrpConvV3Instance) { // Use the template helper to get a working instance configuration using InstanceTuple = diff --git a/experimental/builder/test/test_instance_traits_util.cpp b/experimental/builder/test/test_instance_traits_util.cpp index 4aa5ebf25e..fe31e04e89 100644 --- a/experimental/builder/test/test_instance_traits_util.cpp +++ b/experimental/builder/test/test_instance_traits_util.cpp @@ -199,6 +199,14 @@ TEST(InstanceTraitsUtil, PipelineVersionNameReturnsCorrectStrings) ElementsAre("v1", "v2", "v3", "v4", "v5")); } +TEST(InstanceTraitsUtil, LoopSchedulerNameReturnsCorrectStrings) +{ + using enum ck::LoopScheduler; + EXPECT_THAT(std::vector names = {loop_scheduler_name(Default), + loop_scheduler_name(Interwave)}, + ElementsAre("Default", "Interwave")); +} + TEST(InstanceTraitsUtil, TupleNameReturnsEmptyTupleForEmptyTuple) { EXPECT_EQ(tuple_name>(), "EmptyTuple"); diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp index 42f51acce9..b8d35907fc 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp @@ -28,6 +28,9 @@ #include "ck/host_utility/device_prop.hpp" #include "ck/host_utility/kernel_launch.hpp" #include "ck/host_utility/io.hpp" +#ifdef CK_EXPERIMENTAL_BUILDER +#include "ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp" +#endif namespace ck { namespace tensor_operation { @@ -2063,6 +2066,19 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle return str.str(); } +#ifdef CK_EXPERIMENTAL_BUILDER + std::string GetInstanceString() const override + { + static_assert(ck_tile::reflect::HasInstanceTraits, + "Specialization of instance_traits not found. Please check that a " + "specialization exists in file " + "ck_tile/builder/reflect/" + "instance_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp " + "for the given template parameters."); + return ck_tile::reflect::instance_string(); + } +#endif + size_t GetWorkSpaceSize(const BaseArgument* p_arg) const override { auto arg = dynamic_cast(p_arg); diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp index 020b3dc5a6..a4b1d96629 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp @@ -24,6 +24,9 @@ #include "ck/host_utility/device_prop.hpp" #include "ck/host_utility/kernel_launch.hpp" #include "ck/host_utility/io.hpp" +#ifdef CK_EXPERIMENTAL_BUILDER +#include "ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp" +#endif namespace ck { namespace tensor_operation { @@ -1220,6 +1223,20 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor return str.str(); } + +#ifdef CK_EXPERIMENTAL_BUILDER + std::string GetInstanceString() const override + { + static_assert( + ck_tile::reflect::HasInstanceTraits, + "Specialization of instance_traits not found. Please check that a " + "specialization exists in file " + "ck_tile/builder/reflect/" + "instance_traits_device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp " + "for the given template parameters."); + return ck_tile::reflect::instance_string(); + } +#endif }; } // namespace device