[CK_BUILDER] Add compile-time reflection for a convolution instance (#3065)

* [CK_BILDER] Add compile-time reflection for a convolution instance

Introduce InstanceTraits template metaprogramming framework to enable runtime introspection of device kernel template parameters without requiring implementation knowledge. This reflection system extracts configuration details (block sizes, data types, layouts, tuning parameters) directly from kernel specializations through template
pattern matching. In particular, the GetInstanceString method returns a string that uniquely idenitfies the kernel, by explicitly serializing all template paramter values.

This provides critical functionality for MIOpen integration, since the existing GetTypeString method is ambiguous, and only captures some of the template paramters.

The implementation uses a two-level design: a primary InstanceTraits template declaration in instance_traits.hpp serves as the interface, while kernel-specific specializations (e.g., for DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3) provide the actual extraction logic. This separation allows the reflection system to scale to additional kernel types without modifying the core interface.

Key architectural decisions:

- Forward-declare device kernels in instance_traits.hpp to avoid  circular dependencies, since device implementation headers will  include the reflection headers

- Use compile-time constants and type aliases to expose kernel  parameters, enabling zero-overhead introspection

- Provide a templated instance_string() function that generates human-readable  kernel configuration strings by serializing all template parameters  in order, useful for debugging and kernel identification

- Guard reflection integration with preprocessor definition CK_EXPERIMENTAL_BUILDER to keep  it opt-in until the API stabilizes

- Add GetInstanceString() virtual method to BaseOperator, allowing  runtime polymorphic access to compile-time kernel information

This infrastructure also enables upcoming higher-level semantic reflection abstractions (like ConvTraits) to query kernel configurations programmatically.

Includes unit tests validating both the trait extraction accuracy and the string generation format.

[ROCm/composable_kernel commit: 37dff024c1]
This commit is contained in:
John Shumway
2025-10-21 21:10:19 -07:00
committed by GitHub
parent 4f83a3d745
commit 8f48205046
9 changed files with 1005 additions and 2 deletions

View File

@@ -40,6 +40,11 @@ option(MIOPEN_REQ_LIBS_ONLY "Build only the MIOpen required libraries" OFF)
option(CK_EXPERIMENTAL_BUILDER "Enable experimental builder" OFF)
option(BUILD_MHA_LIB "Build the static library for flash attention" OFF)
if(CK_EXPERIMENTAL_BUILDER)
add_definitions(-DCK_EXPERIMENTAL_BUILDER)
include_directories(${PROJECT_SOURCE_DIR}/experimental/builder/include)
endif()
# Usage: for customized Python location cmake -DCK_USE_ALTERNATIVE_PYTHON="/opt/Python-3.8.13/bin/python3.8"
# CK Codegen requires dataclass which is added in Python 3.7
# Python version 3.8 is required for general good practice as it is default for Ubuntu 20.04

View File

@@ -0,0 +1,58 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
// Compile-time reflection for CK device kernel instances.
//
// - This is the Lowest-level reflection primitive for higher-level semantic abstractions (e.g.,
// ConvTraits).
// - Extracts raw template parameters (block sizes, data types, layouts, tuning params) from kernel
// specializations.
// - Provides uniform interface to query kernel configuration without implementation knowledge
// - Other details about the device kernels can be manually added to template specializations.
// - Currently supports:
// - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
#pragma once
#include <array>
#include <string>
#include <sstream>
#include <type_traits>
#include <ck/utility/data_type.hpp>
#include <ck/utility/sequence.hpp>
#include <ck/utility/blkgemmpipe_scheduler.hpp>
#include <ck/tensor_operation/gpu/device/tensor_layout.hpp>
#include <ck/tensor_operation/gpu/element/element_wise_operation.hpp>
#include <ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp>
#include <ck/tensor_operation/gpu/device/gemm_specialization.hpp>
#include "instance_traits_util.hpp"
namespace ck_tile::reflect {
// Primary template for InstanceTraits - extracts compile-time information directly from
// device kernel instances (e.g., DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3)
//
// This is an unspecialized template declaration. Actual specializations for specific
// device kernels are provided in separate header files (e.g.,
// instance_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp).
template <typename Instance>
struct InstanceTraits;
// Concept-based helper to detect if InstanceTraits<T> is specialized
// (i.e., has the instance_string() member function).
// This can be used for an informative static_assert in the device-op GetInstanceString in case the
// instance_string() template is broken.
template <typename T>
concept HasInstanceTraits = requires {
{ InstanceTraits<T>::instance_string() } -> std::convertible_to<std::string>;
};
// Free function that delegates to InstanceTraits static member function.
// Each InstanceTraits specialization provides its own instance_string() implementation.
template <typename T>
inline std::string instance_string()
{
return InstanceTraits<T>::instance_string();
}
} // namespace ck_tile::reflect

View File

@@ -0,0 +1,345 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
// InstanceTraits specialization for DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
//
// CRITICAL MAINTENANCE NOTE:
// This InstanceTraits file MUST be kept strictly in sync with the device implementation header:
// ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp
// "In sync" means that the template parameter order, names, and types in the declaration below
// MUST EXACTLY MATCH those in the device implementation. If these diverge, you may encounter
// compilation errors, subtle template instantiation mismatches, or silent runtime bugs that are
// difficult to diagnose. Always update both files together and review changes carefully.
// ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp
#pragma once
#include "instance_traits.hpp"
// Forward declaration to avoid circular dependency.
// This file will be included by the device implementation header, so we cannot include
// the implementation header here. We only need the template signature to pattern-match
// on template parameters - we don't need any implementation details.
namespace ck::tensor_operation::device {
template <ck::index_t NDimSpatial,
typename ALayout,
typename BLayout,
typename DsLayout,
typename ELayout,
typename ADataType,
typename BDataType,
typename AccDataType,
typename CShuffleDataType,
typename DsDataType,
typename EDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CDEElementwiseOperation,
ConvolutionForwardSpecialization ConvForwardSpecialization,
GemmSpecialization GemmSpec,
ck::index_t BlockSize,
ck::index_t MPerBlock,
ck::index_t NPerBlock,
ck::index_t KPerBlock,
ck::index_t AK1,
ck::index_t BK1,
ck::index_t MPerXDL,
ck::index_t NPerXDL,
ck::index_t MXdlPerWave,
ck::index_t NXdlPerWave,
typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder,
ck::index_t ABlockTransferSrcVectorDim,
ck::index_t ABlockTransferSrcScalarPerVector,
ck::index_t ABlockTransferDstScalarPerVector_AK1,
ck::index_t ABlockLdsExtraM,
typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
typename BBlockTransferThreadClusterArrangeOrder,
typename BBlockTransferSrcAccessOrder,
ck::index_t BBlockTransferSrcVectorDim,
ck::index_t BBlockTransferSrcScalarPerVector,
ck::index_t BBlockTransferDstScalarPerVector_BK1,
ck::index_t BBlockLdsExtraN,
ck::index_t CShuffleMXdlPerWavePerShuffle,
ck::index_t CShuffleNXdlPerWavePerShuffle,
typename CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
ck::index_t CDEBlockTransferScalarPerVector_NPerBlock,
ck::BlockGemmPipelineScheduler BlkGemmPipeSched,
ck::BlockGemmPipelineVersion BlkGemmPipelineVer,
typename AComputeDataType,
typename BComputeDataType>
struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3;
} // namespace ck::tensor_operation::device
namespace ck_tile::reflect {
// Specialization for DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
template <ck::index_t NDimSpatial,
typename ALayout_,
typename BLayout_,
typename DsLayout_,
typename ELayout_,
typename ADataType_,
typename BDataType_,
typename AccDataType_,
typename CShuffleDataType_,
typename DsDataType_,
typename EDataType_,
typename AElementwiseOperation_,
typename BElementwiseOperation_,
typename CDEElementwiseOperation_,
ck::tensor_operation::device::ConvolutionForwardSpecialization ConvForwardSpecialization,
ck::tensor_operation::device::GemmSpecialization GemmSpec,
ck::index_t BlockSize,
ck::index_t MPerBlock,
ck::index_t NPerBlock,
ck::index_t KPerBlock,
ck::index_t AK1,
ck::index_t BK1,
ck::index_t MPerXDL,
ck::index_t NPerXDL,
ck::index_t MXdlPerWave,
ck::index_t NXdlPerWave,
typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder_,
ck::index_t ABlockTransferSrcVectorDim,
ck::index_t ABlockTransferSrcScalarPerVector,
ck::index_t ABlockTransferDstScalarPerVector_AK1,
ck::index_t ABlockLdsExtraM,
typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
typename BBlockTransferThreadClusterArrangeOrder,
typename BBlockTransferSrcAccessOrder_,
ck::index_t BBlockTransferSrcVectorDim,
ck::index_t BBlockTransferSrcScalarPerVector,
ck::index_t BBlockTransferDstScalarPerVector_BK1,
ck::index_t BBlockLdsExtraN,
ck::index_t CShuffleMXdlPerWavePerShuffle,
ck::index_t CShuffleNXdlPerWavePerShuffle,
typename CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
ck::index_t CDEBlockTransferScalarPerVector_NPerBlock,
ck::BlockGemmPipelineScheduler BlkGemmPipeSched,
ck::BlockGemmPipelineVersion BlkGemmPipelineVer,
typename AComputeDataType_,
typename BComputeDataType_>
struct InstanceTraits<ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<
NDimSpatial,
ALayout_,
BLayout_,
DsLayout_,
ELayout_,
ADataType_,
BDataType_,
AccDataType_,
CShuffleDataType_,
DsDataType_,
EDataType_,
AElementwiseOperation_,
BElementwiseOperation_,
CDEElementwiseOperation_,
ConvForwardSpecialization,
GemmSpec,
BlockSize,
MPerBlock,
NPerBlock,
KPerBlock,
AK1,
BK1,
MPerXDL,
NPerXDL,
MXdlPerWave,
NXdlPerWave,
ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder_,
ABlockTransferSrcVectorDim,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_AK1,
ABlockLdsExtraM,
BBlockTransferThreadClusterLengths_BK0_N_BK1,
BBlockTransferThreadClusterArrangeOrder,
BBlockTransferSrcAccessOrder_,
BBlockTransferSrcVectorDim,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_BK1,
BBlockLdsExtraN,
CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle,
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
CDEBlockTransferScalarPerVector_NPerBlock,
BlkGemmPipeSched,
BlkGemmPipelineVer,
AComputeDataType_,
BComputeDataType_>>
{
// Spatial dimension
static constexpr int kSpatialDim = NDimSpatial;
// Layout types
using ALayout = ALayout_;
using BLayout = BLayout_;
using DsLayout = DsLayout_;
using ELayout = ELayout_;
// Data types
using ADataType = ADataType_;
using BDataType = BDataType_;
using AccDataType = AccDataType_;
using CShuffleDataType = CShuffleDataType_;
using DsDataType = DsDataType_;
using EDataType = EDataType_;
// Element-wise operations
using AElementwiseOperation = AElementwiseOperation_;
using BElementwiseOperation = BElementwiseOperation_;
using CDEElementwiseOperation = CDEElementwiseOperation_;
// Specialization
static constexpr ck::tensor_operation::device::ConvolutionForwardSpecialization
kConvForwardSpecialization = ConvForwardSpecialization;
static constexpr ck::tensor_operation::device::GemmSpecialization kGemmSpecialization =
GemmSpec;
// Block configuration
static constexpr int kBlockSize = BlockSize;
static constexpr int kMPerBlock = MPerBlock;
static constexpr int kNPerBlock = NPerBlock;
static constexpr int kKPerBlock = KPerBlock;
// Tuning parameters
static constexpr int kAK1 = AK1;
static constexpr int kBK1 = BK1;
static constexpr int kMPerXDL = MPerXDL;
static constexpr int kNPerXDL = NPerXDL;
static constexpr int kMXdlPerWave = MXdlPerWave;
static constexpr int kNXdlPerWave = NXdlPerWave;
// A block transfer thread cluster dimensions (converted to std::array)
static constexpr auto kAThreadClusterLengths =
detail::SequenceToArray<ABlockTransferThreadClusterLengths_AK0_M_AK1>::value;
static constexpr auto kAThreadClusterArrangeOrder =
detail::SequenceToArray<ABlockTransferThreadClusterArrangeOrder>::value;
static constexpr auto kABlockTransferSrcAccessOrder =
detail::SequenceToArray<ABlockTransferSrcAccessOrder_>::value;
static constexpr int kABlockTransferSrcVectorDim = ABlockTransferSrcVectorDim;
static constexpr int kABlockTransferSrcScalarPerVector = ABlockTransferSrcScalarPerVector;
static constexpr int kABlockTransferDstScalarPerVectorK1 = ABlockTransferDstScalarPerVector_AK1;
static constexpr int kABlockLdsExtraM = ABlockLdsExtraM;
// B block transfer thread cluster dimensions (converted to std::array)
static constexpr auto kBThreadClusterLengths =
detail::SequenceToArray<BBlockTransferThreadClusterLengths_BK0_N_BK1>::value;
static constexpr auto kBThreadClusterArrangeOrder =
detail::SequenceToArray<BBlockTransferThreadClusterArrangeOrder>::value;
static constexpr auto kBBlockTransferSrcAccessOrder =
detail::SequenceToArray<BBlockTransferSrcAccessOrder_>::value;
static constexpr int kBBlockTransferSrcVectorDim = BBlockTransferSrcVectorDim;
static constexpr int kBBlockTransferSrcScalarPerVector = BBlockTransferSrcScalarPerVector;
static constexpr int kBBlockTransferDstScalarPerVectorK1 = BBlockTransferDstScalarPerVector_BK1;
static constexpr int kBBlockLdsExtraN = BBlockLdsExtraN;
// C shuffle parameters (converted to std::array)
static constexpr int kCShuffleMXdlPerWavePerShuffle = CShuffleMXdlPerWavePerShuffle;
static constexpr int kCShuffleNXdlPerWavePerShuffle = CShuffleNXdlPerWavePerShuffle;
static constexpr auto kCThreadClusterLengths = detail::SequenceToArray<
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock>::value;
static constexpr int kCBlockTransferScalarPerVector = CDEBlockTransferScalarPerVector_NPerBlock;
// Pipeline configuration
static constexpr ck::BlockGemmPipelineScheduler kPipelineScheduler = BlkGemmPipeSched;
static constexpr ck::BlockGemmPipelineVersion kPipelineVersion = BlkGemmPipelineVer;
// Compute data types
using AComputeDataType = AComputeDataType_;
using BComputeDataType = BComputeDataType_;
// Static member function to generate instance string
static std::string instance_string()
{
std::ostringstream oss;
// Kernel type name
oss << "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3";
// Template parameters in exact order matching InstanceTraits member order
oss << "<" << kSpatialDim; // 1. NDimSpatial
oss << "," << detail::layout_name<ALayout>(); // 2. ALayout
oss << "," << detail::layout_name<BLayout>(); // 3. BLayout
oss << "," << detail::tuple_name<DsLayout>(); // 4. DsLayout
oss << "," << detail::layout_name<ELayout>(); // 5. ELayout
oss << "," << detail::type_name<ADataType>(); // 6. ADataType
oss << "," << detail::type_name<BDataType>(); // 7. BDataType
oss << "," << detail::type_name<AccDataType>(); // 8. AccDataType
oss << "," << detail::type_name<CShuffleDataType>(); // 9. CShuffleDataType
oss << "," << detail::tuple_name<DsDataType>(); // 10. DsDataType
oss << "," << detail::type_name<EDataType>(); // 11. EDataType
oss << ","
<< detail::elementwise_op_name<AElementwiseOperation>(); // 12. AElementwiseOperation
oss << ","
<< detail::elementwise_op_name<BElementwiseOperation>(); // 13. BElementwiseOperation
oss << ","
<< detail::elementwise_op_name<CDEElementwiseOperation>(); // 14.
// CDEElementwiseOperation
oss << ","
<< detail::conv_fwd_spec_name(
kConvForwardSpecialization); // 15. ConvForwardSpecialization
oss << "," << detail::gemm_spec_name(kGemmSpecialization); // 16. GemmSpec
oss << "," << kBlockSize; // 17. BlockSize
oss << "," << kMPerBlock; // 18. MPerBlock
oss << "," << kNPerBlock; // 19. NPerBlock
oss << "," << kKPerBlock; // 20. KPerBlock
oss << "," << kAK1; // 21. AK1
oss << "," << kBK1; // 22. BK1
oss << "," << kMPerXDL; // 23. MPerXDL
oss << "," << kNPerXDL; // 24. NPerXDL
oss << "," << kMXdlPerWave; // 25. MXdlPerWave
oss << "," << kNXdlPerWave; // 26. NXdlPerWave
oss << ","
<< detail::array_to_string(
kAThreadClusterLengths); // 27. ABlockTransferThreadClusterLengths
oss << ","
<< detail::array_to_string(
kAThreadClusterArrangeOrder); // 28. ABlockTransferThreadClusterArrangeOrder
oss << ","
<< detail::array_to_string(
kABlockTransferSrcAccessOrder); // 29. ABlockTransferSrcAccessOrder
oss << "," << kABlockTransferSrcVectorDim; // 30. ABlockTransferSrcVectorDim
oss << "," << kABlockTransferSrcScalarPerVector; // 31. ABlockTransferSrcScalarPerVector
oss << ","
<< kABlockTransferDstScalarPerVectorK1; // 32. ABlockTransferDstScalarPerVector_AK1
oss << "," << kABlockLdsExtraM; // 33. ABlockLdsExtraM
oss << ","
<< detail::array_to_string(
kBThreadClusterLengths); // 34. BBlockTransferThreadClusterLengths
oss << ","
<< detail::array_to_string(
kBThreadClusterArrangeOrder); // 35. BBlockTransferThreadClusterArrangeOrder
oss << ","
<< detail::array_to_string(
kBBlockTransferSrcAccessOrder); // 36. BBlockTransferSrcAccessOrder
oss << "," << kBBlockTransferSrcVectorDim; // 37. BBlockTransferSrcVectorDim
oss << "," << kBBlockTransferSrcScalarPerVector; // 38. BBlockTransferSrcScalarPerVector
oss << ","
<< kBBlockTransferDstScalarPerVectorK1; // 39. BBlockTransferDstScalarPerVector_BK1
oss << "," << kBBlockLdsExtraN; // 40. BBlockLdsExtraN
oss << "," << kCShuffleMXdlPerWavePerShuffle; // 41. CShuffleMXdlPerWavePerShuffle
oss << "," << kCShuffleNXdlPerWavePerShuffle; // 42. CShuffleNXdlPerWavePerShuffle
oss << ","
<< detail::array_to_string(
kCThreadClusterLengths); // 43. CDEBlockTransferClusterLengths
oss << ","
<< kCBlockTransferScalarPerVector; // 44. CDEBlockTransferScalarPerVector_NPerBlock
oss << "," << detail::pipeline_scheduler_name(kPipelineScheduler); // 45. BlkGemmPipeSched
oss << "," << detail::pipeline_version_name(kPipelineVersion); // 46. BlkGemmPipelineVer
oss << "," << detail::type_name<AComputeDataType>(); // 47. AComputeDataType
oss << "," << detail::type_name<BComputeDataType>(); // 48. BComputeDataType
oss << ">";
return oss.str();
}
};
} // namespace ck_tile::reflect

View File

@@ -0,0 +1,195 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
// Utility functions and helpers for instance_traits.hpp
// Contains helper functions to convert types, enums, and sequences to string representations.
// The helper function are consteval so that unknown cases cause compile-time errors.
#pragma once
#include <array>
#include <string>
#include <string_view>
#include <sstream>
#include <type_traits>
#include <ck/utility/data_type.hpp>
#include <ck/utility/sequence.hpp>
#include <ck/utility/blkgemmpipe_scheduler.hpp>
#include <ck/tensor_operation/gpu/device/tensor_layout.hpp>
#include <ck/tensor_operation/gpu/element/element_wise_operation.hpp>
#include <ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp>
#include <ck/tensor_operation/gpu/device/gemm_specialization.hpp>
namespace ck_tile::reflect::detail {
// Metaprogramming helper to convert ck::Sequence to constexpr std::array
template <typename Seq>
struct SequenceToArray;
template <ck::index_t... Is>
struct SequenceToArray<ck::Sequence<Is...>>
{
static constexpr std::array<int, sizeof...(Is)> value = {static_cast<int>(Is)...};
};
// Convert data types to string names
template <typename T>
consteval std::string_view type_name()
{
if constexpr(std::is_same_v<T, ck::half_t>)
return "fp16";
else if constexpr(std::is_same_v<T, float>)
return "fp32";
else if constexpr(std::is_same_v<T, double>)
return "fp64";
else if constexpr(std::is_same_v<T, int8_t>)
return "s8";
else if constexpr(std::is_same_v<T, int32_t>)
return "s32";
else if constexpr(std::is_same_v<T, ck::bhalf_t>)
return "bf16";
else if constexpr(std::is_same_v<T, ck::f8_t>)
return "fp8";
else if constexpr(std::is_same_v<T, ck::bf8_t>)
return "bf8";
else
static_assert(false, "unknown_type");
}
// Convert layout types to string names
template <typename T>
constexpr std::string_view layout_name()
{
// Convolution layouts
if constexpr(std::is_same_v<T, ck::tensor_layout::convolution::GNHWC>)
return "GNHWC";
else if constexpr(std::is_same_v<T, ck::tensor_layout::convolution::GKYXC>)
return "GKYXC";
else if constexpr(std::is_same_v<T, ck::tensor_layout::convolution::GNHWK>)
return "GNHWK";
else if constexpr(std::is_same_v<T, ck::tensor_layout::convolution::GKZYXC>)
return "GKZYXC";
else if constexpr(std::is_same_v<T, ck::tensor_layout::convolution::GNDHWC>)
return "GNDHWC";
else if constexpr(std::is_same_v<T, ck::tensor_layout::convolution::GNDHWK>)
return "GNDHWK";
else if constexpr(std::is_same_v<T, ck::tensor_layout::convolution::NHWGC>)
return "NHWGC";
else if constexpr(std::is_same_v<T, ck::tensor_layout::convolution::KYXGC>)
return "KYXGC";
else if constexpr(std::is_same_v<T, ck::tensor_layout::convolution::NHWGK>)
return "NHWGK";
else
static_assert(false, "unknown_layout");
}
// Convert element-wise operation types to string names
template <typename T>
constexpr std::string_view elementwise_op_name()
{
if constexpr(std::is_same_v<T, ck::tensor_operation::element_wise::PassThrough>)
return "PassThrough";
else if constexpr(std::is_same_v<T, ck::tensor_operation::element_wise::Scale>)
return "Scale";
else if constexpr(std::is_same_v<T, ck::tensor_operation::element_wise::Bilinear>)
return "Bilinear";
else if constexpr(std::is_same_v<T, ck::tensor_operation::element_wise::Add>)
return "Add";
else if constexpr(std::is_same_v<T, ck::tensor_operation::element_wise::AddRelu>)
return "AddRelu";
else if constexpr(std::is_same_v<T, ck::tensor_operation::element_wise::Relu>)
return "Relu";
else
static_assert(false, "unknown_op");
}
// Convert ConvolutionForwardSpecialization enum to string
constexpr std::string_view
conv_fwd_spec_name(ck::tensor_operation::device::ConvolutionForwardSpecialization spec)
{
using ck::tensor_operation::device::ConvolutionForwardSpecialization;
switch(spec)
{
case ConvolutionForwardSpecialization::Default: return "Default";
case ConvolutionForwardSpecialization::Filter1x1Stride1Pad0: return "Filter1x1Stride1Pad0";
case ConvolutionForwardSpecialization::Filter1x1Pad0: return "Filter1x1Pad0";
case ConvolutionForwardSpecialization::Filter3x3: return "Filter3x3";
case ConvolutionForwardSpecialization::OddC: return "OddC";
}
}
// Convert GemmSpecialization enum to string
constexpr std::string_view gemm_spec_name(ck::tensor_operation::device::GemmSpecialization spec)
{
using ck::tensor_operation::device::GemmSpecialization;
switch(spec)
{
case GemmSpecialization::Default: return "Default";
case GemmSpecialization::MPadding: return "MPadding";
case GemmSpecialization::NPadding: return "NPadding";
case GemmSpecialization::KPadding: return "KPadding";
case GemmSpecialization::MNPadding: return "MNPadding";
case GemmSpecialization::MKPadding: return "MKPadding";
case GemmSpecialization::NKPadding: return "NKPadding";
case GemmSpecialization::MNKPadding: return "MNKPadding";
case GemmSpecialization::OPadding: return "OPadding";
case GemmSpecialization::MOPadding: return "MOPadding";
case GemmSpecialization::NOPadding: return "NOPadding";
case GemmSpecialization::KOPadding: return "KOPadding";
case GemmSpecialization::MNOPadding: return "MNOPadding";
case GemmSpecialization::MKOPadding: return "MKOPadding";
case GemmSpecialization::NKOPadding: return "NKOPadding";
case GemmSpecialization::MNKOPadding: return "MNKOPadding";
}
}
// Convert BlockGemmPipelineScheduler enum to string
constexpr std::string_view pipeline_scheduler_name(ck::BlockGemmPipelineScheduler sched)
{
using ck::BlockGemmPipelineScheduler;
switch(sched)
{
case BlockGemmPipelineScheduler::Intrawave: return "Intrawave";
case BlockGemmPipelineScheduler::Interwave: return "Interwave";
}
}
// Convert BlockGemmPipelineVersion enum to string
constexpr std::string_view pipeline_version_name(ck::BlockGemmPipelineVersion ver)
{
using ck::BlockGemmPipelineVersion;
switch(ver)
{
case BlockGemmPipelineVersion::v1: return "v1";
case BlockGemmPipelineVersion::v2: return "v2";
case BlockGemmPipelineVersion::v3: return "v3";
case BlockGemmPipelineVersion::v4: return "v4";
case BlockGemmPipelineVersion::v5: return "v5";
}
}
// Convert std::array to string
template <typename T, std::size_t N>
inline std::string array_to_string(const std::array<T, N>& arr)
{
std::ostringstream oss;
oss << "Seq(";
for(std::size_t i = 0; i < arr.size(); ++i)
{
if(i > 0)
oss << ",";
oss << arr[i];
}
oss << ")";
return oss.str();
}
// Handle ck::Tuple (empty tuple for DsLayout/DsDataType)
template <typename T>
constexpr std::string_view tuple_name()
{
// For now, just check if it's an empty tuple
return "EmptyTuple";
}
} // namespace ck_tile::reflect::detail

View File

@@ -1,4 +1,3 @@
include(gtest)
# Helper function to create a gtest executable with common properties
@@ -17,4 +16,8 @@ function(add_ck_builder_test test_name)
endfunction()
add_ck_builder_test(test_conv_builder
test_conv_builder.cpp)
test_conv_builder.cpp
test_instance_traits.cpp)
add_ck_builder_test(test_get_instance_string
test_get_instance_string.cpp)

