diff --git a/CMakeLists.txt b/CMakeLists.txt index 310e2a6576..f58dff8e15 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -40,6 +40,11 @@ option(MIOPEN_REQ_LIBS_ONLY "Build only the MIOpen required libraries" OFF) option(CK_EXPERIMENTAL_BUILDER "Enable experimental builder" OFF) option(BUILD_MHA_LIB "Build the static library for flash attention" OFF) +if(CK_EXPERIMENTAL_BUILDER) + add_definitions(-DCK_EXPERIMENTAL_BUILDER) + include_directories(${PROJECT_SOURCE_DIR}/experimental/builder/include) +endif() + # Usage: for customized Python location cmake -DCK_USE_ALTERNATIVE_PYTHON="/opt/Python-3.8.13/bin/python3.8" # CK Codegen requires dataclass which is added in Python 3.7 # Python version 3.8 is required for general good practice as it is default for Ubuntu 20.04 diff --git a/experimental/builder/include/ck_tile/builder/reflect/instance_traits.hpp b/experimental/builder/include/ck_tile/builder/reflect/instance_traits.hpp new file mode 100644 index 0000000000..a47ad0ef57 --- /dev/null +++ b/experimental/builder/include/ck_tile/builder/reflect/instance_traits.hpp @@ -0,0 +1,58 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +// Compile-time reflection for CK device kernel instances. +// +// - This is the Lowest-level reflection primitive for higher-level semantic abstractions (e.g., +// ConvTraits). +// - Extracts raw template parameters (block sizes, data types, layouts, tuning params) from kernel +// specializations. +// - Provides uniform interface to query kernel configuration without implementation knowledge +// - Other details about the device kernels can be manually added to template specializations. +// - Currently supports: +// - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "instance_traits_util.hpp" + +namespace ck_tile::reflect { + +// Primary template for InstanceTraits - extracts compile-time information directly from +// device kernel instances (e.g., DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3) +// +// This is an unspecialized template declaration. Actual specializations for specific +// device kernels are provided in separate header files (e.g., +// instance_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp). +template +struct InstanceTraits; + +// Concept-based helper to detect if InstanceTraits is specialized +// (i.e., has the instance_string() member function). +// This can be used for an informative static_assert in the device-op GetInstanceString in case the +// instance_string() template is broken. +template +concept HasInstanceTraits = requires { + { InstanceTraits::instance_string() } -> std::convertible_to; +}; + +// Free function that delegates to InstanceTraits static member function. +// Each InstanceTraits specialization provides its own instance_string() implementation. +template +inline std::string instance_string() +{ + return InstanceTraits::instance_string(); +} + +} // namespace ck_tile::reflect diff --git a/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp b/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp new file mode 100644 index 0000000000..21201b8d50 --- /dev/null +++ b/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp @@ -0,0 +1,345 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +// InstanceTraits specialization for DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 +// +// CRITICAL MAINTENANCE NOTE: +// This InstanceTraits file MUST be kept strictly in sync with the device implementation header: +// ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp +// "In sync" means that the template parameter order, names, and types in the declaration below +// MUST EXACTLY MATCH those in the device implementation. If these diverge, you may encounter +// compilation errors, subtle template instantiation mismatches, or silent runtime bugs that are +// difficult to diagnose. Always update both files together and review changes carefully. +// ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp + +#pragma once + +#include "instance_traits.hpp" + +// Forward declaration to avoid circular dependency. +// This file will be included by the device implementation header, so we cannot include +// the implementation header here. We only need the template signature to pattern-match +// on template parameters - we don't need any implementation details. +namespace ck::tensor_operation::device { + +template +struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3; + +} // namespace ck::tensor_operation::device + +namespace ck_tile::reflect { + +// Specialization for DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 +template +struct InstanceTraits> +{ + // Spatial dimension + static constexpr int kSpatialDim = NDimSpatial; + + // Layout types + using ALayout = ALayout_; + using BLayout = BLayout_; + using DsLayout = DsLayout_; + using ELayout = ELayout_; + + // Data types + using ADataType = ADataType_; + using BDataType = BDataType_; + using AccDataType = AccDataType_; + using CShuffleDataType = CShuffleDataType_; + using DsDataType = DsDataType_; + using EDataType = EDataType_; + + // Element-wise operations + using AElementwiseOperation = AElementwiseOperation_; + using BElementwiseOperation = BElementwiseOperation_; + using CDEElementwiseOperation = CDEElementwiseOperation_; + + // Specialization + static constexpr ck::tensor_operation::device::ConvolutionForwardSpecialization + kConvForwardSpecialization = ConvForwardSpecialization; + static constexpr ck::tensor_operation::device::GemmSpecialization kGemmSpecialization = + GemmSpec; + + // Block configuration + static constexpr int kBlockSize = BlockSize; + static constexpr int kMPerBlock = MPerBlock; + static constexpr int kNPerBlock = NPerBlock; + static constexpr int kKPerBlock = KPerBlock; + + // Tuning parameters + static constexpr int kAK1 = AK1; + static constexpr int kBK1 = BK1; + static constexpr int kMPerXDL = MPerXDL; + static constexpr int kNPerXDL = NPerXDL; + static constexpr int kMXdlPerWave = MXdlPerWave; + static constexpr int kNXdlPerWave = NXdlPerWave; + + // A block transfer thread cluster dimensions (converted to std::array) + static constexpr auto kAThreadClusterLengths = + detail::SequenceToArray::value; + static constexpr auto kAThreadClusterArrangeOrder = + detail::SequenceToArray::value; + static constexpr auto kABlockTransferSrcAccessOrder = + detail::SequenceToArray::value; + static constexpr int kABlockTransferSrcVectorDim = ABlockTransferSrcVectorDim; + static constexpr int kABlockTransferSrcScalarPerVector = ABlockTransferSrcScalarPerVector; + static constexpr int kABlockTransferDstScalarPerVectorK1 = ABlockTransferDstScalarPerVector_AK1; + static constexpr int kABlockLdsExtraM = ABlockLdsExtraM; + + // B block transfer thread cluster dimensions (converted to std::array) + static constexpr auto kBThreadClusterLengths = + detail::SequenceToArray::value; + static constexpr auto kBThreadClusterArrangeOrder = + detail::SequenceToArray::value; + static constexpr auto kBBlockTransferSrcAccessOrder = + detail::SequenceToArray::value; + static constexpr int kBBlockTransferSrcVectorDim = BBlockTransferSrcVectorDim; + static constexpr int kBBlockTransferSrcScalarPerVector = BBlockTransferSrcScalarPerVector; + static constexpr int kBBlockTransferDstScalarPerVectorK1 = BBlockTransferDstScalarPerVector_BK1; + static constexpr int kBBlockLdsExtraN = BBlockLdsExtraN; + + // C shuffle parameters (converted to std::array) + static constexpr int kCShuffleMXdlPerWavePerShuffle = CShuffleMXdlPerWavePerShuffle; + static constexpr int kCShuffleNXdlPerWavePerShuffle = CShuffleNXdlPerWavePerShuffle; + static constexpr auto kCThreadClusterLengths = detail::SequenceToArray< + CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock>::value; + static constexpr int kCBlockTransferScalarPerVector = CDEBlockTransferScalarPerVector_NPerBlock; + + // Pipeline configuration + static constexpr ck::BlockGemmPipelineScheduler kPipelineScheduler = BlkGemmPipeSched; + static constexpr ck::BlockGemmPipelineVersion kPipelineVersion = BlkGemmPipelineVer; + + // Compute data types + using AComputeDataType = AComputeDataType_; + using BComputeDataType = BComputeDataType_; + + // Static member function to generate instance string + static std::string instance_string() + { + std::ostringstream oss; + + // Kernel type name + oss << "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3"; + + // Template parameters in exact order matching InstanceTraits member order + oss << "<" << kSpatialDim; // 1. NDimSpatial + oss << "," << detail::layout_name(); // 2. ALayout + oss << "," << detail::layout_name(); // 3. BLayout + oss << "," << detail::tuple_name(); // 4. DsLayout + oss << "," << detail::layout_name(); // 5. ELayout + oss << "," << detail::type_name(); // 6. ADataType + oss << "," << detail::type_name(); // 7. BDataType + oss << "," << detail::type_name(); // 8. AccDataType + oss << "," << detail::type_name(); // 9. CShuffleDataType + oss << "," << detail::tuple_name(); // 10. DsDataType + oss << "," << detail::type_name(); // 11. EDataType + oss << "," + << detail::elementwise_op_name(); // 12. AElementwiseOperation + oss << "," + << detail::elementwise_op_name(); // 13. BElementwiseOperation + oss << "," + << detail::elementwise_op_name(); // 14. + // CDEElementwiseOperation + oss << "," + << detail::conv_fwd_spec_name( + kConvForwardSpecialization); // 15. ConvForwardSpecialization + oss << "," << detail::gemm_spec_name(kGemmSpecialization); // 16. GemmSpec + oss << "," << kBlockSize; // 17. BlockSize + oss << "," << kMPerBlock; // 18. MPerBlock + oss << "," << kNPerBlock; // 19. NPerBlock + oss << "," << kKPerBlock; // 20. KPerBlock + oss << "," << kAK1; // 21. AK1 + oss << "," << kBK1; // 22. BK1 + oss << "," << kMPerXDL; // 23. MPerXDL + oss << "," << kNPerXDL; // 24. NPerXDL + oss << "," << kMXdlPerWave; // 25. MXdlPerWave + oss << "," << kNXdlPerWave; // 26. NXdlPerWave + oss << "," + << detail::array_to_string( + kAThreadClusterLengths); // 27. ABlockTransferThreadClusterLengths + oss << "," + << detail::array_to_string( + kAThreadClusterArrangeOrder); // 28. ABlockTransferThreadClusterArrangeOrder + oss << "," + << detail::array_to_string( + kABlockTransferSrcAccessOrder); // 29. ABlockTransferSrcAccessOrder + oss << "," << kABlockTransferSrcVectorDim; // 30. ABlockTransferSrcVectorDim + oss << "," << kABlockTransferSrcScalarPerVector; // 31. ABlockTransferSrcScalarPerVector + oss << "," + << kABlockTransferDstScalarPerVectorK1; // 32. ABlockTransferDstScalarPerVector_AK1 + oss << "," << kABlockLdsExtraM; // 33. ABlockLdsExtraM + oss << "," + << detail::array_to_string( + kBThreadClusterLengths); // 34. BBlockTransferThreadClusterLengths + oss << "," + << detail::array_to_string( + kBThreadClusterArrangeOrder); // 35. BBlockTransferThreadClusterArrangeOrder + oss << "," + << detail::array_to_string( + kBBlockTransferSrcAccessOrder); // 36. BBlockTransferSrcAccessOrder + oss << "," << kBBlockTransferSrcVectorDim; // 37. BBlockTransferSrcVectorDim + oss << "," << kBBlockTransferSrcScalarPerVector; // 38. BBlockTransferSrcScalarPerVector + oss << "," + << kBBlockTransferDstScalarPerVectorK1; // 39. BBlockTransferDstScalarPerVector_BK1 + oss << "," << kBBlockLdsExtraN; // 40. BBlockLdsExtraN + oss << "," << kCShuffleMXdlPerWavePerShuffle; // 41. CShuffleMXdlPerWavePerShuffle + oss << "," << kCShuffleNXdlPerWavePerShuffle; // 42. CShuffleNXdlPerWavePerShuffle + oss << "," + << detail::array_to_string( + kCThreadClusterLengths); // 43. CDEBlockTransferClusterLengths + oss << "," + << kCBlockTransferScalarPerVector; // 44. CDEBlockTransferScalarPerVector_NPerBlock + oss << "," << detail::pipeline_scheduler_name(kPipelineScheduler); // 45. BlkGemmPipeSched + oss << "," << detail::pipeline_version_name(kPipelineVersion); // 46. BlkGemmPipelineVer + oss << "," << detail::type_name(); // 47. AComputeDataType + oss << "," << detail::type_name(); // 48. BComputeDataType + oss << ">"; + + return oss.str(); + } +}; + +} // namespace ck_tile::reflect diff --git a/experimental/builder/include/ck_tile/builder/reflect/instance_traits_util.hpp b/experimental/builder/include/ck_tile/builder/reflect/instance_traits_util.hpp new file mode 100644 index 0000000000..160a560529 --- /dev/null +++ b/experimental/builder/include/ck_tile/builder/reflect/instance_traits_util.hpp @@ -0,0 +1,195 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +// Utility functions and helpers for instance_traits.hpp +// Contains helper functions to convert types, enums, and sequences to string representations. +// The helper function are consteval so that unknown cases cause compile-time errors. + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace ck_tile::reflect::detail { + +// Metaprogramming helper to convert ck::Sequence to constexpr std::array +template +struct SequenceToArray; + +template +struct SequenceToArray> +{ + static constexpr std::array value = {static_cast(Is)...}; +}; + +// Convert data types to string names +template +consteval std::string_view type_name() +{ + if constexpr(std::is_same_v) + return "fp16"; + else if constexpr(std::is_same_v) + return "fp32"; + else if constexpr(std::is_same_v) + return "fp64"; + else if constexpr(std::is_same_v) + return "s8"; + else if constexpr(std::is_same_v) + return "s32"; + else if constexpr(std::is_same_v) + return "bf16"; + else if constexpr(std::is_same_v) + return "fp8"; + else if constexpr(std::is_same_v) + return "bf8"; + else + static_assert(false, "unknown_type"); +} + +// Convert layout types to string names +template +constexpr std::string_view layout_name() +{ + // Convolution layouts + if constexpr(std::is_same_v) + return "GNHWC"; + else if constexpr(std::is_same_v) + return "GKYXC"; + else if constexpr(std::is_same_v) + return "GNHWK"; + else if constexpr(std::is_same_v) + return "GKZYXC"; + else if constexpr(std::is_same_v) + return "GNDHWC"; + else if constexpr(std::is_same_v) + return "GNDHWK"; + else if constexpr(std::is_same_v) + return "NHWGC"; + else if constexpr(std::is_same_v) + return "KYXGC"; + else if constexpr(std::is_same_v) + return "NHWGK"; + else + static_assert(false, "unknown_layout"); +} + +// Convert element-wise operation types to string names +template +constexpr std::string_view elementwise_op_name() +{ + if constexpr(std::is_same_v) + return "PassThrough"; + else if constexpr(std::is_same_v) + return "Scale"; + else if constexpr(std::is_same_v) + return "Bilinear"; + else if constexpr(std::is_same_v) + return "Add"; + else if constexpr(std::is_same_v) + return "AddRelu"; + else if constexpr(std::is_same_v) + return "Relu"; + else + static_assert(false, "unknown_op"); +} + +// Convert ConvolutionForwardSpecialization enum to string +constexpr std::string_view +conv_fwd_spec_name(ck::tensor_operation::device::ConvolutionForwardSpecialization spec) +{ + using ck::tensor_operation::device::ConvolutionForwardSpecialization; + switch(spec) + { + case ConvolutionForwardSpecialization::Default: return "Default"; + case ConvolutionForwardSpecialization::Filter1x1Stride1Pad0: return "Filter1x1Stride1Pad0"; + case ConvolutionForwardSpecialization::Filter1x1Pad0: return "Filter1x1Pad0"; + case ConvolutionForwardSpecialization::Filter3x3: return "Filter3x3"; + case ConvolutionForwardSpecialization::OddC: return "OddC"; + } +} + +// Convert GemmSpecialization enum to string +constexpr std::string_view gemm_spec_name(ck::tensor_operation::device::GemmSpecialization spec) +{ + using ck::tensor_operation::device::GemmSpecialization; + switch(spec) + { + case GemmSpecialization::Default: return "Default"; + case GemmSpecialization::MPadding: return "MPadding"; + case GemmSpecialization::NPadding: return "NPadding"; + case GemmSpecialization::KPadding: return "KPadding"; + case GemmSpecialization::MNPadding: return "MNPadding"; + case GemmSpecialization::MKPadding: return "MKPadding"; + case GemmSpecialization::NKPadding: return "NKPadding"; + case GemmSpecialization::MNKPadding: return "MNKPadding"; + case GemmSpecialization::OPadding: return "OPadding"; + case GemmSpecialization::MOPadding: return "MOPadding"; + case GemmSpecialization::NOPadding: return "NOPadding"; + case GemmSpecialization::KOPadding: return "KOPadding"; + case GemmSpecialization::MNOPadding: return "MNOPadding"; + case GemmSpecialization::MKOPadding: return "MKOPadding"; + case GemmSpecialization::NKOPadding: return "NKOPadding"; + case GemmSpecialization::MNKOPadding: return "MNKOPadding"; + } +} + +// Convert BlockGemmPipelineScheduler enum to string +constexpr std::string_view pipeline_scheduler_name(ck::BlockGemmPipelineScheduler sched) +{ + using ck::BlockGemmPipelineScheduler; + switch(sched) + { + case BlockGemmPipelineScheduler::Intrawave: return "Intrawave"; + case BlockGemmPipelineScheduler::Interwave: return "Interwave"; + } +} + +// Convert BlockGemmPipelineVersion enum to string +constexpr std::string_view pipeline_version_name(ck::BlockGemmPipelineVersion ver) +{ + using ck::BlockGemmPipelineVersion; + switch(ver) + { + case BlockGemmPipelineVersion::v1: return "v1"; + case BlockGemmPipelineVersion::v2: return "v2"; + case BlockGemmPipelineVersion::v3: return "v3"; + case BlockGemmPipelineVersion::v4: return "v4"; + case BlockGemmPipelineVersion::v5: return "v5"; + } +} + +// Convert std::array to string +template +inline std::string array_to_string(const std::array& arr) +{ + std::ostringstream oss; + oss << "Seq("; + for(std::size_t i = 0; i < arr.size(); ++i) + { + if(i > 0) + oss << ","; + oss << arr[i]; + } + oss << ")"; + return oss.str(); +} + +// Handle ck::Tuple (empty tuple for DsLayout/DsDataType) +template +constexpr std::string_view tuple_name() +{ + // For now, just check if it's an empty tuple + return "EmptyTuple"; +} + +} // namespace ck_tile::reflect::detail diff --git a/experimental/builder/test/CMakeLists.txt b/experimental/builder/test/CMakeLists.txt index 5890aa8dcd..04b63b7823 100644 --- a/experimental/builder/test/CMakeLists.txt +++ b/experimental/builder/test/CMakeLists.txt @@ -1,4 +1,3 @@ - include(gtest) # Helper function to create a gtest executable with common properties @@ -17,4 +16,8 @@ function(add_ck_builder_test test_name) endfunction() add_ck_builder_test(test_conv_builder - test_conv_builder.cpp) + test_conv_builder.cpp + test_instance_traits.cpp) + +add_ck_builder_test(test_get_instance_string + test_get_instance_string.cpp) diff --git a/experimental/builder/test/test_get_instance_string.cpp b/experimental/builder/test/test_get_instance_string.cpp new file mode 100644 index 0000000000..5ccd17a5f1 --- /dev/null +++ b/experimental/builder/test/test_get_instance_string.cpp @@ -0,0 +1,104 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +// Test GetInstanceString through base class pointer +TEST(GetInstanceStringTest, GetInstanceStringThroughBaseClass) +{ + // Use the template helper to get a working instance configuration + using InstanceTuple = + ck::tensor_operation::device::instance::device_grouped_conv_fwd_xdl_f16_comp_instances< + 2, // NDimSpatial + ck::tensor_operation::device::instance::GNHWC, // ALayout + ck::tensor_operation::device::instance::GKYXC, // BLayout + ck::tensor_operation::device::instance::Empty_Tuple, // DsLayout + ck::tensor_operation::device::instance::GNHWK, // ELayout + ck::tensor_operation::device::instance::ConvFwdDefault>; // ConvForwardSpecialization + + // Get the first instance from the tuple + using DeviceInstance = typename std::tuple_element<0, InstanceTuple>::type; + + // Define the base class type using DeviceGroupedConvFwdMultipleABD + using BaseClass = ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD< + 2, // NDimSpatial + ck::tensor_operation::device::instance::GNHWC, // ALayout + ck::tensor_operation::device::instance::GKYXC, // BLayout + ck::tensor_operation::device::instance::Empty_Tuple, // DsLayout + ck::tensor_operation::device::instance::GNHWK, // ELayout + ck::half_t, // ADataType + ck::half_t, // BDataType + ck::Tuple<>, // DsDataType + ck::half_t, // EDataType + ck::tensor_operation::element_wise::PassThrough, // AElementwiseOperation + ck::tensor_operation::element_wise::PassThrough, // BElementwiseOperation + ck::tensor_operation::element_wise::PassThrough, // CDEElementwiseOperation + ck::half_t, // AComputeType + ck::half_t>; // BComputeType + + // Create an instance of the derived class + DeviceInstance device_instance; + + // Get a pointer to the base class + BaseClass* base_ptr = &device_instance; + + // Call GetInstanceString through the base class pointer + std::string instance_str = base_ptr->GetInstanceString(); + + // Expected complete instance string based on the first instance from + // device_grouped_conv_fwd_xdl_f16_comp_instances This corresponds to the configuration with + // BlockSize=256, MPerBlock=128, NPerBlock=128, KPerBlock=64, etc. + std::string expected_str = "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3" + "<2" // NDimSpatial + ",GNHWC" // ALayout + ",GKYXC" // BLayout + ",EmptyTuple" // DsLayout + ",GNHWK" // ELayout + ",fp16" // ADataType + ",fp16" // BDataType + ",fp32" // AccDataType + ",fp16" // CShuffleDataType + ",EmptyTuple" // DsDataType + ",fp16" // EDataType + ",PassThrough" // AElementwiseOperation + ",PassThrough" // BElementwiseOperation + ",PassThrough" // CDEElementwiseOperation + ",Default" // ConvForwardSpecialization + ",MNKPadding" // GemmSpec + ",256" // BlockSize + ",128" // MPerBlock + ",128" // NPerBlock + ",64" // KPerBlock + ",8" // AK1 + ",8" // BK1 + ",32" // MPerXDL + ",32" // NPerXDL + ",2" // MXdlPerWave + ",2" // NXdlPerWave + ",Seq(8,32,1)" // ABlockTransferThreadClusterLengths + ",Seq(1,0,2)" // ABlockTransferThreadClusterArrangeOrder + ",Seq(1,0,2)" // ABlockTransferSrcAccessOrder + ",2" // ABlockTransferSrcVectorDim + ",8" // ABlockTransferSrcScalarPerVector + ",8" // ABlockTransferDstScalarPerVector_AK1 + ",0" // ABlockLdsExtraM + ",Seq(8,32,1)" // BBlockTransferThreadClusterLengths + ",Seq(1,0,2)" // BBlockTransferThreadClusterArrangeOrder + ",Seq(1,0,2)" // BBlockTransferSrcAccessOrder + ",2" // BBlockTransferSrcVectorDim + ",8" // BBlockTransferSrcScalarPerVector + ",8" // BBlockTransferDstScalarPerVector_BK1 + ",0" // BBlockLdsExtraN + ",1" // CShuffleMXdlPerWavePerShuffle + ",1" // CShuffleNXdlPerWavePerShuffle + ",Seq(1,32,1,8)" // CDEBlockTransferClusterLengths + ",8" // CDEBlockTransferScalarPerVector_NPerBlock + ",Intrawave" // BlkGemmPipeSched + ",v4" // BlkGemmPipelineVer + ",fp16" // AComputeDataType + ",fp16>"; // BComputeDataType + EXPECT_EQ(instance_str, expected_str); +} diff --git a/experimental/builder/test/test_instance_traits.cpp b/experimental/builder/test/test_instance_traits.cpp new file mode 100644 index 0000000000..f6a8fd28c2 --- /dev/null +++ b/experimental/builder/test/test_instance_traits.cpp @@ -0,0 +1,276 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +namespace { + +using ::testing::ElementsAre; +// Test fixture for InstanceTraits tests +class InstanceTraitsTest : public ::testing::Test +{ +}; + +// Test InstanceTraits with DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 +TEST_F(InstanceTraitsTest, ConvFwdInstanceTraitsExtraction) +{ + // Define a concrete instance type with specific template parameters + using DeviceInstance = + ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3< + 2, // NDimSpatial + ck::tensor_layout::convolution::GNHWC, // ALayout + ck::tensor_layout::convolution::GKYXC, // BLayout + ck::Tuple<>, // DsLayout + ck::tensor_layout::convolution::GNHWK, // ELayout + ck::half_t, // ADataType + ck::half_t, // BDataType + float, // AccDataType + ck::half_t, // CShuffleDataType + ck::Tuple<>, // DsDataType + ck::half_t, // EDataType + ck::tensor_operation::element_wise::PassThrough, // AElementwiseOperation + ck::tensor_operation::element_wise::PassThrough, // BElementwiseOperation + ck::tensor_operation::element_wise::PassThrough, // CDEElementwiseOperation + ck::tensor_operation::device::ConvolutionForwardSpecialization:: + Default, // ConvForwardSpecialization + ck::tensor_operation::device::GemmSpecialization::Default, // GemmSpec + 256, // BlockSize + 128, // MPerBlock + 128, // NPerBlock + 16, // KPerBlock + 8, // AK1 + 8, // BK1 + 32, // MPerXDL + 32, // NPerXDL + 4, // MXdlPerWave + 4, // NXdlPerWave + ck::Sequence<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1 + ck::Sequence<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder + ck::Sequence<1, 0, 2>, // ABlockTransferSrcAccessOrder + 2, // ABlockTransferSrcVectorDim + 8, // ABlockTransferSrcScalarPerVector + 8, // ABlockTransferDstScalarPerVector_AK1 + 1, // ABlockLdsExtraM + ck::Sequence<4, 64, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1 + ck::Sequence<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder + ck::Sequence<1, 0, 2>, // BBlockTransferSrcAccessOrder + 2, // BBlockTransferSrcVectorDim + 8, // BBlockTransferSrcScalarPerVector + 8, // BBlockTransferDstScalarPerVector_BK1 + 1, // BBlockLdsExtraN + 1, // CShuffleMXdlPerWavePerShuffle + 1, // CShuffleNXdlPerWavePerShuffle + ck::Sequence<1, + 32, + 1, + 8>, // CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock + 8, // CDEBlockTransferScalarPerVector_NPerBlock + ck::BlockGemmPipelineScheduler::Intrawave, // BlkGemmPipeSched + ck::BlockGemmPipelineVersion::v1, // BlkGemmPipelineVer + ck::half_t, // AComputeDataType + ck::half_t>; // BComputeDataType + + // Use InstanceTraits to extract compile-time information + using Traits = ck_tile::reflect::InstanceTraits; + + // Verify spatial dimension + EXPECT_EQ(Traits::kSpatialDim, 2); + + // Verify block configuration + EXPECT_EQ(Traits::kBlockSize, 256); + EXPECT_EQ(Traits::kMPerBlock, 128); + EXPECT_EQ(Traits::kNPerBlock, 128); + EXPECT_EQ(Traits::kKPerBlock, 16); + + // Verify tuning parameters + EXPECT_EQ(Traits::kAK1, 8); + EXPECT_EQ(Traits::kBK1, 8); + EXPECT_EQ(Traits::kMPerXDL, 32); + EXPECT_EQ(Traits::kNPerXDL, 32); + EXPECT_EQ(Traits::kMXdlPerWave, 4); + EXPECT_EQ(Traits::kNXdlPerWave, 4); + + // Verify A block transfer parameters + EXPECT_EQ(Traits::kABlockTransferSrcVectorDim, 2); + EXPECT_EQ(Traits::kABlockTransferSrcScalarPerVector, 8); + EXPECT_EQ(Traits::kABlockTransferDstScalarPerVectorK1, 8); + EXPECT_EQ(Traits::kABlockLdsExtraM, 1); + + // Verify B block transfer parameters + EXPECT_EQ(Traits::kBBlockTransferSrcVectorDim, 2); + EXPECT_EQ(Traits::kBBlockTransferSrcScalarPerVector, 8); + EXPECT_EQ(Traits::kBBlockTransferDstScalarPerVectorK1, 8); + EXPECT_EQ(Traits::kBBlockLdsExtraN, 1); + + // Verify C shuffle parameters + EXPECT_EQ(Traits::kCShuffleMXdlPerWavePerShuffle, 1); + EXPECT_EQ(Traits::kCShuffleNXdlPerWavePerShuffle, 1); + EXPECT_EQ(Traits::kCBlockTransferScalarPerVector, 8); + + // Verify pipeline configuration + EXPECT_EQ(Traits::kPipelineScheduler, ck::BlockGemmPipelineScheduler::Intrawave); + EXPECT_EQ(Traits::kPipelineVersion, ck::BlockGemmPipelineVersion::v1); + + // Verify data types using std::is_same + EXPECT_TRUE((std::is_same::value)); + EXPECT_TRUE((std::is_same::value)); + EXPECT_TRUE((std::is_same::value)); + EXPECT_TRUE((std::is_same::value)); + + // Verify layout types + EXPECT_TRUE((std::is_same::value)); + EXPECT_TRUE((std::is_same::value)); + EXPECT_TRUE((std::is_same::value)); + + // Verify all array values for thread cluster lengths using googlemock matchers + EXPECT_THAT(Traits::kAThreadClusterLengths, ElementsAre(4, 64, 1)); + EXPECT_THAT(Traits::kBThreadClusterLengths, ElementsAre(4, 64, 1)); + EXPECT_THAT(Traits::kCThreadClusterLengths, ElementsAre(1, 32, 1, 8)); + + // Verify A block transfer arrange order and access order arrays + EXPECT_THAT(Traits::kAThreadClusterArrangeOrder, ElementsAre(1, 0, 2)); + EXPECT_THAT(Traits::kABlockTransferSrcAccessOrder, ElementsAre(1, 0, 2)); + + // Verify B block transfer arrange order and access order arrays + EXPECT_THAT(Traits::kBThreadClusterArrangeOrder, ElementsAre(1, 0, 2)); + EXPECT_THAT(Traits::kBBlockTransferSrcAccessOrder, ElementsAre(1, 0, 2)); + + // Verify additional data types + EXPECT_TRUE((std::is_same::value)); + EXPECT_TRUE((std::is_same>::value)); + EXPECT_TRUE((std::is_same::value)); + EXPECT_TRUE((std::is_same::value)); + + // Verify additional layout types + EXPECT_TRUE((std::is_same>::value)); + + // Verify element-wise operations + EXPECT_TRUE((std::is_same::value)); + EXPECT_TRUE((std::is_same::value)); + EXPECT_TRUE((std::is_same::value)); +} + +// Test instance_string function +TEST_F(InstanceTraitsTest, InstanceStringGeneration) +{ + // Define a concrete instance type with specific template parameters + using DeviceInstance = + ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3< + 2, // NDimSpatial + ck::tensor_layout::convolution::GNHWC, // ALayout + ck::tensor_layout::convolution::GKYXC, // BLayout + ck::Tuple<>, // DsLayout + ck::tensor_layout::convolution::GNHWK, // ELayout + ck::half_t, // ADataType + ck::half_t, // BDataType + float, // AccDataType + ck::half_t, // CShuffleDataType + ck::Tuple<>, // DsDataType + ck::half_t, // EDataType + ck::tensor_operation::element_wise::PassThrough, // AElementwiseOperation + ck::tensor_operation::element_wise::PassThrough, // BElementwiseOperation + ck::tensor_operation::element_wise::PassThrough, // CDEElementwiseOperation + ck::tensor_operation::device::ConvolutionForwardSpecialization:: + Default, // ConvForwardSpecialization + ck::tensor_operation::device::GemmSpecialization::Default, // GemmSpec + 256, // BlockSize + 128, // MPerBlock + 128, // NPerBlock + 16, // KPerBlock + 8, // AK1 + 8, // BK1 + 32, // MPerXDL + 32, // NPerXDL + 4, // MXdlPerWave + 4, // NXdlPerWave + ck::Sequence<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1 + ck::Sequence<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder + ck::Sequence<1, 0, 2>, // ABlockTransferSrcAccessOrder + 2, // ABlockTransferSrcVectorDim + 8, // ABlockTransferSrcScalarPerVector + 8, // ABlockTransferDstScalarPerVector_AK1 + 1, // ABlockLdsExtraM + ck::Sequence<4, 64, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1 + ck::Sequence<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder + ck::Sequence<1, 0, 2>, // BBlockTransferSrcAccessOrder + 2, // BBlockTransferSrcVectorDim + 8, // BBlockTransferSrcScalarPerVector + 8, // BBlockTransferDstScalarPerVector_BK1 + 1, // BBlockLdsExtraN + 1, // CShuffleMXdlPerWavePerShuffle + 1, // CShuffleNXdlPerWavePerShuffle + ck::Sequence<1, + 32, + 1, + 8>, // CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock + 8, // CDEBlockTransferScalarPerVector_NPerBlock + ck::BlockGemmPipelineScheduler::Intrawave, // BlkGemmPipeSched + ck::BlockGemmPipelineVersion::v1, // BlkGemmPipelineVer + ck::half_t, // AComputeDataType + ck::half_t>; // BComputeDataType + + // Generate instance string + std::string instance_str = ck_tile::reflect::instance_string(); + + // Expected string with all template parameters in exact order + std::string expected_str = "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3" + "<2" // NDimSpatial + ",GNHWC" // ALayout + ",GKYXC" // BLayout + ",EmptyTuple" // DsLayout + ",GNHWK" // ELayout + ",fp16" // ADataType + ",fp16" // BDataType + ",fp32" // AccDataType + ",fp16" // CShuffleDataType + ",EmptyTuple" // DsDataType + ",fp16" // EDataType + ",PassThrough" // AElementwiseOperation + ",PassThrough" // BElementwiseOperation + ",PassThrough" // CDEElementwiseOperation + ",Default" // ConvForwardSpecialization + ",Default" // GemmSpec + ",256" // BlockSize + ",128" // MPerBlock + ",128" // NPerBlock + ",16" // KPerBlock + ",8" // AK1 + ",8" // BK1 + ",32" // MPerXDL + ",32" // NPerXDL + ",4" // MXdlPerWave + ",4" // NXdlPerWave + ",Seq(4,64,1)" // ABlockTransferThreadClusterLengths + ",Seq(1,0,2)" // ABlockTransferThreadClusterArrangeOrder + ",Seq(1,0,2)" // ABlockTransferSrcAccessOrder + ",2" // ABlockTransferSrcVectorDim + ",8" // ABlockTransferSrcScalarPerVector + ",8" // ABlockTransferDstScalarPerVector_AK1 + ",1" // ABlockLdsExtraM + ",Seq(4,64,1)" // BBlockTransferThreadClusterLengths + ",Seq(1,0,2)" // BBlockTransferThreadClusterArrangeOrder + ",Seq(1,0,2)" // BBlockTransferSrcAccessOrder + ",2" // BBlockTransferSrcVectorDim + ",8" // BBlockTransferSrcScalarPerVector + ",8" // BBlockTransferDstScalarPerVector_BK1 + ",1" // BBlockLdsExtraN + ",1" // CShuffleMXdlPerWavePerShuffle + ",1" // CShuffleNXdlPerWavePerShuffle + ",Seq(1,32,1,8)" // CDEBlockTransferClusterLengths + ",8" // CDEBlockTransferScalarPerVector_NPerBlock + ",Intrawave" // BlkGemmPipeSched + ",v1" // BlkGemmPipelineVer + ",fp16" // AComputeDataType + ",fp16>"; // BComputeDataType + + // Verify the generated string matches exactly + EXPECT_EQ(instance_str, expected_str); +} + +} // anonymous namespace diff --git a/include/ck/tensor_operation/gpu/device/device_base.hpp b/include/ck/tensor_operation/gpu/device/device_base.hpp index e7ce7cbcf5..2ce0452544 100644 --- a/include/ck/tensor_operation/gpu/device/device_base.hpp +++ b/include/ck/tensor_operation/gpu/device/device_base.hpp @@ -227,6 +227,7 @@ struct BaseOperator #if !defined(__HIPCC_RTC__) || !defined(CK_CODE_GEN_RTC) virtual bool IsSupportedArgument(const BaseArgument*) { return false; } virtual std::string GetTypeString() const { return ""; } + virtual std::string GetInstanceString() const { return ""; } virtual std::string GetTypeIdName() const { return typeid(*this).name(); } diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp index dbc60e3fdc..ebcefa226b 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp @@ -28,6 +28,9 @@ #include "ck/host_utility/kernel_launch.hpp" #include "ck/host_utility/flush_cache.hpp" #include "ck/host_utility/io.hpp" +#ifdef CK_EXPERIMENTAL_BUILDER +#include "ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp" +#endif namespace ck { namespace tensor_operation { @@ -1994,6 +1997,19 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 return str.str(); } +#ifdef CK_EXPERIMENTAL_BUILDER + std::string GetInstanceString() const override + { + static_assert(ck_tile::reflect::HasInstanceTraits, + "Specialization of instance_traits not found. Please check that a " + "specialization exists in file " + "ck_tile/builder/reflect/" + "instance_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp " + "for the given template parameters."); + return ck_tile::reflect::instance_string(); + } +#endif + size_t GetWorkSpaceSize(const BaseArgument* p_arg) const override { auto arg = dynamic_cast(p_arg);