Add the last two forward instance traits. (#3134)

* Add InstanceTraits for DeviceGroupedConvFwdMultipleD_Wmma_CShuffle

* Add InstanceTraits for kernel_grouped_conv_fwd_dl_multiple_d

* A few small changes to fix broken instance traits.

[ROCm/composable_kernel commit: 5ed2046bee]
This commit is contained in:
John Shumway
2025-10-31 07:52:42 -07:00
committed by GitHub
parent d2474f5396
commit a8a377ca53
17 changed files with 1207 additions and 82 deletions

View File

@@ -0,0 +1,341 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
// InstanceTraits specialization for DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK
//
// CRITICAL MAINTENANCE NOTE:
// This InstanceTraits file MUST be kept strictly in sync with the device implementation header:
// ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp
// "In sync" means that the template parameter order, names, and types in the declaration below
// MUST EXACTLY MATCH those in the device implementation. If these diverge, you may encounter
// compilation errors, subtle template instantiation mismatches, or silent runtime bugs that are
// difficult to diagnose. Always update both files together and review changes carefully.
#pragma once
#include "instance_traits.hpp"
// Forward declaration to avoid circular dependency.
// This file will be included by the device implementation header, so we cannot include
// the implementation header here. We only need the template signature to pattern-match
// on template parameters - we don't need any implementation details.
namespace ck::tensor_operation::device {
template <index_t NDimSpatial,
typename ADataType,
typename BDataType,
typename DsDataType,
typename EDataType,
typename AccDataType,
typename ALayout,
typename BLayout,
typename DsLayout,
typename ELayout,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CDEElementwiseOperation,
ck::tensor_operation::device::ConvolutionForwardSpecialization ConvForwardSpecialization,
ck::tensor_operation::device::GemmSpecialization GemmSpec,
ck::index_t BlockSize,
ck::index_t MPerBlock,
ck::index_t NPerBlock,
ck::index_t K0PerBlock,
ck::index_t K1,
ck::index_t M1PerThread,
ck::index_t N1PerThread,
ck::index_t KPerThread,
typename M1N1ThreadClusterM1Xs,
typename M1N1ThreadClusterN1Xs,
typename ABlockTransferThreadSliceLengths_K0_M0_M1_K1,
typename ABlockTransferThreadClusterLengths_K0_M0_M1_K1,
typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder,
typename ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1,
typename ABlockTransferSrcVectorTensorContiguousDimOrder,
typename ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1,
typename BBlockTransferThreadSliceLengths_K0_N0_N1_K1,
typename BBlockTransferThreadClusterLengths_K0_N0_N1_K1,
typename BBlockTransferThreadClusterArrangeOrder,
typename BBlockTransferSrcAccessOrder,
typename BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1,
typename BBlockTransferSrcVectorTensorContiguousDimOrder,
typename BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1,
typename CThreadTransferSrcDstAccessOrder,
ck::index_t CThreadTransferSrcDstVectorDim,
ck::index_t CThreadTransferDstScalarPerVector>
struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK;
} // namespace ck::tensor_operation::device
namespace ck_tile::reflect {
// Specialization for DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK
template <ck::index_t NDimSpatial,
typename ADataType_,
typename BDataType_,
typename DsDataType_,
typename EDataType_,
typename AccDataType_,
typename ALayout_,
typename BLayout_,
typename DsLayout_,
typename ELayout_,
typename AElementwiseOperation_,
typename BElementwiseOperation_,
typename CDEElementwiseOperation_,
ck::tensor_operation::device::ConvolutionForwardSpecialization ConvForwardSpecialization,
ck::tensor_operation::device::GemmSpecialization GemmSpec,
ck::index_t BlockSize,
ck::index_t MPerBlock,
ck::index_t NPerBlock,
ck::index_t K0PerBlock,
ck::index_t K1,
ck::index_t M1PerThread,
ck::index_t N1PerThread,
ck::index_t KPerThread,
typename M1N1ThreadClusterM1Xs_,
typename M1N1ThreadClusterN1Xs_,
typename ABlockTransferThreadSliceLengths_K0_M0_M1_K1_,
typename ABlockTransferThreadClusterLengths_K0_M0_M1_K1_,
typename ABlockTransferThreadClusterArrangeOrder_,
typename ABlockTransferSrcAccessOrder_,
typename ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1_,
typename ABlockTransferSrcVectorTensorContiguousDimOrder_,
typename ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1_,
typename BBlockTransferThreadSliceLengths_K0_N0_N1_K1_,
typename BBlockTransferThreadClusterLengths_K0_N0_N1_K1_,
typename BBlockTransferThreadClusterArrangeOrder_,
typename BBlockTransferSrcAccessOrder_,
typename BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1_,
typename BBlockTransferSrcVectorTensorContiguousDimOrder_,
typename BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1_,
typename CThreadTransferSrcDstAccessOrder_,
ck::index_t CThreadTransferSrcDstVectorDim,
ck::index_t CThreadTransferDstScalarPerVector>
struct InstanceTraits<ck::tensor_operation::device::DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK<
NDimSpatial,
ADataType_,
BDataType_,
DsDataType_,
EDataType_,
AccDataType_,
ALayout_,
BLayout_,
DsLayout_,
ELayout_,
AElementwiseOperation_,
BElementwiseOperation_,
CDEElementwiseOperation_,
ConvForwardSpecialization,
GemmSpec,
BlockSize,
MPerBlock,
NPerBlock,
K0PerBlock,
K1,
M1PerThread,
N1PerThread,
KPerThread,
M1N1ThreadClusterM1Xs_,
M1N1ThreadClusterN1Xs_,
ABlockTransferThreadSliceLengths_K0_M0_M1_K1_,
ABlockTransferThreadClusterLengths_K0_M0_M1_K1_,
ABlockTransferThreadClusterArrangeOrder_,
ABlockTransferSrcAccessOrder_,
ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1_,
ABlockTransferSrcVectorTensorContiguousDimOrder_,
ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1_,
BBlockTransferThreadSliceLengths_K0_N0_N1_K1_,
BBlockTransferThreadClusterLengths_K0_N0_N1_K1_,
BBlockTransferThreadClusterArrangeOrder_,
BBlockTransferSrcAccessOrder_,
BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1_,
BBlockTransferSrcVectorTensorContiguousDimOrder_,
BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1_,
CThreadTransferSrcDstAccessOrder_,
CThreadTransferSrcDstVectorDim,
CThreadTransferDstScalarPerVector>>
{
// Spatial dimension
static constexpr int kSpatialDim = NDimSpatial;
// Data types
using ADataType = ADataType_;
using BDataType = BDataType_;
using DsDataType = DsDataType_;
using EDataType = EDataType_;
using AccDataType = AccDataType_;
// Layout types
using ALayout = ALayout_;
using BLayout = BLayout_;
using DsLayout = DsLayout_;
using ELayout = ELayout_;
// Element-wise operations
using AElementwiseOperation = AElementwiseOperation_;
using BElementwiseOperation = BElementwiseOperation_;
using CDEElementwiseOperation = CDEElementwiseOperation_;
// Specialization
static constexpr ck::tensor_operation::device::ConvolutionForwardSpecialization
kConvForwardSpecialization = ConvForwardSpecialization;
static constexpr ck::tensor_operation::device::GemmSpecialization kGemmSpecialization =
GemmSpec;
// Block configuration
static constexpr int kBlockSize = BlockSize;
static constexpr int kMPerBlock = MPerBlock;
static constexpr int kNPerBlock = NPerBlock;
static constexpr int kK0PerBlock = K0PerBlock;
// Tuning parameters
static constexpr int kK1 = K1;
static constexpr int kM1PerThread = M1PerThread;
static constexpr int kN1PerThread = N1PerThread;
static constexpr int kKPerThread = KPerThread;
// Thread cluster configurations
using M1N1ThreadClusterM1Xs = M1N1ThreadClusterM1Xs_;
using M1N1ThreadClusterN1Xs = M1N1ThreadClusterN1Xs_;
// A block transfer parameters
using ABlockTransferThreadSliceLengths_K0_M0_M1_K1 =
ABlockTransferThreadSliceLengths_K0_M0_M1_K1_;
using ABlockTransferThreadClusterLengths_K0_M0_M1_K1 =
ABlockTransferThreadClusterLengths_K0_M0_M1_K1_;
using ABlockTransferThreadClusterArrangeOrder = ABlockTransferThreadClusterArrangeOrder_;
using ABlockTransferSrcAccessOrder = ABlockTransferSrcAccessOrder_;
using ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1 =
ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1_;
using ABlockTransferSrcVectorTensorContiguousDimOrder =
ABlockTransferSrcVectorTensorContiguousDimOrder_;
using ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1 =
ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1_;
// B block transfer parameters
using BBlockTransferThreadSliceLengths_K0_N0_N1_K1 =
BBlockTransferThreadSliceLengths_K0_N0_N1_K1_;
using BBlockTransferThreadClusterLengths_K0_N0_N1_K1 =
BBlockTransferThreadClusterLengths_K0_N0_N1_K1_;
using BBlockTransferThreadClusterArrangeOrder = BBlockTransferThreadClusterArrangeOrder_;
using BBlockTransferSrcAccessOrder = BBlockTransferSrcAccessOrder_;
using BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1 =
BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1_;
using BBlockTransferSrcVectorTensorContiguousDimOrder =
BBlockTransferSrcVectorTensorContiguousDimOrder_;
using BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1 =
BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1_;
// C thread transfer parameters
using CThreadTransferSrcDstAccessOrder = CThreadTransferSrcDstAccessOrder_;
static constexpr int kCThreadTransferSrcDstVectorDim = CThreadTransferSrcDstVectorDim;
static constexpr int kCThreadTransferDstScalarPerVector = CThreadTransferDstScalarPerVector;
// Static member function to generate instance string
static std::string instance_string()
{
std::ostringstream oss;
// Kernel type name
oss << "DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK";
// Template parameters in exact order matching the device implementation
oss << "<" << kSpatialDim; // 1. NDimSpatial
oss << "," << detail::type_name<ADataType>(); // 2. ADataType
oss << "," << detail::type_name<BDataType>(); // 3. BDataType
oss << "," << detail::tuple_name<DsDataType>(); // 4. DsDataType
oss << "," << detail::type_name<EDataType>(); // 5. EDataType
oss << "," << detail::type_name<AccDataType>(); // 6. AccDataType
oss << "," << detail::layout_name<ALayout>(); // 7. ALayout
oss << "," << detail::layout_name<BLayout>(); // 8. BLayout
oss << "," << detail::tuple_name<DsLayout>(); // 9. DsLayout
oss << "," << detail::layout_name<ELayout>(); // 10. ELayout
oss << ","
<< detail::elementwise_op_name<AElementwiseOperation>(); // 11. AElementwiseOperation
oss << ","
<< detail::elementwise_op_name<BElementwiseOperation>(); // 12. BElementwiseOperation
oss << ","
<< detail::elementwise_op_name<CDEElementwiseOperation>(); // 13.
// CDEElementwiseOperation
oss << ","
<< detail::conv_fwd_spec_name(
kConvForwardSpecialization); // 14. ConvForwardSpecialization
oss << "," << detail::gemm_spec_name(kGemmSpecialization); // 15. GemmSpec
oss << "," << kBlockSize; // 16. BlockSize
oss << "," << kMPerBlock; // 17. MPerBlock
oss << "," << kNPerBlock; // 18. NPerBlock
oss << "," << kK0PerBlock; // 19. K0PerBlock
oss << "," << kK1; // 20. K1
oss << "," << kM1PerThread; // 21. M1PerThread
oss << "," << kN1PerThread; // 22. N1PerThread
oss << "," << kKPerThread; // 23. KPerThread
oss << "," << detail::sequence_name<M1N1ThreadClusterM1Xs>(); // 24. M1N1ThreadClusterM1Xs
oss << "," << detail::sequence_name<M1N1ThreadClusterN1Xs>(); // 25. M1N1ThreadClusterN1Xs
oss << ","
<< detail::sequence_name<
ABlockTransferThreadSliceLengths_K0_M0_M1_K1>(); // 26.
// ABlockTransferThreadSliceLengths_K0_M0_M1_K1
oss << ","
<< detail::sequence_name<
ABlockTransferThreadClusterLengths_K0_M0_M1_K1>(); // 27.
// ABlockTransferThreadClusterLengths_K0_M0_M1_K1
oss << ","
<< detail::sequence_name<
ABlockTransferThreadClusterArrangeOrder>(); // 28.
// ABlockTransferThreadClusterArrangeOrder
oss << ","
<< detail::sequence_name<
ABlockTransferSrcAccessOrder>(); // 29. ABlockTransferSrcAccessOrder
oss << ","
<< detail::sequence_name<
ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1>(); // 30.
// ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1
oss << ","
<< detail::sequence_name<
ABlockTransferSrcVectorTensorContiguousDimOrder>(); // 31.
// ABlockTransferSrcVectorTensorContiguousDimOrder
oss << ","
<< detail::sequence_name<
ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1>(); // 32.
// ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1
oss << ","
<< detail::sequence_name<
BBlockTransferThreadSliceLengths_K0_N0_N1_K1>(); // 33.
// BBlockTransferThreadSliceLengths_K0_N0_N1_K1
oss << ","
<< detail::sequence_name<
BBlockTransferThreadClusterLengths_K0_N0_N1_K1>(); // 34.
// BBlockTransferThreadClusterLengths_K0_N0_N1_K1
oss << ","
<< detail::sequence_name<
BBlockTransferThreadClusterArrangeOrder>(); // 35.
// BBlockTransferThreadClusterArrangeOrder
oss << ","
<< detail::sequence_name<
BBlockTransferSrcAccessOrder>(); // 36. BBlockTransferSrcAccessOrder
oss << ","
<< detail::sequence_name<
BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1>(); // 37.
// BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1
oss << ","
<< detail::sequence_name<
BBlockTransferSrcVectorTensorContiguousDimOrder>(); // 38.
// BBlockTransferSrcVectorTensorContiguousDimOrder
oss << ","
<< detail::sequence_name<
BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1>(); // 39.
// BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1
oss << ","
<< detail::sequence_name<
CThreadTransferSrcDstAccessOrder>(); // 40. CThreadTransferSrcDstAccessOrder
oss << "," << kCThreadTransferSrcDstVectorDim; // 41. CThreadTransferSrcDstVectorDim
oss << "," << kCThreadTransferDstScalarPerVector; // 42. CThreadTransferDstScalarPerVector
oss << ">";
return oss.str();
}
};
} // namespace ck_tile::reflect

View File

@@ -32,8 +32,8 @@ template <ck::index_t NDimSpatial,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CDEElementwiseOperation,
ConvolutionForwardSpecialization ConvForwardSpecialization,
GemmSpecialization GemmSpec,
ck::tensor_operation::device::ConvolutionForwardSpecialization ConvForwardSpecialization,
ck::tensor_operation::device::GemmSpecialization GemmSpec,
ck::index_t NumGemmKPrefetchStage,
ck::index_t BlockSize,
ck::index_t MPerBlock,
@@ -65,7 +65,7 @@ template <ck::index_t NDimSpatial,
ck::index_t CDEBlockTransferScalarPerVector_NPerBlock,
typename AComputeDataType,
typename BComputeDataType,
LoopScheduler LoopSched,
ck::LoopScheduler LoopSched,
ck::index_t NumGroupsToMerge>
struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle;
@@ -269,17 +269,17 @@ struct InstanceTraits<ck::tensor_operation::device::DeviceGroupedConvFwdMultiple
oss << "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle";
// Template parameters in exact order matching InstanceTraits member order
oss << "<" << kSpatialDim; // 1. NDimSpatial
oss << "," << detail::layout_name<ALayout>(); // 2. ALayout
oss << "," << detail::layout_name<BLayout>(); // 3. BLayout
oss << "," << detail::tuple_name<DsLayout>(); // 4. DsLayout
oss << "," << detail::layout_name<ELayout>(); // 5. ELayout
oss << "," << detail::type_name<ADataType>(); // 6. ADataType
oss << "," << detail::type_name<BDataType>(); // 7. BDataType
oss << "," << detail::type_name<AccDataType>(); // 8. AccDataType
oss << "," << detail::type_name<CShuffleDataType>(); // 9. CShuffleDataType
oss << "," << detail::tuple_name<DsDataType>(); // 10. DsDataType
oss << "," << detail::type_name<EDataType>(); // 11. EDataType
oss << "<" << kSpatialDim; // 1. NDimSpatial
oss << "," << detail::layout_name<ALayout>(); // 2. ALayout
oss << "," << detail::layout_name<BLayout>(); // 3. BLayout
oss << "," << detail::tuple_name<DsLayout>(); // 4. DsLayout
oss << "," << detail::layout_name<ELayout>(); // 5. ELayout
oss << "," << detail::type_or_type_tuple_name<ADataType>(); // 6. ADataType
oss << "," << detail::type_or_type_tuple_name<BDataType>(); // 7. BDataType
oss << "," << detail::type_name<AccDataType>(); // 8. AccDataType
oss << "," << detail::type_name<CShuffleDataType>(); // 9. CShuffleDataType
oss << "," << detail::tuple_name<DsDataType>(); // 10. DsDataType
oss << "," << detail::type_name<EDataType>(); // 11. EDataType
oss << ","
<< detail::elementwise_op_name<AElementwiseOperation>(); // 12. AElementwiseOperation
oss << ","

View File

@@ -22,7 +22,7 @@
// on template parameters - we don't need any implementation details.
namespace ck::tensor_operation::device {
template <ck::index_t NDimSpatial,
template <index_t NDimSpatial,
typename ALayout,
typename BLayout,
typename DsLayout,
@@ -36,8 +36,8 @@ template <ck::index_t NDimSpatial,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CDEElementwiseOperation,
ConvolutionForwardSpecialization ConvForwardSpecialization,
GemmSpecialization GemmSpec,
ck::tensor_operation::device::ConvolutionForwardSpecialization ConvForwardSpecialization,
ck::tensor_operation::device::GemmSpecialization GemmSpec,
ck::index_t BlockSize,
ck::index_t MPerBlock,
ck::index_t NPerBlock,
@@ -259,6 +259,8 @@ struct InstanceTraits<ck::tensor_operation::device::DeviceGroupedConvFwdMultiple
using AComputeDataType = AComputeDataType_;
using BComputeDataType = BComputeDataType_;
static constexpr bool kDirectLoad = DirectLoad;
// Static member function to generate instance string
static std::string instance_string()
{

View File

@@ -0,0 +1,343 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
// InstanceTraits specialization for DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
//
// CRITICAL MAINTENANCE NOTE:
// This InstanceTraits file MUST be kept strictly in sync with the device implementation header:
// ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp
// "In sync" means that the template parameter order, names, and types in the declaration below
// MUST EXACTLY MATCH those in the device implementation. If these diverge, you may encounter
// compilation errors, subtle template instantiation mismatches, or silent runtime bugs that are
// difficult to diagnose. Always update both files together and review changes carefully.
#pragma once
#include "instance_traits.hpp"
// Forward declaration to avoid circular dependency.
// This file will be included by the device implementation header, so we cannot include
// the implementation header here. We only need the template signature to pattern-match
// on template parameters - we don't need any implementation details.
// Forward declare types from ck namespace that are used in the template parameters
namespace ck {
enum struct PipelineVersion;
enum struct LoopScheduler;
} // namespace ck
namespace ck::tensor_operation::device {
template <ck::index_t NDimSpatial,
typename ALayout,
typename BLayout,
typename DsLayout,
typename ELayout,
typename ADataType,
typename BDataType,
typename AccDataType,
typename CShuffleDataType,
typename DsDataType,
typename EDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CDEElementwiseOperation,
ck::tensor_operation::device::ConvolutionForwardSpecialization ConvForwardSpecialization,
ck::tensor_operation::device::GemmSpecialization GemmSpec,
ck::index_t NumGemmKPrefetchStage,
ck::index_t BlockSize,
ck::index_t MPerBlock,
ck::index_t NPerBlock,
ck::index_t KPerBlock,
ck::index_t K1,
ck::index_t MPerWmma,
ck::index_t NPerWmma,
ck::index_t MRepeat,
ck::index_t NRepeat,
typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder,
ck::index_t ABlockTransferSrcVectorDim,
ck::index_t ABlockTransferSrcScalarPerVector,
ck::index_t ABlockTransferDstScalarPerVector_AK1,
bool ABlockLdsExtraM,
typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
typename BBlockTransferThreadClusterArrangeOrder,
typename BBlockTransferSrcAccessOrder,
ck::index_t BBlockTransferSrcVectorDim,
ck::index_t BBlockTransferSrcScalarPerVector,
ck::index_t BBlockTransferDstScalarPerVector_BK1,
bool BBlockLdsExtraN,
ck::index_t CShuffleMRepeatPerShuffle,
ck::index_t CShuffleNRepeatPerShuffle,
typename CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
ck::index_t CDEShuffleBlockTransferScalarPerVector_NPerBlock,
ck::LoopScheduler LoopSched,
ck::PipelineVersion PipelineVer>
struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle;
} // namespace ck::tensor_operation::device
namespace ck_tile::reflect {
// Specialization for DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
template <ck::index_t NDimSpatial,
typename ALayout_,
typename BLayout_,
typename DsLayout_,
typename ELayout_,
typename ADataType_,
typename BDataType_,
typename AccDataType_,
typename CShuffleDataType_,
typename DsDataType_,
typename EDataType_,
typename AElementwiseOperation_,
typename BElementwiseOperation_,
typename CDEElementwiseOperation_,
ck::tensor_operation::device::ConvolutionForwardSpecialization ConvForwardSpecialization,
ck::tensor_operation::device::GemmSpecialization GemmSpec,
ck::index_t NumGemmKPrefetchStage,
ck::index_t BlockSize,
ck::index_t MPerBlock,
ck::index_t NPerBlock,
ck::index_t KPerBlock,
ck::index_t K1,
ck::index_t MPerWmma,
ck::index_t NPerWmma,
ck::index_t MRepeat,
ck::index_t NRepeat,
typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder_,
ck::index_t ABlockTransferSrcVectorDim,
ck::index_t ABlockTransferSrcScalarPerVector,
ck::index_t ABlockTransferDstScalarPerVector_AK1,
bool ABlockLdsExtraM,
typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
typename BBlockTransferThreadClusterArrangeOrder,
typename BBlockTransferSrcAccessOrder_,
ck::index_t BBlockTransferSrcVectorDim,
ck::index_t BBlockTransferSrcScalarPerVector,
ck::index_t BBlockTransferDstScalarPerVector_BK1,
bool BBlockLdsExtraN,
ck::index_t CShuffleMRepeatPerShuffle,
ck::index_t CShuffleNRepeatPerShuffle,
typename CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
ck::index_t CDEShuffleBlockTransferScalarPerVector_NPerBlock,
ck::LoopScheduler LoopSched,
ck::PipelineVersion PipelineVer>
struct InstanceTraits<ck::tensor_operation::device::DeviceGroupedConvFwdMultipleD_Wmma_CShuffle<
NDimSpatial,
ALayout_,
BLayout_,
DsLayout_,
ELayout_,
ADataType_,
BDataType_,
AccDataType_,
CShuffleDataType_,
DsDataType_,
EDataType_,
AElementwiseOperation_,
BElementwiseOperation_,
CDEElementwiseOperation_,
ConvForwardSpecialization,
GemmSpec,
NumGemmKPrefetchStage,
BlockSize,
MPerBlock,
NPerBlock,
KPerBlock,
K1,
MPerWmma,
NPerWmma,
MRepeat,
NRepeat,
ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder_,
ABlockTransferSrcVectorDim,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_AK1,
ABlockLdsExtraM,
BBlockTransferThreadClusterLengths_BK0_N_BK1,
BBlockTransferThreadClusterArrangeOrder,
BBlockTransferSrcAccessOrder_,
BBlockTransferSrcVectorDim,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_BK1,
BBlockLdsExtraN,
CShuffleMRepeatPerShuffle,
CShuffleNRepeatPerShuffle,
CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
CDEShuffleBlockTransferScalarPerVector_NPerBlock,
LoopSched,
PipelineVer>>
{
// Spatial dimension
static constexpr int kSpatialDim = NDimSpatial;
// Layout types
using ALayout = ALayout_;
using BLayout = BLayout_;
using DsLayout = DsLayout_;
using ELayout = ELayout_;
// Data types
using ADataType = ADataType_;
using BDataType = BDataType_;
using AccDataType = AccDataType_;
using CShuffleDataType = CShuffleDataType_;
using DsDataType = DsDataType_;
using EDataType = EDataType_;
// Element-wise operations
using AElementwiseOperation = AElementwiseOperation_;
using BElementwiseOperation = BElementwiseOperation_;
using CDEElementwiseOperation = CDEElementwiseOperation_;
// Specialization
static constexpr ck::tensor_operation::device::ConvolutionForwardSpecialization
kConvForwardSpecialization = ConvForwardSpecialization;
static constexpr ck::tensor_operation::device::GemmSpecialization kGemmSpecialization =
GemmSpec;
// Prefetch stage
static constexpr int kNumGemmKPrefetchStage = NumGemmKPrefetchStage;
// Block configuration
static constexpr int kBlockSize = BlockSize;
static constexpr int kMPerBlock = MPerBlock;
static constexpr int kNPerBlock = NPerBlock;
static constexpr int kKPerBlock = KPerBlock;
// Tuning parameters
static constexpr int kK1 = K1;
static constexpr int kMPerWmma = MPerWmma;
static constexpr int kNPerWmma = NPerWmma;
static constexpr int kMRepeat = MRepeat;
static constexpr int kNRepeat = NRepeat;
// A block transfer thread cluster dimensions (converted to std::array)
static constexpr auto kAThreadClusterLengths =
detail::SequenceToArray<ABlockTransferThreadClusterLengths_AK0_M_AK1>::value;
static constexpr auto kAThreadClusterArrangeOrder =
detail::SequenceToArray<ABlockTransferThreadClusterArrangeOrder>::value;
static constexpr auto kABlockTransferSrcAccessOrder =
detail::SequenceToArray<ABlockTransferSrcAccessOrder_>::value;
static constexpr int kABlockTransferSrcVectorDim = ABlockTransferSrcVectorDim;
static constexpr int kABlockTransferSrcScalarPerVector = ABlockTransferSrcScalarPerVector;
static constexpr int kABlockTransferDstScalarPerVectorK1 = ABlockTransferDstScalarPerVector_AK1;
static constexpr bool kABlockLdsExtraM = ABlockLdsExtraM;
// B block transfer thread cluster dimensions (converted to std::array)
static constexpr auto kBThreadClusterLengths =
detail::SequenceToArray<BBlockTransferThreadClusterLengths_BK0_N_BK1>::value;
static constexpr auto kBThreadClusterArrangeOrder =
detail::SequenceToArray<BBlockTransferThreadClusterArrangeOrder>::value;
static constexpr auto kBBlockTransferSrcAccessOrder =
detail::SequenceToArray<BBlockTransferSrcAccessOrder_>::value;
static constexpr int kBBlockTransferSrcVectorDim = BBlockTransferSrcVectorDim;
static constexpr int kBBlockTransferSrcScalarPerVector = BBlockTransferSrcScalarPerVector;
static constexpr int kBBlockTransferDstScalarPerVectorK1 = BBlockTransferDstScalarPerVector_BK1;
static constexpr bool kBBlockLdsExtraN = BBlockLdsExtraN;
// C shuffle parameters (converted to std::array)
static constexpr int kCShuffleMRepeatPerShuffle = CShuffleMRepeatPerShuffle;
static constexpr int kCShuffleNRepeatPerShuffle = CShuffleNRepeatPerShuffle;
static constexpr auto kCDEThreadClusterLengths = detail::SequenceToArray<
CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock>::value;
static constexpr int kCDEBlockTransferScalarPerVector =
CDEShuffleBlockTransferScalarPerVector_NPerBlock;
// Pipeline configuration
static constexpr ck::LoopScheduler kLoopScheduler = LoopSched;
static constexpr ck::PipelineVersion kPipelineVersion = PipelineVer;
// Static member function to generate instance string
static std::string instance_string()
{
std::ostringstream oss;
// Kernel type name
oss << "DeviceGroupedConvFwdMultipleD_Wmma_CShuffle";
// Template parameters in exact order matching InstanceTraits member order
oss << "<" << kSpatialDim; // 1. NDimSpatial
oss << "," << detail::layout_name<ALayout>(); // 2. ALayout
oss << "," << detail::layout_name<BLayout>(); // 3. BLayout
oss << "," << detail::tuple_name<DsLayout>(); // 4. DsLayout
oss << "," << detail::layout_name<ELayout>(); // 5. ELayout
oss << "," << detail::type_name<ADataType>(); // 6. ADataType
oss << "," << detail::type_name<BDataType>(); // 7. BDataType
oss << "," << detail::type_name<AccDataType>(); // 8. AccDataType
oss << "," << detail::type_name<CShuffleDataType>(); // 9. CShuffleDataType
oss << "," << detail::tuple_name<DsDataType>(); // 10. DsDataType
oss << "," << detail::type_name<EDataType>(); // 11. EDataType
oss << ","
<< detail::elementwise_op_name<AElementwiseOperation>(); // 12. AElementwiseOperation
oss << ","
<< detail::elementwise_op_name<BElementwiseOperation>(); // 13. BElementwiseOperation
oss << ","
<< detail::elementwise_op_name<CDEElementwiseOperation>(); // 14.
// CDEElementwiseOperation
oss << ","
<< detail::conv_fwd_spec_name(
kConvForwardSpecialization); // 15. ConvForwardSpecialization
oss << "," << detail::gemm_spec_name(kGemmSpecialization); // 16. GemmSpec
oss << "," << kNumGemmKPrefetchStage; // 17. NumGemmKPrefetchStage
oss << "," << kBlockSize; // 18. BlockSize
oss << "," << kMPerBlock; // 19. MPerBlock
oss << "," << kNPerBlock; // 20. NPerBlock
oss << "," << kKPerBlock; // 21. KPerBlock
oss << "," << kK1; // 22. K1
oss << "," << kMPerWmma; // 23. MPerWmma
oss << "," << kNPerWmma; // 24. NPerWmma
oss << "," << kMRepeat; // 25. MRepeat
oss << "," << kNRepeat; // 26. NRepeat
oss << ","
<< detail::array_to_string(
kAThreadClusterLengths); // 27. ABlockTransferThreadClusterLengths
oss << ","
<< detail::array_to_string(
kAThreadClusterArrangeOrder); // 28. ABlockTransferThreadClusterArrangeOrder
oss << ","
<< detail::array_to_string(
kABlockTransferSrcAccessOrder); // 29. ABlockTransferSrcAccessOrder
oss << "," << kABlockTransferSrcVectorDim; // 30. ABlockTransferSrcVectorDim
oss << "," << kABlockTransferSrcScalarPerVector; // 31. ABlockTransferSrcScalarPerVector
oss << ","
<< kABlockTransferDstScalarPerVectorK1; // 32. ABlockTransferDstScalarPerVector_AK1
oss << "," << (kABlockLdsExtraM ? "true" : "false"); // 33. ABlockLdsExtraM
oss << ","
<< detail::array_to_string(
kBThreadClusterLengths); // 34. BBlockTransferThreadClusterLengths
oss << ","
<< detail::array_to_string(
kBThreadClusterArrangeOrder); // 35. BBlockTransferThreadClusterArrangeOrder
oss << ","
<< detail::array_to_string(
kBBlockTransferSrcAccessOrder); // 36. BBlockTransferSrcAccessOrder
oss << "," << kBBlockTransferSrcVectorDim; // 37. BBlockTransferSrcVectorDim
oss << "," << kBBlockTransferSrcScalarPerVector; // 38. BBlockTransferSrcScalarPerVector
oss << ","
<< kBBlockTransferDstScalarPerVectorK1; // 39. BBlockTransferDstScalarPerVector_BK1
oss << "," << (kBBlockLdsExtraN ? "true" : "false"); // 40. BBlockLdsExtraN
oss << "," << kCShuffleMRepeatPerShuffle; // 41. CShuffleMRepeatPerShuffle
oss << "," << kCShuffleNRepeatPerShuffle; // 42. CShuffleNRepeatPerShuffle
oss << ","
<< detail::array_to_string(
kCDEThreadClusterLengths); // 43. CDEShuffleBlockTransferClusterLengths
oss << ","
<< kCDEBlockTransferScalarPerVector; // 44.
// CDEShuffleBlockTransferScalarPerVector_NPerBlock
oss << "," << detail::loop_scheduler_name(kLoopScheduler); // 45. LoopSched
oss << "," << detail::pipeline_version_name(kPipelineVersion); // 46. PipelineVer
oss << ">";
return oss.str();
}
};
} // namespace ck_tile::reflect

View File

@@ -32,8 +32,8 @@ template <ck::index_t NDimSpatial,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CDEElementwiseOperation,
ConvolutionForwardSpecialization ConvForwardSpecialization,
GemmSpecialization GemmSpec,
ck::tensor_operation::device::ConvolutionForwardSpecialization ConvForwardSpecialization,
ck::tensor_operation::device::GemmSpecialization GemmSpec,
ck::index_t NumGemmKPrefetchStage,
ck::index_t BlockSize,
ck::index_t MPerBlock,
@@ -65,7 +65,7 @@ template <ck::index_t NDimSpatial,
ck::index_t CDEBlockTransferScalarPerVector_NPerBlock,
typename AComputeDataType,
typename BComputeDataType,
LoopScheduler LoopSched>
ck::LoopScheduler LoopSched>
struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor;
} // namespace ck::tensor_operation::device

View File

@@ -16,6 +16,7 @@
#include <ck/utility/sequence.hpp>
#include <ck/utility/blkgemmpipe_scheduler.hpp>
#include <ck/utility/loop_scheduler.hpp>
#include <ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp>
#include <ck/tensor_operation/gpu/device/tensor_layout.hpp>
#include <ck_tile/ops/common/tensor_layout.hpp>
#include <ck/tensor_operation/gpu/element/element_wise_operation.hpp>
@@ -161,6 +162,19 @@ constexpr std::string_view pipeline_version_name(ck::BlockGemmPipelineVersion ve
}
}
// Convert PipelineVersion enum to string (for Wmma kernels)
constexpr std::string_view pipeline_version_name(ck::PipelineVersion ver)
{
using enum ck::PipelineVersion;
switch(ver)
{
case v1: return "v1";
case v2: return "v2";
case v4: return "v4";
case weight_only: return "weight_only";
}
}
// Convert LoopScheduler enum to string
constexpr std::string_view loop_scheduler_name(ck::LoopScheduler sched)
{
@@ -322,4 +336,24 @@ constexpr std::string tuple_name()
}(static_cast<T*>(nullptr));
}
// Concept to check if a type is a ck::Tuple
template <typename T>
concept IsCkTuple =
requires { []<typename... Ts>(ck::Tuple<Ts...>*) {}(static_cast<T*>(nullptr)); };
// Deduces whether to use tuple_name or type_name
// Handles both scalar data types and ck::Tuple types
template <typename T>
constexpr std::string type_or_type_tuple_name()
{
if constexpr(IsCkTuple<T>)
{
return tuple_name<T>();
}
else
{
return std::string(type_name<T>());
}
}
} // namespace ck_tile::reflect::detail

