diff --git a/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp b/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp new file mode 100644 index 0000000000..7970b1ced5 --- /dev/null +++ b/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp @@ -0,0 +1,284 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "instance_traits.hpp" +#include "instance_traits_util.hpp" +#include "ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp" + +// Forward declaration to avoid circular dependency +namespace ck::tensor_operation::device { + +template +struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3; + +} // namespace ck::tensor_operation::device + +namespace ck_tile { +namespace reflect { + +template +struct InstanceTraits> +{ + static constexpr auto kTensorOpName = "DeviceGroupedConvBwdWeight_Xdl_CShuffleV3"; + + static constexpr ck::index_t kNDimSpatial = NDimSpatial; + + using InLayout = InLayout_; + using WeiLayout = WeiLayout_; + using OutLayout = OutLayout_; + + using InDataType = InDataType_; + using WeiDataType = WeiDataType_; + using OutDataType = OutDataType_; + using AccDataType = AccDataType_; + + using InElementwiseOperation = InElementwiseOperation_; + using WeiElementwiseOperation = WeiElementwiseOperation_; + using OutElementwiseOperation = OutElementwiseOperation_; + + static constexpr auto kConvBackwardWeightSpecialization = ConvBackwardWeightSpecialization; + + static constexpr ck::index_t kBlockSize = BlockSize; + static constexpr ck::index_t kMPerBlock = MPerBlock; + static constexpr ck::index_t kNPerBlock = NPerBlock; + static constexpr ck::index_t kK0PerBlock = K0PerBlock; + static constexpr ck::index_t kK1 = K1; + static constexpr ck::index_t kMPerXDL = MPerXDL; + static constexpr ck::index_t kNPerXDL = NPerXDL; + static constexpr ck::index_t kMXdlPerWave = MXdlPerWave; + static constexpr ck::index_t kNXdlPerWave = NXdlPerWave; + + using ABlockTransferThreadClusterLengths_K0_M_K1 = ABlockTransferThreadClusterLengths_K0_M_K1_; + using ABlockTransferThreadClusterArrangeOrder = ABlockTransferThreadClusterArrangeOrder_; + using ABlockTransferSrcAccessOrder = ABlockTransferSrcAccessOrder_; + static constexpr ck::index_t kABlockTransferSrcVectorDim = ABlockTransferSrcVectorDim; + static constexpr ck::index_t kABlockTransferSrcScalarPerVector = + ABlockTransferSrcScalarPerVector; + static constexpr ck::index_t kABlockTransferDstScalarPerVector_K1 = + ABlockTransferDstScalarPerVector_K1; + static constexpr bool kABlockLdsAddExtraM = ABlockLdsAddExtraM; + + using BBlockTransferThreadClusterLengths_K0_N_K1 = BBlockTransferThreadClusterLengths_K0_N_K1_; + using BBlockTransferThreadClusterArrangeOrder = BBlockTransferThreadClusterArrangeOrder_; + using BBlockTransferSrcAccessOrder = BBlockTransferSrcAccessOrder_; + static constexpr ck::index_t kBBlockTransferSrcVectorDim = BBlockTransferSrcVectorDim; + static constexpr ck::index_t kBBlockTransferSrcScalarPerVector = + BBlockTransferSrcScalarPerVector; + static constexpr ck::index_t kBBlockTransferDstScalarPerVector_K1 = + BBlockTransferDstScalarPerVector_K1; + static constexpr bool kBBlockLdsAddExtraN = BBlockLdsAddExtraN; + + static constexpr ck::index_t kCShuffleMXdlPerWavePerShuffle = CShuffleMXdlPerWavePerShuffle; + static constexpr ck::index_t kCShuffleNXdlPerWavePerShuffle = CShuffleNXdlPerWavePerShuffle; + + using CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock = + CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock_; + static constexpr ck::index_t kCBlockTransferScalarPerVector_NWaveNPerXdl = + CBlockTransferScalarPerVector_NWaveNPerXdl; + + static constexpr ck::BlockGemmPipelineScheduler kBlkGemmPipeSched = BlkGemmPipeSched; + static constexpr ck::BlockGemmPipelineVersion kBlkGemmPipelineVer = BlkGemmPipelineVer; + + using ComputeTypeA = ComputeTypeA_; + using ComputeTypeB = ComputeTypeB_; + + // Static member function to generate instance string + static std::string instance_string() + { + std::ostringstream oss; + + // Kernel type name + oss << "DeviceGroupedConvBwdWeight_Xdl_CShuffleV3"; + + // Template parameters in exact order + oss << "<" << kNDimSpatial; // 1. NDimSpatial + oss << "," << detail::layout_name(); // 2. InLayout + oss << "," << detail::layout_name(); // 3. WeiLayout + oss << "," << detail::layout_name(); // 4. OutLayout + oss << "," << detail::type_name(); // 5. InDataType + oss << "," << detail::type_name(); // 6. WeiDataType + oss << "," << detail::type_name(); // 7. OutDataType + oss << "," << detail::type_name(); // 8. AccDataType + oss << "," + << detail::elementwise_op_name(); // 9. InElementwiseOperation + oss << "," + << detail::elementwise_op_name(); // 10. + // WeiElementwiseOperation + oss << "," + << detail::elementwise_op_name(); // 11. + // OutElementwiseOperation + oss << "," + << detail::conv_bwd_weight_spec_name( + kConvBackwardWeightSpecialization); // 12. ConvBackwardWeightSpecialization + oss << "," << kBlockSize; // 13. BlockSize + oss << "," << kMPerBlock; // 14. MPerBlock + oss << "," << kNPerBlock; // 15. NPerBlock + oss << "," << kK0PerBlock; // 16. K0PerBlock + oss << "," << kK1; // 17. K1 + oss << "," << kMPerXDL; // 18. MPerXDL + oss << "," << kNPerXDL; // 19. NPerXDL + oss << "," << kMXdlPerWave; // 20. MXdlPerWave + oss << "," << kNXdlPerWave; // 21. NXdlPerWave + oss << "," << detail::sequence_name(); // 22. + oss << "," << detail::sequence_name(); // 23. + oss << "," << detail::sequence_name(); // 24. + oss << "," << kABlockTransferSrcVectorDim; // 25. + oss << "," << kABlockTransferSrcScalarPerVector; // 26. + oss << "," << kABlockTransferDstScalarPerVector_K1; // 27. + oss << "," << (kABlockLdsAddExtraM ? "true" : "false"); // 28. + oss << "," << detail::sequence_name(); // 29. + oss << "," << detail::sequence_name(); // 30. + oss << "," << detail::sequence_name(); // 31. + oss << "," << kBBlockTransferSrcVectorDim; // 32. + oss << "," << kBBlockTransferSrcScalarPerVector; // 33. + oss << "," << kBBlockTransferDstScalarPerVector_K1; // 34. + oss << "," << (kBBlockLdsAddExtraN ? "true" : "false"); // 35. + oss << "," << kCShuffleMXdlPerWavePerShuffle; // 36. + oss << "," << kCShuffleNXdlPerWavePerShuffle; // 37. + oss << "," + << detail::sequence_name< + CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock>(); // 38. + oss << "," << kCBlockTransferScalarPerVector_NWaveNPerXdl; // 39. + oss << "," << detail::pipeline_scheduler_name(kBlkGemmPipeSched); // 40. + oss << "," << detail::pipeline_version_name(kBlkGemmPipelineVer); // 41. + oss << "," << detail::type_name(); // 42. + oss << "," << detail::type_name(); // 43. + oss << ">"; + + return oss.str(); + } +}; + +} // namespace reflect +} // namespace ck_tile diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp index a0e06a81d6..1e3177729d 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp @@ -29,6 +29,11 @@ #include "ck/host_utility/kernel_launch.hpp" #include "ck/host_utility/flush_cache.hpp" +#ifdef CK_EXPERIMENTAL_BUILDER +#include "ck_tile/builder/reflect/description.hpp" +#include "ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp" +#endif + namespace ck { namespace tensor_operation { namespace device { @@ -1548,6 +1553,25 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3 return str.str(); } + +#ifdef CK_EXPERIMENTAL_BUILDER + std::string GetInstanceString() const override + { + static_assert(ck_tile::reflect::HasInstanceTraits, + "Specialization of instance_traits not found. Please check that a " + "specialization exists in file " + "ck_tile/builder/reflect/" + "instance_traits_device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp " + "for the given template parameters."); + return ck_tile::reflect::instance_string(); + } + + std::unique_ptr describe() const override + { + return std::make_unique(GetInstanceString()); + } +#endif + }; } // namespace device