mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-19 04:19:36 +00:00
Merge commit '37dff024c1d2c6420a91d9a4b0801b350db3eede' into develop
This commit is contained in:
@@ -0,0 +1,58 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
// Compile-time reflection for CK device kernel instances.
|
||||
//
|
||||
// - This is the Lowest-level reflection primitive for higher-level semantic abstractions (e.g.,
|
||||
// ConvTraits).
|
||||
// - Extracts raw template parameters (block sizes, data types, layouts, tuning params) from kernel
|
||||
// specializations.
|
||||
// - Provides uniform interface to query kernel configuration without implementation knowledge
|
||||
// - Other details about the device kernels can be manually added to template specializations.
|
||||
// - Currently supports:
|
||||
// - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <array>
|
||||
#include <string>
|
||||
#include <sstream>
|
||||
#include <type_traits>
|
||||
#include <ck/utility/data_type.hpp>
|
||||
#include <ck/utility/sequence.hpp>
|
||||
#include <ck/utility/blkgemmpipe_scheduler.hpp>
|
||||
#include <ck/tensor_operation/gpu/device/tensor_layout.hpp>
|
||||
#include <ck/tensor_operation/gpu/element/element_wise_operation.hpp>
|
||||
#include <ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp>
|
||||
#include <ck/tensor_operation/gpu/device/gemm_specialization.hpp>
|
||||
#include "instance_traits_util.hpp"
|
||||
|
||||
namespace ck_tile::reflect {
|
||||
|
||||
// Primary template for InstanceTraits - extracts compile-time information directly from
|
||||
// device kernel instances (e.g., DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3)
|
||||
//
|
||||
// This is an unspecialized template declaration. Actual specializations for specific
|
||||
// device kernels are provided in separate header files (e.g.,
|
||||
// instance_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp).
|
||||
template <typename Instance>
|
||||
struct InstanceTraits;
|
||||
|
||||
// Concept-based helper to detect if InstanceTraits<T> is specialized
|
||||
// (i.e., has the instance_string() member function).
|
||||
// This can be used for an informative static_assert in the device-op GetInstanceString in case the
|
||||
// instance_string() template is broken.
|
||||
template <typename T>
|
||||
concept HasInstanceTraits = requires {
|
||||
{ InstanceTraits<T>::instance_string() } -> std::convertible_to<std::string>;
|
||||
};
|
||||
|
||||
// Free function that delegates to InstanceTraits static member function.
|
||||
// Each InstanceTraits specialization provides its own instance_string() implementation.
|
||||
template <typename T>
|
||||
inline std::string instance_string()
|
||||
{
|
||||
return InstanceTraits<T>::instance_string();
|
||||
}
|
||||
|
||||
} // namespace ck_tile::reflect
|
||||
@@ -0,0 +1,345 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
// InstanceTraits specialization for DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
|
||||
//
|
||||
// CRITICAL MAINTENANCE NOTE:
|
||||
// This InstanceTraits file MUST be kept strictly in sync with the device implementation header:
|
||||
// ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp
|
||||
// "In sync" means that the template parameter order, names, and types in the declaration below
|
||||
// MUST EXACTLY MATCH those in the device implementation. If these diverge, you may encounter
|
||||
// compilation errors, subtle template instantiation mismatches, or silent runtime bugs that are
|
||||
// difficult to diagnose. Always update both files together and review changes carefully.
|
||||
// ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "instance_traits.hpp"
|
||||
|
||||
// Forward declaration to avoid circular dependency.
|
||||
// This file will be included by the device implementation header, so we cannot include
|
||||
// the implementation header here. We only need the template signature to pattern-match
|
||||
// on template parameters - we don't need any implementation details.
|
||||
namespace ck::tensor_operation::device {
|
||||
|
||||
template <ck::index_t NDimSpatial,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename DsLayout,
|
||||
typename ELayout,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename AccDataType,
|
||||
typename CShuffleDataType,
|
||||
typename DsDataType,
|
||||
typename EDataType,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CDEElementwiseOperation,
|
||||
ConvolutionForwardSpecialization ConvForwardSpecialization,
|
||||
GemmSpecialization GemmSpec,
|
||||
ck::index_t BlockSize,
|
||||
ck::index_t MPerBlock,
|
||||
ck::index_t NPerBlock,
|
||||
ck::index_t KPerBlock,
|
||||
ck::index_t AK1,
|
||||
ck::index_t BK1,
|
||||
ck::index_t MPerXDL,
|
||||
ck::index_t NPerXDL,
|
||||
ck::index_t MXdlPerWave,
|
||||
ck::index_t NXdlPerWave,
|
||||
typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
|
||||
typename ABlockTransferThreadClusterArrangeOrder,
|
||||
typename ABlockTransferSrcAccessOrder,
|
||||
ck::index_t ABlockTransferSrcVectorDim,
|
||||
ck::index_t ABlockTransferSrcScalarPerVector,
|
||||
ck::index_t ABlockTransferDstScalarPerVector_AK1,
|
||||
ck::index_t ABlockLdsExtraM,
|
||||
typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
|
||||
typename BBlockTransferThreadClusterArrangeOrder,
|
||||
typename BBlockTransferSrcAccessOrder,
|
||||
ck::index_t BBlockTransferSrcVectorDim,
|
||||
ck::index_t BBlockTransferSrcScalarPerVector,
|
||||
ck::index_t BBlockTransferDstScalarPerVector_BK1,
|
||||
ck::index_t BBlockLdsExtraN,
|
||||
ck::index_t CShuffleMXdlPerWavePerShuffle,
|
||||
ck::index_t CShuffleNXdlPerWavePerShuffle,
|
||||
typename CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
ck::index_t CDEBlockTransferScalarPerVector_NPerBlock,
|
||||
ck::BlockGemmPipelineScheduler BlkGemmPipeSched,
|
||||
ck::BlockGemmPipelineVersion BlkGemmPipelineVer,
|
||||
typename AComputeDataType,
|
||||
typename BComputeDataType>
|
||||
struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3;
|
||||
|
||||
} // namespace ck::tensor_operation::device
|
||||
|
||||
namespace ck_tile::reflect {
|
||||
|
||||
// Specialization for DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
|
||||
template <ck::index_t NDimSpatial,
|
||||
typename ALayout_,
|
||||
typename BLayout_,
|
||||
typename DsLayout_,
|
||||
typename ELayout_,
|
||||
typename ADataType_,
|
||||
typename BDataType_,
|
||||
typename AccDataType_,
|
||||
typename CShuffleDataType_,
|
||||
typename DsDataType_,
|
||||
typename EDataType_,
|
||||
typename AElementwiseOperation_,
|
||||
typename BElementwiseOperation_,
|
||||
typename CDEElementwiseOperation_,
|
||||
ck::tensor_operation::device::ConvolutionForwardSpecialization ConvForwardSpecialization,
|
||||
ck::tensor_operation::device::GemmSpecialization GemmSpec,
|
||||
ck::index_t BlockSize,
|
||||
ck::index_t MPerBlock,
|
||||
ck::index_t NPerBlock,
|
||||
ck::index_t KPerBlock,
|
||||
ck::index_t AK1,
|
||||
ck::index_t BK1,
|
||||
ck::index_t MPerXDL,
|
||||
ck::index_t NPerXDL,
|
||||
ck::index_t MXdlPerWave,
|
||||
ck::index_t NXdlPerWave,
|
||||
typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
|
||||
typename ABlockTransferThreadClusterArrangeOrder,
|
||||
typename ABlockTransferSrcAccessOrder_,
|
||||
ck::index_t ABlockTransferSrcVectorDim,
|
||||
ck::index_t ABlockTransferSrcScalarPerVector,
|
||||
ck::index_t ABlockTransferDstScalarPerVector_AK1,
|
||||
ck::index_t ABlockLdsExtraM,
|
||||
typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
|
||||
typename BBlockTransferThreadClusterArrangeOrder,
|
||||
typename BBlockTransferSrcAccessOrder_,
|
||||
ck::index_t BBlockTransferSrcVectorDim,
|
||||
ck::index_t BBlockTransferSrcScalarPerVector,
|
||||
ck::index_t BBlockTransferDstScalarPerVector_BK1,
|
||||
ck::index_t BBlockLdsExtraN,
|
||||
ck::index_t CShuffleMXdlPerWavePerShuffle,
|
||||
ck::index_t CShuffleNXdlPerWavePerShuffle,
|
||||
typename CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
ck::index_t CDEBlockTransferScalarPerVector_NPerBlock,
|
||||
ck::BlockGemmPipelineScheduler BlkGemmPipeSched,
|
||||
ck::BlockGemmPipelineVersion BlkGemmPipelineVer,
|
||||
typename AComputeDataType_,
|
||||
typename BComputeDataType_>
|
||||
struct InstanceTraits<ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<
|
||||
NDimSpatial,
|
||||
ALayout_,
|
||||
BLayout_,
|
||||
DsLayout_,
|
||||
ELayout_,
|
||||
ADataType_,
|
||||
BDataType_,
|
||||
AccDataType_,
|
||||
CShuffleDataType_,
|
||||
DsDataType_,
|
||||
EDataType_,
|
||||
AElementwiseOperation_,
|
||||
BElementwiseOperation_,
|
||||
CDEElementwiseOperation_,
|
||||
ConvForwardSpecialization,
|
||||
GemmSpec,
|
||||
BlockSize,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
AK1,
|
||||
BK1,
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
MXdlPerWave,
|
||||
NXdlPerWave,
|
||||
ABlockTransferThreadClusterLengths_AK0_M_AK1,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
ABlockTransferSrcAccessOrder_,
|
||||
ABlockTransferSrcVectorDim,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
ABlockTransferDstScalarPerVector_AK1,
|
||||
ABlockLdsExtraM,
|
||||
BBlockTransferThreadClusterLengths_BK0_N_BK1,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
BBlockTransferSrcAccessOrder_,
|
||||
BBlockTransferSrcVectorDim,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
BBlockTransferDstScalarPerVector_BK1,
|
||||
BBlockLdsExtraN,
|
||||
CShuffleMXdlPerWavePerShuffle,
|
||||
CShuffleNXdlPerWavePerShuffle,
|
||||
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
CDEBlockTransferScalarPerVector_NPerBlock,
|
||||
BlkGemmPipeSched,
|
||||
BlkGemmPipelineVer,
|
||||
AComputeDataType_,
|
||||
BComputeDataType_>>
|
||||
{
|
||||
// Spatial dimension
|
||||
static constexpr int kSpatialDim = NDimSpatial;
|
||||
|
||||
// Layout types
|
||||
using ALayout = ALayout_;
|
||||
using BLayout = BLayout_;
|
||||
using DsLayout = DsLayout_;
|
||||
using ELayout = ELayout_;
|
||||
|
||||
// Data types
|
||||
using ADataType = ADataType_;
|
||||
using BDataType = BDataType_;
|
||||
using AccDataType = AccDataType_;
|
||||
using CShuffleDataType = CShuffleDataType_;
|
||||
using DsDataType = DsDataType_;
|
||||
using EDataType = EDataType_;
|
||||
|
||||
// Element-wise operations
|
||||
using AElementwiseOperation = AElementwiseOperation_;
|
||||
using BElementwiseOperation = BElementwiseOperation_;
|
||||
using CDEElementwiseOperation = CDEElementwiseOperation_;
|
||||
|
||||
// Specialization
|
||||
static constexpr ck::tensor_operation::device::ConvolutionForwardSpecialization
|
||||
kConvForwardSpecialization = ConvForwardSpecialization;
|
||||
static constexpr ck::tensor_operation::device::GemmSpecialization kGemmSpecialization =
|
||||
GemmSpec;
|
||||
|
||||
// Block configuration
|
||||
static constexpr int kBlockSize = BlockSize;
|
||||
static constexpr int kMPerBlock = MPerBlock;
|
||||
static constexpr int kNPerBlock = NPerBlock;
|
||||
static constexpr int kKPerBlock = KPerBlock;
|
||||
|
||||
// Tuning parameters
|
||||
static constexpr int kAK1 = AK1;
|
||||
static constexpr int kBK1 = BK1;
|
||||
static constexpr int kMPerXDL = MPerXDL;
|
||||
static constexpr int kNPerXDL = NPerXDL;
|
||||
static constexpr int kMXdlPerWave = MXdlPerWave;
|
||||
static constexpr int kNXdlPerWave = NXdlPerWave;
|
||||
|
||||
// A block transfer thread cluster dimensions (converted to std::array)
|
||||
static constexpr auto kAThreadClusterLengths =
|
||||
detail::SequenceToArray<ABlockTransferThreadClusterLengths_AK0_M_AK1>::value;
|
||||
static constexpr auto kAThreadClusterArrangeOrder =
|
||||
detail::SequenceToArray<ABlockTransferThreadClusterArrangeOrder>::value;
|
||||
static constexpr auto kABlockTransferSrcAccessOrder =
|
||||
detail::SequenceToArray<ABlockTransferSrcAccessOrder_>::value;
|
||||
static constexpr int kABlockTransferSrcVectorDim = ABlockTransferSrcVectorDim;
|
||||
static constexpr int kABlockTransferSrcScalarPerVector = ABlockTransferSrcScalarPerVector;
|
||||
static constexpr int kABlockTransferDstScalarPerVectorK1 = ABlockTransferDstScalarPerVector_AK1;
|
||||
static constexpr int kABlockLdsExtraM = ABlockLdsExtraM;
|
||||
|
||||
// B block transfer thread cluster dimensions (converted to std::array)
|
||||
static constexpr auto kBThreadClusterLengths =
|
||||
detail::SequenceToArray<BBlockTransferThreadClusterLengths_BK0_N_BK1>::value;
|
||||
static constexpr auto kBThreadClusterArrangeOrder =
|
||||
detail::SequenceToArray<BBlockTransferThreadClusterArrangeOrder>::value;
|
||||
static constexpr auto kBBlockTransferSrcAccessOrder =
|
||||
detail::SequenceToArray<BBlockTransferSrcAccessOrder_>::value;
|
||||
static constexpr int kBBlockTransferSrcVectorDim = BBlockTransferSrcVectorDim;
|
||||
static constexpr int kBBlockTransferSrcScalarPerVector = BBlockTransferSrcScalarPerVector;
|
||||
static constexpr int kBBlockTransferDstScalarPerVectorK1 = BBlockTransferDstScalarPerVector_BK1;
|
||||
static constexpr int kBBlockLdsExtraN = BBlockLdsExtraN;
|
||||
|
||||
// C shuffle parameters (converted to std::array)
|
||||
static constexpr int kCShuffleMXdlPerWavePerShuffle = CShuffleMXdlPerWavePerShuffle;
|
||||
static constexpr int kCShuffleNXdlPerWavePerShuffle = CShuffleNXdlPerWavePerShuffle;
|
||||
static constexpr auto kCThreadClusterLengths = detail::SequenceToArray<
|
||||
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock>::value;
|
||||
static constexpr int kCBlockTransferScalarPerVector = CDEBlockTransferScalarPerVector_NPerBlock;
|
||||
|
||||
// Pipeline configuration
|
||||
static constexpr ck::BlockGemmPipelineScheduler kPipelineScheduler = BlkGemmPipeSched;
|
||||
static constexpr ck::BlockGemmPipelineVersion kPipelineVersion = BlkGemmPipelineVer;
|
||||
|
||||
// Compute data types
|
||||
using AComputeDataType = AComputeDataType_;
|
||||
using BComputeDataType = BComputeDataType_;
|
||||
|
||||
// Static member function to generate instance string
|
||||
static std::string instance_string()
|
||||
{
|
||||
std::ostringstream oss;
|
||||
|
||||
// Kernel type name
|
||||
oss << "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3";
|
||||
|
||||
// Template parameters in exact order matching InstanceTraits member order
|
||||
oss << "<" << kSpatialDim; // 1. NDimSpatial
|
||||
oss << "," << detail::layout_name<ALayout>(); // 2. ALayout
|
||||
oss << "," << detail::layout_name<BLayout>(); // 3. BLayout
|
||||
oss << "," << detail::tuple_name<DsLayout>(); // 4. DsLayout
|
||||
oss << "," << detail::layout_name<ELayout>(); // 5. ELayout
|
||||
oss << "," << detail::type_name<ADataType>(); // 6. ADataType
|
||||
oss << "," << detail::type_name<BDataType>(); // 7. BDataType
|
||||
oss << "," << detail::type_name<AccDataType>(); // 8. AccDataType
|
||||
oss << "," << detail::type_name<CShuffleDataType>(); // 9. CShuffleDataType
|
||||
oss << "," << detail::tuple_name<DsDataType>(); // 10. DsDataType
|
||||
oss << "," << detail::type_name<EDataType>(); // 11. EDataType
|
||||
oss << ","
|
||||
<< detail::elementwise_op_name<AElementwiseOperation>(); // 12. AElementwiseOperation
|
||||
oss << ","
|
||||
<< detail::elementwise_op_name<BElementwiseOperation>(); // 13. BElementwiseOperation
|
||||
oss << ","
|
||||
<< detail::elementwise_op_name<CDEElementwiseOperation>(); // 14.
|
||||
// CDEElementwiseOperation
|
||||
oss << ","
|
||||
<< detail::conv_fwd_spec_name(
|
||||
kConvForwardSpecialization); // 15. ConvForwardSpecialization
|
||||
oss << "," << detail::gemm_spec_name(kGemmSpecialization); // 16. GemmSpec
|
||||
oss << "," << kBlockSize; // 17. BlockSize
|
||||
oss << "," << kMPerBlock; // 18. MPerBlock
|
||||
oss << "," << kNPerBlock; // 19. NPerBlock
|
||||
oss << "," << kKPerBlock; // 20. KPerBlock
|
||||
oss << "," << kAK1; // 21. AK1
|
||||
oss << "," << kBK1; // 22. BK1
|
||||
oss << "," << kMPerXDL; // 23. MPerXDL
|
||||
oss << "," << kNPerXDL; // 24. NPerXDL
|
||||
oss << "," << kMXdlPerWave; // 25. MXdlPerWave
|
||||
oss << "," << kNXdlPerWave; // 26. NXdlPerWave
|
||||
oss << ","
|
||||
<< detail::array_to_string(
|
||||
kAThreadClusterLengths); // 27. ABlockTransferThreadClusterLengths
|
||||
oss << ","
|
||||
<< detail::array_to_string(
|
||||
kAThreadClusterArrangeOrder); // 28. ABlockTransferThreadClusterArrangeOrder
|
||||
oss << ","
|
||||
<< detail::array_to_string(
|
||||
kABlockTransferSrcAccessOrder); // 29. ABlockTransferSrcAccessOrder
|
||||
oss << "," << kABlockTransferSrcVectorDim; // 30. ABlockTransferSrcVectorDim
|
||||
oss << "," << kABlockTransferSrcScalarPerVector; // 31. ABlockTransferSrcScalarPerVector
|
||||
oss << ","
|
||||
<< kABlockTransferDstScalarPerVectorK1; // 32. ABlockTransferDstScalarPerVector_AK1
|
||||
oss << "," << kABlockLdsExtraM; // 33. ABlockLdsExtraM
|
||||
oss << ","
|
||||
<< detail::array_to_string(
|
||||
kBThreadClusterLengths); // 34. BBlockTransferThreadClusterLengths
|
||||
oss << ","
|
||||
<< detail::array_to_string(
|
||||
kBThreadClusterArrangeOrder); // 35. BBlockTransferThreadClusterArrangeOrder
|
||||
oss << ","
|
||||
<< detail::array_to_string(
|
||||
kBBlockTransferSrcAccessOrder); // 36. BBlockTransferSrcAccessOrder
|
||||
oss << "," << kBBlockTransferSrcVectorDim; // 37. BBlockTransferSrcVectorDim
|
||||
oss << "," << kBBlockTransferSrcScalarPerVector; // 38. BBlockTransferSrcScalarPerVector
|
||||
oss << ","
|
||||
<< kBBlockTransferDstScalarPerVectorK1; // 39. BBlockTransferDstScalarPerVector_BK1
|
||||
oss << "," << kBBlockLdsExtraN; // 40. BBlockLdsExtraN
|
||||
oss << "," << kCShuffleMXdlPerWavePerShuffle; // 41. CShuffleMXdlPerWavePerShuffle
|
||||
oss << "," << kCShuffleNXdlPerWavePerShuffle; // 42. CShuffleNXdlPerWavePerShuffle
|
||||
oss << ","
|
||||
<< detail::array_to_string(
|
||||
kCThreadClusterLengths); // 43. CDEBlockTransferClusterLengths
|
||||
oss << ","
|
||||
<< kCBlockTransferScalarPerVector; // 44. CDEBlockTransferScalarPerVector_NPerBlock
|
||||
oss << "," << detail::pipeline_scheduler_name(kPipelineScheduler); // 45. BlkGemmPipeSched
|
||||
oss << "," << detail::pipeline_version_name(kPipelineVersion); // 46. BlkGemmPipelineVer
|
||||
oss << "," << detail::type_name<AComputeDataType>(); // 47. AComputeDataType
|
||||
oss << "," << detail::type_name<BComputeDataType>(); // 48. BComputeDataType
|
||||
oss << ">";
|
||||
|
||||
return oss.str();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile::reflect
|
||||
@@ -0,0 +1,195 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
// Utility functions and helpers for instance_traits.hpp
|
||||
// Contains helper functions to convert types, enums, and sequences to string representations.
|
||||
// The helper function are consteval so that unknown cases cause compile-time errors.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <array>
|
||||
#include <string>
|
||||
#include <string_view>
|
||||
#include <sstream>
|
||||
#include <type_traits>
|
||||
#include <ck/utility/data_type.hpp>
|
||||
#include <ck/utility/sequence.hpp>
|
||||
#include <ck/utility/blkgemmpipe_scheduler.hpp>
|
||||
#include <ck/tensor_operation/gpu/device/tensor_layout.hpp>
|
||||
#include <ck/tensor_operation/gpu/element/element_wise_operation.hpp>
|
||||
#include <ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp>
|
||||
#include <ck/tensor_operation/gpu/device/gemm_specialization.hpp>
|
||||
|
||||
namespace ck_tile::reflect::detail {
|
||||
|
||||
// Metaprogramming helper to convert ck::Sequence to constexpr std::array
|
||||
template <typename Seq>
|
||||
struct SequenceToArray;
|
||||
|
||||
template <ck::index_t... Is>
|
||||
struct SequenceToArray<ck::Sequence<Is...>>
|
||||
{
|
||||
static constexpr std::array<int, sizeof...(Is)> value = {static_cast<int>(Is)...};
|
||||
};
|
||||
|
||||
// Convert data types to string names
|
||||
template <typename T>
|
||||
consteval std::string_view type_name()
|
||||
{
|
||||
if constexpr(std::is_same_v<T, ck::half_t>)
|
||||
return "fp16";
|
||||
else if constexpr(std::is_same_v<T, float>)
|
||||
return "fp32";
|
||||
else if constexpr(std::is_same_v<T, double>)
|
||||
return "fp64";
|
||||
else if constexpr(std::is_same_v<T, int8_t>)
|
||||
return "s8";
|
||||
else if constexpr(std::is_same_v<T, int32_t>)
|
||||
return "s32";
|
||||
else if constexpr(std::is_same_v<T, ck::bhalf_t>)
|
||||
return "bf16";
|
||||
else if constexpr(std::is_same_v<T, ck::f8_t>)
|
||||
return "fp8";
|
||||
else if constexpr(std::is_same_v<T, ck::bf8_t>)
|
||||
return "bf8";
|
||||
else
|
||||
static_assert(false, "unknown_type");
|
||||
}
|
||||
|
||||
// Convert layout types to string names
|
||||
template <typename T>
|
||||
constexpr std::string_view layout_name()
|
||||
{
|
||||
// Convolution layouts
|
||||
if constexpr(std::is_same_v<T, ck::tensor_layout::convolution::GNHWC>)
|
||||
return "GNHWC";
|
||||
else if constexpr(std::is_same_v<T, ck::tensor_layout::convolution::GKYXC>)
|
||||
return "GKYXC";
|
||||
else if constexpr(std::is_same_v<T, ck::tensor_layout::convolution::GNHWK>)
|
||||
return "GNHWK";
|
||||
else if constexpr(std::is_same_v<T, ck::tensor_layout::convolution::GKZYXC>)
|
||||
return "GKZYXC";
|
||||
else if constexpr(std::is_same_v<T, ck::tensor_layout::convolution::GNDHWC>)
|
||||
return "GNDHWC";
|
||||
else if constexpr(std::is_same_v<T, ck::tensor_layout::convolution::GNDHWK>)
|
||||
return "GNDHWK";
|
||||
else if constexpr(std::is_same_v<T, ck::tensor_layout::convolution::NHWGC>)
|
||||
return "NHWGC";
|
||||
else if constexpr(std::is_same_v<T, ck::tensor_layout::convolution::KYXGC>)
|
||||
return "KYXGC";
|
||||
else if constexpr(std::is_same_v<T, ck::tensor_layout::convolution::NHWGK>)
|
||||
return "NHWGK";
|
||||
else
|
||||
static_assert(false, "unknown_layout");
|
||||
}
|
||||
|
||||
// Convert element-wise operation types to string names
|
||||
template <typename T>
|
||||
constexpr std::string_view elementwise_op_name()
|
||||
{
|
||||
if constexpr(std::is_same_v<T, ck::tensor_operation::element_wise::PassThrough>)
|
||||
return "PassThrough";
|
||||
else if constexpr(std::is_same_v<T, ck::tensor_operation::element_wise::Scale>)
|
||||
return "Scale";
|
||||
else if constexpr(std::is_same_v<T, ck::tensor_operation::element_wise::Bilinear>)
|
||||
return "Bilinear";
|
||||
else if constexpr(std::is_same_v<T, ck::tensor_operation::element_wise::Add>)
|
||||
return "Add";
|
||||
else if constexpr(std::is_same_v<T, ck::tensor_operation::element_wise::AddRelu>)
|
||||
return "AddRelu";
|
||||
else if constexpr(std::is_same_v<T, ck::tensor_operation::element_wise::Relu>)
|
||||
return "Relu";
|
||||
else
|
||||
static_assert(false, "unknown_op");
|
||||
}
|
||||
|
||||
// Convert ConvolutionForwardSpecialization enum to string
|
||||
constexpr std::string_view
|
||||
conv_fwd_spec_name(ck::tensor_operation::device::ConvolutionForwardSpecialization spec)
|
||||
{
|
||||
using ck::tensor_operation::device::ConvolutionForwardSpecialization;
|
||||
switch(spec)
|
||||
{
|
||||
case ConvolutionForwardSpecialization::Default: return "Default";
|
||||
case ConvolutionForwardSpecialization::Filter1x1Stride1Pad0: return "Filter1x1Stride1Pad0";
|
||||
case ConvolutionForwardSpecialization::Filter1x1Pad0: return "Filter1x1Pad0";
|
||||
case ConvolutionForwardSpecialization::Filter3x3: return "Filter3x3";
|
||||
case ConvolutionForwardSpecialization::OddC: return "OddC";
|
||||
}
|
||||
}
|
||||
|
||||
// Convert GemmSpecialization enum to string
|
||||
constexpr std::string_view gemm_spec_name(ck::tensor_operation::device::GemmSpecialization spec)
|
||||
{
|
||||
using ck::tensor_operation::device::GemmSpecialization;
|
||||
switch(spec)
|
||||
{
|
||||
case GemmSpecialization::Default: return "Default";
|
||||
case GemmSpecialization::MPadding: return "MPadding";
|
||||
case GemmSpecialization::NPadding: return "NPadding";
|
||||
case GemmSpecialization::KPadding: return "KPadding";
|
||||
case GemmSpecialization::MNPadding: return "MNPadding";
|
||||
case GemmSpecialization::MKPadding: return "MKPadding";
|
||||
case GemmSpecialization::NKPadding: return "NKPadding";
|
||||
case GemmSpecialization::MNKPadding: return "MNKPadding";
|
||||
case GemmSpecialization::OPadding: return "OPadding";
|
||||
case GemmSpecialization::MOPadding: return "MOPadding";
|
||||
case GemmSpecialization::NOPadding: return "NOPadding";
|
||||
case GemmSpecialization::KOPadding: return "KOPadding";
|
||||
case GemmSpecialization::MNOPadding: return "MNOPadding";
|
||||
case GemmSpecialization::MKOPadding: return "MKOPadding";
|
||||
case GemmSpecialization::NKOPadding: return "NKOPadding";
|
||||
case GemmSpecialization::MNKOPadding: return "MNKOPadding";
|
||||
}
|
||||
}
|
||||
|
||||
// Convert BlockGemmPipelineScheduler enum to string
|
||||
constexpr std::string_view pipeline_scheduler_name(ck::BlockGemmPipelineScheduler sched)
|
||||
{
|
||||
using ck::BlockGemmPipelineScheduler;
|
||||
switch(sched)
|
||||
{
|
||||
case BlockGemmPipelineScheduler::Intrawave: return "Intrawave";
|
||||
case BlockGemmPipelineScheduler::Interwave: return "Interwave";
|
||||
}
|
||||
}
|
||||
|
||||
// Convert BlockGemmPipelineVersion enum to string
|
||||
constexpr std::string_view pipeline_version_name(ck::BlockGemmPipelineVersion ver)
|
||||
{
|
||||
using ck::BlockGemmPipelineVersion;
|
||||
switch(ver)
|
||||
{
|
||||
case BlockGemmPipelineVersion::v1: return "v1";
|
||||
case BlockGemmPipelineVersion::v2: return "v2";
|
||||
case BlockGemmPipelineVersion::v3: return "v3";
|
||||
case BlockGemmPipelineVersion::v4: return "v4";
|
||||
case BlockGemmPipelineVersion::v5: return "v5";
|
||||
}
|
||||
}
|
||||
|
||||
// Convert std::array to string
|
||||
template <typename T, std::size_t N>
|
||||
inline std::string array_to_string(const std::array<T, N>& arr)
|
||||
{
|
||||
std::ostringstream oss;
|
||||
oss << "Seq(";
|
||||
for(std::size_t i = 0; i < arr.size(); ++i)
|
||||
{
|
||||
if(i > 0)
|
||||
oss << ",";
|
||||
oss << arr[i];
|
||||
}
|
||||
oss << ")";
|
||||
return oss.str();
|
||||
}
|
||||
|
||||
// Handle ck::Tuple (empty tuple for DsLayout/DsDataType)
|
||||
template <typename T>
|
||||
constexpr std::string_view tuple_name()
|
||||
{
|
||||
// For now, just check if it's an empty tuple
|
||||
return "EmptyTuple";
|
||||
}
|
||||
|
||||
} // namespace ck_tile::reflect::detail
|
||||
Reference in New Issue
Block a user