View File

@@ -28,7 +28,9 @@ add_ck_builder_test(test_ckb_inline_diff test_inline_diff.cpp)
add_ck_builder_test(test_ckb_get_instance_string
test_get_instance_string_fwd_grp_conv_v3.cpp
test_get_instance_string_fwd_grp_conv.cpp
test_get_instance_string_fwd_grp_conv_large_tensor.cpp)
test_get_instance_string_fwd_grp_conv_large_tensor.cpp
test_get_instance_string_fwd_grp_conv_wmma.cpp
test_get_instance_string_fwd_grp_conv_dl.cpp)
# Testing the fwd convolution builder requires kernel compilation.
# To enable parallel compilation, the individual tests are split into separate files.

View File

@@ -9,12 +9,21 @@
#include <ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp>
#include <ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp>
#include <ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp>
#include <ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp>
#include <ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp>
namespace {
using ::testing::ElementsAre;
TEST(InstanceTraitsTest, ConvFwdInstanceTraitsExtraction)
// NOTE: The V3ExtractsAllFieldsCorrectly test below performs detailed field extraction testing
// for the V3 variant as a reference implementation. For new InstanceTraits specializations,
// only the instance_string() functionality needs to be tested. Each new specialization should have:
// 1. A test using instance_string<T>() directly (in this file)
// 2. A test using GetInstanceString() through base class pointer (in separate
// test_get_instance_string_*.cpp file) This prevents test duplication while ensuring both access
// methods work correctly.
TEST(InstanceTraits, V3ExtractsAllFieldsCorrectly)
{
// Define a concrete instance type with specific template parameters
using DeviceInstance =
@@ -70,7 +79,8 @@ TEST(InstanceTraitsTest, ConvFwdInstanceTraitsExtraction)
ck::BlockGemmPipelineScheduler::Intrawave, // BlkGemmPipeSched
ck::BlockGemmPipelineVersion::v1, // BlkGemmPipelineVer
ck::half_t, // AComputeDataType
ck::half_t>; // BComputeDataType
ck::half_t, // BComputeDataType
false>;
// Use InstanceTraits to extract compile-time information
using Traits = ck_tile::reflect::InstanceTraits<DeviceInstance>;
@@ -155,7 +165,7 @@ TEST(InstanceTraitsTest, ConvFwdInstanceTraitsExtraction)
ck::tensor_operation::element_wise::PassThrough>::value));
}
TEST(InstanceTraitsTest, V3InstanceStringGeneration)
TEST(InstanceTraits, V3InstanceStringReturnsCorrectFormat)
{
// Define a concrete instance type with specific template parameters
using DeviceInstance =
@@ -211,7 +221,8 @@ TEST(InstanceTraitsTest, V3InstanceStringGeneration)
ck::BlockGemmPipelineScheduler::Intrawave, // BlkGemmPipeSched
ck::BlockGemmPipelineVersion::v1, // BlkGemmPipelineVer
ck::half_t, // AComputeDataType
ck::half_t>; // BComputeDataType
ck::half_t, // BComputeDataType
false>; // DirectLoad
std::string instance_str = ck_tile::reflect::instance_string<DeviceInstance>();
@@ -263,12 +274,13 @@ TEST(InstanceTraitsTest, V3InstanceStringGeneration)
",Intrawave" // BlkGemmPipeSched
",v1" // BlkGemmPipelineVer
",fp16" // AComputeDataType
",fp16>"; // BComputeDataType
",fp16" // BComputeDataType
",false>"; // DirectLoad
EXPECT_EQ(instance_str, expected_str);
}
TEST(InstanceTraitsTest, BaseInstanceStringGeneration)
TEST(InstanceTraits, BaseInstanceStringReturnsCorrectFormat)
{
using DeviceInstance =
ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<
@@ -382,7 +394,7 @@ TEST(InstanceTraitsTest, BaseInstanceStringGeneration)
EXPECT_EQ(instance_str, expected_str);
}
TEST(InstanceTraitsTest, LargeTensorInstanceStringGeneration)
TEST(InstanceTraits, LargeTensorInstanceStringReturnsCorrectFormat)
{
using DeviceInstance =
ck::tensor_operation::device::DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor<
@@ -497,4 +509,215 @@ TEST(InstanceTraitsTest, LargeTensorInstanceStringGeneration)
EXPECT_EQ(instance_str, expected_str);
}
TEST(InstanceTraits, WmmaInstanceStringReturnsCorrectFormat)
{
using DeviceInstance =
ck::tensor_operation::device::DeviceGroupedConvFwdMultipleD_Wmma_CShuffle<
2, // NDimSpatial
ck::tensor_layout::convolution::GNHWC, // ALayout
ck::tensor_layout::convolution::GKYXC, // BLayout
ck::Tuple<>, // DsLayout
ck::tensor_layout::convolution::GNHWK, // ELayout
ck::half_t, // ADataType
ck::half_t, // BDataType
float, // AccDataType
ck::half_t, // CShuffleDataType
ck::Tuple<>, // DsDataType
ck::half_t, // EDataType
ck::tensor_operation::element_wise::PassThrough, // AElementwiseOperation
ck::tensor_operation::element_wise::PassThrough, // BElementwiseOperation
ck::tensor_operation::element_wise::PassThrough, // CDEElementwiseOperation
ck::tensor_operation::device::ConvolutionForwardSpecialization::
Default, // ConvForwardSpecialization
ck::tensor_operation::device::GemmSpecialization::MNKPadding, // GemmSpec
1, // NumGemmKPrefetchStage
128, // BlockSize
64, // MPerBlock
64, // NPerBlock
32, // KPerBlock
8, // K1
16, // MPerWmma
16, // NPerWmma
2, // MRepeat
2, // NRepeat
ck::Sequence<4, 32, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1
ck::Sequence<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder
ck::Sequence<1, 0, 2>, // ABlockTransferSrcAccessOrder
2, // ABlockTransferSrcVectorDim
1, // ABlockTransferSrcScalarPerVector
8, // ABlockTransferDstScalarPerVector_AK1
1, // ABlockLdsExtraM
ck::Sequence<4, 32, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1
ck::Sequence<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder
ck::Sequence<1, 0, 2>, // BBlockTransferSrcAccessOrder
2, // BBlockTransferSrcVectorDim
1, // BBlockTransferSrcScalarPerVector
8, // BBlockTransferDstScalarPerVector_BK1
1, // BBlockLdsExtraN
1, // CShuffleMRepeatPerShuffle
1, // CShuffleNRepeatPerShuffle
ck::Sequence<1,
32,
1,
4>, // CDEShuffleBlockTransferClusterLengths
1, // CDEShuffleBlockTransferScalarPerVector_NPerBlock
ck::LoopScheduler::Default, // LoopSched
ck::PipelineVersion::v1>; // PipelineVer
// Generate instance string
std::string instance_str = ck_tile::reflect::instance_string<DeviceInstance>();
// Expected string with all 46 template parameters
std::string expected_str = "DeviceGroupedConvFwdMultipleD_Wmma_CShuffle"
"<2" // NDimSpatial
",GNHWC" // ALayout
",GKYXC" // BLayout
",EmptyTuple" // DsLayout
",GNHWK" // ELayout
",fp16" // ADataType
",fp16" // BDataType
",fp32" // AccDataType
",fp16" // CShuffleDataType
",EmptyTuple" // DsDataType
",fp16" // EDataType
",PassThrough" // AElementwiseOperation
",PassThrough" // BElementwiseOperation
",PassThrough" // CDEElementwiseOperation
",Default" // ConvForwardSpecialization
",MNKPadding" // GemmSpec
",1" // NumGemmKPrefetchStage
",128" // BlockSize
",64" // MPerBlock
",64" // NPerBlock
",32" // KPerBlock
",8" // K1
",16" // MPerWmma
",16" // NPerWmma
",2" // MRepeat
",2" // NRepeat
",Seq(4,32,1)" // ABlockTransferThreadClusterLengths
",Seq(1,0,2)" // ABlockTransferThreadClusterArrangeOrder
",Seq(1,0,2)" // ABlockTransferSrcAccessOrder
",2" // ABlockTransferSrcVectorDim
",1" // ABlockTransferSrcScalarPerVector
",8" // ABlockTransferDstScalarPerVector_AK1
",true" // ABlockLdsExtraM
",Seq(4,32,1)" // BBlockTransferThreadClusterLengths
",Seq(1,0,2)" // BBlockTransferThreadClusterArrangeOrder
",Seq(1,0,2)" // BBlockTransferSrcAccessOrder
",2" // BBlockTransferSrcVectorDim
",1" // BBlockTransferSrcScalarPerVector
",8" // BBlockTransferDstScalarPerVector_BK1
",true" // BBlockLdsExtraN
",1" // CShuffleMRepeatPerShuffle
",1" // CShuffleNRepeatPerShuffle
",Seq(1,32,1,4)" // CDEShuffleBlockTransferClusterLengths
",1" // CDEShuffleBlockTransferScalarPerVector_NPerBlock
",Default" // LoopSched
",v1>"; // PipelineVer
// Verify the generated string matches exactly
EXPECT_EQ(instance_str, expected_str);
}
TEST(InstanceTraits, DlInstanceStringReturnsCorrectFormat)
{
using DeviceInstance =
ck::tensor_operation::device::DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK<
2, // NDimSpatial
ck::half_t, // ADataType
ck::half_t, // BDataType
ck::Tuple<>, // DsDataType
ck::half_t, // EDataType
float, // AccDataType
ck::tensor_layout::convolution::GNHWC, // ALayout
ck::tensor_layout::convolution::GKYXC, // BLayout
ck::Tuple<>, // DsLayout
ck::tensor_layout::convolution::GNHWK, // ELayout
ck::tensor_operation::element_wise::PassThrough, // AElementwiseOperation
ck::tensor_operation::element_wise::PassThrough, // BElementwiseOperation
ck::tensor_operation::element_wise::PassThrough, // CDEElementwiseOperation
ck::tensor_operation::device::ConvolutionForwardSpecialization::
Default, // ConvForwardSpecialization
ck::tensor_operation::device::GemmSpecialization::MNKPadding, // GemmSpec
8, // BlockSize
16, // MPerBlock
4, // NPerBlock
2, // K0PerBlock
1, // K1
1, // M1PerThread
2, // N1PerThread
1, // KPerThread
ck::Sequence<4, 2>, // M1N1ThreadClusterM1Xs
ck::Sequence<1, 1>, // M1N1ThreadClusterN1Xs
ck::Sequence<2, 1, 2, 1>, // ABlockTransferThreadSliceLengths_K0_M0_M1_K1
ck::Sequence<1, 1, 8, 1>, // ABlockTransferThreadClusterLengths_K0_M0_M1_K1
ck::Sequence<1, 2, 0, 3>, // ABlockTransferThreadClusterArrangeOrder
ck::Sequence<1, 2, 0, 3>, // ABlockTransferSrcAccessOrder
ck::Sequence<1, 1, 1, 1>, // ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1
ck::Sequence<1, 2, 0, 3>, // ABlockTransferSrcVectorTensorContiguousDimOrder
ck::Sequence<1, 1, 1, 1>, // ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1
ck::Sequence<1, 1, 1, 1>, // BBlockTransferThreadSliceLengths_K0_N0_N1_K1
ck::Sequence<2, 1, 4, 1>, // BBlockTransferThreadClusterLengths_K0_N0_N1_K1
ck::Sequence<1, 2, 0, 3>, // BBlockTransferThreadClusterArrangeOrder
ck::Sequence<1, 2, 0, 3>, // BBlockTransferSrcAccessOrder
ck::Sequence<1, 1, 1, 1>, // BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1
ck::Sequence<1, 2, 0, 3>, // BBlockTransferSrcVectorTensorContiguousDimOrder
ck::Sequence<1, 1, 1, 1>, // BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1
ck::Sequence<0, 1, 2, 3, 4, 5>, // CThreadTransferSrcDstAccessOrder
5, // CThreadTransferSrcDstVectorDim
1>; // CThreadTransferDstScalarPerVector
// Generate instance string
std::string instance_str = ck_tile::reflect::instance_string<DeviceInstance>();
// Expected string with all 42 template parameters
std::string expected_str = "DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK"
"<2" // NDimSpatial
",fp16" // ADataType
",fp16" // BDataType
",EmptyTuple" // DsDataType
",fp16" // EDataType
",fp32" // AccDataType
",GNHWC" // ALayout
",GKYXC" // BLayout
",EmptyTuple" // DsLayout
",GNHWK" // ELayout
",PassThrough" // AElementwiseOperation
",PassThrough" // BElementwiseOperation
",PassThrough" // CDEElementwiseOperation
",Default" // ConvForwardSpecialization
",MNKPadding" // GemmSpec
",8" // BlockSize
",16" // MPerBlock
",4" // NPerBlock
",2" // K0PerBlock
",1" // K1
",1" // M1PerThread
",2" // N1PerThread
",1" // KPerThread
",Seq(4,2)" // M1N1ThreadClusterM1Xs
",Seq(1,1)" // M1N1ThreadClusterN1Xs
",Seq(2,1,2,1)" // ABlockTransferThreadSliceLengths_K0_M0_M1_K1
",Seq(1,1,8,1)" // ABlockTransferThreadClusterLengths_K0_M0_M1_K1
",Seq(1,2,0,3)" // ABlockTransferThreadClusterArrangeOrder
",Seq(1,2,0,3)" // ABlockTransferSrcAccessOrder
",Seq(1,1,1,1)" // ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1
",Seq(1,2,0,3)" // ABlockTransferSrcVectorTensorContiguousDimOrder
",Seq(1,1,1,1)" // ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1
",Seq(1,1,1,1)" // BBlockTransferThreadSliceLengths_K0_N0_N1_K1
",Seq(2,1,4,1)" // BBlockTransferThreadClusterLengths_K0_N0_N1_K1
",Seq(1,2,0,3)" // BBlockTransferThreadClusterArrangeOrder
",Seq(1,2,0,3)" // BBlockTransferSrcAccessOrder
",Seq(1,1,1,1)" // BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1
",Seq(1,2,0,3)" // BBlockTransferSrcVectorTensorContiguousDimOrder
",Seq(1,1,1,1)" // BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1
",Seq(0,1,2,3,4,5)" // CThreadTransferSrcDstAccessOrder
",5" // CThreadTransferSrcDstVectorDim
",1>"; // CThreadTransferDstScalarPerVector
// Verify the generated string matches exactly
EXPECT_EQ(instance_str, expected_str);
}
} // anonymous namespace

