mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-04 05:31:24 +00:00
[CK_BUILDER]ckb add remining fwd conv device ops (#3155)
* Add device operation to conv signature. Use unions to hold conv layouts and device operations. * Add predicates for all device op instances. * Use the device op signature for validation. * Fix ckb CMakeLists.txt file for tests. * Fix building CK Builder instance traits after the introduction of direct load template parameter in CK. * Fix clang-formatting. * add device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk * Add full DL configurability with Option A implementation - Added 5 DL descriptor structs (39 configurable parameters) - Added 10 C++20 concepts for type-safe validation - Updated factory to read all parameters from descriptors - Updated test helper to populate all descriptors - All tests passing (13/13 including 3 new DL tests) * Add factory and test support for DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor - Add factory specialization for Large_Tensor device operation (conv_factory.hpp lines 1145-1265) - Add macro collision workaround using pragma push/pop (conv_factory.hpp lines 43-51) - Add test helper function run_test_DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor - Add builder test file test_ckb_conv_fwd_2d_large_tensor_fp16.cpp with 2 test cases - Update CMakeLists.txt to include new test file - Reuse existing ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle descriptor - Map all 42 template parameters identical to regular XDL CShuffle - All 15 builder tests passing including 2 new Large_Tensor tests Completes Task 350: All 4 forward convolution device operations now supported in CK Builder. * Update copyright headers to new format - Change copyright format to: Copyright (C) Advanced Micro Devices, Inc., or its affiliates. - Reorder headers: Copyright first, then SPDX-License-Identifier - Updated files: * experimental/builder/test/conv/test_ckb_conv_fwd_2d_dl_fp16.cpp * experimental/builder/test/conv/test_ckb_conv_fwd_2d_large_tensor_fp16.cpp * experimental/builder/include/ck_tile/builder/device_op_types.hpp * fix c++ 18 format * Fix clang-format-18 error in device_op_types.hpp --------- Co-authored-by: Ville Pietilä <ville.pietila@amd.com> Co-authored-by: Ville Pietilä <188998872+vpietila-amd@users.noreply.github.com>
This commit is contained in:
@@ -183,4 +183,87 @@ concept SpecifiesLoopScheduler = requires {
|
||||
{ T::loop_scheduler } -> std::convertible_to<PipelineScheduler>;
|
||||
};
|
||||
|
||||
/******************************************** */
|
||||
/* DL-specific descriptors and requirements */
|
||||
/******************************************** */
|
||||
|
||||
// Concept for DL thread configuration
|
||||
template <typename T>
|
||||
concept DlThreadConfigDescriptor = requires(T t) {
|
||||
{ t.k0_per_block } -> std::convertible_to<size_t>;
|
||||
{ t.k1 } -> std::convertible_to<size_t>;
|
||||
{ t.m1_per_thread } -> std::convertible_to<size_t>;
|
||||
{ t.n1_per_thread } -> std::convertible_to<size_t>;
|
||||
{ t.k_per_thread } -> std::convertible_to<size_t>;
|
||||
};
|
||||
|
||||
// Concept for DL thread cluster
|
||||
template <typename T>
|
||||
concept DlThreadClusterDescriptor = requires(T t) {
|
||||
{ t.m1_xs } -> std::convertible_to<std::array<size_t, 2>>;
|
||||
{ t.n1_xs } -> std::convertible_to<std::array<size_t, 2>>;
|
||||
};
|
||||
|
||||
// Concept for DL block transfer K0_M0_M1_K1 format
|
||||
template <typename T>
|
||||
concept DlBlockTransferK0M0M1K1Descriptor = requires(T t) {
|
||||
{ t.thread_slice_lengths } -> std::convertible_to<std::array<size_t, 4>>;
|
||||
{ t.thread_cluster_lengths } -> std::convertible_to<std::array<size_t, 4>>;
|
||||
{ t.thread_cluster_arrange_order } -> std::convertible_to<std::array<size_t, 4>>;
|
||||
{ t.src_access_order } -> std::convertible_to<std::array<size_t, 4>>;
|
||||
{ t.src_vector_tensor_lengths } -> std::convertible_to<std::array<size_t, 4>>;
|
||||
{ t.src_vector_tensor_contiguous_dim_order } -> std::convertible_to<std::array<size_t, 4>>;
|
||||
{ t.dst_vector_tensor_lengths } -> std::convertible_to<std::array<size_t, 4>>;
|
||||
};
|
||||
|
||||
// Concept for DL block transfer K0_N0_N1_K1 format
|
||||
template <typename T>
|
||||
concept DlBlockTransferK0N0N1K1Descriptor = requires(T t) {
|
||||
{ t.thread_slice_lengths } -> std::convertible_to<std::array<size_t, 4>>;
|
||||
{ t.thread_cluster_lengths } -> std::convertible_to<std::array<size_t, 4>>;
|
||||
{ t.thread_cluster_arrange_order } -> std::convertible_to<std::array<size_t, 4>>;
|
||||
{ t.src_access_order } -> std::convertible_to<std::array<size_t, 4>>;
|
||||
{ t.src_vector_tensor_lengths } -> std::convertible_to<std::array<size_t, 4>>;
|
||||
{ t.src_vector_tensor_contiguous_dim_order } -> std::convertible_to<std::array<size_t, 4>>;
|
||||
{ t.dst_vector_tensor_lengths } -> std::convertible_to<std::array<size_t, 4>>;
|
||||
};
|
||||
|
||||
// Concept for DL C thread transfer
|
||||
template <typename T>
|
||||
concept DlCThreadTransferDescriptor = requires(T t) {
|
||||
{ t.src_dst_access_order } -> std::convertible_to<std::array<size_t, 6>>;
|
||||
{ t.src_dst_vector_dim } -> std::convertible_to<size_t>;
|
||||
{ t.dst_scalar_per_vector } -> std::convertible_to<size_t>;
|
||||
};
|
||||
|
||||
// Concept to check if algorithm specifies DL thread config
|
||||
template <typename T>
|
||||
concept SpecifiesDlThreadConfig = requires {
|
||||
{ T::dl_thread_config } -> DlThreadConfigDescriptor;
|
||||
};
|
||||
|
||||
// Concept to check if algorithm specifies DL thread cluster
|
||||
template <typename T>
|
||||
concept SpecifiesDlThreadCluster = requires {
|
||||
{ T::dl_thread_cluster } -> DlThreadClusterDescriptor;
|
||||
};
|
||||
|
||||
// Concept to check if algorithm specifies DL A block transfer
|
||||
template <typename T>
|
||||
concept SpecifiesDlBlockTransferA = requires {
|
||||
{ T::dl_block_transfer_a } -> DlBlockTransferK0M0M1K1Descriptor;
|
||||
};
|
||||
|
||||
// Concept to check if algorithm specifies DL B block transfer
|
||||
template <typename T>
|
||||
concept SpecifiesDlBlockTransferB = requires {
|
||||
{ T::dl_block_transfer_b } -> DlBlockTransferK0N0N1K1Descriptor;
|
||||
};
|
||||
|
||||
// Concept to check if algorithm specifies DL C thread transfer
|
||||
template <typename T>
|
||||
concept SpecifiesDlCThreadTransfer = requires {
|
||||
{ T::dl_c_thread_transfer } -> DlCThreadTransferDescriptor;
|
||||
};
|
||||
|
||||
} // namespace ck_tile::builder
|
||||
|
||||
@@ -36,9 +36,21 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp"
|
||||
// WORKAROUND: Macro namespace collision in upstream CK device operation headers.
|
||||
// device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp (line 41) and
|
||||
// device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp (line 51) both define
|
||||
// GridwiseGemmTemplateParameters macro without #undef, causing redefinition errors.
|
||||
// Use pragma push/pop to isolate the Large_Tensor header's macro scope.
|
||||
#pragma push_macro("GridwiseGemmTemplateParameters")
|
||||
#ifdef GridwiseGemmTemplateParameters
|
||||
#undef GridwiseGemmTemplateParameters
|
||||
#endif
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp"
|
||||
#pragma pop_macro("GridwiseGemmTemplateParameters")
|
||||
#include "ck_tile/builder/conv_signature_concepts.hpp"
|
||||
#include "ck_tile/builder/conv_algorithm_concepts.hpp"
|
||||
#include "ck_tile/builder/conv_algorithm_limits.hpp"
|
||||
@@ -990,4 +1002,263 @@ struct ConvFactory<SIGNATURE, ALGORITHM, VERSION>
|
||||
GRIDWISE_GEMM_PIPELINE_VERSION>;
|
||||
};
|
||||
|
||||
// Factory specialization for DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK instance
|
||||
// of a grouped forward convolution kernel using Direct Load (DL) approach.
|
||||
template <ConvSignatureDescriptor auto SIGNATURE,
|
||||
ConvAlgorithmDescriptor auto ALGORITHM,
|
||||
StringLiteral VERSION>
|
||||
requires ConvDirectionIsForward<SIGNATURE> &&
|
||||
ConvDeviceOpIs_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK<SIGNATURE>
|
||||
struct ConvFactory<SIGNATURE, ALGORITHM, VERSION>
|
||||
{
|
||||
static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim;
|
||||
using Layouts = decltype(factory_internal::GetTensorLayout<SIGNATURE.layout,
|
||||
SPATIAL_DIM,
|
||||
ConvDirection::FORWARD>());
|
||||
using Types = factory_internal::ConvTensorTypes<SIGNATURE.data_type>;
|
||||
using Ops = factory_internal::ElementwiseOps<SIGNATURE.elementwise_operation>;
|
||||
using AlgorithmType = decltype(ALGORITHM);
|
||||
|
||||
static_assert(SpecifiesThreadBlock<AlgorithmType>,
|
||||
"The convolution algorithm descriptor must specify thread block info.");
|
||||
static_assert(SpecifiesFwdConcSpecialization<AlgorithmType>,
|
||||
"The convolution algorithm descriptor must specify forward convolution "
|
||||
"specialization.");
|
||||
static_assert(SpecifiesGemmSpecialization<AlgorithmType>,
|
||||
"The convolution algorithm descriptor must specify gemm specialization.");
|
||||
static_assert(SpecifiesDlThreadConfig<AlgorithmType>,
|
||||
"DL algorithm must specify thread config.");
|
||||
static_assert(SpecifiesDlThreadCluster<AlgorithmType>,
|
||||
"DL algorithm must specify thread cluster.");
|
||||
static_assert(SpecifiesDlBlockTransferA<AlgorithmType>,
|
||||
"DL algorithm must specify A block transfer.");
|
||||
static_assert(SpecifiesDlBlockTransferB<AlgorithmType>,
|
||||
"DL algorithm must specify B block transfer.");
|
||||
static_assert(SpecifiesDlCThreadTransfer<AlgorithmType>,
|
||||
"DL algorithm must specify C thread transfer.");
|
||||
|
||||
static constexpr auto FWD_CONV_SPECIALIZATION =
|
||||
factory_internal::SetFwdConvSpecialization<ALGORITHM>();
|
||||
static constexpr auto GEMM_SPECIALIZATION =
|
||||
factory_internal::SetGemmSpecialization<ALGORITHM>();
|
||||
|
||||
static constexpr auto BLOCK = factory_internal::SetThreadBlockInfo<ALGORITHM>();
|
||||
|
||||
// DL-specific parameters from algorithm descriptor
|
||||
static constexpr auto DL_THREAD_CFG = ALGORITHM.dl_thread_config;
|
||||
static constexpr ck::index_t K0PerBlock = DL_THREAD_CFG.k0_per_block;
|
||||
static constexpr ck::index_t K1 = DL_THREAD_CFG.k1;
|
||||
static constexpr ck::index_t M1PerThread = DL_THREAD_CFG.m1_per_thread;
|
||||
static constexpr ck::index_t N1PerThread = DL_THREAD_CFG.n1_per_thread;
|
||||
static constexpr ck::index_t KPerThread = DL_THREAD_CFG.k_per_thread;
|
||||
|
||||
// Thread cluster from descriptor
|
||||
static constexpr auto DL_CLUSTER = ALGORITHM.dl_thread_cluster;
|
||||
using M1N1ThreadClusterM1Xs = to_sequence_v<DL_CLUSTER.m1_xs>;
|
||||
using M1N1ThreadClusterN1Xs = to_sequence_v<DL_CLUSTER.n1_xs>;
|
||||
|
||||
// A Block Transfer from descriptor - K0_M0_M1_K1 tensor format
|
||||
static constexpr auto DL_A_TRANSFER = ALGORITHM.dl_block_transfer_a;
|
||||
using ABlockTransferThreadSliceLengths_K0_M0_M1_K1 =
|
||||
to_sequence_v<DL_A_TRANSFER.thread_slice_lengths>;
|
||||
using ABlockTransferThreadClusterLengths_K0_M0_M1_K1 =
|
||||
to_sequence_v<DL_A_TRANSFER.thread_cluster_lengths>;
|
||||
using ABlockTransferThreadClusterArrangeOrder =
|
||||
to_sequence_v<DL_A_TRANSFER.thread_cluster_arrange_order>;
|
||||
using ABlockTransferSrcAccessOrder = to_sequence_v<DL_A_TRANSFER.src_access_order>;
|
||||
using ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1 =
|
||||
to_sequence_v<DL_A_TRANSFER.src_vector_tensor_lengths>;
|
||||
using ABlockTransferSrcVectorTensorContiguousDimOrder =
|
||||
to_sequence_v<DL_A_TRANSFER.src_vector_tensor_contiguous_dim_order>;
|
||||
using ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1 =
|
||||
to_sequence_v<DL_A_TRANSFER.dst_vector_tensor_lengths>;
|
||||
|
||||
// B Block Transfer from descriptor - K0_N0_N1_K1 tensor format
|
||||
static constexpr auto DL_B_TRANSFER = ALGORITHM.dl_block_transfer_b;
|
||||
using BBlockTransferThreadSliceLengths_K0_N0_N1_K1 =
|
||||
to_sequence_v<DL_B_TRANSFER.thread_slice_lengths>;
|
||||
using BBlockTransferThreadClusterLengths_K0_N0_N1_K1 =
|
||||
to_sequence_v<DL_B_TRANSFER.thread_cluster_lengths>;
|
||||
using BBlockTransferThreadClusterArrangeOrder =
|
||||
to_sequence_v<DL_B_TRANSFER.thread_cluster_arrange_order>;
|
||||
using BBlockTransferSrcAccessOrder = to_sequence_v<DL_B_TRANSFER.src_access_order>;
|
||||
using BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1 =
|
||||
to_sequence_v<DL_B_TRANSFER.src_vector_tensor_lengths>;
|
||||
using BBlockTransferSrcVectorTensorContiguousDimOrder =
|
||||
to_sequence_v<DL_B_TRANSFER.src_vector_tensor_contiguous_dim_order>;
|
||||
using BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1 =
|
||||
to_sequence_v<DL_B_TRANSFER.dst_vector_tensor_lengths>;
|
||||
|
||||
// C Thread Transfer from descriptor
|
||||
static constexpr auto DL_C_TRANSFER = ALGORITHM.dl_c_thread_transfer;
|
||||
using CThreadTransferSrcDstAccessOrder = to_sequence_v<DL_C_TRANSFER.src_dst_access_order>;
|
||||
static constexpr ck::index_t CThreadTransferSrcDstVectorDim = DL_C_TRANSFER.src_dst_vector_dim;
|
||||
static constexpr ck::index_t CThreadTransferDstScalarPerVector =
|
||||
DL_C_TRANSFER.dst_scalar_per_vector;
|
||||
|
||||
// The DL forward convolution kernel class instance
|
||||
using Instance = ck::tensor_operation::device::DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK<
|
||||
SPATIAL_DIM,
|
||||
typename Types::ADataType,
|
||||
typename Types::BDataType,
|
||||
typename Types::DsDataTypes,
|
||||
typename Types::EDataType,
|
||||
typename Types::AccDataType,
|
||||
typename Layouts::ALayout,
|
||||
typename Layouts::BLayout,
|
||||
typename Layouts::DsLayout,
|
||||
typename Layouts::ELayout,
|
||||
typename Ops::AElementwiseOp,
|
||||
typename Ops::BElementwiseOp,
|
||||
typename Ops::CDEElementwiseOp,
|
||||
FWD_CONV_SPECIALIZATION,
|
||||
GEMM_SPECIALIZATION,
|
||||
BLOCK.block_size,
|
||||
BLOCK.per_block.m,
|
||||
BLOCK.per_block.n,
|
||||
K0PerBlock,
|
||||
K1,
|
||||
M1PerThread,
|
||||
N1PerThread,
|
||||
KPerThread,
|
||||
M1N1ThreadClusterM1Xs,
|
||||
M1N1ThreadClusterN1Xs,
|
||||
ABlockTransferThreadSliceLengths_K0_M0_M1_K1,
|
||||
ABlockTransferThreadClusterLengths_K0_M0_M1_K1,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
ABlockTransferSrcAccessOrder,
|
||||
ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1,
|
||||
ABlockTransferSrcVectorTensorContiguousDimOrder,
|
||||
ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1,
|
||||
BBlockTransferThreadSliceLengths_K0_N0_N1_K1,
|
||||
BBlockTransferThreadClusterLengths_K0_N0_N1_K1,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
BBlockTransferSrcAccessOrder,
|
||||
BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1,
|
||||
BBlockTransferSrcVectorTensorContiguousDimOrder,
|
||||
BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1,
|
||||
CThreadTransferSrcDstAccessOrder,
|
||||
CThreadTransferSrcDstVectorDim,
|
||||
CThreadTransferDstScalarPerVector>;
|
||||
};
|
||||
|
||||
// Factory specialization for DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor instance
|
||||
// of a grouped forward convolution kernel with large tensor support (N-splitting).
|
||||
template <ConvSignatureDescriptor auto SIGNATURE,
|
||||
ConvAlgorithmDescriptor auto ALGORITHM,
|
||||
StringLiteral VERSION>
|
||||
requires ConvDirectionIsForward<SIGNATURE> &&
|
||||
ConvDeviceOpIs_DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor<SIGNATURE>
|
||||
struct ConvFactory<SIGNATURE, ALGORITHM, VERSION>
|
||||
{
|
||||
static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim;
|
||||
using Layouts = decltype(factory_internal::GetTensorLayout<SIGNATURE.layout,
|
||||
SPATIAL_DIM,
|
||||
ConvDirection::FORWARD>());
|
||||
using Types = factory_internal::ConvTensorTypes<SIGNATURE.data_type>;
|
||||
using Ops = factory_internal::ElementwiseOps<SIGNATURE.elementwise_operation>;
|
||||
using AlgorithmType = decltype(ALGORITHM);
|
||||
|
||||
static_assert(SpecifiesThreadBlock<AlgorithmType>,
|
||||
"The convolution algorithm descriptor must specify thread block info.");
|
||||
static_assert(SpecifiesGridwiseXdlGemm<AlgorithmType>,
|
||||
"The convolution algorithm descriptor must specify gridwise GEMM info.");
|
||||
static_assert(SpecifiesBlockTransfer<AlgorithmType>,
|
||||
"The convolution algorithm descriptor must specify block transfer info.");
|
||||
static_assert(SpecifiesLdsTransfer<AlgorithmType>,
|
||||
"The convolution algorithm descriptor must specify LDS transfer info.");
|
||||
static_assert(
|
||||
SpecifiesThreadClusterAccessOrder<AlgorithmType>,
|
||||
"The convolution algorithm descriptor must specify thread cluster access order info.");
|
||||
static_assert(SpecifiesSourceAccessOrder<AlgorithmType>,
|
||||
"The convolution algorithm descriptor must specify source access order info.");
|
||||
static_assert(SpecifiesFwdConcSpecialization<AlgorithmType>,
|
||||
"The convolution algorithm descriptor must specify forward convolution "
|
||||
"specialization.");
|
||||
static_assert(SpecifiesGemmSpecialization<AlgorithmType>,
|
||||
"The convolution algorithm descriptor must specify gemm specialization.");
|
||||
static_assert(SpecifiesNumPrefetchStages<AlgorithmType>,
|
||||
"The convolution algorithm descriptor must specify number of prefetch stages.");
|
||||
static_assert(SpecifiesLoopScheduler<AlgorithmType>,
|
||||
"The convolution algorithm descriptor must specify loop scheduler.");
|
||||
|
||||
static constexpr auto FWD_CONV_SPECIALIZATION =
|
||||
factory_internal::SetFwdConvSpecialization<ALGORITHM>();
|
||||
static constexpr auto GEMM_SPECIALIZATION =
|
||||
factory_internal::SetGemmSpecialization<ALGORITHM>();
|
||||
static constexpr factory_internal::ConvSpec SPECIALIZATION{.conv_spec = FWD_CONV_SPECIALIZATION,
|
||||
.gemm_spec = GEMM_SPECIALIZATION};
|
||||
|
||||
static constexpr auto LOOP_SCHEDULER = factory_internal::SetLoopScheduler<ALGORITHM>();
|
||||
static constexpr auto BLOCK = factory_internal::SetThreadBlockInfo<ALGORITHM>();
|
||||
static constexpr auto GRIDWISE_GEMM = ALGORITHM.gridwise_gemm;
|
||||
static constexpr auto A_BLOCK_TRANSFER =
|
||||
factory_internal::SetFwdConvABlockTransfer<ALGORITHM>();
|
||||
static constexpr auto B_BLOCK_TRANSFER =
|
||||
factory_internal::SetFwdConvBBlockTransfer<ALGORITHM>();
|
||||
static constexpr auto C_BLOCK_TRANSFER =
|
||||
factory_internal::SetCBlockTransfer<SIGNATURE, ALGORITHM>();
|
||||
|
||||
// Check limits for the algorithm parameters.
|
||||
static_assert(InputVectorTransferLimits<A_BLOCK_TRANSFER>);
|
||||
static_assert(InputVectorTransferLimits<B_BLOCK_TRANSFER>);
|
||||
static_assert(OutputVectorTransferLimits<C_BLOCK_TRANSFER>);
|
||||
static_assert(AccessOrderLimits<A_BLOCK_TRANSFER.thread_cluster_order>);
|
||||
static_assert(AccessOrderLimits<B_BLOCK_TRANSFER.thread_cluster_order>);
|
||||
static_assert(AccessOrderLimits<A_BLOCK_TRANSFER.src_access_order>);
|
||||
static_assert(AccessOrderLimits<B_BLOCK_TRANSFER.src_access_order>);
|
||||
|
||||
// The forward convolution kernel class instance with large tensor support.
|
||||
using Instance =
|
||||
ck::tensor_operation::device::DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor<
|
||||
SPATIAL_DIM,
|
||||
typename Layouts::ALayout,
|
||||
typename Layouts::BLayout,
|
||||
typename Layouts::DsLayout,
|
||||
typename Layouts::ELayout,
|
||||
typename Types::ADataType,
|
||||
typename Types::BDataType,
|
||||
typename Types::AccDataType,
|
||||
typename Types::CShuffleDataType,
|
||||
typename Types::DsDataTypes,
|
||||
typename Types::EDataType,
|
||||
typename Ops::AElementwiseOp,
|
||||
typename Ops::BElementwiseOp,
|
||||
typename Ops::CDEElementwiseOp,
|
||||
SPECIALIZATION.conv_spec,
|
||||
SPECIALIZATION.gemm_spec,
|
||||
ALGORITHM.num_gemm_k_prefetch_stages,
|
||||
BLOCK.block_size,
|
||||
BLOCK.per_block.m,
|
||||
BLOCK.per_block.n,
|
||||
BLOCK.per_block.k,
|
||||
GRIDWISE_GEMM.ak1,
|
||||
GRIDWISE_GEMM.bk1,
|
||||
GRIDWISE_GEMM.m_per_xdl,
|
||||
GRIDWISE_GEMM.n_per_xdl,
|
||||
GRIDWISE_GEMM.m_xdl_per_wave,
|
||||
GRIDWISE_GEMM.n_xdl_per_wave,
|
||||
to_sequence_v<A_BLOCK_TRANSFER.thread_cluster_dims>,
|
||||
to_sequence_v<A_BLOCK_TRANSFER.thread_cluster_order>,
|
||||
to_sequence_v<A_BLOCK_TRANSFER.src_access_order>,
|
||||
A_BLOCK_TRANSFER.src_vector_dim,
|
||||
A_BLOCK_TRANSFER.src_scalar_per_vector,
|
||||
A_BLOCK_TRANSFER.lds_dst_scalar_per_vector,
|
||||
A_BLOCK_TRANSFER.lds_padding,
|
||||
to_sequence_v<B_BLOCK_TRANSFER.thread_cluster_dims>,
|
||||
to_sequence_v<B_BLOCK_TRANSFER.thread_cluster_order>,
|
||||
to_sequence_v<B_BLOCK_TRANSFER.src_access_order>,
|
||||
B_BLOCK_TRANSFER.src_vector_dim,
|
||||
B_BLOCK_TRANSFER.src_scalar_per_vector,
|
||||
B_BLOCK_TRANSFER.lds_dst_scalar_per_vector,
|
||||
B_BLOCK_TRANSFER.lds_padding,
|
||||
C_BLOCK_TRANSFER.m_per_wave_per_shuffle,
|
||||
C_BLOCK_TRANSFER.n_per_wave_per_shuffle,
|
||||
to_sequence_v<C_BLOCK_TRANSFER.thread_cluster_dims>,
|
||||
C_BLOCK_TRANSFER.scalar_per_vector,
|
||||
typename Types::AComputeType,
|
||||
typename Types::BComputeType,
|
||||
LOOP_SCHEDULER>;
|
||||
};
|
||||
|
||||
} // namespace ck_tile::builder
|
||||
|
||||
@@ -0,0 +1,22 @@
|
||||
// Copyright (C) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
namespace ck_tile::builder {
|
||||
|
||||
// Enumeration for CK Device Operation types.
|
||||
// This allows the builder to select which device operation template to instantiate
|
||||
// based on the user's requirements.
|
||||
enum class DeviceOpType
|
||||
{
|
||||
// Forward Convolution - Non-grouped
|
||||
CONV_FWD, // Maps to: DeviceConvFwd (TODO: No implementation with tuning params exists yet)
|
||||
|
||||
// Forward Convolution - Grouped
|
||||
GROUPED_CONV_FWD_MULTIPLE_ABD, // Maps to: DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
GROUPED_CONV_FWD_MULTIPLE_ABD_XDL_CSHUFFLE_V3, // Maps to:
|
||||
// DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
|
||||
};
|
||||
|
||||
} // namespace ck_tile::builder
|
||||
Reference in New Issue
Block a user