[CK_BUILDER] Add bwd weight factories (#3509)

* Add placeholder test.

* Initial conv bwd weight factory.

* Conv builder test refactoring.

* Add missing pieces to bwd weight factory.

* Improve compile time erros message when no matching factory is found.

* Use amcro to ensure automatic macthing between concepts are their string representations.

* Improve compile time diagnostics.

* Small improvements.

* Improve missing member/wrong type compile-time errors.

* Improve compile time diagnostics.

* Concept bug fixes.

* Remove debug assert.

* Update algorithm signature diagnostics.

* Factory bug fixes.

* First functional version of bwd weight conv factory.

* Refactor handing of GEMM-K batch template parameter in conv bwd weight factory.

* Concept improvements.

* Improve concept diagnostics.

* Introduve a common size type for concepts.

* Update compiletime diagnostics to use the size type.

* Update conv specialization enum.

* Fix fwd conv builder tests.

* Fix smoke tests.

* Separate bwd weigth and bwd data tests into separate targets.

* Clean-up CK Tile builder tests.

* Add bwd weight XDL CShuffle V3 factory.

* Build conv bwd weigth v3 instances successfully.

* Add instance traits for DeviceGroupedConvBwdWeight_Xdl_CShuffleV3.

* Test fix.

* Add instance traits for bwd weight algorithms.

* Add unit tests for instance strings.

* Build new instance traits unit tests but exclude WMMA for now.

* Added factory for DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle.

* Conv bwd weight DL factory.

* Final implementation for bwd weight DL factory.

* Add test for creating DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle instance.

* Add factory for DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle

* Treat ref algorithm the same way as real algorithms in the dispatcher.

* Refactor large tensor support and WMMA configuration.

* Add factory and tests for DeviceGroupedConvBwdWeight_Wmma_CShuffleV3.

* Update Readme.

* Fix WMMA bwd weight tests.

* Added factory and tests for DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3.

* Factory and tests for DeviceGroupedConvBwdWeight_Wmma_CShuffle.

* Dispatching for DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffle.

* Add factory for DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3

* Fix DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3 factory and  compute types for input and output tensor in bwd weigth convs.

* Fix fwd factories after refactoring.

* clang-format

* Move compile-time diagnostics to a separate branch.

* Fix ref algorithm dispatching.

* Fix smoke tests.

* clang-format

* Fix factory for regular WMMA conv bwd weight.

* Clarify builder Readme.

* Remove obsolete test file.

* Fix test after merge.

* clang-format

* Remove the C++26 extensions.

* Unify conv elementwise ops and layout definitions for fwd and bwd directions.

* Remove old layout and elementwise ops.

* Unify handling of conv tensor types between fwd and bwd directions.

* Unify block transfer for fwd and bwd directions. Rename ThreadSliceDim to ThreadClusterRank.

* Make BlockTransferDescriptor concept parametrized. Introduce a common TileTransferParameters concept for conv algorithms.

* clang-format

---------

Co-authored-by: Ville Pietilä <>
This commit is contained in:
Ville Pietilä
2026-01-13 18:12:38 +02:00
committed by GitHub
parent 710fa1fd3d
commit 9908a87c31
69 changed files with 2956 additions and 832 deletions

View File

@@ -15,29 +15,31 @@ namespace ck_tile::builder {
/* Descriptors for individual elements of the algorithm description */
/********************************************************************/
// Common concept for size-related fields
template <typename T>
concept SizeType = std::unsigned_integral<std::remove_cvref_t<T>>;
// Concept for thread block dimensions for a GEMM problem.
template <typename T>
concept ThreadBlockDescriptor = requires(T t) {
{ t.block_size } -> std::convertible_to<size_t>;
{ t.tile_size.m } -> std::convertible_to<size_t>;
{ t.tile_size.n } -> std::convertible_to<size_t>;
{ t.tile_size.k } -> std::convertible_to<size_t>;
{ t.block_size } -> SizeType;
{ t.tile_size.m } -> SizeType;
{ t.tile_size.n } -> SizeType;
{ t.tile_size.k } -> SizeType;
};
// Concept for parameters that describe a gridwise XDL GEMM problem.
template <typename T>
concept GridwiseXdlGemmDescriptor = requires(T t) {
{ t.ak1 } -> std::convertible_to<size_t>;
{ t.bk1 } -> std::convertible_to<size_t>;
{ t.m_per_xdl } -> std::convertible_to<size_t>;
{ t.n_per_xdl } -> std::convertible_to<size_t>;
{ t.m_xdl_per_wave } -> std::convertible_to<size_t>;
{ t.n_xdl_per_wave } -> std::convertible_to<size_t>;
{ t.m_per_xdl } -> SizeType;
{ t.n_per_xdl } -> SizeType;
{ t.m_xdl_per_wave } -> SizeType;
{ t.n_xdl_per_wave } -> SizeType;
};
// Concept for parameter that describe block GEMM problem.
template <typename T>
concept BlockGemmDescriptor = requires(T t) {
concept BlockGemmPipelineDescriptor = requires(T t) {
{ t.pipeline_version } -> std::convertible_to<PipelineVersion>;
{ t.scheduler } -> std::convertible_to<PipelineScheduler>;
};
@@ -45,37 +47,48 @@ concept BlockGemmDescriptor = requires(T t) {
// Concept for parameters that describe a gridwise WMMA GEMM problem.
template <typename T>
concept GridwiseWmmaGemmDescriptor = requires(T t) {
{ t.k1 } -> std::convertible_to<size_t>;
{ t.m_per_wmma } -> std::convertible_to<size_t>;
{ t.n_per_wmma } -> std::convertible_to<size_t>;
{ t.m_wmma_per_wave } -> std::convertible_to<size_t>;
{ t.n_wmma_per_wave } -> std::convertible_to<size_t>;
{ t.pipeline_version } -> std::convertible_to<PipelineVersion>;
{ t.k1 } -> SizeType;
{ t.m_per_wmma } -> SizeType;
{ t.n_per_wmma } -> SizeType;
{ t.m_wmma_per_wave } -> SizeType;
{ t.n_wmma_per_wave } -> SizeType;
};
// Concept for vectorized data transfer for convolution input tensors.
template <typename T>
concept BlockTransferDescriptor = requires(T t) {
{ t.k0 } -> std::convertible_to<size_t>;
{ t.m_n } -> std::convertible_to<size_t>;
{ t.k1 } -> std::convertible_to<size_t>;
concept BlockTransferDescriptor3D = requires(T t) {
{ t.k0 } -> SizeType;
{ t.m_n } -> SizeType;
{ t.k1 } -> SizeType;
};
template <typename T>
concept BlockTransferDescriptor4D = requires(T t) {
{ t.k0 } -> SizeType;
{ t.m_n } -> SizeType;
{ t.k1 } -> SizeType;
{ t.k_batch_size } -> SizeType;
};
template <typename T, size_t ThreadClusterRank>
concept BlockTransferDescriptor = (ThreadClusterRank == 3 && BlockTransferDescriptor3D<T>) ||
(ThreadClusterRank == 4 && BlockTransferDescriptor4D<T>);
// Concept for thread cluster dimensions for GEMM output tensor.
template <typename T>
concept ThreadClusterDescriptor = requires(T t) {
{ t.m_block } -> std::convertible_to<size_t>;
{ t.m_wave_per_xdl } -> std::convertible_to<size_t>;
{ t.n_block } -> std::convertible_to<size_t>;
{ t.n_wave_per_xdl } -> std::convertible_to<size_t>;
{ t.m_block } -> SizeType;
{ t.m_wave_per_xdl } -> SizeType;
{ t.n_block } -> SizeType;
{ t.n_wave_per_xdl } -> SizeType;
};
// Concept for the LDS transfer for the convolution input tensors.
template <typename T>
concept LdsTransferDescriptor = requires(T t) {
{ t.src_vector_dim } -> std::convertible_to<size_t>;
{ t.src_scalar_per_vector } -> std::convertible_to<size_t>;
{ t.lds_dst_scalar_per_vector } -> std::convertible_to<size_t>;
{ t.src_vector_dim } -> SizeType;
{ t.src_scalar_per_vector } -> SizeType;
{ t.lds_dst_scalar_per_vector } -> SizeType;
{ t.is_direct_load } -> std::convertible_to<bool>;
{ t.lds_padding } -> std::convertible_to<bool>;
};
@@ -84,33 +97,35 @@ concept LdsTransferDescriptor = requires(T t) {
// LDS).
template <typename T>
concept EpilogueDescriptor = requires(T t) {
{ t.m_xdl_per_wave_per_shuffle } -> std::convertible_to<size_t>;
{ t.n_per_wave_per_shuffle } -> std::convertible_to<size_t>;
{ t.scalar_per_vector } -> std::convertible_to<size_t>;
{ t.m_xdl_per_wave_per_shuffle } -> SizeType;
{ t.n_per_wave_per_shuffle } -> SizeType;
{ t.scalar_per_vector } -> SizeType;
};
// Concept for the thread cluster access order
template <typename T>
concept AccessOrderDescriptor = requires(T t) {
{ t.order } -> std::convertible_to<std::array<size_t, 3>>;
} || requires(T t) {
{ t.order } -> std::convertible_to<std::array<size_t, 4>>;
};
// Concept for thread block dimensions for a GEMM problem for CK Tile (Block
// size is deduced from block gemm structure).
template <typename T>
concept TileThreadBlockDescriptor = requires(T t) {
{ t.tile_size.m } -> std::convertible_to<size_t>;
{ t.tile_size.n } -> std::convertible_to<size_t>;
{ t.tile_size.k } -> std::convertible_to<size_t>;
{ t.tile_size.m } -> SizeType;
{ t.tile_size.n } -> SizeType;
{ t.tile_size.k } -> SizeType;
};
// Concept for thread block dimensions for a GEMM problem for CK Tile (Block
// size is deduced from block gemm structure).
template <typename T>
concept TileTransferDescriptor = requires(T t) {
{ t.a_scalar_per_vector } -> std::convertible_to<size_t>;
{ t.b_scalar_per_vector } -> std::convertible_to<size_t>;
{ t.c_scalar_per_vector } -> std::convertible_to<size_t>;
{ t.a_scalar_per_vector } -> SizeType;
{ t.b_scalar_per_vector } -> SizeType;
{ t.c_scalar_per_vector } -> SizeType;
};
// Concept to check if struct specifies block GEMM (CK Tile).
@@ -159,30 +174,51 @@ concept SpecifiesTileThreadBlock = requires {
// Concept to check if a struct specifies gridwise XDL GEMM info.
template <typename T>
concept SpecifiesGridwiseXdlGemm = requires {
{ T::gridwise_gemm } -> GridwiseXdlGemmDescriptor;
concept GridwiseFwdXdlGemmDescriptor = requires(T t) {
{ t.ak1 } -> SizeType;
{ t.bk1 } -> SizeType;
{ t.xdl_params } -> GridwiseXdlGemmDescriptor;
};
// Concept to check if a struct specifies gridwise XDL GEMM info.
template <typename T>
concept GridwiseBwdXdlGemmDescriptor = requires(T t) {
{ t.k1 } -> SizeType;
{ t.xdl_params } -> GridwiseXdlGemmDescriptor;
};
// Concept to check if a struct specifies gridwise XDL GEMM info.
template <typename T>
concept SpecifiesGridwiseFwdXdlGemm = requires(T t) {
{ t.gridwise_gemm } -> GridwiseFwdXdlGemmDescriptor;
};
// Concept to check if a struct specifies gridwise XDL GEMM info.
template <typename T>
concept SpecifiesGridwiseBwdXdlGemm = requires(T t) {
{ t.gridwise_gemm } -> GridwiseBwdXdlGemmDescriptor;
};
// Concept to check if a struct specifies gridwise WMMA GEMM info.
template <typename T>
concept SpecifiesGridwiseWmmaGemm = requires {
{ T::gridwise_gemm } -> GridwiseWmmaGemmDescriptor;
concept SpecifiesGridwiseWmmaGemm = requires(T t) {
{ t.gridwise_gemm } -> GridwiseWmmaGemmDescriptor;
};
// Concept to check if a struct specifies convolution input and output block transfer info.
template <typename T>
template <typename T, size_t ThreadClusterRank = 3>
concept SpecifiesBlockTransfer = requires(T t) {
{ T::transfer.a.block_transfer } -> BlockTransferDescriptor;
{ T::transfer.b.block_transfer } -> BlockTransferDescriptor;
{ T::transfer.a.block_transfer } -> BlockTransferDescriptor<ThreadClusterRank>;
{ T::transfer.b.block_transfer } -> BlockTransferDescriptor<ThreadClusterRank>;
{ T::transfer.c.thread_cluster_dims } -> ThreadClusterDescriptor;
};
// Concept to check if a struct specifies convolution scalar per vector infor for A, B and C.
template <typename T>
concept SpecifiesTileTransfer = requires(T t) {
{ T::transfer.a_scalar_per_vector } -> std::convertible_to<size_t>;
{ T::transfer.b_scalar_per_vector } -> std::convertible_to<size_t>;
{ T::transfer.c_scalar_per_vector } -> std::convertible_to<size_t>;
{ T::transfer.a_scalar_per_vector } -> SizeType;
{ T::transfer.b_scalar_per_vector } -> SizeType;
{ T::transfer.c_scalar_per_vector } -> SizeType;
};
// Concept to check if a struct specifies LDS transfer info for tensors A, B, and C.
@@ -210,8 +246,12 @@ concept SpecifiesSourceAccessOrder = requires(T t) {
// Concept to check if struct specifies block GEMM.
template <typename T>
concept SpecifiesBlockGemm = requires {
{ T::block_gemm.pipeline_version } -> std::convertible_to<PipelineVersion>;
{ T::block_gemm.scheduler } -> std::convertible_to<PipelineScheduler>;
{ T::block_gemm_pipeline } -> BlockGemmPipelineDescriptor;
};
template <typename T>
concept SpecifiesGridwiseGemmPipeline = requires {
{ T::pipeline_version } -> std::convertible_to<PipelineVersion>;
};
// Concept to check if struct specifies block GEMM (CK Tile).
@@ -244,7 +284,12 @@ concept SpecifiesTileConvSpecialization = requires {
template <typename T>
concept SpecifiesFwdConvSpecialization = requires {
{ T::fwd_specialization } -> std::convertible_to<ConvFwdSpecialization>;
{ T::fwd_specialization } -> std::convertible_to<ConvSpecialization>;
};
template <typename T>
concept SpecifiesBwdWeightConvSpecialization = requires {
{ T::bwd_weight_specialization } -> std::convertible_to<ConvSpecialization>;
};
template <typename T>
@@ -254,12 +299,12 @@ concept SpecifiesGemmSpecialization = requires {
template <typename T>
concept SpecifiesNumPrefetchStages = requires {
{ T::num_gemm_k_prefetch_stages } -> std::convertible_to<size_t>;
{ T::num_gemm_k_prefetch_stages } -> SizeType;
};
template <typename T>
concept SpecifiesNumGroupsToMerge = requires {
{ T::num_groups_to_merge } -> std::convertible_to<size_t>;
{ T::num_conv_groups_to_merge } -> SizeType;
};
template <typename T>
@@ -267,12 +312,59 @@ concept SpecifiesLoopScheduler = requires {
{ T::loop_scheduler } -> std::convertible_to<PipelineScheduler>;
};
template <typename T>
concept SpecifiesGenericInstance = !requires {
{ T::specialization };
};
template <typename T>
concept SpecifiesTransposeTransfer = requires {
{ T::max_transpose_transfer_src_scalar_per_vector } -> SizeType;
{ T::max_transpose_transfer_dst_scalar_per_vector } -> SizeType;
};
template <typename T>
concept HasTransposeTransfer = requires {
{ T::max_transpose_transfer_src_scalar_per_vector };
{ T::max_transpose_transfer_dst_scalar_per_vector };
};
template <typename T>
concept TransposeTransferWellDefinedIfProvided =
!HasTransposeTransfer<T> || SpecifiesTransposeTransfer<T>;
template <typename T>
concept SpecifiesGemmBatchOptions = requires {
{ T::num_conv_groups_to_merge } -> SizeType;
};
/******************************************** */
/* Algorithm specialization concepts */
/******************************************** */
template <typename T>
concept SpecifiesLargeTensorSupport = requires {
{ T::specialization } -> std::convertible_to<ConvAlgorithmSpecialization>;
requires T::specialization == ConvAlgorithmSpecialization::LARGE_TENSOR;
};
template <typename T>
concept SpecifiesReferenceAlgorithm = requires {
{ T::specialization } -> std::convertible_to<ConvAlgorithmSpecialization>;
requires T::specialization == ConvAlgorithmSpecialization::REFERENCE;
};
template <typename T>
concept SpecifiesTwoStageSupport = requires {
{ T::specialization } -> std::convertible_to<ConvAlgorithmSpecialization>;
requires T::specialization == ConvAlgorithmSpecialization::TWO_STAGE;
};
template <typename T>
concept SpecifiesMultipleDSupport = requires {
{ T::specialization } -> std::convertible_to<ConvAlgorithmSpecialization>;
requires T::specialization == ConvAlgorithmSpecialization::MULTIPLE_D;
};
/******************************************** */
/* DL-specific descriptors and requirements */
/******************************************** */
@@ -280,11 +372,11 @@ concept SpecifiesLargeTensorSupport = requires {
// Concept for DL thread configuration
template <typename T>
concept DlThreadConfigDescriptor = requires(T t) {
{ t.k0_per_block } -> std::convertible_to<size_t>;
{ t.k1 } -> std::convertible_to<size_t>;
{ t.m1_per_thread } -> std::convertible_to<size_t>;
{ t.n1_per_thread } -> std::convertible_to<size_t>;
{ t.k_per_thread } -> std::convertible_to<size_t>;
{ t.k0_per_block } -> SizeType;
{ t.k1 } -> SizeType;
{ t.m1_per_thread } -> SizeType;
{ t.n1_per_thread } -> SizeType;
{ t.k_per_thread } -> SizeType;
};
// Concept for DL thread cluster
@@ -295,23 +387,29 @@ concept DlThreadClusterDescriptor = requires(T t) {
};
// Concept for DL block transfer
template <typename T>
template <typename T, size_t N>
concept DlBlockTransferDescriptor = requires(T t) {
{ t.thread_slice_lengths } -> std::convertible_to<std::array<size_t, 4>>;
{ t.thread_cluster_lengths } -> std::convertible_to<std::array<size_t, 4>>;
{ t.thread_cluster_arrange_order } -> std::convertible_to<std::array<size_t, 4>>;
{ t.src_access_order } -> std::convertible_to<std::array<size_t, 4>>;
{ t.src_vector_tensor_lengths } -> std::convertible_to<std::array<size_t, 4>>;
{ t.src_vector_tensor_contiguous_dim_order } -> std::convertible_to<std::array<size_t, 4>>;
{ t.dst_vector_tensor_lengths } -> std::convertible_to<std::array<size_t, 4>>;
{ t.thread_slice_lengths } -> std::convertible_to<std::array<size_t, N>>;
{ t.thread_cluster_lengths } -> std::convertible_to<std::array<size_t, N>>;
{ t.thread_cluster_arrange_order } -> std::convertible_to<std::array<size_t, N>>;
{ t.src_access_order } -> std::convertible_to<std::array<size_t, N>>;
{ t.src_vector_tensor_lengths } -> std::convertible_to<std::array<size_t, N>>;
{ t.src_vector_tensor_contiguous_dim_order } -> std::convertible_to<std::array<size_t, N>>;
{ t.dst_vector_tensor_lengths } -> std::convertible_to<std::array<size_t, N>>;
};
template <typename T>
concept DlBlockTransferDescriptor4D = DlBlockTransferDescriptor<T, 4>;
template <typename T>
concept DlBlockTransferDescriptor5D = DlBlockTransferDescriptor<T, 5>;
// Concept for DL epilogue
template <typename T>
concept DlEpilogueDescriptor = requires(T t) {
{ t.src_dst_access_order } -> std::convertible_to<std::array<size_t, 6>>;
{ t.src_dst_vector_dim } -> std::convertible_to<size_t>;
{ t.dst_scalar_per_vector } -> std::convertible_to<size_t>;
{ t.src_dst_vector_dim } -> SizeType;
{ t.dst_scalar_per_vector } -> SizeType;
};
// Concept to check if algorithm specifies DL thread config
@@ -328,15 +426,21 @@ concept SpecifiesDlThreadCluster = requires {
// Concept to check if algorithm specifies DL block transfer
template <typename T>
concept SpecifiesDlBlockTransfer = requires {
{ T::transfer.a.block_transfer } -> DlBlockTransferDescriptor;
{ T::transfer.b.block_transfer } -> DlBlockTransferDescriptor;
concept SpecifiesDlFwdBlockTransfer = requires {
{ T::transfer.a } -> DlBlockTransferDescriptor4D;
{ T::transfer.b } -> DlBlockTransferDescriptor4D;
};
template <typename T>
concept SpecifiesDlBwdBlockTransfer = requires {
{ T::transfer.a } -> DlBlockTransferDescriptor5D;
{ T::transfer.b } -> DlBlockTransferDescriptor5D;
};
// Concept to check if algorithm specifies DL C thread transfer
template <typename T>
concept SpecifiesDlEpilogue = requires {
{ T::transfer.c.epilogue } -> DlEpilogueDescriptor;
{ T::transfer.c } -> DlEpilogueDescriptor;
};
} // namespace ck_tile::builder

View File

@@ -29,10 +29,20 @@ concept OutputVectorTransferLimits = requires {
// Limits for access order. Must be a permutation of {0, 1, 2}.
template <auto Value>
concept AccessOrderLimits = requires {
concept AccessOrderLimits3D = requires {
requires((Value[0] != Value[1]) && (Value[0] != Value[2]) && (Value[1] != Value[2]) &&
(Value[0] >= 0 && Value[0] < 3) && (Value[1] >= 0 && Value[1] < 3) &&
(Value[2] >= 0 && Value[2] < 3));
(Value[2] >= 0 && Value[2] < 3) && (Value.Size() == 3));
};
// Limits for access order. Must be a permutation of {0, 1, 2, 3}.
template <auto Value>
concept AccessOrderLimits4D = requires {
requires((Value[0] != Value[1]) && (Value[0] != Value[2]) && (Value[0] != Value[3]) &&
(Value[1] != Value[2]) && (Value[1] != Value[3]) && (Value[2] != Value[3]) &&
(Value[0] >= 0 && Value[0] < 4) && (Value[1] >= 0 && Value[1] < 4) &&
(Value[2] >= 0 && Value[2] < 4) && (Value[3] >= 0 && Value[3] < 4) &&
(Value.Size() == 4));
};
} // namespace ck_tile::builder

View File

@@ -228,4 +228,13 @@ concept ValidConvWeightLayoutForSpatialDim =
(SpatialDim == 1 && ConvWeightLayout1D<L>) || (SpatialDim == 2 && ConvWeightLayout2D<L>) ||
(SpatialDim == 3 && ConvWeightLayout3D<L>);
// Constraint for 3D conv signature.
template <auto Sig>
concept Is3D = requires {
requires Sig.spatial_dim == 3;
requires ConvInputLayout3D<Sig.input.config.layout>;
requires ConvOutputLayout3D<Sig.output.config.layout>;
requires ConvWeightLayout3D<Sig.weight.config.layout>;
};
} // namespace ck_tile::builder

View File

@@ -0,0 +1,128 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck_tile/builder/conv_algorithm_concepts.hpp"
namespace ck_tile::builder::factory {
// Base algorithm concepts
template <typename T, size_t ThreadClusterRank = 3>
concept TileTransferParameters =
SpecifiesBlockTransfer<T, ThreadClusterRank> && SpecifiesLdsTransfer<T> &&
SpecifiesThreadClusterAccessOrder<T> && SpecifiesSourceAccessOrder<T>;
template <typename T>
concept SpecifiesTileTransferParameters3D = TileTransferParameters<T, 3>;
template <typename T>
concept SpecifiesTileTransferParameters4D = TileTransferParameters<T, 4>;
template <typename T>
concept FwdXdlAlgorithmBase =
ConvAlgorithmDescriptor<T> && SpecifiesThreadBlock<T> && SpecifiesTileTransferParameters3D<T> &&
SpecifiesGridwiseFwdXdlGemm<T> && SpecifiesFwdConvSpecialization<T> &&
SpecifiesGemmSpecialization<T> && SpecifiesNumPrefetchStages<T> &&
SpecifiesNumGroupsToMerge<T> && SpecifiesLoopScheduler<T>;
template <typename T>
concept BwdXdlAlgorithmBase =
ConvAlgorithmDescriptor<T> && SpecifiesThreadBlock<T> && SpecifiesTileTransferParameters4D<T> &&
SpecifiesGridwiseBwdXdlGemm<T> && SpecifiesBwdWeightConvSpecialization<T>;
template <typename T>
concept BwdXdlV3AlgorithmBase =
ConvAlgorithmDescriptor<T> && SpecifiesThreadBlock<T> && SpecifiesTileTransferParameters3D<T> &&
SpecifiesGridwiseBwdXdlGemm<T> && SpecifiesBwdWeightConvSpecialization<T> &&
SpecifiesBlockGemm<T>;
template <typename T>
concept BwdWmmaAlgorithmBase =
ConvAlgorithmDescriptor<T> && SpecifiesThreadBlock<T> && SpecifiesTileTransferParameters3D<T> &&
SpecifiesGridwiseWmmaGemm<T> && SpecifiesBwdWeightConvSpecialization<T>;
template <typename T>
concept BwdWmmaV3AlgorithmBase =
ConvAlgorithmDescriptor<T> && SpecifiesThreadBlock<T> && SpecifiesTileTransferParameters3D<T> &&
SpecifiesGridwiseWmmaGemm<T> && SpecifiesBwdWeightConvSpecialization<T> &&
SpecifiesBlockGemm<T>;
// Reference algorithm concept
template <typename T>
concept ReferenceAlgorithm = ConvAlgorithmDescriptor<T> && SpecifiesReferenceAlgorithm<T>;
// Tile-based algorithm concept
template <typename T>
concept TileAlgorithm = ConvAlgorithmDescriptor<T> && SpecifiesTileThreadBlock<T> &&
SpecifiesTileTransfer<T> && SpecifiesTileConvSpecialization<T> &&
SpecifiesTileBlockGemm<T> && SpecifiesTileOptimizations<T>;
// FWD XDL algorithm concepts
template <typename T>
concept FwdXdlAlgorithm = FwdXdlAlgorithmBase<T> && SpecifiesGenericInstance<T>;
template <typename T>
concept LargeTensorAlgorithm = FwdXdlAlgorithmBase<T> && SpecifiesLargeTensorSupport<T>;
template <typename T>
concept FwdXdlV3Algorithm =
ConvAlgorithmDescriptor<T> && SpecifiesThreadBlock<T> && SpecifiesTileTransferParameters3D<T> &&
SpecifiesGridwiseFwdXdlGemm<T> && SpecifiesFwdConvSpecialization<T> &&
SpecifiesGemmSpecialization<T> && SpecifiesBlockGemm<T>;
// FWD WMMA algorithm concepts
template <typename T>
concept FwdWmmaAlgorithm =
ConvAlgorithmDescriptor<T> && SpecifiesThreadBlock<T> && SpecifiesTileTransferParameters3D<T> &&
SpecifiesGridwiseWmmaGemm<T> && SpecifiesFwdConvSpecialization<T> &&
SpecifiesGemmSpecialization<T> && SpecifiesNumPrefetchStages<T> && SpecifiesLoopScheduler<T> &&
SpecifiesGridwiseGemmPipeline<T>;
// FWD DL algorithms
template <typename T>
concept FwdDlAlgorithm =
ConvAlgorithmDescriptor<T> && SpecifiesThreadBlock<T> && SpecifiesFwdConvSpecialization<T> &&
SpecifiesGemmSpecialization<T> && SpecifiesDlThreadConfig<T> && SpecifiesDlThreadCluster<T> &&
SpecifiesDlFwdBlockTransfer<T> && SpecifiesDlEpilogue<T>;
// BWD weight XDL algorithm concepts
template <typename T>
concept BwdXdlAlgorithm =
BwdXdlAlgorithmBase<T> && SpecifiesTransposeTransfer<T> && SpecifiesGenericInstance<T>;
template <typename T>
concept BwdMultiDXdlAlgorithm = BwdXdlAlgorithmBase<T> && SpecifiesMultipleDSupport<T>;
template <typename T>
concept BwdXdlV3Algorithm = BwdXdlV3AlgorithmBase<T> && SpecifiesGenericInstance<T>;
template <typename T>
concept BwdTwoStageXdlAlgorithm = BwdXdlV3AlgorithmBase<T> && SpecifiesTransposeTransfer<T> &&
SpecifiesGemmBatchOptions<T> && SpecifiesTwoStageSupport<T>;
// BWD weight WMMA algorithm concepts
template <typename T>
concept BwdWmmaAlgorithm =
BwdWmmaAlgorithmBase<T> && SpecifiesNumPrefetchStages<T> && SpecifiesLoopScheduler<T> &&
SpecifiesGridwiseGemmPipeline<T> && SpecifiesGenericInstance<T>;
template <typename T>
concept BwdMultiDWmmaV3Algorithm = BwdWmmaV3AlgorithmBase<T> && SpecifiesMultipleDSupport<T>;
template <typename T>
concept BwdWmmaV3Algorithm =
BwdWmmaV3AlgorithmBase<T> && SpecifiesTransposeTransfer<T> && SpecifiesGenericInstance<T>;
template <typename T>
concept BwdTwoStageWmmaV3Algorithm = BwdWmmaV3AlgorithmBase<T> && SpecifiesTransposeTransfer<T> &&
SpecifiesGemmBatchOptions<T> && SpecifiesTwoStageSupport<T>;
// BWD weigth DL algorithms
template <typename T>
concept BwdDlAlgorithm =
ConvAlgorithmDescriptor<T> && SpecifiesThreadBlock<T> &&
SpecifiesBwdWeightConvSpecialization<T> && SpecifiesDlThreadConfig<T> &&
SpecifiesDlThreadCluster<T> && SpecifiesDlBwdBlockTransfer<T> && SpecifiesDlEpilogue<T>;
} // namespace ck_tile::builder::factory

View File

@@ -0,0 +1,131 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_dl.hpp"
#include "ck_tile/builder/conv_signature_concepts.hpp"
#include "ck_tile/builder/conv_algorithm_concepts.hpp"
#include "ck_tile/builder/builder_utils.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_tensor_layout.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_tensor_type.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_elementwise_op.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_thread_block.hpp"
namespace ck_tile::builder::factory {
// Factory for DeviceGroupedConvBwdWeight_Dl instance
// of a grouped bwd weight convolution kernel.
template <ConvSignatureDescriptor auto SIGNATURE,
ConvAlgorithmDescriptor auto ALGORITHM,
StringLiteral VERSION>
requires ConvDirectionIsBackwardWeight<SIGNATURE>
struct ConvBwdWeightDlFactory
{
static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim;
using Layouts = internal::ConvTensorLayouts<SIGNATURE, SPATIAL_DIM>;
using Types = internal::ConvTensorDataTypes<SIGNATURE>;
using Ops = internal::ConvElementwiseOps<SIGNATURE>;
using AlgorithmType = decltype(ALGORITHM);
static constexpr auto BWD_CONV_SPECIALIZATION =
internal::SetBwdWeightConvSpecialization<ALGORITHM>();
static constexpr auto BLOCK = internal::SetThreadBlockInfo<ALGORITHM>();
// DL-specific parameters from algorithm descriptor
static constexpr auto DL_THREAD_CFG = ALGORITHM.thread_config;
static constexpr ck::index_t K0PerBlock = DL_THREAD_CFG.k0_per_block;
static constexpr ck::index_t K1 = DL_THREAD_CFG.k1;
static constexpr ck::index_t M1PerThread = DL_THREAD_CFG.m1_per_thread;
static constexpr ck::index_t N1PerThread = DL_THREAD_CFG.n1_per_thread;
static constexpr ck::index_t KPerThread = DL_THREAD_CFG.k_per_thread;
// Thread cluster from descriptor
static constexpr auto DL_CLUSTER = ALGORITHM.thread_cluster;
using M1N1ThreadClusterM1Xs = to_sequence_v<DL_CLUSTER.m1_xs>;
using M1N1ThreadClusterN1Xs = to_sequence_v<DL_CLUSTER.n1_xs>;
// A Block Transfer from descriptor - K0_M0_M1_K1 tensor format
static constexpr auto DL_A_TRANSFER = ALGORITHM.transfer.a;
using ABlockTransferThreadSliceLengths_K0_M0_M1_K1 =
to_sequence_v<DL_A_TRANSFER.thread_slice_lengths>;
using ABlockTransferThreadClusterLengths_K0_M0_M1_K1 =
to_sequence_v<DL_A_TRANSFER.thread_cluster_lengths>;
using ABlockTransferThreadClusterArrangeOrder =
to_sequence_v<DL_A_TRANSFER.thread_cluster_arrange_order>;
using ABlockTransferSrcAccessOrder = to_sequence_v<DL_A_TRANSFER.src_access_order>;
using ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1 =
to_sequence_v<DL_A_TRANSFER.src_vector_tensor_lengths>;
using ABlockTransferSrcVectorTensorContiguousDimOrder =
to_sequence_v<DL_A_TRANSFER.src_vector_tensor_contiguous_dim_order>;
using ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1 =
to_sequence_v<DL_A_TRANSFER.dst_vector_tensor_lengths>;
// B Block Transfer from descriptor - K0_N0_N1_K1 tensor format
static constexpr auto DL_B_TRANSFER = ALGORITHM.transfer.b;
using BBlockTransferThreadSliceLengths_K0_N0_N1_K1 =
to_sequence_v<DL_B_TRANSFER.thread_slice_lengths>;
using BBlockTransferThreadClusterLengths_K0_N0_N1_K1 =
to_sequence_v<DL_B_TRANSFER.thread_cluster_lengths>;
using BBlockTransferThreadClusterArrangeOrder =
to_sequence_v<DL_B_TRANSFER.thread_cluster_arrange_order>;
using BBlockTransferSrcAccessOrder = to_sequence_v<DL_B_TRANSFER.src_access_order>;
using BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1 =
to_sequence_v<DL_B_TRANSFER.src_vector_tensor_lengths>;
using BBlockTransferSrcVectorTensorContiguousDimOrder =
to_sequence_v<DL_B_TRANSFER.src_vector_tensor_contiguous_dim_order>;
using BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1 =
to_sequence_v<DL_B_TRANSFER.dst_vector_tensor_lengths>;
// C Thread Transfer from descriptor
static constexpr auto DL_C_TRANSFER = ALGORITHM.transfer.c;
using CThreadTransferSrcDstAccessOrder = to_sequence_v<DL_C_TRANSFER.src_dst_access_order>;
static constexpr ck::index_t CThreadTransferSrcDstVectorDim = DL_C_TRANSFER.src_dst_vector_dim;
static constexpr ck::index_t CThreadTransferDstScalarPerVector =
DL_C_TRANSFER.dst_scalar_per_vector;
// The DL forward convolution kernel class instance
using Instance = ck::tensor_operation::device::DeviceGroupedConvBwdWeight_Dl<
SPATIAL_DIM,
typename Layouts::InLayout,
typename Layouts::WeiLayout,
typename Layouts::OutLayout,
typename Types::InDataType,
typename Types::WeiDataType,
typename Types::OutDataType,
typename Types::AccDataType,
typename Ops::InElementwiseOp,
typename Ops::WeiElementwiseOp,
typename Ops::OutElementwiseOp,
BWD_CONV_SPECIALIZATION,
BLOCK.block_size,
BLOCK.per_block.m,
BLOCK.per_block.n,
K0PerBlock,
K1,
M1PerThread,
N1PerThread,
KPerThread,
M1N1ThreadClusterM1Xs,
M1N1ThreadClusterN1Xs,
ABlockTransferThreadSliceLengths_K0_M0_M1_K1,
ABlockTransferThreadClusterLengths_K0_M0_M1_K1,
ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder,
ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1,
ABlockTransferSrcVectorTensorContiguousDimOrder,
ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1,
BBlockTransferThreadSliceLengths_K0_N0_N1_K1,
BBlockTransferThreadClusterLengths_K0_N0_N1_K1,
BBlockTransferThreadClusterArrangeOrder,
BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1,
BBlockTransferSrcVectorTensorContiguousDimOrder,
BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1,
CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim,
CThreadTransferDstScalarPerVector>;
};
} // namespace ck_tile::builder::factory

View File

@@ -0,0 +1,110 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_wmma_cshuffle_v3.hpp"
#include "ck_tile/builder/conv_signature_concepts.hpp"
#include "ck_tile/builder/conv_algorithm_concepts.hpp"
#include "ck_tile/builder/conv_algorithm_limits.hpp"
#include "ck_tile/builder/builder_utils.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_tensor_layout.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_tensor_type.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_elementwise_op.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_tuning_params.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_block_transfer.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_thread_block.hpp"
namespace ck_tile::builder::factory {
// Factory for DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3 instance
// of a grouped bwd weight convolution kernel.
template <ConvSignatureDescriptor auto SIGNATURE,
ConvAlgorithmDescriptor auto ALGORITHM,
StringLiteral VERSION>
requires ConvDirectionIsBackwardWeight<SIGNATURE> && Is3D<SIGNATURE>
struct ConvBwdWeightMultiDWmmaV3Factory
{
static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim;
using Layouts = internal::ConvTensorLayouts<SIGNATURE, SPATIAL_DIM>;
using Types = internal::ConvTensorDataTypes<SIGNATURE>;
using Ops = internal::ConvElementwiseOps<SIGNATURE>;
using AlgorithmType = decltype(ALGORITHM);
static constexpr auto BWD_CONV_SPECIALIZATION =
internal::SetBwdWeightConvSpecialization<ALGORITHM>();
static constexpr auto BLOCK = internal::SetThreadBlockInfo<ALGORITHM>();
static constexpr auto GRIDWISE_GEMM = ALGORITHM.gridwise_gemm;
static constexpr auto A_BLOCK_TRANSFER =
internal::SetBwdConvBlockTransfer<ALGORITHM.transfer.a>();
static constexpr auto B_BLOCK_TRANSFER =
internal::SetBwdConvBlockTransfer<ALGORITHM.transfer.b>();
static constexpr auto C_BLOCK_TRANSFER = internal::SetCBlockTransfer<SIGNATURE, ALGORITHM>();
static constexpr auto BLOCK_GEMM = internal::SetBlockGemm<ALGORITHM>();
// Check limits for the algorithm parameters.
// TODO: Add more limits checks as needed.
static_assert(InputVectorTransferLimits<A_BLOCK_TRANSFER>, "Invalid A block transfer config");
static_assert(InputVectorTransferLimits<B_BLOCK_TRANSFER>, "Invalid B block transfer config");
static_assert(OutputVectorTransferLimits<C_BLOCK_TRANSFER>, "Invalid C block transfer config");
static_assert(AccessOrderLimits3D<A_BLOCK_TRANSFER.thread_cluster_order>,
"Invalid A thread cluster access order");
static_assert(AccessOrderLimits3D<B_BLOCK_TRANSFER.thread_cluster_order>,
"Invalid B thread cluster access order");
static_assert(AccessOrderLimits3D<A_BLOCK_TRANSFER.src_access_order>,
"Invalid A source access order");
static_assert(AccessOrderLimits3D<B_BLOCK_TRANSFER.src_access_order>,
"Invalid B source access order");
// The forward convolution kernel class instance.
using Instance =
ck::tensor_operation::device::DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3<
SPATIAL_DIM,
typename Layouts::InLayout,
typename Layouts::WeiLayout,
typename Layouts::OutLayout,
typename Layouts::DsLayout,
typename Types::InDataType,
typename Types::WeiDataType,
typename Types::OutDataType,
typename Types::AccDataType,
typename Types::DsDataType,
typename Ops::InElementwiseOp,
typename Ops::WeiElementwiseOp,
typename Ops::OutElementwiseOp,
BWD_CONV_SPECIALIZATION,
BLOCK.block_size,
BLOCK.per_block.m,
BLOCK.per_block.n,
BLOCK.per_block.k,
GRIDWISE_GEMM.k1,
GRIDWISE_GEMM.m_per_wmma,
GRIDWISE_GEMM.n_per_wmma,
GRIDWISE_GEMM.m_wmma_per_wave,
GRIDWISE_GEMM.n_wmma_per_wave,
to_sequence_v<A_BLOCK_TRANSFER.thread_cluster_dims>,
to_sequence_v<A_BLOCK_TRANSFER.thread_cluster_order>,
to_sequence_v<A_BLOCK_TRANSFER.src_access_order>,
A_BLOCK_TRANSFER.src_vector_dim,
A_BLOCK_TRANSFER.src_scalar_per_vector,
A_BLOCK_TRANSFER.lds_dst_scalar_per_vector,
A_BLOCK_TRANSFER.lds_padding,
to_sequence_v<B_BLOCK_TRANSFER.thread_cluster_dims>,
to_sequence_v<B_BLOCK_TRANSFER.thread_cluster_order>,
to_sequence_v<B_BLOCK_TRANSFER.src_access_order>,
B_BLOCK_TRANSFER.src_vector_dim,
B_BLOCK_TRANSFER.src_scalar_per_vector,
B_BLOCK_TRANSFER.lds_dst_scalar_per_vector,
B_BLOCK_TRANSFER.lds_padding,
C_BLOCK_TRANSFER.m_xdl_per_wave_per_shuffle,
C_BLOCK_TRANSFER.n_xdl_per_wave_per_shuffle,
to_sequence_v<C_BLOCK_TRANSFER.thread_cluster_dims>,
C_BLOCK_TRANSFER.scalar_per_vector,
BLOCK_GEMM.scheduler,
BLOCK_GEMM.pipeline_version,
typename Types::OutComputeType,
typename Types::InComputeType>;
};
} // namespace ck_tile::builder::factory

View File

@@ -0,0 +1,103 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp"
#include "ck_tile/builder/conv_signature_concepts.hpp"
#include "ck_tile/builder/conv_algorithm_concepts.hpp"
#include "ck_tile/builder/conv_algorithm_limits.hpp"
#include "ck_tile/builder/builder_utils.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_tensor_layout.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_tensor_type.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_elementwise_op.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_tuning_params.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_block_transfer.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_thread_block.hpp"
namespace ck_tile::builder::factory {
// Factory for DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle instance
// of a grouped bwd weight convolution kernel.
template <ConvSignatureDescriptor auto SIGNATURE,
ConvAlgorithmDescriptor auto ALGORITHM,
StringLiteral VERSION>
requires ConvDirectionIsBackwardWeight<SIGNATURE>
struct ConvBwdWeightMultiDXdlFactory
{
static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim;
using Layouts = internal::ConvTensorLayouts<SIGNATURE, SPATIAL_DIM>;
using Types = internal::ConvTensorDataTypes<SIGNATURE>;
using Ops = internal::ConvElementwiseOps<SIGNATURE>;
using AlgorithmType = decltype(ALGORITHM);
static constexpr auto BWD_CONV_SPECIALIZATION =
internal::SetBwdWeightConvSpecialization<ALGORITHM>();
static constexpr auto BLOCK = internal::SetThreadBlockInfo<ALGORITHM>();
static constexpr auto GRIDWISE_GEMM = ALGORITHM.gridwise_gemm;
static constexpr auto XDL_PARAMS = GRIDWISE_GEMM.xdl_params;
static constexpr auto A_BLOCK_TRANSFER =
internal::SetBwdConvBlockTransfer<ALGORITHM.transfer.a>();
static constexpr auto B_BLOCK_TRANSFER =
internal::SetBwdConvBlockTransfer<ALGORITHM.transfer.b>();
static constexpr auto C_BLOCK_TRANSFER = internal::SetCBlockTransfer<SIGNATURE, ALGORITHM>();
// Check limits for the algorithm parameters.
// TODO: Add more limits checks as needed.
static_assert(InputVectorTransferLimits<A_BLOCK_TRANSFER>);
static_assert(InputVectorTransferLimits<B_BLOCK_TRANSFER>);
static_assert(OutputVectorTransferLimits<C_BLOCK_TRANSFER>);
static_assert(AccessOrderLimits4D<A_BLOCK_TRANSFER.thread_cluster_order>);
static_assert(AccessOrderLimits4D<B_BLOCK_TRANSFER.thread_cluster_order>);
static_assert(AccessOrderLimits4D<A_BLOCK_TRANSFER.src_access_order>);
static_assert(AccessOrderLimits4D<B_BLOCK_TRANSFER.src_access_order>);
// The forward convolution kernel class instance.
using Instance = ck::tensor_operation::device::DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle<
SPATIAL_DIM,
typename Layouts::InLayout,
typename Layouts::WeiLayout,
typename Layouts::OutLayout,
typename Layouts::DsLayout,
typename Types::InDataType,
typename Types::WeiDataType,
typename Types::OutDataType,
typename Types::AccDataType,
typename Types::DsDataType,
typename Ops::InElementwiseOp,
typename Ops::WeiElementwiseOp,
typename Ops::OutElementwiseOp,
BWD_CONV_SPECIALIZATION,
BLOCK.block_size,
BLOCK.per_block.m,
BLOCK.per_block.n,
BLOCK.per_block.k,
GRIDWISE_GEMM.k1,
XDL_PARAMS.m_per_xdl,
XDL_PARAMS.n_per_xdl,
XDL_PARAMS.m_xdl_per_wave,
XDL_PARAMS.n_xdl_per_wave,
to_sequence_v<A_BLOCK_TRANSFER.thread_cluster_dims>,
to_sequence_v<A_BLOCK_TRANSFER.thread_cluster_order>,
to_sequence_v<A_BLOCK_TRANSFER.src_access_order>,
A_BLOCK_TRANSFER.src_vector_dim,
A_BLOCK_TRANSFER.src_scalar_per_vector,
A_BLOCK_TRANSFER.lds_dst_scalar_per_vector,
A_BLOCK_TRANSFER.lds_padding,
to_sequence_v<B_BLOCK_TRANSFER.thread_cluster_dims>,
to_sequence_v<B_BLOCK_TRANSFER.thread_cluster_order>,
to_sequence_v<B_BLOCK_TRANSFER.src_access_order>,
B_BLOCK_TRANSFER.src_vector_dim,
B_BLOCK_TRANSFER.src_scalar_per_vector,
B_BLOCK_TRANSFER.lds_dst_scalar_per_vector,
B_BLOCK_TRANSFER.lds_padding,
C_BLOCK_TRANSFER.m_xdl_per_wave_per_shuffle,
C_BLOCK_TRANSFER.n_xdl_per_wave_per_shuffle,
to_sequence_v<C_BLOCK_TRANSFER.thread_cluster_dims>,
C_BLOCK_TRANSFER.scalar_per_vector,
typename Types::OutComputeType,
typename Types::InComputeType>;
};
} // namespace ck_tile::builder::factory

View File

@@ -0,0 +1,111 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_wmma_cshuffle_v3.hpp"
#include "ck_tile/builder/conv_signature_concepts.hpp"
#include "ck_tile/builder/conv_algorithm_concepts.hpp"
#include "ck_tile/builder/conv_algorithm_limits.hpp"
#include "ck_tile/builder/builder_utils.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_tensor_layout.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_tensor_type.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_elementwise_op.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_tuning_params.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_block_transfer.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_thread_block.hpp"
namespace ck_tile::builder::factory {
// Factory for DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffle_V3 instance
// of a grouped bwd weight convolution kernel.
template <ConvSignatureDescriptor auto SIGNATURE,
ConvAlgorithmDescriptor auto ALGORITHM,
StringLiteral VERSION>
requires ConvDirectionIsBackwardWeight<SIGNATURE>
struct ConvBwdWeightTwoStageWmmaV3Factory
{
static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim;
using Layouts = internal::ConvTensorLayouts<SIGNATURE, SPATIAL_DIM>;
using Types = internal::ConvTensorDataTypes<SIGNATURE>;
using Ops = internal::ConvElementwiseOps<SIGNATURE>;
using AlgorithmType = decltype(ALGORITHM);
static constexpr auto BWD_CONV_SPECIALIZATION =
internal::SetBwdWeightConvSpecialization<ALGORITHM>();
static constexpr auto BLOCK = internal::SetThreadBlockInfo<ALGORITHM>();
static constexpr auto GRIDWISE_GEMM = ALGORITHM.gridwise_gemm;
static constexpr auto A_BLOCK_TRANSFER =
internal::SetBwdConvBlockTransfer<ALGORITHM.transfer.a>();
static constexpr auto B_BLOCK_TRANSFER =
internal::SetBwdConvBlockTransfer<ALGORITHM.transfer.b>();
static constexpr auto C_BLOCK_TRANSFER = internal::SetCBlockTransfer<SIGNATURE, ALGORITHM>();
static constexpr auto BLOCK_GEMM = internal::SetBlockGemm<ALGORITHM>();
// Check limits for the algorithm parameters.
// TODO: Add more limits checks as needed.
static_assert(InputVectorTransferLimits<A_BLOCK_TRANSFER>, "Invalid A block transfer config");
static_assert(InputVectorTransferLimits<B_BLOCK_TRANSFER>, "Invalid B block transfer config");
static_assert(OutputVectorTransferLimits<C_BLOCK_TRANSFER>, "Invalid C block transfer config");
static_assert(AccessOrderLimits3D<A_BLOCK_TRANSFER.thread_cluster_order>,
"Invalid A thread cluster access order");
static_assert(AccessOrderLimits3D<B_BLOCK_TRANSFER.thread_cluster_order>,
"Invalid B thread cluster access order");
static_assert(AccessOrderLimits3D<A_BLOCK_TRANSFER.src_access_order>,
"Invalid A source access order");
static_assert(AccessOrderLimits3D<B_BLOCK_TRANSFER.src_access_order>,
"Invalid B source access order");
// The forward convolution kernel class instance.
using Instance =
ck::tensor_operation::device::DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3<
SPATIAL_DIM,
typename Layouts::InLayout,
typename Layouts::WeiLayout,
typename Layouts::OutLayout,
typename Types::InDataType,
typename Types::WeiDataType,
typename Types::OutDataType,
typename Types::AccDataType,
typename Ops::InElementwiseOp,
typename Ops::WeiElementwiseOp,
typename Ops::OutElementwiseOp,
BWD_CONV_SPECIALIZATION,
BLOCK.block_size,
BLOCK.per_block.m,
BLOCK.per_block.n,
BLOCK.per_block.k,
GRIDWISE_GEMM.k1,
GRIDWISE_GEMM.m_per_wmma,
GRIDWISE_GEMM.n_per_wmma,
GRIDWISE_GEMM.m_wmma_per_wave,
GRIDWISE_GEMM.n_wmma_per_wave,
to_sequence_v<A_BLOCK_TRANSFER.thread_cluster_dims>,
to_sequence_v<A_BLOCK_TRANSFER.thread_cluster_order>,
to_sequence_v<A_BLOCK_TRANSFER.src_access_order>,
A_BLOCK_TRANSFER.src_vector_dim,
A_BLOCK_TRANSFER.src_scalar_per_vector,
A_BLOCK_TRANSFER.lds_dst_scalar_per_vector,
A_BLOCK_TRANSFER.lds_padding,
to_sequence_v<B_BLOCK_TRANSFER.thread_cluster_dims>,
to_sequence_v<B_BLOCK_TRANSFER.thread_cluster_order>,
to_sequence_v<B_BLOCK_TRANSFER.src_access_order>,
B_BLOCK_TRANSFER.src_vector_dim,
B_BLOCK_TRANSFER.src_scalar_per_vector,
B_BLOCK_TRANSFER.lds_dst_scalar_per_vector,
B_BLOCK_TRANSFER.lds_padding,
C_BLOCK_TRANSFER.m_xdl_per_wave_per_shuffle,
C_BLOCK_TRANSFER.n_xdl_per_wave_per_shuffle,
to_sequence_v<C_BLOCK_TRANSFER.thread_cluster_dims>,
C_BLOCK_TRANSFER.scalar_per_vector,
BLOCK_GEMM.scheduler,
BLOCK_GEMM.pipeline_version,
ALGORITHM.num_conv_groups_to_merge,
typename Types::OutComputeType,
typename Types::InComputeType,
ALGORITHM.max_transpose_transfer_src_scalar_per_vector,
ALGORITHM.max_transpose_transfer_dst_scalar_per_vector>;
};
} // namespace ck_tile::builder::factory

View File

@@ -0,0 +1,111 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp"
#include "ck_tile/builder/conv_signature_concepts.hpp"
#include "ck_tile/builder/conv_algorithm_concepts.hpp"
#include "ck_tile/builder/conv_algorithm_limits.hpp"
#include "ck_tile/builder/builder_utils.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_tensor_layout.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_tensor_type.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_elementwise_op.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_tuning_params.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_block_transfer.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_thread_block.hpp"
namespace ck_tile::builder::factory {
// Factory for DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle instance
// of a grouped bwd weight convolution kernel.
template <ConvSignatureDescriptor auto SIGNATURE,
ConvAlgorithmDescriptor auto ALGORITHM,
StringLiteral VERSION>
requires ConvDirectionIsBackwardWeight<SIGNATURE>
struct ConvBwdWeightTwoStageXdlFactory
{
static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim;
using Layouts = internal::ConvTensorLayouts<SIGNATURE, SPATIAL_DIM>;
using Types = internal::ConvTensorDataTypes<SIGNATURE>;
using Ops = internal::ConvElementwiseOps<SIGNATURE>;
using AlgorithmType = decltype(ALGORITHM);
static constexpr auto BWD_CONV_SPECIALIZATION =
internal::SetBwdWeightConvSpecialization<ALGORITHM>();
static constexpr auto BLOCK = internal::SetThreadBlockInfo<ALGORITHM>();
static constexpr auto GRIDWISE_GEMM = ALGORITHM.gridwise_gemm;
static constexpr auto XDL_PARAMS = GRIDWISE_GEMM.xdl_params;
static constexpr auto A_BLOCK_TRANSFER =
internal::SetBwdConvBlockTransfer<ALGORITHM.transfer.a>();
static constexpr auto B_BLOCK_TRANSFER =
internal::SetBwdConvBlockTransfer<ALGORITHM.transfer.b>();
static constexpr auto C_BLOCK_TRANSFER = internal::SetCBlockTransfer<SIGNATURE, ALGORITHM>();
static constexpr auto BLOCK_GEMM = internal::SetBlockGemm<ALGORITHM>();
// Check limits for the algorithm parameters.
// TODO: Add more limits checks as needed.
static_assert(InputVectorTransferLimits<A_BLOCK_TRANSFER>, "Invalid A block transfer config");
static_assert(InputVectorTransferLimits<B_BLOCK_TRANSFER>, "Invalid B block transfer config");
static_assert(OutputVectorTransferLimits<C_BLOCK_TRANSFER>, "Invalid C block transfer config");
static_assert(AccessOrderLimits3D<A_BLOCK_TRANSFER.thread_cluster_order>,
"Invalid A thread cluster access order");
static_assert(AccessOrderLimits3D<B_BLOCK_TRANSFER.thread_cluster_order>,
"Invalid B thread cluster access order");
static_assert(AccessOrderLimits3D<A_BLOCK_TRANSFER.src_access_order>,
"Invalid A source access order");
static_assert(AccessOrderLimits3D<B_BLOCK_TRANSFER.src_access_order>,
"Invalid B source access order");
// The forward convolution kernel class instance.
using Instance = ck::tensor_operation::device::DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle<
SPATIAL_DIM,
typename Layouts::InLayout,
typename Layouts::WeiLayout,
typename Layouts::OutLayout,
typename Types::InDataType,
typename Types::WeiDataType,
typename Types::OutDataType,
typename Types::AccDataType,
typename Ops::InElementwiseOp,
typename Ops::WeiElementwiseOp,
typename Ops::OutElementwiseOp,
BWD_CONV_SPECIALIZATION,
BLOCK.block_size,
BLOCK.per_block.m,
BLOCK.per_block.n,
BLOCK.per_block.k,
GRIDWISE_GEMM.k1,
XDL_PARAMS.m_per_xdl,
XDL_PARAMS.n_per_xdl,
XDL_PARAMS.m_xdl_per_wave,
XDL_PARAMS.n_xdl_per_wave,
to_sequence_v<A_BLOCK_TRANSFER.thread_cluster_dims>,
to_sequence_v<A_BLOCK_TRANSFER.thread_cluster_order>,
to_sequence_v<A_BLOCK_TRANSFER.src_access_order>,
A_BLOCK_TRANSFER.src_vector_dim,
A_BLOCK_TRANSFER.src_scalar_per_vector,
A_BLOCK_TRANSFER.lds_dst_scalar_per_vector,
A_BLOCK_TRANSFER.lds_padding,
to_sequence_v<B_BLOCK_TRANSFER.thread_cluster_dims>,
to_sequence_v<B_BLOCK_TRANSFER.thread_cluster_order>,
to_sequence_v<B_BLOCK_TRANSFER.src_access_order>,
B_BLOCK_TRANSFER.src_vector_dim,
B_BLOCK_TRANSFER.src_scalar_per_vector,
B_BLOCK_TRANSFER.lds_dst_scalar_per_vector,
B_BLOCK_TRANSFER.lds_padding,
C_BLOCK_TRANSFER.m_xdl_per_wave_per_shuffle,
C_BLOCK_TRANSFER.n_xdl_per_wave_per_shuffle,
to_sequence_v<C_BLOCK_TRANSFER.thread_cluster_dims>,
C_BLOCK_TRANSFER.scalar_per_vector,
BLOCK_GEMM.scheduler,
BLOCK_GEMM.pipeline_version,
ALGORITHM.num_conv_groups_to_merge,
typename Types::OutComputeType,
typename Types::InComputeType,
ALGORITHM.max_transpose_transfer_src_scalar_per_vector,
ALGORITHM.max_transpose_transfer_dst_scalar_per_vector>;
};
} // namespace ck_tile::builder::factory

View File

@@ -0,0 +1,109 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle.hpp"
#include "ck_tile/builder/conv_signature_concepts.hpp"
#include "ck_tile/builder/conv_algorithm_concepts.hpp"
#include "ck_tile/builder/conv_algorithm_limits.hpp"
#include "ck_tile/builder/builder_utils.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_tensor_layout.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_tensor_type.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_elementwise_op.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_tuning_params.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_block_transfer.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_thread_block.hpp"
namespace ck_tile::builder::factory {
// Factory for DeviceGroupedConvBwdWeight_Wmma_CShuffle instance
// of a grouped bwd weight convolution kernel.
template <ConvSignatureDescriptor auto SIGNATURE,
ConvAlgorithmDescriptor auto ALGORITHM,
StringLiteral VERSION>
requires ConvDirectionIsBackwardWeight<SIGNATURE> && Is3D<SIGNATURE>
struct ConvBwdWeightWmmaFactory
{
static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim;
using Layouts = internal::ConvTensorLayouts<SIGNATURE, SPATIAL_DIM>;
using Types = internal::ConvTensorDataTypes<SIGNATURE>;
using Ops = internal::ConvElementwiseOps<SIGNATURE>;
using AlgorithmType = decltype(ALGORITHM);
static constexpr auto BWD_CONV_SPECIALIZATION =
internal::SetBwdWeightConvSpecialization<ALGORITHM>();
static constexpr auto BLOCK = internal::SetThreadBlockInfo<ALGORITHM>();
static constexpr auto GRIDWISE_GEMM = ALGORITHM.gridwise_gemm;
static constexpr auto GRIDWISE_GEMM_PIPELINE_VERSION =
internal::SetGridwiseGemmPipelineVersion<ALGORITHM>();
static constexpr auto LOOP_SCHEDULER = internal::SetLoopScheduler<ALGORITHM>();
static constexpr auto A_BLOCK_TRANSFER =
internal::SetBwdConvBlockTransfer<ALGORITHM.transfer.a>();
static constexpr auto B_BLOCK_TRANSFER =
internal::SetBwdConvBlockTransfer<ALGORITHM.transfer.b>();
static constexpr auto C_BLOCK_TRANSFER = internal::SetCBlockTransfer<SIGNATURE, ALGORITHM>();
// Check limits for the algorithm parameters.
// TODO: Add more limits checks as needed.
static_assert(InputVectorTransferLimits<A_BLOCK_TRANSFER>, "Invalid A block transfer config");
static_assert(InputVectorTransferLimits<B_BLOCK_TRANSFER>, "Invalid B block transfer config");
static_assert(OutputVectorTransferLimits<C_BLOCK_TRANSFER>, "Invalid C block transfer config");
static_assert(AccessOrderLimits3D<A_BLOCK_TRANSFER.thread_cluster_order>,
"Invalid A thread cluster access order");
static_assert(AccessOrderLimits3D<B_BLOCK_TRANSFER.thread_cluster_order>,
"Invalid B thread cluster access order");
static_assert(AccessOrderLimits3D<A_BLOCK_TRANSFER.src_access_order>,
"Invalid A source access order");
static_assert(AccessOrderLimits3D<B_BLOCK_TRANSFER.src_access_order>,
"Invalid B source access order");
// The forward convolution kernel class instance.
using Instance = ck::tensor_operation::device::DeviceGroupedConvBwdWeight_Wmma_CShuffle<
SPATIAL_DIM,
typename Layouts::InLayout,
typename Layouts::WeiLayout,
typename Layouts::OutLayout,
typename Types::InDataType,
typename Types::WeiDataType,
typename Types::OutDataType,
typename Types::AccDataType,
typename Ops::InElementwiseOp,
typename Ops::WeiElementwiseOp,
typename Ops::OutElementwiseOp,
BWD_CONV_SPECIALIZATION,
BLOCK.block_size,
BLOCK.per_block.m,
BLOCK.per_block.n,
BLOCK.per_block.k,
GRIDWISE_GEMM.k1,
GRIDWISE_GEMM.m_per_wmma,
GRIDWISE_GEMM.n_per_wmma,
GRIDWISE_GEMM.m_wmma_per_wave,
GRIDWISE_GEMM.n_wmma_per_wave,
to_sequence_v<A_BLOCK_TRANSFER.thread_cluster_dims>,
to_sequence_v<A_BLOCK_TRANSFER.thread_cluster_order>,
to_sequence_v<A_BLOCK_TRANSFER.src_access_order>,
A_BLOCK_TRANSFER.src_vector_dim,
A_BLOCK_TRANSFER.src_scalar_per_vector,
A_BLOCK_TRANSFER.lds_dst_scalar_per_vector,
A_BLOCK_TRANSFER.lds_padding,
to_sequence_v<B_BLOCK_TRANSFER.thread_cluster_dims>,
to_sequence_v<B_BLOCK_TRANSFER.thread_cluster_order>,
to_sequence_v<B_BLOCK_TRANSFER.src_access_order>,
B_BLOCK_TRANSFER.src_vector_dim,
B_BLOCK_TRANSFER.src_scalar_per_vector,
B_BLOCK_TRANSFER.lds_dst_scalar_per_vector,
B_BLOCK_TRANSFER.lds_padding,
C_BLOCK_TRANSFER.m_xdl_per_wave_per_shuffle,
C_BLOCK_TRANSFER.n_xdl_per_wave_per_shuffle,
to_sequence_v<C_BLOCK_TRANSFER.thread_cluster_dims>,
C_BLOCK_TRANSFER.scalar_per_vector,
ALGORITHM.num_gemm_k_prefetch_stages,
LOOP_SCHEDULER,
GRIDWISE_GEMM_PIPELINE_VERSION>;
};
} // namespace ck_tile::builder::factory

View File

@@ -0,0 +1,109 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle_v3.hpp"
#include "ck_tile/builder/conv_signature_concepts.hpp"
#include "ck_tile/builder/conv_algorithm_concepts.hpp"
#include "ck_tile/builder/conv_algorithm_limits.hpp"
#include "ck_tile/builder/builder_utils.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_tensor_layout.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_tensor_type.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_elementwise_op.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_tuning_params.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_block_transfer.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_thread_block.hpp"
namespace ck_tile::builder::factory {
// Factory for DeviceGroupedConvBwdWeight_Wmma_CShuffle_V3 instance
// of a grouped bwd weight convolution kernel.
template <ConvSignatureDescriptor auto SIGNATURE,
ConvAlgorithmDescriptor auto ALGORITHM,
StringLiteral VERSION>
requires ConvDirectionIsBackwardWeight<SIGNATURE>
struct ConvBwdWeightWmmaV3Factory
{
static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim;
using Layouts = internal::ConvTensorLayouts<SIGNATURE, SPATIAL_DIM>;
using Types = internal::ConvTensorDataTypes<SIGNATURE>;
using Ops = internal::ConvElementwiseOps<SIGNATURE>;
using AlgorithmType = decltype(ALGORITHM);
static constexpr auto BWD_CONV_SPECIALIZATION =
internal::SetBwdWeightConvSpecialization<ALGORITHM>();
static constexpr auto BLOCK = internal::SetThreadBlockInfo<ALGORITHM>();
static constexpr auto GRIDWISE_GEMM = ALGORITHM.gridwise_gemm;
static constexpr auto A_BLOCK_TRANSFER =
internal::SetBwdConvBlockTransfer<ALGORITHM.transfer.a>();
static constexpr auto B_BLOCK_TRANSFER =
internal::SetBwdConvBlockTransfer<ALGORITHM.transfer.b>();
static constexpr auto C_BLOCK_TRANSFER = internal::SetCBlockTransfer<SIGNATURE, ALGORITHM>();
static constexpr auto BLOCK_GEMM = internal::SetBlockGemm<ALGORITHM>();
// Check limits for the algorithm parameters.
// TODO: Add more limits checks as needed.
static_assert(InputVectorTransferLimits<A_BLOCK_TRANSFER>, "Invalid A block transfer config");
static_assert(InputVectorTransferLimits<B_BLOCK_TRANSFER>, "Invalid B block transfer config");
static_assert(OutputVectorTransferLimits<C_BLOCK_TRANSFER>, "Invalid C block transfer config");
static_assert(AccessOrderLimits3D<A_BLOCK_TRANSFER.thread_cluster_order>,
"Invalid A thread cluster access order");
static_assert(AccessOrderLimits3D<B_BLOCK_TRANSFER.thread_cluster_order>,
"Invalid B thread cluster access order");
static_assert(AccessOrderLimits3D<A_BLOCK_TRANSFER.src_access_order>,
"Invalid A source access order");
static_assert(AccessOrderLimits3D<B_BLOCK_TRANSFER.src_access_order>,
"Invalid B source access order");
// The forward convolution kernel class instance.
using Instance = ck::tensor_operation::device::DeviceGroupedConvBwdWeight_Wmma_CShuffleV3<
SPATIAL_DIM,
typename Layouts::InLayout,
typename Layouts::WeiLayout,
typename Layouts::OutLayout,
typename Types::InDataType,
typename Types::WeiDataType,
typename Types::OutDataType,
typename Types::AccDataType,
typename Ops::InElementwiseOp,
typename Ops::WeiElementwiseOp,
typename Ops::OutElementwiseOp,
BWD_CONV_SPECIALIZATION,
BLOCK.block_size,
BLOCK.per_block.m,
BLOCK.per_block.n,
BLOCK.per_block.k,
GRIDWISE_GEMM.k1,
GRIDWISE_GEMM.m_per_wmma,
GRIDWISE_GEMM.n_per_wmma,
GRIDWISE_GEMM.m_wmma_per_wave,
GRIDWISE_GEMM.n_wmma_per_wave,
to_sequence_v<A_BLOCK_TRANSFER.thread_cluster_dims>,
to_sequence_v<A_BLOCK_TRANSFER.thread_cluster_order>,
to_sequence_v<A_BLOCK_TRANSFER.src_access_order>,
A_BLOCK_TRANSFER.src_vector_dim,
A_BLOCK_TRANSFER.src_scalar_per_vector,
A_BLOCK_TRANSFER.lds_dst_scalar_per_vector,
A_BLOCK_TRANSFER.lds_padding,
to_sequence_v<B_BLOCK_TRANSFER.thread_cluster_dims>,
to_sequence_v<B_BLOCK_TRANSFER.thread_cluster_order>,
to_sequence_v<B_BLOCK_TRANSFER.src_access_order>,
B_BLOCK_TRANSFER.src_vector_dim,
B_BLOCK_TRANSFER.src_scalar_per_vector,
B_BLOCK_TRANSFER.lds_dst_scalar_per_vector,
B_BLOCK_TRANSFER.lds_padding,
C_BLOCK_TRANSFER.m_xdl_per_wave_per_shuffle,
C_BLOCK_TRANSFER.n_xdl_per_wave_per_shuffle,
to_sequence_v<C_BLOCK_TRANSFER.thread_cluster_dims>,
C_BLOCK_TRANSFER.scalar_per_vector,
BLOCK_GEMM.scheduler,
BLOCK_GEMM.pipeline_version,
typename Types::OutComputeType,
typename Types::InComputeType,
ALGORITHM.max_transpose_transfer_src_scalar_per_vector,
ALGORITHM.max_transpose_transfer_dst_scalar_per_vector>;
};
} // namespace ck_tile::builder::factory

View File

@@ -0,0 +1,103 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp"
#include "ck_tile/builder/conv_signature_concepts.hpp"
#include "ck_tile/builder/conv_algorithm_concepts.hpp"
#include "ck_tile/builder/conv_algorithm_limits.hpp"
#include "ck_tile/builder/builder_utils.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_tensor_layout.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_tensor_type.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_elementwise_op.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_tuning_params.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_block_transfer.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_thread_block.hpp"
namespace ck_tile::builder::factory {
// Factory for DeviceGroupedConvBwdWeight_Xdl_CShuffle instance
// of a grouped bwd weight convolution kernel.
template <ConvSignatureDescriptor auto SIGNATURE,
ConvAlgorithmDescriptor auto ALGORITHM,
StringLiteral VERSION>
requires ConvDirectionIsBackwardWeight<SIGNATURE>
struct ConvBwdWeightXdlFactory
{
static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim;
using Layouts = internal::ConvTensorLayouts<SIGNATURE, SPATIAL_DIM>;
using Types = internal::ConvTensorDataTypes<SIGNATURE>;
using Ops = internal::ConvElementwiseOps<SIGNATURE>;
using AlgorithmType = decltype(ALGORITHM);
static constexpr auto BWD_CONV_SPECIALIZATION =
internal::SetBwdWeightConvSpecialization<ALGORITHM>();
static constexpr auto BLOCK = internal::SetThreadBlockInfo<ALGORITHM>();
static constexpr auto GRIDWISE_GEMM = ALGORITHM.gridwise_gemm;
static constexpr auto XDL_PARAMS = GRIDWISE_GEMM.xdl_params;
static constexpr auto A_BLOCK_TRANSFER =
internal::SetBwdConvBlockTransfer<ALGORITHM.transfer.a>();
static constexpr auto B_BLOCK_TRANSFER =
internal::SetBwdConvBlockTransfer<ALGORITHM.transfer.b>();
static constexpr auto C_BLOCK_TRANSFER = internal::SetCBlockTransfer<SIGNATURE, ALGORITHM>();
// Check limits for the algorithm parameters.
// TODO: Add more limits checks as needed.
static_assert(InputVectorTransferLimits<A_BLOCK_TRANSFER>);
static_assert(InputVectorTransferLimits<B_BLOCK_TRANSFER>);
static_assert(OutputVectorTransferLimits<C_BLOCK_TRANSFER>);
static_assert(AccessOrderLimits4D<A_BLOCK_TRANSFER.thread_cluster_order>);
static_assert(AccessOrderLimits4D<B_BLOCK_TRANSFER.thread_cluster_order>);
static_assert(AccessOrderLimits4D<A_BLOCK_TRANSFER.src_access_order>);
static_assert(AccessOrderLimits4D<B_BLOCK_TRANSFER.src_access_order>);
// The forward convolution kernel class instance.
using Instance = ck::tensor_operation::device::DeviceGroupedConvBwdWeight_Xdl_CShuffle<
SPATIAL_DIM,
typename Layouts::InLayout,
typename Layouts::WeiLayout,
typename Layouts::OutLayout,
typename Types::InDataType,
typename Types::WeiDataType,
typename Types::OutDataType,
typename Types::AccDataType,
typename Ops::InElementwiseOp,
typename Ops::WeiElementwiseOp,
typename Ops::OutElementwiseOp,
BWD_CONV_SPECIALIZATION,
BLOCK.block_size,
BLOCK.per_block.m,
BLOCK.per_block.n,
BLOCK.per_block.k,
GRIDWISE_GEMM.k1,
XDL_PARAMS.m_per_xdl,
XDL_PARAMS.n_per_xdl,
XDL_PARAMS.m_xdl_per_wave,
XDL_PARAMS.n_xdl_per_wave,
to_sequence_v<A_BLOCK_TRANSFER.thread_cluster_dims>,
to_sequence_v<A_BLOCK_TRANSFER.thread_cluster_order>,
to_sequence_v<A_BLOCK_TRANSFER.src_access_order>,
A_BLOCK_TRANSFER.src_vector_dim,
A_BLOCK_TRANSFER.src_scalar_per_vector,
A_BLOCK_TRANSFER.lds_dst_scalar_per_vector,
A_BLOCK_TRANSFER.lds_padding,
to_sequence_v<B_BLOCK_TRANSFER.thread_cluster_dims>,
to_sequence_v<B_BLOCK_TRANSFER.thread_cluster_order>,
to_sequence_v<B_BLOCK_TRANSFER.src_access_order>,
B_BLOCK_TRANSFER.src_vector_dim,
B_BLOCK_TRANSFER.src_scalar_per_vector,
B_BLOCK_TRANSFER.lds_dst_scalar_per_vector,
B_BLOCK_TRANSFER.lds_padding,
C_BLOCK_TRANSFER.m_xdl_per_wave_per_shuffle,
C_BLOCK_TRANSFER.n_xdl_per_wave_per_shuffle,
to_sequence_v<C_BLOCK_TRANSFER.thread_cluster_dims>,
C_BLOCK_TRANSFER.scalar_per_vector,
typename Types::OutComputeType,
typename Types::InComputeType,
ALGORITHM.max_transpose_transfer_src_scalar_per_vector,
ALGORITHM.max_transpose_transfer_dst_scalar_per_vector>;
};
} // namespace ck_tile::builder::factory

View File

@@ -0,0 +1,108 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp"
#include "ck_tile/builder/conv_signature_concepts.hpp"
#include "ck_tile/builder/conv_algorithm_concepts.hpp"
#include "ck_tile/builder/conv_algorithm_limits.hpp"
#include "ck_tile/builder/builder_utils.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_tensor_layout.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_tensor_type.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_elementwise_op.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_tuning_params.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_block_transfer.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_thread_block.hpp"
namespace ck_tile::builder::factory {
// Factory for DeviceGroupedConvBwdWeight_Xdl_CShuffleV3 instance
// of a grouped bwd weight convolution kernel.
template <ConvSignatureDescriptor auto SIGNATURE,
ConvAlgorithmDescriptor auto ALGORITHM,
StringLiteral VERSION>
requires ConvDirectionIsBackwardWeight<SIGNATURE>
struct ConvBwdWeightXdlV3Factory
{
static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim;
using Layouts = internal::ConvTensorLayouts<SIGNATURE, SPATIAL_DIM>;
using Types = internal::ConvTensorDataTypes<SIGNATURE>;
using Ops = internal::ConvElementwiseOps<SIGNATURE>;
using AlgorithmType = decltype(ALGORITHM);
static constexpr auto BWD_CONV_SPECIALIZATION =
internal::SetBwdWeightConvSpecialization<ALGORITHM>();
static constexpr auto BLOCK = internal::SetThreadBlockInfo<ALGORITHM>();
static constexpr auto GRIDWISE_GEMM = ALGORITHM.gridwise_gemm;
static constexpr auto XDL_PARAMS = GRIDWISE_GEMM.xdl_params;
static constexpr auto A_BLOCK_TRANSFER =
internal::SetBwdConvBlockTransfer<ALGORITHM.transfer.a>();
static constexpr auto B_BLOCK_TRANSFER =
internal::SetBwdConvBlockTransfer<ALGORITHM.transfer.b>();
static constexpr auto C_BLOCK_TRANSFER = internal::SetCBlockTransfer<SIGNATURE, ALGORITHM>();
static constexpr auto BLOCK_GEMM = internal::SetBlockGemm<ALGORITHM>();
// Check limits for the algorithm parameters.
// TODO: Add more limits checks as needed.
static_assert(InputVectorTransferLimits<A_BLOCK_TRANSFER>, "Invalid A block transfer config");
static_assert(InputVectorTransferLimits<B_BLOCK_TRANSFER>, "Invalid B block transfer config");
static_assert(OutputVectorTransferLimits<C_BLOCK_TRANSFER>, "Invalid C block transfer config");
static_assert(AccessOrderLimits3D<A_BLOCK_TRANSFER.thread_cluster_order>,
"Invalid A thread cluster access order");
static_assert(AccessOrderLimits3D<B_BLOCK_TRANSFER.thread_cluster_order>,
"Invalid B thread cluster access order");
static_assert(AccessOrderLimits3D<A_BLOCK_TRANSFER.src_access_order>,
"Invalid A source access order");
static_assert(AccessOrderLimits3D<B_BLOCK_TRANSFER.src_access_order>,
"Invalid B source access order");
// The forward convolution kernel class instance.
using Instance = ck::tensor_operation::device::DeviceGroupedConvBwdWeight_Xdl_CShuffleV3<
SPATIAL_DIM,
typename Layouts::InLayout,
typename Layouts::WeiLayout,
typename Layouts::OutLayout,
typename Types::InDataType,
typename Types::WeiDataType,
typename Types::OutDataType,
typename Types::AccDataType,
typename Ops::InElementwiseOp,
typename Ops::WeiElementwiseOp,
typename Ops::OutElementwiseOp,
BWD_CONV_SPECIALIZATION,
BLOCK.block_size,
BLOCK.per_block.m,
BLOCK.per_block.n,
BLOCK.per_block.k,
GRIDWISE_GEMM.k1,
XDL_PARAMS.m_per_xdl,
XDL_PARAMS.n_per_xdl,
XDL_PARAMS.m_xdl_per_wave,
XDL_PARAMS.n_xdl_per_wave,
to_sequence_v<A_BLOCK_TRANSFER.thread_cluster_dims>,
to_sequence_v<A_BLOCK_TRANSFER.thread_cluster_order>,
to_sequence_v<A_BLOCK_TRANSFER.src_access_order>,
A_BLOCK_TRANSFER.src_vector_dim,
A_BLOCK_TRANSFER.src_scalar_per_vector,
A_BLOCK_TRANSFER.lds_dst_scalar_per_vector,
A_BLOCK_TRANSFER.lds_padding,
to_sequence_v<B_BLOCK_TRANSFER.thread_cluster_dims>,
to_sequence_v<B_BLOCK_TRANSFER.thread_cluster_order>,
to_sequence_v<B_BLOCK_TRANSFER.src_access_order>,
B_BLOCK_TRANSFER.src_vector_dim,
B_BLOCK_TRANSFER.src_scalar_per_vector,
B_BLOCK_TRANSFER.lds_dst_scalar_per_vector,
B_BLOCK_TRANSFER.lds_padding,
C_BLOCK_TRANSFER.m_xdl_per_wave_per_shuffle,
C_BLOCK_TRANSFER.n_xdl_per_wave_per_shuffle,
to_sequence_v<C_BLOCK_TRANSFER.thread_cluster_dims>,
C_BLOCK_TRANSFER.scalar_per_vector,
BLOCK_GEMM.scheduler,
BLOCK_GEMM.pipeline_version,
typename Types::OutComputeType,
typename Types::InComputeType>;
};
} // namespace ck_tile::builder::factory

View File

@@ -57,6 +57,9 @@
#include "ck_tile/builder/conv_algorithm_concepts.hpp"
#include "ck_tile/builder/types.hpp"
// Compile time diagnostics
#include "ck_tile/builder/factory/conv_algorithms.hpp"
// Include all factory implementations
#include "ck_tile/builder/factory/conv_fwd_v3_factory.hpp"
#include "ck_tile/builder/factory/conv_fwd_xdl_factory.hpp"
@@ -65,6 +68,15 @@
#include "ck_tile/builder/factory/conv_fwd_large_tensor_factory.hpp"
#include "ck_tile/builder/factory/reference_factory.hpp"
#include "ck_tile/builder/factory/conv_tile_factory.hpp"
#include "ck_tile/builder/factory/conv_bwd_weight_xdl_factory.hpp"
#include "ck_tile/builder/factory/conv_bwd_weight_xdl_v3_factory.hpp"
#include "ck_tile/builder/factory/conv_bwd_weight_two_stage_xdl_factory.hpp"
#include "ck_tile/builder/factory/conv_bwd_weight_dl_factory.hpp"
#include "ck_tile/builder/factory/conv_bwd_weight_multi_d_xdl_factory.hpp"
#include "ck_tile/builder/factory/conv_bwd_weight_wmma_v3_factory.hpp"
#include "ck_tile/builder/factory/conv_bwd_weight_two_stage_wmma_v3_factory.hpp"
#include "ck_tile/builder/factory/conv_bwd_weight_wmma_factory.hpp"
#include "ck_tile/builder/factory/conv_bwd_weight_multi_d_wmma_v3_factory.hpp"
namespace ck_tile::builder::factory {
@@ -87,56 +99,6 @@ namespace ck_tile::builder::factory {
//
// TODO: Make this dispatch logic much more robust and clear for users.
// Reference algorithm (simplest implementation for validation)
template <typename T>
concept IsReferenceAlgorithm = ConvAlgorithmDescriptor<T> && requires {
{ T::specialization } -> std::convertible_to<ConvAlgorithmSpecialization>;
requires T::specialization == ConvAlgorithmSpecialization::REFERENCE;
};
// CK Tile kernel
template <typename T>
concept IsTileAlgorithm = ConvAlgorithmDescriptor<T> && SpecifiesTileThreadBlock<T> &&
SpecifiesTileTransfer<T> && SpecifiesTileConvSpecialization<T> &&
SpecifiesTileBlockGemm<T> && SpecifiesTileOptimizations<T>;
// XDL-based kernel with V3 pipeline structure (newer block GEMM pipeline)
template <typename T>
concept IsXdlV3Algorithm =
ConvAlgorithmDescriptor<T> && SpecifiesThreadBlock<T> && SpecifiesGridwiseXdlGemm<T> &&
SpecifiesBlockTransfer<T> && SpecifiesLdsTransfer<T> && SpecifiesThreadClusterAccessOrder<T> &&
SpecifiesSourceAccessOrder<T> && SpecifiesFwdConvSpecialization<T> &&
SpecifiesGemmSpecialization<T> && SpecifiesBlockGemm<T>;
// Standard XDL-based kernel (uses XDLops hardware instructions for matrix multiply)
template <typename T>
concept IsXdlAlgorithm =
ConvAlgorithmDescriptor<T> && SpecifiesThreadBlock<T> && SpecifiesGridwiseXdlGemm<T> &&
SpecifiesBlockTransfer<T> && SpecifiesLdsTransfer<T> && SpecifiesThreadClusterAccessOrder<T> &&
SpecifiesSourceAccessOrder<T> && SpecifiesFwdConvSpecialization<T> &&
SpecifiesGemmSpecialization<T> && SpecifiesNumPrefetchStages<T> &&
SpecifiesNumGroupsToMerge<T> && SpecifiesLoopScheduler<T>;
// WMMA-based kernel (uses Wavefront Matrix-Matrix Accumulate instructions)
template <typename T>
concept IsWmmaAlgorithm =
ConvAlgorithmDescriptor<T> && SpecifiesThreadBlock<T> && SpecifiesGridwiseWmmaGemm<T> &&
SpecifiesBlockTransfer<T> && SpecifiesLdsTransfer<T> && SpecifiesThreadClusterAccessOrder<T> &&
SpecifiesSourceAccessOrder<T> && SpecifiesFwdConvSpecialization<T> &&
SpecifiesGemmSpecialization<T> && SpecifiesNumPrefetchStages<T> && SpecifiesLoopScheduler<T>;
// Specialized DL kernel for specific NHWC/KYXC/NHWK data layouts
template <typename T>
concept IsDlAlgorithm =
ConvAlgorithmDescriptor<T> && SpecifiesThreadBlock<T> && SpecifiesFwdConvSpecialization<T> &&
SpecifiesGemmSpecialization<T> && SpecifiesDlThreadConfig<T> && SpecifiesDlThreadCluster<T> &&
SpecifiesDlBlockTransfer<T> && SpecifiesDlEpilogue<T>;
// XDL-based kernel with large tensor support
template <typename T>
concept IsLargeTensorAlgorithm =
IsXdlAlgorithm<decltype(T::base_algorithm)> && SpecifiesLargeTensorSupport<T>;
template <ConvSignatureDescriptor auto SIGNATURE,
ConvAlgorithmDescriptor auto ALGORITHM,
StringLiteral VERSION>
@@ -145,35 +107,35 @@ constexpr auto make_conv_instance()
using AlgoType = std::remove_const_t<decltype(ALGORITHM)>;
// Reference algorithm supports all directions
if constexpr(IsReferenceAlgorithm<AlgoType>)
if constexpr(ReferenceAlgorithm<AlgoType>)
{
return typename ReferenceFactory<SIGNATURE, ALGORITHM, VERSION>::Instance{};
}
// CK Tile supports common factory for each direction
else if constexpr(IsTileAlgorithm<AlgoType>)
else if constexpr(TileAlgorithm<AlgoType>)
{
return typename ConvTileFactory<SIGNATURE, ALGORITHM, VERSION>::Instance{};
}
// Forward direction (supports most algorithm variants)
else if constexpr(ConvDirectionIsForward<SIGNATURE>)
{
if constexpr(IsXdlV3Algorithm<AlgoType>)
if constexpr(FwdXdlV3Algorithm<AlgoType>)
{
return typename ConvFwdXdlV3Factory<SIGNATURE, ALGORITHM, VERSION>::Instance{};
}
else if constexpr(IsXdlAlgorithm<AlgoType>)
else if constexpr(FwdXdlAlgorithm<AlgoType>)
{
return typename ConvFwdXdlFactory<SIGNATURE, ALGORITHM, VERSION>::Instance{};
}
else if constexpr(IsWmmaAlgorithm<AlgoType>)
else if constexpr(FwdWmmaAlgorithm<AlgoType>)
{
return typename ConvFwdWmmaFactory<SIGNATURE, ALGORITHM, VERSION>::Instance{};
}
else if constexpr(IsDlAlgorithm<AlgoType>)
else if constexpr(FwdDlAlgorithm<AlgoType>)
{
return typename ConvFwdDlFactory<SIGNATURE, ALGORITHM, VERSION>::Instance{};
}
else if constexpr(IsLargeTensorAlgorithm<AlgoType>)
else if constexpr(LargeTensorAlgorithm<AlgoType>)
{
return typename ConvFwdLargeTensorFactory<SIGNATURE, ALGORITHM, VERSION>::Instance{};
}
@@ -197,10 +159,55 @@ constexpr auto make_conv_instance()
// Backward weight direction (will expand with more algorithms in the future)
else if constexpr(ConvDirectionIsBackwardWeight<SIGNATURE>)
{
static_assert(false,
"Backward weight convolution: Only reference and tile algorithms "
"supported currently. "
"Optimized kernels (XDL, WMMA, etc.) not yet implemented.");
if constexpr(BwdXdlAlgorithm<AlgoType>)
{
return typename ConvBwdWeightXdlFactory<SIGNATURE, ALGORITHM, VERSION>::Instance{};
}
else if constexpr(BwdXdlV3Algorithm<AlgoType>)
{
return typename ConvBwdWeightXdlV3Factory<SIGNATURE, ALGORITHM, VERSION>::Instance{};
}
else if constexpr(BwdTwoStageXdlAlgorithm<AlgoType>)
{
return
typename ConvBwdWeightTwoStageXdlFactory<SIGNATURE, ALGORITHM, VERSION>::Instance{};
}
else if constexpr(BwdDlAlgorithm<AlgoType>)
{
return typename ConvBwdWeightDlFactory<SIGNATURE, ALGORITHM, VERSION>::Instance{};
}
else if constexpr(BwdMultiDXdlAlgorithm<AlgoType>)
{
return
typename ConvBwdWeightMultiDXdlFactory<SIGNATURE, ALGORITHM, VERSION>::Instance{};
}
else if constexpr(BwdWmmaV3Algorithm<AlgoType>)
{
return typename ConvBwdWeightWmmaV3Factory<SIGNATURE, ALGORITHM, VERSION>::Instance{};
}
else if constexpr(BwdTwoStageWmmaV3Algorithm<AlgoType>)
{
return typename ConvBwdWeightTwoStageWmmaV3Factory<SIGNATURE, ALGORITHM, VERSION>::
Instance{};
}
else if constexpr(BwdWmmaAlgorithm<AlgoType>)
{
return typename ConvBwdWeightWmmaFactory<SIGNATURE, ALGORITHM, VERSION>::Instance{};
}
else if constexpr(BwdMultiDWmmaV3Algorithm<AlgoType>)
{
return typename ConvBwdWeightMultiDWmmaV3Factory<SIGNATURE, ALGORITHM, VERSION>::
Instance{};
}
else
{
static_assert(
false,
"No suitable backward weight convolution kernel factory found for the provided "
"ALGORITHM. The ALGORITHM must satisfy requirements for one of: Reference, Tile, "
"XDL, XDL V3, Two-Stage XDL, DL, Multi-D XDL, WMMA V3, Two-Stage "
"WMMA V3, WMMA, or Multi-D WMMA V3 variant.");
}
}
else
{

View File

@@ -24,10 +24,10 @@ template <ConvSignatureDescriptor auto SIGNATURE,
struct ConvFwdDlFactory
{
static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim;
using Layouts = internal::ConvTensorLayouts<SIGNATURE, SPATIAL_DIM, ConvDirection::FORWARD>;
using Types = internal::FwdConvTensorDataTypes<SIGNATURE>;
using Ops = internal::ElementwiseOps<SIGNATURE>;
using AlgorithmType = decltype(ALGORITHM);
using Layouts = internal::ConvTensorLayouts<SIGNATURE, SPATIAL_DIM>;
using Types = internal::ConvTensorDataTypes<SIGNATURE>;
using Ops = internal::ConvElementwiseOps<SIGNATURE>;
using AlgorithmType = decltype(ALGORITHM);
static constexpr auto FWD_CONV_SPECIALIZATION = internal::SetFwdConvSpecialization<ALGORITHM>();
static constexpr auto GEMM_SPECIALIZATION = internal::SetGemmSpecialization<ALGORITHM>();
@@ -48,7 +48,7 @@ struct ConvFwdDlFactory
using M1N1ThreadClusterN1Xs = to_sequence_v<DL_CLUSTER.n1_xs>;
// A Block Transfer from descriptor - K0_M0_M1_K1 tensor format
static constexpr auto DL_A_TRANSFER = ALGORITHM.transfer.a.block_transfer;
static constexpr auto DL_A_TRANSFER = ALGORITHM.transfer.a;
using ABlockTransferThreadSliceLengths_K0_M0_M1_K1 =
to_sequence_v<DL_A_TRANSFER.thread_slice_lengths>;
using ABlockTransferThreadClusterLengths_K0_M0_M1_K1 =
@@ -64,7 +64,7 @@ struct ConvFwdDlFactory
to_sequence_v<DL_A_TRANSFER.dst_vector_tensor_lengths>;
// B Block Transfer from descriptor - K0_N0_N1_K1 tensor format
static constexpr auto DL_B_TRANSFER = ALGORITHM.transfer.b.block_transfer;
static constexpr auto DL_B_TRANSFER = ALGORITHM.transfer.b;
using BBlockTransferThreadSliceLengths_K0_N0_N1_K1 =
to_sequence_v<DL_B_TRANSFER.thread_slice_lengths>;
using BBlockTransferThreadClusterLengths_K0_N0_N1_K1 =
@@ -80,7 +80,7 @@ struct ConvFwdDlFactory
to_sequence_v<DL_B_TRANSFER.dst_vector_tensor_lengths>;
// C Thread Transfer from descriptor
static constexpr auto DL_C_TRANSFER = ALGORITHM.transfer.c.epilogue;
static constexpr auto DL_C_TRANSFER = ALGORITHM.transfer.c;
using CThreadTransferSrcDstAccessOrder = to_sequence_v<DL_C_TRANSFER.src_dst_access_order>;
static constexpr ck::index_t CThreadTransferSrcDstVectorDim = DL_C_TRANSFER.src_dst_vector_dim;
static constexpr ck::index_t CThreadTransferDstScalarPerVector =
@@ -89,18 +89,18 @@ struct ConvFwdDlFactory
// The DL forward convolution kernel class instance
using Instance = ck::tensor_operation::device::DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK<
SPATIAL_DIM,
typename Types::ADataType,
typename Types::BDataType,
typename Types::DsDataTypes,
typename Types::EDataType,
typename Types::InDataType,
typename Types::WeiDataType,
typename Types::DsDataType,
typename Types::OutDataType,
typename Types::AccDataType,
typename Layouts::ALayout,
typename Layouts::BLayout,
typename Layouts::InLayout,
typename Layouts::WeiLayout,
typename Layouts::DsLayout,
typename Layouts::ELayout,
typename Ops::AElementwiseOp,
typename Ops::BElementwiseOp,
typename Ops::CDEElementwiseOp,
typename Layouts::OutLayout,
typename Ops::InElementwiseOp,
typename Ops::WeiElementwiseOp,
typename Ops::OutElementwiseOp,
FWD_CONV_SPECIALIZATION,
GEMM_SPECIALIZATION,
BLOCK.block_size,

View File

@@ -26,68 +26,65 @@ template <ConvSignatureDescriptor auto SIGNATURE,
struct ConvFwdLargeTensorFactory
{
static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim;
using Layouts = internal::ConvTensorLayouts<SIGNATURE, SPATIAL_DIM, ConvDirection::FORWARD>;
using Types = internal::FwdConvTensorDataTypes<SIGNATURE>;
using Ops = internal::ElementwiseOps<SIGNATURE>;
using AlgorithmType = decltype(ALGORITHM);
using Layouts = internal::ConvTensorLayouts<SIGNATURE, SPATIAL_DIM>;
using Types = internal::ConvTensorDataTypes<SIGNATURE>;
using Ops = internal::ConvElementwiseOps<SIGNATURE>;
using AlgorithmType = decltype(ALGORITHM);
static constexpr auto BASE_ALGORITHM = ALGORITHM.base_algorithm;
static constexpr auto FWD_CONV_SPECIALIZATION =
internal::SetFwdConvSpecialization<BASE_ALGORITHM>();
static constexpr auto GEMM_SPECIALIZATION = internal::SetGemmSpecialization<BASE_ALGORITHM>();
static constexpr auto FWD_CONV_SPECIALIZATION = internal::SetFwdConvSpecialization<ALGORITHM>();
static constexpr auto GEMM_SPECIALIZATION = internal::SetGemmSpecialization<ALGORITHM>();
static constexpr internal::ConvSpec SPECIALIZATION{.conv_spec = FWD_CONV_SPECIALIZATION,
.gemm_spec = GEMM_SPECIALIZATION};
static constexpr auto LOOP_SCHEDULER = internal::SetLoopScheduler<BASE_ALGORITHM>();
static constexpr auto BLOCK = internal::SetThreadBlockInfo<BASE_ALGORITHM>();
static constexpr auto GRIDWISE_GEMM = BASE_ALGORITHM.gridwise_gemm;
static constexpr auto LOOP_SCHEDULER = internal::SetLoopScheduler<ALGORITHM>();
static constexpr auto BLOCK = internal::SetThreadBlockInfo<ALGORITHM>();
static constexpr auto GRIDWISE_GEMM = ALGORITHM.gridwise_gemm;
static constexpr auto XDL_PARAMS = GRIDWISE_GEMM.xdl_params;
static constexpr auto A_BLOCK_TRANSFER =
internal::SetFwdConvBlockTransfer<BASE_ALGORITHM.transfer.a>();
internal::SetFwdConvBlockTransfer<ALGORITHM.transfer.a>();
static constexpr auto B_BLOCK_TRANSFER =
internal::SetFwdConvBlockTransfer<BASE_ALGORITHM.transfer.b>();
static constexpr auto C_BLOCK_TRANSFER =
internal::SetCBlockTransfer<SIGNATURE, BASE_ALGORITHM>();
internal::SetFwdConvBlockTransfer<ALGORITHM.transfer.b>();
static constexpr auto C_BLOCK_TRANSFER = internal::SetCBlockTransfer<SIGNATURE, ALGORITHM>();
// Check limits for the algorithm parameters.
static_assert(InputVectorTransferLimits<A_BLOCK_TRANSFER>);
static_assert(InputVectorTransferLimits<B_BLOCK_TRANSFER>);
static_assert(OutputVectorTransferLimits<C_BLOCK_TRANSFER>);
static_assert(AccessOrderLimits<A_BLOCK_TRANSFER.thread_cluster_order>);
static_assert(AccessOrderLimits<B_BLOCK_TRANSFER.thread_cluster_order>);
static_assert(AccessOrderLimits<A_BLOCK_TRANSFER.src_access_order>);
static_assert(AccessOrderLimits<B_BLOCK_TRANSFER.src_access_order>);
static_assert(AccessOrderLimits3D<A_BLOCK_TRANSFER.thread_cluster_order>);
static_assert(AccessOrderLimits3D<B_BLOCK_TRANSFER.thread_cluster_order>);
static_assert(AccessOrderLimits3D<A_BLOCK_TRANSFER.src_access_order>);
static_assert(AccessOrderLimits3D<B_BLOCK_TRANSFER.src_access_order>);
// The forward convolution kernel class instance with large tensor support.
using Instance =
ck::tensor_operation::device::DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor<
SPATIAL_DIM,
typename Layouts::ALayout,
typename Layouts::BLayout,
typename Layouts::InLayout,
typename Layouts::WeiLayout,
typename Layouts::DsLayout,
typename Layouts::ELayout,
typename Types::ADataType,
typename Types::BDataType,
typename Layouts::OutLayout,
typename Types::InDataType,
typename Types::WeiDataType,
typename Types::AccDataType,
typename Types::CShuffleDataType,
typename Types::DsDataTypes,
typename Types::EDataType,
typename Ops::AElementwiseOp,
typename Ops::BElementwiseOp,
typename Ops::CDEElementwiseOp,
typename Types::OutComputeType,
typename Types::DsDataType,
typename Types::OutDataType,
typename Ops::InElementwiseOp,
typename Ops::WeiElementwiseOp,
typename Ops::OutElementwiseOp,
SPECIALIZATION.conv_spec,
SPECIALIZATION.gemm_spec,
BASE_ALGORITHM.num_gemm_k_prefetch_stages,
ALGORITHM.num_gemm_k_prefetch_stages,
BLOCK.block_size,
BLOCK.per_block.m,
BLOCK.per_block.n,
BLOCK.per_block.k,
GRIDWISE_GEMM.ak1,
GRIDWISE_GEMM.bk1,
GRIDWISE_GEMM.m_per_xdl,
GRIDWISE_GEMM.n_per_xdl,
GRIDWISE_GEMM.m_xdl_per_wave,
GRIDWISE_GEMM.n_xdl_per_wave,
XDL_PARAMS.m_per_xdl,
XDL_PARAMS.n_per_xdl,
XDL_PARAMS.m_xdl_per_wave,
XDL_PARAMS.n_xdl_per_wave,
to_sequence_v<A_BLOCK_TRANSFER.thread_cluster_dims>,
to_sequence_v<A_BLOCK_TRANSFER.thread_cluster_order>,
to_sequence_v<A_BLOCK_TRANSFER.src_access_order>,
@@ -106,8 +103,8 @@ struct ConvFwdLargeTensorFactory
C_BLOCK_TRANSFER.n_xdl_per_wave_per_shuffle,
to_sequence_v<C_BLOCK_TRANSFER.thread_cluster_dims>,
C_BLOCK_TRANSFER.scalar_per_vector,
typename Types::AComputeType,
typename Types::BComputeType,
typename Types::InComputeType,
typename Types::WeiComputeType,
LOOP_SCHEDULER>;
};

View File

@@ -26,10 +26,10 @@ template <ConvSignatureDescriptor auto SIGNATURE,
struct ConvFwdXdlV3Factory
{
static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim;
using Layouts = internal::ConvTensorLayouts<SIGNATURE, SPATIAL_DIM, ConvDirection::FORWARD>;
using Types = internal::FwdConvTensorDataTypes<SIGNATURE>;
using Ops = internal::ElementwiseOps<SIGNATURE>;
using AlgorithmType = decltype(ALGORITHM);
using Layouts = internal::ConvTensorLayouts<SIGNATURE, SPATIAL_DIM>;
using Types = internal::ConvTensorDataTypes<SIGNATURE>;
using Ops = internal::ConvElementwiseOps<SIGNATURE>;
using AlgorithmType = decltype(ALGORITHM);
static_assert(ALGORITHM.transfer.a.lds_transfer.is_direct_load ==
ALGORITHM.transfer.b.lds_transfer.is_direct_load,
@@ -43,6 +43,7 @@ struct ConvFwdXdlV3Factory
static constexpr auto BLOCK = internal::SetThreadBlockInfo<ALGORITHM>();
static constexpr auto GRIDWISE_GEMM = ALGORITHM.gridwise_gemm;
static constexpr auto XDL_PARAMS = GRIDWISE_GEMM.xdl_params;
static constexpr auto A_BLOCK_TRANSFER =
internal::SetFwdConvBlockTransfer<ALGORITHM.transfer.a>();
static constexpr auto B_BLOCK_TRANSFER =
@@ -55,27 +56,27 @@ struct ConvFwdXdlV3Factory
static_assert(InputVectorTransferLimits<A_BLOCK_TRANSFER>);
static_assert(InputVectorTransferLimits<B_BLOCK_TRANSFER>);
static_assert(OutputVectorTransferLimits<C_BLOCK_TRANSFER>);
static_assert(AccessOrderLimits<A_BLOCK_TRANSFER.thread_cluster_order>);
static_assert(AccessOrderLimits<B_BLOCK_TRANSFER.thread_cluster_order>);
static_assert(AccessOrderLimits<A_BLOCK_TRANSFER.src_access_order>);
static_assert(AccessOrderLimits<B_BLOCK_TRANSFER.src_access_order>);
static_assert(AccessOrderLimits3D<A_BLOCK_TRANSFER.thread_cluster_order>);
static_assert(AccessOrderLimits3D<B_BLOCK_TRANSFER.thread_cluster_order>);
static_assert(AccessOrderLimits3D<A_BLOCK_TRANSFER.src_access_order>);
static_assert(AccessOrderLimits3D<B_BLOCK_TRANSFER.src_access_order>);
// The forward convolution kernel class instance.
using Instance = ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<
SPATIAL_DIM,
typename Layouts::ALayout,
typename Layouts::BLayout,
typename Layouts::InLayout,
typename Layouts::WeiLayout,
typename Layouts::DsLayout,
typename Layouts::ELayout,
typename Types::ADataType,
typename Types::BDataType,
typename Layouts::OutLayout,
typename Types::InDataType,
typename Types::WeiDataType,
typename Types::AccDataType,
typename Types::CShuffleDataType,
typename Types::DsDataTypes,
typename Types::EDataType,
typename Ops::AElementwiseOp,
typename Ops::BElementwiseOp,
typename Ops::CDEElementwiseOp,
typename Types::OutComputeType,
typename Types::DsDataType,
typename Types::OutDataType,
typename Ops::InElementwiseOp,
typename Ops::WeiElementwiseOp,
typename Ops::OutElementwiseOp,
SPECIALIZATION.conv_spec,
SPECIALIZATION.gemm_spec,
BLOCK.block_size,
@@ -84,10 +85,10 @@ struct ConvFwdXdlV3Factory
BLOCK.per_block.k,
GRIDWISE_GEMM.ak1,
GRIDWISE_GEMM.bk1,
GRIDWISE_GEMM.m_per_xdl,
GRIDWISE_GEMM.n_per_xdl,
GRIDWISE_GEMM.m_xdl_per_wave,
GRIDWISE_GEMM.n_xdl_per_wave,
XDL_PARAMS.m_per_xdl,
XDL_PARAMS.n_per_xdl,
XDL_PARAMS.m_xdl_per_wave,
XDL_PARAMS.n_xdl_per_wave,
to_sequence_v<A_BLOCK_TRANSFER.thread_cluster_dims>,
to_sequence_v<A_BLOCK_TRANSFER.thread_cluster_order>,
to_sequence_v<A_BLOCK_TRANSFER.src_access_order>,
@@ -108,8 +109,8 @@ struct ConvFwdXdlV3Factory
C_BLOCK_TRANSFER.scalar_per_vector,
BLOCK_GEMM.scheduler,
BLOCK_GEMM.pipeline_version,
typename Types::AComputeType,
typename Types::BComputeType,
typename Types::InComputeType,
typename Types::WeiComputeType,
IS_DIRECT_LOAD>;
};

View File

@@ -26,10 +26,10 @@ template <ConvSignatureDescriptor auto SIGNATURE,
struct ConvFwdWmmaFactory
{
static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim;
using Layouts = internal::ConvTensorLayouts<SIGNATURE, SPATIAL_DIM, ConvDirection::FORWARD>;
using Types = internal::FwdConvTensorDataTypes<SIGNATURE>;
using Ops = internal::ElementwiseOps<SIGNATURE>;
using AlgorithmType = decltype(ALGORITHM);
using Layouts = internal::ConvTensorLayouts<SIGNATURE, SPATIAL_DIM>;
using Types = internal::ConvTensorDataTypes<SIGNATURE>;
using Ops = internal::ConvElementwiseOps<SIGNATURE>;
using AlgorithmType = decltype(ALGORITHM);
static constexpr auto FWD_CONV_SPECIALIZATION = internal::SetFwdConvSpecialization<ALGORITHM>();
static constexpr auto GEMM_SPECIALIZATION = internal::SetGemmSpecialization<ALGORITHM>();
@@ -52,27 +52,27 @@ struct ConvFwdWmmaFactory
static_assert(InputVectorTransferLimits<A_BLOCK_TRANSFER>);
static_assert(InputVectorTransferLimits<B_BLOCK_TRANSFER>);
static_assert(OutputVectorTransferLimits<C_BLOCK_TRANSFER>);
static_assert(AccessOrderLimits<A_BLOCK_TRANSFER.thread_cluster_order>);
static_assert(AccessOrderLimits<B_BLOCK_TRANSFER.thread_cluster_order>);
static_assert(AccessOrderLimits<A_BLOCK_TRANSFER.src_access_order>);
static_assert(AccessOrderLimits<B_BLOCK_TRANSFER.src_access_order>);
static_assert(AccessOrderLimits3D<A_BLOCK_TRANSFER.thread_cluster_order>);
static_assert(AccessOrderLimits3D<B_BLOCK_TRANSFER.thread_cluster_order>);
static_assert(AccessOrderLimits3D<A_BLOCK_TRANSFER.src_access_order>);
static_assert(AccessOrderLimits3D<B_BLOCK_TRANSFER.src_access_order>);
// The forward convolution kernel class instance.
using Instance = ck::tensor_operation::device::DeviceGroupedConvFwdMultipleD_Wmma_CShuffle<
SPATIAL_DIM,
typename Layouts::ALayout,
typename Layouts::BLayout,
typename Layouts::InLayout,
typename Layouts::WeiLayout,
typename Layouts::DsLayout,
typename Layouts::ELayout,
typename Types::ADataType,
typename Types::BDataType,
typename Layouts::OutLayout,
typename Types::InDataType,
typename Types::WeiDataType,
typename Types::AccDataType,
typename Types::CShuffleDataType,
typename Types::DsDataTypes,
typename Types::EDataType,
typename Ops::AElementwiseOp,
typename Ops::BElementwiseOp,
typename Ops::CDEElementwiseOp,
typename Types::OutComputeType,
typename Types::DsDataType,
typename Types::OutDataType,
typename Ops::InElementwiseOp,
typename Ops::WeiElementwiseOp,
typename Ops::OutElementwiseOp,
SPECIALIZATION.conv_spec,
SPECIALIZATION.gemm_spec,
ALGORITHM.num_gemm_k_prefetch_stages,

View File

@@ -26,10 +26,10 @@ template <ConvSignatureDescriptor auto SIGNATURE,
struct ConvFwdXdlFactory
{
static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim;
using Layouts = internal::ConvTensorLayouts<SIGNATURE, SPATIAL_DIM, ConvDirection::FORWARD>;
using Types = internal::FwdConvTensorDataTypes<SIGNATURE>;
using Ops = internal::ElementwiseOps<SIGNATURE>;
using AlgorithmType = decltype(ALGORITHM);
using Layouts = internal::ConvTensorLayouts<SIGNATURE, SPATIAL_DIM>;
using Types = internal::ConvTensorDataTypes<SIGNATURE>;
using Ops = internal::ConvElementwiseOps<SIGNATURE>;
using AlgorithmType = decltype(ALGORITHM);
static constexpr auto FWD_CONV_SPECIALIZATION = internal::SetFwdConvSpecialization<ALGORITHM>();
static constexpr auto GEMM_SPECIALIZATION = internal::SetGemmSpecialization<ALGORITHM>();
@@ -39,6 +39,7 @@ struct ConvFwdXdlFactory
static constexpr auto LOOP_SCHEDULER = internal::SetLoopScheduler<ALGORITHM>();
static constexpr auto BLOCK = internal::SetThreadBlockInfo<ALGORITHM>();
static constexpr auto GRIDWISE_GEMM = ALGORITHM.gridwise_gemm;
static constexpr auto XDL_PARAMS = GRIDWISE_GEMM.xdl_params;
static constexpr auto A_BLOCK_TRANSFER =
internal::SetFwdConvBlockTransfer<ALGORITHM.transfer.a>();
static constexpr auto B_BLOCK_TRANSFER =
@@ -50,27 +51,27 @@ struct ConvFwdXdlFactory
static_assert(InputVectorTransferLimits<A_BLOCK_TRANSFER>);
static_assert(InputVectorTransferLimits<B_BLOCK_TRANSFER>);
static_assert(OutputVectorTransferLimits<C_BLOCK_TRANSFER>);
static_assert(AccessOrderLimits<A_BLOCK_TRANSFER.thread_cluster_order>);
static_assert(AccessOrderLimits<B_BLOCK_TRANSFER.thread_cluster_order>);
static_assert(AccessOrderLimits<A_BLOCK_TRANSFER.src_access_order>);
static_assert(AccessOrderLimits<B_BLOCK_TRANSFER.src_access_order>);
static_assert(AccessOrderLimits3D<A_BLOCK_TRANSFER.thread_cluster_order>);
static_assert(AccessOrderLimits3D<B_BLOCK_TRANSFER.thread_cluster_order>);
static_assert(AccessOrderLimits3D<A_BLOCK_TRANSFER.src_access_order>);
static_assert(AccessOrderLimits3D<B_BLOCK_TRANSFER.src_access_order>);
// The forward convolution kernel class instance.
using Instance = ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<
SPATIAL_DIM,
typename Layouts::ALayout,
typename Layouts::BLayout,
typename Layouts::InLayout,
typename Layouts::WeiLayout,
typename Layouts::DsLayout,
typename Layouts::ELayout,
typename Types::ADataType,
typename Types::BDataType,
typename Layouts::OutLayout,
typename Types::InDataType,
typename Types::WeiDataType,
typename Types::AccDataType,
typename Types::CShuffleDataType,
typename Types::DsDataTypes,
typename Types::EDataType,
typename Ops::AElementwiseOp,
typename Ops::BElementwiseOp,
typename Ops::CDEElementwiseOp,
typename Types::OutComputeType,
typename Types::DsDataType,
typename Types::OutDataType,
typename Ops::InElementwiseOp,
typename Ops::WeiElementwiseOp,
typename Ops::OutElementwiseOp,
SPECIALIZATION.conv_spec,
SPECIALIZATION.gemm_spec,
ALGORITHM.num_gemm_k_prefetch_stages,
@@ -80,10 +81,10 @@ struct ConvFwdXdlFactory
BLOCK.per_block.k,
GRIDWISE_GEMM.ak1,
GRIDWISE_GEMM.bk1,
GRIDWISE_GEMM.m_per_xdl,
GRIDWISE_GEMM.n_per_xdl,
GRIDWISE_GEMM.m_xdl_per_wave,
GRIDWISE_GEMM.n_xdl_per_wave,
XDL_PARAMS.m_per_xdl,
XDL_PARAMS.n_per_xdl,
XDL_PARAMS.m_xdl_per_wave,
XDL_PARAMS.n_xdl_per_wave,
to_sequence_v<A_BLOCK_TRANSFER.thread_cluster_dims>,
to_sequence_v<A_BLOCK_TRANSFER.thread_cluster_order>,
to_sequence_v<A_BLOCK_TRANSFER.src_access_order>,
@@ -102,10 +103,10 @@ struct ConvFwdXdlFactory
C_BLOCK_TRANSFER.n_xdl_per_wave_per_shuffle,
to_sequence_v<C_BLOCK_TRANSFER.thread_cluster_dims>,
C_BLOCK_TRANSFER.scalar_per_vector,
typename Types::AComputeType,
typename Types::BComputeType,
typename Types::InComputeType,
typename Types::WeiComputeType,
LOOP_SCHEDULER,
ALGORITHM.num_groups_to_merge>;
ALGORITHM.num_conv_groups_to_merge>;
};
} // namespace ck_tile::builder::factory

View File

@@ -10,27 +10,28 @@
namespace ck_tile::builder::factory::internal {
// Block transfer parameters for A or B tensor.
template <size_t ThreadClusterRank = 3>
struct BlockTransfer
{
ck::Array<size_t, 3> thread_cluster_dims = {0, 0, 0}; // k0, m, k1
ck::Array<size_t, 3> thread_cluster_order = {0, 0, 0};
ck::Array<size_t, 3> src_access_order = {0, 0, 0};
size_t src_vector_dim = 0;
size_t src_scalar_per_vector = 0;
size_t lds_dst_scalar_per_vector = 0;
bool is_direct_load = false;
bool lds_padding = false;
ck::Array<size_t, ThreadClusterRank> thread_cluster_dims{};
ck::Array<size_t, ThreadClusterRank> thread_cluster_order{};
ck::Array<size_t, ThreadClusterRank> src_access_order{};
size_t src_vector_dim = 0;
size_t src_scalar_per_vector = 0;
size_t lds_dst_scalar_per_vector = 0;
bool is_direct_load = false;
bool lds_padding = false;
};
template <auto TRANSFER>
constexpr BlockTransfer SetFwdConvBlockTransfer()
constexpr BlockTransfer<> SetFwdConvBlockTransfer()
{
auto& block_xfer = TRANSFER.block_transfer;
auto& block_order = TRANSFER.block_transfer_access_order;
auto& src_order = TRANSFER.src_access_order;
auto& lds_cfg = TRANSFER.lds_transfer;
return BlockTransfer{
return BlockTransfer<>{
.thread_cluster_dims = {block_xfer.k0, block_xfer.m_n, block_xfer.k1},
.thread_cluster_order = {block_order.order[0], block_order.order[1], block_order.order[2]},
.src_access_order = {src_order.order[0], src_order.order[1], src_order.order[2]},
@@ -42,6 +43,59 @@ constexpr BlockTransfer SetFwdConvBlockTransfer()
};
}
template <auto TRANSFER>
constexpr auto SetBwdConvBlockTransfer()
{
auto& block_xfer = TRANSFER.block_transfer;
auto& block_order = TRANSFER.block_transfer_access_order;
auto& src_order = TRANSFER.src_access_order;
auto& lds_cfg = TRANSFER.lds_transfer;
constexpr auto array_length = block_order.order.size();
static_assert(block_order.order.size() == src_order.order.size(),
"Mismatched size between block order and src order");
if constexpr(array_length == 3)
{
return BlockTransfer<3>{
.thread_cluster_dims = {block_xfer.k0, block_xfer.m_n, block_xfer.k1},
.thread_cluster_order = {block_order.order[0],
block_order.order[1],
block_order.order[2]},
.src_access_order = {src_order.order[0], src_order.order[1], src_order.order[2]},
.src_vector_dim = lds_cfg.src_vector_dim,
.src_scalar_per_vector = lds_cfg.src_scalar_per_vector,
.lds_dst_scalar_per_vector = lds_cfg.lds_dst_scalar_per_vector,
.lds_padding = lds_cfg.lds_padding,
};
}
else if constexpr(array_length == 4)
{
return BlockTransfer<4>{
.thread_cluster_dims = {block_xfer.k_batch_size,
block_xfer.k0,
block_xfer.m_n,
block_xfer.k1},
.thread_cluster_order = {block_order.order[0],
block_order.order[1],
block_order.order[2],
block_order.order[3]},
.src_access_order = {src_order.order[0],
src_order.order[1],
src_order.order[2],
src_order.order[3]},
.src_vector_dim = lds_cfg.src_vector_dim,
.src_scalar_per_vector = lds_cfg.src_scalar_per_vector,
.lds_dst_scalar_per_vector = lds_cfg.lds_dst_scalar_per_vector,
.lds_padding = lds_cfg.lds_padding,
};
}
else
{
static_assert(false, "Internal error: Unsupported array length");
}
}
// Block transfer parameters for C tensor.
struct CBlockTransfer
{

View File

@@ -62,14 +62,15 @@ consteval auto GetElementwiseOp()
}
template <auto Sig>
struct ElementwiseOps
struct ConvElementwiseOps
{
static constexpr auto input_op = GetElementwiseOp<Sig.input>();
static constexpr auto weight_op = GetElementwiseOp<Sig.weight>();
static constexpr auto output_op = GetElementwiseOp<Sig.output>();
using AElementwiseOp = typename decltype(input_op)::Op;
using BElementwiseOp = typename decltype(weight_op)::Op;
using CDEElementwiseOp = typename decltype(output_op)::Op;
using InElementwiseOp = typename decltype(input_op)::Op;
using WeiElementwiseOp = typename decltype(weight_op)::Op;
using OutElementwiseOp = typename decltype(output_op)::Op;
};
} // namespace ck_tile::builder::factory::internal

View File

@@ -190,7 +190,7 @@ consteval auto GetAuxiliaryTensorLayoutTuple(std::index_sequence<Indices...>)
decltype(TensorLayoutToCK<AuxiliaryTensorConfigsArray[Indices].layout>())...>{};
}
template <auto AuxiliaryTensorConfigsValue, size_t SPATIAL_DIM, ConvDirection DIR>
template <auto AuxiliaryTensorConfigsValue, size_t SPATIAL_DIM>
requires(ConvSpatialDim<SPATIAL_DIM>)
struct AuxiliaryTensorLayouts
{
@@ -200,34 +200,32 @@ struct AuxiliaryTensorLayouts
};
// TODO: Currently only the ouput tensor can have auxiliary tensors (e.g., bias).
template <auto Signature, size_t SPATIAL_DIM, ConvDirection DIR>
template <auto Signature, size_t SPATIAL_DIM>
requires(HasElementwiseOpWithAuxiliaryOperands<decltype(Signature.output)>)
consteval auto GetAuxiliaryTensorLayouts()
{
return AuxiliaryTensorLayouts<Signature.output.operation.auxiliary_operand_configs,
SPATIAL_DIM,
DIR>{};
SPATIAL_DIM>{};
}
template <auto Signature, size_t SPATIAL_DIM, ConvDirection DIR>
template <auto Signature, size_t SPATIAL_DIM>
requires(!HasElementwiseOpWithAuxiliaryOperands<decltype(Signature.output)>)
consteval auto GetAuxiliaryTensorLayouts()
{
return EmptyAuxiliaryTensorLayout{};
}
template <auto Signature, size_t SPATIAL_DIM, ConvDirection DIR>
template <auto Signature, size_t SPATIAL_DIM>
requires(ConvSpatialDim<SPATIAL_DIM> &&
ValidConvInputLayoutForSpatialDim<Signature.input.config.layout, SPATIAL_DIM> &&
ValidConvWeightLayoutForSpatialDim<Signature.weight.config.layout, SPATIAL_DIM> &&
ValidConvOutputLayoutForSpatialDim<Signature.output.config.layout, SPATIAL_DIM>)
struct ConvTensorLayouts
{
static_assert(DIR == ConvDirection::FORWARD, "Only Forward convolution is supported.");
using ALayout = decltype(TensorLayoutToCK<Signature.input.config.layout>());
using BLayout = decltype(TensorLayoutToCK<Signature.weight.config.layout>());
using ELayout = decltype(TensorLayoutToCK<Signature.output.config.layout>());
using DsLayout = decltype(GetAuxiliaryTensorLayouts<Signature, SPATIAL_DIM, DIR>())::type;
using InLayout = decltype(TensorLayoutToCK<Signature.input.config.layout>());
using WeiLayout = decltype(TensorLayoutToCK<Signature.weight.config.layout>());
using OutLayout = decltype(TensorLayoutToCK<Signature.output.config.layout>());
using DsLayout = decltype(GetAuxiliaryTensorLayouts<Signature, SPATIAL_DIM>())::type;
};
} // namespace ck_tile::builder::factory::internal

View File

@@ -156,7 +156,7 @@ consteval auto GetAuxiliaryTensorDataTypes()
}
template <auto Signature>
struct FwdConvTensorDataTypes
struct ConvTensorDataTypes
{
static constexpr auto input_types =
GetTensorDataAndComputeTypes<Signature.input.config, Signature.data_type>();
@@ -165,20 +165,17 @@ struct FwdConvTensorDataTypes
static constexpr auto output_types =
GetTensorDataAndComputeTypes<Signature.output.config, Signature.data_type>();
using ADataType = typename decltype(input_types.first)::type;
using AComputeType = typename decltype(input_types.second)::type;
using BDataType = typename decltype(weight_types.first)::type;
using BComputeType = typename decltype(weight_types.second)::type;
using InDataType = typename decltype(input_types.first)::type;
using InComputeType = typename decltype(input_types.second)::type;
using WeiDataType = typename decltype(weight_types.first)::type;
using WeiComputeType = typename decltype(weight_types.second)::type;
using OutDataType = typename decltype(output_types.first)::type;
using OutComputeType = typename decltype(output_types.second)::type;
using AccDataType =
typename decltype(GetTensorAccumulationType<Signature.accumulation_data_type,
Signature.data_type>())::type;
using EDataType = typename decltype(output_types.first)::type;
// This is the "compute" type for output.
using CShuffleDataType = typename decltype(output_types.second)::type;
// Data types for the auxiliary tensors (e.g., bias).
using DsDataTypes = typename decltype(GetAuxiliaryTensorDataTypes<Signature>())::type;
using DsDataType = typename decltype(GetAuxiliaryTensorDataTypes<Signature>())::type;
};
} // namespace ck_tile::builder::factory::internal

View File

@@ -4,6 +4,7 @@
#pragma once
#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp"
#include "ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp"
@@ -37,7 +38,7 @@ struct BlockGemmSpec
template <ConvAlgorithmDescriptor auto ALGORITHM>
consteval BlockGemmSpec SetBlockGemm()
{
constexpr auto& BG = ALGORITHM.block_gemm;
constexpr auto& BG = ALGORITHM.block_gemm_pipeline;
ck::BlockGemmPipelineScheduler scheduler;
ck::BlockGemmPipelineVersion version;
@@ -82,7 +83,7 @@ consteval ck::LoopScheduler SetLoopScheduler()
template <ConvAlgorithmDescriptor auto ALGORITHM>
consteval ck::PipelineVersion SetGridwiseGemmPipelineVersion()
{
constexpr auto pipeline_version = ALGORITHM.gridwise_gemm.pipeline_version;
constexpr auto pipeline_version = ALGORITHM.pipeline_version;
using ck_pipeline = ck::PipelineVersion;
switch(pipeline_version)
{
@@ -149,12 +150,30 @@ consteval ck::tensor_operation::device::ConvolutionForwardSpecialization SetFwdC
using ck_conv_spec = ck::tensor_operation::device::ConvolutionForwardSpecialization;
switch(specialization)
{
case ConvFwdSpecialization::DEFAULT: return ck_conv_spec::Default;
case ConvFwdSpecialization::FILTER_1X1_PAD0: return ck_conv_spec::Filter1x1Pad0;
case ConvFwdSpecialization::FILTER_1X1_STRIDE1_PAD0: return ck_conv_spec::Filter1x1Stride1Pad0;
case ConvFwdSpecialization::FILTER_3x3: return ck_conv_spec::Filter3x3;
case ConvFwdSpecialization::ODD_C: return ck_conv_spec::OddC;
default: throw "Unknown ConvFwdSpecialization";
case ConvSpecialization::DEFAULT: return ck_conv_spec::Default;
case ConvSpecialization::FILTER_1X1_PAD0: return ck_conv_spec::Filter1x1Pad0;
case ConvSpecialization::FILTER_1X1_STRIDE1_PAD0: return ck_conv_spec::Filter1x1Stride1Pad0;
case ConvSpecialization::FILTER_3x3: return ck_conv_spec::Filter3x3;
case ConvSpecialization::ODD_C: return ck_conv_spec::OddC;
default: throw "Unsupported ConvSpecialization";
}
}
template <ConvAlgorithmDescriptor auto ALGORITHM>
consteval ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization
SetBwdWeightConvSpecialization()
{
constexpr auto specialization = ALGORITHM.bwd_weight_specialization;
using ck_conv_spec = ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization;
switch(specialization)
{
case ConvSpecialization::DEFAULT: return ck_conv_spec::Default;
case ConvSpecialization::FILTER_1X1_PAD0: return ck_conv_spec::Filter1x1Pad0;
case ConvSpecialization::FILTER_1X1_STRIDE1_PAD0: return ck_conv_spec::Filter1x1Stride1Pad0;
case ConvSpecialization::ODD_C: return ck_conv_spec::OddC;
case ConvSpecialization::FILTER_3x3:
throw "FILTER_3x3 is not supported for backward weight convolution.";
default: throw "Unsupported ConvSpecialization";
}
}

View File

@@ -26,11 +26,11 @@ struct ReferenceFactory
static constexpr auto kValidation = (internal::ValidateReferenceSignature<SIGNATURE>(), 0);
static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim;
using Types = internal::FwdConvTensorDataTypes<SIGNATURE>;
using Types = internal::ConvTensorDataTypes<SIGNATURE>;
using InDataType = typename Types::ADataType;
using WeiDataType = typename Types::BDataType;
using OutDataType = typename Types::EDataType;
using InDataType = typename Types::InDataType;
using WeiDataType = typename Types::WeiDataType;
using OutDataType = typename Types::OutDataType;
struct Instance
{

View File

@@ -63,10 +63,7 @@ struct GemmAlgorithmInfo
OutputTileTransferInfo c_tile_transfer;
builder::PipelineVersion pipeline_version;
builder::PipelineScheduler pipeline_scheduler;
std::variant<builder::ConvFwdSpecialization,
builder::ConvBwdDataSpecialization,
builder::ConvBwdWeightSpecialization>
conv_specialization;
builder::ConvSpecialization conv_specialization;
builder::GemmPadding padding;
};

View File

@@ -197,18 +197,16 @@ constexpr builder::ConvDirection conv_direction()
/// @brief Derives the convolution-specific specialization from a device kernel `Instance` type.
/// @tparam Instance The device kernel instance type.
/// @return A `builder::ConvFwdSpecialization`, `builder::ConvBwdDataSpecialization`, or
/// `builder::ConvBwdWeightSpecialization` enum value.
/// @return A `builder::ConvSpecialization` enum value.
template <typename Instance>
constexpr auto conv_spec()
{
using InstTraits = InstanceTraits<Instance>;
using enum builder::ConvSpecialization;
if constexpr(requires { InstTraits::kConvForwardSpecialization; })
{
using enum ck::tensor_operation::device::ConvolutionForwardSpecialization;
using enum builder::ConvFwdSpecialization;
switch(InstTraits::kConvForwardSpecialization)
{
case Default: return DEFAULT;
@@ -221,8 +219,6 @@ constexpr auto conv_spec()
else if constexpr(requires { InstTraits::kConvBwdDataSpecialization; })
{
using enum ck::tensor_operation::device::ConvolutionBackwardDataSpecialization;
using enum builder::ConvBwdDataSpecialization;
switch(InstTraits::kConvBwdDataSpecialization)
{
case Default: return DEFAULT;
@@ -232,8 +228,6 @@ constexpr auto conv_spec()
else if constexpr(requires { InstTraits::kConvBwdWeightSpecialization; })
{
using enum ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization;
using enum builder::ConvBwdWeightSpecialization;
switch(InstTraits::kConvBwdWeightSpecialization)
{
case Default: return DEFAULT;

View File

@@ -35,10 +35,10 @@ struct ReferenceCommonTraits
typename builder::factory::internal::LayoutToCK<SIGNATURE.output.config.layout>::type;
// Data types - extract from factory's type helper
using Types = builder::factory::internal::FwdConvTensorDataTypes<SIGNATURE>;
using ADataType = typename Types::ADataType;
using BDataType = typename Types::BDataType;
using EDataType = typename Types::EDataType;
using Types = builder::factory::internal::ConvTensorDataTypes<SIGNATURE>;
using ADataType = typename Types::InDataType;
using BDataType = typename Types::WeiDataType;
using EDataType = typename Types::OutDataType;
using AccDataType = float; // Reference uses float accumulation
// Elementwise operations - reference only supports PassThrough

View File

@@ -72,11 +72,10 @@ struct Args<SIGNATURE>
using OutputDescriptor = TensorDescriptor<OUTPUT_TYPE, OUTPUT_RANK>;
// TODO: We shouldn't need to call into an internal namespace here.
using Ops = factory::internal::ElementwiseOps<SIGNATURE>;
using Ops = factory::internal::ConvElementwiseOps<SIGNATURE>;
// TODO: We shouldn't need to call into an internal namespace here.
using Layouts =
factory::internal::ConvTensorLayouts<SIGNATURE, SPATIAL_DIM, ConvDirection::FORWARD>;
using Layouts = factory::internal::ConvTensorLayouts<SIGNATURE, SPATIAL_DIM>;
ConvTensorLengths<SPATIAL_DIM> lengths;
@@ -90,9 +89,9 @@ struct Args<SIGNATURE>
FilterExtent<SPATIAL_DIM> input_left_pad;
FilterExtent<SPATIAL_DIM> input_right_pad;
Ops::AElementwiseOp a_elementwise_op;
Ops::BElementwiseOp b_elementwise_op;
Ops::CDEElementwiseOp cde_elementwise_op;
Ops::InElementwiseOp a_elementwise_op;
Ops::WeiElementwiseOp b_elementwise_op;
Ops::OutElementwiseOp cde_elementwise_op;
/// This function returns the `TensorDescriptor` corresponding to
/// the input-tensor of the convolution problem. This can then
@@ -107,7 +106,7 @@ struct Args<SIGNATURE>
// function.
const auto param = to_ck_conv_param();
const auto desc = ck::utils::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed<
typename Layouts::ALayout>(param);
typename Layouts::InLayout>(param);
using Extent = typename InputDescriptor::Extent;
return InputDescriptor(Extent::from_vector(desc.GetLengths()),
Extent::from_vector(desc.GetStrides()));
@@ -121,7 +120,7 @@ struct Args<SIGNATURE>
// See note in implementation of `make_input_descriptor`.
const auto param = to_ck_conv_param();
const auto desc = ck::utils::conv::make_weight_host_tensor_descriptor_g_k_c_xs_packed<
typename Layouts::BLayout>(param);
typename Layouts::WeiLayout>(param);
using Extent = typename WeightDescriptor::Extent;
return WeightDescriptor(Extent::from_vector(desc.GetLengths()),
Extent::from_vector(desc.GetStrides()));
@@ -135,7 +134,7 @@ struct Args<SIGNATURE>
// See note in implementation of `make_input_descriptor`.
const auto param = to_ck_conv_param();
const auto desc = ck::utils::conv::make_output_host_tensor_descriptor_g_n_k_wos_packed<
typename Layouts::ELayout>(param);
typename Layouts::OutLayout>(param);
using Extent = typename OutputDescriptor::Extent;
return OutputDescriptor(Extent::from_vector(desc.GetLengths()),
Extent::from_vector(desc.GetStrides()));

View File

@@ -27,7 +27,7 @@ template <typename Conv,
auto SIGNATURE,
size_t SPATIAL_DIM = SIGNATURE.spatial_dim,
// TODO: We shouldn't need to call into an internal namespace here.
typename Ops = factory::internal::ElementwiseOps<SIGNATURE>>
typename Ops = factory::internal::ConvElementwiseOps<SIGNATURE>>
concept CkConvInstance = requires(Conv& conv,
// TODO: This should be changed depending on IsMultiA etc.
// Currently that is not yet supported elsewhere anyway.
@@ -37,9 +37,9 @@ concept CkConvInstance = requires(Conv& conv,
std::array<index_t, SPATIAL_DIM + 3> lengths,
std::array<index_t, SPATIAL_DIM + 3> strides,
std::array<index_t, SPATIAL_DIM> filter,
Ops::AElementwiseOp elementwise_a,
Ops::BElementwiseOp elementwise_b,
Ops::CDEElementwiseOp elementwise_cde) {
Ops::InElementwiseOp elementwise_a,
Ops::WeiElementwiseOp elementwise_b,
Ops::OutElementwiseOp elementwise_cde) {
{
conv.MakeArgument(p_a,
p_b,

View File

@@ -192,8 +192,8 @@ enum class TileConvSpecialization
FILTER_3x3
};
// Enums for the forward convolution specialization.
enum class ConvFwdSpecialization
// Enums for the convolution specializations.
enum class ConvSpecialization
{
DEFAULT,
FILTER_1X1_PAD0,
@@ -202,22 +202,6 @@ enum class ConvFwdSpecialization
ODD_C
};
// Enums for the backward data convolution specialization.
enum class ConvBwdDataSpecialization
{
DEFAULT,
FILTER_1X1_STRIDE1_PAD0,
};
// Enums for the backward weight convolution specialization.
enum class ConvBwdWeightSpecialization
{
DEFAULT,
FILTER_1X1_STRIDE1_PAD0,
FILTER_1X1_PAD0,
ODD_C,
};
// Enums for the Gemm padding.
enum class GemmPadding
{
@@ -249,7 +233,9 @@ enum class PipelineScheduler
enum class ConvAlgorithmSpecialization
{
LARGE_TENSOR,
REFERENCE // GPU reference implementation for validation
REFERENCE, // GPU reference implementation for validation,
TWO_STAGE,
MULTIPLE_D
};
// to_string methods for enum classes
@@ -372,9 +358,9 @@ inline std::string_view to_string(GemmSpecialization spec)
}
}
inline std::string_view to_string(ConvFwdSpecialization spec)
inline std::string_view to_string(ConvSpecialization spec)
{
using enum ConvFwdSpecialization;
using enum ConvSpecialization;
switch(spec)
{
case DEFAULT: return "DEFAULT";
@@ -386,30 +372,6 @@ inline std::string_view to_string(ConvFwdSpecialization spec)
}
}
inline std::string_view to_string(ConvBwdDataSpecialization spec)
{
using enum ConvBwdDataSpecialization;
switch(spec)
{
case DEFAULT: return "DEFAULT";
case FILTER_1X1_STRIDE1_PAD0: return "FILTER_1X1_STRIDE1_PAD0";
default: return "Unknown";
}
}
inline std::string_view to_string(ConvBwdWeightSpecialization spec)
{
using enum ConvBwdWeightSpecialization;
switch(spec)
{
case DEFAULT: return "DEFAULT";
case FILTER_1X1_STRIDE1_PAD0: return "FILTER_1X1_STRIDE1_PAD0";
case FILTER_1X1_PAD0: return "FILTER_1X1_PAD0";
case ODD_C: return "ODD_C";
default: return "Unknown";
}
}
inline std::string_view to_string(GemmPadding padding)
{
using enum GemmPadding;
@@ -525,17 +487,7 @@ inline std::ostream& operator<<(std::ostream& os, GemmSpecialization spec)
return os << to_string(spec);
}
inline std::ostream& operator<<(std::ostream& os, ConvFwdSpecialization spec)
{
return os << to_string(spec);
}
inline std::ostream& operator<<(std::ostream& os, ConvBwdDataSpecialization spec)
{
return os << to_string(spec);
}
inline std::ostream& operator<<(std::ostream& os, ConvBwdWeightSpecialization spec)
inline std::ostream& operator<<(std::ostream& os, ConvSpecialization spec)
{
return os << to_string(spec);
}
@@ -555,14 +507,4 @@ inline std::ostream& operator<<(std::ostream& os, TensorLayout layout)
return os << to_string(layout);
}
// ostream operator overload for std::variant of convolution specializations
inline std::ostream& operator<<(std::ostream& os,
const std::variant<ConvFwdSpecialization,
ConvBwdDataSpecialization,
ConvBwdWeightSpecialization>& spec)
{
std::visit([&os](const auto& s) { os << s; }, spec);
return os;
}
} // namespace ck_tile::builder