View File

@@ -3,7 +3,7 @@
#include <gtest/gtest.h>
#include <ck_tile/builder/reflect/instance_traits.hpp>
#include <ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_abd.hpp>
#include <ck/tensor_operation/gpu/device/device_base.hpp>
#include <ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp>
// Test GetInstanceString through base class pointer for non-V3 variant
@@ -22,22 +22,8 @@ TEST(GetInstanceString, ReturnsStringForFwdGrpConvInstance)
// Get the first instance from the tuple
using DeviceInstance = typename std::tuple_element<0, InstanceTuple>::type;
// Define the base class type using DeviceGroupedConvFwdMultipleABD
using BaseClass = ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD<
2, // NDimSpatial
ck::tensor_operation::device::instance::GNHWC, // ALayout
ck::tensor_operation::device::instance::GKYXC, // BLayout
ck::tensor_operation::device::instance::Empty_Tuple, // DsLayout
ck::tensor_operation::device::instance::GNHWK, // ELayout
ck::half_t, // ADataType
ck::half_t, // BDataType
ck::Tuple<>, // DsDataType
ck::half_t, // EDataType
ck::tensor_operation::element_wise::PassThrough, // AElementwiseOperation
ck::tensor_operation::element_wise::PassThrough, // BElementwiseOperation
ck::tensor_operation::element_wise::PassThrough, // CDEElementwiseOperation
ck::half_t, // AComputeType
ck::half_t>; // BComputeType
// Define the base class type using the most general operator base
using BaseClass = ck::tensor_operation::device::BaseOperator;
// Create an instance of the derived class
DeviceInstance device_instance;

