mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-03 21:21:22 +00:00
Add the last two forward instance traits. (#3134)
* Add InstanceTraits for DeviceGroupedConvFwdMultipleD_Wmma_CShuffle * Add InstanceTraits for kernel_grouped_conv_fwd_dl_multiple_d * A few small changes to fix broken instance traits.
This commit is contained in:
@@ -0,0 +1,341 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
// InstanceTraits specialization for DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK
|
||||
//
|
||||
// CRITICAL MAINTENANCE NOTE:
|
||||
// This InstanceTraits file MUST be kept strictly in sync with the device implementation header:
|
||||
// ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp
|
||||
// "In sync" means that the template parameter order, names, and types in the declaration below
|
||||
// MUST EXACTLY MATCH those in the device implementation. If these diverge, you may encounter
|
||||
// compilation errors, subtle template instantiation mismatches, or silent runtime bugs that are
|
||||
// difficult to diagnose. Always update both files together and review changes carefully.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "instance_traits.hpp"
|
||||
|
||||
// Forward declaration to avoid circular dependency.
|
||||
// This file will be included by the device implementation header, so we cannot include
|
||||
// the implementation header here. We only need the template signature to pattern-match
|
||||
// on template parameters - we don't need any implementation details.
|
||||
|
||||
namespace ck::tensor_operation::device {
|
||||
|
||||
template <index_t NDimSpatial,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename DsDataType,
|
||||
typename EDataType,
|
||||
typename AccDataType,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename DsLayout,
|
||||
typename ELayout,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CDEElementwiseOperation,
|
||||
ck::tensor_operation::device::ConvolutionForwardSpecialization ConvForwardSpecialization,
|
||||
ck::tensor_operation::device::GemmSpecialization GemmSpec,
|
||||
ck::index_t BlockSize,
|
||||
ck::index_t MPerBlock,
|
||||
ck::index_t NPerBlock,
|
||||
ck::index_t K0PerBlock,
|
||||
ck::index_t K1,
|
||||
ck::index_t M1PerThread,
|
||||
ck::index_t N1PerThread,
|
||||
ck::index_t KPerThread,
|
||||
typename M1N1ThreadClusterM1Xs,
|
||||
typename M1N1ThreadClusterN1Xs,
|
||||
typename ABlockTransferThreadSliceLengths_K0_M0_M1_K1,
|
||||
typename ABlockTransferThreadClusterLengths_K0_M0_M1_K1,
|
||||
typename ABlockTransferThreadClusterArrangeOrder,
|
||||
typename ABlockTransferSrcAccessOrder,
|
||||
typename ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1,
|
||||
typename ABlockTransferSrcVectorTensorContiguousDimOrder,
|
||||
typename ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1,
|
||||
typename BBlockTransferThreadSliceLengths_K0_N0_N1_K1,
|
||||
typename BBlockTransferThreadClusterLengths_K0_N0_N1_K1,
|
||||
typename BBlockTransferThreadClusterArrangeOrder,
|
||||
typename BBlockTransferSrcAccessOrder,
|
||||
typename BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1,
|
||||
typename BBlockTransferSrcVectorTensorContiguousDimOrder,
|
||||
typename BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1,
|
||||
typename CThreadTransferSrcDstAccessOrder,
|
||||
ck::index_t CThreadTransferSrcDstVectorDim,
|
||||
ck::index_t CThreadTransferDstScalarPerVector>
|
||||
struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK;
|
||||
|
||||
} // namespace ck::tensor_operation::device
|
||||
|
||||
namespace ck_tile::reflect {
|
||||
|
||||
// Specialization for DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK
|
||||
template <ck::index_t NDimSpatial,
|
||||
typename ADataType_,
|
||||
typename BDataType_,
|
||||
typename DsDataType_,
|
||||
typename EDataType_,
|
||||
typename AccDataType_,
|
||||
typename ALayout_,
|
||||
typename BLayout_,
|
||||
typename DsLayout_,
|
||||
typename ELayout_,
|
||||
typename AElementwiseOperation_,
|
||||
typename BElementwiseOperation_,
|
||||
typename CDEElementwiseOperation_,
|
||||
ck::tensor_operation::device::ConvolutionForwardSpecialization ConvForwardSpecialization,
|
||||
ck::tensor_operation::device::GemmSpecialization GemmSpec,
|
||||
ck::index_t BlockSize,
|
||||
ck::index_t MPerBlock,
|
||||
ck::index_t NPerBlock,
|
||||
ck::index_t K0PerBlock,
|
||||
ck::index_t K1,
|
||||
ck::index_t M1PerThread,
|
||||
ck::index_t N1PerThread,
|
||||
ck::index_t KPerThread,
|
||||
typename M1N1ThreadClusterM1Xs_,
|
||||
typename M1N1ThreadClusterN1Xs_,
|
||||
typename ABlockTransferThreadSliceLengths_K0_M0_M1_K1_,
|
||||
typename ABlockTransferThreadClusterLengths_K0_M0_M1_K1_,
|
||||
typename ABlockTransferThreadClusterArrangeOrder_,
|
||||
typename ABlockTransferSrcAccessOrder_,
|
||||
typename ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1_,
|
||||
typename ABlockTransferSrcVectorTensorContiguousDimOrder_,
|
||||
typename ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1_,
|
||||
typename BBlockTransferThreadSliceLengths_K0_N0_N1_K1_,
|
||||
typename BBlockTransferThreadClusterLengths_K0_N0_N1_K1_,
|
||||
typename BBlockTransferThreadClusterArrangeOrder_,
|
||||
typename BBlockTransferSrcAccessOrder_,
|
||||
typename BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1_,
|
||||
typename BBlockTransferSrcVectorTensorContiguousDimOrder_,
|
||||
typename BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1_,
|
||||
typename CThreadTransferSrcDstAccessOrder_,
|
||||
ck::index_t CThreadTransferSrcDstVectorDim,
|
||||
ck::index_t CThreadTransferDstScalarPerVector>
|
||||
struct InstanceTraits<ck::tensor_operation::device::DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK<
|
||||
NDimSpatial,
|
||||
ADataType_,
|
||||
BDataType_,
|
||||
DsDataType_,
|
||||
EDataType_,
|
||||
AccDataType_,
|
||||
ALayout_,
|
||||
BLayout_,
|
||||
DsLayout_,
|
||||
ELayout_,
|
||||
AElementwiseOperation_,
|
||||
BElementwiseOperation_,
|
||||
CDEElementwiseOperation_,
|
||||
ConvForwardSpecialization,
|
||||
GemmSpec,
|
||||
BlockSize,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
K0PerBlock,
|
||||
K1,
|
||||
M1PerThread,
|
||||
N1PerThread,
|
||||
KPerThread,
|
||||
M1N1ThreadClusterM1Xs_,
|
||||
M1N1ThreadClusterN1Xs_,
|
||||
ABlockTransferThreadSliceLengths_K0_M0_M1_K1_,
|
||||
ABlockTransferThreadClusterLengths_K0_M0_M1_K1_,
|
||||
ABlockTransferThreadClusterArrangeOrder_,
|
||||
ABlockTransferSrcAccessOrder_,
|
||||
ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1_,
|
||||
ABlockTransferSrcVectorTensorContiguousDimOrder_,
|
||||
ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1_,
|
||||
BBlockTransferThreadSliceLengths_K0_N0_N1_K1_,
|
||||
BBlockTransferThreadClusterLengths_K0_N0_N1_K1_,
|
||||
BBlockTransferThreadClusterArrangeOrder_,
|
||||
BBlockTransferSrcAccessOrder_,
|
||||
BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1_,
|
||||
BBlockTransferSrcVectorTensorContiguousDimOrder_,
|
||||
BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1_,
|
||||
CThreadTransferSrcDstAccessOrder_,
|
||||
CThreadTransferSrcDstVectorDim,
|
||||
CThreadTransferDstScalarPerVector>>
|
||||
{
|
||||
// Spatial dimension
|
||||
static constexpr int kSpatialDim = NDimSpatial;
|
||||
|
||||
// Data types
|
||||
using ADataType = ADataType_;
|
||||
using BDataType = BDataType_;
|
||||
using DsDataType = DsDataType_;
|
||||
using EDataType = EDataType_;
|
||||
using AccDataType = AccDataType_;
|
||||
|
||||
// Layout types
|
||||
using ALayout = ALayout_;
|
||||
using BLayout = BLayout_;
|
||||
using DsLayout = DsLayout_;
|
||||
using ELayout = ELayout_;
|
||||
|
||||
// Element-wise operations
|
||||
using AElementwiseOperation = AElementwiseOperation_;
|
||||
using BElementwiseOperation = BElementwiseOperation_;
|
||||
using CDEElementwiseOperation = CDEElementwiseOperation_;
|
||||
|
||||
// Specialization
|
||||
static constexpr ck::tensor_operation::device::ConvolutionForwardSpecialization
|
||||
kConvForwardSpecialization = ConvForwardSpecialization;
|
||||
static constexpr ck::tensor_operation::device::GemmSpecialization kGemmSpecialization =
|
||||
GemmSpec;
|
||||
|
||||
// Block configuration
|
||||
static constexpr int kBlockSize = BlockSize;
|
||||
static constexpr int kMPerBlock = MPerBlock;
|
||||
static constexpr int kNPerBlock = NPerBlock;
|
||||
static constexpr int kK0PerBlock = K0PerBlock;
|
||||
|
||||
// Tuning parameters
|
||||
static constexpr int kK1 = K1;
|
||||
static constexpr int kM1PerThread = M1PerThread;
|
||||
static constexpr int kN1PerThread = N1PerThread;
|
||||
static constexpr int kKPerThread = KPerThread;
|
||||
|
||||
// Thread cluster configurations
|
||||
using M1N1ThreadClusterM1Xs = M1N1ThreadClusterM1Xs_;
|
||||
using M1N1ThreadClusterN1Xs = M1N1ThreadClusterN1Xs_;
|
||||
|
||||
// A block transfer parameters
|
||||
using ABlockTransferThreadSliceLengths_K0_M0_M1_K1 =
|
||||
ABlockTransferThreadSliceLengths_K0_M0_M1_K1_;
|
||||
using ABlockTransferThreadClusterLengths_K0_M0_M1_K1 =
|
||||
ABlockTransferThreadClusterLengths_K0_M0_M1_K1_;
|
||||
using ABlockTransferThreadClusterArrangeOrder = ABlockTransferThreadClusterArrangeOrder_;
|
||||
using ABlockTransferSrcAccessOrder = ABlockTransferSrcAccessOrder_;
|
||||
using ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1 =
|
||||
ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1_;
|
||||
using ABlockTransferSrcVectorTensorContiguousDimOrder =
|
||||
ABlockTransferSrcVectorTensorContiguousDimOrder_;
|
||||
using ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1 =
|
||||
ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1_;
|
||||
|
||||
// B block transfer parameters
|
||||
using BBlockTransferThreadSliceLengths_K0_N0_N1_K1 =
|
||||
BBlockTransferThreadSliceLengths_K0_N0_N1_K1_;
|
||||
using BBlockTransferThreadClusterLengths_K0_N0_N1_K1 =
|
||||
BBlockTransferThreadClusterLengths_K0_N0_N1_K1_;
|
||||
using BBlockTransferThreadClusterArrangeOrder = BBlockTransferThreadClusterArrangeOrder_;
|
||||
using BBlockTransferSrcAccessOrder = BBlockTransferSrcAccessOrder_;
|
||||
using BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1 =
|
||||
BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1_;
|
||||
using BBlockTransferSrcVectorTensorContiguousDimOrder =
|
||||
BBlockTransferSrcVectorTensorContiguousDimOrder_;
|
||||
using BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1 =
|
||||
BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1_;
|
||||
|
||||
// C thread transfer parameters
|
||||
using CThreadTransferSrcDstAccessOrder = CThreadTransferSrcDstAccessOrder_;
|
||||
static constexpr int kCThreadTransferSrcDstVectorDim = CThreadTransferSrcDstVectorDim;
|
||||
static constexpr int kCThreadTransferDstScalarPerVector = CThreadTransferDstScalarPerVector;
|
||||
|
||||
// Static member function to generate instance string
|
||||
static std::string instance_string()
|
||||
{
|
||||
std::ostringstream oss;
|
||||
|
||||
// Kernel type name
|
||||
oss << "DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK";
|
||||
|
||||
// Template parameters in exact order matching the device implementation
|
||||
oss << "<" << kSpatialDim; // 1. NDimSpatial
|
||||
oss << "," << detail::type_name<ADataType>(); // 2. ADataType
|
||||
oss << "," << detail::type_name<BDataType>(); // 3. BDataType
|
||||
oss << "," << detail::tuple_name<DsDataType>(); // 4. DsDataType
|
||||
oss << "," << detail::type_name<EDataType>(); // 5. EDataType
|
||||
oss << "," << detail::type_name<AccDataType>(); // 6. AccDataType
|
||||
oss << "," << detail::layout_name<ALayout>(); // 7. ALayout
|
||||
oss << "," << detail::layout_name<BLayout>(); // 8. BLayout
|
||||
oss << "," << detail::tuple_name<DsLayout>(); // 9. DsLayout
|
||||
oss << "," << detail::layout_name<ELayout>(); // 10. ELayout
|
||||
oss << ","
|
||||
<< detail::elementwise_op_name<AElementwiseOperation>(); // 11. AElementwiseOperation
|
||||
oss << ","
|
||||
<< detail::elementwise_op_name<BElementwiseOperation>(); // 12. BElementwiseOperation
|
||||
oss << ","
|
||||
<< detail::elementwise_op_name<CDEElementwiseOperation>(); // 13.
|
||||
// CDEElementwiseOperation
|
||||
oss << ","
|
||||
<< detail::conv_fwd_spec_name(
|
||||
kConvForwardSpecialization); // 14. ConvForwardSpecialization
|
||||
oss << "," << detail::gemm_spec_name(kGemmSpecialization); // 15. GemmSpec
|
||||
oss << "," << kBlockSize; // 16. BlockSize
|
||||
oss << "," << kMPerBlock; // 17. MPerBlock
|
||||
oss << "," << kNPerBlock; // 18. NPerBlock
|
||||
oss << "," << kK0PerBlock; // 19. K0PerBlock
|
||||
oss << "," << kK1; // 20. K1
|
||||
oss << "," << kM1PerThread; // 21. M1PerThread
|
||||
oss << "," << kN1PerThread; // 22. N1PerThread
|
||||
oss << "," << kKPerThread; // 23. KPerThread
|
||||
oss << "," << detail::sequence_name<M1N1ThreadClusterM1Xs>(); // 24. M1N1ThreadClusterM1Xs
|
||||
oss << "," << detail::sequence_name<M1N1ThreadClusterN1Xs>(); // 25. M1N1ThreadClusterN1Xs
|
||||
oss << ","
|
||||
<< detail::sequence_name<
|
||||
ABlockTransferThreadSliceLengths_K0_M0_M1_K1>(); // 26.
|
||||
// ABlockTransferThreadSliceLengths_K0_M0_M1_K1
|
||||
oss << ","
|
||||
<< detail::sequence_name<
|
||||
ABlockTransferThreadClusterLengths_K0_M0_M1_K1>(); // 27.
|
||||
// ABlockTransferThreadClusterLengths_K0_M0_M1_K1
|
||||
oss << ","
|
||||
<< detail::sequence_name<
|
||||
ABlockTransferThreadClusterArrangeOrder>(); // 28.
|
||||
// ABlockTransferThreadClusterArrangeOrder
|
||||
oss << ","
|
||||
<< detail::sequence_name<
|
||||
ABlockTransferSrcAccessOrder>(); // 29. ABlockTransferSrcAccessOrder
|
||||
oss << ","
|
||||
<< detail::sequence_name<
|
||||
ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1>(); // 30.
|
||||
// ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1
|
||||
oss << ","
|
||||
<< detail::sequence_name<
|
||||
ABlockTransferSrcVectorTensorContiguousDimOrder>(); // 31.
|
||||
// ABlockTransferSrcVectorTensorContiguousDimOrder
|
||||
oss << ","
|
||||
<< detail::sequence_name<
|
||||
ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1>(); // 32.
|
||||
// ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1
|
||||
oss << ","
|
||||
<< detail::sequence_name<
|
||||
BBlockTransferThreadSliceLengths_K0_N0_N1_K1>(); // 33.
|
||||
// BBlockTransferThreadSliceLengths_K0_N0_N1_K1
|
||||
oss << ","
|
||||
<< detail::sequence_name<
|
||||
BBlockTransferThreadClusterLengths_K0_N0_N1_K1>(); // 34.
|
||||
// BBlockTransferThreadClusterLengths_K0_N0_N1_K1
|
||||
oss << ","
|
||||
<< detail::sequence_name<
|
||||
BBlockTransferThreadClusterArrangeOrder>(); // 35.
|
||||
// BBlockTransferThreadClusterArrangeOrder
|
||||
oss << ","
|
||||
<< detail::sequence_name<
|
||||
BBlockTransferSrcAccessOrder>(); // 36. BBlockTransferSrcAccessOrder
|
||||
oss << ","
|
||||
<< detail::sequence_name<
|
||||
BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1>(); // 37.
|
||||
// BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1
|
||||
oss << ","
|
||||
<< detail::sequence_name<
|
||||
BBlockTransferSrcVectorTensorContiguousDimOrder>(); // 38.
|
||||
// BBlockTransferSrcVectorTensorContiguousDimOrder
|
||||
oss << ","
|
||||
<< detail::sequence_name<
|
||||
BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1>(); // 39.
|
||||
// BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1
|
||||
oss << ","
|
||||
<< detail::sequence_name<
|
||||
CThreadTransferSrcDstAccessOrder>(); // 40. CThreadTransferSrcDstAccessOrder
|
||||
oss << "," << kCThreadTransferSrcDstVectorDim; // 41. CThreadTransferSrcDstVectorDim
|
||||
oss << "," << kCThreadTransferDstScalarPerVector; // 42. CThreadTransferDstScalarPerVector
|
||||
oss << ">";
|
||||
|
||||
return oss.str();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile::reflect
|
||||
@@ -32,8 +32,8 @@ template <ck::index_t NDimSpatial,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CDEElementwiseOperation,
|
||||
ConvolutionForwardSpecialization ConvForwardSpecialization,
|
||||
GemmSpecialization GemmSpec,
|
||||
ck::tensor_operation::device::ConvolutionForwardSpecialization ConvForwardSpecialization,
|
||||
ck::tensor_operation::device::GemmSpecialization GemmSpec,
|
||||
ck::index_t NumGemmKPrefetchStage,
|
||||
ck::index_t BlockSize,
|
||||
ck::index_t MPerBlock,
|
||||
@@ -65,7 +65,7 @@ template <ck::index_t NDimSpatial,
|
||||
ck::index_t CDEBlockTransferScalarPerVector_NPerBlock,
|
||||
typename AComputeDataType,
|
||||
typename BComputeDataType,
|
||||
LoopScheduler LoopSched,
|
||||
ck::LoopScheduler LoopSched,
|
||||
ck::index_t NumGroupsToMerge>
|
||||
struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle;
|
||||
|
||||
@@ -269,17 +269,17 @@ struct InstanceTraits<ck::tensor_operation::device::DeviceGroupedConvFwdMultiple
|
||||
oss << "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle";
|
||||
|
||||
// Template parameters in exact order matching InstanceTraits member order
|
||||
oss << "<" << kSpatialDim; // 1. NDimSpatial
|
||||
oss << "," << detail::layout_name<ALayout>(); // 2. ALayout
|
||||
oss << "," << detail::layout_name<BLayout>(); // 3. BLayout
|
||||
oss << "," << detail::tuple_name<DsLayout>(); // 4. DsLayout
|
||||
oss << "," << detail::layout_name<ELayout>(); // 5. ELayout
|
||||
oss << "," << detail::type_name<ADataType>(); // 6. ADataType
|
||||
oss << "," << detail::type_name<BDataType>(); // 7. BDataType
|
||||
oss << "," << detail::type_name<AccDataType>(); // 8. AccDataType
|
||||
oss << "," << detail::type_name<CShuffleDataType>(); // 9. CShuffleDataType
|
||||
oss << "," << detail::tuple_name<DsDataType>(); // 10. DsDataType
|
||||
oss << "," << detail::type_name<EDataType>(); // 11. EDataType
|
||||
oss << "<" << kSpatialDim; // 1. NDimSpatial
|
||||
oss << "," << detail::layout_name<ALayout>(); // 2. ALayout
|
||||
oss << "," << detail::layout_name<BLayout>(); // 3. BLayout
|
||||
oss << "," << detail::tuple_name<DsLayout>(); // 4. DsLayout
|
||||
oss << "," << detail::layout_name<ELayout>(); // 5. ELayout
|
||||
oss << "," << detail::type_or_type_tuple_name<ADataType>(); // 6. ADataType
|
||||
oss << "," << detail::type_or_type_tuple_name<BDataType>(); // 7. BDataType
|
||||
oss << "," << detail::type_name<AccDataType>(); // 8. AccDataType
|
||||
oss << "," << detail::type_name<CShuffleDataType>(); // 9. CShuffleDataType
|
||||
oss << "," << detail::tuple_name<DsDataType>(); // 10. DsDataType
|
||||
oss << "," << detail::type_name<EDataType>(); // 11. EDataType
|
||||
oss << ","
|
||||
<< detail::elementwise_op_name<AElementwiseOperation>(); // 12. AElementwiseOperation
|
||||
oss << ","
|
||||
|
||||
@@ -22,7 +22,7 @@
|
||||
// on template parameters - we don't need any implementation details.
|
||||
namespace ck::tensor_operation::device {
|
||||
|
||||
template <ck::index_t NDimSpatial,
|
||||
template <index_t NDimSpatial,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename DsLayout,
|
||||
@@ -36,8 +36,8 @@ template <ck::index_t NDimSpatial,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CDEElementwiseOperation,
|
||||
ConvolutionForwardSpecialization ConvForwardSpecialization,
|
||||
GemmSpecialization GemmSpec,
|
||||
ck::tensor_operation::device::ConvolutionForwardSpecialization ConvForwardSpecialization,
|
||||
ck::tensor_operation::device::GemmSpecialization GemmSpec,
|
||||
ck::index_t BlockSize,
|
||||
ck::index_t MPerBlock,
|
||||
ck::index_t NPerBlock,
|
||||
@@ -259,6 +259,8 @@ struct InstanceTraits<ck::tensor_operation::device::DeviceGroupedConvFwdMultiple
|
||||
using AComputeDataType = AComputeDataType_;
|
||||
using BComputeDataType = BComputeDataType_;
|
||||
|
||||
static constexpr bool kDirectLoad = DirectLoad;
|
||||
|
||||
// Static member function to generate instance string
|
||||
static std::string instance_string()
|
||||
{
|
||||
|
||||
@@ -0,0 +1,343 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
// InstanceTraits specialization for DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
|
||||
//
|
||||
// CRITICAL MAINTENANCE NOTE:
|
||||
// This InstanceTraits file MUST be kept strictly in sync with the device implementation header:
|
||||
// ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp
|
||||
// "In sync" means that the template parameter order, names, and types in the declaration below
|
||||
// MUST EXACTLY MATCH those in the device implementation. If these diverge, you may encounter
|
||||
// compilation errors, subtle template instantiation mismatches, or silent runtime bugs that are
|
||||
// difficult to diagnose. Always update both files together and review changes carefully.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "instance_traits.hpp"
|
||||
|
||||
// Forward declaration to avoid circular dependency.
|
||||
// This file will be included by the device implementation header, so we cannot include
|
||||
// the implementation header here. We only need the template signature to pattern-match
|
||||
// on template parameters - we don't need any implementation details.
|
||||
|
||||
// Forward declare types from ck namespace that are used in the template parameters
|
||||
namespace ck {
|
||||
enum struct PipelineVersion;
|
||||
enum struct LoopScheduler;
|
||||
} // namespace ck
|
||||
|
||||
namespace ck::tensor_operation::device {
|
||||
|
||||
template <ck::index_t NDimSpatial,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename DsLayout,
|
||||
typename ELayout,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename AccDataType,
|
||||
typename CShuffleDataType,
|
||||
typename DsDataType,
|
||||
typename EDataType,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CDEElementwiseOperation,
|
||||
ck::tensor_operation::device::ConvolutionForwardSpecialization ConvForwardSpecialization,
|
||||
ck::tensor_operation::device::GemmSpecialization GemmSpec,
|
||||
ck::index_t NumGemmKPrefetchStage,
|
||||
ck::index_t BlockSize,
|
||||
ck::index_t MPerBlock,
|
||||
ck::index_t NPerBlock,
|
||||
ck::index_t KPerBlock,
|
||||
ck::index_t K1,
|
||||
ck::index_t MPerWmma,
|
||||
ck::index_t NPerWmma,
|
||||
ck::index_t MRepeat,
|
||||
ck::index_t NRepeat,
|
||||
typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
|
||||
typename ABlockTransferThreadClusterArrangeOrder,
|
||||
typename ABlockTransferSrcAccessOrder,
|
||||
ck::index_t ABlockTransferSrcVectorDim,
|
||||
ck::index_t ABlockTransferSrcScalarPerVector,
|
||||
ck::index_t ABlockTransferDstScalarPerVector_AK1,
|
||||
bool ABlockLdsExtraM,
|
||||
typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
|
||||
typename BBlockTransferThreadClusterArrangeOrder,
|
||||
typename BBlockTransferSrcAccessOrder,
|
||||
ck::index_t BBlockTransferSrcVectorDim,
|
||||
ck::index_t BBlockTransferSrcScalarPerVector,
|
||||
ck::index_t BBlockTransferDstScalarPerVector_BK1,
|
||||
bool BBlockLdsExtraN,
|
||||
ck::index_t CShuffleMRepeatPerShuffle,
|
||||
ck::index_t CShuffleNRepeatPerShuffle,
|
||||
typename CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
ck::index_t CDEShuffleBlockTransferScalarPerVector_NPerBlock,
|
||||
ck::LoopScheduler LoopSched,
|
||||
ck::PipelineVersion PipelineVer>
|
||||
struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle;
|
||||
|
||||
} // namespace ck::tensor_operation::device
|
||||
|
||||
namespace ck_tile::reflect {
|
||||
|
||||
// Specialization for DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
|
||||
template <ck::index_t NDimSpatial,
|
||||
typename ALayout_,
|
||||
typename BLayout_,
|
||||
typename DsLayout_,
|
||||
typename ELayout_,
|
||||
typename ADataType_,
|
||||
typename BDataType_,
|
||||
typename AccDataType_,
|
||||
typename CShuffleDataType_,
|
||||
typename DsDataType_,
|
||||
typename EDataType_,
|
||||
typename AElementwiseOperation_,
|
||||
typename BElementwiseOperation_,
|
||||
typename CDEElementwiseOperation_,
|
||||
ck::tensor_operation::device::ConvolutionForwardSpecialization ConvForwardSpecialization,
|
||||
ck::tensor_operation::device::GemmSpecialization GemmSpec,
|
||||
ck::index_t NumGemmKPrefetchStage,
|
||||
ck::index_t BlockSize,
|
||||
ck::index_t MPerBlock,
|
||||
ck::index_t NPerBlock,
|
||||
ck::index_t KPerBlock,
|
||||
ck::index_t K1,
|
||||
ck::index_t MPerWmma,
|
||||
ck::index_t NPerWmma,
|
||||
ck::index_t MRepeat,
|
||||
ck::index_t NRepeat,
|
||||
typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
|
||||
typename ABlockTransferThreadClusterArrangeOrder,
|
||||
typename ABlockTransferSrcAccessOrder_,
|
||||
ck::index_t ABlockTransferSrcVectorDim,
|
||||
ck::index_t ABlockTransferSrcScalarPerVector,
|
||||
ck::index_t ABlockTransferDstScalarPerVector_AK1,
|
||||
bool ABlockLdsExtraM,
|
||||
typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
|
||||
typename BBlockTransferThreadClusterArrangeOrder,
|
||||
typename BBlockTransferSrcAccessOrder_,
|
||||
ck::index_t BBlockTransferSrcVectorDim,
|
||||
ck::index_t BBlockTransferSrcScalarPerVector,
|
||||
ck::index_t BBlockTransferDstScalarPerVector_BK1,
|
||||
bool BBlockLdsExtraN,
|
||||
ck::index_t CShuffleMRepeatPerShuffle,
|
||||
ck::index_t CShuffleNRepeatPerShuffle,
|
||||
typename CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
ck::index_t CDEShuffleBlockTransferScalarPerVector_NPerBlock,
|
||||
ck::LoopScheduler LoopSched,
|
||||
ck::PipelineVersion PipelineVer>
|
||||
struct InstanceTraits<ck::tensor_operation::device::DeviceGroupedConvFwdMultipleD_Wmma_CShuffle<
|
||||
NDimSpatial,
|
||||
ALayout_,
|
||||
BLayout_,
|
||||
DsLayout_,
|
||||
ELayout_,
|
||||
ADataType_,
|
||||
BDataType_,
|
||||
AccDataType_,
|
||||
CShuffleDataType_,
|
||||
DsDataType_,
|
||||
EDataType_,
|
||||
AElementwiseOperation_,
|
||||
BElementwiseOperation_,
|
||||
CDEElementwiseOperation_,
|
||||
ConvForwardSpecialization,
|
||||
GemmSpec,
|
||||
NumGemmKPrefetchStage,
|
||||
BlockSize,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
K1,
|
||||
MPerWmma,
|
||||
NPerWmma,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
ABlockTransferThreadClusterLengths_AK0_M_AK1,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
ABlockTransferSrcAccessOrder_,
|
||||
ABlockTransferSrcVectorDim,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
ABlockTransferDstScalarPerVector_AK1,
|
||||
ABlockLdsExtraM,
|
||||
BBlockTransferThreadClusterLengths_BK0_N_BK1,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
BBlockTransferSrcAccessOrder_,
|
||||
BBlockTransferSrcVectorDim,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
BBlockTransferDstScalarPerVector_BK1,
|
||||
BBlockLdsExtraN,
|
||||
CShuffleMRepeatPerShuffle,
|
||||
CShuffleNRepeatPerShuffle,
|
||||
CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
CDEShuffleBlockTransferScalarPerVector_NPerBlock,
|
||||
LoopSched,
|
||||
PipelineVer>>
|
||||
{
|
||||
// Spatial dimension
|
||||
static constexpr int kSpatialDim = NDimSpatial;
|
||||
|
||||
// Layout types
|
||||
using ALayout = ALayout_;
|
||||
using BLayout = BLayout_;
|
||||
using DsLayout = DsLayout_;
|
||||
using ELayout = ELayout_;
|
||||
|
||||
// Data types
|
||||
using ADataType = ADataType_;
|
||||
using BDataType = BDataType_;
|
||||
using AccDataType = AccDataType_;
|
||||
using CShuffleDataType = CShuffleDataType_;
|
||||
using DsDataType = DsDataType_;
|
||||
using EDataType = EDataType_;
|
||||
|
||||
// Element-wise operations
|
||||
using AElementwiseOperation = AElementwiseOperation_;
|
||||
using BElementwiseOperation = BElementwiseOperation_;
|
||||
using CDEElementwiseOperation = CDEElementwiseOperation_;
|
||||
|
||||
// Specialization
|
||||
static constexpr ck::tensor_operation::device::ConvolutionForwardSpecialization
|
||||
kConvForwardSpecialization = ConvForwardSpecialization;
|
||||
static constexpr ck::tensor_operation::device::GemmSpecialization kGemmSpecialization =
|
||||
GemmSpec;
|
||||
|
||||
// Prefetch stage
|
||||
static constexpr int kNumGemmKPrefetchStage = NumGemmKPrefetchStage;
|
||||
|
||||
// Block configuration
|
||||
static constexpr int kBlockSize = BlockSize;
|
||||
static constexpr int kMPerBlock = MPerBlock;
|
||||
static constexpr int kNPerBlock = NPerBlock;
|
||||
static constexpr int kKPerBlock = KPerBlock;
|
||||
|
||||
// Tuning parameters
|
||||
static constexpr int kK1 = K1;
|
||||
static constexpr int kMPerWmma = MPerWmma;
|
||||
static constexpr int kNPerWmma = NPerWmma;
|
||||
static constexpr int kMRepeat = MRepeat;
|
||||
static constexpr int kNRepeat = NRepeat;
|
||||
|
||||
// A block transfer thread cluster dimensions (converted to std::array)
|
||||
static constexpr auto kAThreadClusterLengths =
|
||||
detail::SequenceToArray<ABlockTransferThreadClusterLengths_AK0_M_AK1>::value;
|
||||
static constexpr auto kAThreadClusterArrangeOrder =
|
||||
detail::SequenceToArray<ABlockTransferThreadClusterArrangeOrder>::value;
|
||||
static constexpr auto kABlockTransferSrcAccessOrder =
|
||||
detail::SequenceToArray<ABlockTransferSrcAccessOrder_>::value;
|
||||
static constexpr int kABlockTransferSrcVectorDim = ABlockTransferSrcVectorDim;
|
||||
static constexpr int kABlockTransferSrcScalarPerVector = ABlockTransferSrcScalarPerVector;
|
||||
static constexpr int kABlockTransferDstScalarPerVectorK1 = ABlockTransferDstScalarPerVector_AK1;
|
||||
static constexpr bool kABlockLdsExtraM = ABlockLdsExtraM;
|
||||
|
||||
// B block transfer thread cluster dimensions (converted to std::array)
|
||||
static constexpr auto kBThreadClusterLengths =
|
||||
detail::SequenceToArray<BBlockTransferThreadClusterLengths_BK0_N_BK1>::value;
|
||||
static constexpr auto kBThreadClusterArrangeOrder =
|
||||
detail::SequenceToArray<BBlockTransferThreadClusterArrangeOrder>::value;
|
||||
static constexpr auto kBBlockTransferSrcAccessOrder =
|
||||
detail::SequenceToArray<BBlockTransferSrcAccessOrder_>::value;
|
||||
static constexpr int kBBlockTransferSrcVectorDim = BBlockTransferSrcVectorDim;
|
||||
static constexpr int kBBlockTransferSrcScalarPerVector = BBlockTransferSrcScalarPerVector;
|
||||
static constexpr int kBBlockTransferDstScalarPerVectorK1 = BBlockTransferDstScalarPerVector_BK1;
|
||||
static constexpr bool kBBlockLdsExtraN = BBlockLdsExtraN;
|
||||
|
||||
// C shuffle parameters (converted to std::array)
|
||||
static constexpr int kCShuffleMRepeatPerShuffle = CShuffleMRepeatPerShuffle;
|
||||
static constexpr int kCShuffleNRepeatPerShuffle = CShuffleNRepeatPerShuffle;
|
||||
static constexpr auto kCDEThreadClusterLengths = detail::SequenceToArray<
|
||||
CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock>::value;
|
||||
static constexpr int kCDEBlockTransferScalarPerVector =
|
||||
CDEShuffleBlockTransferScalarPerVector_NPerBlock;
|
||||
|
||||
// Pipeline configuration
|
||||
static constexpr ck::LoopScheduler kLoopScheduler = LoopSched;
|
||||
static constexpr ck::PipelineVersion kPipelineVersion = PipelineVer;
|
||||
|
||||
// Static member function to generate instance string
|
||||
static std::string instance_string()
|
||||
{
|
||||
std::ostringstream oss;
|
||||
|
||||
// Kernel type name
|
||||
oss << "DeviceGroupedConvFwdMultipleD_Wmma_CShuffle";
|
||||
|
||||
// Template parameters in exact order matching InstanceTraits member order
|
||||
oss << "<" << kSpatialDim; // 1. NDimSpatial
|
||||
oss << "," << detail::layout_name<ALayout>(); // 2. ALayout
|
||||
oss << "," << detail::layout_name<BLayout>(); // 3. BLayout
|
||||
oss << "," << detail::tuple_name<DsLayout>(); // 4. DsLayout
|
||||
oss << "," << detail::layout_name<ELayout>(); // 5. ELayout
|
||||
oss << "," << detail::type_name<ADataType>(); // 6. ADataType
|
||||
oss << "," << detail::type_name<BDataType>(); // 7. BDataType
|
||||
oss << "," << detail::type_name<AccDataType>(); // 8. AccDataType
|
||||
oss << "," << detail::type_name<CShuffleDataType>(); // 9. CShuffleDataType
|
||||
oss << "," << detail::tuple_name<DsDataType>(); // 10. DsDataType
|
||||
oss << "," << detail::type_name<EDataType>(); // 11. EDataType
|
||||
oss << ","
|
||||
<< detail::elementwise_op_name<AElementwiseOperation>(); // 12. AElementwiseOperation
|
||||
oss << ","
|
||||
<< detail::elementwise_op_name<BElementwiseOperation>(); // 13. BElementwiseOperation
|
||||
oss << ","
|
||||
<< detail::elementwise_op_name<CDEElementwiseOperation>(); // 14.
|
||||
// CDEElementwiseOperation
|
||||
oss << ","
|
||||
<< detail::conv_fwd_spec_name(
|
||||
kConvForwardSpecialization); // 15. ConvForwardSpecialization
|
||||
oss << "," << detail::gemm_spec_name(kGemmSpecialization); // 16. GemmSpec
|
||||
oss << "," << kNumGemmKPrefetchStage; // 17. NumGemmKPrefetchStage
|
||||
oss << "," << kBlockSize; // 18. BlockSize
|
||||
oss << "," << kMPerBlock; // 19. MPerBlock
|
||||
oss << "," << kNPerBlock; // 20. NPerBlock
|
||||
oss << "," << kKPerBlock; // 21. KPerBlock
|
||||
oss << "," << kK1; // 22. K1
|
||||
oss << "," << kMPerWmma; // 23. MPerWmma
|
||||
oss << "," << kNPerWmma; // 24. NPerWmma
|
||||
oss << "," << kMRepeat; // 25. MRepeat
|
||||
oss << "," << kNRepeat; // 26. NRepeat
|
||||
oss << ","
|
||||
<< detail::array_to_string(
|
||||
kAThreadClusterLengths); // 27. ABlockTransferThreadClusterLengths
|
||||
oss << ","
|
||||
<< detail::array_to_string(
|
||||
kAThreadClusterArrangeOrder); // 28. ABlockTransferThreadClusterArrangeOrder
|
||||
oss << ","
|
||||
<< detail::array_to_string(
|
||||
kABlockTransferSrcAccessOrder); // 29. ABlockTransferSrcAccessOrder
|
||||
oss << "," << kABlockTransferSrcVectorDim; // 30. ABlockTransferSrcVectorDim
|
||||
oss << "," << kABlockTransferSrcScalarPerVector; // 31. ABlockTransferSrcScalarPerVector
|
||||
oss << ","
|
||||
<< kABlockTransferDstScalarPerVectorK1; // 32. ABlockTransferDstScalarPerVector_AK1
|
||||
oss << "," << (kABlockLdsExtraM ? "true" : "false"); // 33. ABlockLdsExtraM
|
||||
oss << ","
|
||||
<< detail::array_to_string(
|
||||
kBThreadClusterLengths); // 34. BBlockTransferThreadClusterLengths
|
||||
oss << ","
|
||||
<< detail::array_to_string(
|
||||
kBThreadClusterArrangeOrder); // 35. BBlockTransferThreadClusterArrangeOrder
|
||||
oss << ","
|
||||
<< detail::array_to_string(
|
||||
kBBlockTransferSrcAccessOrder); // 36. BBlockTransferSrcAccessOrder
|
||||
oss << "," << kBBlockTransferSrcVectorDim; // 37. BBlockTransferSrcVectorDim
|
||||
oss << "," << kBBlockTransferSrcScalarPerVector; // 38. BBlockTransferSrcScalarPerVector
|
||||
oss << ","
|
||||
<< kBBlockTransferDstScalarPerVectorK1; // 39. BBlockTransferDstScalarPerVector_BK1
|
||||
oss << "," << (kBBlockLdsExtraN ? "true" : "false"); // 40. BBlockLdsExtraN
|
||||
oss << "," << kCShuffleMRepeatPerShuffle; // 41. CShuffleMRepeatPerShuffle
|
||||
oss << "," << kCShuffleNRepeatPerShuffle; // 42. CShuffleNRepeatPerShuffle
|
||||
oss << ","
|
||||
<< detail::array_to_string(
|
||||
kCDEThreadClusterLengths); // 43. CDEShuffleBlockTransferClusterLengths
|
||||
oss << ","
|
||||
<< kCDEBlockTransferScalarPerVector; // 44.
|
||||
// CDEShuffleBlockTransferScalarPerVector_NPerBlock
|
||||
oss << "," << detail::loop_scheduler_name(kLoopScheduler); // 45. LoopSched
|
||||
oss << "," << detail::pipeline_version_name(kPipelineVersion); // 46. PipelineVer
|
||||
oss << ">";
|
||||
|
||||
return oss.str();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile::reflect
|
||||
@@ -32,8 +32,8 @@ template <ck::index_t NDimSpatial,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CDEElementwiseOperation,
|
||||
ConvolutionForwardSpecialization ConvForwardSpecialization,
|
||||
GemmSpecialization GemmSpec,
|
||||
ck::tensor_operation::device::ConvolutionForwardSpecialization ConvForwardSpecialization,
|
||||
ck::tensor_operation::device::GemmSpecialization GemmSpec,
|
||||
ck::index_t NumGemmKPrefetchStage,
|
||||
ck::index_t BlockSize,
|
||||
ck::index_t MPerBlock,
|
||||
@@ -65,7 +65,7 @@ template <ck::index_t NDimSpatial,
|
||||
ck::index_t CDEBlockTransferScalarPerVector_NPerBlock,
|
||||
typename AComputeDataType,
|
||||
typename BComputeDataType,
|
||||
LoopScheduler LoopSched>
|
||||
ck::LoopScheduler LoopSched>
|
||||
struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor;
|
||||
|
||||
} // namespace ck::tensor_operation::device
|
||||
|
||||
@@ -16,6 +16,7 @@
|
||||
#include <ck/utility/sequence.hpp>
|
||||
#include <ck/utility/blkgemmpipe_scheduler.hpp>
|
||||
#include <ck/utility/loop_scheduler.hpp>
|
||||
#include <ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp>
|
||||
#include <ck/tensor_operation/gpu/device/tensor_layout.hpp>
|
||||
#include <ck_tile/ops/common/tensor_layout.hpp>
|
||||
#include <ck/tensor_operation/gpu/element/element_wise_operation.hpp>
|
||||
@@ -161,6 +162,19 @@ constexpr std::string_view pipeline_version_name(ck::BlockGemmPipelineVersion ve
|
||||
}
|
||||
}
|
||||
|
||||
// Convert PipelineVersion enum to string (for Wmma kernels)
|
||||
constexpr std::string_view pipeline_version_name(ck::PipelineVersion ver)
|
||||
{
|
||||
using enum ck::PipelineVersion;
|
||||
switch(ver)
|
||||
{
|
||||
case v1: return "v1";
|
||||
case v2: return "v2";
|
||||
case v4: return "v4";
|
||||
case weight_only: return "weight_only";
|
||||
}
|
||||
}
|
||||
|
||||
// Convert LoopScheduler enum to string
|
||||
constexpr std::string_view loop_scheduler_name(ck::LoopScheduler sched)
|
||||
{
|
||||
@@ -322,4 +336,24 @@ constexpr std::string tuple_name()
|
||||
}(static_cast<T*>(nullptr));
|
||||
}
|
||||
|
||||
// Concept to check if a type is a ck::Tuple
|
||||
template <typename T>
|
||||
concept IsCkTuple =
|
||||
requires { []<typename... Ts>(ck::Tuple<Ts...>*) {}(static_cast<T*>(nullptr)); };
|
||||
|
||||
// Deduces whether to use tuple_name or type_name
|
||||
// Handles both scalar data types and ck::Tuple types
|
||||
template <typename T>
|
||||
constexpr std::string type_or_type_tuple_name()
|
||||
{
|
||||
if constexpr(IsCkTuple<T>)
|
||||
{
|
||||
return tuple_name<T>();
|
||||
}
|
||||
else
|
||||
{
|
||||
return std::string(type_name<T>());
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace ck_tile::reflect::detail
|
||||
|
||||
Reference in New Issue
Block a user