View File

@@ -0,0 +1,104 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include <gtest/gtest.h>
#include <ck_tile/builder/reflect/instance_traits.hpp>
#include <ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_abd.hpp>
#include <ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_comp_instance.hpp>
// Test GetInstanceString through base class pointer
TEST(GetInstanceStringTest, GetInstanceStringThroughBaseClass)
{
// Use the template helper to get a working instance configuration
using InstanceTuple =
ck::tensor_operation::device::instance::device_grouped_conv_fwd_xdl_f16_comp_instances<
2, // NDimSpatial
ck::tensor_operation::device::instance::GNHWC, // ALayout
ck::tensor_operation::device::instance::GKYXC, // BLayout
ck::tensor_operation::device::instance::Empty_Tuple, // DsLayout
ck::tensor_operation::device::instance::GNHWK, // ELayout
ck::tensor_operation::device::instance::ConvFwdDefault>; // ConvForwardSpecialization
// Get the first instance from the tuple
using DeviceInstance = typename std::tuple_element<0, InstanceTuple>::type;
// Define the base class type using DeviceGroupedConvFwdMultipleABD
using BaseClass = ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD<
2, // NDimSpatial
ck::tensor_operation::device::instance::GNHWC, // ALayout
ck::tensor_operation::device::instance::GKYXC, // BLayout
ck::tensor_operation::device::instance::Empty_Tuple, // DsLayout
ck::tensor_operation::device::instance::GNHWK, // ELayout
ck::half_t, // ADataType
ck::half_t, // BDataType
ck::Tuple<>, // DsDataType
ck::half_t, // EDataType
ck::tensor_operation::element_wise::PassThrough, // AElementwiseOperation
ck::tensor_operation::element_wise::PassThrough, // BElementwiseOperation
ck::tensor_operation::element_wise::PassThrough, // CDEElementwiseOperation
ck::half_t, // AComputeType
ck::half_t>; // BComputeType
// Create an instance of the derived class
DeviceInstance device_instance;
// Get a pointer to the base class
BaseClass* base_ptr = &device_instance;
// Call GetInstanceString through the base class pointer
std::string instance_str = base_ptr->GetInstanceString();
// Expected complete instance string based on the first instance from
// device_grouped_conv_fwd_xdl_f16_comp_instances This corresponds to the configuration with
// BlockSize=256, MPerBlock=128, NPerBlock=128, KPerBlock=64, etc.
std::string expected_str = "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3"
"<2" // NDimSpatial
",GNHWC" // ALayout
",GKYXC" // BLayout
",EmptyTuple" // DsLayout
",GNHWK" // ELayout
",fp16" // ADataType
",fp16" // BDataType
",fp32" // AccDataType
",fp16" // CShuffleDataType
",EmptyTuple" // DsDataType
",fp16" // EDataType
",PassThrough" // AElementwiseOperation
",PassThrough" // BElementwiseOperation
",PassThrough" // CDEElementwiseOperation
",Default" // ConvForwardSpecialization
",MNKPadding" // GemmSpec
",256" // BlockSize
",128" // MPerBlock
",128" // NPerBlock
",64" // KPerBlock
",8" // AK1
",8" // BK1
",32" // MPerXDL
",32" // NPerXDL
",2" // MXdlPerWave
",2" // NXdlPerWave
",Seq(8,32,1)" // ABlockTransferThreadClusterLengths
",Seq(1,0,2)" // ABlockTransferThreadClusterArrangeOrder
",Seq(1,0,2)" // ABlockTransferSrcAccessOrder
",2" // ABlockTransferSrcVectorDim
",8" // ABlockTransferSrcScalarPerVector
",8" // ABlockTransferDstScalarPerVector_AK1
",0" // ABlockLdsExtraM
",Seq(8,32,1)" // BBlockTransferThreadClusterLengths
",Seq(1,0,2)" // BBlockTransferThreadClusterArrangeOrder
",Seq(1,0,2)" // BBlockTransferSrcAccessOrder
",2" // BBlockTransferSrcVectorDim
",8" // BBlockTransferSrcScalarPerVector
",8" // BBlockTransferDstScalarPerVector_BK1
",0" // BBlockLdsExtraN
",1" // CShuffleMXdlPerWavePerShuffle
",1" // CShuffleNXdlPerWavePerShuffle
",Seq(1,32,1,8)" // CDEBlockTransferClusterLengths
",8" // CDEBlockTransferScalarPerVector_NPerBlock
",Intrawave" // BlkGemmPipeSched
",v4" // BlkGemmPipelineVer
",fp16" // AComputeDataType
",fp16>"; // BComputeDataType
EXPECT_EQ(instance_str, expected_str);
}