View File

@@ -0,0 +1,85 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include <gtest/gtest.h>
#include <ck_tile/builder/reflect/instance_traits.hpp>
#include <ck/tensor_operation/gpu/device/device_base.hpp>
#include <ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_dl_instance.hpp>
// Test GetInstanceString through base class pointer for DL variant
TEST(GetInstanceString, ReturnsStringForFwdGrpConvDlInstance)
{
// Use the template helper to get a working instance configuration
using InstanceTuple =
ck::tensor_operation::device::instance::device_grouped_conv2d_fwd_dl_f16_instances<
ck::tensor_operation::device::instance::GNHWC, // InLayout
ck::tensor_operation::device::instance::GKYXC, // WeiLayout
ck::tensor_operation::device::instance::Empty_Tuple, // DsLayout
ck::tensor_operation::device::instance::GNHWK, // OutLayout
ck::Tuple<>, // DsDatatype
ck::tensor_operation::element_wise::PassThrough, // CDEElementOp
ck::tensor_operation::device::instance::ConvFwdDefault>; // ConvSpec
// Get the first instance from the tuple
using DeviceInstance = typename std::tuple_element<0, InstanceTuple>::type;
// Define the base class type using the most general operator base
using BaseClass = ck::tensor_operation::device::BaseOperator;
// Create an instance of the derived class
DeviceInstance device_instance;
// Get a pointer to the base class
BaseClass* base_ptr = &device_instance;
// Call GetInstanceString through the base class pointer
std::string instance_str = base_ptr->GetInstanceString();
// Expected complete instance string based on the first instance from
// device_grouped_conv2d_fwd_dl_f16_instances
std::string expected_str = "DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK"
"<2" // NDimSpatial
",fp16" // ADataType
",fp16" // BDataType
",EmptyTuple" // DsDataType
",fp16" // EDataType
",fp32" // AccDataType
",GNHWC" // ALayout
",GKYXC" // BLayout
",EmptyTuple" // DsLayout
",GNHWK" // ELayout
",PassThrough" // AElementwiseOperation
",PassThrough" // BElementwiseOperation
",PassThrough" // CDEElementwiseOperation
",Default" // ConvForwardSpecialization
",MNKPadding" // GemmSpec
",8" // BlockSize
",16" // MPerBlock
",4" // NPerBlock
",2" // K0PerBlock
",1" // K1
",1" // M1PerThread
",2" // N1PerThread
",1" // KPerThread
",Seq(4,2)" // M1N1ThreadClusterM1Xs
",Seq(1,1)" // M1N1ThreadClusterN1Xs
",Seq(2,1,2,1)" // ABlockTransferThreadSliceLengths_K0_M0_M1_K1
",Seq(1,1,8,1)" // ABlockTransferThreadClusterLengths_K0_M0_M1_K1
",Seq(1,2,0,3)" // ABlockTransferThreadClusterArrangeOrder
",Seq(1,2,0,3)" // ABlockTransferSrcAccessOrder
",Seq(1,1,1,1)" // ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1
",Seq(1,2,0,3)" // ABlockTransferSrcVectorTensorContiguousDimOrder
",Seq(1,1,1,1)" // ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1
",Seq(1,1,1,1)" // BBlockTransferThreadSliceLengths_K0_N0_N1_K1
",Seq(2,1,4,1)" // BBlockTransferThreadClusterLengths_K0_N0_N1_K1
",Seq(1,2,0,3)" // BBlockTransferThreadClusterArrangeOrder
",Seq(1,2,0,3)" // BBlockTransferSrcAccessOrder
",Seq(1,1,1,1)" // BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1
",Seq(1,2,0,3)" // BBlockTransferSrcVectorTensorContiguousDimOrder
",Seq(1,1,1,1)" // BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1
",Seq(0,1,2,3,4,5)" // CThreadTransferSrcDstAccessOrder
",5" // CThreadTransferSrcDstVectorDim
",1>"; // CThreadTransferDstScalarPerVector
EXPECT_EQ(instance_str, expected_str);
}

