mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 06:49:15 +00:00
Add instance traits for two more grouped forward convolutions (#3112)
This commit is contained in:
@@ -0,0 +1,350 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
// InstanceTraits specialization for DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
//
|
||||
// CRITICAL MAINTENANCE NOTE:
|
||||
// This InstanceTraits file MUST be kept strictly in sync with the device implementation header:
|
||||
// ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp
|
||||
// "In sync" means that the template parameter order, names, and types in the declaration below
|
||||
// MUST EXACTLY MATCH those in the device implementation. If these diverge, you may encounter
|
||||
// compilation errors, subtle template instantiation mismatches, or silent runtime bugs that are
|
||||
// difficult to diagnose. Always update both files together and review changes carefully.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "instance_traits.hpp"
|
||||
|
||||
// Forward declaration to avoid circular dependency.
|
||||
namespace ck::tensor_operation::device {
|
||||
|
||||
template <ck::index_t NDimSpatial,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename DsLayout,
|
||||
typename ELayout,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename AccDataType,
|
||||
typename CShuffleDataType,
|
||||
typename DsDataType,
|
||||
typename EDataType,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CDEElementwiseOperation,
|
||||
ConvolutionForwardSpecialization ConvForwardSpecialization,
|
||||
GemmSpecialization GemmSpec,
|
||||
ck::index_t NumGemmKPrefetchStage,
|
||||
ck::index_t BlockSize,
|
||||
ck::index_t MPerBlock,
|
||||
ck::index_t NPerBlock,
|
||||
ck::index_t KPerBlock,
|
||||
ck::index_t AK1,
|
||||
ck::index_t BK1,
|
||||
ck::index_t MPerXDL,
|
||||
ck::index_t NPerXDL,
|
||||
ck::index_t MXdlPerWave,
|
||||
ck::index_t NXdlPerWave,
|
||||
typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
|
||||
typename ABlockTransferThreadClusterArrangeOrder,
|
||||
typename ABlockTransferSrcAccessOrder,
|
||||
ck::index_t ABlockTransferSrcVectorDim,
|
||||
ck::index_t ABlockTransferSrcScalarPerVector,
|
||||
ck::index_t ABlockTransferDstScalarPerVector_AK1,
|
||||
ck::index_t ABlockLdsExtraM,
|
||||
typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
|
||||
typename BBlockTransferThreadClusterArrangeOrder,
|
||||
typename BBlockTransferSrcAccessOrder,
|
||||
ck::index_t BBlockTransferSrcVectorDim,
|
||||
ck::index_t BBlockTransferSrcScalarPerVector,
|
||||
ck::index_t BBlockTransferDstScalarPerVector_BK1,
|
||||
ck::index_t BBlockLdsExtraN,
|
||||
ck::index_t CShuffleMXdlPerWavePerShuffle,
|
||||
ck::index_t CShuffleNXdlPerWavePerShuffle,
|
||||
typename CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
ck::index_t CDEBlockTransferScalarPerVector_NPerBlock,
|
||||
typename AComputeDataType,
|
||||
typename BComputeDataType,
|
||||
LoopScheduler LoopSched,
|
||||
ck::index_t NumGroupsToMerge>
|
||||
struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle;
|
||||
|
||||
} // namespace ck::tensor_operation::device
|
||||
|
||||
namespace ck_tile::reflect {
|
||||
|
||||
// Specialization for DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
template <ck::index_t NDimSpatial,
|
||||
typename ALayout_,
|
||||
typename BLayout_,
|
||||
typename DsLayout_,
|
||||
typename ELayout_,
|
||||
typename ADataType_,
|
||||
typename BDataType_,
|
||||
typename AccDataType_,
|
||||
typename CShuffleDataType_,
|
||||
typename DsDataType_,
|
||||
typename EDataType_,
|
||||
typename AElementwiseOperation_,
|
||||
typename BElementwiseOperation_,
|
||||
typename CDEElementwiseOperation_,
|
||||
ck::tensor_operation::device::ConvolutionForwardSpecialization ConvForwardSpecialization,
|
||||
ck::tensor_operation::device::GemmSpecialization GemmSpec,
|
||||
ck::index_t NumGemmKPrefetchStage,
|
||||
ck::index_t BlockSize,
|
||||
ck::index_t MPerBlock,
|
||||
ck::index_t NPerBlock,
|
||||
ck::index_t KPerBlock,
|
||||
ck::index_t AK1,
|
||||
ck::index_t BK1,
|
||||
ck::index_t MPerXDL,
|
||||
ck::index_t NPerXDL,
|
||||
ck::index_t MXdlPerWave,
|
||||
ck::index_t NXdlPerWave,
|
||||
typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
|
||||
typename ABlockTransferThreadClusterArrangeOrder,
|
||||
typename ABlockTransferSrcAccessOrder_,
|
||||
ck::index_t ABlockTransferSrcVectorDim,
|
||||
ck::index_t ABlockTransferSrcScalarPerVector,
|
||||
ck::index_t ABlockTransferDstScalarPerVector_AK1,
|
||||
ck::index_t ABlockLdsExtraM,
|
||||
typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
|
||||
typename BBlockTransferThreadClusterArrangeOrder,
|
||||
typename BBlockTransferSrcAccessOrder_,
|
||||
ck::index_t BBlockTransferSrcVectorDim,
|
||||
ck::index_t BBlockTransferSrcScalarPerVector,
|
||||
ck::index_t BBlockTransferDstScalarPerVector_BK1,
|
||||
ck::index_t BBlockLdsExtraN,
|
||||
ck::index_t CShuffleMXdlPerWavePerShuffle,
|
||||
ck::index_t CShuffleNXdlPerWavePerShuffle,
|
||||
typename CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
ck::index_t CDEBlockTransferScalarPerVector_NPerBlock,
|
||||
typename AComputeDataType_,
|
||||
typename BComputeDataType_,
|
||||
ck::LoopScheduler LoopSched,
|
||||
ck::index_t NumGroupsToMerge>
|
||||
struct InstanceTraits<ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<
|
||||
NDimSpatial,
|
||||
ALayout_,
|
||||
BLayout_,
|
||||
DsLayout_,
|
||||
ELayout_,
|
||||
ADataType_,
|
||||
BDataType_,
|
||||
AccDataType_,
|
||||
CShuffleDataType_,
|
||||
DsDataType_,
|
||||
EDataType_,
|
||||
AElementwiseOperation_,
|
||||
BElementwiseOperation_,
|
||||
CDEElementwiseOperation_,
|
||||
ConvForwardSpecialization,
|
||||
GemmSpec,
|
||||
NumGemmKPrefetchStage,
|
||||
BlockSize,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
AK1,
|
||||
BK1,
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
MXdlPerWave,
|
||||
NXdlPerWave,
|
||||
ABlockTransferThreadClusterLengths_AK0_M_AK1,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
ABlockTransferSrcAccessOrder_,
|
||||
ABlockTransferSrcVectorDim,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
ABlockTransferDstScalarPerVector_AK1,
|
||||
ABlockLdsExtraM,
|
||||
BBlockTransferThreadClusterLengths_BK0_N_BK1,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
BBlockTransferSrcAccessOrder_,
|
||||
BBlockTransferSrcVectorDim,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
BBlockTransferDstScalarPerVector_BK1,
|
||||
BBlockLdsExtraN,
|
||||
CShuffleMXdlPerWavePerShuffle,
|
||||
CShuffleNXdlPerWavePerShuffle,
|
||||
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
CDEBlockTransferScalarPerVector_NPerBlock,
|
||||
AComputeDataType_,
|
||||
BComputeDataType_,
|
||||
LoopSched,
|
||||
NumGroupsToMerge>>
|
||||
{
|
||||
// Spatial dimension
|
||||
static constexpr int kSpatialDim = NDimSpatial;
|
||||
|
||||
// Layout types
|
||||
using ALayout = ALayout_;
|
||||
using BLayout = BLayout_;
|
||||
using DsLayout = DsLayout_;
|
||||
using ELayout = ELayout_;
|
||||
|
||||
// Data types
|
||||
using ADataType = ADataType_;
|
||||
using BDataType = BDataType_;
|
||||
using AccDataType = AccDataType_;
|
||||
using CShuffleDataType = CShuffleDataType_;
|
||||
using DsDataType = DsDataType_;
|
||||
using EDataType = EDataType_;
|
||||
|
||||
// Element-wise operations
|
||||
using AElementwiseOperation = AElementwiseOperation_;
|
||||
using BElementwiseOperation = BElementwiseOperation_;
|
||||
using CDEElementwiseOperation = CDEElementwiseOperation_;
|
||||
|
||||
// Specialization
|
||||
static constexpr ck::tensor_operation::device::ConvolutionForwardSpecialization
|
||||
kConvForwardSpecialization = ConvForwardSpecialization;
|
||||
static constexpr ck::tensor_operation::device::GemmSpecialization kGemmSpecialization =
|
||||
GemmSpec;
|
||||
|
||||
// Prefetch stage
|
||||
static constexpr int kNumGemmKPrefetchStage = NumGemmKPrefetchStage;
|
||||
|
||||
// Block configuration
|
||||
static constexpr int kBlockSize = BlockSize;
|
||||
static constexpr int kMPerBlock = MPerBlock;
|
||||
static constexpr int kNPerBlock = NPerBlock;
|
||||
static constexpr int kKPerBlock = KPerBlock;
|
||||
|
||||
// Tuning parameters
|
||||
static constexpr int kAK1 = AK1;
|
||||
static constexpr int kBK1 = BK1;
|
||||
static constexpr int kMPerXDL = MPerXDL;
|
||||
static constexpr int kNPerXDL = NPerXDL;
|
||||
static constexpr int kMXdlPerWave = MXdlPerWave;
|
||||
static constexpr int kNXdlPerWave = NXdlPerWave;
|
||||
|
||||
// A block transfer thread cluster dimensions (converted to std::array)
|
||||
static constexpr auto kAThreadClusterLengths =
|
||||
detail::SequenceToArray<ABlockTransferThreadClusterLengths_AK0_M_AK1>::value;
|
||||
static constexpr auto kAThreadClusterArrangeOrder =
|
||||
detail::SequenceToArray<ABlockTransferThreadClusterArrangeOrder>::value;
|
||||
static constexpr auto kABlockTransferSrcAccessOrder =
|
||||
detail::SequenceToArray<ABlockTransferSrcAccessOrder_>::value;
|
||||
static constexpr int kABlockTransferSrcVectorDim = ABlockTransferSrcVectorDim;
|
||||
static constexpr int kABlockTransferSrcScalarPerVector = ABlockTransferSrcScalarPerVector;
|
||||
static constexpr int kABlockTransferDstScalarPerVectorK1 = ABlockTransferDstScalarPerVector_AK1;
|
||||
static constexpr int kABlockLdsExtraM = ABlockLdsExtraM;
|
||||
|
||||
// B block transfer thread cluster dimensions (converted to std::array)
|
||||
static constexpr auto kBThreadClusterLengths =
|
||||
detail::SequenceToArray<BBlockTransferThreadClusterLengths_BK0_N_BK1>::value;
|
||||
static constexpr auto kBThreadClusterArrangeOrder =
|
||||
detail::SequenceToArray<BBlockTransferThreadClusterArrangeOrder>::value;
|
||||
static constexpr auto kBBlockTransferSrcAccessOrder =
|
||||
detail::SequenceToArray<BBlockTransferSrcAccessOrder_>::value;
|
||||
static constexpr int kBBlockTransferSrcVectorDim = BBlockTransferSrcVectorDim;
|
||||
static constexpr int kBBlockTransferSrcScalarPerVector = BBlockTransferSrcScalarPerVector;
|
||||
static constexpr int kBBlockTransferDstScalarPerVectorK1 = BBlockTransferDstScalarPerVector_BK1;
|
||||
static constexpr int kBBlockLdsExtraN = BBlockLdsExtraN;
|
||||
|
||||
// C shuffle parameters (converted to std::array)
|
||||
static constexpr int kCShuffleMXdlPerWavePerShuffle = CShuffleMXdlPerWavePerShuffle;
|
||||
static constexpr int kCShuffleNXdlPerWavePerShuffle = CShuffleNXdlPerWavePerShuffle;
|
||||
static constexpr auto kCThreadClusterLengths = detail::SequenceToArray<
|
||||
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock>::value;
|
||||
static constexpr int kCBlockTransferScalarPerVector = CDEBlockTransferScalarPerVector_NPerBlock;
|
||||
|
||||
// Compute data types
|
||||
using AComputeDataType = AComputeDataType_;
|
||||
using BComputeDataType = BComputeDataType_;
|
||||
|
||||
// Loop scheduler
|
||||
static constexpr ck::LoopScheduler kLoopScheduler = LoopSched;
|
||||
|
||||
// Groups to merge
|
||||
static constexpr int kNumGroupsToMerge = NumGroupsToMerge;
|
||||
|
||||
// Static member function to generate instance string
|
||||
static std::string instance_string()
|
||||
{
|
||||
std::ostringstream oss;
|
||||
|
||||
// Kernel type name
|
||||
oss << "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle";
|
||||
|
||||
// Template parameters in exact order matching InstanceTraits member order
|
||||
oss << "<" << kSpatialDim; // 1. NDimSpatial
|
||||
oss << "," << detail::layout_name<ALayout>(); // 2. ALayout
|
||||
oss << "," << detail::layout_name<BLayout>(); // 3. BLayout
|
||||
oss << "," << detail::tuple_name<DsLayout>(); // 4. DsLayout
|
||||
oss << "," << detail::layout_name<ELayout>(); // 5. ELayout
|
||||
oss << "," << detail::type_name<ADataType>(); // 6. ADataType
|
||||
oss << "," << detail::type_name<BDataType>(); // 7. BDataType
|
||||
oss << "," << detail::type_name<AccDataType>(); // 8. AccDataType
|
||||
oss << "," << detail::type_name<CShuffleDataType>(); // 9. CShuffleDataType
|
||||
oss << "," << detail::tuple_name<DsDataType>(); // 10. DsDataType
|
||||
oss << "," << detail::type_name<EDataType>(); // 11. EDataType
|
||||
oss << ","
|
||||
<< detail::elementwise_op_name<AElementwiseOperation>(); // 12. AElementwiseOperation
|
||||
oss << ","
|
||||
<< detail::elementwise_op_name<BElementwiseOperation>(); // 13. BElementwiseOperation
|
||||
oss << ","
|
||||
<< detail::elementwise_op_name<CDEElementwiseOperation>(); // 14.
|
||||
// CDEElementwiseOperation
|
||||
oss << ","
|
||||
<< detail::conv_fwd_spec_name(
|
||||
kConvForwardSpecialization); // 15. ConvForwardSpecialization
|
||||
oss << "," << detail::gemm_spec_name(kGemmSpecialization); // 16. GemmSpec
|
||||
oss << "," << kNumGemmKPrefetchStage; // 17. NumGemmKPrefetchStage
|
||||
oss << "," << kBlockSize; // 18. BlockSize
|
||||
oss << "," << kMPerBlock; // 19. MPerBlock
|
||||
oss << "," << kNPerBlock; // 20. NPerBlock
|
||||
oss << "," << kKPerBlock; // 21. KPerBlock
|
||||
oss << "," << kAK1; // 22. AK1
|
||||
oss << "," << kBK1; // 23. BK1
|
||||
oss << "," << kMPerXDL; // 24. MPerXDL
|
||||
oss << "," << kNPerXDL; // 25. NPerXDL
|
||||
oss << "," << kMXdlPerWave; // 26. MXdlPerWave
|
||||
oss << "," << kNXdlPerWave; // 27. NXdlPerWave
|
||||
oss << ","
|
||||
<< detail::array_to_string(
|
||||
kAThreadClusterLengths); // 28. ABlockTransferThreadClusterLengths
|
||||
oss << ","
|
||||
<< detail::array_to_string(
|
||||
kAThreadClusterArrangeOrder); // 29. ABlockTransferThreadClusterArrangeOrder
|
||||
oss << ","
|
||||
<< detail::array_to_string(
|
||||
kABlockTransferSrcAccessOrder); // 30. ABlockTransferSrcAccessOrder
|
||||
oss << "," << kABlockTransferSrcVectorDim; // 31. ABlockTransferSrcVectorDim
|
||||
oss << "," << kABlockTransferSrcScalarPerVector; // 32. ABlockTransferSrcScalarPerVector
|
||||
oss << ","
|
||||
<< kABlockTransferDstScalarPerVectorK1; // 33. ABlockTransferDstScalarPerVector_AK1
|
||||
oss << "," << kABlockLdsExtraM; // 34. ABlockLdsExtraM
|
||||
oss << ","
|
||||
<< detail::array_to_string(
|
||||
kBThreadClusterLengths); // 35. BBlockTransferThreadClusterLengths
|
||||
oss << ","
|
||||
<< detail::array_to_string(
|
||||
kBThreadClusterArrangeOrder); // 36. BBlockTransferThreadClusterArrangeOrder
|
||||
oss << ","
|
||||
<< detail::array_to_string(
|
||||
kBBlockTransferSrcAccessOrder); // 37. BBlockTransferSrcAccessOrder
|
||||
oss << "," << kBBlockTransferSrcVectorDim; // 38. BBlockTransferSrcVectorDim
|
||||
oss << "," << kBBlockTransferSrcScalarPerVector; // 39. BBlockTransferSrcScalarPerVector
|
||||
oss << ","
|
||||
<< kBBlockTransferDstScalarPerVectorK1; // 40. BBlockTransferDstScalarPerVector_BK1
|
||||
oss << "," << kBBlockLdsExtraN; // 41. BBlockLdsExtraN
|
||||
oss << "," << kCShuffleMXdlPerWavePerShuffle; // 42. CShuffleMXdlPerWavePerShuffle
|
||||
oss << "," << kCShuffleNXdlPerWavePerShuffle; // 43. CShuffleNXdlPerWavePerShuffle
|
||||
oss << ","
|
||||
<< detail::array_to_string(
|
||||
kCThreadClusterLengths); // 44. CDEBlockTransferClusterLengths
|
||||
oss << ","
|
||||
<< kCBlockTransferScalarPerVector; // 45. CDEBlockTransferScalarPerVector_NPerBlock
|
||||
oss << "," << detail::type_name<AComputeDataType>(); // 46. AComputeDataType
|
||||
oss << "," << detail::type_name<BComputeDataType>(); // 47. BComputeDataType
|
||||
oss << "," << detail::loop_scheduler_name(kLoopScheduler); // 48. LoopSched
|
||||
oss << "," << kNumGroupsToMerge; // 49. NumGroupsToMerge
|
||||
oss << ">";
|
||||
|
||||
return oss.str();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile::reflect
|
||||
@@ -0,0 +1,344 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
// InstanceTraits specialization for DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor
|
||||
//
|
||||
// CRITICAL MAINTENANCE NOTE:
|
||||
// This InstanceTraits file MUST be kept strictly in sync with the device implementation header:
|
||||
// ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp
|
||||
// "In sync" means that the template parameter order, names, and types in the declaration below
|
||||
// MUST EXACTLY MATCH those in the device implementation. If these diverge, you may encounter
|
||||
// compilation errors, subtle template instantiation mismatches, or silent runtime bugs that are
|
||||
// difficult to diagnose. Always update both files together and review changes carefully.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "instance_traits.hpp"
|
||||
|
||||
// Forward declaration to avoid circular dependency.
|
||||
namespace ck::tensor_operation::device {
|
||||
|
||||
template <ck::index_t NDimSpatial,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename DsLayout,
|
||||
typename ELayout,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename AccDataType,
|
||||
typename CShuffleDataType,
|
||||
typename DsDataType,
|
||||
typename EDataType,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CDEElementwiseOperation,
|
||||
ConvolutionForwardSpecialization ConvForwardSpecialization,
|
||||
GemmSpecialization GemmSpec,
|
||||
ck::index_t NumGemmKPrefetchStage,
|
||||
ck::index_t BlockSize,
|
||||
ck::index_t MPerBlock,
|
||||
ck::index_t NPerBlock,
|
||||
ck::index_t KPerBlock,
|
||||
ck::index_t AK1,
|
||||
ck::index_t BK1,
|
||||
ck::index_t MPerXDL,
|
||||
ck::index_t NPerXDL,
|
||||
ck::index_t MXdlPerWave,
|
||||
ck::index_t NXdlPerWave,
|
||||
typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
|
||||
typename ABlockTransferThreadClusterArrangeOrder,
|
||||
typename ABlockTransferSrcAccessOrder,
|
||||
ck::index_t ABlockTransferSrcVectorDim,
|
||||
ck::index_t ABlockTransferSrcScalarPerVector,
|
||||
ck::index_t ABlockTransferDstScalarPerVector_AK1,
|
||||
ck::index_t ABlockLdsExtraM,
|
||||
typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
|
||||
typename BBlockTransferThreadClusterArrangeOrder,
|
||||
typename BBlockTransferSrcAccessOrder,
|
||||
ck::index_t BBlockTransferSrcVectorDim,
|
||||
ck::index_t BBlockTransferSrcScalarPerVector,
|
||||
ck::index_t BBlockTransferDstScalarPerVector_BK1,
|
||||
ck::index_t BBlockLdsExtraN,
|
||||
ck::index_t CShuffleMXdlPerWavePerShuffle,
|
||||
ck::index_t CShuffleNXdlPerWavePerShuffle,
|
||||
typename CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
ck::index_t CDEBlockTransferScalarPerVector_NPerBlock,
|
||||
typename AComputeDataType,
|
||||
typename BComputeDataType,
|
||||
LoopScheduler LoopSched>
|
||||
struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor;
|
||||
|
||||
} // namespace ck::tensor_operation::device
|
||||
|
||||
namespace ck_tile::reflect {
|
||||
|
||||
// Specialization for DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor
|
||||
template <ck::index_t NDimSpatial,
|
||||
typename ALayout_,
|
||||
typename BLayout_,
|
||||
typename DsLayout_,
|
||||
typename ELayout_,
|
||||
typename ADataType_,
|
||||
typename BDataType_,
|
||||
typename AccDataType_,
|
||||
typename CShuffleDataType_,
|
||||
typename DsDataType_,
|
||||
typename EDataType_,
|
||||
typename AElementwiseOperation_,
|
||||
typename BElementwiseOperation_,
|
||||
typename CDEElementwiseOperation_,
|
||||
ck::tensor_operation::device::ConvolutionForwardSpecialization ConvForwardSpecialization,
|
||||
ck::tensor_operation::device::GemmSpecialization GemmSpec,
|
||||
ck::index_t NumGemmKPrefetchStage,
|
||||
ck::index_t BlockSize,
|
||||
ck::index_t MPerBlock,
|
||||
ck::index_t NPerBlock,
|
||||
ck::index_t KPerBlock,
|
||||
ck::index_t AK1,
|
||||
ck::index_t BK1,
|
||||
ck::index_t MPerXDL,
|
||||
ck::index_t NPerXDL,
|
||||
ck::index_t MXdlPerWave,
|
||||
ck::index_t NXdlPerWave,
|
||||
typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
|
||||
typename ABlockTransferThreadClusterArrangeOrder,
|
||||
typename ABlockTransferSrcAccessOrder_,
|
||||
ck::index_t ABlockTransferSrcVectorDim,
|
||||
ck::index_t ABlockTransferSrcScalarPerVector,
|
||||
ck::index_t ABlockTransferDstScalarPerVector_AK1,
|
||||
ck::index_t ABlockLdsExtraM,
|
||||
typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
|
||||
typename BBlockTransferThreadClusterArrangeOrder,
|
||||
typename BBlockTransferSrcAccessOrder_,
|
||||
ck::index_t BBlockTransferSrcVectorDim,
|
||||
ck::index_t BBlockTransferSrcScalarPerVector,
|
||||
ck::index_t BBlockTransferDstScalarPerVector_BK1,
|
||||
ck::index_t BBlockLdsExtraN,
|
||||
ck::index_t CShuffleMXdlPerWavePerShuffle,
|
||||
ck::index_t CShuffleNXdlPerWavePerShuffle,
|
||||
typename CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
ck::index_t CDEBlockTransferScalarPerVector_NPerBlock,
|
||||
typename AComputeDataType_,
|
||||
typename BComputeDataType_,
|
||||
ck::LoopScheduler LoopSched>
|
||||
struct InstanceTraits<
|
||||
ck::tensor_operation::device::DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor<
|
||||
NDimSpatial,
|
||||
ALayout_,
|
||||
BLayout_,
|
||||
DsLayout_,
|
||||
ELayout_,
|
||||
ADataType_,
|
||||
BDataType_,
|
||||
AccDataType_,
|
||||
CShuffleDataType_,
|
||||
DsDataType_,
|
||||
EDataType_,
|
||||
AElementwiseOperation_,
|
||||
BElementwiseOperation_,
|
||||
CDEElementwiseOperation_,
|
||||
ConvForwardSpecialization,
|
||||
GemmSpec,
|
||||
NumGemmKPrefetchStage,
|
||||
BlockSize,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
AK1,
|
||||
BK1,
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
MXdlPerWave,
|
||||
NXdlPerWave,
|
||||
ABlockTransferThreadClusterLengths_AK0_M_AK1,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
ABlockTransferSrcAccessOrder_,
|
||||
ABlockTransferSrcVectorDim,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
ABlockTransferDstScalarPerVector_AK1,
|
||||
ABlockLdsExtraM,
|
||||
BBlockTransferThreadClusterLengths_BK0_N_BK1,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
BBlockTransferSrcAccessOrder_,
|
||||
BBlockTransferSrcVectorDim,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
BBlockTransferDstScalarPerVector_BK1,
|
||||
BBlockLdsExtraN,
|
||||
CShuffleMXdlPerWavePerShuffle,
|
||||
CShuffleNXdlPerWavePerShuffle,
|
||||
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
CDEBlockTransferScalarPerVector_NPerBlock,
|
||||
AComputeDataType_,
|
||||
BComputeDataType_,
|
||||
LoopSched>>
|
||||
{
|
||||
// Spatial dimension
|
||||
static constexpr int kSpatialDim = NDimSpatial;
|
||||
|
||||
// Layout types
|
||||
using ALayout = ALayout_;
|
||||
using BLayout = BLayout_;
|
||||
using DsLayout = DsLayout_;
|
||||
using ELayout = ELayout_;
|
||||
|
||||
// Data types
|
||||
using ADataType = ADataType_;
|
||||
using BDataType = BDataType_;
|
||||
using AccDataType = AccDataType_;
|
||||
using CShuffleDataType = CShuffleDataType_;
|
||||
using DsDataType = DsDataType_;
|
||||
using EDataType = EDataType_;
|
||||
|
||||
// Element-wise operations
|
||||
using AElementwiseOperation = AElementwiseOperation_;
|
||||
using BElementwiseOperation = BElementwiseOperation_;
|
||||
using CDEElementwiseOperation = CDEElementwiseOperation_;
|
||||
|
||||
// Specialization
|
||||
static constexpr ck::tensor_operation::device::ConvolutionForwardSpecialization
|
||||
kConvForwardSpecialization = ConvForwardSpecialization;
|
||||
static constexpr ck::tensor_operation::device::GemmSpecialization kGemmSpecialization =
|
||||
GemmSpec;
|
||||
|
||||
// Prefetch stage
|
||||
static constexpr int kNumGemmKPrefetchStage = NumGemmKPrefetchStage;
|
||||
|
||||
// Block configuration
|
||||
static constexpr int kBlockSize = BlockSize;
|
||||
static constexpr int kMPerBlock = MPerBlock;
|
||||
static constexpr int kNPerBlock = NPerBlock;
|
||||
static constexpr int kKPerBlock = KPerBlock;
|
||||
|
||||
// Tuning parameters
|
||||
static constexpr int kAK1 = AK1;
|
||||
static constexpr int kBK1 = BK1;
|
||||
static constexpr int kMPerXDL = MPerXDL;
|
||||
static constexpr int kNPerXDL = NPerXDL;
|
||||
static constexpr int kMXdlPerWave = MXdlPerWave;
|
||||
static constexpr int kNXdlPerWave = NXdlPerWave;
|
||||
|
||||
// A block transfer thread cluster dimensions (converted to std::array)
|
||||
static constexpr auto kAThreadClusterLengths =
|
||||
detail::SequenceToArray<ABlockTransferThreadClusterLengths_AK0_M_AK1>::value;
|
||||
static constexpr auto kAThreadClusterArrangeOrder =
|
||||
detail::SequenceToArray<ABlockTransferThreadClusterArrangeOrder>::value;
|
||||
static constexpr auto kABlockTransferSrcAccessOrder =
|
||||
detail::SequenceToArray<ABlockTransferSrcAccessOrder_>::value;
|
||||
static constexpr int kABlockTransferSrcVectorDim = ABlockTransferSrcVectorDim;
|
||||
static constexpr int kABlockTransferSrcScalarPerVector = ABlockTransferSrcScalarPerVector;
|
||||
static constexpr int kABlockTransferDstScalarPerVectorK1 = ABlockTransferDstScalarPerVector_AK1;
|
||||
static constexpr int kABlockLdsExtraM = ABlockLdsExtraM;
|
||||
|
||||
// B block transfer thread cluster dimensions (converted to std::array)
|
||||
static constexpr auto kBThreadClusterLengths =
|
||||
detail::SequenceToArray<BBlockTransferThreadClusterLengths_BK0_N_BK1>::value;
|
||||
static constexpr auto kBThreadClusterArrangeOrder =
|
||||
detail::SequenceToArray<BBlockTransferThreadClusterArrangeOrder>::value;
|
||||
static constexpr auto kBBlockTransferSrcAccessOrder =
|
||||
detail::SequenceToArray<BBlockTransferSrcAccessOrder_>::value;
|
||||
static constexpr int kBBlockTransferSrcVectorDim = BBlockTransferSrcVectorDim;
|
||||
static constexpr int kBBlockTransferSrcScalarPerVector = BBlockTransferSrcScalarPerVector;
|
||||
static constexpr int kBBlockTransferDstScalarPerVectorK1 = BBlockTransferDstScalarPerVector_BK1;
|
||||
static constexpr int kBBlockLdsExtraN = BBlockLdsExtraN;
|
||||
|
||||
// C shuffle parameters (converted to std::array)
|
||||
static constexpr int kCShuffleMXdlPerWavePerShuffle = CShuffleMXdlPerWavePerShuffle;
|
||||
static constexpr int kCShuffleNXdlPerWavePerShuffle = CShuffleNXdlPerWavePerShuffle;
|
||||
static constexpr auto kCThreadClusterLengths = detail::SequenceToArray<
|
||||
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock>::value;
|
||||
static constexpr int kCBlockTransferScalarPerVector = CDEBlockTransferScalarPerVector_NPerBlock;
|
||||
|
||||
// Compute data types
|
||||
using AComputeDataType = AComputeDataType_;
|
||||
using BComputeDataType = BComputeDataType_;
|
||||
|
||||
// Loop scheduler
|
||||
static constexpr ck::LoopScheduler kLoopScheduler = LoopSched;
|
||||
|
||||
// Static member function to generate instance string
|
||||
static std::string instance_string()
|
||||
{
|
||||
std::ostringstream oss;
|
||||
|
||||
// Kernel type name
|
||||
oss << "DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor";
|
||||
|
||||
// Template parameters in exact order matching InstanceTraits member order
|
||||
oss << "<" << kSpatialDim; // 1. NDimSpatial
|
||||
oss << "," << detail::layout_name<ALayout>(); // 2. ALayout
|
||||
oss << "," << detail::layout_name<BLayout>(); // 3. BLayout
|
||||
oss << "," << detail::tuple_name<DsLayout>(); // 4. DsLayout
|
||||
oss << "," << detail::layout_name<ELayout>(); // 5. ELayout
|
||||
oss << "," << detail::type_name<ADataType>(); // 6. ADataType
|
||||
oss << "," << detail::type_name<BDataType>(); // 7. BDataType
|
||||
oss << "," << detail::type_name<AccDataType>(); // 8. AccDataType
|
||||
oss << "," << detail::type_name<CShuffleDataType>(); // 9. CShuffleDataType
|
||||
oss << "," << detail::tuple_name<DsDataType>(); // 10. DsDataType
|
||||
oss << "," << detail::type_name<EDataType>(); // 11. EDataType
|
||||
oss << ","
|
||||
<< detail::elementwise_op_name<AElementwiseOperation>(); // 12. AElementwiseOperation
|
||||
oss << ","
|
||||
<< detail::elementwise_op_name<BElementwiseOperation>(); // 13. BElementwiseOperation
|
||||
oss << ","
|
||||
<< detail::elementwise_op_name<CDEElementwiseOperation>(); // 14.
|
||||
// CDEElementwiseOperation
|
||||
oss << ","
|
||||
<< detail::conv_fwd_spec_name(
|
||||
kConvForwardSpecialization); // 15. ConvForwardSpecialization
|
||||
oss << "," << detail::gemm_spec_name(kGemmSpecialization); // 16. GemmSpec
|
||||
oss << "," << kNumGemmKPrefetchStage; // 17. NumGemmKPrefetchStage
|
||||
oss << "," << kBlockSize; // 18. BlockSize
|
||||
oss << "," << kMPerBlock; // 19. MPerBlock
|
||||
oss << "," << kNPerBlock; // 20. NPerBlock
|
||||
oss << "," << kKPerBlock; // 21. KPerBlock
|
||||
oss << "," << kAK1; // 22. AK1
|
||||
oss << "," << kBK1; // 23. BK1
|
||||
oss << "," << kMPerXDL; // 24. MPerXDL
|
||||
oss << "," << kNPerXDL; // 25. NPerXDL
|
||||
oss << "," << kMXdlPerWave; // 26. MXdlPerWave
|
||||
oss << "," << kNXdlPerWave; // 27. NXdlPerWave
|
||||
oss << ","
|
||||
<< detail::array_to_string(
|
||||
kAThreadClusterLengths); // 28. ABlockTransferThreadClusterLengths
|
||||
oss << ","
|
||||
<< detail::array_to_string(
|
||||
kAThreadClusterArrangeOrder); // 29. ABlockTransferThreadClusterArrangeOrder
|
||||
oss << ","
|
||||
<< detail::array_to_string(
|
||||
kABlockTransferSrcAccessOrder); // 30. ABlockTransferSrcAccessOrder
|
||||
oss << "," << kABlockTransferSrcVectorDim; // 31. ABlockTransferSrcVectorDim
|
||||
oss << "," << kABlockTransferSrcScalarPerVector; // 32. ABlockTransferSrcScalarPerVector
|
||||
oss << ","
|
||||
<< kABlockTransferDstScalarPerVectorK1; // 33. ABlockTransferDstScalarPerVector_AK1
|
||||
oss << "," << kABlockLdsExtraM; // 34. ABlockLdsExtraM
|
||||
oss << ","
|
||||
<< detail::array_to_string(
|
||||
kBThreadClusterLengths); // 35. BBlockTransferThreadClusterLengths
|
||||
oss << ","
|
||||
<< detail::array_to_string(
|
||||
kBThreadClusterArrangeOrder); // 36. BBlockTransferThreadClusterArrangeOrder
|
||||
oss << ","
|
||||
<< detail::array_to_string(
|
||||
kBBlockTransferSrcAccessOrder); // 37. BBlockTransferSrcAccessOrder
|
||||
oss << "," << kBBlockTransferSrcVectorDim; // 38. BBlockTransferSrcVectorDim
|
||||
oss << "," << kBBlockTransferSrcScalarPerVector; // 39. BBlockTransferSrcScalarPerVector
|
||||
oss << ","
|
||||
<< kBBlockTransferDstScalarPerVectorK1; // 40. BBlockTransferDstScalarPerVector_BK1
|
||||
oss << "," << kBBlockLdsExtraN; // 41. BBlockLdsExtraN
|
||||
oss << "," << kCShuffleMXdlPerWavePerShuffle; // 42. CShuffleMXdlPerWavePerShuffle
|
||||
oss << "," << kCShuffleNXdlPerWavePerShuffle; // 43. CShuffleNXdlPerWavePerShuffle
|
||||
oss << ","
|
||||
<< detail::array_to_string(
|
||||
kCThreadClusterLengths); // 44. CDEBlockTransferClusterLengths
|
||||
oss << ","
|
||||
<< kCBlockTransferScalarPerVector; // 45. CDEBlockTransferScalarPerVector_NPerBlock
|
||||
oss << "," << detail::type_name<AComputeDataType>(); // 46. AComputeDataType
|
||||
oss << "," << detail::type_name<BComputeDataType>(); // 47. BComputeDataType
|
||||
oss << "," << detail::loop_scheduler_name(kLoopScheduler); // 48. LoopSched
|
||||
oss << ">";
|
||||
|
||||
return oss.str();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile::reflect
|
||||
@@ -15,6 +15,7 @@
|
||||
#include <ck/utility/data_type.hpp>
|
||||
#include <ck/utility/sequence.hpp>
|
||||
#include <ck/utility/blkgemmpipe_scheduler.hpp>
|
||||
#include <ck/utility/loop_scheduler.hpp>
|
||||
#include <ck/tensor_operation/gpu/device/tensor_layout.hpp>
|
||||
#include <ck_tile/ops/common/tensor_layout.hpp>
|
||||
#include <ck/tensor_operation/gpu/element/element_wise_operation.hpp>
|
||||
@@ -160,6 +161,17 @@ constexpr std::string_view pipeline_version_name(ck::BlockGemmPipelineVersion ve
|
||||
}
|
||||
}
|
||||
|
||||
// Convert LoopScheduler enum to string
|
||||
constexpr std::string_view loop_scheduler_name(ck::LoopScheduler sched)
|
||||
{
|
||||
using enum ck::LoopScheduler;
|
||||
switch(sched)
|
||||
{
|
||||
case Default: return "Default";
|
||||
case Interwave: return "Interwave";
|
||||
}
|
||||
}
|
||||
|
||||
// Convert std::array to string
|
||||
template <typename T, std::size_t N>
|
||||
inline std::string array_to_string(const std::array<T, N>& arr)
|
||||
|
||||
@@ -26,7 +26,9 @@ add_ck_builder_test(test_inline_diff test_inline_diff.cpp)
|
||||
|
||||
# Testing the virtual GetInstanceString methods requires kernel compilation.
|
||||
add_ck_builder_test(test_get_instance_string
|
||||
test_get_instance_string.cpp)
|
||||
test_get_instance_string_fwd_grp_conv_v3.cpp
|
||||
test_get_instance_string_fwd_grp_conv.cpp
|
||||
test_get_instance_string_fwd_grp_conv_large_tensor.cpp)
|
||||
|
||||
# Testing the fwd convolution builder requires kernel compilation.
|
||||
# To enable parallel compilation, the individual tests are split into separate files.
|
||||
|
||||
@@ -3,19 +3,18 @@
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
#include <gmock/gmock.h>
|
||||
#include <ck/ck.hpp>
|
||||
#include <ck/utility/reduction_operator.hpp>
|
||||
#include <ck_tile/builder/reflect/instance_traits.hpp>
|
||||
#include <ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp>
|
||||
#include <ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp>
|
||||
#include <ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp>
|
||||
|
||||
namespace {
|
||||
|
||||
using ::testing::ElementsAre;
|
||||
// Test fixture for InstanceTraits tests
|
||||
class InstanceTraitsTest : public ::testing::Test
|
||||
{
|
||||
};
|
||||
|
||||
// Test InstanceTraits with DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
|
||||
TEST_F(InstanceTraitsTest, ConvFwdInstanceTraitsExtraction)
|
||||
TEST(InstanceTraitsTest, ConvFwdInstanceTraitsExtraction)
|
||||
{
|
||||
// Define a concrete instance type with specific template parameters
|
||||
using DeviceInstance =
|
||||
@@ -156,8 +155,7 @@ TEST_F(InstanceTraitsTest, ConvFwdInstanceTraitsExtraction)
|
||||
ck::tensor_operation::element_wise::PassThrough>::value));
|
||||
}
|
||||
|
||||
// Test instance_string function
|
||||
TEST_F(InstanceTraitsTest, InstanceStringGeneration)
|
||||
TEST(InstanceTraitsTest, V3InstanceStringGeneration)
|
||||
{
|
||||
// Define a concrete instance type with specific template parameters
|
||||
using DeviceInstance =
|
||||
@@ -215,10 +213,8 @@ TEST_F(InstanceTraitsTest, InstanceStringGeneration)
|
||||
ck::half_t, // AComputeDataType
|
||||
ck::half_t>; // BComputeDataType
|
||||
|
||||
// Generate instance string
|
||||
std::string instance_str = ck_tile::reflect::instance_string<DeviceInstance>();
|
||||
|
||||
// Expected string with all template parameters in exact order
|
||||
std::string expected_str = "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3"
|
||||
"<2" // NDimSpatial
|
||||
",GNHWC" // ALayout
|
||||
@@ -269,6 +265,234 @@ TEST_F(InstanceTraitsTest, InstanceStringGeneration)
|
||||
",fp16" // AComputeDataType
|
||||
",fp16>"; // BComputeDataType
|
||||
|
||||
EXPECT_EQ(instance_str, expected_str);
|
||||
}
|
||||
|
||||
TEST(InstanceTraitsTest, BaseInstanceStringGeneration)
|
||||
{
|
||||
using DeviceInstance =
|
||||
ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<
|
||||
2, // NDimSpatial
|
||||
ck::tensor_layout::convolution::GNHWC, // ALayout
|
||||
ck::tensor_layout::convolution::GKYXC, // BLayout
|
||||
ck::Tuple<>, // DsLayout
|
||||
ck::tensor_layout::convolution::GNHWK, // ELayout
|
||||
ck::half_t, // ADataType
|
||||
ck::half_t, // BDataType
|
||||
float, // AccDataType
|
||||
ck::half_t, // CShuffleDataType
|
||||
ck::Tuple<>, // DsDataType
|
||||
ck::half_t, // EDataType
|
||||
ck::tensor_operation::element_wise::PassThrough, // AElementwiseOperation
|
||||
ck::tensor_operation::element_wise::PassThrough, // BElementwiseOperation
|
||||
ck::tensor_operation::element_wise::PassThrough, // CDEElementwiseOperation
|
||||
ck::tensor_operation::device::ConvolutionForwardSpecialization::
|
||||
Default, // ConvForwardSpecialization
|
||||
ck::tensor_operation::device::GemmSpecialization::Default, // GemmSpec
|
||||
1, // NumGemmKPrefetchStage
|
||||
256, // BlockSize
|
||||
128, // MPerBlock
|
||||
128, // NPerBlock
|
||||
16, // KPerBlock
|
||||
8, // AK1
|
||||
8, // BK1
|
||||
32, // MPerXDL
|
||||
32, // NPerXDL
|
||||
4, // MXdlPerWave
|
||||
4, // NXdlPerWave
|
||||
ck::Sequence<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1
|
||||
ck::Sequence<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder
|
||||
ck::Sequence<1, 0, 2>, // ABlockTransferSrcAccessOrder
|
||||
2, // ABlockTransferSrcVectorDim
|
||||
8, // ABlockTransferSrcScalarPerVector
|
||||
8, // ABlockTransferDstScalarPerVector_AK1
|
||||
1, // ABlockLdsExtraM
|
||||
ck::Sequence<4, 64, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1
|
||||
ck::Sequence<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder
|
||||
ck::Sequence<1, 0, 2>, // BBlockTransferSrcAccessOrder
|
||||
2, // BBlockTransferSrcVectorDim
|
||||
8, // BBlockTransferSrcScalarPerVector
|
||||
8, // BBlockTransferDstScalarPerVector_BK1
|
||||
1, // BBlockLdsExtraN
|
||||
1, // CShuffleMXdlPerWavePerShuffle
|
||||
1, // CShuffleNXdlPerWavePerShuffle
|
||||
ck::Sequence<1,
|
||||
32,
|
||||
1,
|
||||
8>, // CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
8, // CDEBlockTransferScalarPerVector_NPerBlock
|
||||
ck::half_t, // AComputeDataType
|
||||
ck::half_t, // BComputeDataType
|
||||
ck::LoopScheduler::Default, // LoopSched
|
||||
1>; // NumGroupsToMerge
|
||||
|
||||
std::string instance_str = ck_tile::reflect::instance_string<DeviceInstance>();
|
||||
|
||||
std::string expected_str = "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle"
|
||||
"<2" // NDimSpatial
|
||||
",GNHWC" // ALayout
|
||||
",GKYXC" // BLayout
|
||||
",EmptyTuple" // DsLayout
|
||||
",GNHWK" // ELayout
|
||||
",fp16" // ADataType
|
||||
",fp16" // BDataType
|
||||
",fp32" // AccDataType
|
||||
",fp16" // CShuffleDataType
|
||||
",EmptyTuple" // DsDataType
|
||||
",fp16" // EDataType
|
||||
",PassThrough" // AElementwiseOperation
|
||||
",PassThrough" // BElementwiseOperation
|
||||
",PassThrough" // CDEElementwiseOperation
|
||||
",Default" // ConvForwardSpecialization
|
||||
",Default" // GemmSpec
|
||||
",1" // NumGemmKPrefetchStage
|
||||
",256" // BlockSize
|
||||
",128" // MPerBlock
|
||||
",128" // NPerBlock
|
||||
",16" // KPerBlock
|
||||
",8" // AK1
|
||||
",8" // BK1
|
||||
",32" // MPerXDL
|
||||
",32" // NPerXDL
|
||||
",4" // MXdlPerWave
|
||||
",4" // NXdlPerWave
|
||||
",Seq(4,64,1)" // ABlockTransferThreadClusterLengths
|
||||
",Seq(1,0,2)" // ABlockTransferThreadClusterArrangeOrder
|
||||
",Seq(1,0,2)" // ABlockTransferSrcAccessOrder
|
||||
",2" // ABlockTransferSrcVectorDim
|
||||
",8" // ABlockTransferSrcScalarPerVector
|
||||
",8" // ABlockTransferDstScalarPerVector_AK1
|
||||
",1" // ABlockLdsExtraM
|
||||
",Seq(4,64,1)" // BBlockTransferThreadClusterLengths
|
||||
",Seq(1,0,2)" // BBlockTransferThreadClusterArrangeOrder
|
||||
",Seq(1,0,2)" // BBlockTransferSrcAccessOrder
|
||||
",2" // BBlockTransferSrcVectorDim
|
||||
",8" // BBlockTransferSrcScalarPerVector
|
||||
",8" // BBlockTransferDstScalarPerVector_BK1
|
||||
",1" // BBlockLdsExtraN
|
||||
",1" // CShuffleMXdlPerWavePerShuffle
|
||||
",1" // CShuffleNXdlPerWavePerShuffle
|
||||
",Seq(1,32,1,8)" // CDEBlockTransferClusterLengths
|
||||
",8" // CDEBlockTransferScalarPerVector_NPerBlock
|
||||
",fp16" // AComputeDataType
|
||||
",fp16" // BComputeDataType
|
||||
",Default" // LoopSched
|
||||
",1>"; // NumGroupsToMerge
|
||||
|
||||
EXPECT_EQ(instance_str, expected_str);
|
||||
}
|
||||
|
||||
TEST(InstanceTraitsTest, LargeTensorInstanceStringGeneration)
|
||||
{
|
||||
using DeviceInstance =
|
||||
ck::tensor_operation::device::DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor<
|
||||
2, // NDimSpatial
|
||||
ck::tensor_layout::convolution::GNHWC, // ALayout
|
||||
ck::tensor_layout::convolution::GKYXC, // BLayout
|
||||
ck::Tuple<>, // DsLayout
|
||||
ck::tensor_layout::convolution::GNHWK, // ELayout
|
||||
ck::half_t, // ADataType
|
||||
ck::half_t, // BDataType
|
||||
float, // AccDataType
|
||||
ck::half_t, // CShuffleDataType
|
||||
ck::Tuple<>, // DsDataType
|
||||
ck::half_t, // EDataType
|
||||
ck::tensor_operation::element_wise::PassThrough, // AElementwiseOperation
|
||||
ck::tensor_operation::element_wise::PassThrough, // BElementwiseOperation
|
||||
ck::tensor_operation::element_wise::PassThrough, // CDEElementwiseOperation
|
||||
ck::tensor_operation::device::ConvolutionForwardSpecialization::
|
||||
Default, // ConvForwardSpecialization
|
||||
ck::tensor_operation::device::GemmSpecialization::Default, // GemmSpec
|
||||
1, // NumGemmKPrefetchStage
|
||||
256, // BlockSize
|
||||
128, // MPerBlock
|
||||
128, // NPerBlock
|
||||
16, // KPerBlock
|
||||
8, // AK1
|
||||
8, // BK1
|
||||
32, // MPerXDL
|
||||
32, // NPerXDL
|
||||
4, // MXdlPerWave
|
||||
4, // NXdlPerWave
|
||||
ck::Sequence<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1
|
||||
ck::Sequence<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder
|
||||
ck::Sequence<1, 0, 2>, // ABlockTransferSrcAccessOrder
|
||||
2, // ABlockTransferSrcVectorDim
|
||||
8, // ABlockTransferSrcScalarPerVector
|
||||
8, // ABlockTransferDstScalarPerVector_AK1
|
||||
1, // ABlockLdsExtraM
|
||||
ck::Sequence<4, 64, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1
|
||||
ck::Sequence<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder
|
||||
ck::Sequence<1, 0, 2>, // BBlockTransferSrcAccessOrder
|
||||
2, // BBlockTransferSrcVectorDim
|
||||
8, // BBlockTransferSrcScalarPerVector
|
||||
8, // BBlockTransferDstScalarPerVector_BK1
|
||||
1, // BBlockLdsExtraN
|
||||
1, // CShuffleMXdlPerWavePerShuffle
|
||||
1, // CShuffleNXdlPerWavePerShuffle
|
||||
ck::Sequence<1,
|
||||
32,
|
||||
1,
|
||||
8>, // CDEBlockTransferClusterLengths
|
||||
8, // CDEBlockTransferScalarPerVector_NPerBlock
|
||||
ck::half_t, // AComputeDataType
|
||||
ck::half_t, // BComputeDataType
|
||||
ck::LoopScheduler::Default>; // LoopSched
|
||||
|
||||
// Generate instance string
|
||||
std::string instance_str = ck_tile::reflect::instance_string<DeviceInstance>();
|
||||
|
||||
// Expected string with all 48 template parameters
|
||||
std::string expected_str = "DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor"
|
||||
"<2" // NDimSpatial
|
||||
",GNHWC" // ALayout
|
||||
",GKYXC" // BLayout
|
||||
",EmptyTuple" // DsLayout
|
||||
",GNHWK" // ELayout
|
||||
",fp16" // ADataType
|
||||
",fp16" // BDataType
|
||||
",fp32" // AccDataType
|
||||
",fp16" // CShuffleDataType
|
||||
",EmptyTuple" // DsDataType
|
||||
",fp16" // EDataType
|
||||
",PassThrough" // AElementwiseOperation
|
||||
",PassThrough" // BElementwiseOperation
|
||||
",PassThrough" // CDEElementwiseOperation
|
||||
",Default" // ConvForwardSpecialization
|
||||
",Default" // GemmSpec
|
||||
",1" // NumGemmKPrefetchStage
|
||||
",256" // BlockSize
|
||||
",128" // MPerBlock
|
||||
",128" // NPerBlock
|
||||
",16" // KPerBlock
|
||||
",8" // AK1
|
||||
",8" // BK1
|
||||
",32" // MPerXDL
|
||||
",32" // NPerXDL
|
||||
",4" // MXdlPerWave
|
||||
",4" // NXdlPerWave
|
||||
",Seq(4,64,1)" // ABlockTransferThreadClusterLengths
|
||||
",Seq(1,0,2)" // ABlockTransferThreadClusterArrangeOrder
|
||||
",Seq(1,0,2)" // ABlockTransferSrcAccessOrder
|
||||
",2" // ABlockTransferSrcVectorDim
|
||||
",8" // ABlockTransferSrcScalarPerVector
|
||||
",8" // ABlockTransferDstScalarPerVector_AK1
|
||||
",1" // ABlockLdsExtraM
|
||||
",Seq(4,64,1)" // BBlockTransferThreadClusterLengths
|
||||
",Seq(1,0,2)" // BBlockTransferThreadClusterArrangeOrder
|
||||
",Seq(1,0,2)" // BBlockTransferSrcAccessOrder
|
||||
",2" // BBlockTransferSrcVectorDim
|
||||
",8" // BBlockTransferSrcScalarPerVector
|
||||
",8" // BBlockTransferDstScalarPerVector_BK1
|
||||
",1" // BBlockLdsExtraN
|
||||
",1" // CShuffleMXdlPerWavePerShuffle
|
||||
",1" // CShuffleNXdlPerWavePerShuffle
|
||||
",Seq(1,32,1,8)" // CDEBlockTransferClusterLengths
|
||||
",8" // CDEBlockTransferScalarPerVector_NPerBlock
|
||||
",fp16" // AComputeDataType
|
||||
",fp16" // BComputeDataType
|
||||
",Default>"; // LoopSched
|
||||
|
||||
// Verify the generated string matches exactly
|
||||
EXPECT_EQ(instance_str, expected_str);
|
||||
}
|
||||
@@ -0,0 +1,104 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
#include <ck_tile/builder/reflect/instance_traits.hpp>
|
||||
#include <ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_abd.hpp>
|
||||
#include <ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp>
|
||||
|
||||
// Test GetInstanceString through base class pointer for non-V3 variant
|
||||
TEST(GetInstanceString, ReturnsStringForFwdGrpConvInstance)
|
||||
{
|
||||
// Use the template helper to get a working instance configuration
|
||||
using InstanceTuple =
|
||||
ck::tensor_operation::device::instance::device_grouped_conv_fwd_xdl_f16_instances<
|
||||
2, // NDimSpatial
|
||||
ck::tensor_operation::device::instance::GNHWC, // ALayout
|
||||
ck::tensor_operation::device::instance::GKYXC, // BLayout
|
||||
ck::tensor_operation::device::instance::Empty_Tuple, // DsLayout
|
||||
ck::tensor_operation::device::instance::GNHWK, // ELayout
|
||||
ck::tensor_operation::device::instance::ConvFwdDefault>; // ConvForwardSpecialization
|
||||
|
||||
// Get the first instance from the tuple
|
||||
using DeviceInstance = typename std::tuple_element<0, InstanceTuple>::type;
|
||||
|
||||
// Define the base class type using DeviceGroupedConvFwdMultipleABD
|
||||
using BaseClass = ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD<
|
||||
2, // NDimSpatial
|
||||
ck::tensor_operation::device::instance::GNHWC, // ALayout
|
||||
ck::tensor_operation::device::instance::GKYXC, // BLayout
|
||||
ck::tensor_operation::device::instance::Empty_Tuple, // DsLayout
|
||||
ck::tensor_operation::device::instance::GNHWK, // ELayout
|
||||
ck::half_t, // ADataType
|
||||
ck::half_t, // BDataType
|
||||
ck::Tuple<>, // DsDataType
|
||||
ck::half_t, // EDataType
|
||||
ck::tensor_operation::element_wise::PassThrough, // AElementwiseOperation
|
||||
ck::tensor_operation::element_wise::PassThrough, // BElementwiseOperation
|
||||
ck::tensor_operation::element_wise::PassThrough, // CDEElementwiseOperation
|
||||
ck::half_t, // AComputeType
|
||||
ck::half_t>; // BComputeType
|
||||
|
||||
// Create an instance of the derived class
|
||||
DeviceInstance device_instance;
|
||||
|
||||
// Get a pointer to the base class
|
||||
BaseClass* base_ptr = &device_instance;
|
||||
|
||||
// Call GetInstanceString through the base class pointer
|
||||
std::string instance_str = base_ptr->GetInstanceString();
|
||||
|
||||
// Expected complete instance string based on the first instance from
|
||||
// device_grouped_conv_fwd_xdl_f16_instances
|
||||
std::string expected_str = "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle"
|
||||
"<2" // NDimSpatial
|
||||
",GNHWC" // ALayout
|
||||
",GKYXC" // BLayout
|
||||
",EmptyTuple" // DsLayout
|
||||
",GNHWK" // ELayout
|
||||
",fp16" // ADataType
|
||||
",fp16" // BDataType
|
||||
",fp32" // AccDataType
|
||||
",fp16" // CShuffleDataType
|
||||
",EmptyTuple" // DsDataType
|
||||
",fp16" // EDataType
|
||||
",PassThrough" // AElementwiseOperation
|
||||
",PassThrough" // BElementwiseOperation
|
||||
",PassThrough" // CDEElementwiseOperation
|
||||
",Default" // ConvForwardSpecialization
|
||||
",MNKPadding" // GemmSpec
|
||||
",1" // NumGemmKPrefetchStage
|
||||
",64" // BlockSize
|
||||
",64" // MPerBlock
|
||||
",64" // NPerBlock
|
||||
",32" // KPerBlock
|
||||
",8" // AK1
|
||||
",8" // BK1
|
||||
",32" // MPerXDL
|
||||
",32" // NPerXDL
|
||||
",2" // MXdlPerWave
|
||||
",2" // NXdlPerWave
|
||||
",Seq(4,16,1)" // ABlockTransferThreadClusterLengths
|
||||
",Seq(1,0,2)" // ABlockTransferThreadClusterArrangeOrder
|
||||
",Seq(1,0,2)" // ABlockTransferSrcAccessOrder
|
||||
",2" // ABlockTransferSrcVectorDim
|
||||
",1" // ABlockTransferSrcScalarPerVector
|
||||
",8" // ABlockTransferDstScalarPerVector_AK1
|
||||
",1" // ABlockLdsExtraM
|
||||
",Seq(4,16,1)" // BBlockTransferThreadClusterLengths
|
||||
",Seq(1,0,2)" // BBlockTransferThreadClusterArrangeOrder
|
||||
",Seq(1,0,2)" // BBlockTransferSrcAccessOrder
|
||||
",2" // BBlockTransferSrcVectorDim
|
||||
",1" // BBlockTransferSrcScalarPerVector
|
||||
",8" // BBlockTransferDstScalarPerVector_BK1
|
||||
",1" // BBlockLdsExtraN
|
||||
",1" // CShuffleMXdlPerWavePerShuffle
|
||||
",1" // CShuffleNXdlPerWavePerShuffle
|
||||
",Seq(1,16,1,4)" // CDEBlockTransferClusterLengths
|
||||
",1" // CDEBlockTransferScalarPerVector_NPerBlock
|
||||
",fp16" // AComputeDataType
|
||||
",fp16" // BComputeDataType
|
||||
",Default" // LoopScheduler
|
||||
",1>"; // NumGroupsToMerge
|
||||
EXPECT_EQ(instance_str, expected_str);
|
||||
}
|
||||
@@ -0,0 +1,103 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
#include <ck_tile/builder/reflect/instance_traits.hpp>
|
||||
#include <ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_abd.hpp>
|
||||
#include <ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_large_tensor_instance.hpp>
|
||||
|
||||
// Test GetInstanceString through base class pointer for large tensor variant
|
||||
TEST(GetInstanceString, ReturnsStringForFwdGrpConvLargeTensorInstance)
|
||||
{
|
||||
// Use the template helper to get a working instance configuration
|
||||
using InstanceTuple = ck::tensor_operation::device::instance::
|
||||
device_grouped_conv_fwd_xdl_large_tensor_f16_instances<
|
||||
2, // NDimSpatial
|
||||
ck::tensor_operation::device::instance::GNHWC, // ALayout
|
||||
ck::tensor_operation::device::instance::GKYXC, // BLayout
|
||||
ck::tensor_operation::device::instance::Empty_Tuple, // DsLayout
|
||||
ck::tensor_operation::device::instance::GNHWK, // ELayout
|
||||
ck::tensor_operation::device::instance::ConvFwdDefault>; // ConvForwardSpecialization
|
||||
|
||||
// Get the first instance from the tuple
|
||||
using DeviceInstance = typename std::tuple_element<0, InstanceTuple>::type;
|
||||
|
||||
// Define the base class type using DeviceGroupedConvFwdMultipleABD
|
||||
using BaseClass = ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD<
|
||||
2, // NDimSpatial
|
||||
ck::tensor_operation::device::instance::GNHWC, // ALayout
|
||||
ck::tensor_operation::device::instance::GKYXC, // BLayout
|
||||
ck::tensor_operation::device::instance::Empty_Tuple, // DsLayout
|
||||
ck::tensor_operation::device::instance::GNHWK, // ELayout
|
||||
ck::half_t, // ADataType
|
||||
ck::half_t, // BDataType
|
||||
ck::Tuple<>, // DsDataType
|
||||
ck::half_t, // EDataType
|
||||
ck::tensor_operation::element_wise::PassThrough, // AElementwiseOperation
|
||||
ck::tensor_operation::element_wise::PassThrough, // BElementwiseOperation
|
||||
ck::tensor_operation::element_wise::PassThrough, // CDEElementwiseOperation
|
||||
ck::half_t, // AComputeType
|
||||
ck::half_t>; // BComputeType
|
||||
|
||||
// Create an instance of the derived class
|
||||
DeviceInstance device_instance;
|
||||
|
||||
// Get a pointer to the base class
|
||||
BaseClass* base_ptr = &device_instance;
|
||||
|
||||
// Call GetInstanceString through the base class pointer
|
||||
std::string instance_str = base_ptr->GetInstanceString();
|
||||
|
||||
// Expected complete instance string based on the first instance from
|
||||
// device_grouped_conv_fwd_xdl_large_tensor_f16_instances
|
||||
std::string expected_str = "DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor"
|
||||
"<2" // NDimSpatial
|
||||
",GNHWC" // ALayout
|
||||
",GKYXC" // BLayout
|
||||
",EmptyTuple" // DsLayout
|
||||
",GNHWK" // ELayout
|
||||
",fp16" // ADataType
|
||||
",fp16" // BDataType
|
||||
",fp32" // AccDataType
|
||||
",fp16" // CShuffleDataType
|
||||
",EmptyTuple" // DsDataType
|
||||
",fp16" // EDataType
|
||||
",PassThrough" // AElementwiseOperation
|
||||
",PassThrough" // BElementwiseOperation
|
||||
",PassThrough" // CDEElementwiseOperation
|
||||
",Default" // ConvForwardSpecialization
|
||||
",MNKPadding" // GemmSpec
|
||||
",1" // NumGemmKPrefetchStage
|
||||
",64" // BlockSize
|
||||
",64" // MPerBlock
|
||||
",64" // NPerBlock
|
||||
",32" // KPerBlock
|
||||
",8" // AK1
|
||||
",8" // BK1
|
||||
",32" // MPerXDL
|
||||
",32" // NPerXDL
|
||||
",2" // MXdlPerWave
|
||||
",2" // NXdlPerWave
|
||||
",Seq(4,16,1)" // ABlockTransferThreadClusterLengths
|
||||
",Seq(1,0,2)" // ABlockTransferThreadClusterArrangeOrder
|
||||
",Seq(1,0,2)" // ABlockTransferSrcAccessOrder
|
||||
",2" // ABlockTransferSrcVectorDim
|
||||
",1" // ABlockTransferSrcScalarPerVector
|
||||
",8" // ABlockTransferDstScalarPerVector_AK1
|
||||
",1" // ABlockLdsExtraM
|
||||
",Seq(4,16,1)" // BBlockTransferThreadClusterLengths
|
||||
",Seq(1,0,2)" // BBlockTransferThreadClusterArrangeOrder
|
||||
",Seq(1,0,2)" // BBlockTransferSrcAccessOrder
|
||||
",2" // BBlockTransferSrcVectorDim
|
||||
",1" // BBlockTransferSrcScalarPerVector
|
||||
",8" // BBlockTransferDstScalarPerVector_BK1
|
||||
",1" // BBlockLdsExtraN
|
||||
",1" // CShuffleMXdlPerWavePerShuffle
|
||||
",1" // CShuffleNXdlPerWavePerShuffle
|
||||
",Seq(1,16,1,4)" // CDEBlockTransferClusterLengths
|
||||
",1" // CDEBlockTransferScalarPerVector_NPerBlock
|
||||
",fp16" // AComputeDataType
|
||||
",fp16" // BComputeDataType
|
||||
",Default>"; // LoopScheduler
|
||||
EXPECT_EQ(instance_str, expected_str);
|
||||
}
|
||||
@@ -6,8 +6,8 @@
|
||||
#include <ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_abd.hpp>
|
||||
#include <ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_comp_instance.hpp>
|
||||
|
||||
// Test GetInstanceString through base class pointer
|
||||
TEST(GetInstanceStringTest, GetInstanceStringThroughBaseClass)
|
||||
// Test GetInstanceString through base class pointer for V3 variant
|
||||
TEST(GetInstanceString, ReturnsStringForFwdGrpConvV3Instance)
|
||||
{
|
||||
// Use the template helper to get a working instance configuration
|
||||
using InstanceTuple =
|
||||
@@ -199,6 +199,14 @@ TEST(InstanceTraitsUtil, PipelineVersionNameReturnsCorrectStrings)
|
||||
ElementsAre("v1", "v2", "v3", "v4", "v5"));
|
||||
}
|
||||
|
||||
TEST(InstanceTraitsUtil, LoopSchedulerNameReturnsCorrectStrings)
|
||||
{
|
||||
using enum ck::LoopScheduler;
|
||||
EXPECT_THAT(std::vector<std::string_view> names = {loop_scheduler_name(Default),
|
||||
loop_scheduler_name(Interwave)},
|
||||
ElementsAre("Default", "Interwave"));
|
||||
}
|
||||
|
||||
TEST(InstanceTraitsUtil, TupleNameReturnsEmptyTupleForEmptyTuple)
|
||||
{
|
||||
EXPECT_EQ(tuple_name<ck::Tuple<>>(), "EmptyTuple");
|
||||
|
||||
Reference in New Issue
Block a user