View File

@@ -0,0 +1,276 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include <gtest/gtest.h>
#include <gmock/gmock.h>
#include <ck_tile/builder/reflect/instance_traits.hpp>
#include <ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp>
namespace {
using ::testing::ElementsAre;
// Test fixture for InstanceTraits tests
class InstanceTraitsTest : public ::testing::Test
{
};
// Test InstanceTraits with DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
TEST_F(InstanceTraitsTest, ConvFwdInstanceTraitsExtraction)
{
// Define a concrete instance type with specific template parameters
using DeviceInstance =
ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<
2, // NDimSpatial
ck::tensor_layout::convolution::GNHWC, // ALayout
ck::tensor_layout::convolution::GKYXC, // BLayout
ck::Tuple<>, // DsLayout
ck::tensor_layout::convolution::GNHWK, // ELayout
ck::half_t, // ADataType
ck::half_t, // BDataType
float, // AccDataType
ck::half_t, // CShuffleDataType
ck::Tuple<>, // DsDataType
ck::half_t, // EDataType
ck::tensor_operation::element_wise::PassThrough, // AElementwiseOperation
ck::tensor_operation::element_wise::PassThrough, // BElementwiseOperation
ck::tensor_operation::element_wise::PassThrough, // CDEElementwiseOperation
ck::tensor_operation::device::ConvolutionForwardSpecialization::
Default, // ConvForwardSpecialization
ck::tensor_operation::device::GemmSpecialization::Default, // GemmSpec
256, // BlockSize
128, // MPerBlock
128, // NPerBlock
16, // KPerBlock
8, // AK1
8, // BK1
32, // MPerXDL
32, // NPerXDL
4, // MXdlPerWave
4, // NXdlPerWave
ck::Sequence<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1
ck::Sequence<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder
ck::Sequence<1, 0, 2>, // ABlockTransferSrcAccessOrder
2, // ABlockTransferSrcVectorDim
8, // ABlockTransferSrcScalarPerVector
8, // ABlockTransferDstScalarPerVector_AK1
1, // ABlockLdsExtraM
ck::Sequence<4, 64, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1
ck::Sequence<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder
ck::Sequence<1, 0, 2>, // BBlockTransferSrcAccessOrder
2, // BBlockTransferSrcVectorDim
8, // BBlockTransferSrcScalarPerVector
8, // BBlockTransferDstScalarPerVector_BK1
1, // BBlockLdsExtraN
1, // CShuffleMXdlPerWavePerShuffle
1, // CShuffleNXdlPerWavePerShuffle
ck::Sequence<1,
32,
1,
8>, // CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8, // CDEBlockTransferScalarPerVector_NPerBlock
ck::BlockGemmPipelineScheduler::Intrawave, // BlkGemmPipeSched
ck::BlockGemmPipelineVersion::v1, // BlkGemmPipelineVer
ck::half_t, // AComputeDataType
ck::half_t>; // BComputeDataType
// Use InstanceTraits to extract compile-time information
using Traits = ck_tile::reflect::InstanceTraits<DeviceInstance>;
// Verify spatial dimension
EXPECT_EQ(Traits::kSpatialDim, 2);
// Verify block configuration
EXPECT_EQ(Traits::kBlockSize, 256);
EXPECT_EQ(Traits::kMPerBlock, 128);
EXPECT_EQ(Traits::kNPerBlock, 128);
EXPECT_EQ(Traits::kKPerBlock, 16);
// Verify tuning parameters
EXPECT_EQ(Traits::kAK1, 8);
EXPECT_EQ(Traits::kBK1, 8);
EXPECT_EQ(Traits::kMPerXDL, 32);
EXPECT_EQ(Traits::kNPerXDL, 32);
EXPECT_EQ(Traits::kMXdlPerWave, 4);
EXPECT_EQ(Traits::kNXdlPerWave, 4);
// Verify A block transfer parameters
EXPECT_EQ(Traits::kABlockTransferSrcVectorDim, 2);
EXPECT_EQ(Traits::kABlockTransferSrcScalarPerVector, 8);
EXPECT_EQ(Traits::kABlockTransferDstScalarPerVectorK1, 8);
EXPECT_EQ(Traits::kABlockLdsExtraM, 1);
// Verify B block transfer parameters
EXPECT_EQ(Traits::kBBlockTransferSrcVectorDim, 2);
EXPECT_EQ(Traits::kBBlockTransferSrcScalarPerVector, 8);
EXPECT_EQ(Traits::kBBlockTransferDstScalarPerVectorK1, 8);
EXPECT_EQ(Traits::kBBlockLdsExtraN, 1);
// Verify C shuffle parameters
EXPECT_EQ(Traits::kCShuffleMXdlPerWavePerShuffle, 1);
EXPECT_EQ(Traits::kCShuffleNXdlPerWavePerShuffle, 1);
EXPECT_EQ(Traits::kCBlockTransferScalarPerVector, 8);
// Verify pipeline configuration
EXPECT_EQ(Traits::kPipelineScheduler, ck::BlockGemmPipelineScheduler::Intrawave);
EXPECT_EQ(Traits::kPipelineVersion, ck::BlockGemmPipelineVersion::v1);
// Verify data types using std::is_same
EXPECT_TRUE((std::is_same<Traits::ADataType, ck::half_t>::value));
EXPECT_TRUE((std::is_same<Traits::BDataType, ck::half_t>::value));
EXPECT_TRUE((std::is_same<Traits::AccDataType, float>::value));
EXPECT_TRUE((std::is_same<Traits::EDataType, ck::half_t>::value));
// Verify layout types
EXPECT_TRUE((std::is_same<Traits::ALayout, ck::tensor_layout::convolution::GNHWC>::value));
EXPECT_TRUE((std::is_same<Traits::BLayout, ck::tensor_layout::convolution::GKYXC>::value));
EXPECT_TRUE((std::is_same<Traits::ELayout, ck::tensor_layout::convolution::GNHWK>::value));
// Verify all array values for thread cluster lengths using googlemock matchers
EXPECT_THAT(Traits::kAThreadClusterLengths, ElementsAre(4, 64, 1));
EXPECT_THAT(Traits::kBThreadClusterLengths, ElementsAre(4, 64, 1));
EXPECT_THAT(Traits::kCThreadClusterLengths, ElementsAre(1, 32, 1, 8));
// Verify A block transfer arrange order and access order arrays
EXPECT_THAT(Traits::kAThreadClusterArrangeOrder, ElementsAre(1, 0, 2));
EXPECT_THAT(Traits::kABlockTransferSrcAccessOrder, ElementsAre(1, 0, 2));
// Verify B block transfer arrange order and access order arrays
EXPECT_THAT(Traits::kBThreadClusterArrangeOrder, ElementsAre(1, 0, 2));
EXPECT_THAT(Traits::kBBlockTransferSrcAccessOrder, ElementsAre(1, 0, 2));
// Verify additional data types
EXPECT_TRUE((std::is_same<Traits::CShuffleDataType, ck::half_t>::value));
EXPECT_TRUE((std::is_same<Traits::DsDataType, ck::Tuple<>>::value));
EXPECT_TRUE((std::is_same<Traits::AComputeDataType, ck::half_t>::value));
EXPECT_TRUE((std::is_same<Traits::BComputeDataType, ck::half_t>::value));
// Verify additional layout types
EXPECT_TRUE((std::is_same<Traits::DsLayout, ck::Tuple<>>::value));
// Verify element-wise operations
EXPECT_TRUE((std::is_same<Traits::AElementwiseOperation,
ck::tensor_operation::element_wise::PassThrough>::value));
EXPECT_TRUE((std::is_same<Traits::BElementwiseOperation,
ck::tensor_operation::element_wise::PassThrough>::value));
EXPECT_TRUE((std::is_same<Traits::CDEElementwiseOperation,
ck::tensor_operation::element_wise::PassThrough>::value));
}
// Test instance_string function
TEST_F(InstanceTraitsTest, InstanceStringGeneration)
{
// Define a concrete instance type with specific template parameters
using DeviceInstance =
ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<
2, // NDimSpatial
ck::tensor_layout::convolution::GNHWC, // ALayout
ck::tensor_layout::convolution::GKYXC, // BLayout
ck::Tuple<>, // DsLayout
ck::tensor_layout::convolution::GNHWK, // ELayout
ck::half_t, // ADataType
ck::half_t, // BDataType
float, // AccDataType
ck::half_t, // CShuffleDataType
ck::Tuple<>, // DsDataType
ck::half_t, // EDataType
ck::tensor_operation::element_wise::PassThrough, // AElementwiseOperation
ck::tensor_operation::element_wise::PassThrough, // BElementwiseOperation
ck::tensor_operation::element_wise::PassThrough, // CDEElementwiseOperation
ck::tensor_operation::device::ConvolutionForwardSpecialization::
Default, // ConvForwardSpecialization
ck::tensor_operation::device::GemmSpecialization::Default, // GemmSpec
256, // BlockSize
128, // MPerBlock
128, // NPerBlock
16, // KPerBlock
8, // AK1
8, // BK1
32, // MPerXDL
32, // NPerXDL
4, // MXdlPerWave
4, // NXdlPerWave
ck::Sequence<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1
ck::Sequence<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder
ck::Sequence<1, 0, 2>, // ABlockTransferSrcAccessOrder
2, // ABlockTransferSrcVectorDim
8, // ABlockTransferSrcScalarPerVector
8, // ABlockTransferDstScalarPerVector_AK1
1, // ABlockLdsExtraM
ck::Sequence<4, 64, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1
ck::Sequence<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder
ck::Sequence<1, 0, 2>, // BBlockTransferSrcAccessOrder
2, // BBlockTransferSrcVectorDim
8, // BBlockTransferSrcScalarPerVector
8, // BBlockTransferDstScalarPerVector_BK1
1, // BBlockLdsExtraN
1, // CShuffleMXdlPerWavePerShuffle
1, // CShuffleNXdlPerWavePerShuffle
ck::Sequence<1,
32,
1,
8>, // CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8, // CDEBlockTransferScalarPerVector_NPerBlock
ck::BlockGemmPipelineScheduler::Intrawave, // BlkGemmPipeSched
ck::BlockGemmPipelineVersion::v1, // BlkGemmPipelineVer
ck::half_t, // AComputeDataType
ck::half_t>; // BComputeDataType
// Generate instance string
std::string instance_str = ck_tile::reflect::instance_string<DeviceInstance>();
// Expected string with all template parameters in exact order
std::string expected_str = "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3"
"<2" // NDimSpatial
",GNHWC" // ALayout
",GKYXC" // BLayout
",EmptyTuple" // DsLayout
",GNHWK" // ELayout
",fp16" // ADataType
",fp16" // BDataType
",fp32" // AccDataType
",fp16" // CShuffleDataType
",EmptyTuple" // DsDataType
",fp16" // EDataType
",PassThrough" // AElementwiseOperation
",PassThrough" // BElementwiseOperation
",PassThrough" // CDEElementwiseOperation
",Default" // ConvForwardSpecialization
",Default" // GemmSpec
",256" // BlockSize
",128" // MPerBlock
",128" // NPerBlock
",16" // KPerBlock
",8" // AK1
",8" // BK1
",32" // MPerXDL
",32" // NPerXDL
",4" // MXdlPerWave
",4" // NXdlPerWave
",Seq(4,64,1)" // ABlockTransferThreadClusterLengths
",Seq(1,0,2)" // ABlockTransferThreadClusterArrangeOrder
",Seq(1,0,2)" // ABlockTransferSrcAccessOrder
",2" // ABlockTransferSrcVectorDim
",8" // ABlockTransferSrcScalarPerVector
",8" // ABlockTransferDstScalarPerVector_AK1
",1" // ABlockLdsExtraM
",Seq(4,64,1)" // BBlockTransferThreadClusterLengths
",Seq(1,0,2)" // BBlockTransferThreadClusterArrangeOrder
",Seq(1,0,2)" // BBlockTransferSrcAccessOrder
",2" // BBlockTransferSrcVectorDim
",8" // BBlockTransferSrcScalarPerVector
",8" // BBlockTransferDstScalarPerVector_BK1
",1" // BBlockLdsExtraN
",1" // CShuffleMXdlPerWavePerShuffle
",1" // CShuffleNXdlPerWavePerShuffle
",Seq(1,32,1,8)" // CDEBlockTransferClusterLengths
",8" // CDEBlockTransferScalarPerVector_NPerBlock
",Intrawave" // BlkGemmPipeSched
",v1" // BlkGemmPipelineVer
",fp16" // AComputeDataType
",fp16>"; // BComputeDataType
// Verify the generated string matches exactly
EXPECT_EQ(instance_str, expected_str);
}
} // anonymous namespace