View File

@@ -3,7 +3,7 @@
#include <gtest/gtest.h>
#include <ck_tile/builder/reflect/instance_traits.hpp>
#include <ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_abd.hpp>
#include <ck/tensor_operation/gpu/device/device_base.hpp>
#include <ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_large_tensor_instance.hpp>
// Test GetInstanceString through base class pointer for large tensor variant
@@ -22,22 +22,8 @@ TEST(GetInstanceString, ReturnsStringForFwdGrpConvLargeTensorInstance)
// Get the first instance from the tuple
using DeviceInstance = typename std::tuple_element<0, InstanceTuple>::type;
// Define the base class type using DeviceGroupedConvFwdMultipleABD
using BaseClass = ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD<
2, // NDimSpatial
ck::tensor_operation::device::instance::GNHWC, // ALayout
ck::tensor_operation::device::instance::GKYXC, // BLayout
ck::tensor_operation::device::instance::Empty_Tuple, // DsLayout
ck::tensor_operation::device::instance::GNHWK, // ELayout
ck::half_t, // ADataType
ck::half_t, // BDataType
ck::Tuple<>, // DsDataType
ck::half_t, // EDataType
ck::tensor_operation::element_wise::PassThrough, // AElementwiseOperation
ck::tensor_operation::element_wise::PassThrough, // BElementwiseOperation
ck::tensor_operation::element_wise::PassThrough, // CDEElementwiseOperation
ck::half_t, // AComputeType
ck::half_t>; // BComputeType
// Define the base class type using the most general operator base
using BaseClass = ck::tensor_operation::device::BaseOperator;
// Create an instance of the derived class
DeviceInstance device_instance;

