From a862155c9e931bd233bbecc3c76fd34d887fea66 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?M=C3=A1rton=20Bidlek?= Date: Mon, 9 Mar 2026 17:34:18 +0100 Subject: [PATCH] Proof of concept for removing forward declarations (#5135) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Motivation Currently, we forward declare CK device operation templates in CK-Builder's reflection code: https://github.com/ROCm/composable_kernel/blob/9b168082b7aa19bcf50fd9991baf10a0c77d105b/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_xdl_cshuffle.hpp#L13-L57 This is mainly required to break a circular dependency in reflection. The architecture of that is as follows: MyDeviceOp implements GetInstanceString(). This is typically defined directly in the class definition (no forward declaration). GetInstanceString() calls instance_string() instance_string() calls InstanceTraits::instance_string() InstanceTraits has a specialization for MyDeviceOp which implements instance_string() So order for GetInstanceString() to work properly, InstanceTraits must already be defined. And for InstanceTraits to be defined, the device op needs to be defined. In order to do that, we are currently using aforementioned forward declaration. ## Technical Details C++'s lazy template evaluation is used by calling into an as-of-yet undefined function static member function of `InstanceTraits` in `GetInstanceString()`, and then specializing `InstanceTraits` only _after that_. The caveat here is that both the device op itself as well as the instance traits specialization must be in scope, otherwise there would be an undefined function error. In practise, we can solve that either by placing the instance traits directly into the file that defines `MyDeviceOp`, or possibly by using a `.inc` file to keep the concerns separated. ## Test Plan The results were verified by running the existing regression tests for CK Builder ## Submission Checklist - [ ] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests. --------- Co-authored-by: Márton Bidlek --- ...e_grouped_conv_bwd_weight_xdl_cshuffle.hpp | 1 + ...e_grouped_conv_bwd_weight_xdl_cshuffle.hpp | 312 +----------------- ...e_grouped_conv_bwd_weight_xdl_cshuffle.inc | 267 +++++++++++++++ .../builder/test/conv/ck/test_conv_traits.cpp | 1 + .../test/test_bwd_weight_instance_traits.cpp | 2 +- ...e_grouped_conv_bwd_weight_xdl_cshuffle.hpp | 6 +- 6 files changed, 288 insertions(+), 301 deletions(-) create mode 100644 experimental/builder/include/ck_tile/builder/reflect/reflect_device_grouped_conv_bwd_weight_xdl_cshuffle.inc diff --git a/experimental/builder/include/ck_tile/builder/reflect/conv_traits_device_grouped_conv_bwd_weight_xdl_cshuffle.hpp b/experimental/builder/include/ck_tile/builder/reflect/conv_traits_device_grouped_conv_bwd_weight_xdl_cshuffle.hpp index d47b2ee4d3..a4e31539f7 100644 --- a/experimental/builder/include/ck_tile/builder/reflect/conv_traits_device_grouped_conv_bwd_weight_xdl_cshuffle.hpp +++ b/experimental/builder/include/ck_tile/builder/reflect/conv_traits_device_grouped_conv_bwd_weight_xdl_cshuffle.hpp @@ -5,6 +5,7 @@ #include +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp" #include "ck_tile/builder/reflect/conv_traits.hpp" #include "ck_tile/builder/reflect/conv_traits_helpers.hpp" #include "ck_tile/builder/reflect/instance_traits.hpp" diff --git a/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_xdl_cshuffle.hpp b/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_xdl_cshuffle.hpp index 1edf03740f..c49a773a27 100644 --- a/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_xdl_cshuffle.hpp +++ b/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_xdl_cshuffle.hpp @@ -1,312 +1,26 @@ // Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT +// The full InstanceTraits specialization is in: +// ck_tile/builder/reflect/reflect_device_grouped_conv_bwd_weight_xdl_cshuffle.inc +// +// CRITICAL MAINTENANCE NOTE: +// Keep the specialization in the .inc file strictly in sync with: +// ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp. +// "In sync" means that the template parameter order, names, and types MUST EXACTLY MATCH. If they +// diverge, you may encounter compilation errors, subtle template instantiation mismatches, or +// silent runtime bugs that are difficult to diagnose. Always update both files together and review +// changes carefully. + #pragma once #include "instance_traits.hpp" #include "instance_traits_util.hpp" -#include "ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp" -// Forward declaration to avoid circular dependency -namespace ck::tensor_operation::device { +namespace ck_tile::reflect { -template -struct DeviceGroupedConvBwdWeight_Xdl_CShuffle; - -} // namespace ck::tensor_operation::device - -namespace ck_tile { -namespace reflect { - -/// @brief Tag type for DeviceGroupedConvBwdWeight_Xdl_CShuffle device kernel struct DeviceGroupedConvBwdWeight_Xdl_CShuffle_Tag { }; -template -struct InstanceTraits> -{ - static constexpr auto kTensorOpName = "DeviceGroupedConvBwdWeight_Xdl_CShuffle"; - - static constexpr ck::index_t kSpatialDim = NDimSpatial; - using device_kernel_tag = DeviceGroupedConvBwdWeight_Xdl_CShuffle_Tag; - - using InLayout = InLayout_; - using WeiLayout = WeiLayout_; - using OutLayout = OutLayout_; - - using InDataType = InDataType_; - using WeiDataType = WeiDataType_; - using OutDataType = OutDataType_; - using AccDataType = AccDataType_; - - using InElementwiseOperation = InElementwiseOperation_; - using WeiElementwiseOperation = WeiElementwiseOperation_; - using OutElementwiseOperation = OutElementwiseOperation_; - - static constexpr auto kConvBwdWeightSpecialization = ConvBackwardWeightSpecialization; - static constexpr ck::index_t kBlockSize = BlockSize; - static constexpr ck::index_t kMPerBlock = MPerBlock; - static constexpr ck::index_t kNPerBlock = NPerBlock; - static constexpr ck::index_t kK0PerBlock = K0PerBlock; - static constexpr ck::index_t kK1 = K1; - static constexpr ck::index_t kMPerXDL = MPerXDL; - static constexpr ck::index_t kNPerXDL = NPerXDL; - static constexpr ck::index_t kMXdlPerWave = MXdlPerWave; - static constexpr ck::index_t kNXdlPerWave = NXdlPerWave; - - using ABlockTransferThreadClusterLengths_K0_M_K1 = ABlockTransferThreadClusterLengths_K0_M_K1_; - using ABlockTransferThreadClusterArrangeOrder = ABlockTransferThreadClusterArrangeOrder_; - using ABlockTransferSrcAccessOrder = ABlockTransferSrcAccessOrder_; - - // A block transfer thread cluster dimensions (converted to std::array) - static constexpr auto kAThreadClusterLengths = - detail::SequenceToArray::value; - static constexpr auto kAThreadClusterArrangeOrder = - detail::SequenceToArray::value; - static constexpr auto kABlockTransferSrcAccessOrder = - detail::SequenceToArray::value; - - static constexpr ck::index_t kABlockTransferSrcVectorDim = ABlockTransferSrcVectorDim; - static constexpr ck::index_t kABlockTransferSrcScalarPerVector = - ABlockTransferSrcScalarPerVector; - static constexpr ck::index_t kABlockTransferDstScalarPerVectorK1 = - ABlockTransferDstScalarPerVector_K1; - static constexpr bool kABlockLdsExtraM = ABlockLdsAddExtraM; - - using BBlockTransferThreadClusterLengths_K0_N_K1 = BBlockTransferThreadClusterLengths_K0_N_K1_; - using BBlockTransferThreadClusterArrangeOrder = BBlockTransferThreadClusterArrangeOrder_; - using BBlockTransferSrcAccessOrder = BBlockTransferSrcAccessOrder_; - - // B block transfer thread cluster dimensions (converted to std::array) - static constexpr auto kBThreadClusterLengths = - detail::SequenceToArray::value; - static constexpr auto kBThreadClusterArrangeOrder = - detail::SequenceToArray::value; - static constexpr auto kBBlockTransferSrcAccessOrder = - detail::SequenceToArray::value; - - static constexpr ck::index_t kBBlockTransferSrcVectorDim = BBlockTransferSrcVectorDim; - static constexpr ck::index_t kBBlockTransferSrcScalarPerVector = - BBlockTransferSrcScalarPerVector; - static constexpr ck::index_t kBBlockTransferDstScalarPerVectorK1 = - BBlockTransferDstScalarPerVector_K1; - static constexpr bool kBBlockLdsExtraN = BBlockLdsAddExtraN; - - static constexpr ck::index_t kCShuffleMXdlPerWavePerShuffle = CShuffleMXdlPerWavePerShuffle; - static constexpr ck::index_t kCShuffleNXdlPerWavePerShuffle = CShuffleNXdlPerWavePerShuffle; - - using CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock = - CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock_; - static constexpr auto kCThreadClusterLengths = detail::SequenceToArray< - CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock>::value; - - static constexpr ck::index_t kCBlockTransferScalarPerVector_NWaveNPerXdl = - CBlockTransferScalarPerVector_NWaveNPerXdl; - - using ComputeTypeA = ComputeTypeA_; - using ComputeTypeB = ComputeTypeB_; - - static constexpr ck::index_t kMaxTransposeTransferSrcScalarPerVector = - MaxTransposeTransferSrcScalarPerVector; - static constexpr ck::index_t kMaxTransposeTransferDstScalarPerVector = - MaxTransposeTransferDstScalarPerVector; - - // Static member function to generate instance string - static std::string instance_string() - { - std::ostringstream oss; - - // Kernel type name - oss << "DeviceGroupedConvBwdWeight_Xdl_CShuffle"; - - // Template parameters in exact order - oss << "<" << kSpatialDim; // 1. NDimSpatial - oss << "," << detail::layout_name(); // 2. InLayout - oss << "," << detail::layout_name(); // 3. WeiLayout - oss << "," << detail::layout_name(); // 4. OutLayout - oss << "," << detail::type_name(); // 5. InDataType - oss << "," << detail::type_name(); // 6. WeiDataType - oss << "," << detail::type_name(); // 7. OutDataType - oss << "," << detail::type_name(); // 8. AccDataType - oss << "," - << detail::elementwise_op_name(); // 9. InElementwiseOperation - oss << "," - << detail::elementwise_op_name(); // 10. - // WeiElementwiseOperation - oss << "," - << detail::elementwise_op_name(); // 11. - // OutElementwiseOperation - oss << "," - << detail::conv_bwd_weight_spec_name( - kConvBwdWeightSpecialization); // 12. ConvBackwardWeightSpecialization - oss << "," << kBlockSize; // 13. BlockSize - oss << "," << kMPerBlock; // 14. MPerBlock - oss << "," << kNPerBlock; // 15. NPerBlock - oss << "," << kK0PerBlock; // 16. K0PerBlock - oss << "," << kK1; // 17. K1 - oss << "," << kMPerXDL; // 18. MPerXDL - oss << "," << kNPerXDL; // 19. NPerXDL - oss << "," << kMXdlPerWave; // 20. MXdlPerWave - oss << "," << kNXdlPerWave; // 21. NXdlPerWave - oss << "," << detail::sequence_name(); // 22. - oss << "," << detail::sequence_name(); // 23. - oss << "," << detail::sequence_name(); // 24. - oss << "," << kABlockTransferSrcVectorDim; // 25. - oss << "," << kABlockTransferSrcScalarPerVector; // 26. - oss << "," << kABlockTransferDstScalarPerVectorK1; // 27. - oss << "," << (kABlockLdsExtraM ? "true" : "false"); // 28. - oss << "," << detail::sequence_name(); // 29. - oss << "," << detail::sequence_name(); // 30. - oss << "," << detail::sequence_name(); // 31. - oss << "," << kBBlockTransferSrcVectorDim; // 32. - oss << "," << kBBlockTransferSrcScalarPerVector; // 33. - oss << "," << kBBlockTransferDstScalarPerVectorK1; // 34. - oss << "," << (kBBlockLdsExtraN ? "true" : "false"); // 35. - oss << "," << kCShuffleMXdlPerWavePerShuffle; // 36. - oss << "," << kCShuffleNXdlPerWavePerShuffle; // 37. - oss << "," - << detail::sequence_name< - CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock>(); // 38. - oss << "," << kCBlockTransferScalarPerVector_NWaveNPerXdl; // 39. - oss << "," << detail::type_name(); // 40. - oss << "," << detail::type_name(); // 41. - oss << "," << kMaxTransposeTransferSrcScalarPerVector; // 42. - oss << "," << kMaxTransposeTransferDstScalarPerVector; // 43. - oss << ">"; - - return oss.str(); - } -}; - -} // namespace reflect -} // namespace ck_tile +} // namespace ck_tile::reflect diff --git a/experimental/builder/include/ck_tile/builder/reflect/reflect_device_grouped_conv_bwd_weight_xdl_cshuffle.inc b/experimental/builder/include/ck_tile/builder/reflect/reflect_device_grouped_conv_bwd_weight_xdl_cshuffle.inc new file mode 100644 index 0000000000..09691deb8c --- /dev/null +++ b/experimental/builder/include/ck_tile/builder/reflect/reflect_device_grouped_conv_bwd_weight_xdl_cshuffle.inc @@ -0,0 +1,267 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +/// @file +/// @brief Reflection `InstanceTraits` specialization for +/// `DeviceGroupedConvBwdWeight_Xdl_CShuffle`. + +#include +#include + +#include "ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp" +#include "ck_tile/builder/reflect/instance_traits.hpp" +#include "ck_tile/builder/reflect/instance_traits_util.hpp" + +// CRITICAL MAINTENANCE NOTE: +// Keep this template parameter list strictly in sync with: +// ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp. +// "In sync" means that the template parameter order, names, and types MUST EXACTLY MATCH. If they +// diverge, you may encounter compilation errors, subtle template instantiation mismatches, or +// silent runtime bugs that are difficult to diagnose. Always update both files together and review +// changes carefully. + +namespace ck_tile::reflect { + +template +struct InstanceTraits> +{ + static constexpr auto kTensorOpName = "DeviceGroupedConvBwdWeight_Xdl_CShuffle"; + + static constexpr ck::index_t kSpatialDim = NDimSpatial; + using device_kernel_tag = DeviceGroupedConvBwdWeight_Xdl_CShuffle_Tag; + + using InLayout = InLayout_; + using WeiLayout = WeiLayout_; + using OutLayout = OutLayout_; + + using InDataType = InDataType_; + using WeiDataType = WeiDataType_; + using OutDataType = OutDataType_; + using AccDataType = AccDataType_; + + using InElementwiseOperation = InElementwiseOperation_; + using WeiElementwiseOperation = WeiElementwiseOperation_; + using OutElementwiseOperation = OutElementwiseOperation_; + + static constexpr auto kConvBwdWeightSpecialization = ConvBackwardWeightSpecialization; + static constexpr ck::index_t kBlockSize = BlockSize; + static constexpr ck::index_t kMPerBlock = MPerBlock; + static constexpr ck::index_t kNPerBlock = NPerBlock; + static constexpr ck::index_t kK0PerBlock = K0PerBlock; + static constexpr ck::index_t kK1 = K1; + static constexpr ck::index_t kMPerXDL = MPerXDL; + static constexpr ck::index_t kNPerXDL = NPerXDL; + static constexpr ck::index_t kMXdlPerWave = MXdlPerWave; + static constexpr ck::index_t kNXdlPerWave = NXdlPerWave; + + using ABlockTransferThreadClusterLengths_K0_M_K1 = ABlockTransferThreadClusterLengths_K0_M_K1_; + using ABlockTransferThreadClusterArrangeOrder = ABlockTransferThreadClusterArrangeOrder_; + using ABlockTransferSrcAccessOrder = ABlockTransferSrcAccessOrder_; + + // A block transfer thread cluster dimensions (converted to std::array) + static constexpr auto kAThreadClusterLengths = + detail::SequenceToArray::value; + static constexpr auto kAThreadClusterArrangeOrder = + detail::SequenceToArray::value; + static constexpr auto kABlockTransferSrcAccessOrder = + detail::SequenceToArray::value; + + static constexpr ck::index_t kABlockTransferSrcVectorDim = ABlockTransferSrcVectorDim; + static constexpr ck::index_t kABlockTransferSrcScalarPerVector = + ABlockTransferSrcScalarPerVector; + static constexpr ck::index_t kABlockTransferDstScalarPerVectorK1 = + ABlockTransferDstScalarPerVector_K1; + static constexpr bool kABlockLdsExtraM = ABlockLdsAddExtraM; + + using BBlockTransferThreadClusterLengths_K0_N_K1 = BBlockTransferThreadClusterLengths_K0_N_K1_; + using BBlockTransferThreadClusterArrangeOrder = BBlockTransferThreadClusterArrangeOrder_; + using BBlockTransferSrcAccessOrder = BBlockTransferSrcAccessOrder_; + + // B block transfer thread cluster dimensions (converted to std::array) + static constexpr auto kBThreadClusterLengths = + detail::SequenceToArray::value; + static constexpr auto kBThreadClusterArrangeOrder = + detail::SequenceToArray::value; + static constexpr auto kBBlockTransferSrcAccessOrder = + detail::SequenceToArray::value; + + static constexpr ck::index_t kBBlockTransferSrcVectorDim = BBlockTransferSrcVectorDim; + static constexpr ck::index_t kBBlockTransferSrcScalarPerVector = + BBlockTransferSrcScalarPerVector; + static constexpr ck::index_t kBBlockTransferDstScalarPerVectorK1 = + BBlockTransferDstScalarPerVector_K1; + static constexpr bool kBBlockLdsExtraN = BBlockLdsAddExtraN; + + static constexpr ck::index_t kCShuffleMXdlPerWavePerShuffle = CShuffleMXdlPerWavePerShuffle; + static constexpr ck::index_t kCShuffleNXdlPerWavePerShuffle = CShuffleNXdlPerWavePerShuffle; + + using CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock = + CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock_; + static constexpr auto kCThreadClusterLengths = detail::SequenceToArray< + CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock>::value; + + static constexpr ck::index_t kCBlockTransferScalarPerVector_NWaveNPerXdl = + CBlockTransferScalarPerVector_NWaveNPerXdl; + + using ComputeTypeA = ComputeTypeA_; + using ComputeTypeB = ComputeTypeB_; + + static constexpr ck::index_t kMaxTransposeTransferSrcScalarPerVector = + MaxTransposeTransferSrcScalarPerVector; + static constexpr ck::index_t kMaxTransposeTransferDstScalarPerVector = + MaxTransposeTransferDstScalarPerVector; + + // Static member function to generate instance string + static std::string instance_string() + { + std::ostringstream oss; + + // Kernel type name + oss << "DeviceGroupedConvBwdWeight_Xdl_CShuffle"; + + // Template parameters in exact order + oss << "<" << kSpatialDim; // 1. NDimSpatial + oss << "," << detail::layout_name(); // 2. InLayout + oss << "," << detail::layout_name(); // 3. WeiLayout + oss << "," << detail::layout_name(); // 4. OutLayout + oss << "," << detail::type_name(); // 5. InDataType + oss << "," << detail::type_name(); // 6. WeiDataType + oss << "," << detail::type_name(); // 7. OutDataType + oss << "," << detail::type_name(); // 8. AccDataType + oss << "," + << detail::elementwise_op_name(); // 9. InElementwiseOperation + oss << "," + << detail::elementwise_op_name(); // 10. + // WeiElementwiseOperation + oss << "," + << detail::elementwise_op_name(); // 11. + // OutElementwiseOperation + oss << "," + << detail::conv_bwd_weight_spec_name( + kConvBwdWeightSpecialization); // 12. ConvBackwardWeightSpecialization + oss << "," << kBlockSize; // 13. BlockSize + oss << "," << kMPerBlock; // 14. MPerBlock + oss << "," << kNPerBlock; // 15. NPerBlock + oss << "," << kK0PerBlock; // 16. K0PerBlock + oss << "," << kK1; // 17. K1 + oss << "," << kMPerXDL; // 18. MPerXDL + oss << "," << kNPerXDL; // 19. NPerXDL + oss << "," << kMXdlPerWave; // 20. MXdlPerWave + oss << "," << kNXdlPerWave; // 21. NXdlPerWave + oss << "," << detail::sequence_name(); // 22. + oss << "," << detail::sequence_name(); // 23. + oss << "," << detail::sequence_name(); // 24. + oss << "," << kABlockTransferSrcVectorDim; // 25. + oss << "," << kABlockTransferSrcScalarPerVector; // 26. + oss << "," << kABlockTransferDstScalarPerVectorK1; // 27. + oss << "," << (kABlockLdsExtraM ? "true" : "false"); // 28. + oss << "," << detail::sequence_name(); // 29. + oss << "," << detail::sequence_name(); // 30. + oss << "," << detail::sequence_name(); // 31. + oss << "," << kBBlockTransferSrcVectorDim; // 32. + oss << "," << kBBlockTransferSrcScalarPerVector; // 33. + oss << "," << kBBlockTransferDstScalarPerVectorK1; // 34. + oss << "," << (kBBlockLdsExtraN ? "true" : "false"); // 35. + oss << "," << kCShuffleMXdlPerWavePerShuffle; // 36. + oss << "," << kCShuffleNXdlPerWavePerShuffle; // 37. + oss << "," + << detail::sequence_name< + CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock>(); // 38. + oss << "," << kCBlockTransferScalarPerVector_NWaveNPerXdl; // 39. + oss << "," << detail::type_name(); // 40. + oss << "," << detail::type_name(); // 41. + oss << "," << kMaxTransposeTransferSrcScalarPerVector; // 42. + oss << "," << kMaxTransposeTransferDstScalarPerVector; // 43. + oss << ">"; + + return oss.str(); + } +}; + +} // namespace ck_tile::reflect diff --git a/experimental/builder/test/conv/ck/test_conv_traits.cpp b/experimental/builder/test/conv/ck/test_conv_traits.cpp index a171627753..f01fe35b5d 100644 --- a/experimental/builder/test/conv/ck/test_conv_traits.cpp +++ b/experimental/builder/test/conv/ck/test_conv_traits.cpp @@ -6,6 +6,7 @@ #include #include +#include #include #include #include diff --git a/experimental/builder/test/test_bwd_weight_instance_traits.cpp b/experimental/builder/test/test_bwd_weight_instance_traits.cpp index dbb3a0a8fc..63317315b3 100644 --- a/experimental/builder/test/test_bwd_weight_instance_traits.cpp +++ b/experimental/builder/test/test_bwd_weight_instance_traits.cpp @@ -5,7 +5,7 @@ #include "ck/ck.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck_tile/builder/reflect/instance_traits.hpp" -#include "ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp" #include "ck_tile/builder/reflect/instance_traits_tile_grouped_convolution_backward_weight.hpp" #include "ck_tile/ops/epilogue/cshuffle_epilogue.hpp" diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp index 1f6f2fb789..585454221a 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp @@ -1374,7 +1374,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle "Specialization of instance_traits not found. Please check that a " "specialization exists in file " "ck_tile/builder/reflect/" - "instance_traits_device_grouped_conv_bwd_weight_xdl_cshuffle.hpp " + "reflect_device_grouped_conv_bwd_weight_xdl_cshuffle.inc " "for the given template parameters."); return ck_tile::reflect::instance_string(); } @@ -1417,3 +1417,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle } // namespace device } // namespace tensor_operation } // namespace ck + +#ifdef CK_EXPERIMENTAL_BUILDER +#include "ck_tile/builder/reflect/reflect_device_grouped_conv_bwd_weight_xdl_cshuffle.inc" +#endif