mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-04 05:31:24 +00:00
* [CK_BILDER] Add compile-time reflection for a convolution instance Introduce InstanceTraits template metaprogramming framework to enable runtime introspection of device kernel template parameters without requiring implementation knowledge. This reflection system extracts configuration details (block sizes, data types, layouts, tuning parameters) directly from kernel specializations through template pattern matching. In particular, the GetInstanceString method returns a string that uniquely idenitfies the kernel, by explicitly serializing all template paramter values. This provides critical functionality for MIOpen integration, since the existing GetTypeString method is ambiguous, and only captures some of the template paramters. The implementation uses a two-level design: a primary InstanceTraits template declaration in instance_traits.hpp serves as the interface, while kernel-specific specializations (e.g., for DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3) provide the actual extraction logic. This separation allows the reflection system to scale to additional kernel types without modifying the core interface. Key architectural decisions: - Forward-declare device kernels in instance_traits.hpp to avoid circular dependencies, since device implementation headers will include the reflection headers - Use compile-time constants and type aliases to expose kernel parameters, enabling zero-overhead introspection - Provide a templated instance_string() function that generates human-readable kernel configuration strings by serializing all template parameters in order, useful for debugging and kernel identification - Guard reflection integration with preprocessor definition CK_EXPERIMENTAL_BUILDER to keep it opt-in until the API stabilizes - Add GetInstanceString() virtual method to BaseOperator, allowing runtime polymorphic access to compile-time kernel information This infrastructure also enables upcoming higher-level semantic reflection abstractions (like ConvTraits) to query kernel configurations programmatically. Includes unit tests validating both the trait extraction accuracy and the string generation format.
105 lines
6.5 KiB
C++
105 lines
6.5 KiB
C++
// SPDX-License-Identifier: MIT
|
|
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
|
|
|
#include <gtest/gtest.h>
|
|
#include <ck_tile/builder/reflect/instance_traits.hpp>
|
|
#include <ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_abd.hpp>
|
|
#include <ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_comp_instance.hpp>
|
|
|
|
// Test GetInstanceString through base class pointer
|
|
TEST(GetInstanceStringTest, GetInstanceStringThroughBaseClass)
|
|
{
|
|
// Use the template helper to get a working instance configuration
|
|
using InstanceTuple =
|
|
ck::tensor_operation::device::instance::device_grouped_conv_fwd_xdl_f16_comp_instances<
|
|
2, // NDimSpatial
|
|
ck::tensor_operation::device::instance::GNHWC, // ALayout
|
|
ck::tensor_operation::device::instance::GKYXC, // BLayout
|
|
ck::tensor_operation::device::instance::Empty_Tuple, // DsLayout
|
|
ck::tensor_operation::device::instance::GNHWK, // ELayout
|
|
ck::tensor_operation::device::instance::ConvFwdDefault>; // ConvForwardSpecialization
|
|
|
|
// Get the first instance from the tuple
|
|
using DeviceInstance = typename std::tuple_element<0, InstanceTuple>::type;
|
|
|
|
// Define the base class type using DeviceGroupedConvFwdMultipleABD
|
|
using BaseClass = ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD<
|
|
2, // NDimSpatial
|
|
ck::tensor_operation::device::instance::GNHWC, // ALayout
|
|
ck::tensor_operation::device::instance::GKYXC, // BLayout
|
|
ck::tensor_operation::device::instance::Empty_Tuple, // DsLayout
|
|
ck::tensor_operation::device::instance::GNHWK, // ELayout
|
|
ck::half_t, // ADataType
|
|
ck::half_t, // BDataType
|
|
ck::Tuple<>, // DsDataType
|
|
ck::half_t, // EDataType
|
|
ck::tensor_operation::element_wise::PassThrough, // AElementwiseOperation
|
|
ck::tensor_operation::element_wise::PassThrough, // BElementwiseOperation
|
|
ck::tensor_operation::element_wise::PassThrough, // CDEElementwiseOperation
|
|
ck::half_t, // AComputeType
|
|
ck::half_t>; // BComputeType
|
|
|
|
// Create an instance of the derived class
|
|
DeviceInstance device_instance;
|
|
|
|
// Get a pointer to the base class
|
|
BaseClass* base_ptr = &device_instance;
|
|
|
|
// Call GetInstanceString through the base class pointer
|
|
std::string instance_str = base_ptr->GetInstanceString();
|
|
|
|
// Expected complete instance string based on the first instance from
|
|
// device_grouped_conv_fwd_xdl_f16_comp_instances This corresponds to the configuration with
|
|
// BlockSize=256, MPerBlock=128, NPerBlock=128, KPerBlock=64, etc.
|
|
std::string expected_str = "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3"
|
|
"<2" // NDimSpatial
|
|
",GNHWC" // ALayout
|
|
",GKYXC" // BLayout
|
|
",EmptyTuple" // DsLayout
|
|
",GNHWK" // ELayout
|
|
",fp16" // ADataType
|
|
",fp16" // BDataType
|
|
",fp32" // AccDataType
|
|
",fp16" // CShuffleDataType
|
|
",EmptyTuple" // DsDataType
|
|
",fp16" // EDataType
|
|
",PassThrough" // AElementwiseOperation
|
|
",PassThrough" // BElementwiseOperation
|
|
",PassThrough" // CDEElementwiseOperation
|
|
",Default" // ConvForwardSpecialization
|
|
",MNKPadding" // GemmSpec
|
|
",256" // BlockSize
|
|
",128" // MPerBlock
|
|
",128" // NPerBlock
|
|
",64" // KPerBlock
|
|
",8" // AK1
|
|
",8" // BK1
|
|
",32" // MPerXDL
|
|
",32" // NPerXDL
|
|
",2" // MXdlPerWave
|
|
",2" // NXdlPerWave
|
|
",Seq(8,32,1)" // ABlockTransferThreadClusterLengths
|
|
",Seq(1,0,2)" // ABlockTransferThreadClusterArrangeOrder
|
|
",Seq(1,0,2)" // ABlockTransferSrcAccessOrder
|
|
",2" // ABlockTransferSrcVectorDim
|
|
",8" // ABlockTransferSrcScalarPerVector
|
|
",8" // ABlockTransferDstScalarPerVector_AK1
|
|
",0" // ABlockLdsExtraM
|
|
",Seq(8,32,1)" // BBlockTransferThreadClusterLengths
|
|
",Seq(1,0,2)" // BBlockTransferThreadClusterArrangeOrder
|
|
",Seq(1,0,2)" // BBlockTransferSrcAccessOrder
|
|
",2" // BBlockTransferSrcVectorDim
|
|
",8" // BBlockTransferSrcScalarPerVector
|
|
",8" // BBlockTransferDstScalarPerVector_BK1
|
|
",0" // BBlockLdsExtraN
|
|
",1" // CShuffleMXdlPerWavePerShuffle
|
|
",1" // CShuffleNXdlPerWavePerShuffle
|
|
",Seq(1,32,1,8)" // CDEBlockTransferClusterLengths
|
|
",8" // CDEBlockTransferScalarPerVector_NPerBlock
|
|
",Intrawave" // BlkGemmPipeSched
|
|
",v4" // BlkGemmPipelineVer
|
|
",fp16" // AComputeDataType
|
|
",fp16>"; // BComputeDataType
|
|
EXPECT_EQ(instance_str, expected_str);
|
|
}
|