View File

@@ -3,7 +3,7 @@
#include <gtest/gtest.h>
#include <ck_tile/builder/reflect/instance_traits.hpp>
#include <ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_abd.hpp>
#include <ck/tensor_operation/gpu/device/device_base.hpp>
#include <ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_comp_instance.hpp>
// Test GetInstanceString through base class pointer for V3 variant
@@ -22,22 +22,8 @@ TEST(GetInstanceString, ReturnsStringForFwdGrpConvV3Instance)
// Get the first instance from the tuple
using DeviceInstance = typename std::tuple_element<0, InstanceTuple>::type;
// Define the base class type using DeviceGroupedConvFwdMultipleABD
using BaseClass = ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD<
2, // NDimSpatial
ck::tensor_operation::device::instance::GNHWC, // ALayout
ck::tensor_operation::device::instance::GKYXC, // BLayout
ck::tensor_operation::device::instance::Empty_Tuple, // DsLayout
ck::tensor_operation::device::instance::GNHWK, // ELayout
ck::half_t, // ADataType
ck::half_t, // BDataType
ck::Tuple<>, // DsDataType
ck::half_t, // EDataType
ck::tensor_operation::element_wise::PassThrough, // AElementwiseOperation
ck::tensor_operation::element_wise::PassThrough, // BElementwiseOperation
ck::tensor_operation::element_wise::PassThrough, // CDEElementwiseOperation
ck::half_t, // AComputeType
ck::half_t>; // BComputeType
// Define the base class type using the most general operator base
using BaseClass = ck::tensor_operation::device::BaseOperator;
// Create an instance of the derived class
DeviceInstance device_instance;
@@ -99,6 +85,7 @@ TEST(GetInstanceString, ReturnsStringForFwdGrpConvV3Instance)
",Intrawave" // BlkGemmPipeSched
",v4" // BlkGemmPipelineVer
",fp16" // AComputeDataType
",fp16>"; // BComputeDataType
",fp16" // BComputeDataType
",false>"; // DirectLoad
EXPECT_EQ(instance_str, expected_str);
}

