[CK_BUILDER] Add compile-time reflection for a convolution instance (#3065)

* [CK_BILDER] Add compile-time reflection for a convolution instance

Introduce InstanceTraits template metaprogramming framework to enable runtime introspection of device kernel template parameters without requiring implementation knowledge. This reflection system extracts configuration details (block sizes, data types, layouts, tuning parameters) directly from kernel specializations through template
pattern matching. In particular, the GetInstanceString method returns a string that uniquely idenitfies the kernel, by explicitly serializing all template paramter values.

This provides critical functionality for MIOpen integration, since the existing GetTypeString method is ambiguous, and only captures some of the template paramters.

The implementation uses a two-level design: a primary InstanceTraits template declaration in instance_traits.hpp serves as the interface, while kernel-specific specializations (e.g., for DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3) provide the actual extraction logic. This separation allows the reflection system to scale to additional kernel types without modifying the core interface.

Key architectural decisions:

- Forward-declare device kernels in instance_traits.hpp to avoid  circular dependencies, since device implementation headers will  include the reflection headers

- Use compile-time constants and type aliases to expose kernel  parameters, enabling zero-overhead introspection

- Provide a templated instance_string() function that generates human-readable  kernel configuration strings by serializing all template parameters  in order, useful for debugging and kernel identification

- Guard reflection integration with preprocessor definition CK_EXPERIMENTAL_BUILDER to keep  it opt-in until the API stabilizes

- Add GetInstanceString() virtual method to BaseOperator, allowing  runtime polymorphic access to compile-time kernel information

This infrastructure also enables upcoming higher-level semantic reflection abstractions (like ConvTraits) to query kernel configurations programmatically.

Includes unit tests validating both the trait extraction accuracy and the string generation format.
This commit is contained in:
John Shumway
2025-10-21 21:10:19 -07:00
committed by GitHub
parent 3a28632b20
commit 37dff024c1
9 changed files with 1005 additions and 2 deletions

View File

@@ -0,0 +1,58 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
// Compile-time reflection for CK device kernel instances.
//
// - This is the Lowest-level reflection primitive for higher-level semantic abstractions (e.g.,
// ConvTraits).
// - Extracts raw template parameters (block sizes, data types, layouts, tuning params) from kernel
// specializations.
// - Provides uniform interface to query kernel configuration without implementation knowledge
// - Other details about the device kernels can be manually added to template specializations.
// - Currently supports:
// - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
#pragma once
#include <array>
#include <string>
#include <sstream>
#include <type_traits>
#include <ck/utility/data_type.hpp>
#include <ck/utility/sequence.hpp>
#include <ck/utility/blkgemmpipe_scheduler.hpp>
#include <ck/tensor_operation/gpu/device/tensor_layout.hpp>
#include <ck/tensor_operation/gpu/element/element_wise_operation.hpp>
#include <ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp>
#include <ck/tensor_operation/gpu/device/gemm_specialization.hpp>
#include "instance_traits_util.hpp"
namespace ck_tile::reflect {
// Primary template for InstanceTraits - extracts compile-time information directly from
// device kernel instances (e.g., DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3)
//
// This is an unspecialized template declaration. Actual specializations for specific
// device kernels are provided in separate header files (e.g.,
// instance_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp).
template <typename Instance>
struct InstanceTraits;
// Concept-based helper to detect if InstanceTraits<T> is specialized
// (i.e., has the instance_string() member function).
// This can be used for an informative static_assert in the device-op GetInstanceString in case the
// instance_string() template is broken.
template <typename T>
concept HasInstanceTraits = requires {
{ InstanceTraits<T>::instance_string() } -> std::convertible_to<std::string>;
};
// Free function that delegates to InstanceTraits static member function.
// Each InstanceTraits specialization provides its own instance_string() implementation.
template <typename T>
inline std::string instance_string()
{
return InstanceTraits<T>::instance_string();
}
} // namespace ck_tile::reflect

View File

@@ -0,0 +1,345 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
// InstanceTraits specialization for DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
//
// CRITICAL MAINTENANCE NOTE:
// This InstanceTraits file MUST be kept strictly in sync with the device implementation header:
// ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp
// "In sync" means that the template parameter order, names, and types in the declaration below
// MUST EXACTLY MATCH those in the device implementation. If these diverge, you may encounter
// compilation errors, subtle template instantiation mismatches, or silent runtime bugs that are
// difficult to diagnose. Always update both files together and review changes carefully.
// ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp
#pragma once
#include "instance_traits.hpp"
// Forward declaration to avoid circular dependency.
// This file will be included by the device implementation header, so we cannot include
// the implementation header here. We only need the template signature to pattern-match
// on template parameters - we don't need any implementation details.
namespace ck::tensor_operation::device {
template <ck::index_t NDimSpatial,
typename ALayout,
typename BLayout,
typename DsLayout,
typename ELayout,
typename ADataType,
typename BDataType,
typename AccDataType,
typename CShuffleDataType,
typename DsDataType,
typename EDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CDEElementwiseOperation,
ConvolutionForwardSpecialization ConvForwardSpecialization,
GemmSpecialization GemmSpec,
ck::index_t BlockSize,
ck::index_t MPerBlock,
ck::index_t NPerBlock,
ck::index_t KPerBlock,
ck::index_t AK1,
ck::index_t BK1,
ck::index_t MPerXDL,
ck::index_t NPerXDL,
ck::index_t MXdlPerWave,
ck::index_t NXdlPerWave,
typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder,
ck::index_t ABlockTransferSrcVectorDim,
ck::index_t ABlockTransferSrcScalarPerVector,
ck::index_t ABlockTransferDstScalarPerVector_AK1,
ck::index_t ABlockLdsExtraM,
typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
typename BBlockTransferThreadClusterArrangeOrder,
typename BBlockTransferSrcAccessOrder,
ck::index_t BBlockTransferSrcVectorDim,
ck::index_t BBlockTransferSrcScalarPerVector,
ck::index_t BBlockTransferDstScalarPerVector_BK1,
ck::index_t BBlockLdsExtraN,
ck::index_t CShuffleMXdlPerWavePerShuffle,
ck::index_t CShuffleNXdlPerWavePerShuffle,
typename CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
ck::index_t CDEBlockTransferScalarPerVector_NPerBlock,
ck::BlockGemmPipelineScheduler BlkGemmPipeSched,
ck::BlockGemmPipelineVersion BlkGemmPipelineVer,
typename AComputeDataType,
typename BComputeDataType>
struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3;
} // namespace ck::tensor_operation::device
namespace ck_tile::reflect {
// Specialization for DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
template <ck::index_t NDimSpatial,
typename ALayout_,
typename BLayout_,
typename DsLayout_,
typename ELayout_,
typename ADataType_,
typename BDataType_,
typename AccDataType_,
typename CShuffleDataType_,
typename DsDataType_,
typename EDataType_,
typename AElementwiseOperation_,
typename BElementwiseOperation_,
typename CDEElementwiseOperation_,
ck::tensor_operation::device::ConvolutionForwardSpecialization ConvForwardSpecialization,
ck::tensor_operation::device::GemmSpecialization GemmSpec,
ck::index_t BlockSize,
ck::index_t MPerBlock,
ck::index_t NPerBlock,
ck::index_t KPerBlock,
ck::index_t AK1,
ck::index_t BK1,
ck::index_t MPerXDL,
ck::index_t NPerXDL,
ck::index_t MXdlPerWave,
ck::index_t NXdlPerWave,
typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder_,
ck::index_t ABlockTransferSrcVectorDim,
ck::index_t ABlockTransferSrcScalarPerVector,
ck::index_t ABlockTransferDstScalarPerVector_AK1,
ck::index_t ABlockLdsExtraM,
typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
typename BBlockTransferThreadClusterArrangeOrder,
typename BBlockTransferSrcAccessOrder_,
ck::index_t BBlockTransferSrcVectorDim,
ck::index_t BBlockTransferSrcScalarPerVector,
ck::index_t BBlockTransferDstScalarPerVector_BK1,
ck::index_t BBlockLdsExtraN,
ck::index_t CShuffleMXdlPerWavePerShuffle,
ck::index_t CShuffleNXdlPerWavePerShuffle,
typename CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
ck::index_t CDEBlockTransferScalarPerVector_NPerBlock,
ck::BlockGemmPipelineScheduler BlkGemmPipeSched,
ck::BlockGemmPipelineVersion BlkGemmPipelineVer,
typename AComputeDataType_,
typename BComputeDataType_>
struct InstanceTraits<ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<
NDimSpatial,
ALayout_,
BLayout_,
DsLayout_,
ELayout_,
ADataType_,
BDataType_,
AccDataType_,
CShuffleDataType_,
DsDataType_,
EDataType_,
AElementwiseOperation_,
BElementwiseOperation_,
CDEElementwiseOperation_,
ConvForwardSpecialization,
GemmSpec,
BlockSize,
MPerBlock,
NPerBlock,
KPerBlock,
AK1,
BK1,
MPerXDL,
NPerXDL,
MXdlPerWave,
NXdlPerWave,
ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder_,
ABlockTransferSrcVectorDim,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_AK1,
ABlockLdsExtraM,
BBlockTransferThreadClusterLengths_BK0_N_BK1,
BBlockTransferThreadClusterArrangeOrder,
BBlockTransferSrcAccessOrder_,
BBlockTransferSrcVectorDim,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_BK1,
BBlockLdsExtraN,
CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle,
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
CDEBlockTransferScalarPerVector_NPerBlock,
BlkGemmPipeSched,
BlkGemmPipelineVer,
AComputeDataType_,
BComputeDataType_>>
{
// Spatial dimension
static constexpr int kSpatialDim = NDimSpatial;
// Layout types
using ALayout = ALayout_;
using BLayout = BLayout_;
using DsLayout = DsLayout_;
using ELayout = ELayout_;
// Data types
using ADataType = ADataType_;
using BDataType = BDataType_;
using AccDataType = AccDataType_;
using CShuffleDataType = CShuffleDataType_;
using DsDataType = DsDataType_;
using EDataType = EDataType_;
// Element-wise operations
using AElementwiseOperation = AElementwiseOperation_;
using BElementwiseOperation = BElementwiseOperation_;
using CDEElementwiseOperation = CDEElementwiseOperation_;
// Specialization
static constexpr ck::tensor_operation::device::ConvolutionForwardSpecialization
kConvForwardSpecialization = ConvForwardSpecialization;
static constexpr ck::tensor_operation::device::GemmSpecialization kGemmSpecialization =
GemmSpec;
// Block configuration
static constexpr int kBlockSize = BlockSize;
static constexpr int kMPerBlock = MPerBlock;
static constexpr int kNPerBlock = NPerBlock;
static constexpr int kKPerBlock = KPerBlock;
// Tuning parameters
static constexpr int kAK1 = AK1;
static constexpr int kBK1 = BK1;
static constexpr int kMPerXDL = MPerXDL;
static constexpr int kNPerXDL = NPerXDL;
static constexpr int kMXdlPerWave = MXdlPerWave;
static constexpr int kNXdlPerWave = NXdlPerWave;
// A block transfer thread cluster dimensions (converted to std::array)
static constexpr auto kAThreadClusterLengths =
detail::SequenceToArray<ABlockTransferThreadClusterLengths_AK0_M_AK1>::value;
static constexpr auto kAThreadClusterArrangeOrder =
detail::SequenceToArray<ABlockTransferThreadClusterArrangeOrder>::value;
static constexpr auto kABlockTransferSrcAccessOrder =
detail::SequenceToArray<ABlockTransferSrcAccessOrder_>::value;
static constexpr int kABlockTransferSrcVectorDim = ABlockTransferSrcVectorDim;
static constexpr int kABlockTransferSrcScalarPerVector = ABlockTransferSrcScalarPerVector;
static constexpr int kABlockTransferDstScalarPerVectorK1 = ABlockTransferDstScalarPerVector_AK1;
static constexpr int kABlockLdsExtraM = ABlockLdsExtraM;
// B block transfer thread cluster dimensions (converted to std::array)
static constexpr auto kBThreadClusterLengths =
detail::SequenceToArray<BBlockTransferThreadClusterLengths_BK0_N_BK1>::value;
static constexpr auto kBThreadClusterArrangeOrder =
detail::SequenceToArray<BBlockTransferThreadClusterArrangeOrder>::value;
static constexpr auto kBBlockTransferSrcAccessOrder =
detail::SequenceToArray<BBlockTransferSrcAccessOrder_>::value;
static constexpr int kBBlockTransferSrcVectorDim = BBlockTransferSrcVectorDim;
static constexpr int kBBlockTransferSrcScalarPerVector = BBlockTransferSrcScalarPerVector;
static constexpr int kBBlockTransferDstScalarPerVectorK1 = BBlockTransferDstScalarPerVector_BK1;
static constexpr int kBBlockLdsExtraN = BBlockLdsExtraN;
// C shuffle parameters (converted to std::array)
static constexpr int kCShuffleMXdlPerWavePerShuffle = CShuffleMXdlPerWavePerShuffle;
static constexpr int kCShuffleNXdlPerWavePerShuffle = CShuffleNXdlPerWavePerShuffle;
static constexpr auto kCThreadClusterLengths = detail::SequenceToArray<
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock>::value;
static constexpr int kCBlockTransferScalarPerVector = CDEBlockTransferScalarPerVector_NPerBlock;
// Pipeline configuration
static constexpr ck::BlockGemmPipelineScheduler kPipelineScheduler = BlkGemmPipeSched;
static constexpr ck::BlockGemmPipelineVersion kPipelineVersion = BlkGemmPipelineVer;
// Compute data types
using AComputeDataType = AComputeDataType_;
using BComputeDataType = BComputeDataType_;
// Static member function to generate instance string
static std::string instance_string()
{
std::ostringstream oss;
// Kernel type name
oss << "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3";
// Template parameters in exact order matching InstanceTraits member order
oss << "<" << kSpatialDim; // 1. NDimSpatial
oss << "," << detail::layout_name<ALayout>(); // 2. ALayout
oss << "," << detail::layout_name<BLayout>(); // 3. BLayout
oss << "," << detail::tuple_name<DsLayout>(); // 4. DsLayout
oss << "," << detail::layout_name<ELayout>(); // 5. ELayout
oss << "," << detail::type_name<ADataType>(); // 6. ADataType
oss << "," << detail::type_name<BDataType>(); // 7. BDataType
oss << "," << detail::type_name<AccDataType>(); // 8. AccDataType
oss << "," << detail::type_name<CShuffleDataType>(); // 9. CShuffleDataType
oss << "," << detail::tuple_name<DsDataType>(); // 10. DsDataType
oss << "," << detail::type_name<EDataType>(); // 11. EDataType
oss << ","
<< detail::elementwise_op_name<AElementwiseOperation>(); // 12. AElementwiseOperation
oss << ","
<< detail::elementwise_op_name<BElementwiseOperation>(); // 13. BElementwiseOperation
oss << ","
<< detail::elementwise_op_name<CDEElementwiseOperation>(); // 14.
// CDEElementwiseOperation
oss << ","
<< detail::conv_fwd_spec_name(
kConvForwardSpecialization); // 15. ConvForwardSpecialization
oss << "," << detail::gemm_spec_name(kGemmSpecialization); // 16. GemmSpec
oss << "," << kBlockSize; // 17. BlockSize
oss << "," << kMPerBlock; // 18. MPerBlock
oss << "," << kNPerBlock; // 19. NPerBlock
oss << "," << kKPerBlock; // 20. KPerBlock
oss << "," << kAK1; // 21. AK1
oss << "," << kBK1; // 22. BK1
oss << "," << kMPerXDL; // 23. MPerXDL
oss << "," << kNPerXDL; // 24. NPerXDL
oss << "," << kMXdlPerWave; // 25. MXdlPerWave
oss << "," << kNXdlPerWave; // 26. NXdlPerWave
oss << ","
<< detail::array_to_string(
kAThreadClusterLengths); // 27. ABlockTransferThreadClusterLengths
oss << ","
<< detail::array_to_string(
kAThreadClusterArrangeOrder); // 28. ABlockTransferThreadClusterArrangeOrder
oss << ","
<< detail::array_to_string(
kABlockTransferSrcAccessOrder); // 29. ABlockTransferSrcAccessOrder
oss << "," << kABlockTransferSrcVectorDim; // 30. ABlockTransferSrcVectorDim
oss << "," << kABlockTransferSrcScalarPerVector; // 31. ABlockTransferSrcScalarPerVector
oss << ","
<< kABlockTransferDstScalarPerVectorK1; // 32. ABlockTransferDstScalarPerVector_AK1
oss << "," << kABlockLdsExtraM; // 33. ABlockLdsExtraM
oss << ","
<< detail::array_to_string(
kBThreadClusterLengths); // 34. BBlockTransferThreadClusterLengths
oss << ","
<< detail::array_to_string(
kBThreadClusterArrangeOrder); // 35. BBlockTransferThreadClusterArrangeOrder
oss << ","
<< detail::array_to_string(
kBBlockTransferSrcAccessOrder); // 36. BBlockTransferSrcAccessOrder
oss << "," << kBBlockTransferSrcVectorDim; // 37. BBlockTransferSrcVectorDim
oss << "," << kBBlockTransferSrcScalarPerVector; // 38. BBlockTransferSrcScalarPerVector
oss << ","
<< kBBlockTransferDstScalarPerVectorK1; // 39. BBlockTransferDstScalarPerVector_BK1
oss << "," << kBBlockLdsExtraN; // 40. BBlockLdsExtraN
oss << "," << kCShuffleMXdlPerWavePerShuffle; // 41. CShuffleMXdlPerWavePerShuffle
oss << "," << kCShuffleNXdlPerWavePerShuffle; // 42. CShuffleNXdlPerWavePerShuffle
oss << ","
<< detail::array_to_string(
kCThreadClusterLengths); // 43. CDEBlockTransferClusterLengths
oss << ","
<< kCBlockTransferScalarPerVector; // 44. CDEBlockTransferScalarPerVector_NPerBlock
oss << "," << detail::pipeline_scheduler_name(kPipelineScheduler); // 45. BlkGemmPipeSched
oss << "," << detail::pipeline_version_name(kPipelineVersion); // 46. BlkGemmPipelineVer
oss << "," << detail::type_name<AComputeDataType>(); // 47. AComputeDataType
oss << "," << detail::type_name<BComputeDataType>(); // 48. BComputeDataType
oss << ">";
return oss.str();
}
};
} // namespace ck_tile::reflect

View File

@@ -0,0 +1,195 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
// Utility functions and helpers for instance_traits.hpp
// Contains helper functions to convert types, enums, and sequences to string representations.
// The helper function are consteval so that unknown cases cause compile-time errors.
#pragma once
#include <array>
#include <string>
#include <string_view>
#include <sstream>
#include <type_traits>
#include <ck/utility/data_type.hpp>
#include <ck/utility/sequence.hpp>
#include <ck/utility/blkgemmpipe_scheduler.hpp>
#include <ck/tensor_operation/gpu/device/tensor_layout.hpp>
#include <ck/tensor_operation/gpu/element/element_wise_operation.hpp>
#include <ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp>
#include <ck/tensor_operation/gpu/device/gemm_specialization.hpp>
namespace ck_tile::reflect::detail {
// Metaprogramming helper to convert ck::Sequence to constexpr std::array
template <typename Seq>
struct SequenceToArray;
template <ck::index_t... Is>
struct SequenceToArray<ck::Sequence<Is...>>
{
static constexpr std::array<int, sizeof...(Is)> value = {static_cast<int>(Is)...};
};
// Convert data types to string names
template <typename T>
consteval std::string_view type_name()
{
if constexpr(std::is_same_v<T, ck::half_t>)
return "fp16";
else if constexpr(std::is_same_v<T, float>)
return "fp32";
else if constexpr(std::is_same_v<T, double>)
return "fp64";
else if constexpr(std::is_same_v<T, int8_t>)
return "s8";
else if constexpr(std::is_same_v<T, int32_t>)
return "s32";
else if constexpr(std::is_same_v<T, ck::bhalf_t>)
return "bf16";
else if constexpr(std::is_same_v<T, ck::f8_t>)
return "fp8";
else if constexpr(std::is_same_v<T, ck::bf8_t>)
return "bf8";
else
static_assert(false, "unknown_type");
}
// Convert layout types to string names
template <typename T>
constexpr std::string_view layout_name()
{
// Convolution layouts
if constexpr(std::is_same_v<T, ck::tensor_layout::convolution::GNHWC>)
return "GNHWC";
else if constexpr(std::is_same_v<T, ck::tensor_layout::convolution::GKYXC>)
return "GKYXC";
else if constexpr(std::is_same_v<T, ck::tensor_layout::convolution::GNHWK>)
return "GNHWK";
else if constexpr(std::is_same_v<T, ck::tensor_layout::convolution::GKZYXC>)
return "GKZYXC";
else if constexpr(std::is_same_v<T, ck::tensor_layout::convolution::GNDHWC>)
return "GNDHWC";
else if constexpr(std::is_same_v<T, ck::tensor_layout::convolution::GNDHWK>)
return "GNDHWK";
else if constexpr(std::is_same_v<T, ck::tensor_layout::convolution::NHWGC>)
return "NHWGC";
else if constexpr(std::is_same_v<T, ck::tensor_layout::convolution::KYXGC>)
return "KYXGC";
else if constexpr(std::is_same_v<T, ck::tensor_layout::convolution::NHWGK>)
return "NHWGK";
else
static_assert(false, "unknown_layout");
}
// Convert element-wise operation types to string names
template <typename T>
constexpr std::string_view elementwise_op_name()
{
if constexpr(std::is_same_v<T, ck::tensor_operation::element_wise::PassThrough>)
return "PassThrough";
else if constexpr(std::is_same_v<T, ck::tensor_operation::element_wise::Scale>)
return "Scale";
else if constexpr(std::is_same_v<T, ck::tensor_operation::element_wise::Bilinear>)
return "Bilinear";
else if constexpr(std::is_same_v<T, ck::tensor_operation::element_wise::Add>)
return "Add";
else if constexpr(std::is_same_v<T, ck::tensor_operation::element_wise::AddRelu>)
return "AddRelu";
else if constexpr(std::is_same_v<T, ck::tensor_operation::element_wise::Relu>)
return "Relu";
else
static_assert(false, "unknown_op");
}
// Convert ConvolutionForwardSpecialization enum to string
constexpr std::string_view
conv_fwd_spec_name(ck::tensor_operation::device::ConvolutionForwardSpecialization spec)
{
using ck::tensor_operation::device::ConvolutionForwardSpecialization;
switch(spec)
{
case ConvolutionForwardSpecialization::Default: return "Default";
case ConvolutionForwardSpecialization::Filter1x1Stride1Pad0: return "Filter1x1Stride1Pad0";
case ConvolutionForwardSpecialization::Filter1x1Pad0: return "Filter1x1Pad0";
case ConvolutionForwardSpecialization::Filter3x3: return "Filter3x3";
case ConvolutionForwardSpecialization::OddC: return "OddC";
}
}
// Convert GemmSpecialization enum to string
constexpr std::string_view gemm_spec_name(ck::tensor_operation::device::GemmSpecialization spec)
{
using ck::tensor_operation::device::GemmSpecialization;
switch(spec)
{
case GemmSpecialization::Default: return "Default";
case GemmSpecialization::MPadding: return "MPadding";
case GemmSpecialization::NPadding: return "NPadding";
case GemmSpecialization::KPadding: return "KPadding";
case GemmSpecialization::MNPadding: return "MNPadding";
case GemmSpecialization::MKPadding: return "MKPadding";
case GemmSpecialization::NKPadding: return "NKPadding";
case GemmSpecialization::MNKPadding: return "MNKPadding";
case GemmSpecialization::OPadding: return "OPadding";
case GemmSpecialization::MOPadding: return "MOPadding";
case GemmSpecialization::NOPadding: return "NOPadding";
case GemmSpecialization::KOPadding: return "KOPadding";
case GemmSpecialization::MNOPadding: return "MNOPadding";
case GemmSpecialization::MKOPadding: return "MKOPadding";
case GemmSpecialization::NKOPadding: return "NKOPadding";
case GemmSpecialization::MNKOPadding: return "MNKOPadding";
}
}
// Convert BlockGemmPipelineScheduler enum to string
constexpr std::string_view pipeline_scheduler_name(ck::BlockGemmPipelineScheduler sched)
{
using ck::BlockGemmPipelineScheduler;
switch(sched)
{
case BlockGemmPipelineScheduler::Intrawave: return "Intrawave";
case BlockGemmPipelineScheduler::Interwave: return "Interwave";
}
}
// Convert BlockGemmPipelineVersion enum to string
constexpr std::string_view pipeline_version_name(ck::BlockGemmPipelineVersion ver)
{
using ck::BlockGemmPipelineVersion;
switch(ver)
{
case BlockGemmPipelineVersion::v1: return "v1";
case BlockGemmPipelineVersion::v2: return "v2";
case BlockGemmPipelineVersion::v3: return "v3";
case BlockGemmPipelineVersion::v4: return "v4";
case BlockGemmPipelineVersion::v5: return "v5";
}
}
// Convert std::array to string
template <typename T, std::size_t N>
inline std::string array_to_string(const std::array<T, N>& arr)
{
std::ostringstream oss;
oss << "Seq(";
for(std::size_t i = 0; i < arr.size(); ++i)
{
if(i > 0)
oss << ",";
oss << arr[i];
}
oss << ")";
return oss.str();
}
// Handle ck::Tuple (empty tuple for DsLayout/DsDataType)
template <typename T>
constexpr std::string_view tuple_name()
{
// For now, just check if it's an empty tuple
return "EmptyTuple";
}
} // namespace ck_tile::reflect::detail