mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 22:39:03 +00:00
[rocm-libraries] ROCm/rocm-libraries#5284 (commit 76b5b15)
[CK_BUILDER] Add DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3 to CK Builder (#5284) Add factory, InstanceTraits, and conv traits support for the WMMA V3 forward convolution kernel, enabling the CK Builder to generate and dispatch this kernel variant used by MIOpen on gfx11/gfx12 GPUs. ## Motivation As reported in issue #4944, MIOpen includes WMMA V3 forward convolution kernels, so this PR adds support for those kernels similarly to other supported kernels. ## Technical Details This follows the same implementation as the other kernels. I added some support for reflection, but I left a few todos since we need to generalize our convolution traits to generalize across WMMA/MFMA and CK/CKTile. ## Test Plan Added faster tests to `ninja smoke-builder` that check the instance-traits logic, and I added longer tests that instantiate kernels, following the existing pattern in other kernals. ## Test Result I tested all code with `ninja check-builder` on a gfx1101 build and ran on gfx1101. Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
committed by
assistant-librarian[bot]
parent
26d29374e5
commit
9f47b8a63d
@@ -76,6 +76,13 @@ concept FwdXdlV3Algorithm =
|
||||
SpecifiesGridwiseFwdXdlGemm<T> && SpecifiesFwdConvSpecialization<T> &&
|
||||
SpecifiesGemmSpecialization<T> && SpecifiesBlockGemm<T> && SpecifiesNumGroupsToMerge<T>;
|
||||
|
||||
// FWD WMMA V3 algorithm concept
|
||||
template <typename T>
|
||||
concept FwdWmmaV3Algorithm =
|
||||
ConvAlgorithmDescriptor<T> && SpecifiesThreadBlock<T> && SpecifiesTileTransferParameters3D<T> &&
|
||||
SpecifiesGridwiseWmmaGemm<T> && SpecifiesFwdConvSpecialization<T> &&
|
||||
SpecifiesGemmSpecialization<T> && SpecifiesBlockGemm<T> && SpecifiesNumGroupsToMerge<T>;
|
||||
|
||||
// FWD WMMA algorithm concepts
|
||||
template <typename T>
|
||||
concept FwdWmmaAlgorithm =
|
||||
|
||||
@@ -64,6 +64,7 @@
|
||||
#include "ck_tile/builder/factory/conv_fwd_v3_factory.hpp"
|
||||
#include "ck_tile/builder/factory/conv_fwd_xdl_factory.hpp"
|
||||
#include "ck_tile/builder/factory/conv_fwd_wmma_factory.hpp"
|
||||
#include "ck_tile/builder/factory/conv_fwd_wmma_v3_factory.hpp"
|
||||
#include "ck_tile/builder/factory/conv_fwd_dl_factory.hpp"
|
||||
#include "ck_tile/builder/factory/conv_fwd_large_tensor_factory.hpp"
|
||||
#include "ck_tile/builder/factory/reference_factory.hpp"
|
||||
@@ -130,6 +131,10 @@ constexpr auto make_conv_instance()
|
||||
{
|
||||
return typename ConvFwdXdlFactory<SIGNATURE, ALGORITHM, VERSION>::Instance{};
|
||||
}
|
||||
else if constexpr(FwdWmmaV3Algorithm<AlgoType>)
|
||||
{
|
||||
return typename ConvFwdWmmaV3Factory<SIGNATURE, ALGORITHM, VERSION>::Instance{};
|
||||
}
|
||||
else if constexpr(FwdWmmaAlgorithm<AlgoType>)
|
||||
{
|
||||
return typename ConvFwdWmmaFactory<SIGNATURE, ALGORITHM, VERSION>::Instance{};
|
||||
|
||||
@@ -0,0 +1,159 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_wmma_cshuffle_v3.hpp"
|
||||
#include "ck_tile/builder/conv_signature_concepts.hpp"
|
||||
#include "ck_tile/builder/conv_algorithm_concepts.hpp"
|
||||
#include "ck_tile/builder/conv_algorithm_limits.hpp"
|
||||
#include "ck_tile/builder/builder_utils.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck/conv_tensor_layout.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck/conv_tensor_type.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck/conv_elementwise_op.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck/conv_tuning_params.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck/conv_block_transfer.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck/conv_thread_block.hpp"
|
||||
|
||||
namespace ck_tile::builder::factory {
|
||||
|
||||
// Factory for DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3 instance
|
||||
// of a grouped forward convolution kernel.
|
||||
template <ConvSignatureDescriptor auto SIGNATURE,
|
||||
ConvAlgorithmDescriptor auto ALGORITHM,
|
||||
StringLiteral VERSION>
|
||||
requires ConvDirectionIsForward<SIGNATURE>
|
||||
struct ConvFwdWmmaV3Factory
|
||||
{
|
||||
static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim;
|
||||
using Layouts = internal::ConvTensorLayouts<SIGNATURE>;
|
||||
using Types = internal::ConvTensorDataTypes<SIGNATURE>;
|
||||
using Ops = internal::ConvElementwiseOps<SIGNATURE>;
|
||||
using AlgorithmType = decltype(ALGORITHM);
|
||||
|
||||
static_assert(ALGORITHM.transfer.a.lds_transfer.is_direct_load ==
|
||||
ALGORITHM.transfer.b.lds_transfer.is_direct_load,
|
||||
"A and B block transfers must both be direct load or not.");
|
||||
|
||||
static constexpr auto FWD_CONV_SPECIALIZATION = internal::SetFwdConvSpecialization<ALGORITHM>();
|
||||
static constexpr auto GEMM_SPECIALIZATION = internal::SetGemmSpecialization<ALGORITHM>();
|
||||
static constexpr internal::ConvSpec SPECIALIZATION{.conv_spec = FWD_CONV_SPECIALIZATION,
|
||||
.gemm_spec = GEMM_SPECIALIZATION};
|
||||
|
||||
static constexpr auto BLOCK = internal::SetThreadBlockInfo<ALGORITHM>();
|
||||
static constexpr auto GRIDWISE_GEMM = ALGORITHM.gridwise_gemm;
|
||||
static constexpr auto A_BLOCK_TRANSFER =
|
||||
internal::SetFwdConvBlockTransfer<ALGORITHM.transfer.a>();
|
||||
static constexpr auto B_BLOCK_TRANSFER =
|
||||
internal::SetFwdConvBlockTransfer<ALGORITHM.transfer.b>();
|
||||
static constexpr auto C_BLOCK_TRANSFER = internal::SetCBlockTransfer<SIGNATURE, ALGORITHM>();
|
||||
static constexpr auto BLOCK_GEMM = internal::SetBlockGemm<ALGORITHM>();
|
||||
|
||||
// Check limits for the algorithm parameters.
|
||||
static_assert(ValidABlockTransfer<A_BLOCK_TRANSFER,
|
||||
Types::input_types.first,
|
||||
sizeof(typename Types::InDataType),
|
||||
BLOCK.block_size,
|
||||
BLOCK.per_block>);
|
||||
static_assert(ValidBBlockTransfer<B_BLOCK_TRANSFER,
|
||||
Types::weight_types.first,
|
||||
sizeof(typename Types::WeiDataType),
|
||||
BLOCK.block_size,
|
||||
BLOCK.per_block>);
|
||||
static_assert(ValidCBlockTransfer<C_BLOCK_TRANSFER,
|
||||
Types::output_types.first,
|
||||
BLOCK.block_size,
|
||||
BLOCK.per_block>);
|
||||
|
||||
// Layout validations
|
||||
using enum TensorLayout;
|
||||
static_assert(IsValidLayout<SIGNATURE.input.config.layout,
|
||||
G_NW_C_strided,
|
||||
G_NHW_C_strided,
|
||||
G_NDHW_C_strided,
|
||||
GNWC,
|
||||
GNHWC,
|
||||
GNDHWC,
|
||||
NWGC,
|
||||
NHWGC,
|
||||
NDHWGC> &&
|
||||
A_BLOCK_TRANSFER.src_vector_dim == 2);
|
||||
|
||||
static_assert(IsValidLayout<SIGNATURE.weight.config.layout,
|
||||
G_K_X_C_strided,
|
||||
G_K_YX_C_strided,
|
||||
G_K_ZYX_C_strided,
|
||||
GKXC,
|
||||
GKYXC,
|
||||
GKZYXC,
|
||||
KXGC,
|
||||
KYXGC,
|
||||
KZYXGC> &&
|
||||
B_BLOCK_TRANSFER.src_vector_dim == 2);
|
||||
|
||||
static_assert(IsValidLayout<SIGNATURE.output.config.layout,
|
||||
G_NW_K_strided,
|
||||
G_NHW_K_strided,
|
||||
G_NDHW_K_strided,
|
||||
GNWK,
|
||||
GNHWK,
|
||||
GNDHWK,
|
||||
NWGK,
|
||||
NHWGK,
|
||||
NDHWGK>);
|
||||
|
||||
// The forward convolution kernel class instance.
|
||||
using Instance = ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3<
|
||||
SPATIAL_DIM,
|
||||
typename Layouts::InLayout,
|
||||
typename Layouts::WeiLayout,
|
||||
typename Layouts::DsLayout,
|
||||
typename Layouts::OutLayout,
|
||||
typename Types::InDataType,
|
||||
typename Types::WeiDataType,
|
||||
typename Types::AccDataType,
|
||||
typename Types::OutComputeType,
|
||||
typename Types::DsDataType,
|
||||
typename Types::OutDataType,
|
||||
typename Ops::InElementwiseOp,
|
||||
typename Ops::WeiElementwiseOp,
|
||||
typename Ops::OutElementwiseOp,
|
||||
SPECIALIZATION.conv_spec,
|
||||
SPECIALIZATION.gemm_spec,
|
||||
BLOCK.block_size,
|
||||
BLOCK.per_block.m,
|
||||
BLOCK.per_block.n,
|
||||
BLOCK.per_block.k,
|
||||
GRIDWISE_GEMM.ak1,
|
||||
GRIDWISE_GEMM.bk1,
|
||||
GRIDWISE_GEMM.m_per_wmma,
|
||||
GRIDWISE_GEMM.n_per_wmma,
|
||||
GRIDWISE_GEMM.m_wmma_per_wave,
|
||||
GRIDWISE_GEMM.n_wmma_per_wave,
|
||||
to_sequence_v<A_BLOCK_TRANSFER.thread_cluster_dims>,
|
||||
to_sequence_v<A_BLOCK_TRANSFER.thread_cluster_order>,
|
||||
to_sequence_v<A_BLOCK_TRANSFER.src_access_order>,
|
||||
A_BLOCK_TRANSFER.src_vector_dim,
|
||||
A_BLOCK_TRANSFER.src_scalar_per_vector,
|
||||
A_BLOCK_TRANSFER.lds_dst_scalar_per_vector,
|
||||
static_cast<ck::index_t>(A_BLOCK_TRANSFER.lds_padding),
|
||||
to_sequence_v<B_BLOCK_TRANSFER.thread_cluster_dims>,
|
||||
to_sequence_v<B_BLOCK_TRANSFER.thread_cluster_order>,
|
||||
to_sequence_v<B_BLOCK_TRANSFER.src_access_order>,
|
||||
B_BLOCK_TRANSFER.src_vector_dim,
|
||||
B_BLOCK_TRANSFER.src_scalar_per_vector,
|
||||
B_BLOCK_TRANSFER.lds_dst_scalar_per_vector,
|
||||
static_cast<ck::index_t>(B_BLOCK_TRANSFER.lds_padding),
|
||||
C_BLOCK_TRANSFER.m_xdl_per_wave_per_shuffle,
|
||||
C_BLOCK_TRANSFER.n_xdl_per_wave_per_shuffle,
|
||||
to_sequence_v<C_BLOCK_TRANSFER.thread_cluster_dims>,
|
||||
C_BLOCK_TRANSFER.scalar_per_vector,
|
||||
BLOCK_GEMM.scheduler,
|
||||
BLOCK_GEMM.pipeline_version,
|
||||
true, // UseThreadTileTransfer
|
||||
typename Types::InComputeType,
|
||||
typename Types::WeiComputeType,
|
||||
ALGORITHM.num_conv_groups_to_merge>;
|
||||
};
|
||||
|
||||
} // namespace ck_tile::builder::factory
|
||||
@@ -0,0 +1,48 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <concepts>
|
||||
|
||||
#include "ck_tile/builder/reflect/conv_traits.hpp"
|
||||
#include "ck_tile/builder/reflect/conv_traits_helpers.hpp"
|
||||
#include "ck_tile/builder/reflect/instance_traits.hpp"
|
||||
#include "ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_abd_wmma_cshuffle_v3.hpp"
|
||||
|
||||
namespace ck_tile::reflect::conv {
|
||||
|
||||
/// @brief Tag dispatch implementation for DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3
|
||||
template <typename Instance>
|
||||
requires HasInstanceTraits<Instance> &&
|
||||
std::same_as<typename InstanceTraits<Instance>::device_kernel_tag,
|
||||
DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3_Tag>
|
||||
constexpr ConvTraits instance_to_conv_traits()
|
||||
{
|
||||
using InstTraits = InstanceTraits<Instance>;
|
||||
|
||||
return ConvTraits{
|
||||
.spatial_dim = InstTraits::kSpatialDim,
|
||||
.direction = conv_direction<Instance>(),
|
||||
.layout = fwd_conv_layout<Instance>(),
|
||||
.data_type = conv_data_type<typename InstTraits::ADataType>(),
|
||||
.input_element_op = elementwise_op<typename InstTraits::AElementwiseOperation>(),
|
||||
.weight_element_op = elementwise_op<typename InstTraits::BElementwiseOperation>(),
|
||||
.output_element_op = elementwise_op<typename InstTraits::CDEElementwiseOperation>(),
|
||||
.gemm_padding = gemm_spec<Instance>(),
|
||||
.conv_specialization = conv_spec<Instance>(),
|
||||
.thread_block_size = InstTraits::kBlockSize,
|
||||
.tile_dims = conv_traits_data_tile<InstTraits>(),
|
||||
.a_tile_transfer = conv_traits_a_transfer_params<InstTraits>(InstTraits::kAK1),
|
||||
.b_tile_transfer = conv_traits_b_transfer_params<InstTraits>(InstTraits::kBK1),
|
||||
.warp_gemm = conv_traits_wmma_warp_gemm_params<InstTraits>(),
|
||||
.c_tile_transfer = conv_traits_wmma_c_tile_transfer<InstTraits>(),
|
||||
// TODO: Add compute types (AComputeDataType, BComputeDataType) when ConvTraits supports
|
||||
// them
|
||||
// TODO: Add NumGroupsToMerge when ConvTraits supports it
|
||||
.pipeline_version = get_pipeline_version<InstTraits>(),
|
||||
.pipeline_scheduler = get_pipeline_scheduler<InstTraits>(),
|
||||
};
|
||||
}
|
||||
|
||||
} // namespace ck_tile::reflect::conv
|
||||
@@ -8,6 +8,7 @@
|
||||
#include "ck_tile/builder/reflect/conv_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp"
|
||||
#include "ck_tile/builder/reflect/conv_traits_device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp"
|
||||
#include "ck_tile/builder/reflect/conv_traits_device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp"
|
||||
#include "ck_tile/builder/reflect/conv_traits_device_grouped_conv_fwd_multiple_abd_wmma_cshuffle_v3.hpp"
|
||||
|
||||
// Bwd weight instances
|
||||
#include "ck_tile/builder/reflect/conv_traits_device_grouped_conv_bwd_weight_xdl_cshuffle.hpp"
|
||||
|
||||
@@ -0,0 +1,16 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "instance_traits.hpp"
|
||||
#include "instance_traits_util.hpp"
|
||||
|
||||
namespace ck_tile::reflect {
|
||||
|
||||
/// @brief Tag type for DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3 device kernel
|
||||
struct DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3_Tag
|
||||
{
|
||||
};
|
||||
|
||||
} // namespace ck_tile::reflect
|
||||
@@ -0,0 +1,302 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
// InstanceTraits specialization for DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3
|
||||
//
|
||||
// This .inc file is #included at the bottom of the device op header
|
||||
// (device_grouped_conv_fwd_multiple_abd_wmma_cshuffle_v3.hpp) under
|
||||
// #ifdef CK_EXPERIMENTAL_BUILDER, AFTER the struct is fully defined.
|
||||
// This eliminates the need for forward declarations.
|
||||
//
|
||||
// CRITICAL MAINTENANCE NOTE:
|
||||
// This file MUST be kept strictly in sync with the device implementation header:
|
||||
// ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_wmma_cshuffle_v3.hpp
|
||||
// The template parameter order, names, and types MUST EXACTLY MATCH those in the device
|
||||
// implementation.
|
||||
|
||||
#include "ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_abd_wmma_cshuffle_v3.hpp"
|
||||
|
||||
namespace ck_tile::reflect {
|
||||
|
||||
// Specialization for DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3
|
||||
template <ck::index_t NDimSpatial,
|
||||
typename ALayout_,
|
||||
typename BLayout_,
|
||||
typename DsLayout_,
|
||||
typename ELayout_,
|
||||
typename ADataType_,
|
||||
typename BDataType_,
|
||||
typename AccDataType_,
|
||||
typename CShuffleDataType_,
|
||||
typename DsDataType_,
|
||||
typename EDataType_,
|
||||
typename AElementwiseOperation_,
|
||||
typename BElementwiseOperation_,
|
||||
typename CDEElementwiseOperation_,
|
||||
ck::tensor_operation::device::ConvolutionForwardSpecialization ConvForwardSpecialization,
|
||||
ck::tensor_operation::device::GemmSpecialization GemmSpec,
|
||||
ck::index_t BlockSize,
|
||||
ck::index_t MPerBlock,
|
||||
ck::index_t NPerBlock,
|
||||
ck::index_t KPerBlock,
|
||||
ck::index_t AK1,
|
||||
ck::index_t BK1,
|
||||
ck::index_t MPerWmma,
|
||||
ck::index_t NPerWmma,
|
||||
ck::index_t MRepeat,
|
||||
ck::index_t NRepeat,
|
||||
typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
|
||||
typename ABlockTransferThreadClusterArrangeOrder,
|
||||
typename ABlockTransferSrcAccessOrder_,
|
||||
ck::index_t ABlockTransferSrcVectorDim,
|
||||
ck::index_t ABlockTransferSrcScalarPerVector,
|
||||
ck::index_t ABlockTransferDstScalarPerVector_AK1,
|
||||
ck::index_t ABlockLdsExtraM,
|
||||
typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
|
||||
typename BBlockTransferThreadClusterArrangeOrder,
|
||||
typename BBlockTransferSrcAccessOrder_,
|
||||
ck::index_t BBlockTransferSrcVectorDim,
|
||||
ck::index_t BBlockTransferSrcScalarPerVector,
|
||||
ck::index_t BBlockTransferDstScalarPerVector_BK1,
|
||||
ck::index_t BBlockLdsExtraN,
|
||||
ck::index_t CShuffleMRepeatPerShuffle,
|
||||
ck::index_t CShuffleNRepeatPerShuffle,
|
||||
typename CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
ck::index_t CDEBlockTransferScalarPerVector_NPerBlock,
|
||||
ck::BlockGemmPipelineScheduler BlkGemmPipeSched,
|
||||
ck::BlockGemmPipelineVersion BlkGemmPipelineVer,
|
||||
bool UseThreadTileTransfer,
|
||||
typename AComputeDataType_,
|
||||
typename BComputeDataType_,
|
||||
ck::index_t NumGroupsToMerge>
|
||||
struct InstanceTraits<
|
||||
ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3<
|
||||
NDimSpatial,
|
||||
ALayout_,
|
||||
BLayout_,
|
||||
DsLayout_,
|
||||
ELayout_,
|
||||
ADataType_,
|
||||
BDataType_,
|
||||
AccDataType_,
|
||||
CShuffleDataType_,
|
||||
DsDataType_,
|
||||
EDataType_,
|
||||
AElementwiseOperation_,
|
||||
BElementwiseOperation_,
|
||||
CDEElementwiseOperation_,
|
||||
ConvForwardSpecialization,
|
||||
GemmSpec,
|
||||
BlockSize,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
AK1,
|
||||
BK1,
|
||||
MPerWmma,
|
||||
NPerWmma,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
ABlockTransferThreadClusterLengths_AK0_M_AK1,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
ABlockTransferSrcAccessOrder_,
|
||||
ABlockTransferSrcVectorDim,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
ABlockTransferDstScalarPerVector_AK1,
|
||||
ABlockLdsExtraM,
|
||||
BBlockTransferThreadClusterLengths_BK0_N_BK1,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
BBlockTransferSrcAccessOrder_,
|
||||
BBlockTransferSrcVectorDim,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
BBlockTransferDstScalarPerVector_BK1,
|
||||
BBlockLdsExtraN,
|
||||
CShuffleMRepeatPerShuffle,
|
||||
CShuffleNRepeatPerShuffle,
|
||||
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
CDEBlockTransferScalarPerVector_NPerBlock,
|
||||
BlkGemmPipeSched,
|
||||
BlkGemmPipelineVer,
|
||||
UseThreadTileTransfer,
|
||||
AComputeDataType_,
|
||||
BComputeDataType_,
|
||||
NumGroupsToMerge>>
|
||||
{
|
||||
/// @brief Tag type identifying this device kernel variant
|
||||
using device_kernel_tag = DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3_Tag;
|
||||
|
||||
// Spatial dimension
|
||||
static constexpr int kSpatialDim = NDimSpatial;
|
||||
|
||||
// Layout types
|
||||
using ALayout = ALayout_;
|
||||
using BLayout = BLayout_;
|
||||
using DsLayout = DsLayout_;
|
||||
using ELayout = ELayout_;
|
||||
|
||||
// Data types
|
||||
using ADataType = ADataType_;
|
||||
using BDataType = BDataType_;
|
||||
using AccDataType = AccDataType_;
|
||||
using CShuffleDataType = CShuffleDataType_;
|
||||
using DsDataType = DsDataType_;
|
||||
using EDataType = EDataType_;
|
||||
|
||||
// Element-wise operations
|
||||
using AElementwiseOperation = AElementwiseOperation_;
|
||||
using BElementwiseOperation = BElementwiseOperation_;
|
||||
using CDEElementwiseOperation = CDEElementwiseOperation_;
|
||||
|
||||
// Specialization
|
||||
static constexpr ck::tensor_operation::device::ConvolutionForwardSpecialization
|
||||
kConvForwardSpecialization = ConvForwardSpecialization;
|
||||
static constexpr ck::tensor_operation::device::GemmSpecialization kGemmSpecialization =
|
||||
GemmSpec;
|
||||
|
||||
// Block configuration
|
||||
static constexpr int kBlockSize = BlockSize;
|
||||
static constexpr int kMPerBlock = MPerBlock;
|
||||
static constexpr int kNPerBlock = NPerBlock;
|
||||
static constexpr int kKPerBlock = KPerBlock;
|
||||
|
||||
// Tuning parameters
|
||||
static constexpr int kAK1 = AK1;
|
||||
static constexpr int kBK1 = BK1;
|
||||
static constexpr int kMPerWmma = MPerWmma;
|
||||
static constexpr int kNPerWmma = NPerWmma;
|
||||
static constexpr int kMRepeat = MRepeat;
|
||||
static constexpr int kNRepeat = NRepeat;
|
||||
|
||||
// A block transfer thread cluster dimensions (converted to std::array)
|
||||
static constexpr auto kAThreadClusterLengths =
|
||||
detail::SequenceToArray<ABlockTransferThreadClusterLengths_AK0_M_AK1>::value;
|
||||
static constexpr auto kAThreadClusterArrangeOrder =
|
||||
detail::SequenceToArray<ABlockTransferThreadClusterArrangeOrder>::value;
|
||||
static constexpr auto kABlockTransferSrcAccessOrder =
|
||||
detail::SequenceToArray<ABlockTransferSrcAccessOrder_>::value;
|
||||
static constexpr int kABlockTransferSrcVectorDim = ABlockTransferSrcVectorDim;
|
||||
static constexpr int kABlockTransferSrcScalarPerVector = ABlockTransferSrcScalarPerVector;
|
||||
static constexpr int kABlockTransferDstScalarPerVectorK1 = ABlockTransferDstScalarPerVector_AK1;
|
||||
static constexpr int kABlockLdsExtraM = ABlockLdsExtraM;
|
||||
|
||||
// B block transfer thread cluster dimensions (converted to std::array)
|
||||
static constexpr auto kBThreadClusterLengths =
|
||||
detail::SequenceToArray<BBlockTransferThreadClusterLengths_BK0_N_BK1>::value;
|
||||
static constexpr auto kBThreadClusterArrangeOrder =
|
||||
detail::SequenceToArray<BBlockTransferThreadClusterArrangeOrder>::value;
|
||||
static constexpr auto kBBlockTransferSrcAccessOrder =
|
||||
detail::SequenceToArray<BBlockTransferSrcAccessOrder_>::value;
|
||||
static constexpr int kBBlockTransferSrcVectorDim = BBlockTransferSrcVectorDim;
|
||||
static constexpr int kBBlockTransferSrcScalarPerVector = BBlockTransferSrcScalarPerVector;
|
||||
static constexpr int kBBlockTransferDstScalarPerVectorK1 = BBlockTransferDstScalarPerVector_BK1;
|
||||
static constexpr int kBBlockLdsExtraN = BBlockLdsExtraN;
|
||||
|
||||
// C shuffle parameters (converted to std::array)
|
||||
static constexpr int kCShuffleMRepeatPerShuffle = CShuffleMRepeatPerShuffle;
|
||||
static constexpr int kCShuffleNRepeatPerShuffle = CShuffleNRepeatPerShuffle;
|
||||
static constexpr auto kCDEThreadClusterLengths = detail::SequenceToArray<
|
||||
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock>::value;
|
||||
static constexpr int kCDEBlockTransferScalarPerVector =
|
||||
CDEBlockTransferScalarPerVector_NPerBlock;
|
||||
|
||||
// Pipeline configuration
|
||||
static constexpr ck::BlockGemmPipelineScheduler kPipelineScheduler = BlkGemmPipeSched;
|
||||
static constexpr ck::BlockGemmPipelineVersion kPipelineVersion = BlkGemmPipelineVer;
|
||||
|
||||
static constexpr bool kUseThreadTileTransfer = UseThreadTileTransfer;
|
||||
|
||||
// Compute data types
|
||||
using AComputeDataType = AComputeDataType_;
|
||||
using BComputeDataType = BComputeDataType_;
|
||||
|
||||
static constexpr int kNumGroupsToMerge = NumGroupsToMerge;
|
||||
|
||||
// Static member function to generate instance string
|
||||
static std::string instance_string()
|
||||
{
|
||||
std::ostringstream oss;
|
||||
|
||||
// Kernel type name
|
||||
oss << "DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3";
|
||||
|
||||
// Template parameters in exact order
|
||||
oss << "<" << kSpatialDim; // 1. NDimSpatial
|
||||
oss << "," << detail::layout_name<ALayout>(); // 2. ALayout
|
||||
oss << "," << detail::layout_name<BLayout>(); // 3. BLayout
|
||||
oss << "," << detail::tuple_name<DsLayout>(); // 4. DsLayout
|
||||
oss << "," << detail::layout_name<ELayout>(); // 5. ELayout
|
||||
oss << "," << detail::type_or_type_tuple_name<ADataType>(); // 6. ADataType
|
||||
oss << "," << detail::type_or_type_tuple_name<BDataType>(); // 7. BDataType
|
||||
oss << "," << detail::type_name<AccDataType>(); // 8. AccDataType
|
||||
oss << "," << detail::type_name<CShuffleDataType>(); // 9. CShuffleDataType
|
||||
oss << "," << detail::tuple_name<DsDataType>(); // 10. DsDataType
|
||||
oss << "," << detail::type_name<EDataType>(); // 11. EDataType
|
||||
oss << ","
|
||||
<< detail::elementwise_op_name<AElementwiseOperation>(); // 12. AElementwiseOperation
|
||||
oss << ","
|
||||
<< detail::elementwise_op_name<BElementwiseOperation>(); // 13. BElementwiseOperation
|
||||
oss << ","
|
||||
<< detail::elementwise_op_name<CDEElementwiseOperation>(); // 14.
|
||||
// CDEElementwiseOperation
|
||||
oss << ","
|
||||
<< detail::conv_fwd_spec_name(
|
||||
kConvForwardSpecialization); // 15. ConvForwardSpecialization
|
||||
oss << "," << detail::gemm_spec_name(kGemmSpecialization); // 16. GemmSpec
|
||||
oss << "," << kBlockSize; // 17. BlockSize
|
||||
oss << "," << kMPerBlock; // 18. MPerBlock
|
||||
oss << "," << kNPerBlock; // 19. NPerBlock
|
||||
oss << "," << kKPerBlock; // 20. KPerBlock
|
||||
oss << "," << kAK1; // 21. AK1
|
||||
oss << "," << kBK1; // 22. BK1
|
||||
oss << "," << kMPerWmma; // 23. MPerWmma
|
||||
oss << "," << kNPerWmma; // 24. NPerWmma
|
||||
oss << "," << kMRepeat; // 25. MRepeat
|
||||
oss << "," << kNRepeat; // 26. NRepeat
|
||||
oss << ","
|
||||
<< detail::array_to_string(
|
||||
kAThreadClusterLengths); // 27. ABlockTransferThreadClusterLengths
|
||||
oss << ","
|
||||
<< detail::array_to_string(
|
||||
kAThreadClusterArrangeOrder); // 28. ABlockTransferThreadClusterArrangeOrder
|
||||
oss << ","
|
||||
<< detail::array_to_string(
|
||||
kABlockTransferSrcAccessOrder); // 29. ABlockTransferSrcAccessOrder
|
||||
oss << "," << kABlockTransferSrcVectorDim; // 30. ABlockTransferSrcVectorDim
|
||||
oss << "," << kABlockTransferSrcScalarPerVector; // 31. ABlockTransferSrcScalarPerVector
|
||||
oss << ","
|
||||
<< kABlockTransferDstScalarPerVectorK1; // 32. ABlockTransferDstScalarPerVector_AK1
|
||||
oss << "," << (kABlockLdsExtraM ? "true" : "false"); // 33. ABlockLdsExtraM
|
||||
oss << ","
|
||||
<< detail::array_to_string(
|
||||
kBThreadClusterLengths); // 34. BBlockTransferThreadClusterLengths
|
||||
oss << ","
|
||||
<< detail::array_to_string(
|
||||
kBThreadClusterArrangeOrder); // 35. BBlockTransferThreadClusterArrangeOrder
|
||||
oss << ","
|
||||
<< detail::array_to_string(
|
||||
kBBlockTransferSrcAccessOrder); // 36. BBlockTransferSrcAccessOrder
|
||||
oss << "," << kBBlockTransferSrcVectorDim; // 37. BBlockTransferSrcVectorDim
|
||||
oss << "," << kBBlockTransferSrcScalarPerVector; // 38. BBlockTransferSrcScalarPerVector
|
||||
oss << ","
|
||||
<< kBBlockTransferDstScalarPerVectorK1; // 39. BBlockTransferDstScalarPerVector_BK1
|
||||
oss << "," << (kBBlockLdsExtraN ? "true" : "false"); // 40. BBlockLdsExtraN
|
||||
oss << "," << kCShuffleMRepeatPerShuffle; // 41. CShuffleMRepeatPerShuffle
|
||||
oss << "," << kCShuffleNRepeatPerShuffle; // 42. CShuffleNRepeatPerShuffle
|
||||
oss << ","
|
||||
<< detail::array_to_string(
|
||||
kCDEThreadClusterLengths); // 43. CDEBlockTransferClusterLengths
|
||||
oss << ","
|
||||
<< kCDEBlockTransferScalarPerVector; // 44. CDEBlockTransferScalarPerVector_NPerBlock
|
||||
oss << "," << detail::pipeline_scheduler_name(kPipelineScheduler); // 45. BlkGemmPipeSched
|
||||
oss << "," << detail::pipeline_version_name(kPipelineVersion); // 46. BlkGemmPipelineVer
|
||||
oss << "," << (kUseThreadTileTransfer ? "true" : "false"); // 47. UseThreadTileTransfer
|
||||
oss << "," << detail::type_name<AComputeDataType>(); // 48. AComputeDataType
|
||||
oss << "," << detail::type_name<BComputeDataType>(); // 49. BComputeDataType
|
||||
oss << "," << kNumGroupsToMerge; // 50. NumGroupsToMerge
|
||||
oss << ">";
|
||||
|
||||
return oss.str();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile::reflect
|
||||
@@ -146,6 +146,7 @@ set(INSTANCE_STRING_TESTS
|
||||
|
||||
if (CK_USE_WMMA)
|
||||
list(APPEND INSTANCE_STRING_TESTS
|
||||
test_instance_string_fwd_grp_conv_wmma_v3.cpp
|
||||
test_instance_string_bwd_weight_grp_conv_wmma_v3.cpp
|
||||
test_instance_string_bwd_weight_grp_conv_multiple_d_wmma_v3.cpp
|
||||
test_instance_string_bwd_weight_grp_conv_two_stage_wmma_v3.cpp
|
||||
@@ -172,6 +173,13 @@ add_ck_builder_test(test_ckb_build_fwd_instances
|
||||
conv/ck/test_ckb_conv_fwd_3d_fp32.cpp
|
||||
conv/ck_tile/test_ckb_conv_fwd_2d_fp16_v3.cpp
|
||||
)
|
||||
|
||||
if (CK_USE_WMMA)
|
||||
target_sources(test_ckb_build_fwd_instances PRIVATE
|
||||
conv/ck/test_ckb_conv_fwd_2d_wmma_v3_fp16.cpp
|
||||
)
|
||||
endif()
|
||||
|
||||
target_link_libraries(test_ckb_build_fwd_instances PRIVATE utility)
|
||||
|
||||
set(BWD_WEIGHT_TESTS
|
||||
|
||||
@@ -0,0 +1,105 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "utils/ckb_conv_test_configs.hpp"
|
||||
#include "utils/ckb_conv_test_utils.hpp"
|
||||
#include "utils/conv_algorithm_type_utils.hpp"
|
||||
#include "ck_tile/builder/testing/conv/fwd.hpp"
|
||||
#include "ck_tile/builder/testing/conv/fwd_ck.hpp"
|
||||
#include "ck_tile/builder/testing/conv/reference.hpp"
|
||||
#include "ck_tile/host/device_prop.hpp"
|
||||
#include "testing_utils.hpp"
|
||||
|
||||
namespace ckb = ck_tile::builder;
|
||||
namespace ckt = ck_tile::builder::test;
|
||||
namespace cku = ck_tile::builder::test_utils;
|
||||
|
||||
using ck_tile::test::MatchesReference;
|
||||
using ck_tile::test::SuccessfulRun;
|
||||
|
||||
constexpr auto SIGNATURE =
|
||||
ckt::ConvSignature{.spatial_dim = 2,
|
||||
.direction = ckb::ConvDirection::FORWARD,
|
||||
.data_type = ckb::DataType::FP16,
|
||||
.accumulation_data_type = ckb::DataType::FP32,
|
||||
.input = {.config = {.layout = ckb::TensorLayout::GNHWC}},
|
||||
.weight = {.config = {.layout = ckb::TensorLayout::GKYXC}},
|
||||
.output = {.config = {.layout = ckb::TensorLayout::GNHWK}}};
|
||||
|
||||
constexpr auto ALGORITHM = cku::ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3{}
|
||||
.with_thread_block(cku::ThreadBlock_64_64x64x32)
|
||||
.with_gemm_config(cku::GemmParamsABK1_Wmma_16x16_4x2_per_wave)
|
||||
.with_transfer(cku::Transfer_4x16x1)
|
||||
.with_fwd_specializations(ckb::ConvSpecialization::DEFAULT,
|
||||
ckb::GemmSpecialization::MNKPadding)
|
||||
.with_block_gemm(cku::BlockGemmDesc_v1_intrawave)
|
||||
.with_num_conv_groups_to_merge(1);
|
||||
|
||||
using Builder = ckb::ConvBuilder<SIGNATURE, ALGORITHM>;
|
||||
using Instance = Builder::Instance;
|
||||
|
||||
using Reference = ckb::ConvBuilder<SIGNATURE, ckt::ConvAlgorithm_Reference{}>::Instance;
|
||||
|
||||
TEST(Fwd2DFp16_WmmaV3_GNHWC, Create)
|
||||
{
|
||||
const auto expected_transfer_parameters = to_string(ALGORITHM);
|
||||
cku::run_test<Builder>({"DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3",
|
||||
expected_transfer_parameters,
|
||||
"Default",
|
||||
"Intrawave",
|
||||
"v1",
|
||||
"GNHWC,GKYXC,EmptyTuple,GNHWK",
|
||||
"PassThrough,PassThrough,PassThrough",
|
||||
"MNKPadding"});
|
||||
}
|
||||
|
||||
TEST(Fwd2DFp16_WmmaV3_GNHWC, Execution)
|
||||
{
|
||||
if(!ck_tile::get_device_name().starts_with("gfx11") &&
|
||||
!ck_tile::get_device_name().starts_with("gfx12"))
|
||||
{
|
||||
// Note: WMMA kernel requires gfx11 or gfx12
|
||||
GTEST_SKIP() << "unsupported architecture";
|
||||
}
|
||||
|
||||
ckt::Args<SIGNATURE> args = {
|
||||
.lengths =
|
||||
{
|
||||
.batch_size = 16,
|
||||
.groups = 1,
|
||||
.input_channels = 32,
|
||||
.output_channels = 48,
|
||||
.image =
|
||||
{
|
||||
.width = 56,
|
||||
.height = 64,
|
||||
},
|
||||
.filter =
|
||||
{
|
||||
.width = 3,
|
||||
.height = 5,
|
||||
},
|
||||
},
|
||||
.filter_strides = {.width = 1, .height = 1},
|
||||
.filter_dilation = {.width = 1, .height = 1},
|
||||
.input_left_pad = {.width = 0, .height = 0},
|
||||
.input_right_pad = {.width = 0, .height = 0},
|
||||
.a_elementwise_op = {},
|
||||
.b_elementwise_op = {},
|
||||
.cde_elementwise_op = {},
|
||||
};
|
||||
|
||||
auto inputs = ckt::alloc_inputs(args);
|
||||
auto outputs = ckt::alloc_outputs(args);
|
||||
auto reference = ckt::alloc_outputs(args);
|
||||
|
||||
ckt::init_inputs(args, inputs.get());
|
||||
|
||||
auto conv = Instance{};
|
||||
EXPECT_THAT(ckt::run(conv, args, inputs.get(), outputs.get()), SuccessfulRun());
|
||||
|
||||
auto ref_conv = Reference{};
|
||||
EXPECT_THAT(ckt::run(ref_conv, args, inputs.get(), reference.get()), SuccessfulRun());
|
||||
|
||||
EXPECT_THAT(outputs.get(), MatchesReference(args, reference.get()));
|
||||
}
|
||||
@@ -632,6 +632,14 @@ using ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 =
|
||||
BlockGemm_,
|
||||
GemmBatchOptions_>;
|
||||
|
||||
using ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3 =
|
||||
ConvAlgorithmTemplate<ThreadBlock_,
|
||||
WmmaGemmABK1_,
|
||||
Transfer_<>,
|
||||
ConvSpecializationFwd_,
|
||||
BlockGemm_,
|
||||
GemmBatchOptions_>;
|
||||
|
||||
using ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle =
|
||||
ConvAlgorithmTemplate<ThreadBlock_,
|
||||
WmmaGemm_,
|
||||
|
||||
@@ -12,6 +12,7 @@
|
||||
#include "ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp"
|
||||
#include "ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp"
|
||||
#include "ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_wmma_cshuffle_v3.hpp"
|
||||
#include "ck_tile/builder/reflect/instance_traits_tile_grouped_convolution_forward.hpp"
|
||||
#include "ck_tile/ops/epilogue/cshuffle_epilogue.hpp"
|
||||
|
||||
@@ -626,6 +627,118 @@ TEST(InstanceTraits, WmmaInstanceStringReturnsCorrectFormat)
|
||||
EXPECT_EQ(instance_str, expected_str);
|
||||
}
|
||||
|
||||
TEST(InstanceTraits, WmmaV3InstanceStringReturnsCorrectFormat)
|
||||
{
|
||||
using DeviceInstance =
|
||||
ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3<
|
||||
2, // NDimSpatial
|
||||
ck::tensor_layout::convolution::GNHWC, // ALayout
|
||||
ck::tensor_layout::convolution::GKYXC, // BLayout
|
||||
ck::Tuple<>, // DsLayout
|
||||
ck::tensor_layout::convolution::GNHWK, // ELayout
|
||||
ck::half_t, // ADataType
|
||||
ck::half_t, // BDataType
|
||||
float, // AccDataType
|
||||
ck::half_t, // CShuffleDataType
|
||||
ck::Tuple<>, // DsDataType
|
||||
ck::half_t, // EDataType
|
||||
ck::tensor_operation::element_wise::PassThrough, // AElementwiseOperation
|
||||
ck::tensor_operation::element_wise::PassThrough, // BElementwiseOperation
|
||||
ck::tensor_operation::element_wise::PassThrough, // CDEElementwiseOperation
|
||||
ck::tensor_operation::device::ConvolutionForwardSpecialization::
|
||||
Default, // ConvForwardSpec
|
||||
ck::tensor_operation::device::GemmSpecialization::MNKPadding, // GemmSpec
|
||||
64, // BlockSize
|
||||
64, // MPerBlock
|
||||
64, // NPerBlock
|
||||
32, // KPerBlock
|
||||
8, // AK1
|
||||
8, // BK1
|
||||
16, // MPerWmma
|
||||
16, // NPerWmma
|
||||
4, // MRepeat
|
||||
2, // NRepeat
|
||||
ck::Sequence<4, 16, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1
|
||||
ck::Sequence<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder
|
||||
ck::Sequence<1, 0, 2>, // ABlockTransferSrcAccessOrder
|
||||
2, // ABlockTransferSrcVectorDim
|
||||
1, // ABlockTransferSrcScalarPerVector
|
||||
8, // ABlockTransferDstScalarPerVector_AK1
|
||||
1, // ABlockLdsExtraM
|
||||
ck::Sequence<4, 16, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1
|
||||
ck::Sequence<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder
|
||||
ck::Sequence<1, 0, 2>, // BBlockTransferSrcAccessOrder
|
||||
2, // BBlockTransferSrcVectorDim
|
||||
1, // BBlockTransferSrcScalarPerVector
|
||||
8, // BBlockTransferDstScalarPerVector_BK1
|
||||
1, // BBlockLdsExtraN
|
||||
1, // CShuffleMRepeatPerShuffle
|
||||
1, // CShuffleNRepeatPerShuffle
|
||||
ck::Sequence<1, 16, 1, 4>, // CDEBlockTransferClusterLengths
|
||||
1, // CDEBlockTransferScalarPerVector_NPerBlock
|
||||
ck::BlockGemmPipelineScheduler::Intrawave, // BlkGemmPipeSched
|
||||
ck::BlockGemmPipelineVersion::v1>; // BlkGemmPipelineVer
|
||||
|
||||
// Generate instance string
|
||||
std::string instance_str = ck_tile::reflect::instance_string<DeviceInstance>();
|
||||
|
||||
// Expected string with all template parameters
|
||||
std::string expected_str = "DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3"
|
||||
"<2" // NDimSpatial
|
||||
",GNHWC" // ALayout
|
||||
",GKYXC" // BLayout
|
||||
",EmptyTuple" // DsLayout
|
||||
",GNHWK" // ELayout
|
||||
",fp16" // ADataType
|
||||
",fp16" // BDataType
|
||||
",fp32" // AccDataType
|
||||
",fp16" // CShuffleDataType
|
||||
",EmptyTuple" // DsDataType
|
||||
",fp16" // EDataType
|
||||
",PassThrough" // AElementwiseOperation
|
||||
",PassThrough" // BElementwiseOperation
|
||||
",PassThrough" // CDEElementwiseOperation
|
||||
",Default" // ConvForwardSpecialization
|
||||
",MNKPadding" // GemmSpec
|
||||
",64" // BlockSize
|
||||
",64" // MPerBlock
|
||||
",64" // NPerBlock
|
||||
",32" // KPerBlock
|
||||
",8" // AK1
|
||||
",8" // BK1
|
||||
",16" // MPerWmma
|
||||
",16" // NPerWmma
|
||||
",4" // MRepeat
|
||||
",2" // NRepeat
|
||||
",Seq(4,16,1)" // ABlockTransferThreadClusterLengths
|
||||
",Seq(1,0,2)" // ABlockTransferThreadClusterArrangeOrder
|
||||
",Seq(1,0,2)" // ABlockTransferSrcAccessOrder
|
||||
",2" // ABlockTransferSrcVectorDim
|
||||
",1" // ABlockTransferSrcScalarPerVector
|
||||
",8" // ABlockTransferDstScalarPerVector_AK1
|
||||
",true" // ABlockLdsExtraM
|
||||
",Seq(4,16,1)" // BBlockTransferThreadClusterLengths
|
||||
",Seq(1,0,2)" // BBlockTransferThreadClusterArrangeOrder
|
||||
",Seq(1,0,2)" // BBlockTransferSrcAccessOrder
|
||||
",2" // BBlockTransferSrcVectorDim
|
||||
",1" // BBlockTransferSrcScalarPerVector
|
||||
",8" // BBlockTransferDstScalarPerVector_BK1
|
||||
",true" // BBlockLdsExtraN
|
||||
",1" // CShuffleMRepeatPerShuffle
|
||||
",1" // CShuffleNRepeatPerShuffle
|
||||
",Seq(1,16,1,4)" // CDEBlockTransferClusterLengths
|
||||
",1" // CDEBlockTransferScalarPerVector_NPerBlock
|
||||
",Intrawave" // BlkGemmPipeSched
|
||||
",v1" // BlkGemmPipelineVer
|
||||
",true" // UseThreadTileTransfer
|
||||
",fp16" // AComputeDataType
|
||||
",fp16" // BComputeDataType
|
||||
",1>"; // NumGroupsToMerge
|
||||
|
||||
// Verify the generated string matches exactly
|
||||
EXPECT_EQ(instance_str, expected_str);
|
||||
}
|
||||
|
||||
TEST(InstanceTraits, DlInstanceStringReturnsCorrectFormat)
|
||||
{
|
||||
using DeviceInstance =
|
||||
|
||||
@@ -0,0 +1,98 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
#include <ck_tile/builder/reflect/instance_traits.hpp>
|
||||
#include <ck_tile/builder/reflect/conv_describe.hpp>
|
||||
#include <ck/tensor_operation/gpu/device/device_base.hpp>
|
||||
#include <ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_instance.hpp>
|
||||
|
||||
namespace {
|
||||
|
||||
namespace ckr = ck_tile::reflect;
|
||||
|
||||
// Use the template helper to get a working instance configuration
|
||||
using InstanceTuple = ck::tensor_operation::device::instance::
|
||||
device_grouped_conv_fwd_wmma_cshufflev3_f16_generic_instances<
|
||||
2, // NDimSpatial
|
||||
ck::tensor_operation::device::instance::GNHWC, // ALayout
|
||||
ck::tensor_operation::device::instance::GKYXC, // BLayout
|
||||
ck::tensor_operation::device::instance::Empty_Tuple, // DsLayout
|
||||
ck::tensor_operation::device::instance::GNHWK, // ELayout
|
||||
ck::tensor_operation::device::instance::ConvFwdDefault>; // ConvForwardSpecialization
|
||||
|
||||
// Get the first instance from the tuple
|
||||
using DeviceInstance = typename std::tuple_element<0, InstanceTuple>::type;
|
||||
|
||||
// Expected complete instance string based on the first instance from
|
||||
// device_grouped_conv_fwd_wmma_cshufflev3_f16_instances
|
||||
std::string expected_str = "DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3"
|
||||
"<2" // NDimSpatial
|
||||
",GNHWC" // ALayout
|
||||
",GKYXC" // BLayout
|
||||
",EmptyTuple" // DsLayout
|
||||
",GNHWK" // ELayout
|
||||
",fp16" // ADataType
|
||||
",fp16" // BDataType
|
||||
",fp32" // AccDataType
|
||||
",fp16" // CShuffleDataType
|
||||
",EmptyTuple" // DsDataType
|
||||
",fp16" // EDataType
|
||||
",PassThrough" // AElementwiseOperation
|
||||
",PassThrough" // BElementwiseOperation
|
||||
",PassThrough" // CDEElementwiseOperation
|
||||
",Default" // ConvForwardSpecialization
|
||||
",MNKPadding" // GemmSpec
|
||||
",64" // BlockSize
|
||||
",64" // MPerBlock
|
||||
",64" // NPerBlock
|
||||
",32" // KPerBlock
|
||||
",8" // AK1
|
||||
",8" // BK1
|
||||
",16" // MPerWmma
|
||||
",16" // NPerWmma
|
||||
",4" // MRepeat
|
||||
",2" // NRepeat
|
||||
",Seq(4,16,1)" // ABlockTransferThreadClusterLengths
|
||||
",Seq(1,0,2)" // ABlockTransferThreadClusterArrangeOrder
|
||||
",Seq(1,0,2)" // ABlockTransferSrcAccessOrder
|
||||
",2" // ABlockTransferSrcVectorDim
|
||||
",1" // ABlockTransferSrcScalarPerVector
|
||||
",8" // ABlockTransferDstScalarPerVector_AK1
|
||||
",true" // ABlockLdsExtraM
|
||||
",Seq(4,16,1)" // BBlockTransferThreadClusterLengths
|
||||
",Seq(1,0,2)" // BBlockTransferThreadClusterArrangeOrder
|
||||
",Seq(1,0,2)" // BBlockTransferSrcAccessOrder
|
||||
",2" // BBlockTransferSrcVectorDim
|
||||
",1" // BBlockTransferSrcScalarPerVector
|
||||
",8" // BBlockTransferDstScalarPerVector_BK1
|
||||
",true" // BBlockLdsExtraN
|
||||
",1" // CShuffleMRepeatPerShuffle
|
||||
",1" // CShuffleNRepeatPerShuffle
|
||||
",Seq(1,16,1,4)" // CDEBlockTransferClusterLengths
|
||||
",1" // CDEBlockTransferScalarPerVector_NPerBlock
|
||||
",Intrawave" // BlkGemmPipeSched
|
||||
",v1" // BlkGemmPipelineVer
|
||||
",true" // UseThreadTileTransfer
|
||||
",fp16" // AComputeDataType
|
||||
",fp16" // BComputeDataType
|
||||
",1>"; // NumGroupsToMerge
|
||||
|
||||
// Test describe() through base class pointer for WMMA V3 variant
|
||||
TEST(InstanceString, DescribeReturnsCorrectValueForFwdGrpConvWmmaV3)
|
||||
{
|
||||
using BaseClass = ck::tensor_operation::device::BaseOperator;
|
||||
DeviceInstance device_instance;
|
||||
BaseClass* base_ptr = &device_instance;
|
||||
|
||||
auto desc = base_ptr->describe();
|
||||
ASSERT_NE(desc, nullptr);
|
||||
EXPECT_EQ(desc->instance_string(), expected_str);
|
||||
}
|
||||
|
||||
TEST(InstanceString, DescriptionReturnsCorrectValueForFwdGrpConvWmmaV3)
|
||||
{
|
||||
EXPECT_EQ(ckr::describe<DeviceInstance>().instance_string(), expected_str);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
@@ -344,6 +344,16 @@ constexpr GridwiseWmmaGemmABK1 GemmParamsABK1_Wmma_16x16_2x1_per_wave{.ak1
|
||||
.m_wmma_per_wave = 2,
|
||||
.n_wmma_per_wave = 1};
|
||||
|
||||
constexpr GridwiseWmmaGemmABK1 GemmParamsABK1_Wmma_16x16_4x2_per_wave{.ak1 = 8,
|
||||
.bk1 = 8,
|
||||
.m_per_wmma = 16,
|
||||
.n_per_wmma = 16,
|
||||
.m_wmma_per_wave = 4,
|
||||
.n_wmma_per_wave = 2};
|
||||
|
||||
constexpr ThreadBlock ThreadBlock_64_64x64x32{.block_size = 64,
|
||||
.tile_size = {.m = 64, .n = 64, .k = 32}};
|
||||
|
||||
constexpr ThreadBlock ThreadBlock_256_256x256x32{.block_size = 256,
|
||||
.tile_size = {.m = 256, .n = 256, .k = 32}};
|
||||
|
||||
|
||||
@@ -409,6 +409,17 @@ inline std::string to_string<ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_C
|
||||
return oss.str();
|
||||
}
|
||||
|
||||
template <>
|
||||
inline std::string to_string<ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3>(
|
||||
ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3 t)
|
||||
{
|
||||
std::ostringstream oss;
|
||||
oss << to_string(static_cast<ThreadBlock_>(t)) << ","
|
||||
<< to_string(static_cast<WmmaGemmABK1_>(t)) << ","
|
||||
<< to_string(static_cast<Transfer_<>>(t));
|
||||
return oss.str();
|
||||
}
|
||||
|
||||
template <>
|
||||
inline std::string to_string<ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle>(
|
||||
ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle t)
|
||||
|
||||
@@ -29,6 +29,11 @@
|
||||
#include "ck/host_utility/flush_cache.hpp"
|
||||
#include "ck/host_utility/io.hpp"
|
||||
|
||||
#ifdef CK_EXPERIMENTAL_BUILDER
|
||||
#include "ck_tile/builder/reflect/description.hpp"
|
||||
#include "ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_abd_wmma_cshuffle_v3.hpp"
|
||||
#endif
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
@@ -2341,8 +2346,28 @@ struct DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3
|
||||
"The argument pointer is not an object of "
|
||||
"DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle::Argument structure!");
|
||||
}
|
||||
|
||||
#ifdef CK_EXPERIMENTAL_BUILDER
|
||||
std::string GetInstanceString() const override
|
||||
{
|
||||
static_assert(
|
||||
ck_tile::reflect::HasInstanceTraits<DeviceOp>,
|
||||
"InstanceTraits specialization is required. Include the .inc file for this device op.");
|
||||
return ck_tile::reflect::instance_string<DeviceOp>();
|
||||
}
|
||||
|
||||
std::unique_ptr<ck_tile::reflect::Description> describe() const override
|
||||
{
|
||||
return std::make_unique<ck_tile::reflect::InstanceStringDescription>(
|
||||
ck_tile::reflect::instance_string<DeviceOp>());
|
||||
}
|
||||
#endif
|
||||
};
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
|
||||
#ifdef CK_EXPERIMENTAL_BUILDER
|
||||
#include "ck_tile/builder/reflect/reflect_device_grouped_conv_fwd_multiple_abd_wmma_cshuffle_v3.inc"
|
||||
#endif
|
||||
|
||||
Reference in New Issue
Block a user