View File

@@ -0,0 +1,90 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include <gtest/gtest.h>
#include <ck_tile/builder/reflect/instance_traits.hpp>
#include <ck/tensor_operation/gpu/device/device_base.hpp>
#include <ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_instance.hpp>
// Test GetInstanceString through base class pointer for Wmma variant
TEST(GetInstanceString, ReturnsStringForFwdGrpConvWmmaInstance)
{
// Use the template helper to get a working instance configuration
using InstanceTuple =
ck::tensor_operation::device::instance::device_grouped_conv_fwd_wmma_f16_instances<
2, // NDimSpatial
ck::tensor_operation::device::instance::GNHWC, // ALayout
ck::tensor_operation::device::instance::GKYXC, // BLayout
ck::tensor_operation::device::instance::Empty_Tuple, // DsLayout
ck::tensor_operation::device::instance::GNHWK, // ELayout
ck::Tuple<>, // DsDatatype
ck::tensor_operation::element_wise::PassThrough, // CDEElementOp
ck::tensor_operation::device::instance::ConvFwdDefault>; // ConvSpec
// Get the first instance from the tuple
using DeviceInstance = typename std::tuple_element<0, InstanceTuple>::type;
// Define the base class type using the most general operator base
using BaseClass = ck::tensor_operation::device::BaseOperator;
// Create an instance of the derived class
DeviceInstance device_instance;
// Get a pointer to the base class
BaseClass* base_ptr = &device_instance;
// Call GetInstanceString through the base class pointer
std::string instance_str = base_ptr->GetInstanceString();
// Expected complete instance string based on the first instance from
// device_grouped_conv_fwd_wmma_f16_instances
std::string expected_str = "DeviceGroupedConvFwdMultipleD_Wmma_CShuffle"
"<2" // NDimSpatial
",GNHWC" // ALayout
",GKYXC" // BLayout
",EmptyTuple" // DsLayout
",GNHWK" // ELayout
",fp16" // ADataType
",fp16" // BDataType
",fp32" // AccDataType
",fp16" // CShuffleDataType
",EmptyTuple" // DsDataType
",fp16" // EDataType
",PassThrough" // AElementwiseOperation
",PassThrough" // BElementwiseOperation
",PassThrough" // CDEElementwiseOperation
",Default" // ConvForwardSpecialization
",MNKPadding" // GemmSpec
",1" // NumGemmKPrefetchStage
",128" // BlockSize
",64" // MPerBlock
",64" // NPerBlock
",32" // KPerBlock
",8" // K1
",16" // MPerWmma
",16" // NPerWmma
",2" // MRepeat
",2" // NRepeat
",Seq(4,32,1)" // ABlockTransferThreadClusterLengths
",Seq(1,0,2)" // ABlockTransferThreadClusterArrangeOrder
",Seq(1,0,2)" // ABlockTransferSrcAccessOrder
",2" // ABlockTransferSrcVectorDim
",1" // ABlockTransferSrcScalarPerVector
",8" // ABlockTransferDstScalarPerVector_AK1
",true" // ABlockLdsExtraM
",Seq(4,32,1)" // BBlockTransferThreadClusterLengths
",Seq(1,0,2)" // BBlockTransferThreadClusterArrangeOrder
",Seq(1,0,2)" // BBlockTransferSrcAccessOrder
",2" // BBlockTransferSrcVectorDim
",1" // BBlockTransferSrcScalarPerVector
",8" // BBlockTransferDstScalarPerVector_BK1
",true" // BBlockLdsExtraN
",1" // CShuffleMRepeatPerShuffle
",1" // CShuffleNRepeatPerShuffle
",Seq(1,32,1,4)" // CDEShuffleBlockTransferClusterLengths
",1" // CDEShuffleBlockTransferScalarPerVector_NPerBlock
",Default" // LoopSched
",v1>"; // PipelineVer
EXPECT_EQ(instance_str, expected_str);
}

