[rocm-libraries] ROCm/rocm-libraries#5284 (commit 76b5b15)

[CK_BUILDER] Add
 DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3 to CK Builder (#5284)

Add factory, InstanceTraits, and conv traits support for the WMMA V3
forward convolution kernel, enabling the CK Builder to generate and
dispatch this kernel variant used by MIOpen on gfx11/gfx12 GPUs.

## Motivation

As reported in issue #4944, MIOpen includes WMMA V3 forward convolution
kernels, so this PR adds support for those kernels similarly to other
supported kernels.

## Technical Details

This follows the same implementation as the other kernels. I added some
support for reflection, but I left a few todos since we need to
generalize our convolution traits to generalize across WMMA/MFMA and
CK/CKTile.

## Test Plan

Added faster tests to `ninja smoke-builder` that check the
instance-traits logic, and I added longer tests that instantiate
kernels, following the existing pattern in other kernals.

## Test Result

I tested all code with `ninja check-builder` on a gfx1101 build and ran
on gfx1101.

Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
John Shumway
2026-03-10 23:43:03 +00:00
committed by assistant-librarian[bot]
parent 26d29374e5
commit 9f47b8a63d
15 changed files with 916 additions and 0 deletions

View File

@@ -76,6 +76,13 @@ concept FwdXdlV3Algorithm =
SpecifiesGridwiseFwdXdlGemm<T> && SpecifiesFwdConvSpecialization<T> &&
SpecifiesGemmSpecialization<T> && SpecifiesBlockGemm<T> && SpecifiesNumGroupsToMerge<T>;
// FWD WMMA V3 algorithm concept
template <typename T>
concept FwdWmmaV3Algorithm =
ConvAlgorithmDescriptor<T> && SpecifiesThreadBlock<T> && SpecifiesTileTransferParameters3D<T> &&
SpecifiesGridwiseWmmaGemm<T> && SpecifiesFwdConvSpecialization<T> &&
SpecifiesGemmSpecialization<T> && SpecifiesBlockGemm<T> && SpecifiesNumGroupsToMerge<T>;
// FWD WMMA algorithm concepts
template <typename T>
concept FwdWmmaAlgorithm =

View File

@@ -64,6 +64,7 @@
#include "ck_tile/builder/factory/conv_fwd_v3_factory.hpp"
#include "ck_tile/builder/factory/conv_fwd_xdl_factory.hpp"
#include "ck_tile/builder/factory/conv_fwd_wmma_factory.hpp"
#include "ck_tile/builder/factory/conv_fwd_wmma_v3_factory.hpp"
#include "ck_tile/builder/factory/conv_fwd_dl_factory.hpp"
#include "ck_tile/builder/factory/conv_fwd_large_tensor_factory.hpp"
#include "ck_tile/builder/factory/reference_factory.hpp"
@@ -130,6 +131,10 @@ constexpr auto make_conv_instance()
{
return typename ConvFwdXdlFactory<SIGNATURE, ALGORITHM, VERSION>::Instance{};
}
else if constexpr(FwdWmmaV3Algorithm<AlgoType>)
{
return typename ConvFwdWmmaV3Factory<SIGNATURE, ALGORITHM, VERSION>::Instance{};
}
else if constexpr(FwdWmmaAlgorithm<AlgoType>)
{
return typename ConvFwdWmmaFactory<SIGNATURE, ALGORITHM, VERSION>::Instance{};

View File

@@ -0,0 +1,159 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_wmma_cshuffle_v3.hpp"
#include "ck_tile/builder/conv_signature_concepts.hpp"
#include "ck_tile/builder/conv_algorithm_concepts.hpp"
#include "ck_tile/builder/conv_algorithm_limits.hpp"
#include "ck_tile/builder/builder_utils.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_tensor_layout.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_tensor_type.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_elementwise_op.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_tuning_params.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_block_transfer.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_thread_block.hpp"
namespace ck_tile::builder::factory {
// Factory for DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3 instance
// of a grouped forward convolution kernel.
template <ConvSignatureDescriptor auto SIGNATURE,
ConvAlgorithmDescriptor auto ALGORITHM,
StringLiteral VERSION>
requires ConvDirectionIsForward<SIGNATURE>
struct ConvFwdWmmaV3Factory
{
static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim;
using Layouts = internal::ConvTensorLayouts<SIGNATURE>;
using Types = internal::ConvTensorDataTypes<SIGNATURE>;
using Ops = internal::ConvElementwiseOps<SIGNATURE>;
using AlgorithmType = decltype(ALGORITHM);
static_assert(ALGORITHM.transfer.a.lds_transfer.is_direct_load ==
ALGORITHM.transfer.b.lds_transfer.is_direct_load,
"A and B block transfers must both be direct load or not.");
static constexpr auto FWD_CONV_SPECIALIZATION = internal::SetFwdConvSpecialization<ALGORITHM>();
static constexpr auto GEMM_SPECIALIZATION = internal::SetGemmSpecialization<ALGORITHM>();
static constexpr internal::ConvSpec SPECIALIZATION{.conv_spec = FWD_CONV_SPECIALIZATION,
.gemm_spec = GEMM_SPECIALIZATION};
static constexpr auto BLOCK = internal::SetThreadBlockInfo<ALGORITHM>();
static constexpr auto GRIDWISE_GEMM = ALGORITHM.gridwise_gemm;
static constexpr auto A_BLOCK_TRANSFER =
internal::SetFwdConvBlockTransfer<ALGORITHM.transfer.a>();
static constexpr auto B_BLOCK_TRANSFER =
internal::SetFwdConvBlockTransfer<ALGORITHM.transfer.b>();
static constexpr auto C_BLOCK_TRANSFER = internal::SetCBlockTransfer<SIGNATURE, ALGORITHM>();
static constexpr auto BLOCK_GEMM = internal::SetBlockGemm<ALGORITHM>();
// Check limits for the algorithm parameters.
static_assert(ValidABlockTransfer<A_BLOCK_TRANSFER,
Types::input_types.first,
sizeof(typename Types::InDataType),
BLOCK.block_size,
BLOCK.per_block>);
static_assert(ValidBBlockTransfer<B_BLOCK_TRANSFER,
Types::weight_types.first,
sizeof(typename Types::WeiDataType),
BLOCK.block_size,
BLOCK.per_block>);
static_assert(ValidCBlockTransfer<C_BLOCK_TRANSFER,
Types::output_types.first,
BLOCK.block_size,
BLOCK.per_block>);
// Layout validations
using enum TensorLayout;
static_assert(IsValidLayout<SIGNATURE.input.config.layout,
G_NW_C_strided,
G_NHW_C_strided,
G_NDHW_C_strided,
GNWC,
GNHWC,
GNDHWC,
NWGC,
NHWGC,
NDHWGC> &&
A_BLOCK_TRANSFER.src_vector_dim == 2);
static_assert(IsValidLayout<SIGNATURE.weight.config.layout,
G_K_X_C_strided,
G_K_YX_C_strided,
G_K_ZYX_C_strided,
GKXC,
GKYXC,
GKZYXC,
KXGC,
KYXGC,
KZYXGC> &&
B_BLOCK_TRANSFER.src_vector_dim == 2);
static_assert(IsValidLayout<SIGNATURE.output.config.layout,
G_NW_K_strided,
G_NHW_K_strided,
G_NDHW_K_strided,
GNWK,
GNHWK,
GNDHWK,
NWGK,
NHWGK,
NDHWGK>);
// The forward convolution kernel class instance.
using Instance = ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3<
SPATIAL_DIM,
typename Layouts::InLayout,
typename Layouts::WeiLayout,
typename Layouts::DsLayout,
typename Layouts::OutLayout,
typename Types::InDataType,
typename Types::WeiDataType,
typename Types::AccDataType,
typename Types::OutComputeType,
typename Types::DsDataType,
typename Types::OutDataType,
typename Ops::InElementwiseOp,
typename Ops::WeiElementwiseOp,
typename Ops::OutElementwiseOp,
SPECIALIZATION.conv_spec,
SPECIALIZATION.gemm_spec,
BLOCK.block_size,
BLOCK.per_block.m,
BLOCK.per_block.n,
BLOCK.per_block.k,
GRIDWISE_GEMM.ak1,
GRIDWISE_GEMM.bk1,
GRIDWISE_GEMM.m_per_wmma,
GRIDWISE_GEMM.n_per_wmma,
GRIDWISE_GEMM.m_wmma_per_wave,
GRIDWISE_GEMM.n_wmma_per_wave,
to_sequence_v<A_BLOCK_TRANSFER.thread_cluster_dims>,
to_sequence_v<A_BLOCK_TRANSFER.thread_cluster_order>,
to_sequence_v<A_BLOCK_TRANSFER.src_access_order>,
A_BLOCK_TRANSFER.src_vector_dim,
A_BLOCK_TRANSFER.src_scalar_per_vector,
A_BLOCK_TRANSFER.lds_dst_scalar_per_vector,
static_cast<ck::index_t>(A_BLOCK_TRANSFER.lds_padding),
to_sequence_v<B_BLOCK_TRANSFER.thread_cluster_dims>,
to_sequence_v<B_BLOCK_TRANSFER.thread_cluster_order>,
to_sequence_v<B_BLOCK_TRANSFER.src_access_order>,
B_BLOCK_TRANSFER.src_vector_dim,
B_BLOCK_TRANSFER.src_scalar_per_vector,
B_BLOCK_TRANSFER.lds_dst_scalar_per_vector,
static_cast<ck::index_t>(B_BLOCK_TRANSFER.lds_padding),
C_BLOCK_TRANSFER.m_xdl_per_wave_per_shuffle,
C_BLOCK_TRANSFER.n_xdl_per_wave_per_shuffle,
to_sequence_v<C_BLOCK_TRANSFER.thread_cluster_dims>,
C_BLOCK_TRANSFER.scalar_per_vector,
BLOCK_GEMM.scheduler,
BLOCK_GEMM.pipeline_version,
true, // UseThreadTileTransfer
typename Types::InComputeType,
typename Types::WeiComputeType,
ALGORITHM.num_conv_groups_to_merge>;
};
} // namespace ck_tile::builder::factory

View File

@@ -0,0 +1,48 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include <concepts>
#include "ck_tile/builder/reflect/conv_traits.hpp"
#include "ck_tile/builder/reflect/conv_traits_helpers.hpp"
#include "ck_tile/builder/reflect/instance_traits.hpp"
#include "ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_abd_wmma_cshuffle_v3.hpp"
namespace ck_tile::reflect::conv {
/// @brief Tag dispatch implementation for DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3
template <typename Instance>
requires HasInstanceTraits<Instance> &&
std::same_as<typename InstanceTraits<Instance>::device_kernel_tag,
DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3_Tag>
constexpr ConvTraits instance_to_conv_traits()
{
using InstTraits = InstanceTraits<Instance>;
return ConvTraits{
.spatial_dim = InstTraits::kSpatialDim,
.direction = conv_direction<Instance>(),
.layout = fwd_conv_layout<Instance>(),
.data_type = conv_data_type<typename InstTraits::ADataType>(),
.input_element_op = elementwise_op<typename InstTraits::AElementwiseOperation>(),
.weight_element_op = elementwise_op<typename InstTraits::BElementwiseOperation>(),
.output_element_op = elementwise_op<typename InstTraits::CDEElementwiseOperation>(),
.gemm_padding = gemm_spec<Instance>(),
.conv_specialization = conv_spec<Instance>(),
.thread_block_size = InstTraits::kBlockSize,
.tile_dims = conv_traits_data_tile<InstTraits>(),
.a_tile_transfer = conv_traits_a_transfer_params<InstTraits>(InstTraits::kAK1),
.b_tile_transfer = conv_traits_b_transfer_params<InstTraits>(InstTraits::kBK1),
.warp_gemm = conv_traits_wmma_warp_gemm_params<InstTraits>(),
.c_tile_transfer = conv_traits_wmma_c_tile_transfer<InstTraits>(),
// TODO: Add compute types (AComputeDataType, BComputeDataType) when ConvTraits supports
// them
// TODO: Add NumGroupsToMerge when ConvTraits supports it
.pipeline_version = get_pipeline_version<InstTraits>(),
.pipeline_scheduler = get_pipeline_scheduler<InstTraits>(),
};
}
} // namespace ck_tile::reflect::conv

View File

@@ -8,6 +8,7 @@
#include "ck_tile/builder/reflect/conv_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp"
#include "ck_tile/builder/reflect/conv_traits_device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp"
#include "ck_tile/builder/reflect/conv_traits_device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp"
#include "ck_tile/builder/reflect/conv_traits_device_grouped_conv_fwd_multiple_abd_wmma_cshuffle_v3.hpp"
// Bwd weight instances
#include "ck_tile/builder/reflect/conv_traits_device_grouped_conv_bwd_weight_xdl_cshuffle.hpp"

View File

@@ -0,0 +1,16 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "instance_traits.hpp"
#include "instance_traits_util.hpp"
namespace ck_tile::reflect {
/// @brief Tag type for DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3 device kernel
struct DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3_Tag
{
};
} // namespace ck_tile::reflect

View File

@@ -0,0 +1,302 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
// InstanceTraits specialization for DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3
//
// This .inc file is #included at the bottom of the device op header
// (device_grouped_conv_fwd_multiple_abd_wmma_cshuffle_v3.hpp) under
// #ifdef CK_EXPERIMENTAL_BUILDER, AFTER the struct is fully defined.
// This eliminates the need for forward declarations.
//
// CRITICAL MAINTENANCE NOTE:
// This file MUST be kept strictly in sync with the device implementation header:
// ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_wmma_cshuffle_v3.hpp
// The template parameter order, names, and types MUST EXACTLY MATCH those in the device
// implementation.
#include "ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_abd_wmma_cshuffle_v3.hpp"
namespace ck_tile::reflect {
// Specialization for DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3
template <ck::index_t NDimSpatial,
typename ALayout_,
typename BLayout_,
typename DsLayout_,
typename ELayout_,
typename ADataType_,
typename BDataType_,
typename AccDataType_,
typename CShuffleDataType_,
typename DsDataType_,
typename EDataType_,
typename AElementwiseOperation_,
typename BElementwiseOperation_,
typename CDEElementwiseOperation_,
ck::tensor_operation::device::ConvolutionForwardSpecialization ConvForwardSpecialization,
ck::tensor_operation::device::GemmSpecialization GemmSpec,
ck::index_t BlockSize,
ck::index_t MPerBlock,
ck::index_t NPerBlock,
ck::index_t KPerBlock,
ck::index_t AK1,
ck::index_t BK1,
ck::index_t MPerWmma,
ck::index_t NPerWmma,
ck::index_t MRepeat,
ck::index_t NRepeat,
typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder_,
ck::index_t ABlockTransferSrcVectorDim,
ck::index_t ABlockTransferSrcScalarPerVector,
ck::index_t ABlockTransferDstScalarPerVector_AK1,
ck::index_t ABlockLdsExtraM,
typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
typename BBlockTransferThreadClusterArrangeOrder,
typename BBlockTransferSrcAccessOrder_,
ck::index_t BBlockTransferSrcVectorDim,
ck::index_t BBlockTransferSrcScalarPerVector,
ck::index_t BBlockTransferDstScalarPerVector_BK1,
ck::index_t BBlockLdsExtraN,
ck::index_t CShuffleMRepeatPerShuffle,
ck::index_t CShuffleNRepeatPerShuffle,
typename CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
ck::index_t CDEBlockTransferScalarPerVector_NPerBlock,
ck::BlockGemmPipelineScheduler BlkGemmPipeSched,
ck::BlockGemmPipelineVersion BlkGemmPipelineVer,
bool UseThreadTileTransfer,
typename AComputeDataType_,
typename BComputeDataType_,
ck::index_t NumGroupsToMerge>
struct InstanceTraits<
ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3<
NDimSpatial,
ALayout_,
BLayout_,
DsLayout_,
ELayout_,
ADataType_,
BDataType_,
AccDataType_,
CShuffleDataType_,
DsDataType_,
EDataType_,
AElementwiseOperation_,
BElementwiseOperation_,
CDEElementwiseOperation_,
ConvForwardSpecialization,
GemmSpec,
BlockSize,
MPerBlock,
NPerBlock,
KPerBlock,
AK1,
BK1,
MPerWmma,
NPerWmma,
MRepeat,
NRepeat,
ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder_,
ABlockTransferSrcVectorDim,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_AK1,
ABlockLdsExtraM,
BBlockTransferThreadClusterLengths_BK0_N_BK1,
BBlockTransferThreadClusterArrangeOrder,
BBlockTransferSrcAccessOrder_,
BBlockTransferSrcVectorDim,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_BK1,
BBlockLdsExtraN,
CShuffleMRepeatPerShuffle,
CShuffleNRepeatPerShuffle,
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
CDEBlockTransferScalarPerVector_NPerBlock,
BlkGemmPipeSched,
BlkGemmPipelineVer,
UseThreadTileTransfer,
AComputeDataType_,
BComputeDataType_,
NumGroupsToMerge>>
{
/// @brief Tag type identifying this device kernel variant
using device_kernel_tag = DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3_Tag;
// Spatial dimension
static constexpr int kSpatialDim = NDimSpatial;
// Layout types
using ALayout = ALayout_;
using BLayout = BLayout_;
using DsLayout = DsLayout_;
using ELayout = ELayout_;
// Data types
using ADataType = ADataType_;
using BDataType = BDataType_;
using AccDataType = AccDataType_;
using CShuffleDataType = CShuffleDataType_;
using DsDataType = DsDataType_;
using EDataType = EDataType_;
// Element-wise operations
using AElementwiseOperation = AElementwiseOperation_;
using BElementwiseOperation = BElementwiseOperation_;
using CDEElementwiseOperation = CDEElementwiseOperation_;
// Specialization
static constexpr ck::tensor_operation::device::ConvolutionForwardSpecialization
kConvForwardSpecialization = ConvForwardSpecialization;
static constexpr ck::tensor_operation::device::GemmSpecialization kGemmSpecialization =
GemmSpec;
// Block configuration
static constexpr int kBlockSize = BlockSize;
static constexpr int kMPerBlock = MPerBlock;
static constexpr int kNPerBlock = NPerBlock;
static constexpr int kKPerBlock = KPerBlock;
// Tuning parameters
static constexpr int kAK1 = AK1;
static constexpr int kBK1 = BK1;
static constexpr int kMPerWmma = MPerWmma;
static constexpr int kNPerWmma = NPerWmma;
static constexpr int kMRepeat = MRepeat;
static constexpr int kNRepeat = NRepeat;
// A block transfer thread cluster dimensions (converted to std::array)
static constexpr auto kAThreadClusterLengths =
detail::SequenceToArray<ABlockTransferThreadClusterLengths_AK0_M_AK1>::value;
static constexpr auto kAThreadClusterArrangeOrder =
detail::SequenceToArray<ABlockTransferThreadClusterArrangeOrder>::value;
static constexpr auto kABlockTransferSrcAccessOrder =
detail::SequenceToArray<ABlockTransferSrcAccessOrder_>::value;
static constexpr int kABlockTransferSrcVectorDim = ABlockTransferSrcVectorDim;
static constexpr int kABlockTransferSrcScalarPerVector = ABlockTransferSrcScalarPerVector;
static constexpr int kABlockTransferDstScalarPerVectorK1 = ABlockTransferDstScalarPerVector_AK1;
static constexpr int kABlockLdsExtraM = ABlockLdsExtraM;
// B block transfer thread cluster dimensions (converted to std::array)
static constexpr auto kBThreadClusterLengths =
detail::SequenceToArray<BBlockTransferThreadClusterLengths_BK0_N_BK1>::value;
static constexpr auto kBThreadClusterArrangeOrder =
detail::SequenceToArray<BBlockTransferThreadClusterArrangeOrder>::value;
static constexpr auto kBBlockTransferSrcAccessOrder =
detail::SequenceToArray<BBlockTransferSrcAccessOrder_>::value;
static constexpr int kBBlockTransferSrcVectorDim = BBlockTransferSrcVectorDim;
static constexpr int kBBlockTransferSrcScalarPerVector = BBlockTransferSrcScalarPerVector;
static constexpr int kBBlockTransferDstScalarPerVectorK1 = BBlockTransferDstScalarPerVector_BK1;
static constexpr int kBBlockLdsExtraN = BBlockLdsExtraN;
// C shuffle parameters (converted to std::array)
static constexpr int kCShuffleMRepeatPerShuffle = CShuffleMRepeatPerShuffle;
static constexpr int kCShuffleNRepeatPerShuffle = CShuffleNRepeatPerShuffle;
static constexpr auto kCDEThreadClusterLengths = detail::SequenceToArray<
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock>::value;
static constexpr int kCDEBlockTransferScalarPerVector =
CDEBlockTransferScalarPerVector_NPerBlock;
// Pipeline configuration
static constexpr ck::BlockGemmPipelineScheduler kPipelineScheduler = BlkGemmPipeSched;
static constexpr ck::BlockGemmPipelineVersion kPipelineVersion = BlkGemmPipelineVer;
static constexpr bool kUseThreadTileTransfer = UseThreadTileTransfer;
// Compute data types
using AComputeDataType = AComputeDataType_;
using BComputeDataType = BComputeDataType_;
static constexpr int kNumGroupsToMerge = NumGroupsToMerge;
// Static member function to generate instance string
static std::string instance_string()
{
std::ostringstream oss;
// Kernel type name
oss << "DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3";
// Template parameters in exact order
oss << "<" << kSpatialDim; // 1. NDimSpatial
oss << "," << detail::layout_name<ALayout>(); // 2. ALayout
oss << "," << detail::layout_name<BLayout>(); // 3. BLayout
oss << "," << detail::tuple_name<DsLayout>(); // 4. DsLayout
oss << "," << detail::layout_name<ELayout>(); // 5. ELayout
oss << "," << detail::type_or_type_tuple_name<ADataType>(); // 6. ADataType
oss << "," << detail::type_or_type_tuple_name<BDataType>(); // 7. BDataType
oss << "," << detail::type_name<AccDataType>(); // 8. AccDataType
oss << "," << detail::type_name<CShuffleDataType>(); // 9. CShuffleDataType
oss << "," << detail::tuple_name<DsDataType>(); // 10. DsDataType
oss << "," << detail::type_name<EDataType>(); // 11. EDataType
oss << ","
<< detail::elementwise_op_name<AElementwiseOperation>(); // 12. AElementwiseOperation
oss << ","
<< detail::elementwise_op_name<BElementwiseOperation>(); // 13. BElementwiseOperation
oss << ","
<< detail::elementwise_op_name<CDEElementwiseOperation>(); // 14.
// CDEElementwiseOperation
oss << ","
<< detail::conv_fwd_spec_name(
kConvForwardSpecialization); // 15. ConvForwardSpecialization
oss << "," << detail::gemm_spec_name(kGemmSpecialization); // 16. GemmSpec
oss << "," << kBlockSize; // 17. BlockSize
oss << "," << kMPerBlock; // 18. MPerBlock
oss << "," << kNPerBlock; // 19. NPerBlock
oss << "," << kKPerBlock; // 20. KPerBlock
oss << "," << kAK1; // 21. AK1
oss << "," << kBK1; // 22. BK1
oss << "," << kMPerWmma; // 23. MPerWmma
oss << "," << kNPerWmma; // 24. NPerWmma
oss << "," << kMRepeat; // 25. MRepeat
oss << "," << kNRepeat; // 26. NRepeat
oss << ","
<< detail::array_to_string(
kAThreadClusterLengths); // 27. ABlockTransferThreadClusterLengths
oss << ","
<< detail::array_to_string(
kAThreadClusterArrangeOrder); // 28. ABlockTransferThreadClusterArrangeOrder
oss << ","
<< detail::array_to_string(
kABlockTransferSrcAccessOrder); // 29. ABlockTransferSrcAccessOrder
oss << "," << kABlockTransferSrcVectorDim; // 30. ABlockTransferSrcVectorDim
oss << "," << kABlockTransferSrcScalarPerVector; // 31. ABlockTransferSrcScalarPerVector
oss << ","
<< kABlockTransferDstScalarPerVectorK1; // 32. ABlockTransferDstScalarPerVector_AK1
oss << "," << (kABlockLdsExtraM ? "true" : "false"); // 33. ABlockLdsExtraM
oss << ","
<< detail::array_to_string(
kBThreadClusterLengths); // 34. BBlockTransferThreadClusterLengths
oss << ","
<< detail::array_to_string(
kBThreadClusterArrangeOrder); // 35. BBlockTransferThreadClusterArrangeOrder
oss << ","
<< detail::array_to_string(
kBBlockTransferSrcAccessOrder); // 36. BBlockTransferSrcAccessOrder
oss << "," << kBBlockTransferSrcVectorDim; // 37. BBlockTransferSrcVectorDim
oss << "," << kBBlockTransferSrcScalarPerVector; // 38. BBlockTransferSrcScalarPerVector
oss << ","
<< kBBlockTransferDstScalarPerVectorK1; // 39. BBlockTransferDstScalarPerVector_BK1
oss << "," << (kBBlockLdsExtraN ? "true" : "false"); // 40. BBlockLdsExtraN
oss << "," << kCShuffleMRepeatPerShuffle; // 41. CShuffleMRepeatPerShuffle
oss << "," << kCShuffleNRepeatPerShuffle; // 42. CShuffleNRepeatPerShuffle
oss << ","
<< detail::array_to_string(
kCDEThreadClusterLengths); // 43. CDEBlockTransferClusterLengths
oss << ","
<< kCDEBlockTransferScalarPerVector; // 44. CDEBlockTransferScalarPerVector_NPerBlock
oss << "," << detail::pipeline_scheduler_name(kPipelineScheduler); // 45. BlkGemmPipeSched
oss << "," << detail::pipeline_version_name(kPipelineVersion); // 46. BlkGemmPipelineVer
oss << "," << (kUseThreadTileTransfer ? "true" : "false"); // 47. UseThreadTileTransfer
oss << "," << detail::type_name<AComputeDataType>(); // 48. AComputeDataType
oss << "," << detail::type_name<BComputeDataType>(); // 49. BComputeDataType
oss << "," << kNumGroupsToMerge; // 50. NumGroupsToMerge
oss << ">";
return oss.str();
}
};
} // namespace ck_tile::reflect