View File

@@ -227,6 +227,7 @@ struct BaseOperator
#if !defined(__HIPCC_RTC__) || !defined(CK_CODE_GEN_RTC)
virtual bool IsSupportedArgument(const BaseArgument*) { return false; }
virtual std::string GetTypeString() const { return ""; }
virtual std::string GetInstanceString() const { return ""; }
virtual std::string GetTypeIdName() const { return typeid(*this).name(); }

View File

@@ -28,6 +28,9 @@
#include "ck/host_utility/kernel_launch.hpp"
#include "ck/host_utility/flush_cache.hpp"
#include "ck/host_utility/io.hpp"
#ifdef CK_EXPERIMENTAL_BUILDER
#include "ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp"
#endif
namespace ck {
namespace tensor_operation {
@@ -1994,6 +1997,19 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
return str.str();
}
#ifdef CK_EXPERIMENTAL_BUILDER
std::string GetInstanceString() const override
{
static_assert(ck_tile::reflect::HasInstanceTraits<DeviceOp>,
"Specialization of instance_traits not found. Please check that a "
"specialization exists in file "
"ck_tile/builder/reflect/"
"instance_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp "
"for the given template parameters.");
return ck_tile::reflect::instance_string<DeviceOp>();
}
#endif
size_t GetWorkSpaceSize(const BaseArgument* p_arg) const override
{
auto arg = dynamic_cast<const Argument*>(p_arg);