mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 02:02:46 +00:00
[CK_BUILDER] Add bwd weight factories (#3509)
* Add placeholder test.
* Initial conv bwd weight factory.
* Conv builder test refactoring.
* Add missing pieces to bwd weight factory.
* Improve compile time erros message when no matching factory is found.
* Use amcro to ensure automatic macthing between concepts are their string representations.
* Improve compile time diagnostics.
* Small improvements.
* Improve missing member/wrong type compile-time errors.
* Improve compile time diagnostics.
* Concept bug fixes.
* Remove debug assert.
* Update algorithm signature diagnostics.
* Factory bug fixes.
* First functional version of bwd weight conv factory.
* Refactor handing of GEMM-K batch template parameter in conv bwd weight factory.
* Concept improvements.
* Improve concept diagnostics.
* Introduve a common size type for concepts.
* Update compiletime diagnostics to use the size type.
* Update conv specialization enum.
* Fix fwd conv builder tests.
* Fix smoke tests.
* Separate bwd weigth and bwd data tests into separate targets.
* Clean-up CK Tile builder tests.
* Add bwd weight XDL CShuffle V3 factory.
* Build conv bwd weigth v3 instances successfully.
* Add instance traits for DeviceGroupedConvBwdWeight_Xdl_CShuffleV3.
* Test fix.
* Add instance traits for bwd weight algorithms.
* Add unit tests for instance strings.
* Build new instance traits unit tests but exclude WMMA for now.
* Added factory for DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle.
* Conv bwd weight DL factory.
* Final implementation for bwd weight DL factory.
* Add test for creating DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle instance.
* Add factory for DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle
* Treat ref algorithm the same way as real algorithms in the dispatcher.
* Refactor large tensor support and WMMA configuration.
* Add factory and tests for DeviceGroupedConvBwdWeight_Wmma_CShuffleV3.
* Update Readme.
* Fix WMMA bwd weight tests.
* Added factory and tests for DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3.
* Factory and tests for DeviceGroupedConvBwdWeight_Wmma_CShuffle.
* Dispatching for DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffle.
* Add factory for DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3
* Fix DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3 factory and compute types for input and output tensor in bwd weigth convs.
* Fix fwd factories after refactoring.
* clang-format
* Move compile-time diagnostics to a separate branch.
* Fix ref algorithm dispatching.
* Fix smoke tests.
* clang-format
* Fix factory for regular WMMA conv bwd weight.
* Clarify builder Readme.
* Remove obsolete test file.
* Fix test after merge.
* clang-format
* Remove the C++26 extensions.
* Unify conv elementwise ops and layout definitions for fwd and bwd directions.
* Remove old layout and elementwise ops.
* Unify handling of conv tensor types between fwd and bwd directions.
* Unify block transfer for fwd and bwd directions. Rename ThreadSliceDim to ThreadClusterRank.
* Make BlockTransferDescriptor concept parametrized. Introduce a common TileTransferParameters concept for conv algorithms.
* clang-format
---------
Co-authored-by: Ville Pietilä <>
[ROCm/composable_kernel commit: 9908a87c31]
This commit is contained in:
@@ -45,6 +45,11 @@ cmake
|
||||
..
|
||||
```
|
||||
|
||||
Note: The tests for WMMA builders are only built when `CK_USE_WMMA` is enabled. Add e.g.
|
||||
`gfx1121` or any of the other `gfx11`/`gfx12` architectures to the GPU targets. Alternatively,
|
||||
one can add flag `-D CK_USE_WMMA=ON` to build the tests. For the end-to-end tests that use
|
||||
the instances from builder, one needs an actual Navi card.
|
||||
|
||||
## Building and Testing
|
||||
|
||||
The builder test suite is organized into two main categories:
|
||||
|
||||
@@ -15,29 +15,31 @@ namespace ck_tile::builder {
|
||||
/* Descriptors for individual elements of the algorithm description */
|
||||
/********************************************************************/
|
||||
|
||||
// Common concept for size-related fields
|
||||
template <typename T>
|
||||
concept SizeType = std::unsigned_integral<std::remove_cvref_t<T>>;
|
||||
|
||||
// Concept for thread block dimensions for a GEMM problem.
|
||||
template <typename T>
|
||||
concept ThreadBlockDescriptor = requires(T t) {
|
||||
{ t.block_size } -> std::convertible_to<size_t>;
|
||||
{ t.tile_size.m } -> std::convertible_to<size_t>;
|
||||
{ t.tile_size.n } -> std::convertible_to<size_t>;
|
||||
{ t.tile_size.k } -> std::convertible_to<size_t>;
|
||||
{ t.block_size } -> SizeType;
|
||||
{ t.tile_size.m } -> SizeType;
|
||||
{ t.tile_size.n } -> SizeType;
|
||||
{ t.tile_size.k } -> SizeType;
|
||||
};
|
||||
|
||||
// Concept for parameters that describe a gridwise XDL GEMM problem.
|
||||
template <typename T>
|
||||
concept GridwiseXdlGemmDescriptor = requires(T t) {
|
||||
{ t.ak1 } -> std::convertible_to<size_t>;
|
||||
{ t.bk1 } -> std::convertible_to<size_t>;
|
||||
{ t.m_per_xdl } -> std::convertible_to<size_t>;
|
||||
{ t.n_per_xdl } -> std::convertible_to<size_t>;
|
||||
{ t.m_xdl_per_wave } -> std::convertible_to<size_t>;
|
||||
{ t.n_xdl_per_wave } -> std::convertible_to<size_t>;
|
||||
{ t.m_per_xdl } -> SizeType;
|
||||
{ t.n_per_xdl } -> SizeType;
|
||||
{ t.m_xdl_per_wave } -> SizeType;
|
||||
{ t.n_xdl_per_wave } -> SizeType;
|
||||
};
|
||||
|
||||
// Concept for parameter that describe block GEMM problem.
|
||||
template <typename T>
|
||||
concept BlockGemmDescriptor = requires(T t) {
|
||||
concept BlockGemmPipelineDescriptor = requires(T t) {
|
||||
{ t.pipeline_version } -> std::convertible_to<PipelineVersion>;
|
||||
{ t.scheduler } -> std::convertible_to<PipelineScheduler>;
|
||||
};
|
||||
@@ -45,37 +47,48 @@ concept BlockGemmDescriptor = requires(T t) {
|
||||
// Concept for parameters that describe a gridwise WMMA GEMM problem.
|
||||
template <typename T>
|
||||
concept GridwiseWmmaGemmDescriptor = requires(T t) {
|
||||
{ t.k1 } -> std::convertible_to<size_t>;
|
||||
{ t.m_per_wmma } -> std::convertible_to<size_t>;
|
||||
{ t.n_per_wmma } -> std::convertible_to<size_t>;
|
||||
{ t.m_wmma_per_wave } -> std::convertible_to<size_t>;
|
||||
{ t.n_wmma_per_wave } -> std::convertible_to<size_t>;
|
||||
{ t.pipeline_version } -> std::convertible_to<PipelineVersion>;
|
||||
{ t.k1 } -> SizeType;
|
||||
{ t.m_per_wmma } -> SizeType;
|
||||
{ t.n_per_wmma } -> SizeType;
|
||||
{ t.m_wmma_per_wave } -> SizeType;
|
||||
{ t.n_wmma_per_wave } -> SizeType;
|
||||
};
|
||||
|
||||
// Concept for vectorized data transfer for convolution input tensors.
|
||||
template <typename T>
|
||||
concept BlockTransferDescriptor = requires(T t) {
|
||||
{ t.k0 } -> std::convertible_to<size_t>;
|
||||
{ t.m_n } -> std::convertible_to<size_t>;
|
||||
{ t.k1 } -> std::convertible_to<size_t>;
|
||||
concept BlockTransferDescriptor3D = requires(T t) {
|
||||
{ t.k0 } -> SizeType;
|
||||
{ t.m_n } -> SizeType;
|
||||
{ t.k1 } -> SizeType;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
concept BlockTransferDescriptor4D = requires(T t) {
|
||||
{ t.k0 } -> SizeType;
|
||||
{ t.m_n } -> SizeType;
|
||||
{ t.k1 } -> SizeType;
|
||||
{ t.k_batch_size } -> SizeType;
|
||||
};
|
||||
|
||||
template <typename T, size_t ThreadClusterRank>
|
||||
concept BlockTransferDescriptor = (ThreadClusterRank == 3 && BlockTransferDescriptor3D<T>) ||
|
||||
(ThreadClusterRank == 4 && BlockTransferDescriptor4D<T>);
|
||||
|
||||
// Concept for thread cluster dimensions for GEMM output tensor.
|
||||
template <typename T>
|
||||
concept ThreadClusterDescriptor = requires(T t) {
|
||||
{ t.m_block } -> std::convertible_to<size_t>;
|
||||
{ t.m_wave_per_xdl } -> std::convertible_to<size_t>;
|
||||
{ t.n_block } -> std::convertible_to<size_t>;
|
||||
{ t.n_wave_per_xdl } -> std::convertible_to<size_t>;
|
||||
{ t.m_block } -> SizeType;
|
||||
{ t.m_wave_per_xdl } -> SizeType;
|
||||
{ t.n_block } -> SizeType;
|
||||
{ t.n_wave_per_xdl } -> SizeType;
|
||||
};
|
||||
|
||||
// Concept for the LDS transfer for the convolution input tensors.
|
||||
template <typename T>
|
||||
concept LdsTransferDescriptor = requires(T t) {
|
||||
{ t.src_vector_dim } -> std::convertible_to<size_t>;
|
||||
{ t.src_scalar_per_vector } -> std::convertible_to<size_t>;
|
||||
{ t.lds_dst_scalar_per_vector } -> std::convertible_to<size_t>;
|
||||
{ t.src_vector_dim } -> SizeType;
|
||||
{ t.src_scalar_per_vector } -> SizeType;
|
||||
{ t.lds_dst_scalar_per_vector } -> SizeType;
|
||||
{ t.is_direct_load } -> std::convertible_to<bool>;
|
||||
{ t.lds_padding } -> std::convertible_to<bool>;
|
||||
};
|
||||
@@ -84,33 +97,35 @@ concept LdsTransferDescriptor = requires(T t) {
|
||||
// LDS).
|
||||
template <typename T>
|
||||
concept EpilogueDescriptor = requires(T t) {
|
||||
{ t.m_xdl_per_wave_per_shuffle } -> std::convertible_to<size_t>;
|
||||
{ t.n_per_wave_per_shuffle } -> std::convertible_to<size_t>;
|
||||
{ t.scalar_per_vector } -> std::convertible_to<size_t>;
|
||||
{ t.m_xdl_per_wave_per_shuffle } -> SizeType;
|
||||
{ t.n_per_wave_per_shuffle } -> SizeType;
|
||||
{ t.scalar_per_vector } -> SizeType;
|
||||
};
|
||||
|
||||
// Concept for the thread cluster access order
|
||||
template <typename T>
|
||||
concept AccessOrderDescriptor = requires(T t) {
|
||||
{ t.order } -> std::convertible_to<std::array<size_t, 3>>;
|
||||
} || requires(T t) {
|
||||
{ t.order } -> std::convertible_to<std::array<size_t, 4>>;
|
||||
};
|
||||
|
||||
// Concept for thread block dimensions for a GEMM problem for CK Tile (Block
|
||||
// size is deduced from block gemm structure).
|
||||
template <typename T>
|
||||
concept TileThreadBlockDescriptor = requires(T t) {
|
||||
{ t.tile_size.m } -> std::convertible_to<size_t>;
|
||||
{ t.tile_size.n } -> std::convertible_to<size_t>;
|
||||
{ t.tile_size.k } -> std::convertible_to<size_t>;
|
||||
{ t.tile_size.m } -> SizeType;
|
||||
{ t.tile_size.n } -> SizeType;
|
||||
{ t.tile_size.k } -> SizeType;
|
||||
};
|
||||
|
||||
// Concept for thread block dimensions for a GEMM problem for CK Tile (Block
|
||||
// size is deduced from block gemm structure).
|
||||
template <typename T>
|
||||
concept TileTransferDescriptor = requires(T t) {
|
||||
{ t.a_scalar_per_vector } -> std::convertible_to<size_t>;
|
||||
{ t.b_scalar_per_vector } -> std::convertible_to<size_t>;
|
||||
{ t.c_scalar_per_vector } -> std::convertible_to<size_t>;
|
||||
{ t.a_scalar_per_vector } -> SizeType;
|
||||
{ t.b_scalar_per_vector } -> SizeType;
|
||||
{ t.c_scalar_per_vector } -> SizeType;
|
||||
};
|
||||
|
||||
// Concept to check if struct specifies block GEMM (CK Tile).
|
||||
@@ -159,30 +174,51 @@ concept SpecifiesTileThreadBlock = requires {
|
||||
|
||||
// Concept to check if a struct specifies gridwise XDL GEMM info.
|
||||
template <typename T>
|
||||
concept SpecifiesGridwiseXdlGemm = requires {
|
||||
{ T::gridwise_gemm } -> GridwiseXdlGemmDescriptor;
|
||||
concept GridwiseFwdXdlGemmDescriptor = requires(T t) {
|
||||
{ t.ak1 } -> SizeType;
|
||||
{ t.bk1 } -> SizeType;
|
||||
{ t.xdl_params } -> GridwiseXdlGemmDescriptor;
|
||||
};
|
||||
|
||||
// Concept to check if a struct specifies gridwise XDL GEMM info.
|
||||
template <typename T>
|
||||
concept GridwiseBwdXdlGemmDescriptor = requires(T t) {
|
||||
{ t.k1 } -> SizeType;
|
||||
{ t.xdl_params } -> GridwiseXdlGemmDescriptor;
|
||||
};
|
||||
|
||||
// Concept to check if a struct specifies gridwise XDL GEMM info.
|
||||
template <typename T>
|
||||
concept SpecifiesGridwiseFwdXdlGemm = requires(T t) {
|
||||
{ t.gridwise_gemm } -> GridwiseFwdXdlGemmDescriptor;
|
||||
};
|
||||
|
||||
// Concept to check if a struct specifies gridwise XDL GEMM info.
|
||||
template <typename T>
|
||||
concept SpecifiesGridwiseBwdXdlGemm = requires(T t) {
|
||||
{ t.gridwise_gemm } -> GridwiseBwdXdlGemmDescriptor;
|
||||
};
|
||||
|
||||
// Concept to check if a struct specifies gridwise WMMA GEMM info.
|
||||
template <typename T>
|
||||
concept SpecifiesGridwiseWmmaGemm = requires {
|
||||
{ T::gridwise_gemm } -> GridwiseWmmaGemmDescriptor;
|
||||
concept SpecifiesGridwiseWmmaGemm = requires(T t) {
|
||||
{ t.gridwise_gemm } -> GridwiseWmmaGemmDescriptor;
|
||||
};
|
||||
|
||||
// Concept to check if a struct specifies convolution input and output block transfer info.
|
||||
template <typename T>
|
||||
template <typename T, size_t ThreadClusterRank = 3>
|
||||
concept SpecifiesBlockTransfer = requires(T t) {
|
||||
{ T::transfer.a.block_transfer } -> BlockTransferDescriptor;
|
||||
{ T::transfer.b.block_transfer } -> BlockTransferDescriptor;
|
||||
{ T::transfer.a.block_transfer } -> BlockTransferDescriptor<ThreadClusterRank>;
|
||||
{ T::transfer.b.block_transfer } -> BlockTransferDescriptor<ThreadClusterRank>;
|
||||
{ T::transfer.c.thread_cluster_dims } -> ThreadClusterDescriptor;
|
||||
};
|
||||
|
||||
// Concept to check if a struct specifies convolution scalar per vector infor for A, B and C.
|
||||
template <typename T>
|
||||
concept SpecifiesTileTransfer = requires(T t) {
|
||||
{ T::transfer.a_scalar_per_vector } -> std::convertible_to<size_t>;
|
||||
{ T::transfer.b_scalar_per_vector } -> std::convertible_to<size_t>;
|
||||
{ T::transfer.c_scalar_per_vector } -> std::convertible_to<size_t>;
|
||||
{ T::transfer.a_scalar_per_vector } -> SizeType;
|
||||
{ T::transfer.b_scalar_per_vector } -> SizeType;
|
||||
{ T::transfer.c_scalar_per_vector } -> SizeType;
|
||||
};
|
||||
|
||||
// Concept to check if a struct specifies LDS transfer info for tensors A, B, and C.
|
||||
@@ -210,8 +246,12 @@ concept SpecifiesSourceAccessOrder = requires(T t) {
|
||||
// Concept to check if struct specifies block GEMM.
|
||||
template <typename T>
|
||||
concept SpecifiesBlockGemm = requires {
|
||||
{ T::block_gemm.pipeline_version } -> std::convertible_to<PipelineVersion>;
|
||||
{ T::block_gemm.scheduler } -> std::convertible_to<PipelineScheduler>;
|
||||
{ T::block_gemm_pipeline } -> BlockGemmPipelineDescriptor;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
concept SpecifiesGridwiseGemmPipeline = requires {
|
||||
{ T::pipeline_version } -> std::convertible_to<PipelineVersion>;
|
||||
};
|
||||
|
||||
// Concept to check if struct specifies block GEMM (CK Tile).
|
||||
@@ -244,7 +284,12 @@ concept SpecifiesTileConvSpecialization = requires {
|
||||
|
||||
template <typename T>
|
||||
concept SpecifiesFwdConvSpecialization = requires {
|
||||
{ T::fwd_specialization } -> std::convertible_to<ConvFwdSpecialization>;
|
||||
{ T::fwd_specialization } -> std::convertible_to<ConvSpecialization>;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
concept SpecifiesBwdWeightConvSpecialization = requires {
|
||||
{ T::bwd_weight_specialization } -> std::convertible_to<ConvSpecialization>;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
@@ -254,12 +299,12 @@ concept SpecifiesGemmSpecialization = requires {
|
||||
|
||||
template <typename T>
|
||||
concept SpecifiesNumPrefetchStages = requires {
|
||||
{ T::num_gemm_k_prefetch_stages } -> std::convertible_to<size_t>;
|
||||
{ T::num_gemm_k_prefetch_stages } -> SizeType;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
concept SpecifiesNumGroupsToMerge = requires {
|
||||
{ T::num_groups_to_merge } -> std::convertible_to<size_t>;
|
||||
{ T::num_conv_groups_to_merge } -> SizeType;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
@@ -267,12 +312,59 @@ concept SpecifiesLoopScheduler = requires {
|
||||
{ T::loop_scheduler } -> std::convertible_to<PipelineScheduler>;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
concept SpecifiesGenericInstance = !requires {
|
||||
{ T::specialization };
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
concept SpecifiesTransposeTransfer = requires {
|
||||
{ T::max_transpose_transfer_src_scalar_per_vector } -> SizeType;
|
||||
{ T::max_transpose_transfer_dst_scalar_per_vector } -> SizeType;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
concept HasTransposeTransfer = requires {
|
||||
{ T::max_transpose_transfer_src_scalar_per_vector };
|
||||
{ T::max_transpose_transfer_dst_scalar_per_vector };
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
concept TransposeTransferWellDefinedIfProvided =
|
||||
!HasTransposeTransfer<T> || SpecifiesTransposeTransfer<T>;
|
||||
|
||||
template <typename T>
|
||||
concept SpecifiesGemmBatchOptions = requires {
|
||||
{ T::num_conv_groups_to_merge } -> SizeType;
|
||||
};
|
||||
|
||||
/******************************************** */
|
||||
/* Algorithm specialization concepts */
|
||||
/******************************************** */
|
||||
template <typename T>
|
||||
concept SpecifiesLargeTensorSupport = requires {
|
||||
{ T::specialization } -> std::convertible_to<ConvAlgorithmSpecialization>;
|
||||
requires T::specialization == ConvAlgorithmSpecialization::LARGE_TENSOR;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
concept SpecifiesReferenceAlgorithm = requires {
|
||||
{ T::specialization } -> std::convertible_to<ConvAlgorithmSpecialization>;
|
||||
requires T::specialization == ConvAlgorithmSpecialization::REFERENCE;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
concept SpecifiesTwoStageSupport = requires {
|
||||
{ T::specialization } -> std::convertible_to<ConvAlgorithmSpecialization>;
|
||||
requires T::specialization == ConvAlgorithmSpecialization::TWO_STAGE;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
concept SpecifiesMultipleDSupport = requires {
|
||||
{ T::specialization } -> std::convertible_to<ConvAlgorithmSpecialization>;
|
||||
requires T::specialization == ConvAlgorithmSpecialization::MULTIPLE_D;
|
||||
};
|
||||
|
||||
/******************************************** */
|
||||
/* DL-specific descriptors and requirements */
|
||||
/******************************************** */
|
||||
@@ -280,11 +372,11 @@ concept SpecifiesLargeTensorSupport = requires {
|
||||
// Concept for DL thread configuration
|
||||
template <typename T>
|
||||
concept DlThreadConfigDescriptor = requires(T t) {
|
||||
{ t.k0_per_block } -> std::convertible_to<size_t>;
|
||||
{ t.k1 } -> std::convertible_to<size_t>;
|
||||
{ t.m1_per_thread } -> std::convertible_to<size_t>;
|
||||
{ t.n1_per_thread } -> std::convertible_to<size_t>;
|
||||
{ t.k_per_thread } -> std::convertible_to<size_t>;
|
||||
{ t.k0_per_block } -> SizeType;
|
||||
{ t.k1 } -> SizeType;
|
||||
{ t.m1_per_thread } -> SizeType;
|
||||
{ t.n1_per_thread } -> SizeType;
|
||||
{ t.k_per_thread } -> SizeType;
|
||||
};
|
||||
|
||||
// Concept for DL thread cluster
|
||||
@@ -295,23 +387,29 @@ concept DlThreadClusterDescriptor = requires(T t) {
|
||||
};
|
||||
|
||||
// Concept for DL block transfer
|
||||
template <typename T>
|
||||
template <typename T, size_t N>
|
||||
concept DlBlockTransferDescriptor = requires(T t) {
|
||||
{ t.thread_slice_lengths } -> std::convertible_to<std::array<size_t, 4>>;
|
||||
{ t.thread_cluster_lengths } -> std::convertible_to<std::array<size_t, 4>>;
|
||||
{ t.thread_cluster_arrange_order } -> std::convertible_to<std::array<size_t, 4>>;
|
||||
{ t.src_access_order } -> std::convertible_to<std::array<size_t, 4>>;
|
||||
{ t.src_vector_tensor_lengths } -> std::convertible_to<std::array<size_t, 4>>;
|
||||
{ t.src_vector_tensor_contiguous_dim_order } -> std::convertible_to<std::array<size_t, 4>>;
|
||||
{ t.dst_vector_tensor_lengths } -> std::convertible_to<std::array<size_t, 4>>;
|
||||
{ t.thread_slice_lengths } -> std::convertible_to<std::array<size_t, N>>;
|
||||
{ t.thread_cluster_lengths } -> std::convertible_to<std::array<size_t, N>>;
|
||||
{ t.thread_cluster_arrange_order } -> std::convertible_to<std::array<size_t, N>>;
|
||||
{ t.src_access_order } -> std::convertible_to<std::array<size_t, N>>;
|
||||
{ t.src_vector_tensor_lengths } -> std::convertible_to<std::array<size_t, N>>;
|
||||
{ t.src_vector_tensor_contiguous_dim_order } -> std::convertible_to<std::array<size_t, N>>;
|
||||
{ t.dst_vector_tensor_lengths } -> std::convertible_to<std::array<size_t, N>>;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
concept DlBlockTransferDescriptor4D = DlBlockTransferDescriptor<T, 4>;
|
||||
|
||||
template <typename T>
|
||||
concept DlBlockTransferDescriptor5D = DlBlockTransferDescriptor<T, 5>;
|
||||
|
||||
// Concept for DL epilogue
|
||||
template <typename T>
|
||||
concept DlEpilogueDescriptor = requires(T t) {
|
||||
{ t.src_dst_access_order } -> std::convertible_to<std::array<size_t, 6>>;
|
||||
{ t.src_dst_vector_dim } -> std::convertible_to<size_t>;
|
||||
{ t.dst_scalar_per_vector } -> std::convertible_to<size_t>;
|
||||
{ t.src_dst_vector_dim } -> SizeType;
|
||||
{ t.dst_scalar_per_vector } -> SizeType;
|
||||
};
|
||||
|
||||
// Concept to check if algorithm specifies DL thread config
|
||||
@@ -328,15 +426,21 @@ concept SpecifiesDlThreadCluster = requires {
|
||||
|
||||
// Concept to check if algorithm specifies DL block transfer
|
||||
template <typename T>
|
||||
concept SpecifiesDlBlockTransfer = requires {
|
||||
{ T::transfer.a.block_transfer } -> DlBlockTransferDescriptor;
|
||||
{ T::transfer.b.block_transfer } -> DlBlockTransferDescriptor;
|
||||
concept SpecifiesDlFwdBlockTransfer = requires {
|
||||
{ T::transfer.a } -> DlBlockTransferDescriptor4D;
|
||||
{ T::transfer.b } -> DlBlockTransferDescriptor4D;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
concept SpecifiesDlBwdBlockTransfer = requires {
|
||||
{ T::transfer.a } -> DlBlockTransferDescriptor5D;
|
||||
{ T::transfer.b } -> DlBlockTransferDescriptor5D;
|
||||
};
|
||||
|
||||
// Concept to check if algorithm specifies DL C thread transfer
|
||||
template <typename T>
|
||||
concept SpecifiesDlEpilogue = requires {
|
||||
{ T::transfer.c.epilogue } -> DlEpilogueDescriptor;
|
||||
{ T::transfer.c } -> DlEpilogueDescriptor;
|
||||
};
|
||||
|
||||
} // namespace ck_tile::builder
|
||||
|
||||
@@ -29,10 +29,20 @@ concept OutputVectorTransferLimits = requires {
|
||||
|
||||
// Limits for access order. Must be a permutation of {0, 1, 2}.
|
||||
template <auto Value>
|
||||
concept AccessOrderLimits = requires {
|
||||
concept AccessOrderLimits3D = requires {
|
||||
requires((Value[0] != Value[1]) && (Value[0] != Value[2]) && (Value[1] != Value[2]) &&
|
||||
(Value[0] >= 0 && Value[0] < 3) && (Value[1] >= 0 && Value[1] < 3) &&
|
||||
(Value[2] >= 0 && Value[2] < 3));
|
||||
(Value[2] >= 0 && Value[2] < 3) && (Value.Size() == 3));
|
||||
};
|
||||
|
||||
// Limits for access order. Must be a permutation of {0, 1, 2, 3}.
|
||||
template <auto Value>
|
||||
concept AccessOrderLimits4D = requires {
|
||||
requires((Value[0] != Value[1]) && (Value[0] != Value[2]) && (Value[0] != Value[3]) &&
|
||||
(Value[1] != Value[2]) && (Value[1] != Value[3]) && (Value[2] != Value[3]) &&
|
||||
(Value[0] >= 0 && Value[0] < 4) && (Value[1] >= 0 && Value[1] < 4) &&
|
||||
(Value[2] >= 0 && Value[2] < 4) && (Value[3] >= 0 && Value[3] < 4) &&
|
||||
(Value.Size() == 4));
|
||||
};
|
||||
|
||||
} // namespace ck_tile::builder
|
||||
|
||||
@@ -228,4 +228,13 @@ concept ValidConvWeightLayoutForSpatialDim =
|
||||
(SpatialDim == 1 && ConvWeightLayout1D<L>) || (SpatialDim == 2 && ConvWeightLayout2D<L>) ||
|
||||
(SpatialDim == 3 && ConvWeightLayout3D<L>);
|
||||
|
||||
// Constraint for 3D conv signature.
|
||||
template <auto Sig>
|
||||
concept Is3D = requires {
|
||||
requires Sig.spatial_dim == 3;
|
||||
requires ConvInputLayout3D<Sig.input.config.layout>;
|
||||
requires ConvOutputLayout3D<Sig.output.config.layout>;
|
||||
requires ConvWeightLayout3D<Sig.weight.config.layout>;
|
||||
};
|
||||
|
||||
} // namespace ck_tile::builder
|
||||
|
||||
@@ -0,0 +1,128 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/builder/conv_algorithm_concepts.hpp"
|
||||
|
||||
namespace ck_tile::builder::factory {
|
||||
|
||||
// Base algorithm concepts
|
||||
template <typename T, size_t ThreadClusterRank = 3>
|
||||
concept TileTransferParameters =
|
||||
SpecifiesBlockTransfer<T, ThreadClusterRank> && SpecifiesLdsTransfer<T> &&
|
||||
SpecifiesThreadClusterAccessOrder<T> && SpecifiesSourceAccessOrder<T>;
|
||||
|
||||
template <typename T>
|
||||
concept SpecifiesTileTransferParameters3D = TileTransferParameters<T, 3>;
|
||||
|
||||
template <typename T>
|
||||
concept SpecifiesTileTransferParameters4D = TileTransferParameters<T, 4>;
|
||||
|
||||
template <typename T>
|
||||
concept FwdXdlAlgorithmBase =
|
||||
ConvAlgorithmDescriptor<T> && SpecifiesThreadBlock<T> && SpecifiesTileTransferParameters3D<T> &&
|
||||
SpecifiesGridwiseFwdXdlGemm<T> && SpecifiesFwdConvSpecialization<T> &&
|
||||
SpecifiesGemmSpecialization<T> && SpecifiesNumPrefetchStages<T> &&
|
||||
SpecifiesNumGroupsToMerge<T> && SpecifiesLoopScheduler<T>;
|
||||
|
||||
template <typename T>
|
||||
concept BwdXdlAlgorithmBase =
|
||||
ConvAlgorithmDescriptor<T> && SpecifiesThreadBlock<T> && SpecifiesTileTransferParameters4D<T> &&
|
||||
SpecifiesGridwiseBwdXdlGemm<T> && SpecifiesBwdWeightConvSpecialization<T>;
|
||||
|
||||
template <typename T>
|
||||
concept BwdXdlV3AlgorithmBase =
|
||||
ConvAlgorithmDescriptor<T> && SpecifiesThreadBlock<T> && SpecifiesTileTransferParameters3D<T> &&
|
||||
SpecifiesGridwiseBwdXdlGemm<T> && SpecifiesBwdWeightConvSpecialization<T> &&
|
||||
SpecifiesBlockGemm<T>;
|
||||
|
||||
template <typename T>
|
||||
concept BwdWmmaAlgorithmBase =
|
||||
ConvAlgorithmDescriptor<T> && SpecifiesThreadBlock<T> && SpecifiesTileTransferParameters3D<T> &&
|
||||
SpecifiesGridwiseWmmaGemm<T> && SpecifiesBwdWeightConvSpecialization<T>;
|
||||
|
||||
template <typename T>
|
||||
concept BwdWmmaV3AlgorithmBase =
|
||||
ConvAlgorithmDescriptor<T> && SpecifiesThreadBlock<T> && SpecifiesTileTransferParameters3D<T> &&
|
||||
SpecifiesGridwiseWmmaGemm<T> && SpecifiesBwdWeightConvSpecialization<T> &&
|
||||
SpecifiesBlockGemm<T>;
|
||||
|
||||
// Reference algorithm concept
|
||||
template <typename T>
|
||||
concept ReferenceAlgorithm = ConvAlgorithmDescriptor<T> && SpecifiesReferenceAlgorithm<T>;
|
||||
|
||||
// Tile-based algorithm concept
|
||||
template <typename T>
|
||||
concept TileAlgorithm = ConvAlgorithmDescriptor<T> && SpecifiesTileThreadBlock<T> &&
|
||||
SpecifiesTileTransfer<T> && SpecifiesTileConvSpecialization<T> &&
|
||||
SpecifiesTileBlockGemm<T> && SpecifiesTileOptimizations<T>;
|
||||
|
||||
// FWD XDL algorithm concepts
|
||||
template <typename T>
|
||||
concept FwdXdlAlgorithm = FwdXdlAlgorithmBase<T> && SpecifiesGenericInstance<T>;
|
||||
|
||||
template <typename T>
|
||||
concept LargeTensorAlgorithm = FwdXdlAlgorithmBase<T> && SpecifiesLargeTensorSupport<T>;
|
||||
|
||||
template <typename T>
|
||||
concept FwdXdlV3Algorithm =
|
||||
ConvAlgorithmDescriptor<T> && SpecifiesThreadBlock<T> && SpecifiesTileTransferParameters3D<T> &&
|
||||
SpecifiesGridwiseFwdXdlGemm<T> && SpecifiesFwdConvSpecialization<T> &&
|
||||
SpecifiesGemmSpecialization<T> && SpecifiesBlockGemm<T>;
|
||||
|
||||
// FWD WMMA algorithm concepts
|
||||
template <typename T>
|
||||
concept FwdWmmaAlgorithm =
|
||||
ConvAlgorithmDescriptor<T> && SpecifiesThreadBlock<T> && SpecifiesTileTransferParameters3D<T> &&
|
||||
SpecifiesGridwiseWmmaGemm<T> && SpecifiesFwdConvSpecialization<T> &&
|
||||
SpecifiesGemmSpecialization<T> && SpecifiesNumPrefetchStages<T> && SpecifiesLoopScheduler<T> &&
|
||||
SpecifiesGridwiseGemmPipeline<T>;
|
||||
|
||||
// FWD DL algorithms
|
||||
template <typename T>
|
||||
concept FwdDlAlgorithm =
|
||||
ConvAlgorithmDescriptor<T> && SpecifiesThreadBlock<T> && SpecifiesFwdConvSpecialization<T> &&
|
||||
SpecifiesGemmSpecialization<T> && SpecifiesDlThreadConfig<T> && SpecifiesDlThreadCluster<T> &&
|
||||
SpecifiesDlFwdBlockTransfer<T> && SpecifiesDlEpilogue<T>;
|
||||
|
||||
// BWD weight XDL algorithm concepts
|
||||
template <typename T>
|
||||
concept BwdXdlAlgorithm =
|
||||
BwdXdlAlgorithmBase<T> && SpecifiesTransposeTransfer<T> && SpecifiesGenericInstance<T>;
|
||||
|
||||
template <typename T>
|
||||
concept BwdMultiDXdlAlgorithm = BwdXdlAlgorithmBase<T> && SpecifiesMultipleDSupport<T>;
|
||||
|
||||
template <typename T>
|
||||
concept BwdXdlV3Algorithm = BwdXdlV3AlgorithmBase<T> && SpecifiesGenericInstance<T>;
|
||||
|
||||
template <typename T>
|
||||
concept BwdTwoStageXdlAlgorithm = BwdXdlV3AlgorithmBase<T> && SpecifiesTransposeTransfer<T> &&
|
||||
SpecifiesGemmBatchOptions<T> && SpecifiesTwoStageSupport<T>;
|
||||
|
||||
// BWD weight WMMA algorithm concepts
|
||||
template <typename T>
|
||||
concept BwdWmmaAlgorithm =
|
||||
BwdWmmaAlgorithmBase<T> && SpecifiesNumPrefetchStages<T> && SpecifiesLoopScheduler<T> &&
|
||||
SpecifiesGridwiseGemmPipeline<T> && SpecifiesGenericInstance<T>;
|
||||
|
||||
template <typename T>
|
||||
concept BwdMultiDWmmaV3Algorithm = BwdWmmaV3AlgorithmBase<T> && SpecifiesMultipleDSupport<T>;
|
||||
|
||||
template <typename T>
|
||||
concept BwdWmmaV3Algorithm =
|
||||
BwdWmmaV3AlgorithmBase<T> && SpecifiesTransposeTransfer<T> && SpecifiesGenericInstance<T>;
|
||||
|
||||
template <typename T>
|
||||
concept BwdTwoStageWmmaV3Algorithm = BwdWmmaV3AlgorithmBase<T> && SpecifiesTransposeTransfer<T> &&
|
||||
SpecifiesGemmBatchOptions<T> && SpecifiesTwoStageSupport<T>;
|
||||
|
||||
// BWD weigth DL algorithms
|
||||
template <typename T>
|
||||
concept BwdDlAlgorithm =
|
||||
ConvAlgorithmDescriptor<T> && SpecifiesThreadBlock<T> &&
|
||||
SpecifiesBwdWeightConvSpecialization<T> && SpecifiesDlThreadConfig<T> &&
|
||||
SpecifiesDlThreadCluster<T> && SpecifiesDlBwdBlockTransfer<T> && SpecifiesDlEpilogue<T>;
|
||||
|
||||
} // namespace ck_tile::builder::factory
|
||||
@@ -0,0 +1,131 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_dl.hpp"
|
||||
#include "ck_tile/builder/conv_signature_concepts.hpp"
|
||||
#include "ck_tile/builder/conv_algorithm_concepts.hpp"
|
||||
#include "ck_tile/builder/builder_utils.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck/conv_tensor_layout.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck/conv_tensor_type.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck/conv_elementwise_op.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck/conv_thread_block.hpp"
|
||||
|
||||
namespace ck_tile::builder::factory {
|
||||
|
||||
// Factory for DeviceGroupedConvBwdWeight_Dl instance
|
||||
// of a grouped bwd weight convolution kernel.
|
||||
template <ConvSignatureDescriptor auto SIGNATURE,
|
||||
ConvAlgorithmDescriptor auto ALGORITHM,
|
||||
StringLiteral VERSION>
|
||||
requires ConvDirectionIsBackwardWeight<SIGNATURE>
|
||||
struct ConvBwdWeightDlFactory
|
||||
{
|
||||
static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim;
|
||||
using Layouts = internal::ConvTensorLayouts<SIGNATURE, SPATIAL_DIM>;
|
||||
using Types = internal::ConvTensorDataTypes<SIGNATURE>;
|
||||
using Ops = internal::ConvElementwiseOps<SIGNATURE>;
|
||||
using AlgorithmType = decltype(ALGORITHM);
|
||||
|
||||
static constexpr auto BWD_CONV_SPECIALIZATION =
|
||||
internal::SetBwdWeightConvSpecialization<ALGORITHM>();
|
||||
|
||||
static constexpr auto BLOCK = internal::SetThreadBlockInfo<ALGORITHM>();
|
||||
|
||||
// DL-specific parameters from algorithm descriptor
|
||||
static constexpr auto DL_THREAD_CFG = ALGORITHM.thread_config;
|
||||
static constexpr ck::index_t K0PerBlock = DL_THREAD_CFG.k0_per_block;
|
||||
static constexpr ck::index_t K1 = DL_THREAD_CFG.k1;
|
||||
static constexpr ck::index_t M1PerThread = DL_THREAD_CFG.m1_per_thread;
|
||||
static constexpr ck::index_t N1PerThread = DL_THREAD_CFG.n1_per_thread;
|
||||
static constexpr ck::index_t KPerThread = DL_THREAD_CFG.k_per_thread;
|
||||
|
||||
// Thread cluster from descriptor
|
||||
static constexpr auto DL_CLUSTER = ALGORITHM.thread_cluster;
|
||||
using M1N1ThreadClusterM1Xs = to_sequence_v<DL_CLUSTER.m1_xs>;
|
||||
using M1N1ThreadClusterN1Xs = to_sequence_v<DL_CLUSTER.n1_xs>;
|
||||
|
||||
// A Block Transfer from descriptor - K0_M0_M1_K1 tensor format
|
||||
static constexpr auto DL_A_TRANSFER = ALGORITHM.transfer.a;
|
||||
using ABlockTransferThreadSliceLengths_K0_M0_M1_K1 =
|
||||
to_sequence_v<DL_A_TRANSFER.thread_slice_lengths>;
|
||||
using ABlockTransferThreadClusterLengths_K0_M0_M1_K1 =
|
||||
to_sequence_v<DL_A_TRANSFER.thread_cluster_lengths>;
|
||||
using ABlockTransferThreadClusterArrangeOrder =
|
||||
to_sequence_v<DL_A_TRANSFER.thread_cluster_arrange_order>;
|
||||
using ABlockTransferSrcAccessOrder = to_sequence_v<DL_A_TRANSFER.src_access_order>;
|
||||
using ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1 =
|
||||
to_sequence_v<DL_A_TRANSFER.src_vector_tensor_lengths>;
|
||||
using ABlockTransferSrcVectorTensorContiguousDimOrder =
|
||||
to_sequence_v<DL_A_TRANSFER.src_vector_tensor_contiguous_dim_order>;
|
||||
using ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1 =
|
||||
to_sequence_v<DL_A_TRANSFER.dst_vector_tensor_lengths>;
|
||||
|
||||
// B Block Transfer from descriptor - K0_N0_N1_K1 tensor format
|
||||
static constexpr auto DL_B_TRANSFER = ALGORITHM.transfer.b;
|
||||
using BBlockTransferThreadSliceLengths_K0_N0_N1_K1 =
|
||||
to_sequence_v<DL_B_TRANSFER.thread_slice_lengths>;
|
||||
using BBlockTransferThreadClusterLengths_K0_N0_N1_K1 =
|
||||
to_sequence_v<DL_B_TRANSFER.thread_cluster_lengths>;
|
||||
using BBlockTransferThreadClusterArrangeOrder =
|
||||
to_sequence_v<DL_B_TRANSFER.thread_cluster_arrange_order>;
|
||||
using BBlockTransferSrcAccessOrder = to_sequence_v<DL_B_TRANSFER.src_access_order>;
|
||||
using BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1 =
|
||||
to_sequence_v<DL_B_TRANSFER.src_vector_tensor_lengths>;
|
||||
using BBlockTransferSrcVectorTensorContiguousDimOrder =
|
||||
to_sequence_v<DL_B_TRANSFER.src_vector_tensor_contiguous_dim_order>;
|
||||
using BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1 =
|
||||
to_sequence_v<DL_B_TRANSFER.dst_vector_tensor_lengths>;
|
||||
|
||||
// C Thread Transfer from descriptor
|
||||
static constexpr auto DL_C_TRANSFER = ALGORITHM.transfer.c;
|
||||
using CThreadTransferSrcDstAccessOrder = to_sequence_v<DL_C_TRANSFER.src_dst_access_order>;
|
||||
static constexpr ck::index_t CThreadTransferSrcDstVectorDim = DL_C_TRANSFER.src_dst_vector_dim;
|
||||
static constexpr ck::index_t CThreadTransferDstScalarPerVector =
|
||||
DL_C_TRANSFER.dst_scalar_per_vector;
|
||||
|
||||
// The DL forward convolution kernel class instance
|
||||
using Instance = ck::tensor_operation::device::DeviceGroupedConvBwdWeight_Dl<
|
||||
SPATIAL_DIM,
|
||||
typename Layouts::InLayout,
|
||||
typename Layouts::WeiLayout,
|
||||
typename Layouts::OutLayout,
|
||||
typename Types::InDataType,
|
||||
typename Types::WeiDataType,
|
||||
typename Types::OutDataType,
|
||||
typename Types::AccDataType,
|
||||
typename Ops::InElementwiseOp,
|
||||
typename Ops::WeiElementwiseOp,
|
||||
typename Ops::OutElementwiseOp,
|
||||
BWD_CONV_SPECIALIZATION,
|
||||
BLOCK.block_size,
|
||||
BLOCK.per_block.m,
|
||||
BLOCK.per_block.n,
|
||||
K0PerBlock,
|
||||
K1,
|
||||
M1PerThread,
|
||||
N1PerThread,
|
||||
KPerThread,
|
||||
M1N1ThreadClusterM1Xs,
|
||||
M1N1ThreadClusterN1Xs,
|
||||
ABlockTransferThreadSliceLengths_K0_M0_M1_K1,
|
||||
ABlockTransferThreadClusterLengths_K0_M0_M1_K1,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
ABlockTransferSrcAccessOrder,
|
||||
ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1,
|
||||
ABlockTransferSrcVectorTensorContiguousDimOrder,
|
||||
ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1,
|
||||
BBlockTransferThreadSliceLengths_K0_N0_N1_K1,
|
||||
BBlockTransferThreadClusterLengths_K0_N0_N1_K1,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
BBlockTransferSrcAccessOrder,
|
||||
BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1,
|
||||
BBlockTransferSrcVectorTensorContiguousDimOrder,
|
||||
BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1,
|
||||
CThreadTransferSrcDstAccessOrder,
|
||||
CThreadTransferSrcDstVectorDim,
|
||||
CThreadTransferDstScalarPerVector>;
|
||||
};
|
||||
|
||||
} // namespace ck_tile::builder::factory
|
||||
@@ -0,0 +1,110 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_wmma_cshuffle_v3.hpp"
|
||||
#include "ck_tile/builder/conv_signature_concepts.hpp"
|
||||
#include "ck_tile/builder/conv_algorithm_concepts.hpp"
|
||||
#include "ck_tile/builder/conv_algorithm_limits.hpp"
|
||||
#include "ck_tile/builder/builder_utils.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck/conv_tensor_layout.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck/conv_tensor_type.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck/conv_elementwise_op.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck/conv_tuning_params.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck/conv_block_transfer.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck/conv_thread_block.hpp"
|
||||
|
||||
namespace ck_tile::builder::factory {
|
||||
|
||||
// Factory for DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3 instance
|
||||
// of a grouped bwd weight convolution kernel.
|
||||
template <ConvSignatureDescriptor auto SIGNATURE,
|
||||
ConvAlgorithmDescriptor auto ALGORITHM,
|
||||
StringLiteral VERSION>
|
||||
requires ConvDirectionIsBackwardWeight<SIGNATURE> && Is3D<SIGNATURE>
|
||||
struct ConvBwdWeightMultiDWmmaV3Factory
|
||||
{
|
||||
static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim;
|
||||
using Layouts = internal::ConvTensorLayouts<SIGNATURE, SPATIAL_DIM>;
|
||||
using Types = internal::ConvTensorDataTypes<SIGNATURE>;
|
||||
using Ops = internal::ConvElementwiseOps<SIGNATURE>;
|
||||
using AlgorithmType = decltype(ALGORITHM);
|
||||
|
||||
static constexpr auto BWD_CONV_SPECIALIZATION =
|
||||
internal::SetBwdWeightConvSpecialization<ALGORITHM>();
|
||||
|
||||
static constexpr auto BLOCK = internal::SetThreadBlockInfo<ALGORITHM>();
|
||||
static constexpr auto GRIDWISE_GEMM = ALGORITHM.gridwise_gemm;
|
||||
static constexpr auto A_BLOCK_TRANSFER =
|
||||
internal::SetBwdConvBlockTransfer<ALGORITHM.transfer.a>();
|
||||
static constexpr auto B_BLOCK_TRANSFER =
|
||||
internal::SetBwdConvBlockTransfer<ALGORITHM.transfer.b>();
|
||||
static constexpr auto C_BLOCK_TRANSFER = internal::SetCBlockTransfer<SIGNATURE, ALGORITHM>();
|
||||
static constexpr auto BLOCK_GEMM = internal::SetBlockGemm<ALGORITHM>();
|
||||
|
||||
// Check limits for the algorithm parameters.
|
||||
// TODO: Add more limits checks as needed.
|
||||
static_assert(InputVectorTransferLimits<A_BLOCK_TRANSFER>, "Invalid A block transfer config");
|
||||
static_assert(InputVectorTransferLimits<B_BLOCK_TRANSFER>, "Invalid B block transfer config");
|
||||
static_assert(OutputVectorTransferLimits<C_BLOCK_TRANSFER>, "Invalid C block transfer config");
|
||||
static_assert(AccessOrderLimits3D<A_BLOCK_TRANSFER.thread_cluster_order>,
|
||||
"Invalid A thread cluster access order");
|
||||
static_assert(AccessOrderLimits3D<B_BLOCK_TRANSFER.thread_cluster_order>,
|
||||
"Invalid B thread cluster access order");
|
||||
static_assert(AccessOrderLimits3D<A_BLOCK_TRANSFER.src_access_order>,
|
||||
"Invalid A source access order");
|
||||
static_assert(AccessOrderLimits3D<B_BLOCK_TRANSFER.src_access_order>,
|
||||
"Invalid B source access order");
|
||||
|
||||
// The forward convolution kernel class instance.
|
||||
using Instance =
|
||||
ck::tensor_operation::device::DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3<
|
||||
SPATIAL_DIM,
|
||||
typename Layouts::InLayout,
|
||||
typename Layouts::WeiLayout,
|
||||
typename Layouts::OutLayout,
|
||||
typename Layouts::DsLayout,
|
||||
typename Types::InDataType,
|
||||
typename Types::WeiDataType,
|
||||
typename Types::OutDataType,
|
||||
typename Types::AccDataType,
|
||||
typename Types::DsDataType,
|
||||
typename Ops::InElementwiseOp,
|
||||
typename Ops::WeiElementwiseOp,
|
||||
typename Ops::OutElementwiseOp,
|
||||
BWD_CONV_SPECIALIZATION,
|
||||
BLOCK.block_size,
|
||||
BLOCK.per_block.m,
|
||||
BLOCK.per_block.n,
|
||||
BLOCK.per_block.k,
|
||||
GRIDWISE_GEMM.k1,
|
||||
GRIDWISE_GEMM.m_per_wmma,
|
||||
GRIDWISE_GEMM.n_per_wmma,
|
||||
GRIDWISE_GEMM.m_wmma_per_wave,
|
||||
GRIDWISE_GEMM.n_wmma_per_wave,
|
||||
to_sequence_v<A_BLOCK_TRANSFER.thread_cluster_dims>,
|
||||
to_sequence_v<A_BLOCK_TRANSFER.thread_cluster_order>,
|
||||
to_sequence_v<A_BLOCK_TRANSFER.src_access_order>,
|
||||
A_BLOCK_TRANSFER.src_vector_dim,
|
||||
A_BLOCK_TRANSFER.src_scalar_per_vector,
|
||||
A_BLOCK_TRANSFER.lds_dst_scalar_per_vector,
|
||||
A_BLOCK_TRANSFER.lds_padding,
|
||||
to_sequence_v<B_BLOCK_TRANSFER.thread_cluster_dims>,
|
||||
to_sequence_v<B_BLOCK_TRANSFER.thread_cluster_order>,
|
||||
to_sequence_v<B_BLOCK_TRANSFER.src_access_order>,
|
||||
B_BLOCK_TRANSFER.src_vector_dim,
|
||||
B_BLOCK_TRANSFER.src_scalar_per_vector,
|
||||
B_BLOCK_TRANSFER.lds_dst_scalar_per_vector,
|
||||
B_BLOCK_TRANSFER.lds_padding,
|
||||
C_BLOCK_TRANSFER.m_xdl_per_wave_per_shuffle,
|
||||
C_BLOCK_TRANSFER.n_xdl_per_wave_per_shuffle,
|
||||
to_sequence_v<C_BLOCK_TRANSFER.thread_cluster_dims>,
|
||||
C_BLOCK_TRANSFER.scalar_per_vector,
|
||||
BLOCK_GEMM.scheduler,
|
||||
BLOCK_GEMM.pipeline_version,
|
||||
typename Types::OutComputeType,
|
||||
typename Types::InComputeType>;
|
||||
};
|
||||
|
||||
} // namespace ck_tile::builder::factory
|
||||
@@ -0,0 +1,103 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp"
|
||||
#include "ck_tile/builder/conv_signature_concepts.hpp"
|
||||
#include "ck_tile/builder/conv_algorithm_concepts.hpp"
|
||||
#include "ck_tile/builder/conv_algorithm_limits.hpp"
|
||||
#include "ck_tile/builder/builder_utils.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck/conv_tensor_layout.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck/conv_tensor_type.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck/conv_elementwise_op.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck/conv_tuning_params.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck/conv_block_transfer.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck/conv_thread_block.hpp"
|
||||
|
||||
namespace ck_tile::builder::factory {
|
||||
|
||||
// Factory for DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle instance
|
||||
// of a grouped bwd weight convolution kernel.
|
||||
template <ConvSignatureDescriptor auto SIGNATURE,
|
||||
ConvAlgorithmDescriptor auto ALGORITHM,
|
||||
StringLiteral VERSION>
|
||||
requires ConvDirectionIsBackwardWeight<SIGNATURE>
|
||||
struct ConvBwdWeightMultiDXdlFactory
|
||||
{
|
||||
static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim;
|
||||
using Layouts = internal::ConvTensorLayouts<SIGNATURE, SPATIAL_DIM>;
|
||||
using Types = internal::ConvTensorDataTypes<SIGNATURE>;
|
||||
using Ops = internal::ConvElementwiseOps<SIGNATURE>;
|
||||
using AlgorithmType = decltype(ALGORITHM);
|
||||
|
||||
static constexpr auto BWD_CONV_SPECIALIZATION =
|
||||
internal::SetBwdWeightConvSpecialization<ALGORITHM>();
|
||||
|
||||
static constexpr auto BLOCK = internal::SetThreadBlockInfo<ALGORITHM>();
|
||||
static constexpr auto GRIDWISE_GEMM = ALGORITHM.gridwise_gemm;
|
||||
static constexpr auto XDL_PARAMS = GRIDWISE_GEMM.xdl_params;
|
||||
static constexpr auto A_BLOCK_TRANSFER =
|
||||
internal::SetBwdConvBlockTransfer<ALGORITHM.transfer.a>();
|
||||
static constexpr auto B_BLOCK_TRANSFER =
|
||||
internal::SetBwdConvBlockTransfer<ALGORITHM.transfer.b>();
|
||||
static constexpr auto C_BLOCK_TRANSFER = internal::SetCBlockTransfer<SIGNATURE, ALGORITHM>();
|
||||
|
||||
// Check limits for the algorithm parameters.
|
||||
// TODO: Add more limits checks as needed.
|
||||
static_assert(InputVectorTransferLimits<A_BLOCK_TRANSFER>);
|
||||
static_assert(InputVectorTransferLimits<B_BLOCK_TRANSFER>);
|
||||
static_assert(OutputVectorTransferLimits<C_BLOCK_TRANSFER>);
|
||||
static_assert(AccessOrderLimits4D<A_BLOCK_TRANSFER.thread_cluster_order>);
|
||||
static_assert(AccessOrderLimits4D<B_BLOCK_TRANSFER.thread_cluster_order>);
|
||||
static_assert(AccessOrderLimits4D<A_BLOCK_TRANSFER.src_access_order>);
|
||||
static_assert(AccessOrderLimits4D<B_BLOCK_TRANSFER.src_access_order>);
|
||||
|
||||
// The forward convolution kernel class instance.
|
||||
using Instance = ck::tensor_operation::device::DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle<
|
||||
SPATIAL_DIM,
|
||||
typename Layouts::InLayout,
|
||||
typename Layouts::WeiLayout,
|
||||
typename Layouts::OutLayout,
|
||||
typename Layouts::DsLayout,
|
||||
typename Types::InDataType,
|
||||
typename Types::WeiDataType,
|
||||
typename Types::OutDataType,
|
||||
typename Types::AccDataType,
|
||||
typename Types::DsDataType,
|
||||
typename Ops::InElementwiseOp,
|
||||
typename Ops::WeiElementwiseOp,
|
||||
typename Ops::OutElementwiseOp,
|
||||
BWD_CONV_SPECIALIZATION,
|
||||
BLOCK.block_size,
|
||||
BLOCK.per_block.m,
|
||||
BLOCK.per_block.n,
|
||||
BLOCK.per_block.k,
|
||||
GRIDWISE_GEMM.k1,
|
||||
XDL_PARAMS.m_per_xdl,
|
||||
XDL_PARAMS.n_per_xdl,
|
||||
XDL_PARAMS.m_xdl_per_wave,
|
||||
XDL_PARAMS.n_xdl_per_wave,
|
||||
to_sequence_v<A_BLOCK_TRANSFER.thread_cluster_dims>,
|
||||
to_sequence_v<A_BLOCK_TRANSFER.thread_cluster_order>,
|
||||
to_sequence_v<A_BLOCK_TRANSFER.src_access_order>,
|
||||
A_BLOCK_TRANSFER.src_vector_dim,
|
||||
A_BLOCK_TRANSFER.src_scalar_per_vector,
|
||||
A_BLOCK_TRANSFER.lds_dst_scalar_per_vector,
|
||||
A_BLOCK_TRANSFER.lds_padding,
|
||||
to_sequence_v<B_BLOCK_TRANSFER.thread_cluster_dims>,
|
||||
to_sequence_v<B_BLOCK_TRANSFER.thread_cluster_order>,
|
||||
to_sequence_v<B_BLOCK_TRANSFER.src_access_order>,
|
||||
B_BLOCK_TRANSFER.src_vector_dim,
|
||||
B_BLOCK_TRANSFER.src_scalar_per_vector,
|
||||
B_BLOCK_TRANSFER.lds_dst_scalar_per_vector,
|
||||
B_BLOCK_TRANSFER.lds_padding,
|
||||
C_BLOCK_TRANSFER.m_xdl_per_wave_per_shuffle,
|
||||
C_BLOCK_TRANSFER.n_xdl_per_wave_per_shuffle,
|
||||
to_sequence_v<C_BLOCK_TRANSFER.thread_cluster_dims>,
|
||||
C_BLOCK_TRANSFER.scalar_per_vector,
|
||||
typename Types::OutComputeType,
|
||||
typename Types::InComputeType>;
|
||||
};
|
||||
|
||||
} // namespace ck_tile::builder::factory
|
||||
@@ -0,0 +1,111 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_wmma_cshuffle_v3.hpp"
|
||||
#include "ck_tile/builder/conv_signature_concepts.hpp"
|
||||
#include "ck_tile/builder/conv_algorithm_concepts.hpp"
|
||||
#include "ck_tile/builder/conv_algorithm_limits.hpp"
|
||||
#include "ck_tile/builder/builder_utils.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck/conv_tensor_layout.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck/conv_tensor_type.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck/conv_elementwise_op.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck/conv_tuning_params.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck/conv_block_transfer.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck/conv_thread_block.hpp"
|
||||
|
||||
namespace ck_tile::builder::factory {
|
||||
|
||||
// Factory for DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffle_V3 instance
|
||||
// of a grouped bwd weight convolution kernel.
|
||||
template <ConvSignatureDescriptor auto SIGNATURE,
|
||||
ConvAlgorithmDescriptor auto ALGORITHM,
|
||||
StringLiteral VERSION>
|
||||
requires ConvDirectionIsBackwardWeight<SIGNATURE>
|
||||
struct ConvBwdWeightTwoStageWmmaV3Factory
|
||||
{
|
||||
static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim;
|
||||
using Layouts = internal::ConvTensorLayouts<SIGNATURE, SPATIAL_DIM>;
|
||||
using Types = internal::ConvTensorDataTypes<SIGNATURE>;
|
||||
using Ops = internal::ConvElementwiseOps<SIGNATURE>;
|
||||
using AlgorithmType = decltype(ALGORITHM);
|
||||
|
||||
static constexpr auto BWD_CONV_SPECIALIZATION =
|
||||
internal::SetBwdWeightConvSpecialization<ALGORITHM>();
|
||||
|
||||
static constexpr auto BLOCK = internal::SetThreadBlockInfo<ALGORITHM>();
|
||||
static constexpr auto GRIDWISE_GEMM = ALGORITHM.gridwise_gemm;
|
||||
static constexpr auto A_BLOCK_TRANSFER =
|
||||
internal::SetBwdConvBlockTransfer<ALGORITHM.transfer.a>();
|
||||
static constexpr auto B_BLOCK_TRANSFER =
|
||||
internal::SetBwdConvBlockTransfer<ALGORITHM.transfer.b>();
|
||||
static constexpr auto C_BLOCK_TRANSFER = internal::SetCBlockTransfer<SIGNATURE, ALGORITHM>();
|
||||
static constexpr auto BLOCK_GEMM = internal::SetBlockGemm<ALGORITHM>();
|
||||
|
||||
// Check limits for the algorithm parameters.
|
||||
// TODO: Add more limits checks as needed.
|
||||
static_assert(InputVectorTransferLimits<A_BLOCK_TRANSFER>, "Invalid A block transfer config");
|
||||
static_assert(InputVectorTransferLimits<B_BLOCK_TRANSFER>, "Invalid B block transfer config");
|
||||
static_assert(OutputVectorTransferLimits<C_BLOCK_TRANSFER>, "Invalid C block transfer config");
|
||||
static_assert(AccessOrderLimits3D<A_BLOCK_TRANSFER.thread_cluster_order>,
|
||||
"Invalid A thread cluster access order");
|
||||
static_assert(AccessOrderLimits3D<B_BLOCK_TRANSFER.thread_cluster_order>,
|
||||
"Invalid B thread cluster access order");
|
||||
static_assert(AccessOrderLimits3D<A_BLOCK_TRANSFER.src_access_order>,
|
||||
"Invalid A source access order");
|
||||
static_assert(AccessOrderLimits3D<B_BLOCK_TRANSFER.src_access_order>,
|
||||
"Invalid B source access order");
|
||||
|
||||
// The forward convolution kernel class instance.
|
||||
using Instance =
|
||||
ck::tensor_operation::device::DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3<
|
||||
SPATIAL_DIM,
|
||||
typename Layouts::InLayout,
|
||||
typename Layouts::WeiLayout,
|
||||
typename Layouts::OutLayout,
|
||||
typename Types::InDataType,
|
||||
typename Types::WeiDataType,
|
||||
typename Types::OutDataType,
|
||||
typename Types::AccDataType,
|
||||
typename Ops::InElementwiseOp,
|
||||
typename Ops::WeiElementwiseOp,
|
||||
typename Ops::OutElementwiseOp,
|
||||
BWD_CONV_SPECIALIZATION,
|
||||
BLOCK.block_size,
|
||||
BLOCK.per_block.m,
|
||||
BLOCK.per_block.n,
|
||||
BLOCK.per_block.k,
|
||||
GRIDWISE_GEMM.k1,
|
||||
GRIDWISE_GEMM.m_per_wmma,
|
||||
GRIDWISE_GEMM.n_per_wmma,
|
||||
GRIDWISE_GEMM.m_wmma_per_wave,
|
||||
GRIDWISE_GEMM.n_wmma_per_wave,
|
||||
to_sequence_v<A_BLOCK_TRANSFER.thread_cluster_dims>,
|
||||
to_sequence_v<A_BLOCK_TRANSFER.thread_cluster_order>,
|
||||
to_sequence_v<A_BLOCK_TRANSFER.src_access_order>,
|
||||
A_BLOCK_TRANSFER.src_vector_dim,
|
||||
A_BLOCK_TRANSFER.src_scalar_per_vector,
|
||||
A_BLOCK_TRANSFER.lds_dst_scalar_per_vector,
|
||||
A_BLOCK_TRANSFER.lds_padding,
|
||||
to_sequence_v<B_BLOCK_TRANSFER.thread_cluster_dims>,
|
||||
to_sequence_v<B_BLOCK_TRANSFER.thread_cluster_order>,
|
||||
to_sequence_v<B_BLOCK_TRANSFER.src_access_order>,
|
||||
B_BLOCK_TRANSFER.src_vector_dim,
|
||||
B_BLOCK_TRANSFER.src_scalar_per_vector,
|
||||
B_BLOCK_TRANSFER.lds_dst_scalar_per_vector,
|
||||
B_BLOCK_TRANSFER.lds_padding,
|
||||
C_BLOCK_TRANSFER.m_xdl_per_wave_per_shuffle,
|
||||
C_BLOCK_TRANSFER.n_xdl_per_wave_per_shuffle,
|
||||
to_sequence_v<C_BLOCK_TRANSFER.thread_cluster_dims>,
|
||||
C_BLOCK_TRANSFER.scalar_per_vector,
|
||||
BLOCK_GEMM.scheduler,
|
||||
BLOCK_GEMM.pipeline_version,
|
||||
ALGORITHM.num_conv_groups_to_merge,
|
||||
typename Types::OutComputeType,
|
||||
typename Types::InComputeType,
|
||||
ALGORITHM.max_transpose_transfer_src_scalar_per_vector,
|
||||
ALGORITHM.max_transpose_transfer_dst_scalar_per_vector>;
|
||||
};
|
||||
|
||||
} // namespace ck_tile::builder::factory
|
||||
@@ -0,0 +1,111 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp"
|
||||
#include "ck_tile/builder/conv_signature_concepts.hpp"
|
||||
#include "ck_tile/builder/conv_algorithm_concepts.hpp"
|
||||
#include "ck_tile/builder/conv_algorithm_limits.hpp"
|
||||
#include "ck_tile/builder/builder_utils.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck/conv_tensor_layout.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck/conv_tensor_type.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck/conv_elementwise_op.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck/conv_tuning_params.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck/conv_block_transfer.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck/conv_thread_block.hpp"
|
||||
|
||||
namespace ck_tile::builder::factory {
|
||||
|
||||
// Factory for DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle instance
|
||||
// of a grouped bwd weight convolution kernel.
|
||||
template <ConvSignatureDescriptor auto SIGNATURE,
|
||||
ConvAlgorithmDescriptor auto ALGORITHM,
|
||||
StringLiteral VERSION>
|
||||
requires ConvDirectionIsBackwardWeight<SIGNATURE>
|
||||
struct ConvBwdWeightTwoStageXdlFactory
|
||||
{
|
||||
static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim;
|
||||
using Layouts = internal::ConvTensorLayouts<SIGNATURE, SPATIAL_DIM>;
|
||||
using Types = internal::ConvTensorDataTypes<SIGNATURE>;
|
||||
using Ops = internal::ConvElementwiseOps<SIGNATURE>;
|
||||
using AlgorithmType = decltype(ALGORITHM);
|
||||
|
||||
static constexpr auto BWD_CONV_SPECIALIZATION =
|
||||
internal::SetBwdWeightConvSpecialization<ALGORITHM>();
|
||||
|
||||
static constexpr auto BLOCK = internal::SetThreadBlockInfo<ALGORITHM>();
|
||||
static constexpr auto GRIDWISE_GEMM = ALGORITHM.gridwise_gemm;
|
||||
static constexpr auto XDL_PARAMS = GRIDWISE_GEMM.xdl_params;
|
||||
static constexpr auto A_BLOCK_TRANSFER =
|
||||
internal::SetBwdConvBlockTransfer<ALGORITHM.transfer.a>();
|
||||
static constexpr auto B_BLOCK_TRANSFER =
|
||||
internal::SetBwdConvBlockTransfer<ALGORITHM.transfer.b>();
|
||||
static constexpr auto C_BLOCK_TRANSFER = internal::SetCBlockTransfer<SIGNATURE, ALGORITHM>();
|
||||
static constexpr auto BLOCK_GEMM = internal::SetBlockGemm<ALGORITHM>();
|
||||
|
||||
// Check limits for the algorithm parameters.
|
||||
// TODO: Add more limits checks as needed.
|
||||
static_assert(InputVectorTransferLimits<A_BLOCK_TRANSFER>, "Invalid A block transfer config");
|
||||
static_assert(InputVectorTransferLimits<B_BLOCK_TRANSFER>, "Invalid B block transfer config");
|
||||
static_assert(OutputVectorTransferLimits<C_BLOCK_TRANSFER>, "Invalid C block transfer config");
|
||||
static_assert(AccessOrderLimits3D<A_BLOCK_TRANSFER.thread_cluster_order>,
|
||||
"Invalid A thread cluster access order");
|
||||
static_assert(AccessOrderLimits3D<B_BLOCK_TRANSFER.thread_cluster_order>,
|
||||
"Invalid B thread cluster access order");
|
||||
static_assert(AccessOrderLimits3D<A_BLOCK_TRANSFER.src_access_order>,
|
||||
"Invalid A source access order");
|
||||
static_assert(AccessOrderLimits3D<B_BLOCK_TRANSFER.src_access_order>,
|
||||
"Invalid B source access order");
|
||||
|
||||
// The forward convolution kernel class instance.
|
||||
using Instance = ck::tensor_operation::device::DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle<
|
||||
SPATIAL_DIM,
|
||||
typename Layouts::InLayout,
|
||||
typename Layouts::WeiLayout,
|
||||
typename Layouts::OutLayout,
|
||||
typename Types::InDataType,
|
||||
typename Types::WeiDataType,
|
||||
typename Types::OutDataType,
|
||||
typename Types::AccDataType,
|
||||
typename Ops::InElementwiseOp,
|
||||
typename Ops::WeiElementwiseOp,
|
||||
typename Ops::OutElementwiseOp,
|
||||
BWD_CONV_SPECIALIZATION,
|
||||
BLOCK.block_size,
|
||||
BLOCK.per_block.m,
|
||||
BLOCK.per_block.n,
|
||||
BLOCK.per_block.k,
|
||||
GRIDWISE_GEMM.k1,
|
||||
XDL_PARAMS.m_per_xdl,
|
||||
XDL_PARAMS.n_per_xdl,
|
||||
XDL_PARAMS.m_xdl_per_wave,
|
||||
XDL_PARAMS.n_xdl_per_wave,
|
||||
to_sequence_v<A_BLOCK_TRANSFER.thread_cluster_dims>,
|
||||
to_sequence_v<A_BLOCK_TRANSFER.thread_cluster_order>,
|
||||
to_sequence_v<A_BLOCK_TRANSFER.src_access_order>,
|
||||
A_BLOCK_TRANSFER.src_vector_dim,
|
||||
A_BLOCK_TRANSFER.src_scalar_per_vector,
|
||||
A_BLOCK_TRANSFER.lds_dst_scalar_per_vector,
|
||||
A_BLOCK_TRANSFER.lds_padding,
|
||||
to_sequence_v<B_BLOCK_TRANSFER.thread_cluster_dims>,
|
||||
to_sequence_v<B_BLOCK_TRANSFER.thread_cluster_order>,
|
||||
to_sequence_v<B_BLOCK_TRANSFER.src_access_order>,
|
||||
B_BLOCK_TRANSFER.src_vector_dim,
|
||||
B_BLOCK_TRANSFER.src_scalar_per_vector,
|
||||
B_BLOCK_TRANSFER.lds_dst_scalar_per_vector,
|
||||
B_BLOCK_TRANSFER.lds_padding,
|
||||
C_BLOCK_TRANSFER.m_xdl_per_wave_per_shuffle,
|
||||
C_BLOCK_TRANSFER.n_xdl_per_wave_per_shuffle,
|
||||
to_sequence_v<C_BLOCK_TRANSFER.thread_cluster_dims>,
|
||||
C_BLOCK_TRANSFER.scalar_per_vector,
|
||||
BLOCK_GEMM.scheduler,
|
||||
BLOCK_GEMM.pipeline_version,
|
||||
ALGORITHM.num_conv_groups_to_merge,
|
||||
typename Types::OutComputeType,
|
||||
typename Types::InComputeType,
|
||||
ALGORITHM.max_transpose_transfer_src_scalar_per_vector,
|
||||
ALGORITHM.max_transpose_transfer_dst_scalar_per_vector>;
|
||||
};
|
||||
|
||||
} // namespace ck_tile::builder::factory
|
||||
@@ -0,0 +1,109 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle.hpp"
|
||||
#include "ck_tile/builder/conv_signature_concepts.hpp"
|
||||
#include "ck_tile/builder/conv_algorithm_concepts.hpp"
|
||||
#include "ck_tile/builder/conv_algorithm_limits.hpp"
|
||||
#include "ck_tile/builder/builder_utils.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck/conv_tensor_layout.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck/conv_tensor_type.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck/conv_elementwise_op.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck/conv_tuning_params.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck/conv_block_transfer.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck/conv_thread_block.hpp"
|
||||
|
||||
namespace ck_tile::builder::factory {
|
||||
|
||||
// Factory for DeviceGroupedConvBwdWeight_Wmma_CShuffle instance
|
||||
// of a grouped bwd weight convolution kernel.
|
||||
template <ConvSignatureDescriptor auto SIGNATURE,
|
||||
ConvAlgorithmDescriptor auto ALGORITHM,
|
||||
StringLiteral VERSION>
|
||||
requires ConvDirectionIsBackwardWeight<SIGNATURE> && Is3D<SIGNATURE>
|
||||
struct ConvBwdWeightWmmaFactory
|
||||
{
|
||||
static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim;
|
||||
using Layouts = internal::ConvTensorLayouts<SIGNATURE, SPATIAL_DIM>;
|
||||
using Types = internal::ConvTensorDataTypes<SIGNATURE>;
|
||||
using Ops = internal::ConvElementwiseOps<SIGNATURE>;
|
||||
using AlgorithmType = decltype(ALGORITHM);
|
||||
|
||||
static constexpr auto BWD_CONV_SPECIALIZATION =
|
||||
internal::SetBwdWeightConvSpecialization<ALGORITHM>();
|
||||
|
||||
static constexpr auto BLOCK = internal::SetThreadBlockInfo<ALGORITHM>();
|
||||
static constexpr auto GRIDWISE_GEMM = ALGORITHM.gridwise_gemm;
|
||||
static constexpr auto GRIDWISE_GEMM_PIPELINE_VERSION =
|
||||
internal::SetGridwiseGemmPipelineVersion<ALGORITHM>();
|
||||
static constexpr auto LOOP_SCHEDULER = internal::SetLoopScheduler<ALGORITHM>();
|
||||
|
||||
static constexpr auto A_BLOCK_TRANSFER =
|
||||
internal::SetBwdConvBlockTransfer<ALGORITHM.transfer.a>();
|
||||
static constexpr auto B_BLOCK_TRANSFER =
|
||||
internal::SetBwdConvBlockTransfer<ALGORITHM.transfer.b>();
|
||||
static constexpr auto C_BLOCK_TRANSFER = internal::SetCBlockTransfer<SIGNATURE, ALGORITHM>();
|
||||
|
||||
// Check limits for the algorithm parameters.
|
||||
// TODO: Add more limits checks as needed.
|
||||
static_assert(InputVectorTransferLimits<A_BLOCK_TRANSFER>, "Invalid A block transfer config");
|
||||
static_assert(InputVectorTransferLimits<B_BLOCK_TRANSFER>, "Invalid B block transfer config");
|
||||
static_assert(OutputVectorTransferLimits<C_BLOCK_TRANSFER>, "Invalid C block transfer config");
|
||||
static_assert(AccessOrderLimits3D<A_BLOCK_TRANSFER.thread_cluster_order>,
|
||||
"Invalid A thread cluster access order");
|
||||
static_assert(AccessOrderLimits3D<B_BLOCK_TRANSFER.thread_cluster_order>,
|
||||
"Invalid B thread cluster access order");
|
||||
static_assert(AccessOrderLimits3D<A_BLOCK_TRANSFER.src_access_order>,
|
||||
"Invalid A source access order");
|
||||
static_assert(AccessOrderLimits3D<B_BLOCK_TRANSFER.src_access_order>,
|
||||
"Invalid B source access order");
|
||||
|
||||
// The forward convolution kernel class instance.
|
||||
using Instance = ck::tensor_operation::device::DeviceGroupedConvBwdWeight_Wmma_CShuffle<
|
||||
SPATIAL_DIM,
|
||||
typename Layouts::InLayout,
|
||||
typename Layouts::WeiLayout,
|
||||
typename Layouts::OutLayout,
|
||||
typename Types::InDataType,
|
||||
typename Types::WeiDataType,
|
||||
typename Types::OutDataType,
|
||||
typename Types::AccDataType,
|
||||
typename Ops::InElementwiseOp,
|
||||
typename Ops::WeiElementwiseOp,
|
||||
typename Ops::OutElementwiseOp,
|
||||
BWD_CONV_SPECIALIZATION,
|
||||
BLOCK.block_size,
|
||||
BLOCK.per_block.m,
|
||||
BLOCK.per_block.n,
|
||||
BLOCK.per_block.k,
|
||||
GRIDWISE_GEMM.k1,
|
||||
GRIDWISE_GEMM.m_per_wmma,
|
||||
GRIDWISE_GEMM.n_per_wmma,
|
||||
GRIDWISE_GEMM.m_wmma_per_wave,
|
||||
GRIDWISE_GEMM.n_wmma_per_wave,
|
||||
to_sequence_v<A_BLOCK_TRANSFER.thread_cluster_dims>,
|
||||
to_sequence_v<A_BLOCK_TRANSFER.thread_cluster_order>,
|
||||
to_sequence_v<A_BLOCK_TRANSFER.src_access_order>,
|
||||
A_BLOCK_TRANSFER.src_vector_dim,
|
||||
A_BLOCK_TRANSFER.src_scalar_per_vector,
|
||||
A_BLOCK_TRANSFER.lds_dst_scalar_per_vector,
|
||||
A_BLOCK_TRANSFER.lds_padding,
|
||||
to_sequence_v<B_BLOCK_TRANSFER.thread_cluster_dims>,
|
||||
to_sequence_v<B_BLOCK_TRANSFER.thread_cluster_order>,
|
||||
to_sequence_v<B_BLOCK_TRANSFER.src_access_order>,
|
||||
B_BLOCK_TRANSFER.src_vector_dim,
|
||||
B_BLOCK_TRANSFER.src_scalar_per_vector,
|
||||
B_BLOCK_TRANSFER.lds_dst_scalar_per_vector,
|
||||
B_BLOCK_TRANSFER.lds_padding,
|
||||
C_BLOCK_TRANSFER.m_xdl_per_wave_per_shuffle,
|
||||
C_BLOCK_TRANSFER.n_xdl_per_wave_per_shuffle,
|
||||
to_sequence_v<C_BLOCK_TRANSFER.thread_cluster_dims>,
|
||||
C_BLOCK_TRANSFER.scalar_per_vector,
|
||||
ALGORITHM.num_gemm_k_prefetch_stages,
|
||||
LOOP_SCHEDULER,
|
||||
GRIDWISE_GEMM_PIPELINE_VERSION>;
|
||||
};
|
||||
|
||||
} // namespace ck_tile::builder::factory
|
||||
@@ -0,0 +1,109 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle_v3.hpp"
|
||||
#include "ck_tile/builder/conv_signature_concepts.hpp"
|
||||
#include "ck_tile/builder/conv_algorithm_concepts.hpp"
|
||||
#include "ck_tile/builder/conv_algorithm_limits.hpp"
|
||||
#include "ck_tile/builder/builder_utils.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck/conv_tensor_layout.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck/conv_tensor_type.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck/conv_elementwise_op.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck/conv_tuning_params.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck/conv_block_transfer.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck/conv_thread_block.hpp"
|
||||
|
||||
namespace ck_tile::builder::factory {
|
||||
|
||||
// Factory for DeviceGroupedConvBwdWeight_Wmma_CShuffle_V3 instance
|
||||
// of a grouped bwd weight convolution kernel.
|
||||
template <ConvSignatureDescriptor auto SIGNATURE,
|
||||
ConvAlgorithmDescriptor auto ALGORITHM,
|
||||
StringLiteral VERSION>
|
||||
requires ConvDirectionIsBackwardWeight<SIGNATURE>
|
||||
struct ConvBwdWeightWmmaV3Factory
|
||||
{
|
||||
static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim;
|
||||
using Layouts = internal::ConvTensorLayouts<SIGNATURE, SPATIAL_DIM>;
|
||||
using Types = internal::ConvTensorDataTypes<SIGNATURE>;
|
||||
using Ops = internal::ConvElementwiseOps<SIGNATURE>;
|
||||
using AlgorithmType = decltype(ALGORITHM);
|
||||
|
||||
static constexpr auto BWD_CONV_SPECIALIZATION =
|
||||
internal::SetBwdWeightConvSpecialization<ALGORITHM>();
|
||||
|
||||
static constexpr auto BLOCK = internal::SetThreadBlockInfo<ALGORITHM>();
|
||||
static constexpr auto GRIDWISE_GEMM = ALGORITHM.gridwise_gemm;
|
||||
static constexpr auto A_BLOCK_TRANSFER =
|
||||
internal::SetBwdConvBlockTransfer<ALGORITHM.transfer.a>();
|
||||
static constexpr auto B_BLOCK_TRANSFER =
|
||||
internal::SetBwdConvBlockTransfer<ALGORITHM.transfer.b>();
|
||||
static constexpr auto C_BLOCK_TRANSFER = internal::SetCBlockTransfer<SIGNATURE, ALGORITHM>();
|
||||
static constexpr auto BLOCK_GEMM = internal::SetBlockGemm<ALGORITHM>();
|
||||
|
||||
// Check limits for the algorithm parameters.
|
||||
// TODO: Add more limits checks as needed.
|
||||
static_assert(InputVectorTransferLimits<A_BLOCK_TRANSFER>, "Invalid A block transfer config");
|
||||
static_assert(InputVectorTransferLimits<B_BLOCK_TRANSFER>, "Invalid B block transfer config");
|
||||
static_assert(OutputVectorTransferLimits<C_BLOCK_TRANSFER>, "Invalid C block transfer config");
|
||||
static_assert(AccessOrderLimits3D<A_BLOCK_TRANSFER.thread_cluster_order>,
|
||||
"Invalid A thread cluster access order");
|
||||
static_assert(AccessOrderLimits3D<B_BLOCK_TRANSFER.thread_cluster_order>,
|
||||
"Invalid B thread cluster access order");
|
||||
static_assert(AccessOrderLimits3D<A_BLOCK_TRANSFER.src_access_order>,
|
||||
"Invalid A source access order");
|
||||
static_assert(AccessOrderLimits3D<B_BLOCK_TRANSFER.src_access_order>,
|
||||
"Invalid B source access order");
|
||||
|
||||
// The forward convolution kernel class instance.
|
||||
using Instance = ck::tensor_operation::device::DeviceGroupedConvBwdWeight_Wmma_CShuffleV3<
|
||||
SPATIAL_DIM,
|
||||
typename Layouts::InLayout,
|
||||
typename Layouts::WeiLayout,
|
||||
typename Layouts::OutLayout,
|
||||
typename Types::InDataType,
|
||||
typename Types::WeiDataType,
|
||||
typename Types::OutDataType,
|
||||
typename Types::AccDataType,
|
||||
typename Ops::InElementwiseOp,
|
||||
typename Ops::WeiElementwiseOp,
|
||||
typename Ops::OutElementwiseOp,
|
||||
BWD_CONV_SPECIALIZATION,
|
||||
BLOCK.block_size,
|
||||
BLOCK.per_block.m,
|
||||
BLOCK.per_block.n,
|
||||
BLOCK.per_block.k,
|
||||
GRIDWISE_GEMM.k1,
|
||||
GRIDWISE_GEMM.m_per_wmma,
|
||||
GRIDWISE_GEMM.n_per_wmma,
|
||||
GRIDWISE_GEMM.m_wmma_per_wave,
|
||||
GRIDWISE_GEMM.n_wmma_per_wave,
|
||||
to_sequence_v<A_BLOCK_TRANSFER.thread_cluster_dims>,
|
||||
to_sequence_v<A_BLOCK_TRANSFER.thread_cluster_order>,
|
||||
to_sequence_v<A_BLOCK_TRANSFER.src_access_order>,
|
||||
A_BLOCK_TRANSFER.src_vector_dim,
|
||||
A_BLOCK_TRANSFER.src_scalar_per_vector,
|
||||
A_BLOCK_TRANSFER.lds_dst_scalar_per_vector,
|
||||
A_BLOCK_TRANSFER.lds_padding,
|
||||
to_sequence_v<B_BLOCK_TRANSFER.thread_cluster_dims>,
|
||||
to_sequence_v<B_BLOCK_TRANSFER.thread_cluster_order>,
|
||||
to_sequence_v<B_BLOCK_TRANSFER.src_access_order>,
|
||||
B_BLOCK_TRANSFER.src_vector_dim,
|
||||
B_BLOCK_TRANSFER.src_scalar_per_vector,
|
||||
B_BLOCK_TRANSFER.lds_dst_scalar_per_vector,
|
||||
B_BLOCK_TRANSFER.lds_padding,
|
||||
C_BLOCK_TRANSFER.m_xdl_per_wave_per_shuffle,
|
||||
C_BLOCK_TRANSFER.n_xdl_per_wave_per_shuffle,
|
||||
to_sequence_v<C_BLOCK_TRANSFER.thread_cluster_dims>,
|
||||
C_BLOCK_TRANSFER.scalar_per_vector,
|
||||
BLOCK_GEMM.scheduler,
|
||||
BLOCK_GEMM.pipeline_version,
|
||||
typename Types::OutComputeType,
|
||||
typename Types::InComputeType,
|
||||
ALGORITHM.max_transpose_transfer_src_scalar_per_vector,
|
||||
ALGORITHM.max_transpose_transfer_dst_scalar_per_vector>;
|
||||
};
|
||||
|
||||
} // namespace ck_tile::builder::factory
|
||||
@@ -0,0 +1,103 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp"
|
||||
#include "ck_tile/builder/conv_signature_concepts.hpp"
|
||||
#include "ck_tile/builder/conv_algorithm_concepts.hpp"
|
||||
#include "ck_tile/builder/conv_algorithm_limits.hpp"
|
||||
#include "ck_tile/builder/builder_utils.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck/conv_tensor_layout.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck/conv_tensor_type.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck/conv_elementwise_op.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck/conv_tuning_params.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck/conv_block_transfer.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck/conv_thread_block.hpp"
|
||||
|
||||
namespace ck_tile::builder::factory {
|
||||
|
||||
// Factory for DeviceGroupedConvBwdWeight_Xdl_CShuffle instance
|
||||
// of a grouped bwd weight convolution kernel.
|
||||
template <ConvSignatureDescriptor auto SIGNATURE,
|
||||
ConvAlgorithmDescriptor auto ALGORITHM,
|
||||
StringLiteral VERSION>
|
||||
requires ConvDirectionIsBackwardWeight<SIGNATURE>
|
||||
struct ConvBwdWeightXdlFactory
|
||||
{
|
||||
static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim;
|
||||
using Layouts = internal::ConvTensorLayouts<SIGNATURE, SPATIAL_DIM>;
|
||||
using Types = internal::ConvTensorDataTypes<SIGNATURE>;
|
||||
using Ops = internal::ConvElementwiseOps<SIGNATURE>;
|
||||
using AlgorithmType = decltype(ALGORITHM);
|
||||
|
||||
static constexpr auto BWD_CONV_SPECIALIZATION =
|
||||
internal::SetBwdWeightConvSpecialization<ALGORITHM>();
|
||||
|
||||
static constexpr auto BLOCK = internal::SetThreadBlockInfo<ALGORITHM>();
|
||||
static constexpr auto GRIDWISE_GEMM = ALGORITHM.gridwise_gemm;
|
||||
static constexpr auto XDL_PARAMS = GRIDWISE_GEMM.xdl_params;
|
||||
static constexpr auto A_BLOCK_TRANSFER =
|
||||
internal::SetBwdConvBlockTransfer<ALGORITHM.transfer.a>();
|
||||
static constexpr auto B_BLOCK_TRANSFER =
|
||||
internal::SetBwdConvBlockTransfer<ALGORITHM.transfer.b>();
|
||||
static constexpr auto C_BLOCK_TRANSFER = internal::SetCBlockTransfer<SIGNATURE, ALGORITHM>();
|
||||
|
||||
// Check limits for the algorithm parameters.
|
||||
// TODO: Add more limits checks as needed.
|
||||
static_assert(InputVectorTransferLimits<A_BLOCK_TRANSFER>);
|
||||
static_assert(InputVectorTransferLimits<B_BLOCK_TRANSFER>);
|
||||
static_assert(OutputVectorTransferLimits<C_BLOCK_TRANSFER>);
|
||||
static_assert(AccessOrderLimits4D<A_BLOCK_TRANSFER.thread_cluster_order>);
|
||||
static_assert(AccessOrderLimits4D<B_BLOCK_TRANSFER.thread_cluster_order>);
|
||||
static_assert(AccessOrderLimits4D<A_BLOCK_TRANSFER.src_access_order>);
|
||||
static_assert(AccessOrderLimits4D<B_BLOCK_TRANSFER.src_access_order>);
|
||||
|
||||
// The forward convolution kernel class instance.
|
||||
using Instance = ck::tensor_operation::device::DeviceGroupedConvBwdWeight_Xdl_CShuffle<
|
||||
SPATIAL_DIM,
|
||||
typename Layouts::InLayout,
|
||||
typename Layouts::WeiLayout,
|
||||
typename Layouts::OutLayout,
|
||||
typename Types::InDataType,
|
||||
typename Types::WeiDataType,
|
||||
typename Types::OutDataType,
|
||||
typename Types::AccDataType,
|
||||
typename Ops::InElementwiseOp,
|
||||
typename Ops::WeiElementwiseOp,
|
||||
typename Ops::OutElementwiseOp,
|
||||
BWD_CONV_SPECIALIZATION,
|
||||
BLOCK.block_size,
|
||||
BLOCK.per_block.m,
|
||||
BLOCK.per_block.n,
|
||||
BLOCK.per_block.k,
|
||||
GRIDWISE_GEMM.k1,
|
||||
XDL_PARAMS.m_per_xdl,
|
||||
XDL_PARAMS.n_per_xdl,
|
||||
XDL_PARAMS.m_xdl_per_wave,
|
||||
XDL_PARAMS.n_xdl_per_wave,
|
||||
to_sequence_v<A_BLOCK_TRANSFER.thread_cluster_dims>,
|
||||
to_sequence_v<A_BLOCK_TRANSFER.thread_cluster_order>,
|
||||
to_sequence_v<A_BLOCK_TRANSFER.src_access_order>,
|
||||
A_BLOCK_TRANSFER.src_vector_dim,
|
||||
A_BLOCK_TRANSFER.src_scalar_per_vector,
|
||||
A_BLOCK_TRANSFER.lds_dst_scalar_per_vector,
|
||||
A_BLOCK_TRANSFER.lds_padding,
|
||||
to_sequence_v<B_BLOCK_TRANSFER.thread_cluster_dims>,
|
||||
to_sequence_v<B_BLOCK_TRANSFER.thread_cluster_order>,
|
||||
to_sequence_v<B_BLOCK_TRANSFER.src_access_order>,
|
||||
B_BLOCK_TRANSFER.src_vector_dim,
|
||||
B_BLOCK_TRANSFER.src_scalar_per_vector,
|
||||
B_BLOCK_TRANSFER.lds_dst_scalar_per_vector,
|
||||
B_BLOCK_TRANSFER.lds_padding,
|
||||
C_BLOCK_TRANSFER.m_xdl_per_wave_per_shuffle,
|
||||
C_BLOCK_TRANSFER.n_xdl_per_wave_per_shuffle,
|
||||
to_sequence_v<C_BLOCK_TRANSFER.thread_cluster_dims>,
|
||||
C_BLOCK_TRANSFER.scalar_per_vector,
|
||||
typename Types::OutComputeType,
|
||||
typename Types::InComputeType,
|
||||
ALGORITHM.max_transpose_transfer_src_scalar_per_vector,
|
||||
ALGORITHM.max_transpose_transfer_dst_scalar_per_vector>;
|
||||
};
|
||||
|
||||
} // namespace ck_tile::builder::factory
|
||||
@@ -0,0 +1,108 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp"
|
||||
#include "ck_tile/builder/conv_signature_concepts.hpp"
|
||||
#include "ck_tile/builder/conv_algorithm_concepts.hpp"
|
||||
#include "ck_tile/builder/conv_algorithm_limits.hpp"
|
||||
#include "ck_tile/builder/builder_utils.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck/conv_tensor_layout.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck/conv_tensor_type.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck/conv_elementwise_op.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck/conv_tuning_params.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck/conv_block_transfer.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck/conv_thread_block.hpp"
|
||||
|
||||
namespace ck_tile::builder::factory {
|
||||
|
||||
// Factory for DeviceGroupedConvBwdWeight_Xdl_CShuffleV3 instance
|
||||
// of a grouped bwd weight convolution kernel.
|
||||
template <ConvSignatureDescriptor auto SIGNATURE,
|
||||
ConvAlgorithmDescriptor auto ALGORITHM,
|
||||
StringLiteral VERSION>
|
||||
requires ConvDirectionIsBackwardWeight<SIGNATURE>
|
||||
struct ConvBwdWeightXdlV3Factory
|
||||
{
|
||||
static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim;
|
||||
using Layouts = internal::ConvTensorLayouts<SIGNATURE, SPATIAL_DIM>;
|
||||
using Types = internal::ConvTensorDataTypes<SIGNATURE>;
|
||||
using Ops = internal::ConvElementwiseOps<SIGNATURE>;
|
||||
using AlgorithmType = decltype(ALGORITHM);
|
||||
|
||||
static constexpr auto BWD_CONV_SPECIALIZATION =
|
||||
internal::SetBwdWeightConvSpecialization<ALGORITHM>();
|
||||
|
||||
static constexpr auto BLOCK = internal::SetThreadBlockInfo<ALGORITHM>();
|
||||
static constexpr auto GRIDWISE_GEMM = ALGORITHM.gridwise_gemm;
|
||||
static constexpr auto XDL_PARAMS = GRIDWISE_GEMM.xdl_params;
|
||||
static constexpr auto A_BLOCK_TRANSFER =
|
||||
internal::SetBwdConvBlockTransfer<ALGORITHM.transfer.a>();
|
||||
static constexpr auto B_BLOCK_TRANSFER =
|
||||
internal::SetBwdConvBlockTransfer<ALGORITHM.transfer.b>();
|
||||
static constexpr auto C_BLOCK_TRANSFER = internal::SetCBlockTransfer<SIGNATURE, ALGORITHM>();
|
||||
static constexpr auto BLOCK_GEMM = internal::SetBlockGemm<ALGORITHM>();
|
||||
|
||||
// Check limits for the algorithm parameters.
|
||||
// TODO: Add more limits checks as needed.
|
||||
static_assert(InputVectorTransferLimits<A_BLOCK_TRANSFER>, "Invalid A block transfer config");
|
||||
static_assert(InputVectorTransferLimits<B_BLOCK_TRANSFER>, "Invalid B block transfer config");
|
||||
static_assert(OutputVectorTransferLimits<C_BLOCK_TRANSFER>, "Invalid C block transfer config");
|
||||
static_assert(AccessOrderLimits3D<A_BLOCK_TRANSFER.thread_cluster_order>,
|
||||
"Invalid A thread cluster access order");
|
||||
static_assert(AccessOrderLimits3D<B_BLOCK_TRANSFER.thread_cluster_order>,
|
||||
"Invalid B thread cluster access order");
|
||||
static_assert(AccessOrderLimits3D<A_BLOCK_TRANSFER.src_access_order>,
|
||||
"Invalid A source access order");
|
||||
static_assert(AccessOrderLimits3D<B_BLOCK_TRANSFER.src_access_order>,
|
||||
"Invalid B source access order");
|
||||
|
||||
// The forward convolution kernel class instance.
|
||||
using Instance = ck::tensor_operation::device::DeviceGroupedConvBwdWeight_Xdl_CShuffleV3<
|
||||
SPATIAL_DIM,
|
||||
typename Layouts::InLayout,
|
||||
typename Layouts::WeiLayout,
|
||||
typename Layouts::OutLayout,
|
||||
typename Types::InDataType,
|
||||
typename Types::WeiDataType,
|
||||
typename Types::OutDataType,
|
||||
typename Types::AccDataType,
|
||||
typename Ops::InElementwiseOp,
|
||||
typename Ops::WeiElementwiseOp,
|
||||
typename Ops::OutElementwiseOp,
|
||||
BWD_CONV_SPECIALIZATION,
|
||||
BLOCK.block_size,
|
||||
BLOCK.per_block.m,
|
||||
BLOCK.per_block.n,
|
||||
BLOCK.per_block.k,
|
||||
GRIDWISE_GEMM.k1,
|
||||
XDL_PARAMS.m_per_xdl,
|
||||
XDL_PARAMS.n_per_xdl,
|
||||
XDL_PARAMS.m_xdl_per_wave,
|
||||
XDL_PARAMS.n_xdl_per_wave,
|
||||
to_sequence_v<A_BLOCK_TRANSFER.thread_cluster_dims>,
|
||||
to_sequence_v<A_BLOCK_TRANSFER.thread_cluster_order>,
|
||||
to_sequence_v<A_BLOCK_TRANSFER.src_access_order>,
|
||||
A_BLOCK_TRANSFER.src_vector_dim,
|
||||
A_BLOCK_TRANSFER.src_scalar_per_vector,
|
||||
A_BLOCK_TRANSFER.lds_dst_scalar_per_vector,
|
||||
A_BLOCK_TRANSFER.lds_padding,
|
||||
to_sequence_v<B_BLOCK_TRANSFER.thread_cluster_dims>,
|
||||
to_sequence_v<B_BLOCK_TRANSFER.thread_cluster_order>,
|
||||
to_sequence_v<B_BLOCK_TRANSFER.src_access_order>,
|
||||
B_BLOCK_TRANSFER.src_vector_dim,
|
||||
B_BLOCK_TRANSFER.src_scalar_per_vector,
|
||||
B_BLOCK_TRANSFER.lds_dst_scalar_per_vector,
|
||||
B_BLOCK_TRANSFER.lds_padding,
|
||||
C_BLOCK_TRANSFER.m_xdl_per_wave_per_shuffle,
|
||||
C_BLOCK_TRANSFER.n_xdl_per_wave_per_shuffle,
|
||||
to_sequence_v<C_BLOCK_TRANSFER.thread_cluster_dims>,
|
||||
C_BLOCK_TRANSFER.scalar_per_vector,
|
||||
BLOCK_GEMM.scheduler,
|
||||
BLOCK_GEMM.pipeline_version,
|
||||
typename Types::OutComputeType,
|
||||
typename Types::InComputeType>;
|
||||
};
|
||||
|
||||
} // namespace ck_tile::builder::factory
|
||||
@@ -57,6 +57,9 @@
|
||||
#include "ck_tile/builder/conv_algorithm_concepts.hpp"
|
||||
#include "ck_tile/builder/types.hpp"
|
||||
|
||||
// Compile time diagnostics
|
||||
#include "ck_tile/builder/factory/conv_algorithms.hpp"
|
||||
|
||||
// Include all factory implementations
|
||||
#include "ck_tile/builder/factory/conv_fwd_v3_factory.hpp"
|
||||
#include "ck_tile/builder/factory/conv_fwd_xdl_factory.hpp"
|
||||
@@ -65,6 +68,15 @@
|
||||
#include "ck_tile/builder/factory/conv_fwd_large_tensor_factory.hpp"
|
||||
#include "ck_tile/builder/factory/reference_factory.hpp"
|
||||
#include "ck_tile/builder/factory/conv_tile_factory.hpp"
|
||||
#include "ck_tile/builder/factory/conv_bwd_weight_xdl_factory.hpp"
|
||||
#include "ck_tile/builder/factory/conv_bwd_weight_xdl_v3_factory.hpp"
|
||||
#include "ck_tile/builder/factory/conv_bwd_weight_two_stage_xdl_factory.hpp"
|
||||
#include "ck_tile/builder/factory/conv_bwd_weight_dl_factory.hpp"
|
||||
#include "ck_tile/builder/factory/conv_bwd_weight_multi_d_xdl_factory.hpp"
|
||||
#include "ck_tile/builder/factory/conv_bwd_weight_wmma_v3_factory.hpp"
|
||||
#include "ck_tile/builder/factory/conv_bwd_weight_two_stage_wmma_v3_factory.hpp"
|
||||
#include "ck_tile/builder/factory/conv_bwd_weight_wmma_factory.hpp"
|
||||
#include "ck_tile/builder/factory/conv_bwd_weight_multi_d_wmma_v3_factory.hpp"
|
||||
|
||||
namespace ck_tile::builder::factory {
|
||||
|
||||
@@ -87,56 +99,6 @@ namespace ck_tile::builder::factory {
|
||||
//
|
||||
// TODO: Make this dispatch logic much more robust and clear for users.
|
||||
|
||||
// Reference algorithm (simplest implementation for validation)
|
||||
template <typename T>
|
||||
concept IsReferenceAlgorithm = ConvAlgorithmDescriptor<T> && requires {
|
||||
{ T::specialization } -> std::convertible_to<ConvAlgorithmSpecialization>;
|
||||
requires T::specialization == ConvAlgorithmSpecialization::REFERENCE;
|
||||
};
|
||||
|
||||
// CK Tile kernel
|
||||
template <typename T>
|
||||
concept IsTileAlgorithm = ConvAlgorithmDescriptor<T> && SpecifiesTileThreadBlock<T> &&
|
||||
SpecifiesTileTransfer<T> && SpecifiesTileConvSpecialization<T> &&
|
||||
SpecifiesTileBlockGemm<T> && SpecifiesTileOptimizations<T>;
|
||||
|
||||
// XDL-based kernel with V3 pipeline structure (newer block GEMM pipeline)
|
||||
template <typename T>
|
||||
concept IsXdlV3Algorithm =
|
||||
ConvAlgorithmDescriptor<T> && SpecifiesThreadBlock<T> && SpecifiesGridwiseXdlGemm<T> &&
|
||||
SpecifiesBlockTransfer<T> && SpecifiesLdsTransfer<T> && SpecifiesThreadClusterAccessOrder<T> &&
|
||||
SpecifiesSourceAccessOrder<T> && SpecifiesFwdConvSpecialization<T> &&
|
||||
SpecifiesGemmSpecialization<T> && SpecifiesBlockGemm<T>;
|
||||
|
||||
// Standard XDL-based kernel (uses XDLops hardware instructions for matrix multiply)
|
||||
template <typename T>
|
||||
concept IsXdlAlgorithm =
|
||||
ConvAlgorithmDescriptor<T> && SpecifiesThreadBlock<T> && SpecifiesGridwiseXdlGemm<T> &&
|
||||
SpecifiesBlockTransfer<T> && SpecifiesLdsTransfer<T> && SpecifiesThreadClusterAccessOrder<T> &&
|
||||
SpecifiesSourceAccessOrder<T> && SpecifiesFwdConvSpecialization<T> &&
|
||||
SpecifiesGemmSpecialization<T> && SpecifiesNumPrefetchStages<T> &&
|
||||
SpecifiesNumGroupsToMerge<T> && SpecifiesLoopScheduler<T>;
|
||||
|
||||
// WMMA-based kernel (uses Wavefront Matrix-Matrix Accumulate instructions)
|
||||
template <typename T>
|
||||
concept IsWmmaAlgorithm =
|
||||
ConvAlgorithmDescriptor<T> && SpecifiesThreadBlock<T> && SpecifiesGridwiseWmmaGemm<T> &&
|
||||
SpecifiesBlockTransfer<T> && SpecifiesLdsTransfer<T> && SpecifiesThreadClusterAccessOrder<T> &&
|
||||
SpecifiesSourceAccessOrder<T> && SpecifiesFwdConvSpecialization<T> &&
|
||||
SpecifiesGemmSpecialization<T> && SpecifiesNumPrefetchStages<T> && SpecifiesLoopScheduler<T>;
|
||||
|
||||
// Specialized DL kernel for specific NHWC/KYXC/NHWK data layouts
|
||||
template <typename T>
|
||||
concept IsDlAlgorithm =
|
||||
ConvAlgorithmDescriptor<T> && SpecifiesThreadBlock<T> && SpecifiesFwdConvSpecialization<T> &&
|
||||
SpecifiesGemmSpecialization<T> && SpecifiesDlThreadConfig<T> && SpecifiesDlThreadCluster<T> &&
|
||||
SpecifiesDlBlockTransfer<T> && SpecifiesDlEpilogue<T>;
|
||||
|
||||
// XDL-based kernel with large tensor support
|
||||
template <typename T>
|
||||
concept IsLargeTensorAlgorithm =
|
||||
IsXdlAlgorithm<decltype(T::base_algorithm)> && SpecifiesLargeTensorSupport<T>;
|
||||
|
||||
template <ConvSignatureDescriptor auto SIGNATURE,
|
||||
ConvAlgorithmDescriptor auto ALGORITHM,
|
||||
StringLiteral VERSION>
|
||||
@@ -145,35 +107,35 @@ constexpr auto make_conv_instance()
|
||||
using AlgoType = std::remove_const_t<decltype(ALGORITHM)>;
|
||||
|
||||
// Reference algorithm supports all directions
|
||||
if constexpr(IsReferenceAlgorithm<AlgoType>)
|
||||
if constexpr(ReferenceAlgorithm<AlgoType>)
|
||||
{
|
||||
return typename ReferenceFactory<SIGNATURE, ALGORITHM, VERSION>::Instance{};
|
||||
}
|
||||
// CK Tile supports common factory for each direction
|
||||
else if constexpr(IsTileAlgorithm<AlgoType>)
|
||||
else if constexpr(TileAlgorithm<AlgoType>)
|
||||
{
|
||||
return typename ConvTileFactory<SIGNATURE, ALGORITHM, VERSION>::Instance{};
|
||||
}
|
||||
// Forward direction (supports most algorithm variants)
|
||||
else if constexpr(ConvDirectionIsForward<SIGNATURE>)
|
||||
{
|
||||
if constexpr(IsXdlV3Algorithm<AlgoType>)
|
||||
if constexpr(FwdXdlV3Algorithm<AlgoType>)
|
||||
{
|
||||
return typename ConvFwdXdlV3Factory<SIGNATURE, ALGORITHM, VERSION>::Instance{};
|
||||
}
|
||||
else if constexpr(IsXdlAlgorithm<AlgoType>)
|
||||
else if constexpr(FwdXdlAlgorithm<AlgoType>)
|
||||
{
|
||||
return typename ConvFwdXdlFactory<SIGNATURE, ALGORITHM, VERSION>::Instance{};
|
||||
}
|
||||
else if constexpr(IsWmmaAlgorithm<AlgoType>)
|
||||
else if constexpr(FwdWmmaAlgorithm<AlgoType>)
|
||||
{
|
||||
return typename ConvFwdWmmaFactory<SIGNATURE, ALGORITHM, VERSION>::Instance{};
|
||||
}
|
||||
else if constexpr(IsDlAlgorithm<AlgoType>)
|
||||
else if constexpr(FwdDlAlgorithm<AlgoType>)
|
||||
{
|
||||
return typename ConvFwdDlFactory<SIGNATURE, ALGORITHM, VERSION>::Instance{};
|
||||
}
|
||||
else if constexpr(IsLargeTensorAlgorithm<AlgoType>)
|
||||
else if constexpr(LargeTensorAlgorithm<AlgoType>)
|
||||
{
|
||||
return typename ConvFwdLargeTensorFactory<SIGNATURE, ALGORITHM, VERSION>::Instance{};
|
||||
}
|
||||
@@ -197,10 +159,55 @@ constexpr auto make_conv_instance()
|
||||
// Backward weight direction (will expand with more algorithms in the future)
|
||||
else if constexpr(ConvDirectionIsBackwardWeight<SIGNATURE>)
|
||||
{
|
||||
static_assert(false,
|
||||
"Backward weight convolution: Only reference and tile algorithms "
|
||||
"supported currently. "
|
||||
"Optimized kernels (XDL, WMMA, etc.) not yet implemented.");
|
||||
if constexpr(BwdXdlAlgorithm<AlgoType>)
|
||||
{
|
||||
return typename ConvBwdWeightXdlFactory<SIGNATURE, ALGORITHM, VERSION>::Instance{};
|
||||
}
|
||||
else if constexpr(BwdXdlV3Algorithm<AlgoType>)
|
||||
{
|
||||
return typename ConvBwdWeightXdlV3Factory<SIGNATURE, ALGORITHM, VERSION>::Instance{};
|
||||
}
|
||||
else if constexpr(BwdTwoStageXdlAlgorithm<AlgoType>)
|
||||
{
|
||||
return
|
||||
typename ConvBwdWeightTwoStageXdlFactory<SIGNATURE, ALGORITHM, VERSION>::Instance{};
|
||||
}
|
||||
else if constexpr(BwdDlAlgorithm<AlgoType>)
|
||||
{
|
||||
return typename ConvBwdWeightDlFactory<SIGNATURE, ALGORITHM, VERSION>::Instance{};
|
||||
}
|
||||
else if constexpr(BwdMultiDXdlAlgorithm<AlgoType>)
|
||||
{
|
||||
return
|
||||
typename ConvBwdWeightMultiDXdlFactory<SIGNATURE, ALGORITHM, VERSION>::Instance{};
|
||||
}
|
||||
else if constexpr(BwdWmmaV3Algorithm<AlgoType>)
|
||||
{
|
||||
return typename ConvBwdWeightWmmaV3Factory<SIGNATURE, ALGORITHM, VERSION>::Instance{};
|
||||
}
|
||||
else if constexpr(BwdTwoStageWmmaV3Algorithm<AlgoType>)
|
||||
{
|
||||
return typename ConvBwdWeightTwoStageWmmaV3Factory<SIGNATURE, ALGORITHM, VERSION>::
|
||||
Instance{};
|
||||
}
|
||||
else if constexpr(BwdWmmaAlgorithm<AlgoType>)
|
||||
{
|
||||
return typename ConvBwdWeightWmmaFactory<SIGNATURE, ALGORITHM, VERSION>::Instance{};
|
||||
}
|
||||
else if constexpr(BwdMultiDWmmaV3Algorithm<AlgoType>)
|
||||
{
|
||||
return typename ConvBwdWeightMultiDWmmaV3Factory<SIGNATURE, ALGORITHM, VERSION>::
|
||||
Instance{};
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(
|
||||
false,
|
||||
"No suitable backward weight convolution kernel factory found for the provided "
|
||||
"ALGORITHM. The ALGORITHM must satisfy requirements for one of: Reference, Tile, "
|
||||
"XDL, XDL V3, Two-Stage XDL, DL, Multi-D XDL, WMMA V3, Two-Stage "
|
||||
"WMMA V3, WMMA, or Multi-D WMMA V3 variant.");
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
|
||||
@@ -24,10 +24,10 @@ template <ConvSignatureDescriptor auto SIGNATURE,
|
||||
struct ConvFwdDlFactory
|
||||
{
|
||||
static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim;
|
||||
using Layouts = internal::ConvTensorLayouts<SIGNATURE, SPATIAL_DIM, ConvDirection::FORWARD>;
|
||||
using Types = internal::FwdConvTensorDataTypes<SIGNATURE>;
|
||||
using Ops = internal::ElementwiseOps<SIGNATURE>;
|
||||
using AlgorithmType = decltype(ALGORITHM);
|
||||
using Layouts = internal::ConvTensorLayouts<SIGNATURE, SPATIAL_DIM>;
|
||||
using Types = internal::ConvTensorDataTypes<SIGNATURE>;
|
||||
using Ops = internal::ConvElementwiseOps<SIGNATURE>;
|
||||
using AlgorithmType = decltype(ALGORITHM);
|
||||
|
||||
static constexpr auto FWD_CONV_SPECIALIZATION = internal::SetFwdConvSpecialization<ALGORITHM>();
|
||||
static constexpr auto GEMM_SPECIALIZATION = internal::SetGemmSpecialization<ALGORITHM>();
|
||||
@@ -48,7 +48,7 @@ struct ConvFwdDlFactory
|
||||
using M1N1ThreadClusterN1Xs = to_sequence_v<DL_CLUSTER.n1_xs>;
|
||||
|
||||
// A Block Transfer from descriptor - K0_M0_M1_K1 tensor format
|
||||
static constexpr auto DL_A_TRANSFER = ALGORITHM.transfer.a.block_transfer;
|
||||
static constexpr auto DL_A_TRANSFER = ALGORITHM.transfer.a;
|
||||
using ABlockTransferThreadSliceLengths_K0_M0_M1_K1 =
|
||||
to_sequence_v<DL_A_TRANSFER.thread_slice_lengths>;
|
||||
using ABlockTransferThreadClusterLengths_K0_M0_M1_K1 =
|
||||
@@ -64,7 +64,7 @@ struct ConvFwdDlFactory
|
||||
to_sequence_v<DL_A_TRANSFER.dst_vector_tensor_lengths>;
|
||||
|
||||
// B Block Transfer from descriptor - K0_N0_N1_K1 tensor format
|
||||
static constexpr auto DL_B_TRANSFER = ALGORITHM.transfer.b.block_transfer;
|
||||
static constexpr auto DL_B_TRANSFER = ALGORITHM.transfer.b;
|
||||
using BBlockTransferThreadSliceLengths_K0_N0_N1_K1 =
|
||||
to_sequence_v<DL_B_TRANSFER.thread_slice_lengths>;
|
||||
using BBlockTransferThreadClusterLengths_K0_N0_N1_K1 =
|
||||
@@ -80,7 +80,7 @@ struct ConvFwdDlFactory
|
||||
to_sequence_v<DL_B_TRANSFER.dst_vector_tensor_lengths>;
|
||||
|
||||
// C Thread Transfer from descriptor
|
||||
static constexpr auto DL_C_TRANSFER = ALGORITHM.transfer.c.epilogue;
|
||||
static constexpr auto DL_C_TRANSFER = ALGORITHM.transfer.c;
|
||||
using CThreadTransferSrcDstAccessOrder = to_sequence_v<DL_C_TRANSFER.src_dst_access_order>;
|
||||
static constexpr ck::index_t CThreadTransferSrcDstVectorDim = DL_C_TRANSFER.src_dst_vector_dim;
|
||||
static constexpr ck::index_t CThreadTransferDstScalarPerVector =
|
||||
@@ -89,18 +89,18 @@ struct ConvFwdDlFactory
|
||||
// The DL forward convolution kernel class instance
|
||||
using Instance = ck::tensor_operation::device::DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK<
|
||||
SPATIAL_DIM,
|
||||
typename Types::ADataType,
|
||||
typename Types::BDataType,
|
||||
typename Types::DsDataTypes,
|
||||
typename Types::EDataType,
|
||||
typename Types::InDataType,
|
||||
typename Types::WeiDataType,
|
||||
typename Types::DsDataType,
|
||||
typename Types::OutDataType,
|
||||
typename Types::AccDataType,
|
||||
typename Layouts::ALayout,
|
||||
typename Layouts::BLayout,
|
||||
typename Layouts::InLayout,
|
||||
typename Layouts::WeiLayout,
|
||||
typename Layouts::DsLayout,
|
||||
typename Layouts::ELayout,
|
||||
typename Ops::AElementwiseOp,
|
||||
typename Ops::BElementwiseOp,
|
||||
typename Ops::CDEElementwiseOp,
|
||||
typename Layouts::OutLayout,
|
||||
typename Ops::InElementwiseOp,
|
||||
typename Ops::WeiElementwiseOp,
|
||||
typename Ops::OutElementwiseOp,
|
||||
FWD_CONV_SPECIALIZATION,
|
||||
GEMM_SPECIALIZATION,
|
||||
BLOCK.block_size,
|
||||
|
||||
@@ -26,68 +26,65 @@ template <ConvSignatureDescriptor auto SIGNATURE,
|
||||
struct ConvFwdLargeTensorFactory
|
||||
{
|
||||
static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim;
|
||||
using Layouts = internal::ConvTensorLayouts<SIGNATURE, SPATIAL_DIM, ConvDirection::FORWARD>;
|
||||
using Types = internal::FwdConvTensorDataTypes<SIGNATURE>;
|
||||
using Ops = internal::ElementwiseOps<SIGNATURE>;
|
||||
using AlgorithmType = decltype(ALGORITHM);
|
||||
using Layouts = internal::ConvTensorLayouts<SIGNATURE, SPATIAL_DIM>;
|
||||
using Types = internal::ConvTensorDataTypes<SIGNATURE>;
|
||||
using Ops = internal::ConvElementwiseOps<SIGNATURE>;
|
||||
using AlgorithmType = decltype(ALGORITHM);
|
||||
|
||||
static constexpr auto BASE_ALGORITHM = ALGORITHM.base_algorithm;
|
||||
|
||||
static constexpr auto FWD_CONV_SPECIALIZATION =
|
||||
internal::SetFwdConvSpecialization<BASE_ALGORITHM>();
|
||||
static constexpr auto GEMM_SPECIALIZATION = internal::SetGemmSpecialization<BASE_ALGORITHM>();
|
||||
static constexpr auto FWD_CONV_SPECIALIZATION = internal::SetFwdConvSpecialization<ALGORITHM>();
|
||||
static constexpr auto GEMM_SPECIALIZATION = internal::SetGemmSpecialization<ALGORITHM>();
|
||||
static constexpr internal::ConvSpec SPECIALIZATION{.conv_spec = FWD_CONV_SPECIALIZATION,
|
||||
.gemm_spec = GEMM_SPECIALIZATION};
|
||||
|
||||
static constexpr auto LOOP_SCHEDULER = internal::SetLoopScheduler<BASE_ALGORITHM>();
|
||||
static constexpr auto BLOCK = internal::SetThreadBlockInfo<BASE_ALGORITHM>();
|
||||
static constexpr auto GRIDWISE_GEMM = BASE_ALGORITHM.gridwise_gemm;
|
||||
static constexpr auto LOOP_SCHEDULER = internal::SetLoopScheduler<ALGORITHM>();
|
||||
static constexpr auto BLOCK = internal::SetThreadBlockInfo<ALGORITHM>();
|
||||
static constexpr auto GRIDWISE_GEMM = ALGORITHM.gridwise_gemm;
|
||||
static constexpr auto XDL_PARAMS = GRIDWISE_GEMM.xdl_params;
|
||||
static constexpr auto A_BLOCK_TRANSFER =
|
||||
internal::SetFwdConvBlockTransfer<BASE_ALGORITHM.transfer.a>();
|
||||
internal::SetFwdConvBlockTransfer<ALGORITHM.transfer.a>();
|
||||
static constexpr auto B_BLOCK_TRANSFER =
|
||||
internal::SetFwdConvBlockTransfer<BASE_ALGORITHM.transfer.b>();
|
||||
static constexpr auto C_BLOCK_TRANSFER =
|
||||
internal::SetCBlockTransfer<SIGNATURE, BASE_ALGORITHM>();
|
||||
internal::SetFwdConvBlockTransfer<ALGORITHM.transfer.b>();
|
||||
static constexpr auto C_BLOCK_TRANSFER = internal::SetCBlockTransfer<SIGNATURE, ALGORITHM>();
|
||||
|
||||
// Check limits for the algorithm parameters.
|
||||
static_assert(InputVectorTransferLimits<A_BLOCK_TRANSFER>);
|
||||
static_assert(InputVectorTransferLimits<B_BLOCK_TRANSFER>);
|
||||
static_assert(OutputVectorTransferLimits<C_BLOCK_TRANSFER>);
|
||||
static_assert(AccessOrderLimits<A_BLOCK_TRANSFER.thread_cluster_order>);
|
||||
static_assert(AccessOrderLimits<B_BLOCK_TRANSFER.thread_cluster_order>);
|
||||
static_assert(AccessOrderLimits<A_BLOCK_TRANSFER.src_access_order>);
|
||||
static_assert(AccessOrderLimits<B_BLOCK_TRANSFER.src_access_order>);
|
||||
static_assert(AccessOrderLimits3D<A_BLOCK_TRANSFER.thread_cluster_order>);
|
||||
static_assert(AccessOrderLimits3D<B_BLOCK_TRANSFER.thread_cluster_order>);
|
||||
static_assert(AccessOrderLimits3D<A_BLOCK_TRANSFER.src_access_order>);
|
||||
static_assert(AccessOrderLimits3D<B_BLOCK_TRANSFER.src_access_order>);
|
||||
|
||||
// The forward convolution kernel class instance with large tensor support.
|
||||
using Instance =
|
||||
ck::tensor_operation::device::DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor<
|
||||
SPATIAL_DIM,
|
||||
typename Layouts::ALayout,
|
||||
typename Layouts::BLayout,
|
||||
typename Layouts::InLayout,
|
||||
typename Layouts::WeiLayout,
|
||||
typename Layouts::DsLayout,
|
||||
typename Layouts::ELayout,
|
||||
typename Types::ADataType,
|
||||
typename Types::BDataType,
|
||||
typename Layouts::OutLayout,
|
||||
typename Types::InDataType,
|
||||
typename Types::WeiDataType,
|
||||
typename Types::AccDataType,
|
||||
typename Types::CShuffleDataType,
|
||||
typename Types::DsDataTypes,
|
||||
typename Types::EDataType,
|
||||
typename Ops::AElementwiseOp,
|
||||
typename Ops::BElementwiseOp,
|
||||
typename Ops::CDEElementwiseOp,
|
||||
typename Types::OutComputeType,
|
||||
typename Types::DsDataType,
|
||||
typename Types::OutDataType,
|
||||
typename Ops::InElementwiseOp,
|
||||
typename Ops::WeiElementwiseOp,
|
||||
typename Ops::OutElementwiseOp,
|
||||
SPECIALIZATION.conv_spec,
|
||||
SPECIALIZATION.gemm_spec,
|
||||
BASE_ALGORITHM.num_gemm_k_prefetch_stages,
|
||||
ALGORITHM.num_gemm_k_prefetch_stages,
|
||||
BLOCK.block_size,
|
||||
BLOCK.per_block.m,
|
||||
BLOCK.per_block.n,
|
||||
BLOCK.per_block.k,
|
||||
GRIDWISE_GEMM.ak1,
|
||||
GRIDWISE_GEMM.bk1,
|
||||
GRIDWISE_GEMM.m_per_xdl,
|
||||
GRIDWISE_GEMM.n_per_xdl,
|
||||
GRIDWISE_GEMM.m_xdl_per_wave,
|
||||
GRIDWISE_GEMM.n_xdl_per_wave,
|
||||
XDL_PARAMS.m_per_xdl,
|
||||
XDL_PARAMS.n_per_xdl,
|
||||
XDL_PARAMS.m_xdl_per_wave,
|
||||
XDL_PARAMS.n_xdl_per_wave,
|
||||
to_sequence_v<A_BLOCK_TRANSFER.thread_cluster_dims>,
|
||||
to_sequence_v<A_BLOCK_TRANSFER.thread_cluster_order>,
|
||||
to_sequence_v<A_BLOCK_TRANSFER.src_access_order>,
|
||||
@@ -106,8 +103,8 @@ struct ConvFwdLargeTensorFactory
|
||||
C_BLOCK_TRANSFER.n_xdl_per_wave_per_shuffle,
|
||||
to_sequence_v<C_BLOCK_TRANSFER.thread_cluster_dims>,
|
||||
C_BLOCK_TRANSFER.scalar_per_vector,
|
||||
typename Types::AComputeType,
|
||||
typename Types::BComputeType,
|
||||
typename Types::InComputeType,
|
||||
typename Types::WeiComputeType,
|
||||
LOOP_SCHEDULER>;
|
||||
};
|
||||
|
||||
|
||||
@@ -26,10 +26,10 @@ template <ConvSignatureDescriptor auto SIGNATURE,
|
||||
struct ConvFwdXdlV3Factory
|
||||
{
|
||||
static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim;
|
||||
using Layouts = internal::ConvTensorLayouts<SIGNATURE, SPATIAL_DIM, ConvDirection::FORWARD>;
|
||||
using Types = internal::FwdConvTensorDataTypes<SIGNATURE>;
|
||||
using Ops = internal::ElementwiseOps<SIGNATURE>;
|
||||
using AlgorithmType = decltype(ALGORITHM);
|
||||
using Layouts = internal::ConvTensorLayouts<SIGNATURE, SPATIAL_DIM>;
|
||||
using Types = internal::ConvTensorDataTypes<SIGNATURE>;
|
||||
using Ops = internal::ConvElementwiseOps<SIGNATURE>;
|
||||
using AlgorithmType = decltype(ALGORITHM);
|
||||
|
||||
static_assert(ALGORITHM.transfer.a.lds_transfer.is_direct_load ==
|
||||
ALGORITHM.transfer.b.lds_transfer.is_direct_load,
|
||||
@@ -43,6 +43,7 @@ struct ConvFwdXdlV3Factory
|
||||
|
||||
static constexpr auto BLOCK = internal::SetThreadBlockInfo<ALGORITHM>();
|
||||
static constexpr auto GRIDWISE_GEMM = ALGORITHM.gridwise_gemm;
|
||||
static constexpr auto XDL_PARAMS = GRIDWISE_GEMM.xdl_params;
|
||||
static constexpr auto A_BLOCK_TRANSFER =
|
||||
internal::SetFwdConvBlockTransfer<ALGORITHM.transfer.a>();
|
||||
static constexpr auto B_BLOCK_TRANSFER =
|
||||
@@ -55,27 +56,27 @@ struct ConvFwdXdlV3Factory
|
||||
static_assert(InputVectorTransferLimits<A_BLOCK_TRANSFER>);
|
||||
static_assert(InputVectorTransferLimits<B_BLOCK_TRANSFER>);
|
||||
static_assert(OutputVectorTransferLimits<C_BLOCK_TRANSFER>);
|
||||
static_assert(AccessOrderLimits<A_BLOCK_TRANSFER.thread_cluster_order>);
|
||||
static_assert(AccessOrderLimits<B_BLOCK_TRANSFER.thread_cluster_order>);
|
||||
static_assert(AccessOrderLimits<A_BLOCK_TRANSFER.src_access_order>);
|
||||
static_assert(AccessOrderLimits<B_BLOCK_TRANSFER.src_access_order>);
|
||||
static_assert(AccessOrderLimits3D<A_BLOCK_TRANSFER.thread_cluster_order>);
|
||||
static_assert(AccessOrderLimits3D<B_BLOCK_TRANSFER.thread_cluster_order>);
|
||||
static_assert(AccessOrderLimits3D<A_BLOCK_TRANSFER.src_access_order>);
|
||||
static_assert(AccessOrderLimits3D<B_BLOCK_TRANSFER.src_access_order>);
|
||||
|
||||
// The forward convolution kernel class instance.
|
||||
using Instance = ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<
|
||||
SPATIAL_DIM,
|
||||
typename Layouts::ALayout,
|
||||
typename Layouts::BLayout,
|
||||
typename Layouts::InLayout,
|
||||
typename Layouts::WeiLayout,
|
||||
typename Layouts::DsLayout,
|
||||
typename Layouts::ELayout,
|
||||
typename Types::ADataType,
|
||||
typename Types::BDataType,
|
||||
typename Layouts::OutLayout,
|
||||
typename Types::InDataType,
|
||||
typename Types::WeiDataType,
|
||||
typename Types::AccDataType,
|
||||
typename Types::CShuffleDataType,
|
||||
typename Types::DsDataTypes,
|
||||
typename Types::EDataType,
|
||||
typename Ops::AElementwiseOp,
|
||||
typename Ops::BElementwiseOp,
|
||||
typename Ops::CDEElementwiseOp,
|
||||
typename Types::OutComputeType,
|
||||
typename Types::DsDataType,
|
||||
typename Types::OutDataType,
|
||||
typename Ops::InElementwiseOp,
|
||||
typename Ops::WeiElementwiseOp,
|
||||
typename Ops::OutElementwiseOp,
|
||||
SPECIALIZATION.conv_spec,
|
||||
SPECIALIZATION.gemm_spec,
|
||||
BLOCK.block_size,
|
||||
@@ -84,10 +85,10 @@ struct ConvFwdXdlV3Factory
|
||||
BLOCK.per_block.k,
|
||||
GRIDWISE_GEMM.ak1,
|
||||
GRIDWISE_GEMM.bk1,
|
||||
GRIDWISE_GEMM.m_per_xdl,
|
||||
GRIDWISE_GEMM.n_per_xdl,
|
||||
GRIDWISE_GEMM.m_xdl_per_wave,
|
||||
GRIDWISE_GEMM.n_xdl_per_wave,
|
||||
XDL_PARAMS.m_per_xdl,
|
||||
XDL_PARAMS.n_per_xdl,
|
||||
XDL_PARAMS.m_xdl_per_wave,
|
||||
XDL_PARAMS.n_xdl_per_wave,
|
||||
to_sequence_v<A_BLOCK_TRANSFER.thread_cluster_dims>,
|
||||
to_sequence_v<A_BLOCK_TRANSFER.thread_cluster_order>,
|
||||
to_sequence_v<A_BLOCK_TRANSFER.src_access_order>,
|
||||
@@ -108,8 +109,8 @@ struct ConvFwdXdlV3Factory
|
||||
C_BLOCK_TRANSFER.scalar_per_vector,
|
||||
BLOCK_GEMM.scheduler,
|
||||
BLOCK_GEMM.pipeline_version,
|
||||
typename Types::AComputeType,
|
||||
typename Types::BComputeType,
|
||||
typename Types::InComputeType,
|
||||
typename Types::WeiComputeType,
|
||||
IS_DIRECT_LOAD>;
|
||||
};
|
||||
|
||||
|
||||
@@ -26,10 +26,10 @@ template <ConvSignatureDescriptor auto SIGNATURE,
|
||||
struct ConvFwdWmmaFactory
|
||||
{
|
||||
static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim;
|
||||
using Layouts = internal::ConvTensorLayouts<SIGNATURE, SPATIAL_DIM, ConvDirection::FORWARD>;
|
||||
using Types = internal::FwdConvTensorDataTypes<SIGNATURE>;
|
||||
using Ops = internal::ElementwiseOps<SIGNATURE>;
|
||||
using AlgorithmType = decltype(ALGORITHM);
|
||||
using Layouts = internal::ConvTensorLayouts<SIGNATURE, SPATIAL_DIM>;
|
||||
using Types = internal::ConvTensorDataTypes<SIGNATURE>;
|
||||
using Ops = internal::ConvElementwiseOps<SIGNATURE>;
|
||||
using AlgorithmType = decltype(ALGORITHM);
|
||||
|
||||
static constexpr auto FWD_CONV_SPECIALIZATION = internal::SetFwdConvSpecialization<ALGORITHM>();
|
||||
static constexpr auto GEMM_SPECIALIZATION = internal::SetGemmSpecialization<ALGORITHM>();
|
||||
@@ -52,27 +52,27 @@ struct ConvFwdWmmaFactory
|
||||
static_assert(InputVectorTransferLimits<A_BLOCK_TRANSFER>);
|
||||
static_assert(InputVectorTransferLimits<B_BLOCK_TRANSFER>);
|
||||
static_assert(OutputVectorTransferLimits<C_BLOCK_TRANSFER>);
|
||||
static_assert(AccessOrderLimits<A_BLOCK_TRANSFER.thread_cluster_order>);
|
||||
static_assert(AccessOrderLimits<B_BLOCK_TRANSFER.thread_cluster_order>);
|
||||
static_assert(AccessOrderLimits<A_BLOCK_TRANSFER.src_access_order>);
|
||||
static_assert(AccessOrderLimits<B_BLOCK_TRANSFER.src_access_order>);
|
||||
static_assert(AccessOrderLimits3D<A_BLOCK_TRANSFER.thread_cluster_order>);
|
||||
static_assert(AccessOrderLimits3D<B_BLOCK_TRANSFER.thread_cluster_order>);
|
||||
static_assert(AccessOrderLimits3D<A_BLOCK_TRANSFER.src_access_order>);
|
||||
static_assert(AccessOrderLimits3D<B_BLOCK_TRANSFER.src_access_order>);
|
||||
|
||||
// The forward convolution kernel class instance.
|
||||
using Instance = ck::tensor_operation::device::DeviceGroupedConvFwdMultipleD_Wmma_CShuffle<
|
||||
SPATIAL_DIM,
|
||||
typename Layouts::ALayout,
|
||||
typename Layouts::BLayout,
|
||||
typename Layouts::InLayout,
|
||||
typename Layouts::WeiLayout,
|
||||
typename Layouts::DsLayout,
|
||||
typename Layouts::ELayout,
|
||||
typename Types::ADataType,
|
||||
typename Types::BDataType,
|
||||
typename Layouts::OutLayout,
|
||||
typename Types::InDataType,
|
||||
typename Types::WeiDataType,
|
||||
typename Types::AccDataType,
|
||||
typename Types::CShuffleDataType,
|
||||
typename Types::DsDataTypes,
|
||||
typename Types::EDataType,
|
||||
typename Ops::AElementwiseOp,
|
||||
typename Ops::BElementwiseOp,
|
||||
typename Ops::CDEElementwiseOp,
|
||||
typename Types::OutComputeType,
|
||||
typename Types::DsDataType,
|
||||
typename Types::OutDataType,
|
||||
typename Ops::InElementwiseOp,
|
||||
typename Ops::WeiElementwiseOp,
|
||||
typename Ops::OutElementwiseOp,
|
||||
SPECIALIZATION.conv_spec,
|
||||
SPECIALIZATION.gemm_spec,
|
||||
ALGORITHM.num_gemm_k_prefetch_stages,
|
||||
|
||||
@@ -26,10 +26,10 @@ template <ConvSignatureDescriptor auto SIGNATURE,
|
||||
struct ConvFwdXdlFactory
|
||||
{
|
||||
static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim;
|
||||
using Layouts = internal::ConvTensorLayouts<SIGNATURE, SPATIAL_DIM, ConvDirection::FORWARD>;
|
||||
using Types = internal::FwdConvTensorDataTypes<SIGNATURE>;
|
||||
using Ops = internal::ElementwiseOps<SIGNATURE>;
|
||||
using AlgorithmType = decltype(ALGORITHM);
|
||||
using Layouts = internal::ConvTensorLayouts<SIGNATURE, SPATIAL_DIM>;
|
||||
using Types = internal::ConvTensorDataTypes<SIGNATURE>;
|
||||
using Ops = internal::ConvElementwiseOps<SIGNATURE>;
|
||||
using AlgorithmType = decltype(ALGORITHM);
|
||||
|
||||
static constexpr auto FWD_CONV_SPECIALIZATION = internal::SetFwdConvSpecialization<ALGORITHM>();
|
||||
static constexpr auto GEMM_SPECIALIZATION = internal::SetGemmSpecialization<ALGORITHM>();
|
||||
@@ -39,6 +39,7 @@ struct ConvFwdXdlFactory
|
||||
static constexpr auto LOOP_SCHEDULER = internal::SetLoopScheduler<ALGORITHM>();
|
||||
static constexpr auto BLOCK = internal::SetThreadBlockInfo<ALGORITHM>();
|
||||
static constexpr auto GRIDWISE_GEMM = ALGORITHM.gridwise_gemm;
|
||||
static constexpr auto XDL_PARAMS = GRIDWISE_GEMM.xdl_params;
|
||||
static constexpr auto A_BLOCK_TRANSFER =
|
||||
internal::SetFwdConvBlockTransfer<ALGORITHM.transfer.a>();
|
||||
static constexpr auto B_BLOCK_TRANSFER =
|
||||
@@ -50,27 +51,27 @@ struct ConvFwdXdlFactory
|
||||
static_assert(InputVectorTransferLimits<A_BLOCK_TRANSFER>);
|
||||
static_assert(InputVectorTransferLimits<B_BLOCK_TRANSFER>);
|
||||
static_assert(OutputVectorTransferLimits<C_BLOCK_TRANSFER>);
|
||||
static_assert(AccessOrderLimits<A_BLOCK_TRANSFER.thread_cluster_order>);
|
||||
static_assert(AccessOrderLimits<B_BLOCK_TRANSFER.thread_cluster_order>);
|
||||
static_assert(AccessOrderLimits<A_BLOCK_TRANSFER.src_access_order>);
|
||||
static_assert(AccessOrderLimits<B_BLOCK_TRANSFER.src_access_order>);
|
||||
static_assert(AccessOrderLimits3D<A_BLOCK_TRANSFER.thread_cluster_order>);
|
||||
static_assert(AccessOrderLimits3D<B_BLOCK_TRANSFER.thread_cluster_order>);
|
||||
static_assert(AccessOrderLimits3D<A_BLOCK_TRANSFER.src_access_order>);
|
||||
static_assert(AccessOrderLimits3D<B_BLOCK_TRANSFER.src_access_order>);
|
||||
|
||||
// The forward convolution kernel class instance.
|
||||
using Instance = ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<
|
||||
SPATIAL_DIM,
|
||||
typename Layouts::ALayout,
|
||||
typename Layouts::BLayout,
|
||||
typename Layouts::InLayout,
|
||||
typename Layouts::WeiLayout,
|
||||
typename Layouts::DsLayout,
|
||||
typename Layouts::ELayout,
|
||||
typename Types::ADataType,
|
||||
typename Types::BDataType,
|
||||
typename Layouts::OutLayout,
|
||||
typename Types::InDataType,
|
||||
typename Types::WeiDataType,
|
||||
typename Types::AccDataType,
|
||||
typename Types::CShuffleDataType,
|
||||
typename Types::DsDataTypes,
|
||||
typename Types::EDataType,
|
||||
typename Ops::AElementwiseOp,
|
||||
typename Ops::BElementwiseOp,
|
||||
typename Ops::CDEElementwiseOp,
|
||||
typename Types::OutComputeType,
|
||||
typename Types::DsDataType,
|
||||
typename Types::OutDataType,
|
||||
typename Ops::InElementwiseOp,
|
||||
typename Ops::WeiElementwiseOp,
|
||||
typename Ops::OutElementwiseOp,
|
||||
SPECIALIZATION.conv_spec,
|
||||
SPECIALIZATION.gemm_spec,
|
||||
ALGORITHM.num_gemm_k_prefetch_stages,
|
||||
@@ -80,10 +81,10 @@ struct ConvFwdXdlFactory
|
||||
BLOCK.per_block.k,
|
||||
GRIDWISE_GEMM.ak1,
|
||||
GRIDWISE_GEMM.bk1,
|
||||
GRIDWISE_GEMM.m_per_xdl,
|
||||
GRIDWISE_GEMM.n_per_xdl,
|
||||
GRIDWISE_GEMM.m_xdl_per_wave,
|
||||
GRIDWISE_GEMM.n_xdl_per_wave,
|
||||
XDL_PARAMS.m_per_xdl,
|
||||
XDL_PARAMS.n_per_xdl,
|
||||
XDL_PARAMS.m_xdl_per_wave,
|
||||
XDL_PARAMS.n_xdl_per_wave,
|
||||
to_sequence_v<A_BLOCK_TRANSFER.thread_cluster_dims>,
|
||||
to_sequence_v<A_BLOCK_TRANSFER.thread_cluster_order>,
|
||||
to_sequence_v<A_BLOCK_TRANSFER.src_access_order>,
|
||||
@@ -102,10 +103,10 @@ struct ConvFwdXdlFactory
|
||||
C_BLOCK_TRANSFER.n_xdl_per_wave_per_shuffle,
|
||||
to_sequence_v<C_BLOCK_TRANSFER.thread_cluster_dims>,
|
||||
C_BLOCK_TRANSFER.scalar_per_vector,
|
||||
typename Types::AComputeType,
|
||||
typename Types::BComputeType,
|
||||
typename Types::InComputeType,
|
||||
typename Types::WeiComputeType,
|
||||
LOOP_SCHEDULER,
|
||||
ALGORITHM.num_groups_to_merge>;
|
||||
ALGORITHM.num_conv_groups_to_merge>;
|
||||
};
|
||||
|
||||
} // namespace ck_tile::builder::factory
|
||||
|
||||
@@ -10,27 +10,28 @@
|
||||
namespace ck_tile::builder::factory::internal {
|
||||
|
||||
// Block transfer parameters for A or B tensor.
|
||||
template <size_t ThreadClusterRank = 3>
|
||||
struct BlockTransfer
|
||||
{
|
||||
ck::Array<size_t, 3> thread_cluster_dims = {0, 0, 0}; // k0, m, k1
|
||||
ck::Array<size_t, 3> thread_cluster_order = {0, 0, 0};
|
||||
ck::Array<size_t, 3> src_access_order = {0, 0, 0};
|
||||
size_t src_vector_dim = 0;
|
||||
size_t src_scalar_per_vector = 0;
|
||||
size_t lds_dst_scalar_per_vector = 0;
|
||||
bool is_direct_load = false;
|
||||
bool lds_padding = false;
|
||||
ck::Array<size_t, ThreadClusterRank> thread_cluster_dims{};
|
||||
ck::Array<size_t, ThreadClusterRank> thread_cluster_order{};
|
||||
ck::Array<size_t, ThreadClusterRank> src_access_order{};
|
||||
size_t src_vector_dim = 0;
|
||||
size_t src_scalar_per_vector = 0;
|
||||
size_t lds_dst_scalar_per_vector = 0;
|
||||
bool is_direct_load = false;
|
||||
bool lds_padding = false;
|
||||
};
|
||||
|
||||
template <auto TRANSFER>
|
||||
constexpr BlockTransfer SetFwdConvBlockTransfer()
|
||||
constexpr BlockTransfer<> SetFwdConvBlockTransfer()
|
||||
{
|
||||
auto& block_xfer = TRANSFER.block_transfer;
|
||||
auto& block_order = TRANSFER.block_transfer_access_order;
|
||||
auto& src_order = TRANSFER.src_access_order;
|
||||
auto& lds_cfg = TRANSFER.lds_transfer;
|
||||
|
||||
return BlockTransfer{
|
||||
return BlockTransfer<>{
|
||||
.thread_cluster_dims = {block_xfer.k0, block_xfer.m_n, block_xfer.k1},
|
||||
.thread_cluster_order = {block_order.order[0], block_order.order[1], block_order.order[2]},
|
||||
.src_access_order = {src_order.order[0], src_order.order[1], src_order.order[2]},
|
||||
@@ -42,6 +43,59 @@ constexpr BlockTransfer SetFwdConvBlockTransfer()
|
||||
};
|
||||
}
|
||||
|
||||
template <auto TRANSFER>
|
||||
constexpr auto SetBwdConvBlockTransfer()
|
||||
{
|
||||
auto& block_xfer = TRANSFER.block_transfer;
|
||||
auto& block_order = TRANSFER.block_transfer_access_order;
|
||||
auto& src_order = TRANSFER.src_access_order;
|
||||
auto& lds_cfg = TRANSFER.lds_transfer;
|
||||
|
||||
constexpr auto array_length = block_order.order.size();
|
||||
static_assert(block_order.order.size() == src_order.order.size(),
|
||||
"Mismatched size between block order and src order");
|
||||
|
||||
if constexpr(array_length == 3)
|
||||
{
|
||||
return BlockTransfer<3>{
|
||||
.thread_cluster_dims = {block_xfer.k0, block_xfer.m_n, block_xfer.k1},
|
||||
.thread_cluster_order = {block_order.order[0],
|
||||
block_order.order[1],
|
||||
block_order.order[2]},
|
||||
.src_access_order = {src_order.order[0], src_order.order[1], src_order.order[2]},
|
||||
.src_vector_dim = lds_cfg.src_vector_dim,
|
||||
.src_scalar_per_vector = lds_cfg.src_scalar_per_vector,
|
||||
.lds_dst_scalar_per_vector = lds_cfg.lds_dst_scalar_per_vector,
|
||||
.lds_padding = lds_cfg.lds_padding,
|
||||
};
|
||||
}
|
||||
else if constexpr(array_length == 4)
|
||||
{
|
||||
return BlockTransfer<4>{
|
||||
.thread_cluster_dims = {block_xfer.k_batch_size,
|
||||
block_xfer.k0,
|
||||
block_xfer.m_n,
|
||||
block_xfer.k1},
|
||||
.thread_cluster_order = {block_order.order[0],
|
||||
block_order.order[1],
|
||||
block_order.order[2],
|
||||
block_order.order[3]},
|
||||
.src_access_order = {src_order.order[0],
|
||||
src_order.order[1],
|
||||
src_order.order[2],
|
||||
src_order.order[3]},
|
||||
.src_vector_dim = lds_cfg.src_vector_dim,
|
||||
.src_scalar_per_vector = lds_cfg.src_scalar_per_vector,
|
||||
.lds_dst_scalar_per_vector = lds_cfg.lds_dst_scalar_per_vector,
|
||||
.lds_padding = lds_cfg.lds_padding,
|
||||
};
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(false, "Internal error: Unsupported array length");
|
||||
}
|
||||
}
|
||||
|
||||
// Block transfer parameters for C tensor.
|
||||
struct CBlockTransfer
|
||||
{
|
||||
|
||||
@@ -62,14 +62,15 @@ consteval auto GetElementwiseOp()
|
||||
}
|
||||
|
||||
template <auto Sig>
|
||||
struct ElementwiseOps
|
||||
struct ConvElementwiseOps
|
||||
{
|
||||
static constexpr auto input_op = GetElementwiseOp<Sig.input>();
|
||||
static constexpr auto weight_op = GetElementwiseOp<Sig.weight>();
|
||||
static constexpr auto output_op = GetElementwiseOp<Sig.output>();
|
||||
using AElementwiseOp = typename decltype(input_op)::Op;
|
||||
using BElementwiseOp = typename decltype(weight_op)::Op;
|
||||
using CDEElementwiseOp = typename decltype(output_op)::Op;
|
||||
|
||||
using InElementwiseOp = typename decltype(input_op)::Op;
|
||||
using WeiElementwiseOp = typename decltype(weight_op)::Op;
|
||||
using OutElementwiseOp = typename decltype(output_op)::Op;
|
||||
};
|
||||
|
||||
} // namespace ck_tile::builder::factory::internal
|
||||
|
||||
@@ -190,7 +190,7 @@ consteval auto GetAuxiliaryTensorLayoutTuple(std::index_sequence<Indices...>)
|
||||
decltype(TensorLayoutToCK<AuxiliaryTensorConfigsArray[Indices].layout>())...>{};
|
||||
}
|
||||
|
||||
template <auto AuxiliaryTensorConfigsValue, size_t SPATIAL_DIM, ConvDirection DIR>
|
||||
template <auto AuxiliaryTensorConfigsValue, size_t SPATIAL_DIM>
|
||||
requires(ConvSpatialDim<SPATIAL_DIM>)
|
||||
struct AuxiliaryTensorLayouts
|
||||
{
|
||||
@@ -200,34 +200,32 @@ struct AuxiliaryTensorLayouts
|
||||
};
|
||||
|
||||
// TODO: Currently only the ouput tensor can have auxiliary tensors (e.g., bias).
|
||||
template <auto Signature, size_t SPATIAL_DIM, ConvDirection DIR>
|
||||
template <auto Signature, size_t SPATIAL_DIM>
|
||||
requires(HasElementwiseOpWithAuxiliaryOperands<decltype(Signature.output)>)
|
||||
consteval auto GetAuxiliaryTensorLayouts()
|
||||
{
|
||||
return AuxiliaryTensorLayouts<Signature.output.operation.auxiliary_operand_configs,
|
||||
SPATIAL_DIM,
|
||||
DIR>{};
|
||||
SPATIAL_DIM>{};
|
||||
}
|
||||
|
||||
template <auto Signature, size_t SPATIAL_DIM, ConvDirection DIR>
|
||||
template <auto Signature, size_t SPATIAL_DIM>
|
||||
requires(!HasElementwiseOpWithAuxiliaryOperands<decltype(Signature.output)>)
|
||||
consteval auto GetAuxiliaryTensorLayouts()
|
||||
{
|
||||
return EmptyAuxiliaryTensorLayout{};
|
||||
}
|
||||
|
||||
template <auto Signature, size_t SPATIAL_DIM, ConvDirection DIR>
|
||||
template <auto Signature, size_t SPATIAL_DIM>
|
||||
requires(ConvSpatialDim<SPATIAL_DIM> &&
|
||||
ValidConvInputLayoutForSpatialDim<Signature.input.config.layout, SPATIAL_DIM> &&
|
||||
ValidConvWeightLayoutForSpatialDim<Signature.weight.config.layout, SPATIAL_DIM> &&
|
||||
ValidConvOutputLayoutForSpatialDim<Signature.output.config.layout, SPATIAL_DIM>)
|
||||
struct ConvTensorLayouts
|
||||
{
|
||||
static_assert(DIR == ConvDirection::FORWARD, "Only Forward convolution is supported.");
|
||||
using ALayout = decltype(TensorLayoutToCK<Signature.input.config.layout>());
|
||||
using BLayout = decltype(TensorLayoutToCK<Signature.weight.config.layout>());
|
||||
using ELayout = decltype(TensorLayoutToCK<Signature.output.config.layout>());
|
||||
using DsLayout = decltype(GetAuxiliaryTensorLayouts<Signature, SPATIAL_DIM, DIR>())::type;
|
||||
using InLayout = decltype(TensorLayoutToCK<Signature.input.config.layout>());
|
||||
using WeiLayout = decltype(TensorLayoutToCK<Signature.weight.config.layout>());
|
||||
using OutLayout = decltype(TensorLayoutToCK<Signature.output.config.layout>());
|
||||
using DsLayout = decltype(GetAuxiliaryTensorLayouts<Signature, SPATIAL_DIM>())::type;
|
||||
};
|
||||
|
||||
} // namespace ck_tile::builder::factory::internal
|
||||
|
||||
@@ -156,7 +156,7 @@ consteval auto GetAuxiliaryTensorDataTypes()
|
||||
}
|
||||
|
||||
template <auto Signature>
|
||||
struct FwdConvTensorDataTypes
|
||||
struct ConvTensorDataTypes
|
||||
{
|
||||
static constexpr auto input_types =
|
||||
GetTensorDataAndComputeTypes<Signature.input.config, Signature.data_type>();
|
||||
@@ -165,20 +165,17 @@ struct FwdConvTensorDataTypes
|
||||
static constexpr auto output_types =
|
||||
GetTensorDataAndComputeTypes<Signature.output.config, Signature.data_type>();
|
||||
|
||||
using ADataType = typename decltype(input_types.first)::type;
|
||||
using AComputeType = typename decltype(input_types.second)::type;
|
||||
using BDataType = typename decltype(weight_types.first)::type;
|
||||
using BComputeType = typename decltype(weight_types.second)::type;
|
||||
using InDataType = typename decltype(input_types.first)::type;
|
||||
using InComputeType = typename decltype(input_types.second)::type;
|
||||
using WeiDataType = typename decltype(weight_types.first)::type;
|
||||
using WeiComputeType = typename decltype(weight_types.second)::type;
|
||||
using OutDataType = typename decltype(output_types.first)::type;
|
||||
using OutComputeType = typename decltype(output_types.second)::type;
|
||||
using AccDataType =
|
||||
typename decltype(GetTensorAccumulationType<Signature.accumulation_data_type,
|
||||
Signature.data_type>())::type;
|
||||
using EDataType = typename decltype(output_types.first)::type;
|
||||
|
||||
// This is the "compute" type for output.
|
||||
using CShuffleDataType = typename decltype(output_types.second)::type;
|
||||
|
||||
// Data types for the auxiliary tensors (e.g., bias).
|
||||
using DsDataTypes = typename decltype(GetAuxiliaryTensorDataTypes<Signature>())::type;
|
||||
using DsDataType = typename decltype(GetAuxiliaryTensorDataTypes<Signature>())::type;
|
||||
};
|
||||
|
||||
} // namespace ck_tile::builder::factory::internal
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
#pragma once
|
||||
|
||||
#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_base.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp"
|
||||
@@ -37,7 +38,7 @@ struct BlockGemmSpec
|
||||
template <ConvAlgorithmDescriptor auto ALGORITHM>
|
||||
consteval BlockGemmSpec SetBlockGemm()
|
||||
{
|
||||
constexpr auto& BG = ALGORITHM.block_gemm;
|
||||
constexpr auto& BG = ALGORITHM.block_gemm_pipeline;
|
||||
|
||||
ck::BlockGemmPipelineScheduler scheduler;
|
||||
ck::BlockGemmPipelineVersion version;
|
||||
@@ -82,7 +83,7 @@ consteval ck::LoopScheduler SetLoopScheduler()
|
||||
template <ConvAlgorithmDescriptor auto ALGORITHM>
|
||||
consteval ck::PipelineVersion SetGridwiseGemmPipelineVersion()
|
||||
{
|
||||
constexpr auto pipeline_version = ALGORITHM.gridwise_gemm.pipeline_version;
|
||||
constexpr auto pipeline_version = ALGORITHM.pipeline_version;
|
||||
using ck_pipeline = ck::PipelineVersion;
|
||||
switch(pipeline_version)
|
||||
{
|
||||
@@ -149,12 +150,30 @@ consteval ck::tensor_operation::device::ConvolutionForwardSpecialization SetFwdC
|
||||
using ck_conv_spec = ck::tensor_operation::device::ConvolutionForwardSpecialization;
|
||||
switch(specialization)
|
||||
{
|
||||
case ConvFwdSpecialization::DEFAULT: return ck_conv_spec::Default;
|
||||
case ConvFwdSpecialization::FILTER_1X1_PAD0: return ck_conv_spec::Filter1x1Pad0;
|
||||
case ConvFwdSpecialization::FILTER_1X1_STRIDE1_PAD0: return ck_conv_spec::Filter1x1Stride1Pad0;
|
||||
case ConvFwdSpecialization::FILTER_3x3: return ck_conv_spec::Filter3x3;
|
||||
case ConvFwdSpecialization::ODD_C: return ck_conv_spec::OddC;
|
||||
default: throw "Unknown ConvFwdSpecialization";
|
||||
case ConvSpecialization::DEFAULT: return ck_conv_spec::Default;
|
||||
case ConvSpecialization::FILTER_1X1_PAD0: return ck_conv_spec::Filter1x1Pad0;
|
||||
case ConvSpecialization::FILTER_1X1_STRIDE1_PAD0: return ck_conv_spec::Filter1x1Stride1Pad0;
|
||||
case ConvSpecialization::FILTER_3x3: return ck_conv_spec::Filter3x3;
|
||||
case ConvSpecialization::ODD_C: return ck_conv_spec::OddC;
|
||||
default: throw "Unsupported ConvSpecialization";
|
||||
}
|
||||
}
|
||||
|
||||
template <ConvAlgorithmDescriptor auto ALGORITHM>
|
||||
consteval ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization
|
||||
SetBwdWeightConvSpecialization()
|
||||
{
|
||||
constexpr auto specialization = ALGORITHM.bwd_weight_specialization;
|
||||
using ck_conv_spec = ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization;
|
||||
switch(specialization)
|
||||
{
|
||||
case ConvSpecialization::DEFAULT: return ck_conv_spec::Default;
|
||||
case ConvSpecialization::FILTER_1X1_PAD0: return ck_conv_spec::Filter1x1Pad0;
|
||||
case ConvSpecialization::FILTER_1X1_STRIDE1_PAD0: return ck_conv_spec::Filter1x1Stride1Pad0;
|
||||
case ConvSpecialization::ODD_C: return ck_conv_spec::OddC;
|
||||
case ConvSpecialization::FILTER_3x3:
|
||||
throw "FILTER_3x3 is not supported for backward weight convolution.";
|
||||
default: throw "Unsupported ConvSpecialization";
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -26,11 +26,11 @@ struct ReferenceFactory
|
||||
static constexpr auto kValidation = (internal::ValidateReferenceSignature<SIGNATURE>(), 0);
|
||||
|
||||
static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim;
|
||||
using Types = internal::FwdConvTensorDataTypes<SIGNATURE>;
|
||||
using Types = internal::ConvTensorDataTypes<SIGNATURE>;
|
||||
|
||||
using InDataType = typename Types::ADataType;
|
||||
using WeiDataType = typename Types::BDataType;
|
||||
using OutDataType = typename Types::EDataType;
|
||||
using InDataType = typename Types::InDataType;
|
||||
using WeiDataType = typename Types::WeiDataType;
|
||||
using OutDataType = typename Types::OutDataType;
|
||||
|
||||
struct Instance
|
||||
{
|
||||
|
||||
@@ -63,10 +63,7 @@ struct GemmAlgorithmInfo
|
||||
OutputTileTransferInfo c_tile_transfer;
|
||||
builder::PipelineVersion pipeline_version;
|
||||
builder::PipelineScheduler pipeline_scheduler;
|
||||
std::variant<builder::ConvFwdSpecialization,
|
||||
builder::ConvBwdDataSpecialization,
|
||||
builder::ConvBwdWeightSpecialization>
|
||||
conv_specialization;
|
||||
builder::ConvSpecialization conv_specialization;
|
||||
builder::GemmPadding padding;
|
||||
};
|
||||
|
||||
|
||||
@@ -197,18 +197,16 @@ constexpr builder::ConvDirection conv_direction()
|
||||
|
||||
/// @brief Derives the convolution-specific specialization from a device kernel `Instance` type.
|
||||
/// @tparam Instance The device kernel instance type.
|
||||
/// @return A `builder::ConvFwdSpecialization`, `builder::ConvBwdDataSpecialization`, or
|
||||
/// `builder::ConvBwdWeightSpecialization` enum value.
|
||||
/// @return A `builder::ConvSpecialization` enum value.
|
||||
template <typename Instance>
|
||||
constexpr auto conv_spec()
|
||||
{
|
||||
using InstTraits = InstanceTraits<Instance>;
|
||||
using enum builder::ConvSpecialization;
|
||||
|
||||
if constexpr(requires { InstTraits::kConvForwardSpecialization; })
|
||||
{
|
||||
using enum ck::tensor_operation::device::ConvolutionForwardSpecialization;
|
||||
using enum builder::ConvFwdSpecialization;
|
||||
|
||||
switch(InstTraits::kConvForwardSpecialization)
|
||||
{
|
||||
case Default: return DEFAULT;
|
||||
@@ -221,8 +219,6 @@ constexpr auto conv_spec()
|
||||
else if constexpr(requires { InstTraits::kConvBwdDataSpecialization; })
|
||||
{
|
||||
using enum ck::tensor_operation::device::ConvolutionBackwardDataSpecialization;
|
||||
using enum builder::ConvBwdDataSpecialization;
|
||||
|
||||
switch(InstTraits::kConvBwdDataSpecialization)
|
||||
{
|
||||
case Default: return DEFAULT;
|
||||
@@ -232,8 +228,6 @@ constexpr auto conv_spec()
|
||||
else if constexpr(requires { InstTraits::kConvBwdWeightSpecialization; })
|
||||
{
|
||||
using enum ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization;
|
||||
using enum builder::ConvBwdWeightSpecialization;
|
||||
|
||||
switch(InstTraits::kConvBwdWeightSpecialization)
|
||||
{
|
||||
case Default: return DEFAULT;
|
||||
|
||||
@@ -35,10 +35,10 @@ struct ReferenceCommonTraits
|
||||
typename builder::factory::internal::LayoutToCK<SIGNATURE.output.config.layout>::type;
|
||||
|
||||
// Data types - extract from factory's type helper
|
||||
using Types = builder::factory::internal::FwdConvTensorDataTypes<SIGNATURE>;
|
||||
using ADataType = typename Types::ADataType;
|
||||
using BDataType = typename Types::BDataType;
|
||||
using EDataType = typename Types::EDataType;
|
||||
using Types = builder::factory::internal::ConvTensorDataTypes<SIGNATURE>;
|
||||
using ADataType = typename Types::InDataType;
|
||||
using BDataType = typename Types::WeiDataType;
|
||||
using EDataType = typename Types::OutDataType;
|
||||
using AccDataType = float; // Reference uses float accumulation
|
||||
|
||||
// Elementwise operations - reference only supports PassThrough
|
||||
|
||||
@@ -72,11 +72,10 @@ struct Args<SIGNATURE>
|
||||
using OutputDescriptor = TensorDescriptor<OUTPUT_TYPE, OUTPUT_RANK>;
|
||||
|
||||
// TODO: We shouldn't need to call into an internal namespace here.
|
||||
using Ops = factory::internal::ElementwiseOps<SIGNATURE>;
|
||||
using Ops = factory::internal::ConvElementwiseOps<SIGNATURE>;
|
||||
|
||||
// TODO: We shouldn't need to call into an internal namespace here.
|
||||
using Layouts =
|
||||
factory::internal::ConvTensorLayouts<SIGNATURE, SPATIAL_DIM, ConvDirection::FORWARD>;
|
||||
using Layouts = factory::internal::ConvTensorLayouts<SIGNATURE, SPATIAL_DIM>;
|
||||
|
||||
ConvTensorLengths<SPATIAL_DIM> lengths;
|
||||
|
||||
@@ -90,9 +89,9 @@ struct Args<SIGNATURE>
|
||||
FilterExtent<SPATIAL_DIM> input_left_pad;
|
||||
FilterExtent<SPATIAL_DIM> input_right_pad;
|
||||
|
||||
Ops::AElementwiseOp a_elementwise_op;
|
||||
Ops::BElementwiseOp b_elementwise_op;
|
||||
Ops::CDEElementwiseOp cde_elementwise_op;
|
||||
Ops::InElementwiseOp a_elementwise_op;
|
||||
Ops::WeiElementwiseOp b_elementwise_op;
|
||||
Ops::OutElementwiseOp cde_elementwise_op;
|
||||
|
||||
/// This function returns the `TensorDescriptor` corresponding to
|
||||
/// the input-tensor of the convolution problem. This can then
|
||||
@@ -107,7 +106,7 @@ struct Args<SIGNATURE>
|
||||
// function.
|
||||
const auto param = to_ck_conv_param();
|
||||
const auto desc = ck::utils::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed<
|
||||
typename Layouts::ALayout>(param);
|
||||
typename Layouts::InLayout>(param);
|
||||
using Extent = typename InputDescriptor::Extent;
|
||||
return InputDescriptor(Extent::from_vector(desc.GetLengths()),
|
||||
Extent::from_vector(desc.GetStrides()));
|
||||
@@ -121,7 +120,7 @@ struct Args<SIGNATURE>
|
||||
// See note in implementation of `make_input_descriptor`.
|
||||
const auto param = to_ck_conv_param();
|
||||
const auto desc = ck::utils::conv::make_weight_host_tensor_descriptor_g_k_c_xs_packed<
|
||||
typename Layouts::BLayout>(param);
|
||||
typename Layouts::WeiLayout>(param);
|
||||
using Extent = typename WeightDescriptor::Extent;
|
||||
return WeightDescriptor(Extent::from_vector(desc.GetLengths()),
|
||||
Extent::from_vector(desc.GetStrides()));
|
||||
@@ -135,7 +134,7 @@ struct Args<SIGNATURE>
|
||||
// See note in implementation of `make_input_descriptor`.
|
||||
const auto param = to_ck_conv_param();
|
||||
const auto desc = ck::utils::conv::make_output_host_tensor_descriptor_g_n_k_wos_packed<
|
||||
typename Layouts::ELayout>(param);
|
||||
typename Layouts::OutLayout>(param);
|
||||
using Extent = typename OutputDescriptor::Extent;
|
||||
return OutputDescriptor(Extent::from_vector(desc.GetLengths()),
|
||||
Extent::from_vector(desc.GetStrides()));
|
||||
|
||||
@@ -27,7 +27,7 @@ template <typename Conv,
|
||||
auto SIGNATURE,
|
||||
size_t SPATIAL_DIM = SIGNATURE.spatial_dim,
|
||||
// TODO: We shouldn't need to call into an internal namespace here.
|
||||
typename Ops = factory::internal::ElementwiseOps<SIGNATURE>>
|
||||
typename Ops = factory::internal::ConvElementwiseOps<SIGNATURE>>
|
||||
concept CkConvInstance = requires(Conv& conv,
|
||||
// TODO: This should be changed depending on IsMultiA etc.
|
||||
// Currently that is not yet supported elsewhere anyway.
|
||||
@@ -37,9 +37,9 @@ concept CkConvInstance = requires(Conv& conv,
|
||||
std::array<index_t, SPATIAL_DIM + 3> lengths,
|
||||
std::array<index_t, SPATIAL_DIM + 3> strides,
|
||||
std::array<index_t, SPATIAL_DIM> filter,
|
||||
Ops::AElementwiseOp elementwise_a,
|
||||
Ops::BElementwiseOp elementwise_b,
|
||||
Ops::CDEElementwiseOp elementwise_cde) {
|
||||
Ops::InElementwiseOp elementwise_a,
|
||||
Ops::WeiElementwiseOp elementwise_b,
|
||||
Ops::OutElementwiseOp elementwise_cde) {
|
||||
{
|
||||
conv.MakeArgument(p_a,
|
||||
p_b,
|
||||
|
||||
@@ -192,8 +192,8 @@ enum class TileConvSpecialization
|
||||
FILTER_3x3
|
||||
};
|
||||
|
||||
// Enums for the forward convolution specialization.
|
||||
enum class ConvFwdSpecialization
|
||||
// Enums for the convolution specializations.
|
||||
enum class ConvSpecialization
|
||||
{
|
||||
DEFAULT,
|
||||
FILTER_1X1_PAD0,
|
||||
@@ -202,22 +202,6 @@ enum class ConvFwdSpecialization
|
||||
ODD_C
|
||||
};
|
||||
|
||||
// Enums for the backward data convolution specialization.
|
||||
enum class ConvBwdDataSpecialization
|
||||
{
|
||||
DEFAULT,
|
||||
FILTER_1X1_STRIDE1_PAD0,
|
||||
};
|
||||
|
||||
// Enums for the backward weight convolution specialization.
|
||||
enum class ConvBwdWeightSpecialization
|
||||
{
|
||||
DEFAULT,
|
||||
FILTER_1X1_STRIDE1_PAD0,
|
||||
FILTER_1X1_PAD0,
|
||||
ODD_C,
|
||||
};
|
||||
|
||||
// Enums for the Gemm padding.
|
||||
enum class GemmPadding
|
||||
{
|
||||
@@ -249,7 +233,9 @@ enum class PipelineScheduler
|
||||
enum class ConvAlgorithmSpecialization
|
||||
{
|
||||
LARGE_TENSOR,
|
||||
REFERENCE // GPU reference implementation for validation
|
||||
REFERENCE, // GPU reference implementation for validation,
|
||||
TWO_STAGE,
|
||||
MULTIPLE_D
|
||||
};
|
||||
|
||||
// to_string methods for enum classes
|
||||
@@ -372,9 +358,9 @@ inline std::string_view to_string(GemmSpecialization spec)
|
||||
}
|
||||
}
|
||||
|
||||
inline std::string_view to_string(ConvFwdSpecialization spec)
|
||||
inline std::string_view to_string(ConvSpecialization spec)
|
||||
{
|
||||
using enum ConvFwdSpecialization;
|
||||
using enum ConvSpecialization;
|
||||
switch(spec)
|
||||
{
|
||||
case DEFAULT: return "DEFAULT";
|
||||
@@ -386,30 +372,6 @@ inline std::string_view to_string(ConvFwdSpecialization spec)
|
||||
}
|
||||
}
|
||||
|
||||
inline std::string_view to_string(ConvBwdDataSpecialization spec)
|
||||
{
|
||||
using enum ConvBwdDataSpecialization;
|
||||
switch(spec)
|
||||
{
|
||||
case DEFAULT: return "DEFAULT";
|
||||
case FILTER_1X1_STRIDE1_PAD0: return "FILTER_1X1_STRIDE1_PAD0";
|
||||
default: return "Unknown";
|
||||
}
|
||||
}
|
||||
|
||||
inline std::string_view to_string(ConvBwdWeightSpecialization spec)
|
||||
{
|
||||
using enum ConvBwdWeightSpecialization;
|
||||
switch(spec)
|
||||
{
|
||||
case DEFAULT: return "DEFAULT";
|
||||
case FILTER_1X1_STRIDE1_PAD0: return "FILTER_1X1_STRIDE1_PAD0";
|
||||
case FILTER_1X1_PAD0: return "FILTER_1X1_PAD0";
|
||||
case ODD_C: return "ODD_C";
|
||||
default: return "Unknown";
|
||||
}
|
||||
}
|
||||
|
||||
inline std::string_view to_string(GemmPadding padding)
|
||||
{
|
||||
using enum GemmPadding;
|
||||
@@ -525,17 +487,7 @@ inline std::ostream& operator<<(std::ostream& os, GemmSpecialization spec)
|
||||
return os << to_string(spec);
|
||||
}
|
||||
|
||||
inline std::ostream& operator<<(std::ostream& os, ConvFwdSpecialization spec)
|
||||
{
|
||||
return os << to_string(spec);
|
||||
}
|
||||
|
||||
inline std::ostream& operator<<(std::ostream& os, ConvBwdDataSpecialization spec)
|
||||
{
|
||||
return os << to_string(spec);
|
||||
}
|
||||
|
||||
inline std::ostream& operator<<(std::ostream& os, ConvBwdWeightSpecialization spec)
|
||||
inline std::ostream& operator<<(std::ostream& os, ConvSpecialization spec)
|
||||
{
|
||||
return os << to_string(spec);
|
||||
}
|
||||
@@ -555,14 +507,4 @@ inline std::ostream& operator<<(std::ostream& os, TensorLayout layout)
|
||||
return os << to_string(layout);
|
||||
}
|
||||
|
||||
// ostream operator overload for std::variant of convolution specializations
|
||||
inline std::ostream& operator<<(std::ostream& os,
|
||||
const std::variant<ConvFwdSpecialization,
|
||||
ConvBwdDataSpecialization,
|
||||
ConvBwdWeightSpecialization>& spec)
|
||||
{
|
||||
std::visit([&os](const auto& s) { os << s; }, spec);
|
||||
return os;
|
||||
}
|
||||
|
||||
} // namespace ck_tile::builder
|
||||
|
||||
@@ -124,7 +124,7 @@ add_ck_builder_test(test_ckb_conv_description
|
||||
# Verifies that GetInstanceString() methods and other functions produce valid kernel code.
|
||||
# Tests various convolution types:
|
||||
# - Group convolution (v3, standard, large tensor, WMMA, DL variants)
|
||||
# - Backward weight group convolution (XDL)
|
||||
# - Backward weight group convolution (XDL standard, XDL v3, WMMA, DL, multiple D, two-stage variants)
|
||||
# Requires kernel compilation to validate the generated strings through the base class.
|
||||
|
||||
set(INSTANCE_STRING_TESTS
|
||||
@@ -167,10 +167,35 @@ add_ck_builder_test(test_ckb_build_fwd_instances
|
||||
conv/ck/test_ckb_conv_fwd_3d_fp16.cpp
|
||||
conv/ck/test_ckb_conv_fwd_3d_fp32.cpp
|
||||
conv/ck_tile/test_ckb_conv_fwd_2d_fp16_v3.cpp
|
||||
conv/ck_tile/test_ckb_conv_bwd_weight_2d_fp16_v3.cpp
|
||||
conv/ck_tile/test_ckb_conv_bwd_data_2d_fp16_v3.cpp)
|
||||
)
|
||||
target_link_libraries(test_ckb_build_fwd_instances PRIVATE utility)
|
||||
|
||||
set(BWD_WEIGHT_TESTS
|
||||
conv/ck/test_ckb_conv_bwd_weight_xdl_cshuffle.cpp
|
||||
conv/ck/test_ckb_conv_bwd_weight_two_stage_xdl_cshuffle.cpp
|
||||
conv/ck/test_ckb_conv_bwd_weight_multi_d_xdl_cshuffle.cpp
|
||||
conv/ck/test_ckb_conv_bwd_weight_xdl_cshuffle_v3.cpp
|
||||
conv/ck/test_ckb_conv_bwd_weight_dl.cpp
|
||||
conv/ck_tile/test_ckb_conv_bwd_weight_2d_fp16_v3.cpp
|
||||
)
|
||||
|
||||
if (CK_USE_WMMA)
|
||||
list(APPEND BWD_WEIGHT_TESTS
|
||||
conv/ck/test_ckb_conv_bwd_weight_wmma_cshuffle_v3.cpp
|
||||
conv/ck/test_ckb_conv_bwd_weight_wmma_cshuffle.cpp
|
||||
conv/ck/test_ckb_conv_bwd_weight_two_stage_wmma_cshuffle_v3.cpp
|
||||
conv/ck/test_ckb_conv_bwd_weight_multi_d_wmma_cshuffle_v3.cpp
|
||||
)
|
||||
endif()
|
||||
|
||||
add_ck_builder_test(test_ckb_build_bwd_weight_instances ${BWD_WEIGHT_TESTS})
|
||||
target_link_libraries(test_ckb_build_bwd_weight_instances PRIVATE utility)
|
||||
|
||||
add_ck_builder_test(test_ckb_build_bwd_data_instances
|
||||
conv/ck_tile/test_ckb_conv_bwd_data_2d_fp16_v3.cpp
|
||||
)
|
||||
target_link_libraries(test_ckb_build_bwd_data_instances PRIVATE utility)
|
||||
|
||||
|
||||
################################################################################
|
||||
# FACTORY TESTS - Expensive Regression Tests (Full MIOpen Kernel Set)
|
||||
@@ -224,6 +249,8 @@ endforeach()
|
||||
set(CKB_REGRESSION_TESTS
|
||||
test_ckb_instance_string
|
||||
test_ckb_build_fwd_instances
|
||||
test_ckb_build_bwd_weight_instances
|
||||
test_ckb_build_bwd_data_instances
|
||||
test_ckb_testing_utils
|
||||
# test_ckb_factory_grouped_convolution_forward_convscale
|
||||
# test_ckb_factory_grouped_convolution_forward_scaleadd_ab
|
||||
|
||||
@@ -0,0 +1,40 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "utils/ckb_conv_test_configs.hpp"
|
||||
#include "utils/ckb_conv_test_utils.hpp"
|
||||
#include "utils/conv_algorithm_type_utils.hpp"
|
||||
|
||||
namespace ckb = ck_tile::builder;
|
||||
namespace ckt = ck_tile::builder::test;
|
||||
namespace cku = ck_tile::builder::test_utils;
|
||||
|
||||
constexpr auto SIGNATURE =
|
||||
ckt::ConvSignature{.spatial_dim = 2,
|
||||
.direction = ckb::ConvDirection::BACKWARD_WEIGHT,
|
||||
.data_type = ckb::DataType::BF16,
|
||||
.accumulation_data_type = ckb::DataType::FP32,
|
||||
.input = {.config = {.layout = ckb::TensorLayout::GNHWC}},
|
||||
.weight = {.config = {.layout = ckb::TensorLayout::GKYXC}},
|
||||
.output = {.config = {.layout = ckb::TensorLayout::GNHWK}}};
|
||||
|
||||
constexpr auto ALGORITHM = cku::ConvAlgorithm_DeviceGroupedConvBwdWeight_Dl{}
|
||||
.with_thread_block(cku::ThreadBlock_256_128x128x16)
|
||||
.with_bwd_specialization(cku::ConvSpecialization::DEFAULT)
|
||||
.with_dl_thread_config(cku::DlThreadConfig_16x1x4x4x1)
|
||||
.with_dl_thread_cluster(cku::DlThreadCluster_8x2)
|
||||
.with_dl_transfer(cku::DlTransfer5D);
|
||||
|
||||
using Builder = ckb::ConvBuilder<SIGNATURE, ALGORITHM>;
|
||||
using Instance = Builder::Instance;
|
||||
|
||||
TEST(BwdWeight_2DBf16_DL, Create)
|
||||
{
|
||||
const auto expected_transfer_parameters = to_string(ALGORITHM);
|
||||
std::cout << "Expected Transfer Parameters: " << expected_transfer_parameters << std::endl;
|
||||
cku::run_test<Builder>({"DeviceGroupedConvBwdWeight_Dl",
|
||||
expected_transfer_parameters,
|
||||
"Default",
|
||||
"GNHWC,GKYXC,GNHWK",
|
||||
"PassThrough,PassThrough,PassThrough"});
|
||||
}
|
||||
@@ -0,0 +1,42 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "utils/ckb_conv_test_configs.hpp"
|
||||
#include "utils/ckb_conv_test_utils.hpp"
|
||||
#include "utils/conv_algorithm_type_utils.hpp"
|
||||
#include "ck_tile/host/device_prop.hpp"
|
||||
|
||||
namespace ckb = ck_tile::builder;
|
||||
namespace ckt = ck_tile::builder::test;
|
||||
namespace cku = ck_tile::builder::test_utils;
|
||||
|
||||
constexpr auto SIGNATURE =
|
||||
ckt::ConvSignature{.spatial_dim = 3,
|
||||
.direction = ckb::ConvDirection::BACKWARD_WEIGHT,
|
||||
.data_type = ckb::DataType::FP16,
|
||||
.accumulation_data_type = ckb::DataType::FP32,
|
||||
.input = {.config = {.layout = ckb::TensorLayout::GNDHWC}},
|
||||
.weight = {.config = {.layout = ckb::TensorLayout::GKZYXC}},
|
||||
.output = {.config = {.layout = ckb::TensorLayout::GNDHWK}}};
|
||||
|
||||
constexpr auto ALGORITHM = cku::ConvAlgorithm_DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffle_V3{}
|
||||
.with_thread_block(cku::ThreadBlock_64_32x32x32)
|
||||
.with_gemm_config(cku::GemmParams_Wmma_16x16_2x1_per_wave)
|
||||
.with_transfer(cku::BwdTransfer_4x8x1_4x16x1_v3)
|
||||
.with_bwd_specialization(ckb::ConvSpecialization::DEFAULT)
|
||||
.with_block_gemm(cku::BlockGemmDesc_v1_intrawave);
|
||||
|
||||
using Builder = ckb::ConvBuilder<SIGNATURE, ALGORITHM>;
|
||||
using Instance = Builder::Instance;
|
||||
|
||||
TEST(BwdWeight_3DFp16_MultiD_Wmma_ShuffleV3_GNHWC, Create)
|
||||
{
|
||||
const auto expected_transfer_parameters = to_string(ALGORITHM);
|
||||
std::cout << "Expected Transfer Parameters: " << expected_transfer_parameters << std::endl;
|
||||
cku::run_test<Builder>({"DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3",
|
||||
expected_transfer_parameters,
|
||||
"Default",
|
||||
"GNDHWC,GKZYXC,GNDHWK",
|
||||
"PassThrough,PassThrough,PassThrough",
|
||||
"fp16,fp16>"}); // check compute types
|
||||
}
|
||||
@@ -0,0 +1,41 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "utils/ckb_conv_test_configs.hpp"
|
||||
#include "utils/ckb_conv_test_utils.hpp"
|
||||
#include "utils/conv_algorithm_type_utils.hpp"
|
||||
#include "ck_tile/host/device_prop.hpp"
|
||||
|
||||
namespace ckb = ck_tile::builder;
|
||||
namespace ckt = ck_tile::builder::test;
|
||||
namespace cku = ck_tile::builder::test_utils;
|
||||
|
||||
constexpr auto SIGNATURE =
|
||||
ckt::ConvSignature{.spatial_dim = 2,
|
||||
.direction = ckb::ConvDirection::BACKWARD_WEIGHT,
|
||||
.data_type = ckb::DataType::FP16,
|
||||
.accumulation_data_type = ckb::DataType::FP32,
|
||||
.input = {.config = {.layout = ckb::TensorLayout::GNHWC}},
|
||||
.weight = {.config = {.layout = ckb::TensorLayout::GKYXC}},
|
||||
.output = {.config = {.layout = ckb::TensorLayout::GNHWK}}};
|
||||
|
||||
constexpr auto ALGORITHM = cku::ConvAlgorithm_DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle{}
|
||||
.with_thread_block(cku::ThreadBlock_256_128x128x8)
|
||||
.with_gemm_config(cku::BwdGemmParams_Xdl_4x4_per_wave)
|
||||
.with_transfer(cku::BwdTransfer_4x64x1)
|
||||
.with_bwd_specialization(ckb::ConvSpecialization::DEFAULT);
|
||||
|
||||
using Builder = ckb::ConvBuilder<SIGNATURE, ALGORITHM>;
|
||||
using Instance = Builder::Instance;
|
||||
|
||||
TEST(BwdWeight_2DFp16_MultiD_CShuffle_GNHWC, Create)
|
||||
{
|
||||
const auto expected_transfer_parameters = to_string(ALGORITHM);
|
||||
std::cout << "Expected Transfer Parameters: " << expected_transfer_parameters << std::endl;
|
||||
cku::run_test<Builder>({"DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle",
|
||||
expected_transfer_parameters,
|
||||
"Default",
|
||||
"GNHWC,GKYXC,GNHWK",
|
||||
"PassThrough,PassThrough,PassThrough",
|
||||
"fp16,fp16>"}); // check compute types
|
||||
}
|
||||
@@ -0,0 +1,46 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "utils/ckb_conv_test_configs.hpp"
|
||||
#include "utils/ckb_conv_test_utils.hpp"
|
||||
#include "utils/conv_algorithm_type_utils.hpp"
|
||||
#include "ck_tile/host/device_prop.hpp"
|
||||
|
||||
namespace ckb = ck_tile::builder;
|
||||
namespace ckt = ck_tile::builder::test;
|
||||
namespace cku = ck_tile::builder::test_utils;
|
||||
using enum ck_tile::builder::TensorLayout;
|
||||
|
||||
constexpr auto SIGNATURE = ckt::ConvSignature{.spatial_dim = 2,
|
||||
.direction = ckb::ConvDirection::BACKWARD_WEIGHT,
|
||||
.data_type = ckb::DataType::FP16,
|
||||
.accumulation_data_type = ckb::DataType::FP32,
|
||||
.input = {.config = {.layout = NGCHW}},
|
||||
.weight = {.config = {.layout = GKYXC}},
|
||||
.output = {.config = {.layout = NGKHW}}};
|
||||
|
||||
constexpr auto ALGORITHM = cku::ConvAlgorithm_DeviceGroupedConvBwdWeight_TwoStage_Wmma_CShuffle_V3{}
|
||||
.with_thread_block(cku::ThreadBlock_64_32x32x32)
|
||||
.with_gemm_config(cku::GemmParams_Wmma_16x16_2x1_per_wave)
|
||||
.with_transfer(cku::BwdTransfer_4x8x1_4x16x1_v3)
|
||||
.with_bwd_specialization(ckb::ConvSpecialization::DEFAULT)
|
||||
.with_block_gemm(cku::BlockGemmDesc_v1_intrawave)
|
||||
.with_num_conv_groups_to_merge(2)
|
||||
.with_transpose_params(2, 2);
|
||||
|
||||
using Builder = ckb::ConvBuilder<SIGNATURE, ALGORITHM>;
|
||||
using Instance = Builder::Instance;
|
||||
|
||||
TEST(BwdWeight_2DFp16_TwoStage_Wmma_CShuffle_V3, Create)
|
||||
{
|
||||
const auto expected_transfer_parameters = to_string(ALGORITHM);
|
||||
std::cout << "Expected Transfer Parameters: " << expected_transfer_parameters << std::endl;
|
||||
cku::run_test<Builder>({"DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3",
|
||||
expected_transfer_parameters,
|
||||
"Default",
|
||||
"NGCHW,GKYXC,NGKHW",
|
||||
"PassThrough,PassThrough,PassThrough",
|
||||
"Intrawave",
|
||||
"v1",
|
||||
"fp16,fp16,2,2>"}); // Check compute types and transpose params.
|
||||
}
|
||||
@@ -0,0 +1,44 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "utils/ckb_conv_test_configs.hpp"
|
||||
#include "utils/ckb_conv_test_utils.hpp"
|
||||
#include "utils/conv_algorithm_type_utils.hpp"
|
||||
#include "ck_tile/host/device_prop.hpp"
|
||||
|
||||
namespace ckb = ck_tile::builder;
|
||||
namespace ckt = ck_tile::builder::test;
|
||||
namespace cku = ck_tile::builder::test_utils;
|
||||
|
||||
constexpr auto SIGNATURE =
|
||||
ckt::ConvSignature{.spatial_dim = 2,
|
||||
.direction = ckb::ConvDirection::BACKWARD_WEIGHT,
|
||||
.data_type = ckb::DataType::BF16,
|
||||
.accumulation_data_type = ckb::DataType::FP32,
|
||||
.input = {.config = {.layout = ckb::TensorLayout::GNHWC}},
|
||||
.weight = {.config = {.layout = ckb::TensorLayout::GKYXC}},
|
||||
.output = {.config = {.layout = ckb::TensorLayout::GNHWK}}};
|
||||
|
||||
constexpr auto ALGORITHM = cku::ConvAlgorithm_DeviceGroupedConvBwdWeight_TwoStage_Xdl_CShuffle{}
|
||||
.with_thread_block(cku::ThreadBlock_64_32x32x32)
|
||||
.with_gemm_config(cku::BwdGemmParams_Xdl_1x1_per_wave)
|
||||
.with_transfer(cku::BwdTransfer_4x8x1_4x16x1_v3)
|
||||
.with_bwd_specialization(ckb::ConvSpecialization::DEFAULT)
|
||||
.with_block_gemm(cku::BlockGemmDesc_v2_intrawave)
|
||||
.with_num_conv_groups_to_merge(2)
|
||||
.with_transpose_params(2, 4);
|
||||
|
||||
using Builder = ckb::ConvBuilder<SIGNATURE, ALGORITHM>;
|
||||
using Instance = Builder::Instance;
|
||||
|
||||
TEST(BwdWeight_2DBf16_TwoStage_CShuffle, Create)
|
||||
{
|
||||
const auto expected_transfer_parameters = to_string(ALGORITHM);
|
||||
cku::run_test<Builder>({"DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle",
|
||||
expected_transfer_parameters,
|
||||
"Default",
|
||||
"GNHWC,GKYXC,GNHWK",
|
||||
"PassThrough,PassThrough,PassThrough",
|
||||
"Intrawave,v2", // pipeline versions
|
||||
"bf16,bf16,2,4>"}); // compute types and transpose params
|
||||
}
|
||||
@@ -0,0 +1,43 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "utils/ckb_conv_test_configs.hpp"
|
||||
#include "utils/ckb_conv_test_utils.hpp"
|
||||
#include "utils/conv_algorithm_type_utils.hpp"
|
||||
#include "ck_tile/host/device_prop.hpp"
|
||||
|
||||
namespace ckb = ck_tile::builder;
|
||||
namespace ckt = ck_tile::builder::test;
|
||||
namespace cku = ck_tile::builder::test_utils;
|
||||
using enum ck_tile::builder::TensorLayout;
|
||||
|
||||
constexpr auto SIGNATURE = ckt::ConvSignature{.spatial_dim = 3,
|
||||
.direction = ckb::ConvDirection::BACKWARD_WEIGHT,
|
||||
.data_type = ckb::DataType::BF16,
|
||||
.accumulation_data_type = ckb::DataType::FP32,
|
||||
.input = {.config = {.layout = NGCDHW}},
|
||||
.weight = {.config = {.layout = GKZYXC}},
|
||||
.output = {.config = {.layout = NGKDHW}}};
|
||||
|
||||
constexpr auto ALGORITHM = cku::ConvAlgorithm_DeviceGroupedConvBwdWeight_Wmma_CShuffle{}
|
||||
.with_thread_block(cku::ThreadBlock_64_32x32x32)
|
||||
.with_gemm_config(cku::GemmParams_Wmma_16x16_2x1_per_wave)
|
||||
.with_transfer(cku::BwdTransfer_4x8x1_4x16x1_v3)
|
||||
.with_bwd_specialization(ckb::ConvSpecialization::DEFAULT)
|
||||
.with_prefetch_config(1, ckb::PipelineScheduler::DEFAULT)
|
||||
.with_gridwise_gemm_pipeline(ckb::PipelineVersion::V1);
|
||||
|
||||
using Builder = ckb::ConvBuilder<SIGNATURE, ALGORITHM>;
|
||||
using Instance = Builder::Instance;
|
||||
|
||||
TEST(BwdWeight_3DBf16_Wmma_CShuffle, Create)
|
||||
{
|
||||
const auto expected_transfer_parameters = to_string(ALGORITHM);
|
||||
std::cout << "Expected Transfer Parameters: " << expected_transfer_parameters << std::endl;
|
||||
cku::run_test<Builder>({"DeviceGroupedConvBwdWeight_Wmma_CShuffle",
|
||||
expected_transfer_parameters,
|
||||
"Default",
|
||||
"NGCDHW,GKZYXC,NGKDHW",
|
||||
"PassThrough,PassThrough,PassThrough",
|
||||
"v1"});
|
||||
}
|
||||
@@ -0,0 +1,46 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "utils/ckb_conv_test_configs.hpp"
|
||||
#include "utils/ckb_conv_test_utils.hpp"
|
||||
#include "utils/conv_algorithm_type_utils.hpp"
|
||||
#include "ck_tile/host/device_prop.hpp"
|
||||
|
||||
namespace ckb = ck_tile::builder;
|
||||
namespace ckt = ck_tile::builder::test;
|
||||
namespace cku = ck_tile::builder::test_utils;
|
||||
using enum ck_tile::builder::TensorLayout;
|
||||
|
||||
constexpr auto SIGNATURE = ckt::ConvSignature{.spatial_dim = 1,
|
||||
.direction = ckb::ConvDirection::BACKWARD_WEIGHT,
|
||||
.data_type = ckb::DataType::BF16,
|
||||
.accumulation_data_type = ckb::DataType::FP32,
|
||||
.input = {.config = {.layout = NGCW}},
|
||||
.weight = {.config = {.layout = GKXC}},
|
||||
.output = {.config = {.layout = NGKW}}};
|
||||
|
||||
constexpr auto ALGORITHM =
|
||||
cku::ConvAlgorithm_DeviceGroupedConvBwdWeight_Wmma_CShuffle_V3{}
|
||||
.with_thread_block(cku::ThreadBlock_64_32x32x32)
|
||||
.with_gemm_config(cku::GemmParams_Wmma_16x16_2x1_per_wave)
|
||||
.with_transfer(cku::BwdTransfer_4x8x1_4x16x1_v3)
|
||||
.with_bwd_specialization(ckb::ConvSpecialization::FILTER_1X1_STRIDE1_PAD0)
|
||||
.with_block_gemm(cku::BlockGemmDesc_v1_intrawave)
|
||||
.with_transpose_params(4, 4);
|
||||
|
||||
using Builder = ckb::ConvBuilder<SIGNATURE, ALGORITHM>;
|
||||
using Instance = Builder::Instance;
|
||||
|
||||
TEST(BwdWeight_1DBf16_Wmma_CShuffle_V3, Create)
|
||||
{
|
||||
const auto expected_transfer_parameters = to_string(ALGORITHM);
|
||||
std::cout << "Expected Transfer Parameters: " << expected_transfer_parameters << std::endl;
|
||||
cku::run_test<Builder>({"DeviceGroupedConvBwdWeight_Wmma_CShuffleV3",
|
||||
expected_transfer_parameters,
|
||||
"Filter1x1Stride1Pad0",
|
||||
"NGCW,GKXC,NGKW",
|
||||
"PassThrough,PassThrough,PassThrough",
|
||||
"Intrawave",
|
||||
"v1",
|
||||
"bf16,bf16,4,4>"});
|
||||
}
|
||||
@@ -0,0 +1,41 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "utils/ckb_conv_test_configs.hpp"
|
||||
#include "utils/ckb_conv_test_utils.hpp"
|
||||
#include "utils/conv_algorithm_type_utils.hpp"
|
||||
#include "ck_tile/host/device_prop.hpp"
|
||||
|
||||
namespace ckb = ck_tile::builder;
|
||||
namespace ckt = ck_tile::builder::test;
|
||||
namespace cku = ck_tile::builder::test_utils;
|
||||
|
||||
constexpr auto SIGNATURE =
|
||||
ckt::ConvSignature{.spatial_dim = 2,
|
||||
.direction = ckb::ConvDirection::BACKWARD_WEIGHT,
|
||||
.data_type = ckb::DataType::FP16,
|
||||
.accumulation_data_type = ckb::DataType::FP32,
|
||||
.input = {.config = {.layout = ckb::TensorLayout::GNHWC}},
|
||||
.weight = {.config = {.layout = ckb::TensorLayout::GKYXC}},
|
||||
.output = {.config = {.layout = ckb::TensorLayout::GNHWK}}};
|
||||
|
||||
constexpr auto ALGORITHM = cku::ConvAlgorithm_DeviceGroupedConvBwdWeight_Xdl_CShuffle{}
|
||||
.with_thread_block(cku::ThreadBlock_256_128x128x8)
|
||||
.with_gemm_config(cku::BwdGemmParams_Xdl_4x4_per_wave)
|
||||
.with_transfer(cku::BwdTransfer_4x64x1)
|
||||
.with_bwd_specialization(ckb::ConvSpecialization::DEFAULT)
|
||||
.with_transpose_params(2, 2);
|
||||
|
||||
using Builder = ckb::ConvBuilder<SIGNATURE, ALGORITHM>;
|
||||
using Instance = Builder::Instance;
|
||||
|
||||
TEST(BwdWeight_2DFp16_CShuffle_GNHWC, Create)
|
||||
{
|
||||
const auto expected_transfer_parameters = to_string(ALGORITHM);
|
||||
cku::run_test<Builder>({"DeviceGroupedConvBwdWeight_Xdl_CShuffle",
|
||||
expected_transfer_parameters,
|
||||
"Default",
|
||||
"GNHWC,GKYXC,GNHWK",
|
||||
"PassThrough,PassThrough,PassThrough",
|
||||
"fp16,fp16,2,2>"}); // check compute types and transpose params
|
||||
}
|
||||
@@ -0,0 +1,43 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "utils/ckb_conv_test_configs.hpp"
|
||||
#include "utils/ckb_conv_test_utils.hpp"
|
||||
#include "utils/conv_algorithm_type_utils.hpp"
|
||||
#include "ck_tile/host/device_prop.hpp"
|
||||
|
||||
namespace ckb = ck_tile::builder;
|
||||
namespace ckt = ck_tile::builder::test;
|
||||
namespace cku = ck_tile::builder::test_utils;
|
||||
using enum ck_tile::builder::TensorLayout;
|
||||
|
||||
constexpr auto SIGNATURE = ckt::ConvSignature{.spatial_dim = 1,
|
||||
.direction = ckb::ConvDirection::BACKWARD_WEIGHT,
|
||||
.data_type = ckb::DataType::BF16,
|
||||
.accumulation_data_type = ckb::DataType::FP32,
|
||||
.input = {.config = {.layout = NGCW}},
|
||||
.weight = {.config = {.layout = GKXC}},
|
||||
.output = {.config = {.layout = NGKW}}};
|
||||
|
||||
constexpr auto ALGORITHM =
|
||||
cku::ConvAlgorithm_DeviceGroupedConvBwdWeight_Xdl_CShuffle_V3{}
|
||||
.with_thread_block(cku::ThreadBlock_64_32x32x32)
|
||||
.with_gemm_config(cku::BwdGemmParams_Xdl_1x1_per_wave)
|
||||
.with_transfer(cku::BwdTransfer_4x8x1_4x16x1_v3)
|
||||
.with_bwd_specialization(ckb::ConvSpecialization::FILTER_1X1_STRIDE1_PAD0)
|
||||
.with_block_gemm(cku::BlockGemmDesc_v2_intrawave);
|
||||
|
||||
using Builder = ckb::ConvBuilder<SIGNATURE, ALGORITHM>;
|
||||
using Instance = Builder::Instance;
|
||||
|
||||
TEST(BwdWeight_1DBf16_CShuffle_V3, Create)
|
||||
{
|
||||
const auto expected_transfer_parameters = to_string(ALGORITHM);
|
||||
cku::run_test<Builder>({"DeviceGroupedConvBwdWeight_Xdl_CShuffleV3",
|
||||
expected_transfer_parameters,
|
||||
"Filter1x1Stride1Pad0",
|
||||
"NGCW,GKXC,NGKW",
|
||||
"PassThrough,PassThrough,PassThrough",
|
||||
"Intrawave",
|
||||
"v2"});
|
||||
}
|
||||
@@ -30,11 +30,11 @@ TEST(FwdConvInstances,
|
||||
|
||||
constexpr auto FwdConvAlgorithm =
|
||||
ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3{}
|
||||
.with_thread_block(FwdThreadBlock_256_256x256x32)
|
||||
.with_thread_block(ThreadBlock_256_256x256x32)
|
||||
.with_gemm_config(FwdGemmParams_Xdl_4x4_per_wave)
|
||||
.with_transfer(FwdTransfer_4x64x1)
|
||||
.with_specializations(ConvFwdSpecialization::FILTER_1X1_STRIDE1_PAD0,
|
||||
GemmSpecialization::MNKPadding)
|
||||
.with_transfer(Transfer_4x64x1)
|
||||
.with_fwd_specializations(ConvSpecialization::FILTER_1X1_STRIDE1_PAD0,
|
||||
GemmSpecialization::MNKPadding)
|
||||
.with_block_gemm(BlockGemmDesc_v2_intrawave);
|
||||
|
||||
using Builder = ConvBuilder<FwdConvSignature, FwdConvAlgorithm>;
|
||||
|
||||
@@ -27,11 +27,12 @@ TEST(FwdConvInstances,
|
||||
|
||||
constexpr auto FwdConvAlgorithm =
|
||||
ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle{}
|
||||
.with_thread_block(FwdThreadBlock_64_64x32x32)
|
||||
.with_thread_block(ThreadBlock_64_64x32x32)
|
||||
.with_gemm_config(FwdGemmParams_Xdl_2x1_per_wave)
|
||||
.with_transfer(FwdTransfer_4x16x1)
|
||||
.with_specializations(ConvFwdSpecialization::DEFAULT, GemmSpecialization::MNKPadding)
|
||||
.with_prefetch_config(1, 2, PipelineScheduler::DEFAULT);
|
||||
.with_transfer(Transfer_4x16x1)
|
||||
.with_fwd_specializations(ConvSpecialization::DEFAULT, GemmSpecialization::MNKPadding)
|
||||
.with_prefetch_config(1, PipelineScheduler::DEFAULT)
|
||||
.with_num_conv_groups_to_merge(2);
|
||||
|
||||
using Builder = ConvBuilder<FwdConvSignature, FwdConvAlgorithm>;
|
||||
|
||||
|
||||
@@ -29,11 +29,13 @@ TEST(FwdConvInstances,
|
||||
|
||||
constexpr auto FwdConvAlgorithm =
|
||||
ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle{}
|
||||
.with_thread_block(FwdThreadBlock_128_64x64x64)
|
||||
.with_gemm_config(FwdGemmParams_Wmma_2x1_per_wave)
|
||||
.with_transfer(FwdTransfer_4x32x1)
|
||||
.with_specializations(ConvFwdSpecialization::DEFAULT, GemmSpecialization::MNKPadding)
|
||||
.with_prefetch_config(1, 0, PipelineScheduler::DEFAULT);
|
||||
.with_thread_block(ThreadBlock_128_64x64x64)
|
||||
.with_gemm_config(GemmParams_Wmma_2x1_per_wave)
|
||||
.with_transfer(Transfer_4x32x1)
|
||||
.with_fwd_specializations(ConvSpecialization::DEFAULT, GemmSpecialization::MNKPadding)
|
||||
.with_prefetch_config(1, PipelineScheduler::DEFAULT)
|
||||
.with_num_conv_groups_to_merge(2)
|
||||
.with_gridwise_gemm_pipeline(PipelineVersion::V1);
|
||||
|
||||
using Builder = ConvBuilder<FwdConvSignature, FwdConvAlgorithm>;
|
||||
|
||||
|
||||
@@ -27,10 +27,10 @@ TEST(FwdConvInstances,
|
||||
|
||||
constexpr auto FwdConvAlgorithm =
|
||||
ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3{}
|
||||
.with_thread_block(FwdThreadBlock_256_256x256x32)
|
||||
.with_thread_block(ThreadBlock_256_256x256x32)
|
||||
.with_gemm_config(FwdGemmParams_Xdl_4x4_per_wave)
|
||||
.with_transfer(FwdTransfer_4x64x1)
|
||||
.with_specializations(ConvFwdSpecialization::DEFAULT, GemmSpecialization::MNKPadding)
|
||||
.with_transfer(Transfer_4x64x1)
|
||||
.with_fwd_specializations(ConvSpecialization::DEFAULT, GemmSpecialization::MNKPadding)
|
||||
.with_block_gemm(BlockGemmDesc_v1_intrawave);
|
||||
|
||||
using Builder = ConvBuilder<FwdConvSignature, FwdConvAlgorithm>;
|
||||
@@ -64,10 +64,11 @@ TEST(FwdConvInstances,
|
||||
|
||||
constexpr auto FwdConvAlgorithm =
|
||||
ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3{}
|
||||
.with_thread_block(FwdThreadBlock_256_256x256x32)
|
||||
.with_thread_block(ThreadBlock_256_256x256x32)
|
||||
.with_gemm_config(FwdGemmParams_Xdl_4x4_per_wave)
|
||||
.with_transfer(FwdTransfer_4x64x1)
|
||||
.with_specializations(ConvFwdSpecialization::FILTER_3x3, GemmSpecialization::MNKPadding)
|
||||
.with_transfer(Transfer_4x64x1)
|
||||
.with_fwd_specializations(ConvSpecialization::FILTER_3x3,
|
||||
GemmSpecialization::MNKPadding)
|
||||
.with_block_gemm(BlockGemmDesc_v5_intrawave);
|
||||
|
||||
using Builder = ConvBuilder<FwdConvSignature, FwdConvAlgorithm>;
|
||||
|
||||
@@ -32,11 +32,12 @@ TEST(FwdConvInstances,
|
||||
|
||||
constexpr auto FwdConvAlgorithm =
|
||||
ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle{}
|
||||
.with_thread_block(FwdThreadBlock_64_64x32x32)
|
||||
.with_thread_block(ThreadBlock_64_64x32x32)
|
||||
.with_gemm_config(FwdGemmParams_Xdl_2x1_per_wave)
|
||||
.with_transfer(FwdTransfer_4x16x1)
|
||||
.with_specializations(ConvFwdSpecialization::DEFAULT, GemmSpecialization::MNKPadding)
|
||||
.with_prefetch_config(1, 1, PipelineScheduler::DEFAULT);
|
||||
.with_transfer(Transfer_4x16x1)
|
||||
.with_fwd_specializations(ConvSpecialization::DEFAULT, GemmSpecialization::MNKPadding)
|
||||
.with_prefetch_config(1, PipelineScheduler::DEFAULT)
|
||||
.with_num_conv_groups_to_merge(1);
|
||||
|
||||
using Builder = ConvBuilder<FwdConvSignature, FwdConvAlgorithm>;
|
||||
|
||||
|
||||
@@ -25,15 +25,16 @@ TEST(FwdConvInstances, Create_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK_Ins
|
||||
|
||||
constexpr auto FwdConvAlgorithm =
|
||||
ConvAlgorithm_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK{}
|
||||
.with_thread_block(FwdThreadBlock_256_128x128x16)
|
||||
.with_specializations(ConvFwdSpecialization::DEFAULT, GemmSpecialization::MNKPadding)
|
||||
.with_thread_block(ThreadBlock_256_128x128x16)
|
||||
.with_fwd_specializations(ConvSpecialization::DEFAULT, GemmSpecialization::MNKPadding)
|
||||
.with_dl_thread_config(DlThreadConfig_16x2x4x4x1)
|
||||
.with_dl_thread_cluster(DlThreadCluster_8x2)
|
||||
.with_dl_transfer(DlFwdTransfer);
|
||||
.with_dl_transfer(DlTransfer4D);
|
||||
|
||||
using Builder = ConvBuilder<FwdConvSignature, FwdConvAlgorithm>;
|
||||
|
||||
const auto expected_transfer_parameters = to_string(FwdConvAlgorithm);
|
||||
std::cout << "Expected Transfer Parameters: " << expected_transfer_parameters << std::endl;
|
||||
run_test<Builder>({"DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK",
|
||||
expected_transfer_parameters,
|
||||
"Default",
|
||||
@@ -59,16 +60,17 @@ TEST(FwdConvInstances,
|
||||
|
||||
constexpr auto FwdConvAlgorithm =
|
||||
ConvAlgorithm_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK{}
|
||||
.with_thread_block(FwdThreadBlock_256_128x128x16)
|
||||
.with_specializations(ConvFwdSpecialization::FILTER_1X1_PAD0,
|
||||
GemmSpecialization::MNKPadding)
|
||||
.with_thread_block(ThreadBlock_256_128x128x16)
|
||||
.with_fwd_specializations(ConvSpecialization::FILTER_1X1_PAD0,
|
||||
GemmSpecialization::MNKPadding)
|
||||
.with_dl_thread_config(DlThreadConfig_16x2x4x4x1)
|
||||
.with_dl_thread_cluster(DlThreadCluster_8x2)
|
||||
.with_dl_transfer(DlFwdTransfer);
|
||||
.with_dl_transfer(DlTransfer4D);
|
||||
|
||||
using Builder = ConvBuilder<FwdConvSignature, FwdConvAlgorithm>;
|
||||
|
||||
const auto expected_transfer_parameters = to_string(FwdConvAlgorithm);
|
||||
std::cout << "Expected Transfer Parameters: " << expected_transfer_parameters << std::endl;
|
||||
run_test<Builder>({"DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK",
|
||||
expected_transfer_parameters,
|
||||
"Filter1x1Pad0",
|
||||
|
||||
@@ -25,11 +25,11 @@ constexpr auto SIGNATURE =
|
||||
.output = {.config = {.layout = ckb::TensorLayout::GNHWK}}};
|
||||
|
||||
constexpr auto ALGORITHM = cku::ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3{}
|
||||
.with_thread_block(cku::FwdThreadBlock_256_256x256x32)
|
||||
.with_thread_block(cku::ThreadBlock_256_256x256x32)
|
||||
.with_gemm_config(cku::FwdGemmParams_Xdl_4x4_per_wave)
|
||||
.with_transfer(cku::FwdTransfer_4x64x1)
|
||||
.with_specializations(ckb::ConvFwdSpecialization::DEFAULT,
|
||||
ckb::GemmSpecialization::MNKPadding)
|
||||
.with_transfer(cku::Transfer_4x64x1)
|
||||
.with_fwd_specializations(ckb::ConvSpecialization::DEFAULT,
|
||||
ckb::GemmSpecialization::MNKPadding)
|
||||
.with_block_gemm(cku::BlockGemmDesc_v3_intrawave);
|
||||
|
||||
using Builder = ckb::ConvBuilder<SIGNATURE, ALGORITHM>;
|
||||
|
||||
@@ -26,11 +26,11 @@ TEST(FwdConvInstances,
|
||||
|
||||
constexpr auto FwdConvAlgorithm =
|
||||
ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3{}
|
||||
.with_thread_block(FwdThreadBlock_256_128x128x32)
|
||||
.with_thread_block(ThreadBlock_256_128x128x32)
|
||||
.with_gemm_config(FwdGemmParams_Xdl_4x4_per_wave)
|
||||
.with_transfer(FwdTransfer_4x64x1)
|
||||
.with_specializations(ConvFwdSpecialization::FILTER_1X1_STRIDE1_PAD0,
|
||||
GemmSpecialization::MNKPadding)
|
||||
.with_transfer(Transfer_4x64x1)
|
||||
.with_fwd_specializations(ConvSpecialization::FILTER_1X1_STRIDE1_PAD0,
|
||||
GemmSpecialization::MNKPadding)
|
||||
.with_block_gemm(BlockGemmDesc_v4_intrawave);
|
||||
|
||||
using Builder = ConvBuilder<FwdConvSignature, FwdConvAlgorithm>;
|
||||
|
||||
@@ -27,11 +27,12 @@ TEST(FwdConvInstances,
|
||||
|
||||
constexpr auto FwdConvAlgorithm =
|
||||
ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle{}
|
||||
.with_thread_block(FwdThreadBlock_256_256x128x32)
|
||||
.with_thread_block(ThreadBlock_256_256x128x32)
|
||||
.with_gemm_config(FwdGemmParams_Xdl_4x2_per_wave)
|
||||
.with_transfer(FwdTransfer_4x64x1_fp8)
|
||||
.with_specializations(ConvFwdSpecialization::DEFAULT, GemmSpecialization::MNKPadding)
|
||||
.with_prefetch_config(1, 1, PipelineScheduler::DEFAULT);
|
||||
.with_transfer(Transfer_4x64x1_fp8)
|
||||
.with_fwd_specializations(ConvSpecialization::DEFAULT, GemmSpecialization::MNKPadding)
|
||||
.with_prefetch_config(1, PipelineScheduler::DEFAULT)
|
||||
.with_num_conv_groups_to_merge(1);
|
||||
|
||||
using Builder = ConvBuilder<FwdConvSignature, FwdConvAlgorithm>;
|
||||
|
||||
|
||||
@@ -25,14 +25,13 @@ TEST(FwdConvInstances,
|
||||
.output = {.config = {.layout = GNHWK}}};
|
||||
|
||||
constexpr auto FwdConvAlgorithm =
|
||||
ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor{
|
||||
.base_algorithm = ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle{}
|
||||
.with_thread_block(FwdThreadBlock_256_256x128x32)
|
||||
.with_gemm_config(FwdGemmParams_Xdl_2x1_per_wave)
|
||||
.with_transfer(FwdTransfer_4x16x1)
|
||||
.with_specializations(ConvFwdSpecialization::DEFAULT,
|
||||
GemmSpecialization::MNKPadding)
|
||||
.with_prefetch_config(1, 1, PipelineScheduler::DEFAULT)};
|
||||
ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor{}
|
||||
.with_thread_block(ThreadBlock_256_256x128x32)
|
||||
.with_gemm_config(FwdGemmParams_Xdl_2x1_per_wave)
|
||||
.with_transfer(Transfer_4x16x1)
|
||||
.with_fwd_specializations(ConvSpecialization::DEFAULT, GemmSpecialization::MNKPadding)
|
||||
.with_prefetch_config(1, PipelineScheduler::DEFAULT)
|
||||
.with_num_conv_groups_to_merge(1);
|
||||
|
||||
using Builder = ConvBuilder<FwdConvSignature, FwdConvAlgorithm>;
|
||||
|
||||
@@ -62,14 +61,14 @@ TEST(
|
||||
.output = {.config = {.layout = GNHWK}}};
|
||||
|
||||
constexpr auto FwdConvAlgorithm =
|
||||
ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor{
|
||||
.base_algorithm = ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle{}
|
||||
.with_thread_block(FwdThreadBlock_128_128x128x32)
|
||||
.with_gemm_config(FwdGemmParams_Xdl_2x1_per_wave)
|
||||
.with_transfer(FwdTransfer_4x16x1)
|
||||
.with_specializations(ConvFwdSpecialization::FILTER_1X1_PAD0,
|
||||
GemmSpecialization::MNKPadding)
|
||||
.with_prefetch_config(1, 1, PipelineScheduler::DEFAULT)};
|
||||
ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor{}
|
||||
.with_thread_block(ThreadBlock_128_128x128x32)
|
||||
.with_gemm_config(FwdGemmParams_Xdl_2x1_per_wave)
|
||||
.with_transfer(Transfer_4x16x1)
|
||||
.with_fwd_specializations(ConvSpecialization::FILTER_1X1_PAD0,
|
||||
GemmSpecialization::MNKPadding)
|
||||
.with_prefetch_config(1, PipelineScheduler::DEFAULT)
|
||||
.with_num_conv_groups_to_merge(1);
|
||||
|
||||
using Builder = ConvBuilder<FwdConvSignature, FwdConvAlgorithm>;
|
||||
|
||||
|
||||
@@ -27,10 +27,10 @@ TEST(FwdConvInstances,
|
||||
|
||||
constexpr auto FwdConvAlgorithm =
|
||||
ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3{}
|
||||
.with_thread_block(FwdThreadBlock_256_256x256x32)
|
||||
.with_thread_block(ThreadBlock_256_256x256x32)
|
||||
.with_gemm_config(FwdGemmParams_Xdl_4x4_per_wave)
|
||||
.with_transfer(FwdTransfer_4x64x1)
|
||||
.with_specializations(ConvFwdSpecialization::DEFAULT, GemmSpecialization::MNKPadding)
|
||||
.with_transfer(Transfer_4x64x1)
|
||||
.with_fwd_specializations(ConvSpecialization::DEFAULT, GemmSpecialization::MNKPadding)
|
||||
.with_block_gemm(BlockGemmDesc_v3_intrawave);
|
||||
|
||||
using Builder = ConvBuilder<FwdConvSignature, FwdConvAlgorithm>;
|
||||
|
||||
@@ -27,11 +27,11 @@ TEST(FwdConvInstances,
|
||||
|
||||
constexpr auto FwdConvAlgorithm =
|
||||
ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3{}
|
||||
.with_thread_block(FwdThreadBlock_256_128x128x32)
|
||||
.with_thread_block(ThreadBlock_256_128x128x32)
|
||||
.with_gemm_config(FwdGemmParams_Xdl_2x1_per_wave)
|
||||
.with_transfer(FwdTransfer_4x64x1)
|
||||
.with_specializations(ConvFwdSpecialization::FILTER_1X1_PAD0,
|
||||
GemmSpecialization::MNKPadding)
|
||||
.with_transfer(Transfer_4x64x1)
|
||||
.with_fwd_specializations(ConvSpecialization::FILTER_1X1_PAD0,
|
||||
GemmSpecialization::MNKPadding)
|
||||
.with_block_gemm(BlockGemmDesc_v4_intrawave);
|
||||
|
||||
using Builder = ConvBuilder<FwdConvSignature, FwdConvAlgorithm>;
|
||||
|
||||
@@ -27,11 +27,11 @@ TEST(FwdConvInstances,
|
||||
|
||||
constexpr auto FwdConvAlgorithm =
|
||||
ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3{}
|
||||
.with_thread_block(FwdThreadBlock_256_256x256x32)
|
||||
.with_thread_block(ThreadBlock_256_256x256x32)
|
||||
.with_gemm_config(FwdGemmParams_Xdl_4x4_per_wave)
|
||||
.with_transfer(FwdTransfer_4x64x1)
|
||||
.with_specializations(ConvFwdSpecialization::FILTER_1X1_PAD0,
|
||||
GemmSpecialization::MNKPadding)
|
||||
.with_transfer(Transfer_4x64x1)
|
||||
.with_fwd_specializations(ConvSpecialization::FILTER_1X1_PAD0,
|
||||
GemmSpecialization::MNKPadding)
|
||||
.with_block_gemm(BlockGemmDesc_v1_intrawave);
|
||||
|
||||
using Builder = ConvBuilder<FwdConvSignature, FwdConvAlgorithm>;
|
||||
|
||||
@@ -101,7 +101,7 @@ TEST_F(ConvTraitsTest, ConvFwdTraitsExtraction)
|
||||
|
||||
// Verify specializations
|
||||
EXPECT_EQ(Traits::gemm_padding, ck_tile::builder::GemmPadding::DEFAULT);
|
||||
EXPECT_EQ(Traits::conv_specialization, ck_tile::builder::ConvFwdSpecialization::DEFAULT);
|
||||
EXPECT_EQ(Traits::conv_specialization, ck_tile::builder::ConvSpecialization::DEFAULT);
|
||||
|
||||
// Verify algorithm information
|
||||
EXPECT_EQ(Traits::thread_block_size, 256);
|
||||
@@ -229,7 +229,7 @@ TEST_F(ConvTraitsTest, ConvFwdBaseTraitsExtraction)
|
||||
|
||||
// Verify specializations
|
||||
EXPECT_EQ(Traits::gemm_padding, ck_tile::builder::GemmPadding::DEFAULT);
|
||||
EXPECT_EQ(Traits::conv_specialization, ck_tile::builder::ConvFwdSpecialization::DEFAULT);
|
||||
EXPECT_EQ(Traits::conv_specialization, ck_tile::builder::ConvSpecialization::DEFAULT);
|
||||
|
||||
// Verify algorithm information
|
||||
EXPECT_EQ(Traits::thread_block_size, 256);
|
||||
@@ -313,7 +313,7 @@ TEST_F(ConvTraitsTest, ConvFwdLargeTensorTraitsExtraction)
|
||||
|
||||
// Verify specializations
|
||||
EXPECT_EQ(Traits::gemm_padding, ck_tile::builder::GemmPadding::DEFAULT);
|
||||
EXPECT_EQ(Traits::conv_specialization, ck_tile::builder::ConvFwdSpecialization::DEFAULT);
|
||||
EXPECT_EQ(Traits::conv_specialization, ck_tile::builder::ConvSpecialization::DEFAULT);
|
||||
|
||||
// Verify algorithm information
|
||||
EXPECT_EQ(Traits::thread_block_size, 256);
|
||||
|
||||
@@ -230,7 +230,7 @@ TEST(InstanceToConvTraits, ExtractsDefaultSpecialization)
|
||||
|
||||
using Traits = ck_tile::reflect::conv::ConvTraits<DeviceInstance>;
|
||||
|
||||
EXPECT_EQ(Traits::conv_specialization, ck_tile::builder::ConvFwdSpecialization::DEFAULT);
|
||||
EXPECT_EQ(Traits::conv_specialization, ck_tile::builder::ConvSpecialization::DEFAULT);
|
||||
}
|
||||
|
||||
TEST(InstanceToConvTraits, ExtractsFilter1x1Pad0Specialization)
|
||||
@@ -289,8 +289,7 @@ TEST(InstanceToConvTraits, ExtractsFilter1x1Pad0Specialization)
|
||||
|
||||
using Traits = ck_tile::reflect::conv::ConvTraits<DeviceInstance>;
|
||||
|
||||
EXPECT_EQ(Traits::conv_specialization,
|
||||
ck_tile::builder::ConvFwdSpecialization::FILTER_1X1_PAD0);
|
||||
EXPECT_EQ(Traits::conv_specialization, ck_tile::builder::ConvSpecialization::FILTER_1X1_PAD0);
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
|
||||
@@ -8,26 +8,27 @@ namespace {
|
||||
|
||||
using namespace ck_tile::builder::test_utils;
|
||||
|
||||
TEST(FwdConvInstances, Create_ConvAlgorithm_Tile_GroupedConvolutionKernel_2D_FP16_NHWGC)
|
||||
TEST(BwdDataConvInstances, Create_ConvAlgorithm_Tile_GroupedConvolutionKernel_2D_FP16_NHWGC)
|
||||
{
|
||||
constexpr ConvSignature FwdConvSignature{.spatial_dim = 2,
|
||||
.direction = ConvDirection::BACKWARD_DATA,
|
||||
.data_type = DataType::FP16,
|
||||
.accumulation_data_type = DataType::FP32,
|
||||
.input = {.config = {.layout = TensorLayout::NHWGC}},
|
||||
.weight = {.config = {.layout = TensorLayout::GKYXC}},
|
||||
.output = {.config = {.layout = TensorLayout::NHWGK}}};
|
||||
constexpr ConvSignature BwdDataConvSignature{
|
||||
.spatial_dim = 2,
|
||||
.direction = ConvDirection::BACKWARD_DATA,
|
||||
.data_type = DataType::FP16,
|
||||
.accumulation_data_type = DataType::FP32,
|
||||
.input = {.config = {.layout = TensorLayout::NHWGC}},
|
||||
.weight = {.config = {.layout = TensorLayout::GKYXC}},
|
||||
.output = {.config = {.layout = TensorLayout::NHWGK}}};
|
||||
|
||||
constexpr auto FwdConvAlgorithm =
|
||||
constexpr auto BwdDataConvAlgorithm =
|
||||
ConvAlgorithm_Tile_GroupedConvolutionKernel{}
|
||||
.with_tile_specializations(TileConvSpecialization::DEFAULT)
|
||||
.with_tile_thread_block(FwdTileThreadBlock_64x64x64)
|
||||
.with_tile_thread_block(TileThreadBlock_64x64x64)
|
||||
.with_tile_block_gemm(TileBlockGemmDesc_16x16_v3_intrawave)
|
||||
.with_tile_transfer(FwdTileTransfer_4x4x4)
|
||||
.with_tile_transfer(TileTransfer_4x4x4)
|
||||
.with_tile_optimizations(TileOptimizations{
|
||||
.num_groups_to_merge = 1, .split_image = false, .explicit_gemm = false});
|
||||
|
||||
using Builder = ConvBuilder<FwdConvSignature, FwdConvAlgorithm>;
|
||||
using Builder = ConvBuilder<BwdDataConvSignature, BwdDataConvAlgorithm>;
|
||||
run_ck_tile_test<Builder>({
|
||||
"grouped_convolution_backward_data",
|
||||
"fp16",
|
||||
|
||||
@@ -8,26 +8,27 @@ namespace {
|
||||
|
||||
using namespace ck_tile::builder::test_utils;
|
||||
|
||||
TEST(FwdConvInstances, Create_ConvAlgorithm_Tile_GroupedConvolutionKernel_2D_FP16_NHWGC)
|
||||
TEST(BwdWeightConvInstances, Create_ConvAlgorithm_Tile_GroupedConvolutionKernel_2D_FP16_NHWGC)
|
||||
{
|
||||
constexpr ConvSignature FwdConvSignature{.spatial_dim = 2,
|
||||
.direction = ConvDirection::BACKWARD_WEIGHT,
|
||||
.data_type = DataType::FP16,
|
||||
.accumulation_data_type = DataType::FP32,
|
||||
.input = {.config = {.layout = TensorLayout::NHWGC}},
|
||||
.weight = {.config = {.layout = TensorLayout::GKYXC}},
|
||||
.output = {.config = {.layout = TensorLayout::NHWGK}}};
|
||||
constexpr ConvSignature BwdWeightConvSignature{
|
||||
.spatial_dim = 2,
|
||||
.direction = ConvDirection::BACKWARD_WEIGHT,
|
||||
.data_type = DataType::FP16,
|
||||
.accumulation_data_type = DataType::FP32,
|
||||
.input = {.config = {.layout = TensorLayout::NHWGC}},
|
||||
.weight = {.config = {.layout = TensorLayout::GKYXC}},
|
||||
.output = {.config = {.layout = TensorLayout::NHWGK}}};
|
||||
|
||||
constexpr auto FwdConvAlgorithm =
|
||||
constexpr auto BwdWeightConvAlgorithm =
|
||||
ConvAlgorithm_Tile_GroupedConvolutionKernel{}
|
||||
.with_tile_specializations(TileConvSpecialization::DEFAULT)
|
||||
.with_tile_thread_block(FwdTileThreadBlock_64x64x64)
|
||||
.with_tile_thread_block(TileThreadBlock_64x64x64)
|
||||
.with_tile_block_gemm(TileBlockGemmDesc_16x16_v3_intrawave)
|
||||
.with_tile_transfer(FwdTileTransfer_4x4x4)
|
||||
.with_tile_transfer(TileTransfer_4x4x4)
|
||||
.with_tile_optimizations(TileOptimizations{
|
||||
.num_groups_to_merge = 1, .split_image = false, .explicit_gemm = false});
|
||||
|
||||
using Builder = ConvBuilder<FwdConvSignature, FwdConvAlgorithm>;
|
||||
using Builder = ConvBuilder<BwdWeightConvSignature, BwdWeightConvAlgorithm>;
|
||||
run_ck_tile_test<Builder>({
|
||||
"grouped_convolution_backward_weight",
|
||||
"fp16",
|
||||
|
||||
@@ -21,9 +21,9 @@ TEST(FwdConvInstances, Create_ConvAlgorithm_Tile_GroupedConvolutionKernel_2D_FP1
|
||||
constexpr auto FwdConvAlgorithm =
|
||||
ConvAlgorithm_Tile_GroupedConvolutionKernel{}
|
||||
.with_tile_specializations(TileConvSpecialization::DEFAULT)
|
||||
.with_tile_thread_block(FwdTileThreadBlock_64x64x64)
|
||||
.with_tile_thread_block(TileThreadBlock_64x64x64)
|
||||
.with_tile_block_gemm(TileBlockGemmDesc_16x16_v3_intrawave)
|
||||
.with_tile_transfer(FwdTileTransfer_4x4x4)
|
||||
.with_tile_transfer(TileTransfer_4x4x4)
|
||||
.with_tile_optimizations(TileOptimizations{
|
||||
.num_groups_to_merge = 1, .split_image = false, .explicit_gemm = false});
|
||||
|
||||
|
||||
@@ -28,18 +28,31 @@ struct ThreadBlock
|
||||
};
|
||||
static_assert(ckb::ThreadBlockDescriptor<ThreadBlock>);
|
||||
|
||||
// Describe gridwise XDL GEMM parameters.
|
||||
struct GridwiseXdlGemm
|
||||
struct XdlParams
|
||||
{
|
||||
// NOTE: ak1 and bk1 are difficult to verify in the kernel instantiation!!!
|
||||
size_t ak1 = 0;
|
||||
size_t bk1 = 0;
|
||||
size_t m_per_xdl = 0;
|
||||
size_t n_per_xdl = 0;
|
||||
size_t m_xdl_per_wave = 0;
|
||||
size_t n_xdl_per_wave = 0;
|
||||
};
|
||||
static_assert(ckb::GridwiseXdlGemmDescriptor<GridwiseXdlGemm>);
|
||||
static_assert(ckb::GridwiseXdlGemmDescriptor<XdlParams>);
|
||||
|
||||
// Describe gridwise XDL GEMM parameters.
|
||||
struct GridwiseFwdXdlGemm
|
||||
{
|
||||
// NOTE: ak1 and bk1 are difficult to verify in the kernel instantiation!!!
|
||||
size_t ak1 = 0;
|
||||
size_t bk1 = 0;
|
||||
XdlParams xdl_params;
|
||||
};
|
||||
static_assert(ckb::GridwiseFwdXdlGemmDescriptor<GridwiseFwdXdlGemm>);
|
||||
|
||||
struct GridwiseBwdXdlGemm
|
||||
{
|
||||
size_t k1 = 0;
|
||||
XdlParams xdl_params;
|
||||
};
|
||||
static_assert(ckb::GridwiseBwdXdlGemmDescriptor<GridwiseBwdXdlGemm>);
|
||||
|
||||
// Describe gridwise WMMA GEMM parameters.
|
||||
struct GridwiseWmmaGemm
|
||||
@@ -49,25 +62,36 @@ struct GridwiseWmmaGemm
|
||||
size_t n_per_wmma = 0;
|
||||
size_t m_wmma_per_wave = 0;
|
||||
size_t n_wmma_per_wave = 0;
|
||||
PipelineVersion pipeline_version;
|
||||
};
|
||||
static_assert(ckb::GridwiseWmmaGemmDescriptor<GridwiseWmmaGemm>);
|
||||
|
||||
struct BlockGemm
|
||||
struct BlockGemmPipeline
|
||||
{
|
||||
PipelineVersion pipeline_version;
|
||||
PipelineScheduler scheduler;
|
||||
};
|
||||
static_assert(ckb::BlockGemmDescriptor<BlockGemm>);
|
||||
static_assert(ckb::BlockGemmPipelineDescriptor<BlockGemmPipeline>);
|
||||
|
||||
// Describe Aand B block transfer thread cluster lengths.
|
||||
template <size_t ThreadSliceLength = 3>
|
||||
struct BlockTransfer
|
||||
{
|
||||
size_t k0;
|
||||
size_t m_n;
|
||||
size_t k1;
|
||||
size_t k_batch_size;
|
||||
};
|
||||
static_assert(ckb::BlockTransferDescriptor<BlockTransfer>);
|
||||
|
||||
// Specialization for ThreadSliceLength == 3
|
||||
template <>
|
||||
struct BlockTransfer<3>
|
||||
{
|
||||
size_t k0;
|
||||
size_t m_n;
|
||||
size_t k1;
|
||||
};
|
||||
static_assert(ckb::BlockTransferDescriptor<BlockTransfer<3>, 3>);
|
||||
static_assert(ckb::BlockTransferDescriptor<BlockTransfer<4>, 4>);
|
||||
|
||||
// Describe C block transfer thread cluster lengths.
|
||||
struct ThreadCluster
|
||||
@@ -97,31 +121,35 @@ struct Epilogue
|
||||
};
|
||||
static_assert(EpilogueDescriptor<Epilogue>);
|
||||
|
||||
template <size_t ThreadSliceLength = 3>
|
||||
struct AccessOrder
|
||||
{
|
||||
std::array<size_t, 3> order;
|
||||
std::array<size_t, ThreadSliceLength> order;
|
||||
};
|
||||
static_assert(AccessOrderDescriptor<AccessOrder>);
|
||||
static_assert(AccessOrderDescriptor<AccessOrder<>>);
|
||||
static_assert(AccessOrderDescriptor<AccessOrder<4>>);
|
||||
|
||||
struct TransferAB
|
||||
template <size_t ThreadSliceLength = 3>
|
||||
struct InputTransfer
|
||||
{
|
||||
BlockTransfer block_transfer;
|
||||
BlockTransfer<ThreadSliceLength> block_transfer;
|
||||
LdsTransfer lds_transfer;
|
||||
AccessOrder block_transfer_access_order;
|
||||
AccessOrder src_access_order;
|
||||
AccessOrder<ThreadSliceLength> block_transfer_access_order;
|
||||
AccessOrder<ThreadSliceLength> src_access_order;
|
||||
};
|
||||
|
||||
struct TransferC
|
||||
struct OutputTransfer
|
||||
{
|
||||
ThreadCluster thread_cluster_dims;
|
||||
Epilogue epilogue;
|
||||
};
|
||||
|
||||
struct TransferABC
|
||||
template <size_t ThreadSliceLength = 3>
|
||||
struct Transfer
|
||||
{
|
||||
TransferAB a;
|
||||
TransferAB b;
|
||||
TransferC c;
|
||||
InputTransfer<ThreadSliceLength> a;
|
||||
InputTransfer<ThreadSliceLength> b;
|
||||
OutputTransfer c;
|
||||
};
|
||||
|
||||
// DL-specific descriptors
|
||||
@@ -142,17 +170,19 @@ struct DlThreadCluster
|
||||
};
|
||||
static_assert(ckb::DlThreadClusterDescriptor<DlThreadCluster>);
|
||||
|
||||
template <size_t D = 4>
|
||||
struct DlBlockTransfer
|
||||
{
|
||||
std::array<size_t, 4> thread_slice_lengths;
|
||||
std::array<size_t, 4> thread_cluster_lengths;
|
||||
std::array<size_t, 4> thread_cluster_arrange_order;
|
||||
std::array<size_t, 4> src_access_order;
|
||||
std::array<size_t, 4> src_vector_tensor_lengths;
|
||||
std::array<size_t, 4> src_vector_tensor_contiguous_dim_order;
|
||||
std::array<size_t, 4> dst_vector_tensor_lengths;
|
||||
std::array<size_t, D> thread_slice_lengths;
|
||||
std::array<size_t, D> thread_cluster_lengths;
|
||||
std::array<size_t, D> thread_cluster_arrange_order;
|
||||
std::array<size_t, D> src_access_order;
|
||||
std::array<size_t, D> src_vector_tensor_lengths;
|
||||
std::array<size_t, D> src_vector_tensor_contiguous_dim_order;
|
||||
std::array<size_t, D> dst_vector_tensor_lengths;
|
||||
};
|
||||
static_assert(ckb::DlBlockTransferDescriptor<DlBlockTransfer>);
|
||||
static_assert(ckb::DlBlockTransferDescriptor4D<DlBlockTransfer<4>>);
|
||||
static_assert(ckb::DlBlockTransferDescriptor5D<DlBlockTransfer<5>>);
|
||||
|
||||
struct DlEpilogue
|
||||
{
|
||||
@@ -169,9 +199,14 @@ struct ThreadBlock_
|
||||
ThreadBlock thread_block;
|
||||
};
|
||||
|
||||
struct XdlGemm_
|
||||
struct FwdXdlGemm_
|
||||
{
|
||||
GridwiseXdlGemm gridwise_gemm;
|
||||
GridwiseFwdXdlGemm gridwise_gemm;
|
||||
};
|
||||
|
||||
struct BwdXdlGemm_
|
||||
{
|
||||
GridwiseBwdXdlGemm gridwise_gemm;
|
||||
};
|
||||
|
||||
struct WmmaGemm_
|
||||
@@ -179,27 +214,48 @@ struct WmmaGemm_
|
||||
GridwiseWmmaGemm gridwise_gemm;
|
||||
};
|
||||
|
||||
template <size_t ThreadSliceLength = 3>
|
||||
struct Transfer_
|
||||
{
|
||||
TransferABC transfer;
|
||||
Transfer<ThreadSliceLength> transfer;
|
||||
};
|
||||
|
||||
struct ConvSpecialization_
|
||||
struct ConvSpecializationFwd_
|
||||
{
|
||||
ConvFwdSpecialization fwd_specialization;
|
||||
ConvSpecialization fwd_specialization;
|
||||
GemmSpecialization gemm_specialization;
|
||||
};
|
||||
|
||||
struct ConvSpecializationBwdWeight_
|
||||
{
|
||||
ConvSpecialization bwd_weight_specialization;
|
||||
};
|
||||
|
||||
struct Prefetch_
|
||||
{
|
||||
size_t num_gemm_k_prefetch_stages;
|
||||
size_t num_groups_to_merge;
|
||||
PipelineScheduler loop_scheduler;
|
||||
};
|
||||
|
||||
struct TransposeParams_
|
||||
{
|
||||
size_t max_transpose_transfer_src_scalar_per_vector{1};
|
||||
size_t max_transpose_transfer_dst_scalar_per_vector{1};
|
||||
};
|
||||
|
||||
struct GemmBatchOptions_
|
||||
{
|
||||
size_t num_conv_groups_to_merge{1};
|
||||
};
|
||||
|
||||
struct BlockGemm_
|
||||
{
|
||||
BlockGemm block_gemm;
|
||||
BlockGemmPipeline block_gemm_pipeline;
|
||||
};
|
||||
|
||||
struct GridGemm_
|
||||
{
|
||||
PipelineVersion pipeline_version;
|
||||
};
|
||||
|
||||
struct DlThreadConfig_
|
||||
@@ -212,33 +268,34 @@ struct DlThreadCluster_
|
||||
DlThreadCluster thread_cluster;
|
||||
};
|
||||
|
||||
struct DlBlockTransferAB
|
||||
template <size_t Dim = 4>
|
||||
struct DlTransfer
|
||||
{
|
||||
DlBlockTransfer block_transfer;
|
||||
};
|
||||
|
||||
struct DlBlockTransferC
|
||||
{
|
||||
DlEpilogue epilogue;
|
||||
};
|
||||
|
||||
struct DlTransferABC
|
||||
{
|
||||
DlBlockTransferAB a;
|
||||
DlBlockTransferAB b;
|
||||
DlBlockTransferC c;
|
||||
DlBlockTransfer<Dim> a;
|
||||
DlBlockTransfer<Dim> b;
|
||||
DlEpilogue c;
|
||||
};
|
||||
|
||||
template <size_t Dim = 4>
|
||||
struct DlTransfer_
|
||||
{
|
||||
DlTransferABC transfer;
|
||||
DlTransfer<Dim> transfer;
|
||||
};
|
||||
|
||||
// Specialization wrapper for large tensor support
|
||||
template <typename BaseAlgorithm>
|
||||
struct LargeTensorWrapper
|
||||
struct TwoStageSpecialization_
|
||||
{
|
||||
static constexpr ConvAlgorithmSpecialization specialization =
|
||||
ConvAlgorithmSpecialization::TWO_STAGE;
|
||||
};
|
||||
|
||||
struct MultipleDSpecialization_
|
||||
{
|
||||
static constexpr ConvAlgorithmSpecialization specialization =
|
||||
ConvAlgorithmSpecialization::MULTIPLE_D;
|
||||
};
|
||||
|
||||
struct LargeTensorSpecialization_
|
||||
{
|
||||
BaseAlgorithm base_algorithm;
|
||||
static constexpr ConvAlgorithmSpecialization specialization =
|
||||
ConvAlgorithmSpecialization::LARGE_TENSOR;
|
||||
};
|
||||
@@ -329,7 +386,11 @@ struct ConvAlgorithmTemplate : Components...
|
||||
constexpr auto with_gemm_config(const GemmConfig& gemm) const
|
||||
{
|
||||
auto result = *this;
|
||||
if constexpr(std::is_base_of_v<XdlGemm_, ConvAlgorithmTemplate>)
|
||||
if constexpr(std::is_base_of_v<FwdXdlGemm_, ConvAlgorithmTemplate>)
|
||||
{
|
||||
result.gridwise_gemm = gemm;
|
||||
}
|
||||
else if constexpr(std::is_base_of_v<BwdXdlGemm_, ConvAlgorithmTemplate>)
|
||||
{
|
||||
result.gridwise_gemm = gemm;
|
||||
}
|
||||
@@ -337,46 +398,82 @@ struct ConvAlgorithmTemplate : Components...
|
||||
{
|
||||
result.gridwise_gemm = gemm;
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(false, "Unrecognized GemmConfig type");
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
constexpr auto with_transfer(const T& t) const
|
||||
{
|
||||
static_assert(std::is_base_of_v<Transfer_, ConvAlgorithmTemplate>);
|
||||
static_assert(std::is_base_of_v<Transfer_<3>, ConvAlgorithmTemplate> ||
|
||||
std::is_base_of_v<Transfer_<4>, ConvAlgorithmTemplate>);
|
||||
auto result = *this;
|
||||
result.transfer = t;
|
||||
return result;
|
||||
}
|
||||
|
||||
constexpr auto with_specializations(ConvFwdSpecialization fwd_spec,
|
||||
GemmSpecialization gemm_spec) const
|
||||
constexpr auto with_fwd_specializations(ConvSpecialization fwd_spec,
|
||||
GemmSpecialization gemm_spec) const
|
||||
{
|
||||
static_assert(std::is_base_of_v<ConvSpecialization_, ConvAlgorithmTemplate>);
|
||||
static_assert(std::is_base_of_v<ConvSpecializationFwd_, ConvAlgorithmTemplate>);
|
||||
auto result = *this;
|
||||
result.fwd_specialization = fwd_spec;
|
||||
result.gemm_specialization = gemm_spec;
|
||||
return result;
|
||||
}
|
||||
|
||||
constexpr auto with_prefetch_config(size_t k_prefetch_stages,
|
||||
size_t groups_to_merge,
|
||||
PipelineScheduler scheduler) const
|
||||
constexpr auto with_bwd_specialization(ConvSpecialization bwd_spec) const
|
||||
{
|
||||
static_assert(std::is_base_of_v<ConvSpecializationBwdWeight_, ConvAlgorithmTemplate>);
|
||||
auto result = *this;
|
||||
result.bwd_weight_specialization = bwd_spec;
|
||||
return result;
|
||||
}
|
||||
|
||||
constexpr auto with_prefetch_config(size_t k_prefetch_stages, PipelineScheduler scheduler) const
|
||||
{
|
||||
static_assert(std::is_base_of_v<Prefetch_, ConvAlgorithmTemplate>);
|
||||
auto result = *this;
|
||||
result.num_gemm_k_prefetch_stages = k_prefetch_stages;
|
||||
result.num_groups_to_merge = groups_to_merge;
|
||||
result.loop_scheduler = scheduler;
|
||||
return result;
|
||||
}
|
||||
|
||||
constexpr auto with_transpose_params(size_t max_src_scalar_per_vector,
|
||||
size_t max_dst_scalar_per_vector) const
|
||||
{
|
||||
static_assert(std::is_base_of_v<TransposeParams_, ConvAlgorithmTemplate>);
|
||||
auto result = *this;
|
||||
result.max_transpose_transfer_src_scalar_per_vector = max_src_scalar_per_vector;
|
||||
result.max_transpose_transfer_dst_scalar_per_vector = max_dst_scalar_per_vector;
|
||||
return result;
|
||||
}
|
||||
|
||||
constexpr auto with_num_conv_groups_to_merge(size_t num_groups_to_merge) const
|
||||
{
|
||||
static_assert(std::is_base_of_v<GemmBatchOptions_, ConvAlgorithmTemplate>);
|
||||
auto result = *this;
|
||||
result.num_conv_groups_to_merge = num_groups_to_merge;
|
||||
return result;
|
||||
}
|
||||
|
||||
template <typename BG>
|
||||
constexpr auto with_block_gemm(const BG& bg) const
|
||||
{
|
||||
static_assert(std::is_base_of_v<BlockGemm_, ConvAlgorithmTemplate>);
|
||||
auto result = *this;
|
||||
result.block_gemm = bg;
|
||||
auto result = *this;
|
||||
result.block_gemm_pipeline = bg;
|
||||
return result;
|
||||
}
|
||||
|
||||
constexpr auto with_gridwise_gemm_pipeline(const PipelineVersion plv) const
|
||||
{
|
||||
static_assert(std::is_base_of_v<GridGemm_, ConvAlgorithmTemplate>);
|
||||
auto result = *this;
|
||||
result.pipeline_version = plv;
|
||||
return result;
|
||||
}
|
||||
|
||||
@@ -401,7 +498,8 @@ struct ConvAlgorithmTemplate : Components...
|
||||
template <typename T>
|
||||
constexpr auto with_dl_transfer(const T& t) const
|
||||
{
|
||||
static_assert(std::is_base_of_v<DlTransfer_, ConvAlgorithmTemplate>);
|
||||
static_assert(std::is_base_of_v<DlTransfer_<4>, ConvAlgorithmTemplate> ||
|
||||
std::is_base_of_v<DlTransfer_<5>, ConvAlgorithmTemplate>);
|
||||
auto result = *this;
|
||||
result.transfer = t;
|
||||
return result;
|
||||
@@ -453,26 +551,49 @@ struct ConvAlgorithmTemplate : Components...
|
||||
}
|
||||
};
|
||||
|
||||
// Algorithm types
|
||||
// Fwd algorithm types
|
||||
|
||||
using ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle =
|
||||
ConvAlgorithmTemplate<ThreadBlock_, XdlGemm_, Transfer_, ConvSpecialization_, Prefetch_>;
|
||||
ConvAlgorithmTemplate<ThreadBlock_,
|
||||
FwdXdlGemm_,
|
||||
Transfer_<>,
|
||||
ConvSpecializationFwd_,
|
||||
Prefetch_,
|
||||
GemmBatchOptions_>;
|
||||
|
||||
using ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 =
|
||||
ConvAlgorithmTemplate<ThreadBlock_, XdlGemm_, Transfer_, ConvSpecialization_, BlockGemm_>;
|
||||
ConvAlgorithmTemplate<ThreadBlock_,
|
||||
FwdXdlGemm_,
|
||||
Transfer_<>,
|
||||
ConvSpecializationFwd_,
|
||||
BlockGemm_>;
|
||||
|
||||
using ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle =
|
||||
ConvAlgorithmTemplate<ThreadBlock_, WmmaGemm_, Transfer_, ConvSpecialization_, Prefetch_>;
|
||||
ConvAlgorithmTemplate<ThreadBlock_,
|
||||
WmmaGemm_,
|
||||
Transfer_<>,
|
||||
ConvSpecializationFwd_,
|
||||
GridGemm_,
|
||||
Prefetch_,
|
||||
GemmBatchOptions_>;
|
||||
|
||||
using ConvAlgorithm_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK =
|
||||
ConvAlgorithmTemplate<ThreadBlock_,
|
||||
ConvSpecialization_,
|
||||
ConvSpecializationFwd_,
|
||||
DlThreadConfig_,
|
||||
DlThreadCluster_,
|
||||
DlTransfer_>;
|
||||
DlTransfer_<>>;
|
||||
|
||||
using ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor =
|
||||
LargeTensorWrapper<ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle>;
|
||||
ConvAlgorithmTemplate<ThreadBlock_,
|
||||
FwdXdlGemm_,
|
||||
Transfer_<>,
|
||||
ConvSpecializationFwd_,
|
||||
Prefetch_,
|
||||
GemmBatchOptions_,
|
||||
LargeTensorSpecialization_>;
|
||||
|
||||
// CK Tile algorithm
|
||||
using ConvAlgorithm_Tile_GroupedConvolutionKernel = ConvAlgorithmTemplate<TileThreadBlock_,
|
||||
TileBlockGemm_,
|
||||
TileTransfer_,
|
||||
@@ -488,4 +609,77 @@ struct ConvAlgorithm_Reference
|
||||
// GPU reference uses simple algorithm, no tile configuration needed
|
||||
};
|
||||
|
||||
// Bwd weight algorithm types
|
||||
using ConvAlgorithm_DeviceGroupedConvBwdWeight_Xdl_CShuffle =
|
||||
ConvAlgorithmTemplate<ThreadBlock_,
|
||||
BwdXdlGemm_,
|
||||
Transfer_<4>,
|
||||
ConvSpecializationBwdWeight_,
|
||||
TransposeParams_>;
|
||||
|
||||
using ConvAlgorithm_DeviceGroupedConvBwdWeight_TwoStage_Xdl_CShuffle =
|
||||
ConvAlgorithmTemplate<ThreadBlock_,
|
||||
BwdXdlGemm_,
|
||||
Transfer_<>,
|
||||
ConvSpecializationBwdWeight_,
|
||||
BlockGemm_,
|
||||
TransposeParams_,
|
||||
GemmBatchOptions_,
|
||||
TwoStageSpecialization_>;
|
||||
|
||||
using ConvAlgorithm_DeviceGroupedConvBwdWeight_Xdl_CShuffle_V3 =
|
||||
ConvAlgorithmTemplate<ThreadBlock_,
|
||||
BwdXdlGemm_,
|
||||
Transfer_<>,
|
||||
ConvSpecializationBwdWeight_,
|
||||
BlockGemm_>;
|
||||
|
||||
using ConvAlgorithm_DeviceGroupedConvBwdWeight_Dl =
|
||||
ConvAlgorithmTemplate<ThreadBlock_,
|
||||
DlThreadConfig_,
|
||||
DlThreadCluster_,
|
||||
DlTransfer_<5>,
|
||||
ConvSpecializationBwdWeight_>;
|
||||
|
||||
using ConvAlgorithm_DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle =
|
||||
ConvAlgorithmTemplate<ThreadBlock_,
|
||||
BwdXdlGemm_,
|
||||
Transfer_<4>,
|
||||
ConvSpecializationBwdWeight_,
|
||||
MultipleDSpecialization_>;
|
||||
|
||||
using ConvAlgorithm_DeviceGroupedConvBwdWeight_Wmma_CShuffle_V3 =
|
||||
ConvAlgorithmTemplate<ThreadBlock_,
|
||||
WmmaGemm_,
|
||||
Transfer_<>,
|
||||
ConvSpecializationBwdWeight_,
|
||||
BlockGemm_,
|
||||
TransposeParams_>;
|
||||
|
||||
using ConvAlgorithm_DeviceGroupedConvBwdWeight_TwoStage_Wmma_CShuffle_V3 =
|
||||
ConvAlgorithmTemplate<ThreadBlock_,
|
||||
WmmaGemm_,
|
||||
Transfer_<>,
|
||||
ConvSpecializationBwdWeight_,
|
||||
BlockGemm_,
|
||||
TransposeParams_,
|
||||
GemmBatchOptions_,
|
||||
TwoStageSpecialization_>;
|
||||
|
||||
using ConvAlgorithm_DeviceGroupedConvBwdWeight_Wmma_CShuffle =
|
||||
ConvAlgorithmTemplate<ThreadBlock_,
|
||||
WmmaGemm_,
|
||||
Transfer_<>,
|
||||
ConvSpecializationBwdWeight_,
|
||||
GridGemm_,
|
||||
Prefetch_>;
|
||||
|
||||
using ConvAlgorithm_DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffle_V3 =
|
||||
ConvAlgorithmTemplate<ThreadBlock_,
|
||||
WmmaGemm_,
|
||||
Transfer_<>,
|
||||
ConvSpecializationBwdWeight_,
|
||||
BlockGemm_,
|
||||
MultipleDSpecialization_>;
|
||||
|
||||
} // namespace ck_tile::builder::test
|
||||
|
||||
@@ -120,14 +120,12 @@ struct DefaultAlgorithm
|
||||
ckb::test::ThreadBlock thread_block{.block_size = 256,
|
||||
.tile_size = {.m = 256, .n = 256, .k = 32}};
|
||||
|
||||
ckb::test::GridwiseXdlGemm gridwise_gemm{.ak1 = 8,
|
||||
.bk1 = 8,
|
||||
.m_per_xdl = 16,
|
||||
.n_per_xdl = 16,
|
||||
.m_xdl_per_wave = 8,
|
||||
.n_xdl_per_wave = 8};
|
||||
ckb::test::GridwiseFwdXdlGemm gridwise_gemm{
|
||||
.ak1 = 8,
|
||||
.bk1 = 8,
|
||||
.xdl_params = {.m_per_xdl = 16, .n_per_xdl = 16, .m_xdl_per_wave = 8, .n_xdl_per_wave = 8}};
|
||||
|
||||
ckb::test::TransferABC transfer{
|
||||
ckb::test::Transfer<> transfer{
|
||||
.a =
|
||||
{
|
||||
.block_transfer = {.k0 = 1, .m_n = 128, .k1 = 2},
|
||||
@@ -161,10 +159,11 @@ struct DefaultAlgorithm
|
||||
},
|
||||
};
|
||||
|
||||
ckb::ConvFwdSpecialization fwd_specialization = ckb::ConvFwdSpecialization::DEFAULT;
|
||||
ckb::GemmSpecialization gemm_specialization = ckb::GemmSpecialization::Default;
|
||||
ckb::test::BlockGemm block_gemm{.pipeline_version = ckb::PipelineVersion::V4,
|
||||
.scheduler = ckb::PipelineScheduler::INTRAWAVE};
|
||||
ckb::ConvSpecialization fwd_specialization = ckb::ConvSpecialization::DEFAULT;
|
||||
ckb::GemmSpecialization gemm_specialization = ckb::GemmSpecialization::Default;
|
||||
ckb::test::BlockGemmPipeline block_gemm_pipeline{.pipeline_version = ckb::PipelineVersion::V4,
|
||||
.scheduler =
|
||||
ckb::PipelineScheduler::INTRAWAVE};
|
||||
};
|
||||
static_assert(ckb::ConvAlgorithmDescriptor<DefaultAlgorithm>);
|
||||
|
||||
|
||||
@@ -38,11 +38,11 @@ TEST(ConvTensorLayout, AssignsLayoutsFor1D_NWGC_GKXC_NWGK)
|
||||
.weight = {.config = {.layout = GKXC}},
|
||||
.output = {.config = {.layout = NWGK}}};
|
||||
|
||||
using TensorLayouts = ConvTensorLayouts<sig, 1, FORWARD>;
|
||||
using TensorLayouts = ConvTensorLayouts<sig, 1>;
|
||||
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::ALayout, ck::tensor_layout::convolution::NWGC>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::BLayout, ck::tensor_layout::convolution::GKXC>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::ELayout, ck::tensor_layout::convolution::NWGK>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::InLayout, ck::tensor_layout::convolution::NWGC>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::WeiLayout, ck::tensor_layout::convolution::GKXC>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::OutLayout, ck::tensor_layout::convolution::NWGK>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::DsLayout, ck::Tuple<>>));
|
||||
}
|
||||
|
||||
@@ -57,11 +57,11 @@ TEST(ConvTensorLayout, AssignsLayoutsFor1D_NGCW_GKXC_NGKW)
|
||||
.weight = {.config = {.layout = GKXC}},
|
||||
.output = {.config = {.layout = NGKW}}};
|
||||
|
||||
using TensorLayouts = ConvTensorLayouts<sig, 1, FORWARD>;
|
||||
using TensorLayouts = ConvTensorLayouts<sig, 1>;
|
||||
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::ALayout, ck::tensor_layout::convolution::NGCW>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::BLayout, ck::tensor_layout::convolution::GKXC>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::ELayout, ck::tensor_layout::convolution::NGKW>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::InLayout, ck::tensor_layout::convolution::NGCW>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::WeiLayout, ck::tensor_layout::convolution::GKXC>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::OutLayout, ck::tensor_layout::convolution::NGKW>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::DsLayout, ck::Tuple<>>));
|
||||
}
|
||||
|
||||
@@ -76,11 +76,11 @@ TEST(ConvTensorLayout, AssignsLayoutsFor1D_GNWC_GKXC_GNWK)
|
||||
.weight = {.config = {.layout = GKXC}},
|
||||
.output = {.config = {.layout = GNWK}}};
|
||||
|
||||
using TensorLayouts = ConvTensorLayouts<sig, 1, FORWARD>;
|
||||
using TensorLayouts = ConvTensorLayouts<sig, 1>;
|
||||
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::ALayout, ck::tensor_layout::convolution::GNWC>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::BLayout, ck::tensor_layout::convolution::GKXC>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::ELayout, ck::tensor_layout::convolution::GNWK>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::InLayout, ck::tensor_layout::convolution::GNWC>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::WeiLayout, ck::tensor_layout::convolution::GKXC>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::OutLayout, ck::tensor_layout::convolution::GNWK>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::DsLayout, ck::Tuple<>>));
|
||||
}
|
||||
|
||||
@@ -95,11 +95,11 @@ TEST(ConvTensorLayout, AssignsLayoutsFor1D_NGCW_GKCX_NGKW)
|
||||
.weight = {.config = {.layout = GKCX}},
|
||||
.output = {.config = {.layout = NGKW}}};
|
||||
|
||||
using TensorLayouts = ConvTensorLayouts<sig, 1, FORWARD>;
|
||||
using TensorLayouts = ConvTensorLayouts<sig, 1>;
|
||||
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::ALayout, ck::tensor_layout::convolution::NGCW>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::BLayout, ck::tensor_layout::convolution::GKCX>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::ELayout, ck::tensor_layout::convolution::NGKW>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::InLayout, ck::tensor_layout::convolution::NGCW>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::WeiLayout, ck::tensor_layout::convolution::GKCX>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::OutLayout, ck::tensor_layout::convolution::NGKW>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::DsLayout, ck::Tuple<>>));
|
||||
}
|
||||
|
||||
@@ -114,11 +114,11 @@ TEST(ConvTensorLayout, AssignsLayoutsFor2D_NGCHW_GKYXC_NGKHW)
|
||||
.weight = {.config = {.layout = GKYXC}},
|
||||
.output = {.config = {.layout = NGKHW}}};
|
||||
|
||||
using TensorLayouts = ConvTensorLayouts<sig, 2, FORWARD>;
|
||||
using TensorLayouts = ConvTensorLayouts<sig, 2>;
|
||||
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::ALayout, ck::tensor_layout::convolution::NGCHW>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::BLayout, ck::tensor_layout::convolution::GKYXC>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::ELayout, ck::tensor_layout::convolution::NGKHW>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::InLayout, ck::tensor_layout::convolution::NGCHW>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::WeiLayout, ck::tensor_layout::convolution::GKYXC>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::OutLayout, ck::tensor_layout::convolution::NGKHW>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::DsLayout, ck::Tuple<>>));
|
||||
}
|
||||
|
||||
@@ -133,11 +133,11 @@ TEST(ConvTensorLayout, AssignsLayoutsFor2D_NHWGC_GKYXC_NHWGK)
|
||||
.weight = {.config = {.layout = GKYXC}},
|
||||
.output = {.config = {.layout = NHWGK}}};
|
||||
|
||||
using TensorLayouts = ConvTensorLayouts<sig, 2, FORWARD>;
|
||||
using TensorLayouts = ConvTensorLayouts<sig, 2>;
|
||||
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::ALayout, ck::tensor_layout::convolution::NHWGC>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::BLayout, ck::tensor_layout::convolution::GKYXC>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::ELayout, ck::tensor_layout::convolution::NHWGK>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::InLayout, ck::tensor_layout::convolution::NHWGC>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::WeiLayout, ck::tensor_layout::convolution::GKYXC>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::OutLayout, ck::tensor_layout::convolution::NHWGK>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::DsLayout, ck::Tuple<>>));
|
||||
}
|
||||
|
||||
@@ -152,11 +152,11 @@ TEST(ConvTensorLayout, AssignsLayoutsFor2D_GNHWC_GKYXC_GNHWK)
|
||||
.weight = {.config = {.layout = GKYXC}},
|
||||
.output = {.config = {.layout = GNHWK}}};
|
||||
|
||||
using TensorLayouts = ConvTensorLayouts<sig, 2, FORWARD>;
|
||||
using TensorLayouts = ConvTensorLayouts<sig, 2>;
|
||||
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::ALayout, ck::tensor_layout::convolution::GNHWC>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::BLayout, ck::tensor_layout::convolution::GKYXC>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::ELayout, ck::tensor_layout::convolution::GNHWK>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::InLayout, ck::tensor_layout::convolution::GNHWC>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::WeiLayout, ck::tensor_layout::convolution::GKYXC>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::OutLayout, ck::tensor_layout::convolution::GNHWK>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::DsLayout, ck::Tuple<>>));
|
||||
}
|
||||
|
||||
@@ -171,11 +171,11 @@ TEST(ConvTensorLayout, AssignsLayoutsFor2D_NGCHW_GKCYX_NGKHW)
|
||||
.weight = {.config = {.layout = GKCYX}},
|
||||
.output = {.config = {.layout = NGKHW}}};
|
||||
|
||||
using TensorLayouts = ConvTensorLayouts<sig, 2, FORWARD>;
|
||||
using TensorLayouts = ConvTensorLayouts<sig, 2>;
|
||||
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::ALayout, ck::tensor_layout::convolution::NGCHW>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::BLayout, ck::tensor_layout::convolution::GKCYX>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::ELayout, ck::tensor_layout::convolution::NGKHW>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::InLayout, ck::tensor_layout::convolution::NGCHW>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::WeiLayout, ck::tensor_layout::convolution::GKCYX>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::OutLayout, ck::tensor_layout::convolution::NGKHW>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::DsLayout, ck::Tuple<>>));
|
||||
}
|
||||
|
||||
@@ -190,11 +190,11 @@ TEST(ConvTensorLayout, AssignsLayoutsFor3D_NGCDHW_GKCZYX_NGKDHW)
|
||||
.weight = {.config = {.layout = GKCZYX}},
|
||||
.output = {.config = {.layout = NGKDHW}}};
|
||||
|
||||
using TensorLayouts = ConvTensorLayouts<sig, 3, FORWARD>;
|
||||
using TensorLayouts = ConvTensorLayouts<sig, 3>;
|
||||
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::ALayout, ck::tensor_layout::convolution::NGCDHW>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::BLayout, ck::tensor_layout::convolution::GKCZYX>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::ELayout, ck::tensor_layout::convolution::NGKDHW>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::InLayout, ck::tensor_layout::convolution::NGCDHW>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::WeiLayout, ck::tensor_layout::convolution::GKCZYX>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::OutLayout, ck::tensor_layout::convolution::NGKDHW>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::DsLayout, ck::Tuple<>>));
|
||||
}
|
||||
|
||||
@@ -209,11 +209,11 @@ TEST(ConvTensorLayout, AssignsLayoutsFor3D_NDHWGC_GKZYXC_NDHWGK)
|
||||
.weight = {.config = {.layout = GKZYXC}},
|
||||
.output = {.config = {.layout = NDHWGK}}};
|
||||
|
||||
using TensorLayouts = ConvTensorLayouts<sig, 3, FORWARD>;
|
||||
using TensorLayouts = ConvTensorLayouts<sig, 3>;
|
||||
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::ALayout, ck::tensor_layout::convolution::NDHWGC>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::BLayout, ck::tensor_layout::convolution::GKZYXC>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::ELayout, ck::tensor_layout::convolution::NDHWGK>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::InLayout, ck::tensor_layout::convolution::NDHWGC>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::WeiLayout, ck::tensor_layout::convolution::GKZYXC>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::OutLayout, ck::tensor_layout::convolution::NDHWGK>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::DsLayout, ck::Tuple<>>));
|
||||
}
|
||||
|
||||
@@ -228,11 +228,11 @@ TEST(ConvTensorLayout, AssignsLayoutsFor3D_GNDHWC_GKZYXC_GNDHWK)
|
||||
.weight = {.config = {.layout = GKZYXC}},
|
||||
.output = {.config = {.layout = GNDHWK}}};
|
||||
|
||||
using TensorLayouts = ConvTensorLayouts<sig, 3, FORWARD>;
|
||||
using TensorLayouts = ConvTensorLayouts<sig, 3>;
|
||||
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::ALayout, ck::tensor_layout::convolution::GNDHWC>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::BLayout, ck::tensor_layout::convolution::GKZYXC>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::ELayout, ck::tensor_layout::convolution::GNDHWK>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::InLayout, ck::tensor_layout::convolution::GNDHWC>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::WeiLayout, ck::tensor_layout::convolution::GKZYXC>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::OutLayout, ck::tensor_layout::convolution::GNDHWK>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::DsLayout, ck::Tuple<>>));
|
||||
}
|
||||
|
||||
@@ -273,7 +273,7 @@ TEST(AuxiliaryTensorLayoutIntegration, SingleBiasTensorWithG_K_Layout)
|
||||
static constexpr std::array<MockAuxiliaryTensorConfig, 1> aux_configs = {
|
||||
MockAuxiliaryTensorConfig{.layout = G_K_strided}};
|
||||
|
||||
using AuxLayouts = AuxiliaryTensorLayouts<aux_configs, 2, FORWARD>;
|
||||
using AuxLayouts = AuxiliaryTensorLayouts<aux_configs, 2>;
|
||||
|
||||
EXPECT_EQ(AuxLayouts::Size, 1);
|
||||
using ExpectedType = ck::Tuple<ck::tensor_layout::convolution::G_K>;
|
||||
@@ -287,7 +287,7 @@ TEST(AuxiliaryTensorLayoutIntegration, SingleBiasTensorWithGC_Layout)
|
||||
static constexpr std::array<MockAuxiliaryTensorConfig, 1> aux_configs = {
|
||||
MockAuxiliaryTensorConfig{.layout = TensorLayout::GC}};
|
||||
|
||||
using AuxLayouts = AuxiliaryTensorLayouts<aux_configs, 2, FORWARD>;
|
||||
using AuxLayouts = AuxiliaryTensorLayouts<aux_configs, 2>;
|
||||
|
||||
EXPECT_EQ(AuxLayouts::Size, 1);
|
||||
using ExpectedType = ck::Tuple<ck::tensor_layout::convolution::GC>;
|
||||
@@ -301,7 +301,7 @@ TEST(AuxiliaryTensorLayoutIntegration, SingleBiasTensorWithG_C_Layout)
|
||||
static constexpr std::array<MockAuxiliaryTensorConfig, 1> aux_configs = {
|
||||
MockAuxiliaryTensorConfig{.layout = G_C_strided}};
|
||||
|
||||
using AuxLayouts = AuxiliaryTensorLayouts<aux_configs, 2, FORWARD>;
|
||||
using AuxLayouts = AuxiliaryTensorLayouts<aux_configs, 2>;
|
||||
|
||||
EXPECT_EQ(AuxLayouts::Size, 1);
|
||||
using ExpectedType = ck::Tuple<ck::tensor_layout::convolution::G_C>;
|
||||
@@ -316,7 +316,7 @@ TEST(AuxiliaryTensorLayoutIntegration, TwoAuxiliaryTensors)
|
||||
MockAuxiliaryTensorConfig{.layout = TensorLayout::G_K_strided},
|
||||
MockAuxiliaryTensorConfig{.layout = GC}};
|
||||
|
||||
using AuxLayouts = AuxiliaryTensorLayouts<aux_configs, 2, FORWARD>;
|
||||
using AuxLayouts = AuxiliaryTensorLayouts<aux_configs, 2>;
|
||||
|
||||
EXPECT_EQ(AuxLayouts::Size, 2);
|
||||
using ExpectedType =
|
||||
@@ -333,7 +333,7 @@ TEST(AuxiliaryTensorLayoutIntegration, ThreeAuxiliaryTensors)
|
||||
MockAuxiliaryTensorConfig{.layout = GC},
|
||||
MockAuxiliaryTensorConfig{.layout = G_C_strided}};
|
||||
|
||||
using AuxLayouts = AuxiliaryTensorLayouts<aux_configs, 2, FORWARD>;
|
||||
using AuxLayouts = AuxiliaryTensorLayouts<aux_configs, 2>;
|
||||
|
||||
EXPECT_EQ(AuxLayouts::Size, 3);
|
||||
using ExpectedType = ck::Tuple<ck::tensor_layout::convolution::G_K,
|
||||
@@ -349,7 +349,7 @@ TEST(AuxiliaryTensorLayoutIntegration, WorksWith1DConvolution)
|
||||
static constexpr std::array<MockAuxiliaryTensorConfig, 1> aux_configs = {
|
||||
MockAuxiliaryTensorConfig{.layout = G_K_strided}};
|
||||
|
||||
using AuxLayouts = AuxiliaryTensorLayouts<aux_configs, 1, FORWARD>;
|
||||
using AuxLayouts = AuxiliaryTensorLayouts<aux_configs, 1>;
|
||||
|
||||
EXPECT_EQ(AuxLayouts::Size, 1);
|
||||
using ExpectedType = ck::Tuple<ck::tensor_layout::convolution::G_K>;
|
||||
@@ -363,7 +363,7 @@ TEST(AuxiliaryTensorLayoutIntegration, WorksWith3DConvolution)
|
||||
static constexpr std::array<MockAuxiliaryTensorConfig, 1> aux_configs = {
|
||||
MockAuxiliaryTensorConfig{.layout = GC}};
|
||||
|
||||
using AuxLayouts = AuxiliaryTensorLayouts<aux_configs, 3, FORWARD>;
|
||||
using AuxLayouts = AuxiliaryTensorLayouts<aux_configs, 3>;
|
||||
|
||||
EXPECT_EQ(AuxLayouts::Size, 1);
|
||||
using ExpectedType = ck::Tuple<ck::tensor_layout::convolution::GC>;
|
||||
@@ -387,11 +387,11 @@ TEST(ConvTensorLayoutsWithAuxiliary, Conv2DWithSingleBiasG_K)
|
||||
.operation =
|
||||
OutputOp{.elementwise_operation = ElementwiseOperation::SCALE}}};
|
||||
|
||||
using TensorLayouts = ConvTensorLayouts<sig, 2, FORWARD>;
|
||||
using TensorLayouts = ConvTensorLayouts<sig, 2>;
|
||||
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::ALayout, ck::tensor_layout::convolution::NGCHW>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::BLayout, ck::tensor_layout::convolution::GKYXC>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::ELayout, ck::tensor_layout::convolution::NGKHW>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::InLayout, ck::tensor_layout::convolution::NGCHW>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::WeiLayout, ck::tensor_layout::convolution::GKYXC>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::OutLayout, ck::tensor_layout::convolution::NGKHW>));
|
||||
|
||||
using ExpectedDsLayout = ck::Tuple<ck::tensor_layout::convolution::G_K>;
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::DsLayout, ExpectedDsLayout>));
|
||||
@@ -414,11 +414,11 @@ TEST(ConvTensorLayoutsWithAuxiliary, Conv2DWithSingleBiasGC)
|
||||
.operation =
|
||||
OutputOp{.elementwise_operation = ElementwiseOperation::SCALE}}};
|
||||
|
||||
using TensorLayouts = ConvTensorLayouts<sig, 2, FORWARD>;
|
||||
using TensorLayouts = ConvTensorLayouts<sig, 2>;
|
||||
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::ALayout, ck::tensor_layout::convolution::NHWGC>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::BLayout, ck::tensor_layout::convolution::GKYXC>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::ELayout, ck::tensor_layout::convolution::NHWGK>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::InLayout, ck::tensor_layout::convolution::NHWGC>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::WeiLayout, ck::tensor_layout::convolution::GKYXC>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::OutLayout, ck::tensor_layout::convolution::NHWGK>));
|
||||
|
||||
using ExpectedDsLayout = ck::Tuple<ck::tensor_layout::convolution::GC>;
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::DsLayout, ExpectedDsLayout>));
|
||||
@@ -442,11 +442,11 @@ TEST(ConvTensorLayoutsWithAuxiliary, Conv2DWithTwoAuxiliaryTensors)
|
||||
.operation = OutputOp{.elementwise_operation =
|
||||
ElementwiseOperation::SCALEADD_SCALEADD_RELU}}};
|
||||
|
||||
using TensorLayouts = ConvTensorLayouts<sig, 2, FORWARD>;
|
||||
using TensorLayouts = ConvTensorLayouts<sig, 2>;
|
||||
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::ALayout, ck::tensor_layout::convolution::GNHWC>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::BLayout, ck::tensor_layout::convolution::GKYXC>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::ELayout, ck::tensor_layout::convolution::GNHWK>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::InLayout, ck::tensor_layout::convolution::GNHWC>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::WeiLayout, ck::tensor_layout::convolution::GKYXC>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::OutLayout, ck::tensor_layout::convolution::GNHWK>));
|
||||
|
||||
using ExpectedDsLayout =
|
||||
ck::Tuple<ck::tensor_layout::convolution::G_K, ck::tensor_layout::convolution::GC>;
|
||||
@@ -470,11 +470,11 @@ TEST(ConvTensorLayoutsWithAuxiliary, Conv1DWithBias)
|
||||
.operation =
|
||||
OutputOp{.elementwise_operation = ElementwiseOperation::SCALE}}};
|
||||
|
||||
using TensorLayouts = ConvTensorLayouts<sig, 1, FORWARD>;
|
||||
using TensorLayouts = ConvTensorLayouts<sig, 1>;
|
||||
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::ALayout, ck::tensor_layout::convolution::NWGC>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::BLayout, ck::tensor_layout::convolution::GKXC>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::ELayout, ck::tensor_layout::convolution::NWGK>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::InLayout, ck::tensor_layout::convolution::NWGC>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::WeiLayout, ck::tensor_layout::convolution::GKXC>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::OutLayout, ck::tensor_layout::convolution::NWGK>));
|
||||
|
||||
using ExpectedDsLayout = ck::Tuple<ck::tensor_layout::convolution::G_K>;
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::DsLayout, ExpectedDsLayout>));
|
||||
@@ -497,11 +497,11 @@ TEST(ConvTensorLayoutsWithAuxiliary, Conv3DWithBias)
|
||||
.operation = OutputOp{.elementwise_operation =
|
||||
ElementwiseOperation::BIAS_BNORM_CLAMP}}};
|
||||
|
||||
using TensorLayouts = ConvTensorLayouts<sig, 3, FORWARD>;
|
||||
using TensorLayouts = ConvTensorLayouts<sig, 3>;
|
||||
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::ALayout, ck::tensor_layout::convolution::NDHWGC>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::BLayout, ck::tensor_layout::convolution::GKZYXC>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::ELayout, ck::tensor_layout::convolution::NDHWGK>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::InLayout, ck::tensor_layout::convolution::NDHWGC>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::WeiLayout, ck::tensor_layout::convolution::GKZYXC>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::OutLayout, ck::tensor_layout::convolution::NDHWGK>));
|
||||
|
||||
using ExpectedDsLayout = ck::Tuple<ck::tensor_layout::convolution::G_C>;
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::DsLayout, ExpectedDsLayout>));
|
||||
|
||||
@@ -19,7 +19,7 @@ TEST(ConvTuningParams, AssignsBlockGemmParams)
|
||||
{
|
||||
ckb::PipelineVersion pipeline_version = ckb::PipelineVersion::V3;
|
||||
ckb::PipelineScheduler scheduler = ckb::PipelineScheduler::INTRAWAVE;
|
||||
} block_gemm;
|
||||
} block_gemm_pipeline;
|
||||
} kAlgorithm;
|
||||
constexpr auto block_gemm = SetBlockGemm<kAlgorithm>();
|
||||
|
||||
@@ -42,10 +42,7 @@ TEST(ConvTuningParams, AssignsGridwiseGemmPipelineVersion)
|
||||
{
|
||||
constexpr struct Algorithm
|
||||
{
|
||||
struct GridwiseGemm
|
||||
{
|
||||
ckb::PipelineVersion pipeline_version = ckb::PipelineVersion::V4;
|
||||
} gridwise_gemm;
|
||||
ckb::PipelineVersion pipeline_version = ckb::PipelineVersion::V4;
|
||||
} kAlgorithm;
|
||||
constexpr auto pipeline_version = SetGridwiseGemmPipelineVersion<kAlgorithm>();
|
||||
|
||||
@@ -78,8 +75,8 @@ TEST(ConvTuningParams, AssignsFwdConvSpecialization)
|
||||
{
|
||||
constexpr struct Algorithm
|
||||
{
|
||||
ckb::ConvFwdSpecialization fwd_specialization =
|
||||
ckb::ConvFwdSpecialization::FILTER_1X1_STRIDE1_PAD0;
|
||||
ckb::ConvSpecialization fwd_specialization =
|
||||
ckb::ConvSpecialization::FILTER_1X1_STRIDE1_PAD0;
|
||||
} kAlgorithm;
|
||||
constexpr auto conv_spec = SetFwdConvSpecialization<kAlgorithm>();
|
||||
|
||||
|
||||
@@ -15,31 +15,42 @@ using namespace test;
|
||||
constexpr DlThreadConfig DlThreadConfig_16x2x4x4x1{
|
||||
.k0_per_block = 16, .k1 = 2, .m1_per_thread = 4, .n1_per_thread = 4, .k_per_thread = 1};
|
||||
|
||||
constexpr DlThreadConfig DlThreadConfig_16x1x4x4x1{
|
||||
.k0_per_block = 16, .k1 = 1, .m1_per_thread = 4, .n1_per_thread = 4, .k_per_thread = 1};
|
||||
|
||||
constexpr DlThreadCluster DlThreadCluster_8x2{.m1_xs = {8, 2}, .n1_xs = {8, 2}};
|
||||
|
||||
constexpr DlBlockTransfer DlBlockTransferAB{.thread_slice_lengths = {8, 1, 1, 2},
|
||||
.thread_cluster_lengths = {2, 1, 128, 1},
|
||||
.thread_cluster_arrange_order = {1, 2, 0, 3},
|
||||
.src_access_order = {1, 2, 0, 3},
|
||||
.src_vector_tensor_lengths = {4, 1, 1, 2},
|
||||
.src_vector_tensor_contiguous_dim_order = {1, 2, 0, 3},
|
||||
.dst_vector_tensor_lengths = {1, 1, 1, 2}};
|
||||
constexpr DlBlockTransfer<4> DlBlockTransfer_8x1x1x2{
|
||||
.thread_slice_lengths = {8, 1, 1, 2},
|
||||
.thread_cluster_lengths = {2, 1, 128, 1},
|
||||
.thread_cluster_arrange_order = {1, 2, 0, 3},
|
||||
.src_access_order = {1, 2, 0, 3},
|
||||
.src_vector_tensor_lengths = {4, 1, 1, 2},
|
||||
.src_vector_tensor_contiguous_dim_order = {1, 2, 0, 3},
|
||||
.dst_vector_tensor_lengths = {1, 1, 1, 2}};
|
||||
|
||||
constexpr DlTransferABC DlFwdTransfer{.a =
|
||||
{
|
||||
.block_transfer = DlBlockTransferAB,
|
||||
},
|
||||
.b =
|
||||
{
|
||||
.block_transfer = DlBlockTransferAB,
|
||||
},
|
||||
.c = {
|
||||
.epilogue = {.src_dst_access_order = {0, 1, 2, 3, 4, 5},
|
||||
.src_dst_vector_dim = 5,
|
||||
.dst_scalar_per_vector = 4},
|
||||
}};
|
||||
constexpr DlTransfer<4> DlTransfer4D{.a = DlBlockTransfer_8x1x1x2,
|
||||
.b = DlBlockTransfer_8x1x1x2,
|
||||
.c = {.src_dst_access_order = {0, 1, 2, 3, 4, 5},
|
||||
.src_dst_vector_dim = 5,
|
||||
.dst_scalar_per_vector = 4}};
|
||||
|
||||
constexpr TransferABC FwdTransfer_4x64x1{
|
||||
constexpr DlBlockTransfer<5> DlBlockTransfer_1x8x1x1x1{
|
||||
.thread_slice_lengths = {1, 8, 1, 1, 1},
|
||||
.thread_cluster_lengths = {1, 2, 1, 128, 1},
|
||||
.thread_cluster_arrange_order = {0, 2, 3, 1, 4},
|
||||
.src_access_order = {0, 2, 3, 1, 4},
|
||||
.src_vector_tensor_lengths = {1, 1, 1, 1, 1},
|
||||
.src_vector_tensor_contiguous_dim_order = {0, 2, 3, 1, 4},
|
||||
.dst_vector_tensor_lengths = {1, 1, 1, 1, 1}};
|
||||
|
||||
constexpr DlTransfer<5> DlTransfer5D{.a = DlBlockTransfer_1x8x1x1x1,
|
||||
.b = DlBlockTransfer_1x8x1x1x1,
|
||||
.c = {.src_dst_access_order = {0, 1, 2, 3, 4, 5},
|
||||
.src_dst_vector_dim = 5,
|
||||
.dst_scalar_per_vector = 1}};
|
||||
|
||||
constexpr Transfer<> Transfer_4x64x1{
|
||||
.a =
|
||||
{
|
||||
.block_transfer = {.k0 = 4, .m_n = 64, .k1 = 1},
|
||||
@@ -72,7 +83,73 @@ constexpr TransferABC FwdTransfer_4x64x1{
|
||||
},
|
||||
};
|
||||
|
||||
constexpr TransferABC FwdTransfer_4x64x1_fp8{
|
||||
constexpr Transfer<4> BwdTransfer_4x64x1{
|
||||
.a =
|
||||
{
|
||||
.block_transfer = {.k0 = 4, .m_n = 64, .k1 = 1, .k_batch_size = 1},
|
||||
.lds_transfer = {.src_vector_dim = 2,
|
||||
.src_scalar_per_vector = 2,
|
||||
.lds_dst_scalar_per_vector = 4,
|
||||
.is_direct_load = false,
|
||||
.lds_padding = true},
|
||||
.block_transfer_access_order = {0, 3, 1, 2},
|
||||
.src_access_order = {0, 2, 1, 3},
|
||||
},
|
||||
.b =
|
||||
{
|
||||
.block_transfer = {.k0 = 4, .m_n = 64, .k1 = 1, .k_batch_size = 1},
|
||||
.lds_transfer = {.src_vector_dim = 2,
|
||||
.src_scalar_per_vector = 2,
|
||||
.lds_dst_scalar_per_vector = 4,
|
||||
.is_direct_load = false,
|
||||
.lds_padding = true},
|
||||
.block_transfer_access_order = {0, 3, 1, 2},
|
||||
.src_access_order = {0, 2, 1, 3},
|
||||
},
|
||||
.c =
|
||||
{
|
||||
.thread_cluster_dims =
|
||||
{.m_block = 1, .m_wave_per_xdl = 32, .n_block = 1, .n_wave_per_xdl = 8},
|
||||
.epilogue = {.m_xdl_per_wave_per_shuffle = 1,
|
||||
.n_per_wave_per_shuffle = 1,
|
||||
.scalar_per_vector = 8},
|
||||
},
|
||||
};
|
||||
|
||||
constexpr Transfer<> BwdTransfer_4x8x1_4x16x1_v3{
|
||||
.a =
|
||||
{
|
||||
.block_transfer = {.k0 = 4, .m_n = 8, .k1 = 1},
|
||||
.lds_transfer = {.src_vector_dim = 1,
|
||||
.src_scalar_per_vector = 2,
|
||||
.lds_dst_scalar_per_vector = 2,
|
||||
.is_direct_load = false,
|
||||
.lds_padding = false},
|
||||
.block_transfer_access_order = {2, 0, 1},
|
||||
.src_access_order = {1, 0, 2},
|
||||
},
|
||||
.b =
|
||||
{
|
||||
.block_transfer = {.k0 = 4, .m_n = 16, .k1 = 1},
|
||||
.lds_transfer = {.src_vector_dim = 1,
|
||||
.src_scalar_per_vector = 2,
|
||||
.lds_dst_scalar_per_vector = 2,
|
||||
.is_direct_load = false,
|
||||
.lds_padding = false},
|
||||
.block_transfer_access_order = {2, 0, 1},
|
||||
.src_access_order = {1, 0, 2},
|
||||
},
|
||||
.c =
|
||||
{
|
||||
.thread_cluster_dims =
|
||||
{.m_block = 1, .m_wave_per_xdl = 8, .n_block = 1, .n_wave_per_xdl = 8},
|
||||
.epilogue = {.m_xdl_per_wave_per_shuffle = 1,
|
||||
.n_per_wave_per_shuffle = 1,
|
||||
.scalar_per_vector = 2},
|
||||
},
|
||||
};
|
||||
|
||||
constexpr Transfer<> Transfer_4x64x1_fp8{
|
||||
.a =
|
||||
{
|
||||
.block_transfer = {.k0 = 4, .m_n = 64, .k1 = 1},
|
||||
@@ -105,7 +182,7 @@ constexpr TransferABC FwdTransfer_4x64x1_fp8{
|
||||
},
|
||||
};
|
||||
|
||||
constexpr TransferABC FwdTransfer_4x16x1{
|
||||
constexpr Transfer<> Transfer_4x16x1{
|
||||
.a =
|
||||
{
|
||||
.block_transfer = {.k0 = 4, .m_n = 16, .k1 = 1},
|
||||
@@ -139,7 +216,7 @@ constexpr TransferABC FwdTransfer_4x16x1{
|
||||
},
|
||||
};
|
||||
|
||||
constexpr TransferABC FwdTransfer_4x32x1{
|
||||
constexpr Transfer<> Transfer_4x32x1{
|
||||
.a =
|
||||
{
|
||||
.block_transfer = {.k0 = 4, .m_n = 32, .k1 = 1},
|
||||
@@ -172,59 +249,80 @@ constexpr TransferABC FwdTransfer_4x32x1{
|
||||
},
|
||||
};
|
||||
|
||||
constexpr GridwiseXdlGemm FwdGemmParams_Xdl_4x4_per_wave{
|
||||
.ak1 = 8, .bk1 = 8, .m_per_xdl = 32, .n_per_xdl = 32, .m_xdl_per_wave = 4, .n_xdl_per_wave = 4};
|
||||
constexpr GridwiseBwdXdlGemm BwdGemmParams_Xdl_4x4_per_wave{
|
||||
.k1 = 8,
|
||||
.xdl_params = {.m_per_xdl = 32, .n_per_xdl = 32, .m_xdl_per_wave = 4, .n_xdl_per_wave = 4}};
|
||||
|
||||
constexpr GridwiseXdlGemm FwdGemmParams_Xdl_4x2_per_wave{
|
||||
.ak1 = 8, .bk1 = 8, .m_per_xdl = 32, .n_per_xdl = 32, .m_xdl_per_wave = 4, .n_xdl_per_wave = 2};
|
||||
constexpr GridwiseBwdXdlGemm BwdGemmParams_Xdl_1x1_per_wave{
|
||||
.k1 = 8,
|
||||
.xdl_params = {.m_per_xdl = 32, .n_per_xdl = 32, .m_xdl_per_wave = 1, .n_xdl_per_wave = 1}};
|
||||
|
||||
constexpr GridwiseXdlGemm FwdGemmParams_Xdl_2x2_per_wave{
|
||||
.ak1 = 8, .bk1 = 8, .m_per_xdl = 32, .n_per_xdl = 32, .m_xdl_per_wave = 2, .n_xdl_per_wave = 2};
|
||||
constexpr GridwiseFwdXdlGemm FwdGemmParams_Xdl_4x4_per_wave{
|
||||
.ak1 = 8,
|
||||
.bk1 = 8,
|
||||
.xdl_params = {.m_per_xdl = 32, .n_per_xdl = 32, .m_xdl_per_wave = 4, .n_xdl_per_wave = 4}};
|
||||
|
||||
constexpr GridwiseXdlGemm FwdGemmParams_Xdl_2x1_per_wave{
|
||||
.ak1 = 8, .bk1 = 8, .m_per_xdl = 32, .n_per_xdl = 32, .m_xdl_per_wave = 2, .n_xdl_per_wave = 1};
|
||||
constexpr GridwiseFwdXdlGemm FwdGemmParams_Xdl_4x2_per_wave{
|
||||
.ak1 = 8,
|
||||
.bk1 = 8,
|
||||
.xdl_params = {.m_per_xdl = 32, .n_per_xdl = 32, .m_xdl_per_wave = 4, .n_xdl_per_wave = 2}};
|
||||
|
||||
constexpr GridwiseWmmaGemm FwdGemmParams_Wmma_2x1_per_wave{.k1 = 8,
|
||||
.m_per_wmma = 32,
|
||||
.n_per_wmma = 32,
|
||||
.m_wmma_per_wave = 2,
|
||||
.n_wmma_per_wave = 1,
|
||||
.pipeline_version = PipelineVersion::V1};
|
||||
constexpr GridwiseFwdXdlGemm FwdGemmParams_Xdl_2x2_per_wave{
|
||||
.ak1 = 8,
|
||||
.bk1 = 8,
|
||||
.xdl_params = {.m_per_xdl = 32, .n_per_xdl = 32, .m_xdl_per_wave = 2, .n_xdl_per_wave = 2}};
|
||||
|
||||
constexpr ThreadBlock FwdThreadBlock_256_256x256x32{.block_size = 256,
|
||||
.tile_size = {.m = 256, .n = 256, .k = 32}};
|
||||
constexpr GridwiseFwdXdlGemm FwdGemmParams_Xdl_2x1_per_wave{
|
||||
.ak1 = 8,
|
||||
.bk1 = 8,
|
||||
.xdl_params = {.m_per_xdl = 32, .n_per_xdl = 32, .m_xdl_per_wave = 2, .n_xdl_per_wave = 1}};
|
||||
|
||||
constexpr ThreadBlock FwdThreadBlock_256_256x128x32{.block_size = 256,
|
||||
.tile_size = {.m = 256, .n = 128, .k = 32}};
|
||||
constexpr GridwiseWmmaGemm GemmParams_Wmma_2x1_per_wave{
|
||||
.k1 = 8, .m_per_wmma = 32, .n_per_wmma = 32, .m_wmma_per_wave = 2, .n_wmma_per_wave = 1};
|
||||
|
||||
constexpr ThreadBlock FwdThreadBlock_256_128x128x32{.block_size = 256,
|
||||
.tile_size = {.m = 128, .n = 128, .k = 32}};
|
||||
constexpr GridwiseWmmaGemm GemmParams_Wmma_16x16_2x1_per_wave{
|
||||
.k1 = 8, .m_per_wmma = 16, .n_per_wmma = 16, .m_wmma_per_wave = 2, .n_wmma_per_wave = 1};
|
||||
|
||||
constexpr ThreadBlock FwdThreadBlock_256_128x128x16{.block_size = 256,
|
||||
.tile_size = {.m = 128, .n = 128, .k = 16}};
|
||||
constexpr ThreadBlock ThreadBlock_256_256x256x32{.block_size = 256,
|
||||
.tile_size = {.m = 256, .n = 256, .k = 32}};
|
||||
|
||||
constexpr ThreadBlock FwdThreadBlock_64_64x32x32{.block_size = 64,
|
||||
.tile_size = {.m = 64, .n = 32, .k = 32}};
|
||||
constexpr ThreadBlock ThreadBlock_256_256x128x32{.block_size = 256,
|
||||
.tile_size = {.m = 256, .n = 128, .k = 32}};
|
||||
|
||||
constexpr ThreadBlock FwdThreadBlock_128_128x128x32{.block_size = 128,
|
||||
.tile_size = {.m = 128, .n = 128, .k = 32}};
|
||||
constexpr ThreadBlock ThreadBlock_256_128x128x32{.block_size = 256,
|
||||
.tile_size = {.m = 128, .n = 128, .k = 32}};
|
||||
|
||||
constexpr ThreadBlock FwdThreadBlock_128_64x64x64{.block_size = 128,
|
||||
.tile_size = {.m = 64, .n = 64, .k = 64}};
|
||||
constexpr ThreadBlock ThreadBlock_256_128x128x16{.block_size = 256,
|
||||
.tile_size = {.m = 128, .n = 128, .k = 16}};
|
||||
|
||||
constexpr BlockGemm BlockGemmDesc_v1_intrawave = {.pipeline_version = PipelineVersion::V1,
|
||||
.scheduler = PipelineScheduler::INTRAWAVE};
|
||||
constexpr ThreadBlock ThreadBlock_256_128x128x8{.block_size = 256,
|
||||
.tile_size = {.m = 128, .n = 128, .k = 8}};
|
||||
|
||||
constexpr BlockGemm BlockGemmDesc_v2_intrawave = {.pipeline_version = PipelineVersion::V2,
|
||||
.scheduler = PipelineScheduler::INTRAWAVE};
|
||||
constexpr ThreadBlock ThreadBlock_64_64x32x32{.block_size = 64,
|
||||
.tile_size = {.m = 64, .n = 32, .k = 32}};
|
||||
|
||||
constexpr BlockGemm BlockGemmDesc_v3_intrawave = {.pipeline_version = PipelineVersion::V3,
|
||||
.scheduler = PipelineScheduler::INTRAWAVE};
|
||||
constexpr ThreadBlock ThreadBlock_64_32x32x32{.block_size = 64,
|
||||
.tile_size = {.m = 32, .n = 32, .k = 32}};
|
||||
|
||||
constexpr BlockGemm BlockGemmDesc_v4_intrawave = {.pipeline_version = PipelineVersion::V4,
|
||||
.scheduler = PipelineScheduler::INTRAWAVE};
|
||||
constexpr ThreadBlock ThreadBlock_128_128x128x32{.block_size = 128,
|
||||
.tile_size = {.m = 128, .n = 128, .k = 32}};
|
||||
|
||||
constexpr BlockGemm BlockGemmDesc_v5_intrawave = {.pipeline_version = PipelineVersion::V5,
|
||||
.scheduler = PipelineScheduler::INTRAWAVE};
|
||||
constexpr ThreadBlock ThreadBlock_128_64x64x64{.block_size = 128,
|
||||
.tile_size = {.m = 64, .n = 64, .k = 64}};
|
||||
|
||||
constexpr BlockGemmPipeline BlockGemmDesc_v1_intrawave = {
|
||||
.pipeline_version = PipelineVersion::V1, .scheduler = PipelineScheduler::INTRAWAVE};
|
||||
|
||||
constexpr BlockGemmPipeline BlockGemmDesc_v2_intrawave = {
|
||||
.pipeline_version = PipelineVersion::V2, .scheduler = PipelineScheduler::INTRAWAVE};
|
||||
|
||||
constexpr BlockGemmPipeline BlockGemmDesc_v3_intrawave = {
|
||||
.pipeline_version = PipelineVersion::V3, .scheduler = PipelineScheduler::INTRAWAVE};
|
||||
|
||||
constexpr BlockGemmPipeline BlockGemmDesc_v4_intrawave = {
|
||||
.pipeline_version = PipelineVersion::V4, .scheduler = PipelineScheduler::INTRAWAVE};
|
||||
|
||||
constexpr BlockGemmPipeline BlockGemmDesc_v5_intrawave = {
|
||||
.pipeline_version = PipelineVersion::V5, .scheduler = PipelineScheduler::INTRAWAVE};
|
||||
|
||||
} // namespace ck_tile::builder::test_utils
|
||||
|
||||
@@ -12,35 +12,35 @@ namespace ck_tile::builder::test_utils {
|
||||
using namespace ck_tile::builder;
|
||||
using namespace test;
|
||||
|
||||
constexpr TileTransfer FwdTileTransfer_1x1x1{
|
||||
constexpr TileTransfer TileTransfer_1x1x1{
|
||||
.a_scalar_per_vector = 1,
|
||||
.b_scalar_per_vector = 1,
|
||||
.c_scalar_per_vector = 1,
|
||||
};
|
||||
|
||||
constexpr TileTransfer FwdTileTransfer_4x4x4{
|
||||
constexpr TileTransfer TileTransfer_4x4x4{
|
||||
.a_scalar_per_vector = 4,
|
||||
.b_scalar_per_vector = 4,
|
||||
.c_scalar_per_vector = 4,
|
||||
};
|
||||
|
||||
constexpr TileTransfer FwdTileTransfer_8x8x8{
|
||||
constexpr TileTransfer TileTransfer_8x8x8{
|
||||
.a_scalar_per_vector = 8,
|
||||
.b_scalar_per_vector = 8,
|
||||
.c_scalar_per_vector = 8,
|
||||
};
|
||||
|
||||
constexpr TileThreadBlock FwdTileThreadBlock_256x256x32{.tile_size = {.m = 256, .n = 256, .k = 32}};
|
||||
constexpr TileThreadBlock TileThreadBlock_256x256x32{.tile_size = {.m = 256, .n = 256, .k = 32}};
|
||||
|
||||
constexpr TileThreadBlock FwdTileThreadBlock_256x128x32{.tile_size = {.m = 256, .n = 128, .k = 32}};
|
||||
constexpr TileThreadBlock TileThreadBlock_256x128x32{.tile_size = {.m = 256, .n = 128, .k = 32}};
|
||||
|
||||
constexpr TileThreadBlock FwdTileThreadBlock_128x128x32{.tile_size = {.m = 128, .n = 128, .k = 32}};
|
||||
constexpr TileThreadBlock TileThreadBlock_128x128x32{.tile_size = {.m = 128, .n = 128, .k = 32}};
|
||||
|
||||
constexpr TileThreadBlock FwdTileThreadBlock_128x128x16{.tile_size = {.m = 128, .n = 128, .k = 16}};
|
||||
constexpr TileThreadBlock TileThreadBlock_128x128x16{.tile_size = {.m = 128, .n = 128, .k = 16}};
|
||||
|
||||
constexpr TileThreadBlock FwdTileThreadBlock_64x32x32{.tile_size = {.m = 64, .n = 32, .k = 32}};
|
||||
constexpr TileThreadBlock TileThreadBlock_64x32x32{.tile_size = {.m = 64, .n = 32, .k = 32}};
|
||||
|
||||
constexpr TileThreadBlock FwdTileThreadBlock_64x64x64{.tile_size = {.m = 64, .n = 64, .k = 64}};
|
||||
constexpr TileThreadBlock TileThreadBlock_64x64x64{.tile_size = {.m = 64, .n = 64, .k = 64}};
|
||||
|
||||
constexpr TileBlockGemm TileBlockGemmDesc_16x16_v1_intrawave = {
|
||||
.warps = {.m = 2, .n = 2, .k = 1},
|
||||
|
||||
@@ -54,7 +54,7 @@ inline std::string to_string<PipelineScheduler>(PipelineScheduler t)
|
||||
}
|
||||
|
||||
template <>
|
||||
inline std::string to_string<ConvFwdSpecialization>(ConvFwdSpecialization t)
|
||||
inline std::string to_string<ConvSpecialization>(ConvSpecialization t)
|
||||
{
|
||||
std::ostringstream oss;
|
||||
oss << t;
|
||||
@@ -86,11 +86,20 @@ inline std::string to_string<ThreadBlock>(ThreadBlock t)
|
||||
}
|
||||
|
||||
template <>
|
||||
inline std::string to_string<GridwiseXdlGemm>(GridwiseXdlGemm t)
|
||||
inline std::string to_string<GridwiseBwdXdlGemm>(GridwiseBwdXdlGemm t)
|
||||
{
|
||||
std::ostringstream oss;
|
||||
oss << t.ak1 << "," << t.bk1 << "," << t.m_per_xdl << "," << t.n_per_xdl << ","
|
||||
<< t.m_xdl_per_wave << "," << t.n_xdl_per_wave;
|
||||
oss << t.k1 << "," << t.xdl_params.m_per_xdl << "," << t.xdl_params.n_per_xdl << ","
|
||||
<< t.xdl_params.m_xdl_per_wave << "," << t.xdl_params.n_xdl_per_wave;
|
||||
return oss.str();
|
||||
}
|
||||
|
||||
template <>
|
||||
inline std::string to_string<GridwiseFwdXdlGemm>(GridwiseFwdXdlGemm t)
|
||||
{
|
||||
std::ostringstream oss;
|
||||
oss << t.ak1 << "," << t.bk1 << "," << t.xdl_params.m_per_xdl << "," << t.xdl_params.n_per_xdl
|
||||
<< "," << t.xdl_params.m_xdl_per_wave << "," << t.xdl_params.n_xdl_per_wave;
|
||||
return oss.str();
|
||||
}
|
||||
|
||||
@@ -104,17 +113,29 @@ inline std::string to_string<GridwiseWmmaGemm>(GridwiseWmmaGemm t)
|
||||
}
|
||||
|
||||
template <>
|
||||
inline std::string to_string<BlockGemm>(BlockGemm t)
|
||||
inline std::string to_string<BlockGemmPipeline>(BlockGemmPipeline t)
|
||||
{
|
||||
std::ostringstream oss;
|
||||
oss << to_string(t.scheduler) << "," << to_string(t.pipeline_version);
|
||||
return oss.str();
|
||||
}
|
||||
|
||||
template <>
|
||||
inline std::string to_string<BlockTransfer>(BlockTransfer t)
|
||||
template <size_t ThreadClusterRank>
|
||||
inline std::string to_string(BlockTransfer<ThreadClusterRank> t)
|
||||
{
|
||||
return array_to_seq(std::array<size_t, 3>{t.k0, t.m_n, t.k1});
|
||||
if constexpr(ThreadClusterRank == 4)
|
||||
{
|
||||
return array_to_seq(std::array<size_t, 4>{t.k_batch_size, t.k0, t.m_n, t.k1});
|
||||
}
|
||||
else if constexpr(ThreadClusterRank == 3)
|
||||
{
|
||||
return array_to_seq(std::array<size_t, 3>{t.k0, t.m_n, t.k1});
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(ThreadClusterRank == 3 || ThreadClusterRank == 4,
|
||||
"Unsupported ThreadClusterRank");
|
||||
}
|
||||
}
|
||||
|
||||
template <>
|
||||
@@ -134,14 +155,14 @@ inline std::string to_string<LdsTransfer>(LdsTransfer t)
|
||||
return oss.str();
|
||||
}
|
||||
|
||||
template <>
|
||||
inline std::string to_string<AccessOrder>(AccessOrder t)
|
||||
template <size_t N>
|
||||
inline std::string to_string(AccessOrder<N> t)
|
||||
{
|
||||
return array_to_seq(t.order);
|
||||
}
|
||||
|
||||
template <>
|
||||
inline std::string to_string<TransferAB>(TransferAB t)
|
||||
template <size_t N = 3>
|
||||
inline std::string to_string(InputTransfer<N> t)
|
||||
{
|
||||
std::ostringstream oss;
|
||||
oss << to_string(t.block_transfer) << "," << to_string(t.block_transfer_access_order) << ","
|
||||
@@ -152,7 +173,7 @@ inline std::string to_string<TransferAB>(TransferAB t)
|
||||
}
|
||||
|
||||
template <>
|
||||
inline std::string to_string<TransferC>(TransferC t)
|
||||
inline std::string to_string<OutputTransfer>(OutputTransfer t)
|
||||
{
|
||||
std::ostringstream oss;
|
||||
oss << t.epilogue.m_xdl_per_wave_per_shuffle << "," << t.epilogue.n_per_wave_per_shuffle << ","
|
||||
@@ -160,8 +181,8 @@ inline std::string to_string<TransferC>(TransferC t)
|
||||
return oss.str();
|
||||
}
|
||||
|
||||
template <>
|
||||
inline std::string to_string<TransferABC>(TransferABC t)
|
||||
template <size_t N = 3>
|
||||
inline std::string to_string(Transfer<N> t)
|
||||
{
|
||||
std::ostringstream oss;
|
||||
oss << to_string(t.a) << "," << to_string(t.b) << "," << to_string(t.c);
|
||||
@@ -185,7 +206,19 @@ inline std::string to_string<DlThreadCluster>(DlThreadCluster t)
|
||||
}
|
||||
|
||||
template <>
|
||||
inline std::string to_string<DlBlockTransfer>(DlBlockTransfer t)
|
||||
inline std::string to_string<DlBlockTransfer<4>>(DlBlockTransfer<4> t)
|
||||
{
|
||||
std::ostringstream oss;
|
||||
oss << array_to_seq(t.thread_slice_lengths) << "," << array_to_seq(t.thread_cluster_lengths)
|
||||
<< "," << array_to_seq(t.thread_cluster_arrange_order) << ","
|
||||
<< array_to_seq(t.src_access_order) << "," << array_to_seq(t.src_vector_tensor_lengths)
|
||||
<< "," << array_to_seq(t.src_vector_tensor_contiguous_dim_order) << ","
|
||||
<< array_to_seq(t.dst_vector_tensor_lengths);
|
||||
return oss.str();
|
||||
}
|
||||
|
||||
template <>
|
||||
inline std::string to_string<DlBlockTransfer<5>>(DlBlockTransfer<5> t)
|
||||
{
|
||||
std::ostringstream oss;
|
||||
oss << array_to_seq(t.thread_slice_lengths) << "," << array_to_seq(t.thread_cluster_lengths)
|
||||
@@ -206,19 +239,24 @@ inline std::string to_string<DlEpilogue>(DlEpilogue t)
|
||||
}
|
||||
|
||||
template <>
|
||||
inline std::string to_string<DlBlockTransferAB>(DlBlockTransferAB t)
|
||||
inline std::string to_string<TransposeParams_>(TransposeParams_ t)
|
||||
{
|
||||
return to_string(t.block_transfer);
|
||||
std::ostringstream oss;
|
||||
oss << t.max_transpose_transfer_src_scalar_per_vector << ","
|
||||
<< t.max_transpose_transfer_dst_scalar_per_vector;
|
||||
return oss.str();
|
||||
}
|
||||
|
||||
template <>
|
||||
inline std::string to_string<DlBlockTransferC>(DlBlockTransferC t)
|
||||
inline std::string to_string<DlTransfer<4>>(DlTransfer<4> t)
|
||||
{
|
||||
return to_string(t.epilogue);
|
||||
std::ostringstream oss;
|
||||
oss << to_string(t.a) << "," << to_string(t.b) << "," << to_string(t.c);
|
||||
return oss.str();
|
||||
}
|
||||
|
||||
template <>
|
||||
inline std::string to_string<DlTransferABC>(DlTransferABC t)
|
||||
inline std::string to_string<DlTransfer<5>>(DlTransfer<5> t)
|
||||
{
|
||||
std::ostringstream oss;
|
||||
oss << to_string(t.a) << "," << to_string(t.b) << "," << to_string(t.c);
|
||||
@@ -234,7 +272,13 @@ inline std::string to_string<ThreadBlock_>(ThreadBlock_ t)
|
||||
}
|
||||
|
||||
template <>
|
||||
inline std::string to_string<XdlGemm_>(XdlGemm_ t)
|
||||
inline std::string to_string<FwdXdlGemm_>(FwdXdlGemm_ t)
|
||||
{
|
||||
return to_string(t.gridwise_gemm);
|
||||
}
|
||||
|
||||
template <>
|
||||
inline std::string to_string<BwdXdlGemm_>(BwdXdlGemm_ t)
|
||||
{
|
||||
return to_string(t.gridwise_gemm);
|
||||
}
|
||||
@@ -245,33 +289,40 @@ inline std::string to_string<WmmaGemm_>(WmmaGemm_ t)
|
||||
return to_string(t.gridwise_gemm);
|
||||
}
|
||||
|
||||
template <>
|
||||
inline std::string to_string<Transfer_>(Transfer_ t)
|
||||
template <size_t ThreadClusterRank = 3>
|
||||
inline std::string to_string(Transfer_<ThreadClusterRank> t)
|
||||
{
|
||||
return to_string(t.transfer);
|
||||
}
|
||||
|
||||
template <>
|
||||
inline std::string to_string<ConvSpecialization_>(ConvSpecialization_ t)
|
||||
inline std::string to_string<ConvSpecializationFwd_>(ConvSpecializationFwd_ t)
|
||||
{
|
||||
std::ostringstream oss;
|
||||
oss << to_string(t.fwd_specialization) << "," << to_string(t.gemm_specialization);
|
||||
return oss.str();
|
||||
}
|
||||
|
||||
template <>
|
||||
inline std::string to_string<ConvSpecializationBwdWeight_>(ConvSpecializationBwdWeight_ t)
|
||||
{
|
||||
std::ostringstream oss;
|
||||
oss << to_string(t.bwd_weight_specialization);
|
||||
return oss.str();
|
||||
}
|
||||
|
||||
template <>
|
||||
inline std::string to_string<Prefetch_>(Prefetch_ t)
|
||||
{
|
||||
std::ostringstream oss;
|
||||
oss << t.num_gemm_k_prefetch_stages << "," << t.num_groups_to_merge << ","
|
||||
<< to_string(t.loop_scheduler);
|
||||
oss << t.num_gemm_k_prefetch_stages << "," << to_string(t.loop_scheduler);
|
||||
return oss.str();
|
||||
}
|
||||
|
||||
template <>
|
||||
inline std::string to_string<BlockGemm_>(BlockGemm_ t)
|
||||
{
|
||||
return to_string(t.block_gemm);
|
||||
return to_string(t.block_gemm_pipeline);
|
||||
}
|
||||
|
||||
template <>
|
||||
@@ -287,7 +338,13 @@ inline std::string to_string<DlThreadCluster_>(DlThreadCluster_ t)
|
||||
}
|
||||
|
||||
template <>
|
||||
inline std::string to_string<DlTransfer_>(DlTransfer_ t)
|
||||
inline std::string to_string<DlTransfer_<4>>(DlTransfer_<4> t)
|
||||
{
|
||||
return to_string(t.transfer);
|
||||
}
|
||||
|
||||
template <>
|
||||
inline std::string to_string<DlTransfer_<5>>(DlTransfer_<5> t)
|
||||
{
|
||||
return to_string(t.transfer);
|
||||
}
|
||||
@@ -299,8 +356,8 @@ inline std::string to_string<ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_C
|
||||
ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle t)
|
||||
{
|
||||
std::ostringstream oss;
|
||||
oss << to_string(static_cast<ThreadBlock_>(t)) << "," << to_string(static_cast<XdlGemm_>(t))
|
||||
<< "," << to_string(static_cast<Transfer_>(t));
|
||||
oss << to_string(static_cast<ThreadBlock_>(t)) << "," << to_string(static_cast<FwdXdlGemm_>(t))
|
||||
<< "," << to_string(static_cast<Transfer_<>>(t));
|
||||
return oss.str();
|
||||
}
|
||||
|
||||
@@ -309,8 +366,8 @@ inline std::string to_string<ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_C
|
||||
ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 t)
|
||||
{
|
||||
std::ostringstream oss;
|
||||
oss << to_string(static_cast<ThreadBlock_>(t)) << "," << to_string(static_cast<XdlGemm_>(t))
|
||||
<< "," << to_string(static_cast<Transfer_>(t));
|
||||
oss << to_string(static_cast<ThreadBlock_>(t)) << "," << to_string(static_cast<FwdXdlGemm_>(t))
|
||||
<< "," << to_string(static_cast<Transfer_<>>(t));
|
||||
return oss.str();
|
||||
}
|
||||
|
||||
@@ -320,7 +377,7 @@ inline std::string to_string<ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Wmma_CS
|
||||
{
|
||||
std::ostringstream oss;
|
||||
oss << to_string(static_cast<ThreadBlock_>(t)) << "," << to_string(static_cast<WmmaGemm_>(t))
|
||||
<< "," << to_string(static_cast<Transfer_>(t));
|
||||
<< "," << to_string(static_cast<Transfer_<>>(t));
|
||||
return oss.str();
|
||||
}
|
||||
|
||||
@@ -332,7 +389,7 @@ inline std::string to_string<ConvAlgorithm_DeviceGroupedConvFwdDlMultipleD_NHWC_
|
||||
oss << to_string(static_cast<ThreadBlock_>(t)) << ","
|
||||
<< to_string(static_cast<DlThreadConfig_>(t)) << ","
|
||||
<< to_string(static_cast<DlThreadCluster_>(t)) << ","
|
||||
<< to_string(static_cast<DlTransfer_>(t));
|
||||
<< to_string(static_cast<DlTransfer_<4>>(t));
|
||||
return oss.str();
|
||||
}
|
||||
|
||||
@@ -340,7 +397,102 @@ template <>
|
||||
inline std::string to_string<ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor>(
|
||||
ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor t)
|
||||
{
|
||||
return to_string(t.base_algorithm);
|
||||
std::ostringstream oss;
|
||||
oss << to_string(static_cast<ThreadBlock_>(t)) << "," << to_string(static_cast<FwdXdlGemm_>(t))
|
||||
<< "," << to_string(static_cast<Transfer_<>>(t));
|
||||
return oss.str();
|
||||
}
|
||||
|
||||
template <>
|
||||
inline std::string to_string<ConvAlgorithm_DeviceGroupedConvBwdWeight_Xdl_CShuffle>(
|
||||
ConvAlgorithm_DeviceGroupedConvBwdWeight_Xdl_CShuffle t)
|
||||
{
|
||||
std::ostringstream oss;
|
||||
oss << to_string(static_cast<ThreadBlock_>(t)) << "," << to_string(static_cast<BwdXdlGemm_>(t))
|
||||
<< "," << to_string(static_cast<Transfer_<4>>(t));
|
||||
return oss.str();
|
||||
}
|
||||
|
||||
template <>
|
||||
inline std::string to_string<ConvAlgorithm_DeviceGroupedConvBwdWeight_Xdl_CShuffle_V3>(
|
||||
ConvAlgorithm_DeviceGroupedConvBwdWeight_Xdl_CShuffle_V3 t)
|
||||
{
|
||||
std::ostringstream oss;
|
||||
oss << to_string(static_cast<ThreadBlock_>(t)) << "," << to_string(static_cast<BwdXdlGemm_>(t))
|
||||
<< "," << to_string(static_cast<Transfer_<>>(t));
|
||||
return oss.str();
|
||||
}
|
||||
|
||||
template <>
|
||||
inline std::string to_string<ConvAlgorithm_DeviceGroupedConvBwdWeight_Wmma_CShuffle>(
|
||||
ConvAlgorithm_DeviceGroupedConvBwdWeight_Wmma_CShuffle t)
|
||||
{
|
||||
std::ostringstream oss;
|
||||
oss << to_string(static_cast<ThreadBlock_>(t)) << "," << to_string(static_cast<WmmaGemm_>(t))
|
||||
<< "," << to_string(static_cast<Transfer_<>>(t));
|
||||
return oss.str();
|
||||
}
|
||||
|
||||
template <>
|
||||
inline std::string to_string<ConvAlgorithm_DeviceGroupedConvBwdWeight_Wmma_CShuffle_V3>(
|
||||
ConvAlgorithm_DeviceGroupedConvBwdWeight_Wmma_CShuffle_V3 t)
|
||||
{
|
||||
std::ostringstream oss;
|
||||
oss << to_string(static_cast<ThreadBlock_>(t)) << "," << to_string(static_cast<WmmaGemm_>(t))
|
||||
<< "," << to_string(static_cast<Transfer_<>>(t));
|
||||
return oss.str();
|
||||
}
|
||||
|
||||
template <>
|
||||
inline std::string to_string<ConvAlgorithm_DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffle_V3>(
|
||||
ConvAlgorithm_DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffle_V3 t)
|
||||
{
|
||||
std::ostringstream oss;
|
||||
oss << to_string(static_cast<ThreadBlock_>(t)) << "," << to_string(static_cast<WmmaGemm_>(t))
|
||||
<< "," << to_string(static_cast<Transfer_<>>(t));
|
||||
return oss.str();
|
||||
}
|
||||
|
||||
template <>
|
||||
inline std::string to_string<ConvAlgorithm_DeviceGroupedConvBwdWeight_TwoStage_Wmma_CShuffle_V3>(
|
||||
ConvAlgorithm_DeviceGroupedConvBwdWeight_TwoStage_Wmma_CShuffle_V3 t)
|
||||
{
|
||||
std::ostringstream oss;
|
||||
oss << to_string(static_cast<ThreadBlock_>(t)) << "," << to_string(static_cast<WmmaGemm_>(t))
|
||||
<< "," << to_string(static_cast<Transfer_<>>(t));
|
||||
return oss.str();
|
||||
}
|
||||
|
||||
template <>
|
||||
inline std::string to_string<ConvAlgorithm_DeviceGroupedConvBwdWeight_TwoStage_Xdl_CShuffle>(
|
||||
ConvAlgorithm_DeviceGroupedConvBwdWeight_TwoStage_Xdl_CShuffle t)
|
||||
{
|
||||
std::ostringstream oss;
|
||||
oss << to_string(static_cast<ThreadBlock_>(t)) << "," << to_string(static_cast<BwdXdlGemm_>(t))
|
||||
<< "," << to_string(static_cast<Transfer_<>>(t));
|
||||
return oss.str();
|
||||
}
|
||||
|
||||
template <>
|
||||
inline std::string to_string<ConvAlgorithm_DeviceGroupedConvBwdWeight_Dl>(
|
||||
ConvAlgorithm_DeviceGroupedConvBwdWeight_Dl t)
|
||||
{
|
||||
std::ostringstream oss;
|
||||
oss << to_string(static_cast<ThreadBlock_>(t)) << ","
|
||||
<< to_string(static_cast<DlThreadConfig_>(t)) << ","
|
||||
<< to_string(static_cast<DlThreadCluster_>(t)) << ","
|
||||
<< to_string(static_cast<DlTransfer_<5>>(t));
|
||||
return oss.str();
|
||||
}
|
||||
|
||||
template <>
|
||||
inline std::string to_string<ConvAlgorithm_DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle>(
|
||||
ConvAlgorithm_DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle t)
|
||||
{
|
||||
std::ostringstream oss;
|
||||
oss << to_string(static_cast<ThreadBlock_>(t)) << "," << to_string(static_cast<BwdXdlGemm_>(t))
|
||||
<< "," << to_string(static_cast<Transfer_<4>>(t));
|
||||
return oss.str();
|
||||
}
|
||||
|
||||
} // namespace ck_tile::builder::test
|
||||
|
||||
@@ -50,7 +50,7 @@ __global__ void
|
||||
#if CK_USE_LAUNCH_BOUNDS
|
||||
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
|
||||
#endif
|
||||
kernel_grouped_conv_bwd_weight_wmma_cshuffle_v3(
|
||||
kernel_grouped_conv_bwd_weight_wmma_cshuffle_v3_multiple_d(
|
||||
typename GridwiseGemm::Argument karg,
|
||||
const AGridDesc_AK0_M_K1 a_grid_desc_ak0_m_ak1,
|
||||
const BGridDesc_BK0_N_K1 b_grid_desc_bk0_n_bk1,
|
||||
@@ -858,30 +858,32 @@ struct DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3
|
||||
{
|
||||
if(gemm_arg.KBatch > 1)
|
||||
{
|
||||
const auto kernel = kernel_grouped_conv_bwd_weight_wmma_cshuffle_v3<
|
||||
GridwiseGemm,
|
||||
remove_reference_t<DeviceOp::AGridDesc_K0_M_K1>,
|
||||
remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>,
|
||||
remove_reference_t<
|
||||
DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>,
|
||||
ComputePtrOffsetOfStridedBatch<I1, I1, NumDTensor>,
|
||||
true,
|
||||
InMemoryDataOperationEnum::AtomicAdd,
|
||||
minimum_occupancy>;
|
||||
const auto kernel =
|
||||
kernel_grouped_conv_bwd_weight_wmma_cshuffle_v3_multiple_d<
|
||||
GridwiseGemm,
|
||||
remove_reference_t<DeviceOp::AGridDesc_K0_M_K1>,
|
||||
remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>,
|
||||
remove_reference_t<
|
||||
DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>,
|
||||
ComputePtrOffsetOfStridedBatch<I1, I1, NumDTensor>,
|
||||
true,
|
||||
InMemoryDataOperationEnum::AtomicAdd,
|
||||
minimum_occupancy>;
|
||||
Run(kernel);
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto kernel = kernel_grouped_conv_bwd_weight_wmma_cshuffle_v3<
|
||||
GridwiseGemm,
|
||||
remove_reference_t<DeviceOp::AGridDesc_K0_M_K1>,
|
||||
remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>,
|
||||
remove_reference_t<
|
||||
DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>,
|
||||
ComputePtrOffsetOfStridedBatch<I1, I1, NumDTensor>,
|
||||
true,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
minimum_occupancy>;
|
||||
const auto kernel =
|
||||
kernel_grouped_conv_bwd_weight_wmma_cshuffle_v3_multiple_d<
|
||||
GridwiseGemm,
|
||||
remove_reference_t<DeviceOp::AGridDesc_K0_M_K1>,
|
||||
remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>,
|
||||
remove_reference_t<
|
||||
DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>,
|
||||
ComputePtrOffsetOfStridedBatch<I1, I1, NumDTensor>,
|
||||
true,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
minimum_occupancy>;
|
||||
Run(kernel);
|
||||
}
|
||||
}
|
||||
@@ -897,30 +899,32 @@ struct DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3
|
||||
{
|
||||
if(gemm_arg.KBatch > 1)
|
||||
{
|
||||
const auto kernel = kernel_grouped_conv_bwd_weight_wmma_cshuffle_v3<
|
||||
GridwiseGemm,
|
||||
remove_reference_t<DeviceOp::AGridDesc_K0_M_K1>,
|
||||
remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>,
|
||||
remove_reference_t<
|
||||
DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>,
|
||||
ComputePtrOffsetOfStridedBatch<I1, I1, NumDTensor>,
|
||||
false,
|
||||
InMemoryDataOperationEnum::AtomicAdd,
|
||||
minimum_occupancy>;
|
||||
const auto kernel =
|
||||
kernel_grouped_conv_bwd_weight_wmma_cshuffle_v3_multiple_d<
|
||||
GridwiseGemm,
|
||||
remove_reference_t<DeviceOp::AGridDesc_K0_M_K1>,
|
||||
remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>,
|
||||
remove_reference_t<
|
||||
DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>,
|
||||
ComputePtrOffsetOfStridedBatch<I1, I1, NumDTensor>,
|
||||
false,
|
||||
InMemoryDataOperationEnum::AtomicAdd,
|
||||
minimum_occupancy>;
|
||||
Run(kernel);
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto kernel = kernel_grouped_conv_bwd_weight_wmma_cshuffle_v3<
|
||||
GridwiseGemm,
|
||||
remove_reference_t<DeviceOp::AGridDesc_K0_M_K1>,
|
||||
remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>,
|
||||
remove_reference_t<
|
||||
DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>,
|
||||
ComputePtrOffsetOfStridedBatch<I1, I1, NumDTensor>,
|
||||
false,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
minimum_occupancy>;
|
||||
const auto kernel =
|
||||
kernel_grouped_conv_bwd_weight_wmma_cshuffle_v3_multiple_d<
|
||||
GridwiseGemm,
|
||||
remove_reference_t<DeviceOp::AGridDesc_K0_M_K1>,
|
||||
remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>,
|
||||
remove_reference_t<
|
||||
DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>,
|
||||
ComputePtrOffsetOfStridedBatch<I1, I1, NumDTensor>,
|
||||
false,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
minimum_occupancy>;
|
||||
Run(kernel);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -52,19 +52,20 @@ __global__ void
|
||||
#if CK_USE_LAUNCH_BOUNDS
|
||||
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||
#endif
|
||||
kernel_batched_gemm_xdlops_bwd_weight(const FloatA* __restrict__ p_a_grid,
|
||||
const FloatB* __restrict__ p_b_grid,
|
||||
FloatC* __restrict__ p_c_grid,
|
||||
const AElementwiseOperation a_element_op,
|
||||
const BElementwiseOperation b_element_op,
|
||||
const CElementwiseOperation c_element_op,
|
||||
const index_t batch_count,
|
||||
const AGridDesc_B_K0_M_K1 a_b_k0_m_k1_grid_desc,
|
||||
const BGridDesc_B_K0_N_K1 b_b_k0_n_k1_grid_desc,
|
||||
const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
const Block2CTileMap block_2_ctile_map,
|
||||
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch)
|
||||
kernel_batched_gemm_xdlops_bwd_weight_multiple_d(
|
||||
const FloatA* __restrict__ p_a_grid,
|
||||
const FloatB* __restrict__ p_b_grid,
|
||||
FloatC* __restrict__ p_c_grid,
|
||||
const AElementwiseOperation a_element_op,
|
||||
const BElementwiseOperation b_element_op,
|
||||
const CElementwiseOperation c_element_op,
|
||||
const index_t batch_count,
|
||||
const AGridDesc_B_K0_M_K1 a_b_k0_m_k1_grid_desc,
|
||||
const BGridDesc_B_K0_N_K1 b_b_k0_n_k1_grid_desc,
|
||||
const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
const Block2CTileMap block_2_ctile_map,
|
||||
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch)
|
||||
{
|
||||
#if defined(__gfx9__) || defined(__gfx11__) || defined(__gfx12__)
|
||||
if constexpr(GridwiseGemm::template IsValidCompilationParameter<>())
|
||||
@@ -568,7 +569,7 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle
|
||||
int max_occupancy = 0;
|
||||
hip_check_error(hipOccupancyMaxActiveBlocksPerMultiprocessor(
|
||||
&max_occupancy,
|
||||
kernel_batched_gemm_xdlops_bwd_weight<
|
||||
kernel_batched_gemm_xdlops_bwd_weight_multiple_d<
|
||||
GridwiseGemm,
|
||||
ADataType,
|
||||
BDataType,
|
||||
@@ -841,7 +842,7 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle
|
||||
p_c_grid, 0, arg.c_space_size_bytes, stream_config.stream_id_));
|
||||
};
|
||||
|
||||
const auto kernel = kernel_batched_gemm_xdlops_bwd_weight<
|
||||
const auto kernel = kernel_batched_gemm_xdlops_bwd_weight_multiple_d<
|
||||
GridwiseGemm,
|
||||
ADataType,
|
||||
BDataType,
|
||||
|
||||
Reference in New Issue
Block a user