View File

@@ -202,8 +202,8 @@ TEST(InstanceTraitsUtil, PipelineVersionNameReturnsCorrectStrings)
TEST(InstanceTraitsUtil, LoopSchedulerNameReturnsCorrectStrings)
{
using enum ck::LoopScheduler;
EXPECT_THAT(std::vector<std::string_view> names = {loop_scheduler_name(Default),
loop_scheduler_name(Interwave)},
EXPECT_THAT((std::vector<std::string_view>{loop_scheduler_name(Default),
loop_scheduler_name(Interwave)}),
ElementsAre("Default", "Interwave"));
}
@@ -267,5 +267,15 @@ TEST(InstanceTraitsUtil, SequenceNameReturnsSeqStringForMultipleValueSequence)
EXPECT_EQ((sequence_name<ck::Sequence<256, 128, 64, 32, 16>>()), "Seq(256,128,64,32,16)");
}
TEST(InstanceTraitsUtil, TypeOrTypeTupleNameReturnsCorrectStringForScalarDataType)
{
EXPECT_EQ(type_or_type_tuple_name<float>(), "fp32");
}
TEST(InstanceTraitsUtil, TypeOrTypeTupleNameReturnsCorrectStringForTupleOfDataTypes)
{
EXPECT_EQ((type_or_type_tuple_name<ck::Tuple<ck::half_t, float>>()), "Tuple(fp16,fp32)");
}
} // namespace
} // namespace ck_tile::reflect::detail