Grouped convolution forward device implementation and base flavors for RDNA3/4 (#2964)

* Fixed typos for padded instances

* Added tests for fp16, KM_KN and KM_NK

* Padding not supported for when BDataType is pk_i4_t. Added fix for correct check and removed padding instances.

* Fixed typos

* Updated the set of tests for FP16

* Updated the set of tests for FP16

* Fix typo

* Moved f16xi4 test under the correct data layout group

* example for gemm_universal_bf16

* Adding examples for gemm_wmma instances

* Added the  missing parameters

* Fixed review comments and added executable to cmakeLists

* Fixing clang format

* Fixing build erros

* Fixed compilation failure.

* Modified some code as per gemm_universal_examples

* Fixed the gemm specialization error

* Fixed the build errors.

* Fix strides of a/b_thread_desc

The descriptors are larger than needed (even though the compiler don't alloc registers for unused values).

* Load in M/NRepeat dims with thread copy's slice instead of a loop

* Clone BlockwiseGemmXdlops_pipeline_v1 for WMMA implementation

* Implement Intrawave and Interwave variants of pipeline v1

* Add instances for Interwave and Intrawave v1

* Add instances with ABlockLdsExtraM and BBlockLdsExtraN = 0

* Remove instances that are too slow (mostly because of register spilling)

* Add a workaround for fp8/bf8->f32 packed conversion issue

* Add instances for Interwave and Intrawave v1

* Enable profiling of mixed precision with f8 and int4 on WMMA

* Fix segfault in profiler when B is pk_i4_t

b_device_buf's size in bytes is larger than b_k_n_permute so b_device_buf.ToDevice reads out-of-bounds.

* Remove instances that are too slow (mostly because of register spilling)

* Add missing add_device_gemm_wmma_universal_f8_f8_bf16 declarations

* Add test case for bf16_i4

* Add missing Regular tests

* Add test_gemm_universal_xdl/wmma_fp16 to REGRESSION_TESTS

They take more than 30 seconds

* Fix a bug that fp16_i4 validation passes only with PermuteB

A permutation required by conversion from pk_i4_t to half_t does not
depend on PermuteB, they can be used independently.

* Use PermuteB with f16_i4 in most instances (as xdl)

Some instances use PermuteB = false for checking correctness.
See also the previous commit.

* Fix cache flushing for pk_i4

* Add mixed precision examples

* Disable all tests and instances with f8 on gfx11

Even though f8_f16 and f16_f8 don't require f8 WMMA instructions,
gfx11 still lacks hardware instructions for fast f8->f32 conversion.

* Add FP16 KM_NK and KM_KN test suites for XDL

These tests were added to common .inc for better testing of WMMA instances

* Support multiple D in GridwiseGemm_wmma_cshuffle_v3

DeviceGemm_Wmma_CShuffleV3 is changed for new template parameters.

* Use ThreadGroupTensorSliceTransfer_v7r3

* Clone for device_gemm_wmma_cshuffle_v3.hpp for future Multiple D support

* Clone example/65_gemm_multiply_multiply/gemm_add_add_xdl_fp16.cpp for wmma

* Implement DeviceGemmMultipleD_Wmma_CShuffleV3

* Make gemm_add_add_wmma to work with DeviceGemmMultipleD_Wmma_CShuffleV3

* Prepare gemma_add tests for adding wmma

* Add gemm_add_fastgelu instances and test

* Add a special wrapper to use DeviceGemmMultipleD_Wmma_CShuffleV3 with old API

ckProfiler uses DeviceGemmMultipleD (tests also call its functions), the wrapper allows to use
DeviceGemmMultipleDSplitK instances there.

* removed unnecessary ck parts from compilation

* initial gemm_add_multiply instance implementations

* fixed profiler help message for gemm_add_multiply

* improved multiply_add profiler layout help

* fixed template arguments for test instances

* added test for gemm_add_multiply

* Support multiple D in GridwiseGemm_wmma_cshuffle_v3

DeviceGemm_Wmma_CShuffleV3 is changed for new template parameters.

* Use ThreadGroupTensorSliceTransfer_v7r3

* Clone for device_gemm_wmma_cshuffle_v3.hpp for future Multiple D support

* Clone example/65_gemm_multiply_multiply/gemm_add_add_xdl_fp16.cpp for wmma

* Implement DeviceGemmMultipleD_Wmma_CShuffleV3

* Make gemm_add_add_wmma to work with DeviceGemmMultipleD_Wmma_CShuffleV3

* Prepare gemma_add tests for adding wmma

* Add gemm_add_fastgelu instances and test

* Add a special wrapper to use DeviceGemmMultipleD_Wmma_CShuffleV3 with old API

ckProfiler uses DeviceGemmMultipleD (tests also call its functions), the wrapper allows to use
DeviceGemmMultipleDSplitK instances there.

* switched to splitK interface

* log print added to splitk benchmarks

* revert main cmake comments

* newline change reverted

* added add_fastgelu instances

* revert unintended change in xdl add_fastgelu

* created gemm_add_add_fastgelu instances

* created fastegelu instances

* added tests for all splitk fastgelus

* Added tests.

* multiply_add instances created

* updates to add_multiply splitk instances

* splitk xdl test fixes

* added wmma multiply_multiply instances

* fixed ONLY_XDL_AND_WMMA_KERNELS tag

* Added gemm_add examples for wmma v1 and v3

* fixed / workarounded i8 instances

* Modified the v3 code to added one fp16 bxdl instance.

* added bf16 xdl instance.

* adding gemm_add wmma_cshuffle and other support


(cherry picked from commit ec447e7f564095ea969eddc39ec77b843aa52976)

Co-authored-by: Cenxuan <cenxuan@streamhpc.com>

* add instances into camkelists


(cherry picked from commit 23bf2d2771c939ea3ca7f493433c55255bffd08e)

Co-authored-by: Cenxuan <cenxuan@streamhpc.com>

* This is work in progress, edited the template parameters in order to build

(cherry picked from commit b4fde8a3314cb44659c4bbda35f1a0133c63dc41)

Co-authored-by: Cenxuan <cenxuan@streamhpc.com>

* temp work saved, changed the BDataType to f16 or bf16 since wmma currently not support non-equal A and B datatype


(cherry picked from commit 22fbd68f1db458ab50780a394ee2544c7a1484d1)

Co-authored-by: Cenxuan <cenxuan@streamhpc.com>

* added datatype and use clang-format-12


(cherry picked from commit ae4e853682ef1bb27784b2f965b4a66b3751ceec)

Co-authored-by: Cenxuan <cenxuan@streamhpc.com>

* Fixing build errors

* Added instances for v3

* Adding instances and executables

* Code update of template parameters modified.

* Renamed file.

* Added tests.

* resolved error tests.

* Fixing build errors

* Updated comments

* removed the changes as per the MR review comment.

* Updated tests.

* fp8 instances - not tested

* Restored the Cmake file that was reverted by mistake during rebase.

* fixed wmma_op test

* Updated comments.

* Updated the template parameter description

* fixed rdna4 instances

* fixed back compatibility on gfx11

* cleanups

* fix ckProfiler

* one more cmake fix

* added fp8 instances

* Updated tests to ad BF16 instances as per review comment

* Added include file and cleaned up(as per review comment)

* Updated and optimized the example code for all types.

* Fixed clang format

* Resolve "Implement `device_gemm_bilinear` for RDNA4"

* test generalization to handle FP16 shuffle better

* added missing changes

* Added bf16 wmma instance for add_relu

* Added f16 wmma instance and corrected bf16 instance errors.

* Added instances to Cmake

* Modified the template parameters to make the instances work.

* Fixed typo in profiler

* Added v3 instances for gemm_add_relu

* addressed core review comments

* Added test for gemm_add_relu wmma instance

* Cleaned up the code.

* Added examples for gemm_add_relu

* Fixing typo to resolve build errors.

* Fixes applied to fix  the precision loss.

* fix billinear test after merge

* Removed the old wmma instances.

* Added wrapper and renamed the wmma_v3 instances

* Updated copyrights and added wrappers.

* Fixes applied according to review comments

* Apply 1 suggestion(s) to 1 file(s)

Co-authored-by: Robin Voetter <robin@streamhpc.com>

* Removed the old wmma instances.

* Updated wrapper for the v3 instances

* removed the old wmma examples

* Renamed the v3 instances

* Deleted the  gtest file added by mistake.

* Updated thge profiler with wrapper

* Fixed test errors.

* Fixed the review comments

* Fixed the if condition MACROS.

* REVERTED THE PROFILER CHANGES

* Revert "REVERTED THE PROFILER CHANGES"

This reverts commit 21cb98546c.

* Revert "Fixed test errors."

This reverts commit 13efcc6fe1.

* Revert "Updated thge profiler with wrapper"

This reverts commit 536f86661d.

* Added missing wrapper instances

* Updated copyrights.

* Fixed typo.

* Fixed copyrights.

* Updated copyrights.

* updated copyrights.

* comments on the atomics workaround

* fixed cmake comment

* Fix bug from merge

* clang-format-18

* Fix compilation error

* multi_abd wmma support:

 - Add multiple A and B support to multiple D implementation (gridwise level)
 - Add multi_abd GEMM (device level)
 - Add instances (xdl parity)
 - Add tests (both xdl and wmma)
 - Add examples
 - Add ckProfiler support (both xdl and wmma)

* Fix bug in device print function

* Fix unused template parameter

* Add support for fwd conv in gridwise implementation. Identical to run function for bwd data.

* Initial device implementation for grouped conv fwd multiABD wmma cshuffleV3. Functional but needs some fixups and extra features in the future.

* Make relevant profilers print the number of valid instances to aid testing.

* Add instances for all vanilla 2D and 3D flavors for f16 and bf16, only one instance per instance list to save compile time for now.  Also added incomplete set of comp instances and bias_clamp for f16 2D, just to make sure the multiple-D aspects of the device implementation are working.

* Reset output buffer after each run in profile_grouped_conv_fwd_impl().

* Disable sharding for the new instances for now, has tendency to lead to linker errors on repeat builds.

* Add CTranspose optimization for NCHW cases just like in xdl cshuffle non-v3 device implementation.

* Add instances for all 8-bit 3D vanilla grouped conv fwd types, including mixed types but with the exception of deprecated f16 comp fp8. Adapt test so we can test 8-bit and mixed types.

* Add int8 instances for 2D vanilla grouped conv fwd all layouts.

* Implement merged groups in device impl and add instances for merged groups 3D vanilla conv fwd

* Add merged groups instances for all 2D vanilla grouped conv fwd types and layouts.

* Implement multi-AB support for grouped conv fwd and add example.

* Add 1D instances

* Add D layout tests to IsSupportedArgument()

* Add comp and mem instances for all vanilla 2D grouped conv fwd types. Skipping "x2" and "part2" instance lists, can be added later without special names if necessary.

* Add comp and mem instances for vanilla 3D grouped conv fwd. Skipped 2x and part2 instances, can be added later in the same instance lists.

* Add some more tests for vanilla grouped conv fwd

* Add 2D bias clamp instances and tests

* Add 3D bias clamp instances and tests

* Add 2D and 3D clamp instances and tests

* Unify problem sizes across vanilla and clamp flavor tests

* Clean up device implementation: remove old todos, remove unnecessary comments and print statements, tweak description, wrap all prints in env check.

* Implement rotating memory and flush cache. Requires ad-hoc buffer size calculations.

* Remove wmma fp8 and bf8 instances when not targetting gfx12

* Add newer instances to DEVICE_INSTANCES so the main ckProfiler can build

* Remove old years for newly created files.

* No need to time kernels for now.

* Fixup comments

* Pass struct args to Gridwise Run() function by reference.

* Don't use workspace memory in the case where A needs explicit transposition but B does not.

* Move calculation of rotating memory buffer sizes to Argument member functions.

* After the convolution to gemm transformation, the resulting 2D tensor descriptors are not necessarily RowMajor or ColumnMajor, so things should not rely on this distinction. Therefore, pass all RowMajor to the Gridwise and use a special version of CheckValidity that does not rely on 2D tensor layouts.

* Unify xdl and wmma example code for grouped conv fwd scaleadd ab

* Go back to passing RCR 2D tensor layouts to gridwise gemm, and use CRC for the CTranspose case. Also remove the special convolution version of checkValidity(). It seems like no matter what 2D tensor layouts you pass to the gridwise gemm, and no matter if you are using extraMN, and no matter if you are using the convolution version of checkvalidity, the results of all tests are the same.

* Add wmma scaleadd ab instances to the device factory and add a completely new scaleadd_ab gtest test for wmma cshufflev3 and xdl. Currently there is no profiler for scaleadd_ab so I made my own inside the test. Furthermore for XDL only the (NDHWGC, GKZYXC, NDHWGK) layout combination existed in the instance factory so that is the only one I added for wmma cshufflev3 and the gtest test as well. Another layout is tested in example 62, for xdl and wmma cshufflev3.

* Add support for V3 pipeline (tested). To be able to support num_loop < 3 we need the fixes from the batched gemm gemm MR which was already merged upstream, so just need to rebase or merge.

* Small post-merge fixup, everything seems to work.

* Do not build or run Xdl operations with Wmma backend for now. Will be reverted before upstreaming.

* Extend scaleadd_ab instance lists

* Extend merged groups instance lists, including adaptations of xdl "2x" instances.

* Extend "comp" instance lists, including "2x" and "part2" instances. 2x instances disabled for now since they do not compile.

* Extend "mem" instance lists.

* Extend regular instance lists.

* Fixup comments and ignored kernel arg name

* Properly use the splitN offsets for D tensors in the gridwise Run() function. Was necessary to pass the bias_clamp_large_cases test.

* Make sure all strides in ComputePtrOffset are at least value initialized to avoid undefined strides. Not convinced this struct is properly initialized in other code / future code.

* Re-enable sharding for wmma cshufflev3 instances

* Post merge fix to vanilla test

* Optionally allow num_k_loop <= PrefetchStages in gridwise CheckValidity. Use this for grouped conv fwd but not in general.

* Remove spurious ck_tile changes that were presumably introduced somewhere in the repeated merging from develop.

* Post-merge fixes. Make sure the new gridwise gemm wmma v3 common Run function can be used. Remove splitK, and forceThreadTileTransfer for now. Also add CShuffle epilogue argument.

* Disable FP8 / BF8 testing on CDNA1/2, it doesn't work anymore and needs to be either fixed or removed.

* Re-enable old wmma instances

* Re-enable Linqun's Xdl Wmma instances

* Small post-merge fixes

* Fix copyright headers

* Remove commented code snippet in gridwise

Co-authored-by: Bartłomiej Kocot <barkocot@amd.com>

* Limit the explicit cast added in threadwise_tensor_slice_transfer_v7r3 to only be used for f8, just in case it hurts performance.

* Adding tuned instace list for groupoed conv fwd (#3288)

Following flavors are updated with tuned instance list:
  - grouped_conv2d_fwd
  - grouped_conv2d_fwd_bias_clamp
  - grouped_conv2d_fwd_clamp
  - grouped_conv3d_fwd
  - grouped_conv3d_fwd_bias_clamp
  - grouped_conv3d_fwd_clamp
  - grouped_conv3d_fwd_scaleadd_ab

Re-factored instance selection:
  - removed all the unnecessary instance tuples (comp/mem/16x16/generic)
  - removed all unnecessary layouts and data types

* Do not use std::remove_cvref_t, does not exist in C++17, use custom one.

* Splitting grouped conv fwd instances (#3449)

* Disable unnecessary and failing tests related to experimental CK builder

* Disable unnecessary ck builder experimental tests fully

---------

Co-authored-by: Anca Hamuraru <anca@streamhpc.com>
Co-authored-by: apoorva <apoorva@streamhpc.com>
Co-authored-by: Anton Gorenko <anton@streamhpc.com>
Co-authored-by: Zoltan Lakatos <zoltan.lakatos@streamhpc.com>
Co-authored-by: Cenxuan <cenxuan@streamhpc.com>
Co-authored-by: Robin Voetter <robin@streamhpc.com>
Co-authored-by: Enrico Degregori <enrico@streamhpc.com>
Co-authored-by: Bartłomiej Kocot <barkocot@amd.com>
Co-authored-by: Wojciech Laskowski <77888887+wj-laskowski@users.noreply.github.com>
This commit is contained in:
Kiefer van Teutem
2025-12-18 21:12:15 +01:00
committed by GitHub
parent 700b2ec9c0
commit 2ea710e88b
91 changed files with 8170 additions and 99 deletions

View File

@@ -11,3 +11,12 @@ add_example_executable(example_conv_fwd_xdl_scaleadd_ab_bf16 conv_fwd_xdl_scalea
add_example_dependencies(example_convnd_activ_multi_ab_xdl example_conv_fwd_xdl_scaleadd_ab_bf16)
add_example_executable(example_conv_fwd_xdl_scaleadd_ab_int8 conv_fwd_xdl_scaleadd_ab_int8.cpp)
add_example_dependencies(example_convnd_activ_multi_ab_xdl example_conv_fwd_xdl_scaleadd_ab_int8)
add_custom_target(example_convnd_activ_multi_ab_wmma_cshufflev3)
# ScaleAdd on A and B
add_example_executable(example_conv_fwd_wmma_cshufflev3_scaleadd_ab_fp16 conv_fwd_wmma_cshufflev3_scaleadd_ab_fp16.cpp)
add_example_dependencies(example_convnd_activ_multi_ab_wmma_cshufflev3 example_conv_fwd_wmma_cshufflev3_scaleadd_ab_fp16)
add_example_executable(example_conv_fwd_wmma_cshufflev3_scaleadd_ab_bf16 conv_fwd_wmma_cshufflev3_scaleadd_ab_bf16.cpp)
add_example_dependencies(example_convnd_activ_multi_ab_wmma_cshufflev3 example_conv_fwd_wmma_cshufflev3_scaleadd_ab_bf16)
add_example_executable(example_conv_fwd_wmma_cshufflev3_scaleadd_ab_int8 conv_fwd_wmma_cshufflev3_scaleadd_ab_int8.cpp)
add_example_dependencies(example_convnd_activ_multi_ab_wmma_cshufflev3 example_conv_fwd_wmma_cshufflev3_scaleadd_ab_int8)

View File

@@ -0,0 +1,27 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#define EXAMPLE_USE_WMMA
#include "convnd_fwd_activ_multi_ab_common.hpp"
using DataType = ck::bhalf_t;
using AccDataType = float;
using InDataType = DataType;
using WeiDataType = DataType;
using OutDataType = DataType;
using ADataTypes = ck::Tuple<DataType, DataType>;
using BDataTypes = ck::Tuple<DataType, DataType>;
using InElementOp = ck::tensor_operation::element_wise::ScaleAdd;
using WeiElementOp = ck::tensor_operation::element_wise::ScaleAdd;
using DeviceGroupedConvNDActivInstance = DeviceGroupedConvNDMultiABFwdInstance<DataType,
AccDataType,
ADataTypes,
BDataTypes,
InElementOp,
WeiElementOp>;
#include "../run_convnd_activ_example.inc"
int main(int argc, char* argv[]) { return !run_convnd_example(argc, argv); }

View File

@@ -0,0 +1,27 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#define EXAMPLE_USE_WMMA
#include "convnd_fwd_activ_multi_ab_common.hpp"
using DataType = ck::half_t;
using AccDataType = float;
using InDataType = DataType;
using WeiDataType = DataType;
using OutDataType = DataType;
using ADataTypes = ck::Tuple<DataType, DataType>;
using BDataTypes = ck::Tuple<DataType, DataType>;
using InElementOp = ck::tensor_operation::element_wise::ScaleAdd;
using WeiElementOp = ck::tensor_operation::element_wise::ScaleAdd;
using DeviceGroupedConvNDActivInstance = DeviceGroupedConvNDMultiABFwdInstance<DataType,
AccDataType,
ADataTypes,
BDataTypes,
InElementOp,
WeiElementOp>;
#include "../run_convnd_activ_example.inc"
int main(int argc, char* argv[]) { return !run_convnd_example(argc, argv); }

View File

@@ -0,0 +1,27 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#define EXAMPLE_USE_WMMA
#include "convnd_fwd_activ_multi_ab_common.hpp"
using DataType = int8_t;
using AccDataType = int32_t;
using InDataType = DataType;
using WeiDataType = DataType;
using OutDataType = DataType;
using ADataTypes = ck::Tuple<DataType, DataType>;
using BDataTypes = ck::Tuple<DataType, DataType>;
using InElementOp = ck::tensor_operation::element_wise::ScaleAdd;
using WeiElementOp = ck::tensor_operation::element_wise::ScaleAdd;
using DeviceGroupedConvNDActivInstance = DeviceGroupedConvNDMultiABFwdInstance<DataType,
AccDataType,
ADataTypes,
BDataTypes,
InElementOp,
WeiElementOp>;
#include "../run_convnd_activ_example.inc"
int main(int argc, char* argv[]) { return !run_convnd_example(argc, argv); }

View File

@@ -9,7 +9,11 @@
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#ifdef EXAMPLE_USE_WMMA
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_wmma_cshuffle_v3.hpp"
#else
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp"
#endif
#include "ck/library/utility/algorithm.hpp"
#include "ck/library/utility/check_err.hpp"
@@ -41,6 +45,62 @@ static constexpr auto ConvSpec =
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
#ifdef EXAMPLE_USE_WMMA
template <typename DataType,
typename AccDataType,
typename InDataTypes,
typename WeiDataTypes,
typename InElementOp,
typename WeiElementOp>
using DeviceGroupedConvNDMultiABFwdInstance =
ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3<
NDimSpatial,
InLayout,
WeiLayout,
ck::Tuple<>,
OutLayout,
InDataTypes,
WeiDataTypes,
AccDataType,
DataType,
ck::Tuple<>,
DataType,
InElementOp,
WeiElementOp,
OutElementOp,
ConvSpec, // ConvForwardSpecialization
GemmSpec, // GemmSpecialization
256, // BlockSize
128, // MPerBlock
256, // NPerBlock
32, // KPerBlock
8, // AK1
8, // BK1
16, // MPerWmma
16, // NPerWmma
4, // MWmmaPerWave
4, // NWmmaPerWave
S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1
S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder
S<1, 0, 2>, // ABlockTransferSrcAccessOrder
2, // ABlockTransferSrcVectorDim
8, // ABlockTransferSrcScalarPerVector
8, // ABlockTransferDstScalarPerVector_AK1
1, // ABlockLdsExtraM
S<4, 64, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1
S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder
S<1, 0, 2>, // BBlockTransferSrcAccessOrder
2, // BBlockTransferSrcVectorDim
8, // BBlockTransferSrcScalarPerVector
8, // BBlockTransferDstScalarPerVector_BK1
1, // BBlockLdsExtraN
1,
1,
S<1, 32, 1, 8>,
8,
ck::BlockGemmPipelineScheduler::Intrawave,
ck::BlockGemmPipelineVersion::v1>;
#else
template <typename DataType,
typename AccDataType,
typename InDataTypes,
@@ -94,6 +154,7 @@ using DeviceGroupedConvNDMultiABFwdInstance =
1,
S<1, 32, 1, 8>,
4>;
#endif
namespace {
template <ck::index_t NDimSpatial,
@@ -261,6 +322,8 @@ bool run_grouped_conv(bool do_verification,
out_device_buf.FromDevice(out_device.mData.data());
printf("Running verification\n");
return ck::utils::check_err(out_device, out_host, "Error: incorrect results!");
}

View File

@@ -157,14 +157,14 @@ add_ck_factory_test(test_ckb_testing_utils test_testing_utils.cpp)
# TODO: Re-enable these tests once we have CI working across all GPU architectures.
# add_ck_factory_test(test_ckb_factory_grouped_convolution_forward test_ck_factory_grouped_convolution_forward.cpp)
# add_ck_factory_test(test_ckb_factory_grouped_convolution_forward_clamp test_ck_factory_grouped_convolution_forward_clamp.cpp)
add_ck_factory_test(test_ckb_factory_grouped_convolution_forward_convscale test_ck_factory_grouped_convolution_forward_convscale.cpp)
# add_ck_factory_test(test_ckb_factory_grouped_convolution_forward_convscale test_ck_factory_grouped_convolution_forward_convscale.cpp)
# add_ck_factory_test(test_ckb_factory_grouped_convolution_forward_bilinear test_ck_factory_grouped_convolution_forward_bilinear.cpp)
# add_ck_factory_test(test_ckb_factory_grouped_convolution_forward_scale test_ck_factory_grouped_convolution_forward_scale.cpp)
add_ck_factory_test(test_ckb_factory_grouped_convolution_forward_scaleadd_ab test_ck_factory_grouped_convolution_forward_scaleadd_ab.cpp)
# add_ck_factory_test(test_ckb_factory_grouped_convolution_forward_scaleadd_ab test_ck_factory_grouped_convolution_forward_scaleadd_ab.cpp)
# add_ck_factory_test(test_ckb_factory_grouped_convolution_forward_bias_clamp test_ck_factory_grouped_convolution_forward_bias_clamp.cpp)
# add_ck_factory_test(test_ckb_factory_grouped_convolution_forward_bias_bnorm_clamp test_ck_factory_grouped_convolution_forward_bias_bnorm_clamp.cpp)
add_ck_factory_test(test_ckb_factory_grouped_convolution_forward_scaleadd_scaleadd_relu test_ck_factory_grouped_convolution_forward_scaleadd_scaleadd_relu.cpp)
add_ck_factory_test(test_ckb_factory_grouped_convolution_forward_dynamic_op test_ck_factory_grouped_convolution_forward_dynamic_op.cpp)
# add_ck_factory_test(test_ckb_factory_grouped_convolution_forward_scaleadd_scaleadd_relu test_ck_factory_grouped_convolution_forward_scaleadd_scaleadd_relu.cpp)
# add_ck_factory_test(test_ckb_factory_grouped_convolution_forward_dynamic_op test_ck_factory_grouped_convolution_forward_dynamic_op.cpp)
################################################################################
# CTest Integration - Register Tests and Assign Labels
@@ -193,10 +193,10 @@ set(CKB_REGRESSION_TESTS
test_ckb_instance_string
test_ckb_build_fwd_instances
test_ckb_testing_utils
test_ckb_factory_grouped_convolution_forward_convscale
test_ckb_factory_grouped_convolution_forward_scaleadd_ab
test_ckb_factory_grouped_convolution_forward_scaleadd_scaleadd_relu
test_ckb_factory_grouped_convolution_forward_dynamic_op
# test_ckb_factory_grouped_convolution_forward_convscale
# test_ckb_factory_grouped_convolution_forward_scaleadd_ab
# test_ckb_factory_grouped_convolution_forward_scaleadd_scaleadd_relu
# test_ckb_factory_grouped_convolution_forward_dynamic_op
)
foreach(test_target ${CKB_REGRESSION_TESTS})

View File

@@ -198,10 +198,10 @@ struct ComputePtrOffsetOfStridedBatch<NumATensor,
return static_cast<long_index_t>(g_idx) * BatchStrideE_;
}
Array<long_index_t, NumATensor> BatchStrideA_;
Array<long_index_t, NumBTensor> BatchStrideB_;
Array<long_index_t, NumDTensor> BatchStrideDs_;
long_index_t BatchStrideE_;
Array<long_index_t, NumATensor> BatchStrideA_{};
Array<long_index_t, NumBTensor> BatchStrideB_{};
Array<long_index_t, NumDTensor> BatchStrideDs_{};
long_index_t BatchStrideE_{};
long_index_t& BatchStrideC_ = BatchStrideE_; // alias for kernels without multiple D
};
@@ -253,10 +253,10 @@ struct ComputePtrOffsetOfStridedBatch<NumATensor,
return static_cast<long_index_t>(g_idx) * BatchStrideE_;
}
long_index_t BatchStrideA_;
long_index_t BatchStrideB_;
Array<long_index_t, NumDTensor> BatchStrideDs_;
long_index_t BatchStrideE_;
long_index_t BatchStrideA_{};
long_index_t BatchStrideB_{};
Array<long_index_t, NumDTensor> BatchStrideDs_{};
long_index_t BatchStrideE_{};
long_index_t& BatchStrideC_ = BatchStrideE_; // alias for kernels without multiple D
};

View File

@@ -894,6 +894,144 @@ struct GridwiseGemm_wmma_cshuffle_v3
k_idx,
karg.KBatch);
}
// Run method for convolution fwd (grid descriptors are passed as arguments,
// not generated internally)
template <typename AGridDesc_AK0_M_K1,
typename BGridDesc_BK0_N_K1,
typename DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
typename ComputePtrOffsetOfBatch,
typename ComputePtrOffsetOfN,
bool HasMainKBlockLoop,
InMemoryDataOperationEnum EGlobalMemoryDataOperation,
TailNumber TailNum,
typename EpilogueArgument>
__device__ static void Run(void* p_shared,
const AGridDesc_AK0_M_K1& a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_K1& b_grid_desc_bk0_n_bk1,
const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock&
ds_grid_desc_mblock_mperblock_nblock_nperblock,
const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock&
e_grid_desc_mblock_mperblock_nblock_nperblock,
const ComputePtrOffsetOfBatch& compute_ptr_offset_of_batch,
const ComputePtrOffsetOfN& compute_ptr_offset_of_n,
[[maybe_unused]] const index_t num_k_per_block,
Argument& karg,
EpilogueArgument& epilogue_args)
{
const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.y);
const index_t n_idx = __builtin_amdgcn_readfirstlane(blockIdx.z / karg.KBatch);
// offset base pointer for each work-group
const long_index_t a_batch_offset =
amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx));
const long_index_t b_batch_offset =
amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx));
const long_index_t e_batch_offset =
amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx));
const auto ds_batch_offset = compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx);
const long_index_t a_n_offset =
amd_wave_read_first_lane(compute_ptr_offset_of_n.GetAPtrOffset(n_idx));
const long_index_t b_n_offset =
amd_wave_read_first_lane(compute_ptr_offset_of_n.GetBPtrOffset(n_idx));
const long_index_t e_n_offset =
amd_wave_read_first_lane(compute_ptr_offset_of_n.GetEPtrOffset(n_idx));
const auto ds_n_offset = compute_ptr_offset_of_n.GetDsPtrOffset(n_idx);
AsGridPointer p_as_grid_;
static_for<0, NumATensor, 1>{}([&](auto i) {
using ADataType_ = remove_cvref_t<tuple_element_t<i.value, AsDataType>>;
p_as_grid_(i) =
static_cast<const ADataType_*>(karg.p_as_grid[i]) + a_batch_offset + a_n_offset;
});
BsGridPointer p_bs_grid_;
static_for<0, NumBTensor, 1>{}([&](auto i) {
using BDataType_ = remove_cvref_t<tuple_element_t<i.value, BsDataType>>;
p_bs_grid_(i) =
static_cast<const BDataType_*>(karg.p_bs_grid[i]) + b_batch_offset + b_n_offset;
});
DsGridPointer p_ds_grid_grp;
static_for<0, NumDTensor, 1>{}([&](auto i) {
using DDataType_ = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
p_ds_grid_grp(i) = static_cast<const DDataType_*>(karg.p_ds_grid[i]) +
ds_batch_offset[i] + ds_n_offset[i];
});
// Currently supporting one A and one B
const auto as_grid_desc_ak0_m_ak1 = generate_tuple(
[&](auto i) {
ignore = i;
return a_grid_desc_ak0_m_ak1;
},
Number<NumATensor>{});
const auto bs_grid_desc_bk0_n_bk1 = generate_tuple(
[&](auto i) {
ignore = i;
return b_grid_desc_bk0_n_bk1;
},
Number<NumBTensor>{});
// divide block work by [M, N]
const auto block_2_ctile_map = Block2CTileMap{karg.M, karg.N, 4};
const auto block_work_idx =
block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
if(!block_2_ctile_map.ValidCTileIndex(
block_work_idx,
make_tuple(e_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0),
e_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2))))
{
return;
}
const index_t block_m_id = __builtin_amdgcn_readfirstlane(block_work_idx[I0]);
const index_t block_n_id = __builtin_amdgcn_readfirstlane(block_work_idx[I1]);
// AScale struct (Empty)
using AScale = typename BlockwiseGemmPipe::Empty;
auto a_scale_struct = AScale{};
// BScale struct (Empty)
using BScale = typename BlockwiseGemmPipe::Empty;
auto b_scale_struct = BScale{};
const index_t num_k_block_per_scale = GetKBlockPerScale();
Base::template Run<decltype(as_grid_desc_ak0_m_ak1),
decltype(bs_grid_desc_bk0_n_bk1),
decltype(ds_grid_desc_mblock_mperblock_nblock_nperblock),
decltype(e_grid_desc_mblock_mperblock_nblock_nperblock),
decltype(a_scale_struct),
decltype(b_scale_struct),
decltype(epilogue_args),
HasMainKBlockLoop,
EGlobalMemoryDataOperation,
TailNum>(p_as_grid_,
p_bs_grid_,
p_ds_grid_grp,
karg.p_e_grid + e_batch_offset + e_n_offset,
p_shared,
as_grid_desc_ak0_m_ak1,
bs_grid_desc_bk0_n_bk1,
ds_grid_desc_mblock_mperblock_nblock_nperblock,
e_grid_desc_mblock_mperblock_nblock_nperblock,
karg.a_element_op,
karg.b_element_op,
karg.cde_element_op,
block_m_id,
block_n_id,
num_k_block_per_scale,
a_scale_struct,
b_scale_struct,
epilogue_args);
}
};
} // namespace ck

View File

@@ -620,8 +620,10 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
}
template <typename DsGridDesc>
__device__ static constexpr auto MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
const DsGridDesc& ds_grid_desc_m_n, index_t MBlock, index_t NBlock)
__device__ __host__ static constexpr auto
MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const DsGridDesc& ds_grid_desc_m_n,
index_t MBlock,
index_t NBlock)
{
return generate_tuple(
[&](auto i) {
@@ -737,7 +739,8 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
template <typename Argument>
__host__ static constexpr bool CheckValidity(const Argument& karg)
__host__ static constexpr bool CheckValidity(const Argument& karg,
bool allow_short_v3_pipe = false)
{
static_assert((MPerBlock % (MPerWmma * MRepeat) == 0) &&
(NPerBlock % (NPerWmma * NRepeat)) == 0,
@@ -928,14 +931,17 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
{
if(num_k_loop <= BlockwiseGemmPipe::PrefetchStages)
{
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
if(!(allow_short_v3_pipe && BlkGemmPipelineVer == BlockGemmPipelineVersion::v3))
{
std::cout << "Pipeline validation failed: num_k_loop (" << num_k_loop
<< ") <= PrefetchStages (" << BlockwiseGemmPipe::PrefetchStages
<< ") for pipeline version != v1." << __FILE__ << ":" << __LINE__
<< ", in function: " << __func__ << std::endl;
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{
std::cout << "Pipeline validation failed: num_k_loop (" << num_k_loop
<< ") <= PrefetchStages (" << BlockwiseGemmPipe::PrefetchStages
<< ") for pipeline version != v1." << __FILE__ << ":" << __LINE__
<< ", in function: " << __func__ << std::endl;
}
return false;
}
return false;
}
}

View File

@@ -286,8 +286,25 @@ struct ThreadwiseTensorSliceTransfer_v7r3
static_for<0, nDst, 1>{}([&](auto i) {
using elm_vector_t = typename remove_cvref_t<decltype(elm_vectors[i])>::type;
elm_vectors(i).template AsType<elm_vector_t>()(I0) =
oob_val ? elm_vectors(i).template AsType<elm_vector_t>()[I0] : elm_vector_t{0};
using scalar_t =
remove_cvref_t<decltype(elm_vectors(i).template AsType<elm_vector_t>()[I0])>;
// This is a bit ugly but necessary to be able to compile f8 instances for grouped
// convolution forward. For some reason for that specific type there is an ambiguity
// in the type resolution for the ternary expression. I added an explicit cast to
// disambiguate and only use it for f8 just in case it affects performance.
if constexpr(std::is_same_v<scalar_t, ck::f8_ocp_t>)
{
elm_vectors(i).template AsType<elm_vector_t>()(I0) =
oob_val ? elm_vector_t{elm_vectors(i).template AsType<elm_vector_t>()[I0]}
: elm_vector_t{0};
}
else
{
elm_vectors(i).template AsType<elm_vector_t>()(I0) =
oob_val ? elm_vectors(i).template AsType<elm_vector_t>()[I0]
: elm_vector_t{0};
}
});
elm_vectors_tuple_(thread_scratch_id)(iAccess) = elm_vectors;

View File

@@ -0,0 +1,273 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_wmma_cshuffle_v3.hpp"
#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
using BF16 = ck::bhalf_t;
using F16 = ck::half_t;
using F32 = float;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using Empty_Tuple = ck::Tuple<>;
using namespace ck::tensor_layout::convolution;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using AddClamp = ck::tensor_operation::element_wise::AddClamp;
using Clamp = ck::tensor_operation::element_wise::Clamp;
static constexpr auto ConvFwdDefault =
ck::tensor_operation::device::ConvolutionForwardSpecialization::Default;
static constexpr auto ConvFwd1x1P0 = ConvolutionForwardSpecialization::Filter1x1Pad0;
static constexpr auto ConvFwd1x1S1P0 = ConvolutionForwardSpecialization::Filter1x1Stride1Pad0;
static constexpr auto ConvFwdOddC =
ck::tensor_operation::device::ConvolutionForwardSpecialization::OddC;
static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding;
template <index_t NDimSpatial,
typename ALayout,
typename BLayout,
typename DsLayout,
typename ELayout,
ConvolutionForwardSpecialization ConvSpec,
typename DsDataTypes = Tuple<>,
typename OutElementOp = PassThrough>
using device_grouped_conv_fwd_wmma_cshufflev3_bf16_instances_part1 = std::tuple<
// clang-format off
//########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MWmma| NWmma| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Pipeline scheduler | Pipeline version |
//########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Size| Block| Block| Block| | | WMMA| WMMA| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MWmmaPerWave| NWmmaPerWave| _MBlock_MWaveMPerWmma| ScalarPerVector| | |
//########################################| | | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerWmma| _NWaveNPerWmma| | |
//########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// generic instance
DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, DsDataTypes, BF16, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 64, 64, 64, 32, 8, 8, 16, 16, 4, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>
#ifndef ONE_INSTANCE_PER_LIST
,
DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, DsDataTypes, BF16, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, DsDataTypes, BF16, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 64, 64, 64, 32, 8, 8, 16, 16, 4, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, DsDataTypes, BF16, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 64, 64, 32, 32, 8, 8, 16, 16, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, DsDataTypes, BF16, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 256, 128, 256, 32, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1>,
DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, DsDataTypes, BF16, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, DsDataTypes, BF16, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 128, 128, 48, 64, 8, 8, 16, 16, 2, 3, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 2>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, DsDataTypes, BF16, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 128, 128, 48, 64, 8, 8, 16, 16, 2, 3, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 64, 1, 2>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, DsDataTypes, BF16, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 256, 64, 64, 64, 8, 8, 16, 16, 2, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, DsDataTypes, BF16, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 256, 64, 64, 64, 8, 8, 16, 16, 2, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>
#endif
// clang-format on
>;
template <index_t NDimSpatial,
typename ALayout,
typename BLayout,
typename DsLayout,
typename ELayout,
ConvolutionForwardSpecialization ConvSpec,
typename DsDataTypes = Tuple<>,
typename OutElementOp = PassThrough>
using device_grouped_conv_fwd_wmma_cshufflev3_bf16_instances_part2 = std::tuple<
// clang-format off
DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, DsDataTypes, BF16, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>
#ifndef ONE_INSTANCE_PER_LIST
,
DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, DsDataTypes, BF16, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 128, 64, 32, 64, 8, 8, 16, 16, 2, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, DsDataTypes, BF16, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 128, 64, 64, 64, 8, 8, 16, 16, 2, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, DsDataTypes, BF16, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 128, 64, 64, 64, 8, 8, 16, 16, 2, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, DsDataTypes, BF16, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 128, 64, 64, 64, 8, 8, 16, 16, 2, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, DsDataTypes, BF16, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 128, 64, 96, 64, 8, 8, 16, 16, 2, 3, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, DsDataTypes, BF16, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 128, 64, 96, 64, 8, 8, 16, 16, 2, 3, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, DsDataTypes, BF16, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 128, 128, 96, 32, 8, 8, 16, 16, 4, 3, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, DsDataTypes, BF16, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 128, 128, 96, 32, 8, 8, 16, 16, 4, 3, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>
#endif
// clang-format on
>;
template <index_t NDimSpatial,
typename ALayout,
typename BLayout,
typename DsLayout,
typename ELayout,
ConvolutionForwardSpecialization ConvSpec,
typename DsDataTypes = Tuple<>,
typename OutElementOp = PassThrough>
using device_grouped_conv_fwd_wmma_cshufflev3_bf16_instances_part3 = std::tuple<
// clang-format off
DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, DsDataTypes, BF16, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 128, 128, 96, 64, 8, 8, 16, 16, 4, 3, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>
#ifndef ONE_INSTANCE_PER_LIST
,
DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, DsDataTypes, BF16, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 256, 128, 256, 64, 8, 8, 16, 16, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, DsDataTypes, BF16, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 256, 128, 96, 64, 8, 8, 16, 16, 2, 3, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 64, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, DsDataTypes, BF16, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 256, 128, 96, 64, 8, 8, 16, 16, 2, 3, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, DsDataTypes, BF16, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 256, 128, 96, 64, 8, 8, 16, 16, 2, 3, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 64, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, DsDataTypes, BF16, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 128, 128, 64, 32, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, DsDataTypes, BF16, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 128, 128, 64, 32, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, DsDataTypes, BF16, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 128, 128, 64, 32, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, DsDataTypes, BF16, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 128, 128, 32, 32, 8, 8, 16, 16, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>
#endif
// clang-format on
>;
template <index_t NDimSpatial,
typename ALayout,
typename BLayout,
typename DsLayout,
typename ELayout,
ConvolutionForwardSpecialization ConvSpec,
typename DsDataTypes = Tuple<>,
typename OutElementOp = PassThrough>
using device_grouped_conv_fwd_wmma_cshufflev3_bf16_instances_part4 = std::tuple<
// clang-format off
DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, DsDataTypes, BF16, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 128, 128, 32, 32, 8, 8, 16, 16, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>
#ifndef ONE_INSTANCE_PER_LIST
,
DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, DsDataTypes, BF16, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 128, 128, 32, 32, 8, 8, 16, 16, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, DsDataTypes, BF16, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 256, 256, 64, 64, 8, 8, 16, 16, 4, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, DsDataTypes, BF16, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 256, 128, 256, 32, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, DsDataTypes, BF16, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 128, 128, 128, 32, 8, 8, 16, 16, 8, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, DsDataTypes, BF16, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, DsDataTypes, BF16, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 128, 64, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, DsDataTypes, BF16, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 256, 64, 64, 32, 8, 8, 16, 16, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, DsDataTypes, BF16, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 256, 64, 64, 32, 8, 8, 16, 16, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>
#endif
// clang-format on
>;
template <index_t NDimSpatial,
typename ALayout,
typename BLayout,
typename DsLayout,
typename ELayout,
ConvolutionForwardSpecialization ConvSpec,
typename DsDataTypes = Tuple<>,
typename OutElementOp = PassThrough>
using device_grouped_conv_fwd_wmma_cshufflev3_f16_instances_part1 = std::tuple<
// clang-format off
//########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MWmma| NWmma| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Pipeline scheduler | Pipeline version |
//########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Size| Block| Block| Block| | | WMMA| WMMA| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MWmmaPerWave| NWmmaPerWave| _MBlock_MWaveMPerWmma| ScalarPerVector| | |
//########################################| | | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerWmma| _NWaveNPerWmma| | |
//########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// generic instance
DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, DsDataTypes, F16, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 64, 64, 64, 32, 8, 8, 16, 16, 4, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>
#ifndef ONE_INSTANCE_PER_LIST
,
DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, DsDataTypes, F16, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, DsDataTypes, F16, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 64, 64, 64, 32, 8, 8, 16, 16, 4, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, DsDataTypes, F16, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 64, 64, 32, 32, 8, 8, 16, 16, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, DsDataTypes, F16, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 256, 128, 256, 32, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1>,
DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, DsDataTypes, F16, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, DsDataTypes, F16, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 128, 128, 48, 64, 8, 8, 16, 16, 2, 3, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 2>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, DsDataTypes, F16, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 128, 128, 48, 64, 8, 8, 16, 16, 2, 3, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 64, 1, 2>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, DsDataTypes, F16, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 256, 64, 64, 64, 8, 8, 16, 16, 2, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>
#endif
// clang-format on
>;
template <index_t NDimSpatial,
typename ALayout,
typename BLayout,
typename DsLayout,
typename ELayout,
ConvolutionForwardSpecialization ConvSpec,
typename DsDataTypes = Tuple<>,
typename OutElementOp = PassThrough>
using device_grouped_conv_fwd_wmma_cshufflev3_f16_instances_part2 = std::tuple<
// clang-format off
//########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MWmma| NWmma| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Pipeline scheduler | Pipeline version |
//########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Size| Block| Block| Block| | | WMMA| WMMA| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MWmmaPerWave| NWmmaPerWave| _MBlock_MWaveMPerWmma| ScalarPerVector| | |
//########################################| | | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerWmma| _NWaveNPerWmma| | |
//########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// generic instance
DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, DsDataTypes, F16, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 256, 64, 64, 64, 8, 8, 16, 16, 2, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>
#ifndef ONE_INSTANCE_PER_LIST
,
DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, DsDataTypes, F16, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, DsDataTypes, F16, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 128, 64, 32, 64, 8, 8, 16, 16, 2, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, DsDataTypes, F16, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 128, 64, 64, 64, 8, 8, 16, 16, 2, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, DsDataTypes, F16, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 128, 64, 64, 64, 8, 8, 16, 16, 2, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, DsDataTypes, F16, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 128, 64, 64, 64, 8, 8, 16, 16, 2, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, DsDataTypes, F16, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 128, 64, 96, 64, 8, 8, 16, 16, 2, 3, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, DsDataTypes, F16, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 128, 64, 96, 64, 8, 8, 16, 16, 2, 3, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, DsDataTypes, F16, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 128, 128, 96, 32, 8, 8, 16, 16, 4, 3, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, DsDataTypes, F16, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 128, 128, 96, 32, 8, 8, 16, 16, 4, 3, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>
#endif
// clang-format on
>;
template <index_t NDimSpatial,
typename ALayout,
typename BLayout,
typename DsLayout,
typename ELayout,
ConvolutionForwardSpecialization ConvSpec,
typename DsDataTypes = Tuple<>,
typename OutElementOp = PassThrough>
using device_grouped_conv_fwd_wmma_cshufflev3_f16_instances_part3 = std::tuple<
// clang-format off
//########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MWmma| NWmma| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Pipeline scheduler | Pipeline version |
//########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Size| Block| Block| Block| | | WMMA| WMMA| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MWmmaPerWave| NWmmaPerWave| _MBlock_MWaveMPerWmma| ScalarPerVector| | |
//########################################| | | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerWmma| _NWaveNPerWmma| | |
//########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// generic instance
DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, DsDataTypes, F16, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 128, 128, 96, 64, 8, 8, 16, 16, 4, 3, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>
#ifndef ONE_INSTANCE_PER_LIST
,
DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, DsDataTypes, F16, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 256, 128, 256, 64, 8, 8, 16, 16, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, DsDataTypes, F16, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 256, 128, 96, 64, 8, 8, 16, 16, 2, 3, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 64, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, DsDataTypes, F16, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 256, 128, 96, 64, 8, 8, 16, 16, 2, 3, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, DsDataTypes, F16, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 256, 128, 96, 64, 8, 8, 16, 16, 2, 3, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 64, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, DsDataTypes, F16, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 128, 128, 64, 32, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, DsDataTypes, F16, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 128, 128, 64, 32, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, DsDataTypes, F16, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 128, 128, 64, 32, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, DsDataTypes, F16, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 128, 128, 32, 32, 8, 8, 16, 16, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>
#endif
// clang-format on
>;
template <index_t NDimSpatial,
typename ALayout,
typename BLayout,
typename DsLayout,
typename ELayout,
ConvolutionForwardSpecialization ConvSpec,
typename DsDataTypes = Tuple<>,
typename OutElementOp = PassThrough>
using device_grouped_conv_fwd_wmma_cshufflev3_f16_instances_part4 = std::tuple<
// clang-format off
//########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MWmma| NWmma| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Pipeline scheduler | Pipeline version |
//########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Size| Block| Block| Block| | | WMMA| WMMA| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MWmmaPerWave| NWmmaPerWave| _MBlock_MWaveMPerWmma| ScalarPerVector| | |
//########################################| | | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerWmma| _NWaveNPerWmma| | |
//########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// generic instance
DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, DsDataTypes, F16, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 128, 128, 32, 32, 8, 8, 16, 16, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>
#ifndef ONE_INSTANCE_PER_LIST
,
DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, DsDataTypes, F16, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 128, 128, 32, 32, 8, 8, 16, 16, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, DsDataTypes, F16, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 256, 256, 64, 64, 8, 8, 16, 16, 4, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, DsDataTypes, F16, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 256, 128, 256, 32, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, DsDataTypes, F16, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 128, 128, 128, 32, 8, 8, 16, 16, 8, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, DsDataTypes, F16, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, DsDataTypes, F16, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 128, 64, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, DsDataTypes, F16, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 256, 64, 64, 32, 8, 8, 16, 16, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, DsDataTypes, F16, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 256, 64, 64, 32, 8, 8, 16, 16, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>
#endif
// clang-format on
>;
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,145 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_wmma_cshuffle_v3.hpp"
#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
using BF16 = ck::bhalf_t;
using F16 = ck::half_t;
using F32 = float;
using I8 = int8_t;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using namespace ck::tensor_layout::convolution;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using ScaleAdd = ck::tensor_operation::element_wise::ScaleAdd;
static constexpr auto ConvFwdDefault =
ck::tensor_operation::device::ConvolutionForwardSpecialization::Default;
static constexpr auto ConvFwd1x1P0 = ConvolutionForwardSpecialization::Filter1x1Pad0;
static constexpr auto ConvFwd1x1S1P0 = ConvolutionForwardSpecialization::Filter1x1Stride1Pad0;
static constexpr auto ConvFwdOddC =
ck::tensor_operation::device::ConvolutionForwardSpecialization::OddC;
static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding;
template <index_t NDimSpatial,
typename ALayout,
typename BLayout,
typename ELayout,
ConvolutionForwardSpecialization ConvSpec>
using device_grouped_conv_fwd_wmma_cshufflev3_scaleadd_ab_bf16_instances =
std::tuple<
// clang-format off
//########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MWmma| NWmma| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Pipeline scheduler | Pipeline version |
//########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Size| Block| Block| Block| | | WMMA| WMMA| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MWmmaPerWave| NWmmaPerWave| _MBlock_MWaveMPerWmma| ScalarPerVector| | |
//########################################| | | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerWmma| _NWaveNPerWmma| | |
//########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// generic instance
DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3<NDimSpatial, ALayout, BLayout, ck::Tuple<>, ELayout, ck::Tuple<BF16, BF16>, ck::Tuple<BF16, BF16>, F32, BF16, ck::Tuple<>, BF16, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 256, 128, 256, 32, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>
#ifndef ONE_INSTANCE_PER_LIST
,
DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3<NDimSpatial, ALayout, BLayout, ck::Tuple<>, ELayout, ck::Tuple<BF16, BF16>, ck::Tuple<BF16, BF16>, F32, BF16, ck::Tuple<>, BF16, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3<NDimSpatial, ALayout, BLayout, ck::Tuple<>, ELayout, ck::Tuple<BF16, BF16>, ck::Tuple<BF16, BF16>, F32, BF16, ck::Tuple<>, BF16, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 128, 128, 48, 64, 8, 8, 16, 16, 2, 3, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 2>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3<NDimSpatial, ALayout, BLayout, ck::Tuple<>, ELayout, ck::Tuple<BF16, BF16>, ck::Tuple<BF16, BF16>, F32, BF16, ck::Tuple<>, BF16, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 128, 128, 48, 64, 8, 8, 16, 16, 2, 3, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 64, 1, 2>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3<NDimSpatial, ALayout, BLayout, ck::Tuple<>, ELayout, ck::Tuple<BF16, BF16>, ck::Tuple<BF16, BF16>, F32, BF16, ck::Tuple<>, BF16, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 256, 64, 64, 64, 8, 8, 16, 16, 2, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3<NDimSpatial, ALayout, BLayout, ck::Tuple<>, ELayout, ck::Tuple<BF16, BF16>, ck::Tuple<BF16, BF16>, F32, BF16, ck::Tuple<>, BF16, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 256, 64, 64, 64, 8, 8, 16, 16, 2, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3<NDimSpatial, ALayout, BLayout, ck::Tuple<>, ELayout, ck::Tuple<BF16, BF16>, ck::Tuple<BF16, BF16>, F32, BF16, ck::Tuple<>, BF16, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3<NDimSpatial, ALayout, BLayout, ck::Tuple<>, ELayout, ck::Tuple<BF16, BF16>, ck::Tuple<BF16, BF16>, F32, BF16, ck::Tuple<>, BF16, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 128, 64, 32, 64, 8, 8, 16, 16, 2, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3<NDimSpatial, ALayout, BLayout, ck::Tuple<>, ELayout, ck::Tuple<BF16, BF16>, ck::Tuple<BF16, BF16>, F32, BF16, ck::Tuple<>, BF16, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 128, 64, 64, 64, 8, 8, 16, 16, 2, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3<NDimSpatial, ALayout, BLayout, ck::Tuple<>, ELayout, ck::Tuple<BF16, BF16>, ck::Tuple<BF16, BF16>, F32, BF16, ck::Tuple<>, BF16, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 128, 64, 64, 64, 8, 8, 16, 16, 2, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3<NDimSpatial, ALayout, BLayout, ck::Tuple<>, ELayout, ck::Tuple<BF16, BF16>, ck::Tuple<BF16, BF16>, F32, BF16, ck::Tuple<>, BF16, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 128, 64, 64, 64, 8, 8, 16, 16, 2, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3<NDimSpatial, ALayout, BLayout, ck::Tuple<>, ELayout, ck::Tuple<BF16, BF16>, ck::Tuple<BF16, BF16>, F32, BF16, ck::Tuple<>, BF16, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 128, 64, 96, 64, 8, 8, 16, 16, 2, 3, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3<NDimSpatial, ALayout, BLayout, ck::Tuple<>, ELayout, ck::Tuple<BF16, BF16>, ck::Tuple<BF16, BF16>, F32, BF16, ck::Tuple<>, BF16, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 128, 64, 96, 64, 8, 8, 16, 16, 2, 3, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3<NDimSpatial, ALayout, BLayout, ck::Tuple<>, ELayout, ck::Tuple<BF16, BF16>, ck::Tuple<BF16, BF16>, F32, BF16, ck::Tuple<>, BF16, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 128, 128, 96, 32, 8, 8, 16, 16, 4, 3, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3<NDimSpatial, ALayout, BLayout, ck::Tuple<>, ELayout, ck::Tuple<BF16, BF16>, ck::Tuple<BF16, BF16>, F32, BF16, ck::Tuple<>, BF16, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 128, 128, 96, 32, 8, 8, 16, 16, 4, 3, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3<NDimSpatial, ALayout, BLayout, ck::Tuple<>, ELayout, ck::Tuple<BF16, BF16>, ck::Tuple<BF16, BF16>, F32, BF16, ck::Tuple<>, BF16, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 128, 128, 96, 64, 8, 8, 16, 16, 4, 3, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3<NDimSpatial, ALayout, BLayout, ck::Tuple<>, ELayout, ck::Tuple<BF16, BF16>, ck::Tuple<BF16, BF16>, F32, BF16, ck::Tuple<>, BF16, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 256, 128, 256, 64, 8, 8, 16, 16, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3<NDimSpatial, ALayout, BLayout, ck::Tuple<>, ELayout, ck::Tuple<BF16, BF16>, ck::Tuple<BF16, BF16>, F32, BF16, ck::Tuple<>, BF16, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 256, 128, 96, 64, 8, 8, 16, 16, 2, 3, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 64, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3<NDimSpatial, ALayout, BLayout, ck::Tuple<>, ELayout, ck::Tuple<BF16, BF16>, ck::Tuple<BF16, BF16>, F32, BF16, ck::Tuple<>, BF16, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 256, 128, 96, 64, 8, 8, 16, 16, 2, 3, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3<NDimSpatial, ALayout, BLayout, ck::Tuple<>, ELayout, ck::Tuple<BF16, BF16>, ck::Tuple<BF16, BF16>, F32, BF16, ck::Tuple<>, BF16, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 256, 128, 96, 64, 8, 8, 16, 16, 2, 3, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 64, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3<NDimSpatial, ALayout, BLayout, ck::Tuple<>, ELayout, ck::Tuple<BF16, BF16>, ck::Tuple<BF16, BF16>, F32, BF16, ck::Tuple<>, BF16, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 128, 128, 64, 32, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3<NDimSpatial, ALayout, BLayout, ck::Tuple<>, ELayout, ck::Tuple<BF16, BF16>, ck::Tuple<BF16, BF16>, F32, BF16, ck::Tuple<>, BF16, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 128, 128, 64, 32, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3<NDimSpatial, ALayout, BLayout, ck::Tuple<>, ELayout, ck::Tuple<BF16, BF16>, ck::Tuple<BF16, BF16>, F32, BF16, ck::Tuple<>, BF16, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 128, 128, 64, 32, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3<NDimSpatial, ALayout, BLayout, ck::Tuple<>, ELayout, ck::Tuple<BF16, BF16>, ck::Tuple<BF16, BF16>, F32, BF16, ck::Tuple<>, BF16, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 128, 128, 32, 32, 8, 8, 16, 16, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3<NDimSpatial, ALayout, BLayout, ck::Tuple<>, ELayout, ck::Tuple<BF16, BF16>, ck::Tuple<BF16, BF16>, F32, BF16, ck::Tuple<>, BF16, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 128, 128, 32, 32, 8, 8, 16, 16, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3<NDimSpatial, ALayout, BLayout, ck::Tuple<>, ELayout, ck::Tuple<BF16, BF16>, ck::Tuple<BF16, BF16>, F32, BF16, ck::Tuple<>, BF16, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 128, 128, 32, 32, 8, 8, 16, 16, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3<NDimSpatial, ALayout, BLayout, ck::Tuple<>, ELayout, ck::Tuple<BF16, BF16>, ck::Tuple<BF16, BF16>, F32, BF16, ck::Tuple<>, BF16, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 256, 256, 64, 64, 8, 8, 16, 16, 4, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3<NDimSpatial, ALayout, BLayout, ck::Tuple<>, ELayout, ck::Tuple<BF16, BF16>, ck::Tuple<BF16, BF16>, F32, BF16, ck::Tuple<>, BF16, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 256, 128, 256, 32, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3<NDimSpatial, ALayout, BLayout, ck::Tuple<>, ELayout, ck::Tuple<BF16, BF16>, ck::Tuple<BF16, BF16>, F32, BF16, ck::Tuple<>, BF16, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 128, 128, 128, 32, 8, 8, 16, 16, 8, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3<NDimSpatial, ALayout, BLayout, ck::Tuple<>, ELayout, ck::Tuple<BF16, BF16>, ck::Tuple<BF16, BF16>, F32, BF16, ck::Tuple<>, BF16, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3<NDimSpatial, ALayout, BLayout, ck::Tuple<>, ELayout, ck::Tuple<BF16, BF16>, ck::Tuple<BF16, BF16>, F32, BF16, ck::Tuple<>, BF16, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 128, 64, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3<NDimSpatial, ALayout, BLayout, ck::Tuple<>, ELayout, ck::Tuple<BF16, BF16>, ck::Tuple<BF16, BF16>, F32, BF16, ck::Tuple<>, BF16, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 256, 64, 64, 32, 8, 8, 16, 16, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3<NDimSpatial, ALayout, BLayout, ck::Tuple<>, ELayout, ck::Tuple<BF16, BF16>, ck::Tuple<BF16, BF16>, F32, BF16, ck::Tuple<>, BF16, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 256, 64, 64, 32, 8, 8, 16, 16, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>
#endif
// clang-format on
>;
template <index_t NDimSpatial,
typename ALayout,
typename BLayout,
typename ELayout,
ConvolutionForwardSpecialization ConvSpec>
using device_grouped_conv_fwd_wmma_cshufflev3_scaleadd_ab_f16_instances =
std::tuple<
// clang-format off
//########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MWmma| NWmma| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Pipeline scheduler | Pipeline version |
//########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Size| Block| Block| Block| | | WMMA| WMMA| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MWmmaPerWave| NWmmaPerWave| _MBlock_MWaveMPerWmma| ScalarPerVector| | |
//########################################| | | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerWmma| _NWaveNPerWmma| | |
//########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// generic instance
DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3<NDimSpatial, ALayout, BLayout, ck::Tuple<>, ELayout, ck::Tuple<F16, F16>, ck::Tuple<F16, F16>, F32, F16, ck::Tuple<>, F16, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 256, 128, 256, 32, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>
#ifndef ONE_INSTANCE_PER_LIST
,
DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3<NDimSpatial, ALayout, BLayout, ck::Tuple<>, ELayout, ck::Tuple<F16, F16>, ck::Tuple<F16, F16>, F32, F16, ck::Tuple<>, F16, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3<NDimSpatial, ALayout, BLayout, ck::Tuple<>, ELayout, ck::Tuple<F16, F16>, ck::Tuple<F16, F16>, F32, F16, ck::Tuple<>, F16, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 128, 128, 48, 64, 8, 8, 16, 16, 2, 3, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 2>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3<NDimSpatial, ALayout, BLayout, ck::Tuple<>, ELayout, ck::Tuple<F16, F16>, ck::Tuple<F16, F16>, F32, F16, ck::Tuple<>, F16, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 128, 128, 48, 64, 8, 8, 16, 16, 2, 3, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 64, 1, 2>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3<NDimSpatial, ALayout, BLayout, ck::Tuple<>, ELayout, ck::Tuple<F16, F16>, ck::Tuple<F16, F16>, F32, F16, ck::Tuple<>, F16, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 256, 64, 64, 64, 8, 8, 16, 16, 2, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3<NDimSpatial, ALayout, BLayout, ck::Tuple<>, ELayout, ck::Tuple<F16, F16>, ck::Tuple<F16, F16>, F32, F16, ck::Tuple<>, F16, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 256, 64, 64, 64, 8, 8, 16, 16, 2, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3<NDimSpatial, ALayout, BLayout, ck::Tuple<>, ELayout, ck::Tuple<F16, F16>, ck::Tuple<F16, F16>, F32, F16, ck::Tuple<>, F16, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3<NDimSpatial, ALayout, BLayout, ck::Tuple<>, ELayout, ck::Tuple<F16, F16>, ck::Tuple<F16, F16>, F32, F16, ck::Tuple<>, F16, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 128, 64, 32, 64, 8, 8, 16, 16, 2, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3<NDimSpatial, ALayout, BLayout, ck::Tuple<>, ELayout, ck::Tuple<F16, F16>, ck::Tuple<F16, F16>, F32, F16, ck::Tuple<>, F16, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 128, 64, 64, 64, 8, 8, 16, 16, 2, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3<NDimSpatial, ALayout, BLayout, ck::Tuple<>, ELayout, ck::Tuple<F16, F16>, ck::Tuple<F16, F16>, F32, F16, ck::Tuple<>, F16, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 128, 64, 64, 64, 8, 8, 16, 16, 2, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3<NDimSpatial, ALayout, BLayout, ck::Tuple<>, ELayout, ck::Tuple<F16, F16>, ck::Tuple<F16, F16>, F32, F16, ck::Tuple<>, F16, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 128, 64, 64, 64, 8, 8, 16, 16, 2, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3<NDimSpatial, ALayout, BLayout, ck::Tuple<>, ELayout, ck::Tuple<F16, F16>, ck::Tuple<F16, F16>, F32, F16, ck::Tuple<>, F16, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 128, 64, 96, 64, 8, 8, 16, 16, 2, 3, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3<NDimSpatial, ALayout, BLayout, ck::Tuple<>, ELayout, ck::Tuple<F16, F16>, ck::Tuple<F16, F16>, F32, F16, ck::Tuple<>, F16, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 128, 64, 96, 64, 8, 8, 16, 16, 2, 3, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3<NDimSpatial, ALayout, BLayout, ck::Tuple<>, ELayout, ck::Tuple<F16, F16>, ck::Tuple<F16, F16>, F32, F16, ck::Tuple<>, F16, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 128, 128, 96, 32, 8, 8, 16, 16, 4, 3, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3<NDimSpatial, ALayout, BLayout, ck::Tuple<>, ELayout, ck::Tuple<F16, F16>, ck::Tuple<F16, F16>, F32, F16, ck::Tuple<>, F16, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 128, 128, 96, 32, 8, 8, 16, 16, 4, 3, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3<NDimSpatial, ALayout, BLayout, ck::Tuple<>, ELayout, ck::Tuple<F16, F16>, ck::Tuple<F16, F16>, F32, F16, ck::Tuple<>, F16, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 128, 128, 96, 64, 8, 8, 16, 16, 4, 3, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3<NDimSpatial, ALayout, BLayout, ck::Tuple<>, ELayout, ck::Tuple<F16, F16>, ck::Tuple<F16, F16>, F32, F16, ck::Tuple<>, F16, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 256, 128, 256, 64, 8, 8, 16, 16, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3<NDimSpatial, ALayout, BLayout, ck::Tuple<>, ELayout, ck::Tuple<F16, F16>, ck::Tuple<F16, F16>, F32, F16, ck::Tuple<>, F16, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 256, 128, 96, 64, 8, 8, 16, 16, 2, 3, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 64, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3<NDimSpatial, ALayout, BLayout, ck::Tuple<>, ELayout, ck::Tuple<F16, F16>, ck::Tuple<F16, F16>, F32, F16, ck::Tuple<>, F16, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 256, 128, 96, 64, 8, 8, 16, 16, 2, 3, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3<NDimSpatial, ALayout, BLayout, ck::Tuple<>, ELayout, ck::Tuple<F16, F16>, ck::Tuple<F16, F16>, F32, F16, ck::Tuple<>, F16, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 256, 128, 96, 64, 8, 8, 16, 16, 2, 3, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 64, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3<NDimSpatial, ALayout, BLayout, ck::Tuple<>, ELayout, ck::Tuple<F16, F16>, ck::Tuple<F16, F16>, F32, F16, ck::Tuple<>, F16, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 128, 128, 64, 32, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3<NDimSpatial, ALayout, BLayout, ck::Tuple<>, ELayout, ck::Tuple<F16, F16>, ck::Tuple<F16, F16>, F32, F16, ck::Tuple<>, F16, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 128, 128, 64, 32, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3<NDimSpatial, ALayout, BLayout, ck::Tuple<>, ELayout, ck::Tuple<F16, F16>, ck::Tuple<F16, F16>, F32, F16, ck::Tuple<>, F16, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 128, 128, 64, 32, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3<NDimSpatial, ALayout, BLayout, ck::Tuple<>, ELayout, ck::Tuple<F16, F16>, ck::Tuple<F16, F16>, F32, F16, ck::Tuple<>, F16, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 128, 128, 32, 32, 8, 8, 16, 16, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3<NDimSpatial, ALayout, BLayout, ck::Tuple<>, ELayout, ck::Tuple<F16, F16>, ck::Tuple<F16, F16>, F32, F16, ck::Tuple<>, F16, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 128, 128, 32, 32, 8, 8, 16, 16, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3<NDimSpatial, ALayout, BLayout, ck::Tuple<>, ELayout, ck::Tuple<F16, F16>, ck::Tuple<F16, F16>, F32, F16, ck::Tuple<>, F16, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 128, 128, 32, 32, 8, 8, 16, 16, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3<NDimSpatial, ALayout, BLayout, ck::Tuple<>, ELayout, ck::Tuple<F16, F16>, ck::Tuple<F16, F16>, F32, F16, ck::Tuple<>, F16, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 256, 256, 64, 64, 8, 8, 16, 16, 4, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3<NDimSpatial, ALayout, BLayout, ck::Tuple<>, ELayout, ck::Tuple<F16, F16>, ck::Tuple<F16, F16>, F32, F16, ck::Tuple<>, F16, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 256, 128, 256, 32, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3<NDimSpatial, ALayout, BLayout, ck::Tuple<>, ELayout, ck::Tuple<F16, F16>, ck::Tuple<F16, F16>, F32, F16, ck::Tuple<>, F16, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 128, 128, 128, 32, 8, 8, 16, 16, 8, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3<NDimSpatial, ALayout, BLayout, ck::Tuple<>, ELayout, ck::Tuple<F16, F16>, ck::Tuple<F16, F16>, F32, F16, ck::Tuple<>, F16, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3<NDimSpatial, ALayout, BLayout, ck::Tuple<>, ELayout, ck::Tuple<F16, F16>, ck::Tuple<F16, F16>, F32, F16, ck::Tuple<>, F16, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 128, 64, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3<NDimSpatial, ALayout, BLayout, ck::Tuple<>, ELayout, ck::Tuple<F16, F16>, ck::Tuple<F16, F16>, F32, F16, ck::Tuple<>, F16, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 256, 64, 64, 32, 8, 8, 16, 16, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3<NDimSpatial, ALayout, BLayout, ck::Tuple<>, ELayout, ck::Tuple<F16, F16>, ck::Tuple<F16, F16>, F32, F16, ck::Tuple<>, F16, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 256, 64, 64, 32, 8, 8, 16, 16, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>
#endif
// clang-format on
>;
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -24,8 +24,14 @@
#include "grouped_convolution_forward_mem_intra_xdl.inc"
#endif
#ifdef CK_USE_WMMA
#define CK_USE_WMMA_OLD
#ifdef CK_USE_WMMA_OLD
#include "grouped_convolution_forward_wmma.inc"
#endif
#include "grouped_convolution_forward_wmma_cshufflev3.inc"
#include "grouped_convolution_forward_mem_inter_wmma_cshufflev3.inc"
#include "grouped_convolution_forward_mem_intra_wmma_cshufflev3.inc"
#endif
namespace ck {
namespace tensor_operation {
@@ -652,7 +658,9 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
#endif // CK_USE_XDL
#ifdef CK_USE_WMMA
// TODO: Put old Wmma instances back
#ifdef CK_USE_WMMA_OLD
if constexpr(NumDimSpatial == 2 && is_same_v<InLayout, GNHWC> &&
is_same_v<WeiLayout, GKYXC> && is_same_v<OutLayout, GNHWK>)
{
@@ -766,6 +774,93 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
}
#endif
}
#endif // CK_USE_WMMA_OLD
#ifdef CK_USE_WMMA
// layout NHWGC/GKYXC/NHWGK
if constexpr(NumDimSpatial == 2 && is_same_v<InLayout, NHWGC> &&
is_same_v<WeiLayout, GKYXC> && is_same_v<OutLayout, NHWGK>)
{
#ifdef CK_ENABLE_FP16
if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
is_same_v<OutDataType, half_t> && is_same_v<AComputeType, half_t> &&
is_same_v<BComputeType, half_t>)
{
add_device_grouped_conv2d_fwd_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_f16_instances_part1(
op_ptrs);
add_device_grouped_conv2d_fwd_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_f16_instances_part2(
op_ptrs);
add_device_grouped_conv2d_fwd_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_f16_instances_part3(
op_ptrs);
add_device_grouped_conv2d_fwd_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_f16_instances_part4(
op_ptrs);
// add_device_grouped_conv2d_fwd_wmma_cshufflev3_large_tensor_nhwgc_gkyxc_nhwgk_f16_instances(
// op_ptrs);
}
#endif
#ifdef CK_ENABLE_BF16
if constexpr(is_same_v<InDataType, ck::bhalf_t> &&
is_same_v<WeiDataType, ck::bhalf_t> &&
is_same_v<OutDataType, ck::bhalf_t> &&
is_same_v<AComputeType, ck::bhalf_t> &&
is_same_v<BComputeType, ck::bhalf_t>)
{
add_device_grouped_conv2d_fwd_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_bf16_instances_part1(
op_ptrs);
add_device_grouped_conv2d_fwd_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_bf16_instances_part2(
op_ptrs);
add_device_grouped_conv2d_fwd_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_bf16_instances_part3(
op_ptrs);
add_device_grouped_conv2d_fwd_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_bf16_instances_part4(
op_ptrs);
// add_device_grouped_conv2d_fwd_wmma_cshufflev3_large_tensor_nhwgc_gkyxc_nhwgk_bf16_instances(
// op_ptrs);
}
#endif
}
// 3D
// layout NDHWGC/GKZYXC/NDHWGK
if constexpr(NumDimSpatial == 3 && is_same_v<InLayout, NDHWGC> &&
is_same_v<WeiLayout, GKZYXC> && is_same_v<OutLayout, NDHWGK>)
{
#ifdef CK_ENABLE_FP16
if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
is_same_v<OutDataType, half_t> && is_same_v<AComputeType, half_t> &&
is_same_v<BComputeType, half_t>)
{
add_device_grouped_conv3d_fwd_wmma_cshufflev3_ndhwgc_gkzyxc_ndhwgk_f16_instances_part1(
op_ptrs);
add_device_grouped_conv3d_fwd_wmma_cshufflev3_ndhwgc_gkzyxc_ndhwgk_f16_instances_part2(
op_ptrs);
add_device_grouped_conv3d_fwd_wmma_cshufflev3_ndhwgc_gkzyxc_ndhwgk_f16_instances_part3(
op_ptrs);
add_device_grouped_conv3d_fwd_wmma_cshufflev3_ndhwgc_gkzyxc_ndhwgk_f16_instances_part4(
op_ptrs);
// add_device_grouped_conv3d_fwd_wmma_cshufflev3_large_tensor_ndhwgc_gkzyxc_ndhwgk_f16_instances(
// op_ptrs);
}
#endif
#ifdef CK_ENABLE_BF16
if constexpr(is_same_v<InDataType, ck::bhalf_t> &&
is_same_v<WeiDataType, ck::bhalf_t> &&
is_same_v<OutDataType, ck::bhalf_t> &&
is_same_v<AComputeType, ck::bhalf_t> &&
is_same_v<BComputeType, ck::bhalf_t>)
{
add_device_grouped_conv3d_fwd_wmma_cshufflev3_ndhwgc_gkzyxc_ndhwgk_bf16_instances_part1(
op_ptrs);
add_device_grouped_conv3d_fwd_wmma_cshufflev3_ndhwgc_gkzyxc_ndhwgk_bf16_instances_part2(
op_ptrs);
add_device_grouped_conv3d_fwd_wmma_cshufflev3_ndhwgc_gkzyxc_ndhwgk_bf16_instances_part3(
op_ptrs);
add_device_grouped_conv3d_fwd_wmma_cshufflev3_ndhwgc_gkzyxc_ndhwgk_bf16_instances_part4(
op_ptrs);
// add_device_grouped_conv3d_fwd_wmma_cshufflev3_large_tensor_ndhwgc_gkzyxc_ndhwgk_bf16_instances(
// op_ptrs);
}
#endif
}
#endif // CK_USE_WMMA
return op_ptrs;

View File

@@ -16,6 +16,10 @@
#include "grouped_convolution_forward_bias_clamp_xdl.inc"
#endif
#ifdef CK_USE_WMMA
#include "grouped_convolution_forward_bias_clamp_wmma_cshufflev3.inc"
#endif
namespace ck {
namespace tensor_operation {
namespace device {
@@ -269,6 +273,91 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
}
#endif // CK_USE_XDL
#ifdef CK_USE_WMMA
// layout NHWGC/GKYXC/NHWGK
if constexpr(NumDimSpatial == 2 && is_same_v<InLayout, NHWGC> &&
is_same_v<WeiLayout, GKYXC> && is_same_v<OutLayout, NHWGK>)
{
#ifdef CK_ENABLE_BF16
if constexpr(is_same_v<InDataType, ck::bhalf_t> &&
is_same_v<WeiDataType, ck::bhalf_t> &&
is_same_v<OutDataType, ck::bhalf_t> &&
is_same_v<AComputeType, ck::bhalf_t> &&
is_same_v<BComputeType, ck::bhalf_t>)
{
add_device_grouped_conv2d_fwd_bias_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_bf16_instances_part1(
op_ptrs);
add_device_grouped_conv2d_fwd_bias_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_bf16_instances_part2(
op_ptrs);
add_device_grouped_conv2d_fwd_bias_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_bf16_instances_part3(
op_ptrs);
add_device_grouped_conv2d_fwd_bias_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_bf16_instances_part4(
op_ptrs);
// add_device_grouped_conv2d_fwd_bias_clamp_wmma_cshufflev3_large_tensor_nhwgc_gkyxc_nhwgk_bf16_instances(
// op_ptrs);
}
#endif
#ifdef CK_ENABLE_FP16
if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
is_same_v<OutDataType, half_t> && is_same_v<AComputeType, half_t> &&
is_same_v<BComputeType, half_t>)
{
add_device_grouped_conv2d_fwd_bias_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_f16_instances_part1(
op_ptrs);
add_device_grouped_conv2d_fwd_bias_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_f16_instances_part2(
op_ptrs);
add_device_grouped_conv2d_fwd_bias_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_f16_instances_part3(
op_ptrs);
add_device_grouped_conv2d_fwd_bias_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_f16_instances_part4(
op_ptrs);
// add_device_grouped_conv2d_fwd_bias_clamp_wmma_cshufflev3_large_tensor_nhwgc_gkyxc_nhwgk_f16_instances(
// op_ptrs);
}
#endif
}
// layout NDHWGC/GKZYXC/NDHWGK
if constexpr(NumDimSpatial == 3 && is_same_v<InLayout, NDHWGC> &&
is_same_v<WeiLayout, GKZYXC> && is_same_v<OutLayout, NDHWGK>)
{
#ifdef CK_ENABLE_BF16
if constexpr(is_same_v<InDataType, ck::bhalf_t> &&
is_same_v<WeiDataType, ck::bhalf_t> &&
is_same_v<OutDataType, ck::bhalf_t> &&
is_same_v<AComputeType, ck::bhalf_t> &&
is_same_v<BComputeType, ck::bhalf_t>)
{
add_device_grouped_conv3d_fwd_bias_clamp_wmma_cshufflev3_ndhwgc_gkzyxc_ndhwgk_bf16_instances_part1(
op_ptrs);
add_device_grouped_conv3d_fwd_bias_clamp_wmma_cshufflev3_ndhwgc_gkzyxc_ndhwgk_bf16_instances_part2(
op_ptrs);
add_device_grouped_conv3d_fwd_bias_clamp_wmma_cshufflev3_ndhwgc_gkzyxc_ndhwgk_bf16_instances_part3(
op_ptrs);
add_device_grouped_conv3d_fwd_bias_clamp_wmma_cshufflev3_ndhwgc_gkzyxc_ndhwgk_bf16_instances_part4(
op_ptrs);
// add_device_grouped_conv3d_fwd_bias_clamp_wmma_cshufflev3_large_tensor_ndhwgc_gkzyxc_ndhwgk_bf16_instances(
// op_ptrs);
}
#endif
#ifdef CK_ENABLE_FP16
if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
is_same_v<OutDataType, half_t> && is_same_v<AComputeType, half_t> &&
is_same_v<BComputeType, half_t>)
{
add_device_grouped_conv3d_fwd_bias_clamp_wmma_cshufflev3_ndhwgc_gkzyxc_ndhwgk_f16_instances_part1(
op_ptrs);
add_device_grouped_conv3d_fwd_bias_clamp_wmma_cshufflev3_ndhwgc_gkzyxc_ndhwgk_f16_instances_part2(
op_ptrs);
add_device_grouped_conv3d_fwd_bias_clamp_wmma_cshufflev3_ndhwgc_gkzyxc_ndhwgk_f16_instances_part3(
op_ptrs);
add_device_grouped_conv3d_fwd_bias_clamp_wmma_cshufflev3_ndhwgc_gkzyxc_ndhwgk_f16_instances_part4(
op_ptrs);
// add_device_grouped_conv3d_fwd_bias_clamp_wmma_cshufflev3_large_tensor_ndhwgc_gkzyxc_ndhwgk_f16_instances(
// op_ptrs);
}
#endif
}
#endif // CK_USE_WMMA
return op_ptrs;
}
};

View File

@@ -0,0 +1,294 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
#ifdef CK_ENABLE_BF16
void add_device_grouped_conv2d_fwd_bias_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_bf16_instances_part1(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC,
GKYXC,
Tuple<NHWGK>,
NHWGK,
BF16,
BF16,
Tuple<BF16>,
BF16,
PassThrough,
PassThrough,
AddClamp>>>& instances);
void add_device_grouped_conv2d_fwd_bias_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_bf16_instances_part2(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC,
GKYXC,
Tuple<NHWGK>,
NHWGK,
BF16,
BF16,
Tuple<BF16>,
BF16,
PassThrough,
PassThrough,
AddClamp>>>& instances);
void add_device_grouped_conv2d_fwd_bias_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_bf16_instances_part3(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC,
GKYXC,
Tuple<NHWGK>,
NHWGK,
BF16,
BF16,
Tuple<BF16>,
BF16,
PassThrough,
PassThrough,
AddClamp>>>& instances);
void add_device_grouped_conv2d_fwd_bias_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_bf16_instances_part4(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC,
GKYXC,
Tuple<NHWGK>,
NHWGK,
BF16,
BF16,
Tuple<BF16>,
BF16,
PassThrough,
PassThrough,
AddClamp>>>& instances);
// void
// add_device_grouped_conv2d_fwd_bias_clamp_wmma_cshufflev3_large_tensor_nhwgc_gkyxc_nhwgk_bf16_instances(
// std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
// NHWGC,
// GKYXC,
// Tuple<NHWGK>,
// NHWGK,
// BF16,
// BF16,
// Tuple<BF16>,
// BF16,
// PassThrough,
// PassThrough,
// AddClamp>>>& instances);
void add_device_grouped_conv3d_fwd_bias_clamp_wmma_cshufflev3_ndhwgc_gkzyxc_ndhwgk_bf16_instances_part1(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
NDHWGC,
GKZYXC,
Tuple<NDHWGK>,
NDHWGK,
BF16,
BF16,
Tuple<BF16>,
BF16,
PassThrough,
PassThrough,
AddClamp>>>& instances);
void add_device_grouped_conv3d_fwd_bias_clamp_wmma_cshufflev3_ndhwgc_gkzyxc_ndhwgk_bf16_instances_part2(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
NDHWGC,
GKZYXC,
Tuple<NDHWGK>,
NDHWGK,
BF16,
BF16,
Tuple<BF16>,
BF16,
PassThrough,
PassThrough,
AddClamp>>>& instances);
void add_device_grouped_conv3d_fwd_bias_clamp_wmma_cshufflev3_ndhwgc_gkzyxc_ndhwgk_bf16_instances_part3(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
NDHWGC,
GKZYXC,
Tuple<NDHWGK>,
NDHWGK,
BF16,
BF16,
Tuple<BF16>,
BF16,
PassThrough,
PassThrough,
AddClamp>>>& instances);
void add_device_grouped_conv3d_fwd_bias_clamp_wmma_cshufflev3_ndhwgc_gkzyxc_ndhwgk_bf16_instances_part4(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
NDHWGC,
GKZYXC,
Tuple<NDHWGK>,
NDHWGK,
BF16,
BF16,
Tuple<BF16>,
BF16,
PassThrough,
PassThrough,
AddClamp>>>& instances);
// void
// add_device_grouped_conv3d_fwd_bias_clamp_wmma_cshufflev3_large_tensor_ndhwgc_gkzyxc_ndhwgk_bf16_instances(
// std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
// NDHWGC,
// GKZYXC,
// Tuple<NDHWGK>,
// NDHWGK,
// BF16,
// BF16,
// Tuple<BF16>,
// BF16,
// PassThrough,
// PassThrough,
// AddClamp>>>& instances);
#endif
#ifdef CK_ENABLE_FP16
void add_device_grouped_conv2d_fwd_bias_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_f16_instances_part1(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC,
GKYXC,
Tuple<NHWGK>,
NHWGK,
F16,
F16,
Tuple<F16>,
F16,
PassThrough,
PassThrough,
AddClamp>>>& instances);
void add_device_grouped_conv2d_fwd_bias_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_f16_instances_part2(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC,
GKYXC,
Tuple<NHWGK>,
NHWGK,
F16,
F16,
Tuple<F16>,
F16,
PassThrough,
PassThrough,
AddClamp>>>& instances);
void add_device_grouped_conv2d_fwd_bias_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_f16_instances_part3(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC,
GKYXC,
Tuple<NHWGK>,
NHWGK,
F16,
F16,
Tuple<F16>,
F16,
PassThrough,
PassThrough,
AddClamp>>>& instances);
void add_device_grouped_conv2d_fwd_bias_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_f16_instances_part4(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC,
GKYXC,
Tuple<NHWGK>,
NHWGK,
F16,
F16,
Tuple<F16>,
F16,
PassThrough,
PassThrough,
AddClamp>>>& instances);
// void
// add_device_grouped_conv2d_fwd_bias_clamp_wmma_cshufflev3_large_tensor_nhwgc_gkyxc_nhwgk_f16_instances(
// std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
// NHWGC,
// GKYXC,
// Tuple<NHWGK>,
// NHWGK,
// F16,
// F16,
// Tuple<F16>,
// F16,
// PassThrough,
// PassThrough,
// AddClamp>>>& instances);
void add_device_grouped_conv3d_fwd_bias_clamp_wmma_cshufflev3_ndhwgc_gkzyxc_ndhwgk_f16_instances_part1(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
NDHWGC,
GKZYXC,
Tuple<NDHWGK>,
NDHWGK,
F16,
F16,
Tuple<F16>,
F16,
PassThrough,
PassThrough,
AddClamp>>>& instances);
void add_device_grouped_conv3d_fwd_bias_clamp_wmma_cshufflev3_ndhwgc_gkzyxc_ndhwgk_f16_instances_part2(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
NDHWGC,
GKZYXC,
Tuple<NDHWGK>,
NDHWGK,
F16,
F16,
Tuple<F16>,
F16,
PassThrough,
PassThrough,
AddClamp>>>& instances);
void add_device_grouped_conv3d_fwd_bias_clamp_wmma_cshufflev3_ndhwgc_gkzyxc_ndhwgk_f16_instances_part3(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
NDHWGC,
GKZYXC,
Tuple<NDHWGK>,
NDHWGK,
F16,
F16,
Tuple<F16>,
F16,
PassThrough,
PassThrough,
AddClamp>>>& instances);
void add_device_grouped_conv3d_fwd_bias_clamp_wmma_cshufflev3_ndhwgc_gkzyxc_ndhwgk_f16_instances_part4(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
NDHWGC,
GKZYXC,
Tuple<NDHWGK>,
NDHWGK,
F16,
F16,
Tuple<F16>,
F16,
PassThrough,
PassThrough,
AddClamp>>>& instances);
// void
// add_device_grouped_conv3d_fwd_bias_clamp_wmma_cshufflev3_large_tensor_ndhwgc_gkzyxc_ndhwgk_f16_instances(
// std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
// NDHWGC,
// GKZYXC,
// Tuple<NDHWGK>,
// NDHWGK,
// F16,
// F16,
// Tuple<F16>,
// F16,
// PassThrough,
// PassThrough,
// AddClamp>>>& instances);
#endif
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -16,6 +16,10 @@
#include "grouped_convolution_forward_clamp_xdl.inc"
#endif
#ifdef CK_USE_WMMA
#include "grouped_convolution_forward_clamp_wmma_cshufflev3.inc"
#endif
namespace ck {
namespace tensor_operation {
namespace device {
@@ -266,6 +270,91 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
}
#endif // CK_USE_XDL
#ifdef CK_USE_WMMA
// layout NHWGC/GKYXC/NHWGK
if constexpr(NumDimSpatial == 2 && is_same_v<InLayout, NHWGC> &&
is_same_v<WeiLayout, GKYXC> && is_same_v<OutLayout, NHWGK>)
{
#ifdef CK_ENABLE_BF16
if constexpr(is_same_v<InDataType, ck::bhalf_t> &&
is_same_v<WeiDataType, ck::bhalf_t> &&
is_same_v<OutDataType, ck::bhalf_t> &&
is_same_v<AComputeType, ck::bhalf_t> &&
is_same_v<BComputeType, ck::bhalf_t>)
{
add_device_grouped_conv2d_fwd_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_bf16_instances_part1(
op_ptrs);
add_device_grouped_conv2d_fwd_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_bf16_instances_part2(
op_ptrs);
add_device_grouped_conv2d_fwd_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_bf16_instances_part3(
op_ptrs);
add_device_grouped_conv2d_fwd_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_bf16_instances_part4(
op_ptrs);
// add_device_grouped_conv2d_fwd_clamp_wmma_cshufflev3_large_tensor_nhwgc_gkyxc_nhwgk_bf16_instances(
// op_ptrs);
}
#endif
#ifdef CK_ENABLE_FP16
if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
is_same_v<OutDataType, half_t> && is_same_v<AComputeType, half_t> &&
is_same_v<BComputeType, half_t>)
{
add_device_grouped_conv2d_fwd_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_f16_instances_part1(
op_ptrs);
add_device_grouped_conv2d_fwd_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_f16_instances_part2(
op_ptrs);
add_device_grouped_conv2d_fwd_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_f16_instances_part3(
op_ptrs);
add_device_grouped_conv2d_fwd_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_f16_instances_part4(
op_ptrs);
// add_device_grouped_conv2d_fwd_clamp_wmma_cshufflev3_large_tensor_nhwgc_gkyxc_nhwgk_f16_instances(
// op_ptrs);
}
#endif
}
// layout NDHWGC/GKZYXC/NDHWGK
if constexpr(NumDimSpatial == 3 && is_same_v<InLayout, NDHWGC> &&
is_same_v<WeiLayout, GKZYXC> && is_same_v<OutLayout, NDHWGK>)
{
#ifdef CK_ENABLE_BF16
if constexpr(is_same_v<InDataType, ck::bhalf_t> &&
is_same_v<WeiDataType, ck::bhalf_t> &&
is_same_v<OutDataType, ck::bhalf_t> &&
is_same_v<AComputeType, ck::bhalf_t> &&
is_same_v<BComputeType, ck::bhalf_t>)
{
add_device_grouped_conv3d_fwd_clamp_wmma_cshufflev3_ndhwgc_gkzyxc_ndhwgk_bf16_instances_part1(
op_ptrs);
add_device_grouped_conv3d_fwd_clamp_wmma_cshufflev3_ndhwgc_gkzyxc_ndhwgk_bf16_instances_part2(
op_ptrs);
add_device_grouped_conv3d_fwd_clamp_wmma_cshufflev3_ndhwgc_gkzyxc_ndhwgk_bf16_instances_part3(
op_ptrs);
add_device_grouped_conv3d_fwd_clamp_wmma_cshufflev3_ndhwgc_gkzyxc_ndhwgk_bf16_instances_part4(
op_ptrs);
// add_device_grouped_conv3d_fwd_clamp_wmma_cshufflev3_large_tensor_ndhwgc_gkzyxc_ndhwgk_bf16_instances(
// op_ptrs);
}
#endif
#ifdef CK_ENABLE_FP16
if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
is_same_v<OutDataType, half_t> && is_same_v<AComputeType, half_t> &&
is_same_v<BComputeType, half_t>)
{
add_device_grouped_conv3d_fwd_clamp_wmma_cshufflev3_ndhwgc_gkzyxc_ndhwgk_f16_instances_part1(
op_ptrs);
add_device_grouped_conv3d_fwd_clamp_wmma_cshufflev3_ndhwgc_gkzyxc_ndhwgk_f16_instances_part2(
op_ptrs);
add_device_grouped_conv3d_fwd_clamp_wmma_cshufflev3_ndhwgc_gkzyxc_ndhwgk_f16_instances_part3(
op_ptrs);
add_device_grouped_conv3d_fwd_clamp_wmma_cshufflev3_ndhwgc_gkzyxc_ndhwgk_f16_instances_part4(
op_ptrs);
// add_device_grouped_conv3d_fwd_clamp_wmma_cshufflev3_large_tensor_ndhwgc_gkzyxc_ndhwgk_f16_instances(
// op_ptrs);
}
#endif
}
#endif // CK_USE_WMMA
return op_ptrs;
}
};

View File

@@ -0,0 +1,294 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
#ifdef CK_ENABLE_BF16
void add_device_grouped_conv2d_fwd_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_bf16_instances_part1(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC,
GKYXC,
Tuple<>,
NHWGK,
BF16,
BF16,
Tuple<>,
BF16,
PassThrough,
PassThrough,
Clamp>>>& instances);
void add_device_grouped_conv2d_fwd_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_bf16_instances_part2(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC,
GKYXC,
Tuple<>,
NHWGK,
BF16,
BF16,
Tuple<>,
BF16,
PassThrough,
PassThrough,
Clamp>>>& instances);
void add_device_grouped_conv2d_fwd_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_bf16_instances_part3(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC,
GKYXC,
Tuple<>,
NHWGK,
BF16,
BF16,
Tuple<>,
BF16,
PassThrough,
PassThrough,
Clamp>>>& instances);
void add_device_grouped_conv2d_fwd_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_bf16_instances_part4(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC,
GKYXC,
Tuple<>,
NHWGK,
BF16,
BF16,
Tuple<>,
BF16,
PassThrough,
PassThrough,
Clamp>>>& instances);
// void
// add_device_grouped_conv2d_fwd_clamp_wmma_cshufflev3_large_tensor_nhwgc_gkyxc_nhwgk_bf16_instances(
// std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
// NHWGC,
// GKYXC,
// Tuple<>,
// NHWGK,
// BF16,
// BF16,
// Tuple<>,
// BF16,
// PassThrough,
// PassThrough,
// Clamp>>>& instances);
void add_device_grouped_conv3d_fwd_clamp_wmma_cshufflev3_ndhwgc_gkzyxc_ndhwgk_bf16_instances_part1(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
NDHWGC,
GKZYXC,
Tuple<>,
NDHWGK,
BF16,
BF16,
Tuple<>,
BF16,
PassThrough,
PassThrough,
Clamp>>>& instances);
void add_device_grouped_conv3d_fwd_clamp_wmma_cshufflev3_ndhwgc_gkzyxc_ndhwgk_bf16_instances_part2(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
NDHWGC,
GKZYXC,
Tuple<>,
NDHWGK,
BF16,
BF16,
Tuple<>,
BF16,
PassThrough,
PassThrough,
Clamp>>>& instances);
void add_device_grouped_conv3d_fwd_clamp_wmma_cshufflev3_ndhwgc_gkzyxc_ndhwgk_bf16_instances_part3(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
NDHWGC,
GKZYXC,
Tuple<>,
NDHWGK,
BF16,
BF16,
Tuple<>,
BF16,
PassThrough,
PassThrough,
Clamp>>>& instances);
void add_device_grouped_conv3d_fwd_clamp_wmma_cshufflev3_ndhwgc_gkzyxc_ndhwgk_bf16_instances_part4(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
NDHWGC,
GKZYXC,
Tuple<>,
NDHWGK,
BF16,
BF16,
Tuple<>,
BF16,
PassThrough,
PassThrough,
Clamp>>>& instances);
// void
// add_device_grouped_conv3d_fwd_clamp_wmma_cshufflev3_large_tensor_ndhwgc_gkzyxc_ndhwgk_bf16_instances(
// std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
// NDHWGC,
// GKZYXC,
// Tuple<>,
// NDHWGK,
// BF16,
// BF16,
// Tuple<>,
// BF16,
// PassThrough,
// PassThrough,
// Clamp>>>& instances);
#endif
#ifdef CK_ENABLE_FP16
void add_device_grouped_conv2d_fwd_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_f16_instances_part1(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC,
GKYXC,
Tuple<>,
NHWGK,
F16,
F16,
Tuple<>,
F16,
PassThrough,
PassThrough,
Clamp>>>& instances);
void add_device_grouped_conv2d_fwd_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_f16_instances_part2(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC,
GKYXC,
Tuple<>,
NHWGK,
F16,
F16,
Tuple<>,
F16,
PassThrough,
PassThrough,
Clamp>>>& instances);
void add_device_grouped_conv2d_fwd_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_f16_instances_part3(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC,
GKYXC,
Tuple<>,
NHWGK,
F16,
F16,
Tuple<>,
F16,
PassThrough,
PassThrough,
Clamp>>>& instances);
void add_device_grouped_conv2d_fwd_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_f16_instances_part4(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC,
GKYXC,
Tuple<>,
NHWGK,
F16,
F16,
Tuple<>,
F16,
PassThrough,
PassThrough,
Clamp>>>& instances);
void add_device_grouped_conv3d_fwd_clamp_wmma_cshufflev3_ndhwgc_gkzyxc_ndhwgk_f16_instances_part1(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
NDHWGC,
GKZYXC,
Tuple<>,
NDHWGK,
F16,
F16,
Tuple<>,
F16,
PassThrough,
PassThrough,
Clamp>>>& instances);
void add_device_grouped_conv3d_fwd_clamp_wmma_cshufflev3_ndhwgc_gkzyxc_ndhwgk_f16_instances_part2(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
NDHWGC,
GKZYXC,
Tuple<>,
NDHWGK,
F16,
F16,
Tuple<>,
F16,
PassThrough,
PassThrough,
Clamp>>>& instances);
void add_device_grouped_conv3d_fwd_clamp_wmma_cshufflev3_ndhwgc_gkzyxc_ndhwgk_f16_instances_part3(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
NDHWGC,
GKZYXC,
Tuple<>,
NDHWGK,
F16,
F16,
Tuple<>,
F16,
PassThrough,
PassThrough,
Clamp>>>& instances);
void add_device_grouped_conv3d_fwd_clamp_wmma_cshufflev3_ndhwgc_gkzyxc_ndhwgk_f16_instances_part4(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
NDHWGC,
GKZYXC,
Tuple<>,
NDHWGK,
F16,
F16,
Tuple<>,
F16,
PassThrough,
PassThrough,
Clamp>>>& instances);
// void
// add_device_grouped_conv2d_fwd_clamp_wmma_cshufflev3_large_tensor_nhwgc_gkyxc_nhwgk_f16_instances(
// std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
// NHWGC,
// GKYXC,
// Tuple<>,
// NHWGK,
// F16,
// F16,
// Tuple<>,
// F16,
// PassThrough,
// PassThrough,
// Clamp>>>& instances);
// void
// add_device_grouped_conv3d_fwd_clamp_wmma_cshufflev3_large_tensor_ndhwgc_gkzyxc_ndhwgk_f16_instances(
// std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
// NDHWGC,
// GKZYXC,
// Tuple<>,
// NDHWGK,
// F16,
// F16,
// Tuple<>,
// F16,
// PassThrough,
// PassThrough,
// Clamp>>>& instances);
#endif
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,30 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
#ifdef CK_ENABLE_INT8
void add_device_grouped_conv2d_fwd_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_int8_mem_inter_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC,
GKYXC,
Empty_Tuple,
NHWGK,
int8_t,
int8_t,
Empty_Tuple,
int8_t,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,30 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
#ifdef CK_ENABLE_INT8
void add_device_grouped_conv2d_fwd_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_int8_mem_intra_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC,
GKYXC,
Empty_Tuple,
NHWGK,
int8_t,
int8_t,
Empty_Tuple,
int8_t,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -21,6 +21,7 @@ namespace instance {
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using ScaleAdd = ck::tensor_operation::element_wise::ScaleAdd;
#ifdef CK_USE_XDL
#ifdef CK_ENABLE_BF16
// grouped conv3d forward multi AB scaleadd, NDHWGC/GKZYXC/NDHWGK
void add_device_grouped_conv3d_fwd_xdl_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_bf16_instances(
@@ -85,6 +86,42 @@ void add_device_grouped_conv3d_fwd_xdl_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_int8_ins
ScaleAdd,
PassThrough>>>& instances);
#endif
#endif // CK_USE_XDL
#ifdef CK_USE_WMMA
#ifdef CK_ENABLE_BF16
// grouped conv3d forward multi AB scaleadd, NDHWGC/GKZYXC/NDHWGK
void add_device_grouped_conv3d_fwd_wmma_cshufflev3_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_bf16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
NDHWGC,
GKZYXC,
ck::Tuple<>,
NDHWGK,
ck::Tuple<BF16, BF16>,
ck::Tuple<BF16, BF16>,
ck::Tuple<>,
BF16,
ScaleAdd,
ScaleAdd,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_FP16
void add_device_grouped_conv3d_fwd_wmma_cshufflev3_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
NDHWGC,
GKZYXC,
ck::Tuple<>,
NDHWGK,
ck::Tuple<F16, F16>,
ck::Tuple<F16, F16>,
ck::Tuple<>,
F16,
ScaleAdd,
ScaleAdd,
PassThrough>>>& instances);
#endif
#endif // CK_USE_WMMA
template <ck::index_t NumDimSpatial,
typename InLayout,
@@ -129,6 +166,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
static auto GetInstances()
{
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
#ifdef CK_USE_XDL
if constexpr(NumDimSpatial == 3 && is_same_v<InLayout, NDHWGC> &&
is_same_v<WeiLayout, GKZYXC> && is_same_v<OutLayout, NDHWGK>)
{
@@ -169,6 +207,32 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
}
#endif
}
#endif // CK_USE_XDL
#ifdef CK_USE_WMMA
if constexpr(NumDimSpatial == 3 && is_same_v<InLayout, NDHWGC> &&
is_same_v<WeiLayout, GKZYXC> && is_same_v<OutLayout, NDHWGK>)
{
#ifdef CK_ENABLE_FP16
if constexpr(is_same_v<InDataType, ck::Tuple<half_t, half_t>> &&
is_same_v<WeiDataType, ck::Tuple<half_t, half_t>> &&
is_same_v<OutDataType, half_t> && is_same_v<ComputeType, half_t>)
{
add_device_grouped_conv3d_fwd_wmma_cshufflev3_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_f16_instances(
op_ptrs);
}
#endif
#ifdef CK_ENABLE_BF16
if constexpr(is_same_v<InDataType, ck::Tuple<ck::bhalf_t, ck::bhalf_t>> &&
is_same_v<WeiDataType, ck::Tuple<ck::bhalf_t, ck::bhalf_t>> &&
is_same_v<OutDataType, ck::bhalf_t> && is_same_v<ComputeType, ck::bhalf_t>)
{
add_device_grouped_conv3d_fwd_wmma_cshufflev3_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_bf16_instances(
op_ptrs);
}
#endif
}
#endif // CK_USE_WMMA
return op_ptrs;
}

View File

@@ -0,0 +1,236 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
// grouped conv2d forward, NHWGC/GKYXC/NHWGK
#ifdef CK_ENABLE_BF16
void add_device_grouped_conv2d_fwd_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_bf16_instances_part1(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC,
GKYXC,
Empty_Tuple,
NHWGK,
BF16,
BF16,
Empty_Tuple,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv2d_fwd_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_bf16_instances_part2(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC,
GKYXC,
Empty_Tuple,
NHWGK,
BF16,
BF16,
Empty_Tuple,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv2d_fwd_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_bf16_instances_part3(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC,
GKYXC,
Empty_Tuple,
NHWGK,
BF16,
BF16,
Empty_Tuple,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv2d_fwd_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_bf16_instances_part4(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC,
GKYXC,
Empty_Tuple,
NHWGK,
BF16,
BF16,
Empty_Tuple,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_FP16
void add_device_grouped_conv2d_fwd_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_f16_instances_part1(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC,
GKYXC,
Empty_Tuple,
NHWGK,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv2d_fwd_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_f16_instances_part2(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC,
GKYXC,
Empty_Tuple,
NHWGK,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv2d_fwd_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_f16_instances_part3(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC,
GKYXC,
Empty_Tuple,
NHWGK,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv2d_fwd_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_f16_instances_part4(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC,
GKYXC,
Empty_Tuple,
NHWGK,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
// grouped conv3d forward, NDHWGC/GKZYXC/NDHWGK
#ifdef CK_ENABLE_BF16
void add_device_grouped_conv3d_fwd_wmma_cshufflev3_ndhwgc_gkzyxc_ndhwgk_bf16_instances_part1(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
NDHWGC,
GKZYXC,
Empty_Tuple,
NDHWGK,
BF16,
BF16,
Empty_Tuple,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv3d_fwd_wmma_cshufflev3_ndhwgc_gkzyxc_ndhwgk_bf16_instances_part2(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
NDHWGC,
GKZYXC,
Empty_Tuple,
NDHWGK,
BF16,
BF16,
Empty_Tuple,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv3d_fwd_wmma_cshufflev3_ndhwgc_gkzyxc_ndhwgk_bf16_instances_part3(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
NDHWGC,
GKZYXC,
Empty_Tuple,
NDHWGK,
BF16,
BF16,
Empty_Tuple,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv3d_fwd_wmma_cshufflev3_ndhwgc_gkzyxc_ndhwgk_bf16_instances_part4(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
NDHWGC,
GKZYXC,
Empty_Tuple,
NDHWGK,
BF16,
BF16,
Empty_Tuple,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_FP16
void add_device_grouped_conv3d_fwd_wmma_cshufflev3_ndhwgc_gkzyxc_ndhwgk_f16_instances_part1(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
NDHWGC,
GKZYXC,
Empty_Tuple,
NDHWGK,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv3d_fwd_wmma_cshufflev3_ndhwgc_gkzyxc_ndhwgk_f16_instances_part2(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
NDHWGC,
GKZYXC,
Empty_Tuple,
NDHWGK,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv3d_fwd_wmma_cshufflev3_ndhwgc_gkzyxc_ndhwgk_f16_instances_part3(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
NDHWGC,
GKZYXC,
Empty_Tuple,
NDHWGK,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv3d_fwd_wmma_cshufflev3_ndhwgc_gkzyxc_ndhwgk_f16_instances_part4(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
NDHWGC,
GKZYXC,
Empty_Tuple,
NDHWGK,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -98,6 +98,11 @@ function(add_instance_library INSTANCE_NAME)
message(DEBUG "removing gemm_wmma_universal_f8 instance ${source} ")
list(REMOVE_ITEM ARGN "${source}")
endif()
# Do not build WMMA grouped conv 3d fwd fp8 / bf8 for any targets except gfx12+
if(NOT INST_TARGETS MATCHES "gfx12" AND source_name MATCHES "grouped_conv3d_fwd_wmma" AND (source_name MATCHES "_fp8_" OR source_name MATCHES "_bf8_"))
message(DEBUG "removing grouped_conv3d_fwd_wmma fp8/bf8 instance ${source} ")
list(REMOVE_ITEM ARGN "${source}")
endif()
# Do not build gemm_universal_preshuffle_f8 for any targets except gfx94, gfx95 and gfx12
if(NOT (INST_TARGETS MATCHES "gfx94|gfx95|gfx12") AND (source_name MATCHES "gemm_universal_preshuffle") AND source_name MATCHES "_f8_" )
message(DEBUG "removing gemm_universal_preshuffle_f8 instance ${source} ")

View File

@@ -89,10 +89,11 @@ set(GROUPED_CONV2D_FWD
# GNHWC, GKYXC, GNHWK
dl/device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f16_instance.cpp
dl/device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f32_instance.cpp
# NHWGC, GKYXC, NHWGK
# # NHWGC, GKYXC, NHWGK
dl/device_grouped_conv2d_fwd_dl_nhwgc_gkyxc_nhwgk_f16_instance.cpp
dl/device_grouped_conv2d_fwd_dl_nhwgc_gkyxc_nhwgk_f32_instance.cpp
# WMMA
# WMMA_OLD
# GNHWC, GKYXC, GNHWK
wmma/device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_f16_instance.cpp
wmma/device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_i8_instance.cpp
@@ -111,6 +112,17 @@ set(GROUPED_CONV2D_FWD
wmma/device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_i8_1x1s1p0_instance.cpp
wmma/device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_f16_oddc_instance.cpp
wmma/device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_i8_oddc_instance.cpp
# WMMA CSHUFFLEV3
## NHWGC, GKYXC, NHWGK
wmma/device_grouped_conv2d_fwd_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_bf16_instance_part1.cpp
wmma/device_grouped_conv2d_fwd_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_bf16_instance_part2.cpp
wmma/device_grouped_conv2d_fwd_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_bf16_instance_part3.cpp
wmma/device_grouped_conv2d_fwd_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_bf16_instance_part4.cpp
wmma/device_grouped_conv2d_fwd_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_f16_instance_part1.cpp
wmma/device_grouped_conv2d_fwd_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_f16_instance_part2.cpp
wmma/device_grouped_conv2d_fwd_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_f16_instance_part3.cpp
wmma/device_grouped_conv2d_fwd_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_f16_instance_part4.cpp
)
# Add generated files for sharded instantiations.
include(ShardInstantiation)

View File

@@ -0,0 +1,65 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_grouped_conv2d_fwd_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_bf16_instances_part1(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC,
GKYXC,
Empty_Tuple,
NHWGK,
BF16,
BF16,
Empty_Tuple,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
add_device_operation_instances(
instances,
device_grouped_conv_fwd_wmma_cshufflev3_bf16_instances_part1<2,
NHWGC,
GKYXC,
Empty_Tuple,
NHWGK,
ConvFwdDefault>{});
add_device_operation_instances(
instances,
device_grouped_conv_fwd_wmma_cshufflev3_bf16_instances_part1<2,
NHWGC,
GKYXC,
Empty_Tuple,
NHWGK,
ConvFwd1x1P0>{});
add_device_operation_instances(
instances,
device_grouped_conv_fwd_wmma_cshufflev3_bf16_instances_part1<2,
NHWGC,
GKYXC,
Empty_Tuple,
NHWGK,
ConvFwd1x1S1P0>{});
add_device_operation_instances(
instances,
device_grouped_conv_fwd_wmma_cshufflev3_bf16_instances_part1<2,
NHWGC,
GKYXC,
Empty_Tuple,
NHWGK,
ConvFwdOddC>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,65 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_grouped_conv2d_fwd_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_bf16_instances_part2(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC,
GKYXC,
Empty_Tuple,
NHWGK,
BF16,
BF16,
Empty_Tuple,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
add_device_operation_instances(
instances,
device_grouped_conv_fwd_wmma_cshufflev3_bf16_instances_part2<2,
NHWGC,
GKYXC,
Empty_Tuple,
NHWGK,
ConvFwdDefault>{});
add_device_operation_instances(
instances,
device_grouped_conv_fwd_wmma_cshufflev3_bf16_instances_part2<2,
NHWGC,
GKYXC,
Empty_Tuple,
NHWGK,
ConvFwd1x1P0>{});
add_device_operation_instances(
instances,
device_grouped_conv_fwd_wmma_cshufflev3_bf16_instances_part2<2,
NHWGC,
GKYXC,
Empty_Tuple,
NHWGK,
ConvFwd1x1S1P0>{});
add_device_operation_instances(
instances,
device_grouped_conv_fwd_wmma_cshufflev3_bf16_instances_part2<2,
NHWGC,
GKYXC,
Empty_Tuple,
NHWGK,
ConvFwdOddC>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,65 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_grouped_conv2d_fwd_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_bf16_instances_part3(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC,
GKYXC,
Empty_Tuple,
NHWGK,
BF16,
BF16,
Empty_Tuple,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
add_device_operation_instances(
instances,
device_grouped_conv_fwd_wmma_cshufflev3_bf16_instances_part3<2,
NHWGC,
GKYXC,
Empty_Tuple,
NHWGK,
ConvFwdDefault>{});
add_device_operation_instances(
instances,
device_grouped_conv_fwd_wmma_cshufflev3_bf16_instances_part3<2,
NHWGC,
GKYXC,
Empty_Tuple,
NHWGK,
ConvFwd1x1P0>{});
add_device_operation_instances(
instances,
device_grouped_conv_fwd_wmma_cshufflev3_bf16_instances_part3<2,
NHWGC,
GKYXC,
Empty_Tuple,
NHWGK,
ConvFwd1x1S1P0>{});
add_device_operation_instances(
instances,
device_grouped_conv_fwd_wmma_cshufflev3_bf16_instances_part3<2,
NHWGC,
GKYXC,
Empty_Tuple,
NHWGK,
ConvFwdOddC>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,65 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_grouped_conv2d_fwd_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_bf16_instances_part4(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC,
GKYXC,
Empty_Tuple,
NHWGK,
BF16,
BF16,
Empty_Tuple,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
add_device_operation_instances(
instances,
device_grouped_conv_fwd_wmma_cshufflev3_bf16_instances_part4<2,
NHWGC,
GKYXC,
Empty_Tuple,
NHWGK,
ConvFwdDefault>{});
add_device_operation_instances(
instances,
device_grouped_conv_fwd_wmma_cshufflev3_bf16_instances_part4<2,
NHWGC,
GKYXC,
Empty_Tuple,
NHWGK,
ConvFwd1x1P0>{});
add_device_operation_instances(
instances,
device_grouped_conv_fwd_wmma_cshufflev3_bf16_instances_part4<2,
NHWGC,
GKYXC,
Empty_Tuple,
NHWGK,
ConvFwd1x1S1P0>{});
add_device_operation_instances(
instances,
device_grouped_conv_fwd_wmma_cshufflev3_bf16_instances_part4<2,
NHWGC,
GKYXC,
Empty_Tuple,
NHWGK,
ConvFwdOddC>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,65 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_grouped_conv2d_fwd_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_f16_instances_part1(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC,
GKYXC,
Empty_Tuple,
NHWGK,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
add_device_operation_instances(
instances,
device_grouped_conv_fwd_wmma_cshufflev3_f16_instances_part1<2,
NHWGC,
GKYXC,
Empty_Tuple,
NHWGK,
ConvFwdDefault>{});
add_device_operation_instances(
instances,
device_grouped_conv_fwd_wmma_cshufflev3_f16_instances_part1<2,
NHWGC,
GKYXC,
Empty_Tuple,
NHWGK,
ConvFwd1x1P0>{});
add_device_operation_instances(
instances,
device_grouped_conv_fwd_wmma_cshufflev3_f16_instances_part1<2,
NHWGC,
GKYXC,
Empty_Tuple,
NHWGK,
ConvFwd1x1S1P0>{});
add_device_operation_instances(
instances,
device_grouped_conv_fwd_wmma_cshufflev3_f16_instances_part1<2,
NHWGC,
GKYXC,
Empty_Tuple,
NHWGK,
ConvFwdOddC>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,65 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_grouped_conv2d_fwd_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_f16_instances_part2(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC,
GKYXC,
Empty_Tuple,
NHWGK,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
add_device_operation_instances(
instances,
device_grouped_conv_fwd_wmma_cshufflev3_f16_instances_part2<2,
NHWGC,
GKYXC,
Empty_Tuple,
NHWGK,
ConvFwdDefault>{});
add_device_operation_instances(
instances,
device_grouped_conv_fwd_wmma_cshufflev3_f16_instances_part2<2,
NHWGC,
GKYXC,
Empty_Tuple,
NHWGK,
ConvFwd1x1P0>{});
add_device_operation_instances(
instances,
device_grouped_conv_fwd_wmma_cshufflev3_f16_instances_part2<2,
NHWGC,
GKYXC,
Empty_Tuple,
NHWGK,
ConvFwd1x1S1P0>{});
add_device_operation_instances(
instances,
device_grouped_conv_fwd_wmma_cshufflev3_f16_instances_part2<2,
NHWGC,
GKYXC,
Empty_Tuple,
NHWGK,
ConvFwdOddC>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,65 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_grouped_conv2d_fwd_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_f16_instances_part3(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC,
GKYXC,
Empty_Tuple,
NHWGK,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
add_device_operation_instances(
instances,
device_grouped_conv_fwd_wmma_cshufflev3_f16_instances_part3<2,
NHWGC,
GKYXC,
Empty_Tuple,
NHWGK,
ConvFwdDefault>{});
add_device_operation_instances(
instances,
device_grouped_conv_fwd_wmma_cshufflev3_f16_instances_part3<2,
NHWGC,
GKYXC,
Empty_Tuple,
NHWGK,
ConvFwd1x1P0>{});
add_device_operation_instances(
instances,
device_grouped_conv_fwd_wmma_cshufflev3_f16_instances_part3<2,
NHWGC,
GKYXC,
Empty_Tuple,
NHWGK,
ConvFwd1x1S1P0>{});
add_device_operation_instances(
instances,
device_grouped_conv_fwd_wmma_cshufflev3_f16_instances_part3<2,
NHWGC,
GKYXC,
Empty_Tuple,
NHWGK,
ConvFwdOddC>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,65 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_grouped_conv2d_fwd_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_f16_instances_part4(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC,
GKYXC,
Empty_Tuple,
NHWGK,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
add_device_operation_instances(
instances,
device_grouped_conv_fwd_wmma_cshufflev3_f16_instances_part4<2,
NHWGC,
GKYXC,
Empty_Tuple,
NHWGK,
ConvFwdDefault>{});
add_device_operation_instances(
instances,
device_grouped_conv_fwd_wmma_cshufflev3_f16_instances_part4<2,
NHWGC,
GKYXC,
Empty_Tuple,
NHWGK,
ConvFwd1x1P0>{});
add_device_operation_instances(
instances,
device_grouped_conv_fwd_wmma_cshufflev3_f16_instances_part4<2,
NHWGC,
GKYXC,
Empty_Tuple,
NHWGK,
ConvFwd1x1S1P0>{});
add_device_operation_instances(
instances,
device_grouped_conv_fwd_wmma_cshufflev3_f16_instances_part4<2,
NHWGC,
GKYXC,
Empty_Tuple,
NHWGK,
ConvFwdOddC>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -1,7 +1,7 @@
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
# ONLY XDL_KERNELS
# ONLY XDL_AND_WMMA_KERNELS
add_instance_library(device_grouped_conv2d_fwd_bias_clamp_instance
xdl/device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_instance.cpp
xdl/device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_16x16_instance.cpp
@@ -39,4 +39,14 @@ add_instance_library(device_grouped_conv2d_fwd_bias_clamp_instance
xdl/mem/device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_fp32_tf32_mem_inter_instance.cpp
xdl/comp/device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_fp32_comp_instance.cpp
xdl/comp/device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_fp32_tf32_comp_instance.cpp
# WMMA CSHUFFLE V3
wmma/device_grouped_conv2d_fwd_bias_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_bf16_instance_part1.cpp
wmma/device_grouped_conv2d_fwd_bias_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_bf16_instance_part2.cpp
wmma/device_grouped_conv2d_fwd_bias_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_bf16_instance_part3.cpp
wmma/device_grouped_conv2d_fwd_bias_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_bf16_instance_part4.cpp
wmma/device_grouped_conv2d_fwd_bias_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_f16_instance_part1.cpp
wmma/device_grouped_conv2d_fwd_bias_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_f16_instance_part2.cpp
wmma/device_grouped_conv2d_fwd_bias_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_f16_instance_part3.cpp
wmma/device_grouped_conv2d_fwd_bias_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_f16_instance_part4.cpp
)

View File

@@ -0,0 +1,62 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_grouped_conv2d_fwd_bias_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_bf16_instances_part1(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC,
GKYXC,
Tuple<NHWGK>,
NHWGK,
BF16,
BF16,
Tuple<BF16>,
BF16,
PassThrough,
PassThrough,
AddClamp>>>& instances)
{
add_device_operation_instances(
instances,
device_grouped_conv_fwd_wmma_cshufflev3_bf16_instances_part1<2,
NHWGC,
GKYXC,
Tuple<NHWGK>,
NHWGK,
ConvFwdDefault,
Tuple<BF16>,
AddClamp>{});
add_device_operation_instances(
instances,
device_grouped_conv_fwd_wmma_cshufflev3_bf16_instances_part1<2,
NHWGC,
GKYXC,
Tuple<NHWGK>,
NHWGK,
ConvFwd1x1P0,
Tuple<BF16>,
AddClamp>{});
add_device_operation_instances(
instances,
device_grouped_conv_fwd_wmma_cshufflev3_bf16_instances_part1<2,
NHWGC,
GKYXC,
Tuple<NHWGK>,
NHWGK,
ConvFwd1x1S1P0,
Tuple<BF16>,
AddClamp>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,62 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_grouped_conv2d_fwd_bias_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_bf16_instances_part2(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC,
GKYXC,
Tuple<NHWGK>,
NHWGK,
BF16,
BF16,
Tuple<BF16>,
BF16,
PassThrough,
PassThrough,
AddClamp>>>& instances)
{
add_device_operation_instances(
instances,
device_grouped_conv_fwd_wmma_cshufflev3_bf16_instances_part2<2,
NHWGC,
GKYXC,
Tuple<NHWGK>,
NHWGK,
ConvFwdDefault,
Tuple<BF16>,
AddClamp>{});
add_device_operation_instances(
instances,
device_grouped_conv_fwd_wmma_cshufflev3_bf16_instances_part2<2,
NHWGC,
GKYXC,
Tuple<NHWGK>,
NHWGK,
ConvFwd1x1P0,
Tuple<BF16>,
AddClamp>{});
add_device_operation_instances(
instances,
device_grouped_conv_fwd_wmma_cshufflev3_bf16_instances_part2<2,
NHWGC,
GKYXC,
Tuple<NHWGK>,
NHWGK,
ConvFwd1x1S1P0,
Tuple<BF16>,
AddClamp>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,62 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_grouped_conv2d_fwd_bias_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_bf16_instances_part3(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC,
GKYXC,
Tuple<NHWGK>,
NHWGK,
BF16,
BF16,
Tuple<BF16>,
BF16,
PassThrough,
PassThrough,
AddClamp>>>& instances)
{
add_device_operation_instances(
instances,
device_grouped_conv_fwd_wmma_cshufflev3_bf16_instances_part3<2,
NHWGC,
GKYXC,
Tuple<NHWGK>,
NHWGK,
ConvFwdDefault,
Tuple<BF16>,
AddClamp>{});
add_device_operation_instances(
instances,
device_grouped_conv_fwd_wmma_cshufflev3_bf16_instances_part3<2,
NHWGC,
GKYXC,
Tuple<NHWGK>,
NHWGK,
ConvFwd1x1P0,
Tuple<BF16>,
AddClamp>{});
add_device_operation_instances(
instances,
device_grouped_conv_fwd_wmma_cshufflev3_bf16_instances_part3<2,
NHWGC,
GKYXC,
Tuple<NHWGK>,
NHWGK,
ConvFwd1x1S1P0,
Tuple<BF16>,
AddClamp>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,62 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_grouped_conv2d_fwd_bias_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_bf16_instances_part4(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC,
GKYXC,
Tuple<NHWGK>,
NHWGK,
BF16,
BF16,
Tuple<BF16>,
BF16,
PassThrough,
PassThrough,
AddClamp>>>& instances)
{
add_device_operation_instances(
instances,
device_grouped_conv_fwd_wmma_cshufflev3_bf16_instances_part4<2,
NHWGC,
GKYXC,
Tuple<NHWGK>,
NHWGK,
ConvFwdDefault,
Tuple<BF16>,
AddClamp>{});
add_device_operation_instances(
instances,
device_grouped_conv_fwd_wmma_cshufflev3_bf16_instances_part4<2,
NHWGC,
GKYXC,
Tuple<NHWGK>,
NHWGK,
ConvFwd1x1P0,
Tuple<BF16>,
AddClamp>{});
add_device_operation_instances(
instances,
device_grouped_conv_fwd_wmma_cshufflev3_bf16_instances_part4<2,
NHWGC,
GKYXC,
Tuple<NHWGK>,
NHWGK,
ConvFwd1x1S1P0,
Tuple<BF16>,
AddClamp>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,62 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_grouped_conv2d_fwd_bias_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_f16_instances_part1(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC,
GKYXC,
Tuple<NHWGK>,
NHWGK,
F16,
F16,
Tuple<F16>,
F16,
PassThrough,
PassThrough,
AddClamp>>>& instances)
{
add_device_operation_instances(
instances,
device_grouped_conv_fwd_wmma_cshufflev3_f16_instances_part1<2,
NHWGC,
GKYXC,
Tuple<NHWGK>,
NHWGK,
ConvFwdDefault,
Tuple<F16>,
AddClamp>{});
add_device_operation_instances(
instances,
device_grouped_conv_fwd_wmma_cshufflev3_f16_instances_part1<2,
NHWGC,
GKYXC,
Tuple<NHWGK>,
NHWGK,
ConvFwd1x1P0,
Tuple<F16>,
AddClamp>{});
add_device_operation_instances(
instances,
device_grouped_conv_fwd_wmma_cshufflev3_f16_instances_part1<2,
NHWGC,
GKYXC,
Tuple<NHWGK>,
NHWGK,
ConvFwd1x1S1P0,
Tuple<F16>,
AddClamp>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,62 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_grouped_conv2d_fwd_bias_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_f16_instances_part2(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC,
GKYXC,
Tuple<NHWGK>,
NHWGK,
F16,
F16,
Tuple<F16>,
F16,
PassThrough,
PassThrough,
AddClamp>>>& instances)
{
add_device_operation_instances(
instances,
device_grouped_conv_fwd_wmma_cshufflev3_f16_instances_part2<2,
NHWGC,
GKYXC,
Tuple<NHWGK>,
NHWGK,
ConvFwdDefault,
Tuple<F16>,
AddClamp>{});
add_device_operation_instances(
instances,
device_grouped_conv_fwd_wmma_cshufflev3_f16_instances_part2<2,
NHWGC,
GKYXC,
Tuple<NHWGK>,
NHWGK,
ConvFwd1x1P0,
Tuple<F16>,
AddClamp>{});
add_device_operation_instances(
instances,
device_grouped_conv_fwd_wmma_cshufflev3_f16_instances_part2<2,
NHWGC,
GKYXC,
Tuple<NHWGK>,
NHWGK,
ConvFwd1x1S1P0,
Tuple<F16>,
AddClamp>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,62 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_grouped_conv2d_fwd_bias_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_f16_instances_part3(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC,
GKYXC,
Tuple<NHWGK>,
NHWGK,
F16,
F16,
Tuple<F16>,
F16,
PassThrough,
PassThrough,
AddClamp>>>& instances)
{
add_device_operation_instances(
instances,
device_grouped_conv_fwd_wmma_cshufflev3_f16_instances_part3<2,
NHWGC,
GKYXC,
Tuple<NHWGK>,
NHWGK,
ConvFwdDefault,
Tuple<F16>,
AddClamp>{});
add_device_operation_instances(
instances,
device_grouped_conv_fwd_wmma_cshufflev3_f16_instances_part3<2,
NHWGC,
GKYXC,
Tuple<NHWGK>,
NHWGK,
ConvFwd1x1P0,
Tuple<F16>,
AddClamp>{});
add_device_operation_instances(
instances,
device_grouped_conv_fwd_wmma_cshufflev3_f16_instances_part3<2,
NHWGC,
GKYXC,
Tuple<NHWGK>,
NHWGK,
ConvFwd1x1S1P0,
Tuple<F16>,
AddClamp>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,62 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_grouped_conv2d_fwd_bias_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_f16_instances_part4(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC,
GKYXC,
Tuple<NHWGK>,
NHWGK,
F16,
F16,
Tuple<F16>,
F16,
PassThrough,
PassThrough,
AddClamp>>>& instances)
{
add_device_operation_instances(
instances,
device_grouped_conv_fwd_wmma_cshufflev3_f16_instances_part4<2,
NHWGC,
GKYXC,
Tuple<NHWGK>,
NHWGK,
ConvFwdDefault,
Tuple<F16>,
AddClamp>{});
add_device_operation_instances(
instances,
device_grouped_conv_fwd_wmma_cshufflev3_f16_instances_part4<2,
NHWGC,
GKYXC,
Tuple<NHWGK>,
NHWGK,
ConvFwd1x1P0,
Tuple<F16>,
AddClamp>{});
add_device_operation_instances(
instances,
device_grouped_conv_fwd_wmma_cshufflev3_f16_instances_part4<2,
NHWGC,
GKYXC,
Tuple<NHWGK>,
NHWGK,
ConvFwd1x1S1P0,
Tuple<F16>,
AddClamp>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -1,7 +1,7 @@
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
# ONLY XDL_KERNELS
# ONLY XDL_AND_WMMA_KERNELS
add_instance_library(device_grouped_conv2d_fwd_clamp_instance
xdl/device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_instance.cpp
xdl/device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_16x16_instance.cpp
@@ -39,4 +39,14 @@ add_instance_library(device_grouped_conv2d_fwd_clamp_instance
xdl/mem/device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_fp32_tf32_mem_inter_instance.cpp
xdl/comp/device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_fp32_comp_instance.cpp
xdl/comp/device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_fp32_tf32_comp_instance.cpp
# WMMA CSHUFFLE V3
wmma/device_grouped_conv2d_fwd_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_bf16_instance_part1.cpp
wmma/device_grouped_conv2d_fwd_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_bf16_instance_part2.cpp
wmma/device_grouped_conv2d_fwd_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_bf16_instance_part3.cpp
wmma/device_grouped_conv2d_fwd_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_bf16_instance_part4.cpp
wmma/device_grouped_conv2d_fwd_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_f16_instance_part1.cpp
wmma/device_grouped_conv2d_fwd_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_f16_instance_part2.cpp
wmma/device_grouped_conv2d_fwd_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_f16_instance_part3.cpp
wmma/device_grouped_conv2d_fwd_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_f16_instance_part4.cpp
)

View File

@@ -0,0 +1,62 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_grouped_conv2d_fwd_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_bf16_instances_part1(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC,
GKYXC,
Tuple<>,
NHWGK,
BF16,
BF16,
Tuple<>,
BF16,
PassThrough,
PassThrough,
Clamp>>>& instances)
{
add_device_operation_instances(
instances,
device_grouped_conv_fwd_wmma_cshufflev3_bf16_instances_part1<2,
NHWGC,
GKYXC,
Tuple<>,
NHWGK,
ConvFwdDefault,
Tuple<>,
Clamp>{});
add_device_operation_instances(
instances,
device_grouped_conv_fwd_wmma_cshufflev3_bf16_instances_part1<2,
NHWGC,
GKYXC,
Tuple<>,
NHWGK,
ConvFwd1x1P0,
Tuple<>,
Clamp>{});
add_device_operation_instances(
instances,
device_grouped_conv_fwd_wmma_cshufflev3_bf16_instances_part1<2,
NHWGC,
GKYXC,
Tuple<>,
NHWGK,
ConvFwd1x1S1P0,
Tuple<>,
Clamp>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,62 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_grouped_conv2d_fwd_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_bf16_instances_part2(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC,
GKYXC,
Tuple<>,
NHWGK,
BF16,
BF16,
Tuple<>,
BF16,
PassThrough,
PassThrough,
Clamp>>>& instances)
{
add_device_operation_instances(
instances,
device_grouped_conv_fwd_wmma_cshufflev3_bf16_instances_part2<2,
NHWGC,
GKYXC,
Tuple<>,
NHWGK,
ConvFwdDefault,
Tuple<>,
Clamp>{});
add_device_operation_instances(
instances,
device_grouped_conv_fwd_wmma_cshufflev3_bf16_instances_part2<2,
NHWGC,
GKYXC,
Tuple<>,
NHWGK,
ConvFwd1x1P0,
Tuple<>,
Clamp>{});
add_device_operation_instances(
instances,
device_grouped_conv_fwd_wmma_cshufflev3_bf16_instances_part2<2,
NHWGC,
GKYXC,
Tuple<>,
NHWGK,
ConvFwd1x1S1P0,
Tuple<>,
Clamp>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,62 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_grouped_conv2d_fwd_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_bf16_instances_part3(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC,
GKYXC,
Tuple<>,
NHWGK,
BF16,
BF16,
Tuple<>,
BF16,
PassThrough,
PassThrough,
Clamp>>>& instances)
{
add_device_operation_instances(
instances,
device_grouped_conv_fwd_wmma_cshufflev3_bf16_instances_part3<2,
NHWGC,
GKYXC,
Tuple<>,
NHWGK,
ConvFwdDefault,
Tuple<>,
Clamp>{});
add_device_operation_instances(
instances,
device_grouped_conv_fwd_wmma_cshufflev3_bf16_instances_part3<2,
NHWGC,
GKYXC,
Tuple<>,
NHWGK,
ConvFwd1x1P0,
Tuple<>,
Clamp>{});
add_device_operation_instances(
instances,
device_grouped_conv_fwd_wmma_cshufflev3_bf16_instances_part3<2,
NHWGC,
GKYXC,
Tuple<>,
NHWGK,
ConvFwd1x1S1P0,
Tuple<>,
Clamp>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,62 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_grouped_conv2d_fwd_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_bf16_instances_part4(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC,
GKYXC,
Tuple<>,
NHWGK,
BF16,
BF16,
Tuple<>,
BF16,
PassThrough,
PassThrough,
Clamp>>>& instances)
{
add_device_operation_instances(
instances,
device_grouped_conv_fwd_wmma_cshufflev3_bf16_instances_part4<2,
NHWGC,
GKYXC,
Tuple<>,
NHWGK,
ConvFwdDefault,
Tuple<>,
Clamp>{});
add_device_operation_instances(
instances,
device_grouped_conv_fwd_wmma_cshufflev3_bf16_instances_part4<2,
NHWGC,
GKYXC,
Tuple<>,
NHWGK,
ConvFwd1x1P0,
Tuple<>,
Clamp>{});
add_device_operation_instances(
instances,
device_grouped_conv_fwd_wmma_cshufflev3_bf16_instances_part4<2,
NHWGC,
GKYXC,
Tuple<>,
NHWGK,
ConvFwd1x1S1P0,
Tuple<>,
Clamp>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,62 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_grouped_conv2d_fwd_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_f16_instances_part1(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC,
GKYXC,
Tuple<>,
NHWGK,
F16,
F16,
Tuple<>,
F16,
PassThrough,
PassThrough,
Clamp>>>& instances)
{
add_device_operation_instances(
instances,
device_grouped_conv_fwd_wmma_cshufflev3_f16_instances_part1<2,
NHWGC,
GKYXC,
Tuple<>,
NHWGK,
ConvFwdDefault,
Tuple<>,
Clamp>{});
add_device_operation_instances(
instances,
device_grouped_conv_fwd_wmma_cshufflev3_f16_instances_part1<2,
NHWGC,
GKYXC,
Tuple<>,
NHWGK,
ConvFwd1x1P0,
Tuple<>,
Clamp>{});
add_device_operation_instances(
instances,
device_grouped_conv_fwd_wmma_cshufflev3_f16_instances_part1<2,
NHWGC,
GKYXC,
Tuple<>,
NHWGK,
ConvFwd1x1S1P0,
Tuple<>,
Clamp>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,62 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_grouped_conv2d_fwd_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_f16_instances_part2(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC,
GKYXC,
Tuple<>,
NHWGK,
F16,
F16,
Tuple<>,
F16,
PassThrough,
PassThrough,
Clamp>>>& instances)
{
add_device_operation_instances(
instances,
device_grouped_conv_fwd_wmma_cshufflev3_f16_instances_part2<2,
NHWGC,
GKYXC,
Tuple<>,
NHWGK,
ConvFwdDefault,
Tuple<>,
Clamp>{});
add_device_operation_instances(
instances,
device_grouped_conv_fwd_wmma_cshufflev3_f16_instances_part2<2,
NHWGC,
GKYXC,
Tuple<>,
NHWGK,
ConvFwd1x1P0,
Tuple<>,
Clamp>{});
add_device_operation_instances(
instances,
device_grouped_conv_fwd_wmma_cshufflev3_f16_instances_part2<2,
NHWGC,
GKYXC,
Tuple<>,
NHWGK,
ConvFwd1x1S1P0,
Tuple<>,
Clamp>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,62 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_grouped_conv2d_fwd_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_f16_instances_part3(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC,
GKYXC,
Tuple<>,
NHWGK,
F16,
F16,
Tuple<>,
F16,
PassThrough,
PassThrough,
Clamp>>>& instances)
{
add_device_operation_instances(
instances,
device_grouped_conv_fwd_wmma_cshufflev3_f16_instances_part3<2,
NHWGC,
GKYXC,
Tuple<>,
NHWGK,
ConvFwdDefault,
Tuple<>,
Clamp>{});
add_device_operation_instances(
instances,
device_grouped_conv_fwd_wmma_cshufflev3_f16_instances_part3<2,
NHWGC,
GKYXC,
Tuple<>,
NHWGK,
ConvFwd1x1P0,
Tuple<>,
Clamp>{});
add_device_operation_instances(
instances,
device_grouped_conv_fwd_wmma_cshufflev3_f16_instances_part3<2,
NHWGC,
GKYXC,
Tuple<>,
NHWGK,
ConvFwd1x1S1P0,
Tuple<>,
Clamp>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,62 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_grouped_conv2d_fwd_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_f16_instances_part4(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC,
GKYXC,
Tuple<>,
NHWGK,
F16,
F16,
Tuple<>,
F16,
PassThrough,
PassThrough,
Clamp>>>& instances)
{
add_device_operation_instances(
instances,
device_grouped_conv_fwd_wmma_cshufflev3_f16_instances_part4<2,
NHWGC,
GKYXC,
Tuple<>,
NHWGK,
ConvFwdDefault,
Tuple<>,
Clamp>{});
add_device_operation_instances(
instances,
device_grouped_conv_fwd_wmma_cshufflev3_f16_instances_part4<2,
NHWGC,
GKYXC,
Tuple<>,
NHWGK,
ConvFwd1x1P0,
Tuple<>,
Clamp>{});
add_device_operation_instances(
instances,
device_grouped_conv_fwd_wmma_cshufflev3_f16_instances_part4<2,
NHWGC,
GKYXC,
Tuple<>,
NHWGK,
ConvFwd1x1S1P0,
Tuple<>,
Clamp>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -52,22 +52,33 @@ set(GROUPED_CONV3D_FWD
xdl/comp/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_comp_part2_instance.cpp
xdl/comp/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_comp_part2_instance.cpp
wmma/device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_f16_instance.cpp
wmma/device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_i8_instance.cpp
wmma/device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp
wmma/device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_i8_instance.cpp
wmma/device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_f16_1x1p0_instance.cpp
wmma/device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_i8_1x1p0_instance.cpp
wmma/device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_f16_1x1p0_instance.cpp
wmma/device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_i8_1x1p0_instance.cpp
wmma/device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_f16_1x1s1p0_instance.cpp
wmma/device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_i8_1x1s1p0_instance.cpp
wmma/device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_f16_1x1s1p0_instance.cpp
wmma/device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_i8_1x1s1p0_instance.cpp
wmma/device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_f16_oddc_instance.cpp
wmma/device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_i8_oddc_instance.cpp
wmma/device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_f16_oddc_instance.cpp
wmma/device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_i8_oddc_instance.cpp
# WMMA_OLD
wmma/device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_f16_instance.cpp
wmma/device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_i8_instance.cpp
wmma/device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp
wmma/device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_i8_instance.cpp
wmma/device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_f16_1x1p0_instance.cpp
wmma/device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_i8_1x1p0_instance.cpp
wmma/device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_f16_1x1p0_instance.cpp
wmma/device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_i8_1x1p0_instance.cpp
wmma/device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_f16_1x1s1p0_instance.cpp
wmma/device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_i8_1x1s1p0_instance.cpp
wmma/device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_f16_1x1s1p0_instance.cpp
wmma/device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_i8_1x1s1p0_instance.cpp
wmma/device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_f16_oddc_instance.cpp
wmma/device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_i8_oddc_instance.cpp
wmma/device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_f16_oddc_instance.cpp
wmma/device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_i8_oddc_instance.cpp
# WMMA CSHUFFLE V3
wmma/device_grouped_conv3d_fwd_wmma_cshufflev3_ndhwgc_gkzyxc_ndhwgk_bf16_instance_part1.cpp
wmma/device_grouped_conv3d_fwd_wmma_cshufflev3_ndhwgc_gkzyxc_ndhwgk_bf16_instance_part2.cpp
wmma/device_grouped_conv3d_fwd_wmma_cshufflev3_ndhwgc_gkzyxc_ndhwgk_bf16_instance_part3.cpp
wmma/device_grouped_conv3d_fwd_wmma_cshufflev3_ndhwgc_gkzyxc_ndhwgk_bf16_instance_part4.cpp
wmma/device_grouped_conv3d_fwd_wmma_cshufflev3_ndhwgc_gkzyxc_ndhwgk_f16_instance_part1.cpp
wmma/device_grouped_conv3d_fwd_wmma_cshufflev3_ndhwgc_gkzyxc_ndhwgk_f16_instance_part2.cpp
wmma/device_grouped_conv3d_fwd_wmma_cshufflev3_ndhwgc_gkzyxc_ndhwgk_f16_instance_part3.cpp
wmma/device_grouped_conv3d_fwd_wmma_cshufflev3_ndhwgc_gkzyxc_ndhwgk_f16_instance_part4.cpp
)
# Add generated files for sharded instantiations.
include(ShardInstantiation)

View File

@@ -0,0 +1,55 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_grouped_conv3d_fwd_wmma_cshufflev3_ndhwgc_gkzyxc_ndhwgk_bf16_instances_part1(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
NDHWGC,
GKZYXC,
Empty_Tuple,
NDHWGK,
BF16,
BF16,
Empty_Tuple,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
add_device_operation_instances(
instances,
device_grouped_conv_fwd_wmma_cshufflev3_bf16_instances_part1<3,
NDHWGC,
GKZYXC,
Empty_Tuple,
NDHWGK,
ConvFwdDefault>{});
add_device_operation_instances(
instances,
device_grouped_conv_fwd_wmma_cshufflev3_bf16_instances_part1<3,
NDHWGC,
GKZYXC,
Empty_Tuple,
NDHWGK,
ConvFwd1x1P0>{});
add_device_operation_instances(
instances,
device_grouped_conv_fwd_wmma_cshufflev3_bf16_instances_part1<3,
NDHWGC,
GKZYXC,
Empty_Tuple,
NDHWGK,
ConvFwd1x1S1P0>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,55 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_grouped_conv3d_fwd_wmma_cshufflev3_ndhwgc_gkzyxc_ndhwgk_bf16_instances_part2(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
NDHWGC,
GKZYXC,
Empty_Tuple,
NDHWGK,
BF16,
BF16,
Empty_Tuple,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
add_device_operation_instances(
instances,
device_grouped_conv_fwd_wmma_cshufflev3_bf16_instances_part2<3,
NDHWGC,
GKZYXC,
Empty_Tuple,
NDHWGK,
ConvFwdDefault>{});
add_device_operation_instances(
instances,
device_grouped_conv_fwd_wmma_cshufflev3_bf16_instances_part2<3,
NDHWGC,
GKZYXC,
Empty_Tuple,
NDHWGK,
ConvFwd1x1P0>{});
add_device_operation_instances(
instances,
device_grouped_conv_fwd_wmma_cshufflev3_bf16_instances_part2<3,
NDHWGC,
GKZYXC,
Empty_Tuple,
NDHWGK,
ConvFwd1x1S1P0>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,55 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_grouped_conv3d_fwd_wmma_cshufflev3_ndhwgc_gkzyxc_ndhwgk_bf16_instances_part3(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
NDHWGC,
GKZYXC,
Empty_Tuple,
NDHWGK,
BF16,
BF16,
Empty_Tuple,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
add_device_operation_instances(
instances,
device_grouped_conv_fwd_wmma_cshufflev3_bf16_instances_part3<3,
NDHWGC,
GKZYXC,
Empty_Tuple,
NDHWGK,
ConvFwdDefault>{});
add_device_operation_instances(
instances,
device_grouped_conv_fwd_wmma_cshufflev3_bf16_instances_part3<3,
NDHWGC,
GKZYXC,
Empty_Tuple,
NDHWGK,
ConvFwd1x1P0>{});
add_device_operation_instances(
instances,
device_grouped_conv_fwd_wmma_cshufflev3_bf16_instances_part3<3,
NDHWGC,
GKZYXC,
Empty_Tuple,
NDHWGK,
ConvFwd1x1S1P0>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,55 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_grouped_conv3d_fwd_wmma_cshufflev3_ndhwgc_gkzyxc_ndhwgk_bf16_instances_part4(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
NDHWGC,
GKZYXC,
Empty_Tuple,
NDHWGK,
BF16,
BF16,
Empty_Tuple,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
add_device_operation_instances(
instances,
device_grouped_conv_fwd_wmma_cshufflev3_bf16_instances_part4<3,
NDHWGC,
GKZYXC,
Empty_Tuple,
NDHWGK,
ConvFwdDefault>{});
add_device_operation_instances(
instances,
device_grouped_conv_fwd_wmma_cshufflev3_bf16_instances_part4<3,
NDHWGC,
GKZYXC,
Empty_Tuple,
NDHWGK,
ConvFwd1x1P0>{});
add_device_operation_instances(
instances,
device_grouped_conv_fwd_wmma_cshufflev3_bf16_instances_part4<3,
NDHWGC,
GKZYXC,
Empty_Tuple,
NDHWGK,
ConvFwd1x1S1P0>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,55 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_grouped_conv3d_fwd_wmma_cshufflev3_ndhwgc_gkzyxc_ndhwgk_f16_instances_part1(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
NDHWGC,
GKZYXC,
Empty_Tuple,
NDHWGK,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
add_device_operation_instances(
instances,
device_grouped_conv_fwd_wmma_cshufflev3_f16_instances_part1<3,
NDHWGC,
GKZYXC,
Empty_Tuple,
NDHWGK,
ConvFwdDefault>{});
add_device_operation_instances(
instances,
device_grouped_conv_fwd_wmma_cshufflev3_f16_instances_part1<3,
NDHWGC,
GKZYXC,
Empty_Tuple,
NDHWGK,
ConvFwd1x1P0>{});
add_device_operation_instances(
instances,
device_grouped_conv_fwd_wmma_cshufflev3_f16_instances_part1<3,
NDHWGC,
GKZYXC,
Empty_Tuple,
NDHWGK,
ConvFwd1x1S1P0>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,55 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_grouped_conv3d_fwd_wmma_cshufflev3_ndhwgc_gkzyxc_ndhwgk_f16_instances_part2(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
NDHWGC,
GKZYXC,
Empty_Tuple,
NDHWGK,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
add_device_operation_instances(
instances,
device_grouped_conv_fwd_wmma_cshufflev3_f16_instances_part2<3,
NDHWGC,
GKZYXC,
Empty_Tuple,
NDHWGK,
ConvFwdDefault>{});
add_device_operation_instances(
instances,
device_grouped_conv_fwd_wmma_cshufflev3_f16_instances_part2<3,
NDHWGC,
GKZYXC,
Empty_Tuple,
NDHWGK,
ConvFwd1x1P0>{});
add_device_operation_instances(
instances,
device_grouped_conv_fwd_wmma_cshufflev3_f16_instances_part2<3,
NDHWGC,
GKZYXC,
Empty_Tuple,
NDHWGK,
ConvFwd1x1S1P0>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,55 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_grouped_conv3d_fwd_wmma_cshufflev3_ndhwgc_gkzyxc_ndhwgk_f16_instances_part3(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
NDHWGC,
GKZYXC,
Empty_Tuple,
NDHWGK,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
add_device_operation_instances(
instances,
device_grouped_conv_fwd_wmma_cshufflev3_f16_instances_part3<3,
NDHWGC,
GKZYXC,
Empty_Tuple,
NDHWGK,
ConvFwdDefault>{});
add_device_operation_instances(
instances,
device_grouped_conv_fwd_wmma_cshufflev3_f16_instances_part3<3,
NDHWGC,
GKZYXC,
Empty_Tuple,
NDHWGK,
ConvFwd1x1P0>{});
add_device_operation_instances(
instances,
device_grouped_conv_fwd_wmma_cshufflev3_f16_instances_part3<3,
NDHWGC,
GKZYXC,
Empty_Tuple,
NDHWGK,
ConvFwd1x1S1P0>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,55 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_grouped_conv3d_fwd_wmma_cshufflev3_ndhwgc_gkzyxc_ndhwgk_f16_instances_part4(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
NDHWGC,
GKZYXC,
Empty_Tuple,
NDHWGK,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
add_device_operation_instances(
instances,
device_grouped_conv_fwd_wmma_cshufflev3_f16_instances_part4<3,
NDHWGC,
GKZYXC,
Empty_Tuple,
NDHWGK,
ConvFwdDefault>{});
add_device_operation_instances(
instances,
device_grouped_conv_fwd_wmma_cshufflev3_f16_instances_part4<3,
NDHWGC,
GKZYXC,
Empty_Tuple,
NDHWGK,
ConvFwd1x1P0>{});
add_device_operation_instances(
instances,
device_grouped_conv_fwd_wmma_cshufflev3_f16_instances_part4<3,
NDHWGC,
GKZYXC,
Empty_Tuple,
NDHWGK,
ConvFwd1x1S1P0>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -1,7 +1,7 @@
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
# ONLY XDL_KERNELS
# ONLY XDL_AND_WMMA_KERNELS
set(GROUPED_CONV3D_FWD
xdl/device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp
xdl/device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_16x16_instance.cpp
@@ -34,6 +34,16 @@ set(GROUPED_CONV3D_FWD
xdl/comp/device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_fp32_tf32_comp_instance.cpp
xdl/device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_fp32_tf32_instance.cpp
)
# WMMA CSHUFFLE V3
wmma/device_grouped_conv3d_fwd_bias_clamp_wmma_cshufflev3_ndhwgc_gkzyxc_ndhwgk_bf16_instance_part1.cpp
wmma/device_grouped_conv3d_fwd_bias_clamp_wmma_cshufflev3_ndhwgc_gkzyxc_ndhwgk_bf16_instance_part2.cpp
wmma/device_grouped_conv3d_fwd_bias_clamp_wmma_cshufflev3_ndhwgc_gkzyxc_ndhwgk_bf16_instance_part3.cpp
wmma/device_grouped_conv3d_fwd_bias_clamp_wmma_cshufflev3_ndhwgc_gkzyxc_ndhwgk_bf16_instance_part4.cpp
wmma/device_grouped_conv3d_fwd_bias_clamp_wmma_cshufflev3_ndhwgc_gkzyxc_ndhwgk_f16_instance_part1.cpp
wmma/device_grouped_conv3d_fwd_bias_clamp_wmma_cshufflev3_ndhwgc_gkzyxc_ndhwgk_f16_instance_part2.cpp
wmma/device_grouped_conv3d_fwd_bias_clamp_wmma_cshufflev3_ndhwgc_gkzyxc_ndhwgk_f16_instance_part3.cpp
wmma/device_grouped_conv3d_fwd_bias_clamp_wmma_cshufflev3_ndhwgc_gkzyxc_ndhwgk_f16_instance_part4.cpp
)
add_instance_library(device_grouped_conv3d_fwd_bias_clamp_instance ${GROUPED_CONV3D_FWD})

View File

@@ -0,0 +1,61 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_grouped_conv3d_fwd_bias_clamp_wmma_cshufflev3_ndhwgc_gkzyxc_ndhwgk_bf16_instances_part1(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
NDHWGC,
GKZYXC,
Tuple<NDHWGK>,
NDHWGK,
BF16,
BF16,
Tuple<BF16>,
BF16,
PassThrough,
PassThrough,
AddClamp>>>& instances)
{
add_device_operation_instances(
instances,
device_grouped_conv_fwd_wmma_cshufflev3_bf16_instances_part1<3,
NDHWGC,
GKZYXC,
Tuple<NDHWGK>,
NDHWGK,
ConvFwdDefault,
Tuple<BF16>,
AddClamp>{});
add_device_operation_instances(
instances,
device_grouped_conv_fwd_wmma_cshufflev3_bf16_instances_part1<3,
NDHWGC,
GKZYXC,
Tuple<NDHWGK>,
NDHWGK,
ConvFwd1x1P0,
Tuple<BF16>,
AddClamp>{});
add_device_operation_instances(
instances,
device_grouped_conv_fwd_wmma_cshufflev3_bf16_instances_part1<3,
NDHWGC,
GKZYXC,
Tuple<NDHWGK>,
NDHWGK,
ConvFwd1x1S1P0,
Tuple<BF16>,
AddClamp>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,61 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_grouped_conv3d_fwd_bias_clamp_wmma_cshufflev3_ndhwgc_gkzyxc_ndhwgk_bf16_instances_part2(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
NDHWGC,
GKZYXC,
Tuple<NDHWGK>,
NDHWGK,
BF16,
BF16,
Tuple<BF16>,
BF16,
PassThrough,
PassThrough,
AddClamp>>>& instances)
{
add_device_operation_instances(
instances,
device_grouped_conv_fwd_wmma_cshufflev3_bf16_instances_part2<3,
NDHWGC,
GKZYXC,
Tuple<NDHWGK>,
NDHWGK,
ConvFwdDefault,
Tuple<BF16>,
AddClamp>{});
add_device_operation_instances(
instances,
device_grouped_conv_fwd_wmma_cshufflev3_bf16_instances_part2<3,
NDHWGC,
GKZYXC,
Tuple<NDHWGK>,
NDHWGK,
ConvFwd1x1P0,
Tuple<BF16>,
AddClamp>{});
add_device_operation_instances(
instances,
device_grouped_conv_fwd_wmma_cshufflev3_bf16_instances_part2<3,
NDHWGC,
GKZYXC,
Tuple<NDHWGK>,
NDHWGK,
ConvFwd1x1S1P0,
Tuple<BF16>,
AddClamp>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,61 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_grouped_conv3d_fwd_bias_clamp_wmma_cshufflev3_ndhwgc_gkzyxc_ndhwgk_bf16_instances_part3(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
NDHWGC,
GKZYXC,
Tuple<NDHWGK>,
NDHWGK,
BF16,
BF16,
Tuple<BF16>,
BF16,
PassThrough,
PassThrough,
AddClamp>>>& instances)
{
add_device_operation_instances(
instances,
device_grouped_conv_fwd_wmma_cshufflev3_bf16_instances_part3<3,
NDHWGC,
GKZYXC,
Tuple<NDHWGK>,
NDHWGK,
ConvFwdDefault,
Tuple<BF16>,
AddClamp>{});
add_device_operation_instances(
instances,
device_grouped_conv_fwd_wmma_cshufflev3_bf16_instances_part3<3,
NDHWGC,
GKZYXC,
Tuple<NDHWGK>,
NDHWGK,
ConvFwd1x1P0,
Tuple<BF16>,
AddClamp>{});
add_device_operation_instances(
instances,
device_grouped_conv_fwd_wmma_cshufflev3_bf16_instances_part3<3,
NDHWGC,
GKZYXC,
Tuple<NDHWGK>,
NDHWGK,
ConvFwd1x1S1P0,
Tuple<BF16>,
AddClamp>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,61 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_grouped_conv3d_fwd_bias_clamp_wmma_cshufflev3_ndhwgc_gkzyxc_ndhwgk_bf16_instances_part4(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
NDHWGC,
GKZYXC,
Tuple<NDHWGK>,
NDHWGK,
BF16,
BF16,
Tuple<BF16>,
BF16,
PassThrough,
PassThrough,
AddClamp>>>& instances)
{
add_device_operation_instances(
instances,
device_grouped_conv_fwd_wmma_cshufflev3_bf16_instances_part4<3,
NDHWGC,
GKZYXC,
Tuple<NDHWGK>,
NDHWGK,
ConvFwdDefault,
Tuple<BF16>,
AddClamp>{});
add_device_operation_instances(
instances,
device_grouped_conv_fwd_wmma_cshufflev3_bf16_instances_part4<3,
NDHWGC,
GKZYXC,
Tuple<NDHWGK>,
NDHWGK,
ConvFwd1x1P0,
Tuple<BF16>,
AddClamp>{});
add_device_operation_instances(
instances,
device_grouped_conv_fwd_wmma_cshufflev3_bf16_instances_part4<3,
NDHWGC,
GKZYXC,
Tuple<NDHWGK>,
NDHWGK,
ConvFwd1x1S1P0,
Tuple<BF16>,
AddClamp>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,61 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_grouped_conv3d_fwd_bias_clamp_wmma_cshufflev3_ndhwgc_gkzyxc_ndhwgk_f16_instances_part1(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
NDHWGC,
GKZYXC,
Tuple<NDHWGK>,
NDHWGK,
F16,
F16,
Tuple<F16>,
F16,
PassThrough,
PassThrough,
AddClamp>>>& instances)
{
add_device_operation_instances(
instances,
device_grouped_conv_fwd_wmma_cshufflev3_f16_instances_part1<3,
NDHWGC,
GKZYXC,
Tuple<NDHWGK>,
NDHWGK,
ConvFwdDefault,
Tuple<F16>,
AddClamp>{});
add_device_operation_instances(
instances,
device_grouped_conv_fwd_wmma_cshufflev3_f16_instances_part1<3,
NDHWGC,
GKZYXC,
Tuple<NDHWGK>,
NDHWGK,
ConvFwd1x1P0,
Tuple<F16>,
AddClamp>{});
add_device_operation_instances(
instances,
device_grouped_conv_fwd_wmma_cshufflev3_f16_instances_part1<3,
NDHWGC,
GKZYXC,
Tuple<NDHWGK>,
NDHWGK,
ConvFwd1x1S1P0,
Tuple<F16>,
AddClamp>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,61 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_grouped_conv3d_fwd_bias_clamp_wmma_cshufflev3_ndhwgc_gkzyxc_ndhwgk_f16_instances_part2(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
NDHWGC,
GKZYXC,
Tuple<NDHWGK>,
NDHWGK,
F16,
F16,
Tuple<F16>,
F16,
PassThrough,
PassThrough,
AddClamp>>>& instances)
{
add_device_operation_instances(
instances,
device_grouped_conv_fwd_wmma_cshufflev3_f16_instances_part2<3,
NDHWGC,
GKZYXC,
Tuple<NDHWGK>,
NDHWGK,
ConvFwdDefault,
Tuple<F16>,
AddClamp>{});
add_device_operation_instances(
instances,
device_grouped_conv_fwd_wmma_cshufflev3_f16_instances_part2<3,
NDHWGC,
GKZYXC,
Tuple<NDHWGK>,
NDHWGK,
ConvFwd1x1P0,
Tuple<F16>,
AddClamp>{});
add_device_operation_instances(
instances,
device_grouped_conv_fwd_wmma_cshufflev3_f16_instances_part2<3,
NDHWGC,
GKZYXC,
Tuple<NDHWGK>,
NDHWGK,
ConvFwd1x1S1P0,
Tuple<F16>,
AddClamp>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,61 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_grouped_conv3d_fwd_bias_clamp_wmma_cshufflev3_ndhwgc_gkzyxc_ndhwgk_f16_instances_part3(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
NDHWGC,
GKZYXC,
Tuple<NDHWGK>,
NDHWGK,
F16,
F16,
Tuple<F16>,
F16,
PassThrough,
PassThrough,
AddClamp>>>& instances)
{
add_device_operation_instances(
instances,
device_grouped_conv_fwd_wmma_cshufflev3_f16_instances_part3<3,
NDHWGC,
GKZYXC,
Tuple<NDHWGK>,
NDHWGK,
ConvFwdDefault,
Tuple<F16>,
AddClamp>{});
add_device_operation_instances(
instances,
device_grouped_conv_fwd_wmma_cshufflev3_f16_instances_part3<3,
NDHWGC,
GKZYXC,
Tuple<NDHWGK>,
NDHWGK,
ConvFwd1x1P0,
Tuple<F16>,
AddClamp>{});
add_device_operation_instances(
instances,
device_grouped_conv_fwd_wmma_cshufflev3_f16_instances_part3<3,
NDHWGC,
GKZYXC,
Tuple<NDHWGK>,
NDHWGK,
ConvFwd1x1S1P0,
Tuple<F16>,
AddClamp>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,61 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_grouped_conv3d_fwd_bias_clamp_wmma_cshufflev3_ndhwgc_gkzyxc_ndhwgk_f16_instances_part4(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
NDHWGC,
GKZYXC,
Tuple<NDHWGK>,
NDHWGK,
F16,
F16,
Tuple<F16>,
F16,
PassThrough,
PassThrough,
AddClamp>>>& instances)
{
add_device_operation_instances(
instances,
device_grouped_conv_fwd_wmma_cshufflev3_f16_instances_part4<3,
NDHWGC,
GKZYXC,
Tuple<NDHWGK>,
NDHWGK,
ConvFwdDefault,
Tuple<F16>,
AddClamp>{});
add_device_operation_instances(
instances,
device_grouped_conv_fwd_wmma_cshufflev3_f16_instances_part4<3,
NDHWGC,
GKZYXC,
Tuple<NDHWGK>,
NDHWGK,
ConvFwd1x1P0,
Tuple<F16>,
AddClamp>{});
add_device_operation_instances(
instances,
device_grouped_conv_fwd_wmma_cshufflev3_f16_instances_part4<3,
NDHWGC,
GKZYXC,
Tuple<NDHWGK>,
NDHWGK,
ConvFwd1x1S1P0,
Tuple<F16>,
AddClamp>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -1,7 +1,7 @@
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
# ONLY XDL_KERNELS
# ONLY XDL_AND_WMMA_KERNELS
set(GROUPED_CONV3D_FWD
xdl/device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp
xdl/device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_16x16_instance.cpp
@@ -34,6 +34,16 @@ set(GROUPED_CONV3D_FWD
xdl/comp/device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_fp32_tf32_comp_instance.cpp
xdl/device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_fp32_tf32_instance.cpp
)
# WMMA CSHUFFLE V3
wmma/device_grouped_conv3d_fwd_clamp_wmma_cshufflev3_ndhwgc_gkzyxc_ndhwgk_bf16_instance_part1.cpp
wmma/device_grouped_conv3d_fwd_clamp_wmma_cshufflev3_ndhwgc_gkzyxc_ndhwgk_bf16_instance_part2.cpp
wmma/device_grouped_conv3d_fwd_clamp_wmma_cshufflev3_ndhwgc_gkzyxc_ndhwgk_bf16_instance_part3.cpp
wmma/device_grouped_conv3d_fwd_clamp_wmma_cshufflev3_ndhwgc_gkzyxc_ndhwgk_bf16_instance_part4.cpp
wmma/device_grouped_conv3d_fwd_clamp_wmma_cshufflev3_ndhwgc_gkzyxc_ndhwgk_f16_instance_part1.cpp
wmma/device_grouped_conv3d_fwd_clamp_wmma_cshufflev3_ndhwgc_gkzyxc_ndhwgk_f16_instance_part2.cpp
wmma/device_grouped_conv3d_fwd_clamp_wmma_cshufflev3_ndhwgc_gkzyxc_ndhwgk_f16_instance_part3.cpp
wmma/device_grouped_conv3d_fwd_clamp_wmma_cshufflev3_ndhwgc_gkzyxc_ndhwgk_f16_instance_part4.cpp
)
add_instance_library(device_grouped_conv3d_fwd_clamp_instance ${GROUPED_CONV3D_FWD})

View File

@@ -0,0 +1,61 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_grouped_conv3d_fwd_clamp_wmma_cshufflev3_ndhwgc_gkzyxc_ndhwgk_bf16_instances_part1(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
NDHWGC,
GKZYXC,
Tuple<>,
NDHWGK,
BF16,
BF16,
Tuple<>,
BF16,
PassThrough,
PassThrough,
Clamp>>>& instances)
{
add_device_operation_instances(
instances,
device_grouped_conv_fwd_wmma_cshufflev3_bf16_instances_part1<3,
NDHWGC,
GKZYXC,
Tuple<>,
NDHWGK,
ConvFwdDefault,
Tuple<>,
Clamp>{});
add_device_operation_instances(
instances,
device_grouped_conv_fwd_wmma_cshufflev3_bf16_instances_part1<3,
NDHWGC,
GKZYXC,
Tuple<>,
NDHWGK,
ConvFwd1x1P0,
Tuple<>,
Clamp>{});
add_device_operation_instances(
instances,
device_grouped_conv_fwd_wmma_cshufflev3_bf16_instances_part1<3,
NDHWGC,
GKZYXC,
Tuple<>,
NDHWGK,
ConvFwd1x1S1P0,
Tuple<>,
Clamp>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,61 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_grouped_conv3d_fwd_clamp_wmma_cshufflev3_ndhwgc_gkzyxc_ndhwgk_bf16_instances_part2(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
NDHWGC,
GKZYXC,
Tuple<>,
NDHWGK,
BF16,
BF16,
Tuple<>,
BF16,
PassThrough,
PassThrough,
Clamp>>>& instances)
{
add_device_operation_instances(
instances,
device_grouped_conv_fwd_wmma_cshufflev3_bf16_instances_part2<3,
NDHWGC,
GKZYXC,
Tuple<>,
NDHWGK,
ConvFwdDefault,
Tuple<>,
Clamp>{});
add_device_operation_instances(
instances,
device_grouped_conv_fwd_wmma_cshufflev3_bf16_instances_part2<3,
NDHWGC,
GKZYXC,
Tuple<>,
NDHWGK,
ConvFwd1x1P0,
Tuple<>,
Clamp>{});
add_device_operation_instances(
instances,
device_grouped_conv_fwd_wmma_cshufflev3_bf16_instances_part2<3,
NDHWGC,
GKZYXC,
Tuple<>,
NDHWGK,
ConvFwd1x1S1P0,
Tuple<>,
Clamp>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,61 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_grouped_conv3d_fwd_clamp_wmma_cshufflev3_ndhwgc_gkzyxc_ndhwgk_bf16_instances_part3(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
NDHWGC,
GKZYXC,
Tuple<>,
NDHWGK,
BF16,
BF16,
Tuple<>,
BF16,
PassThrough,
PassThrough,
Clamp>>>& instances)
{
add_device_operation_instances(
instances,
device_grouped_conv_fwd_wmma_cshufflev3_bf16_instances_part3<3,
NDHWGC,
GKZYXC,
Tuple<>,
NDHWGK,
ConvFwdDefault,
Tuple<>,
Clamp>{});
add_device_operation_instances(
instances,
device_grouped_conv_fwd_wmma_cshufflev3_bf16_instances_part3<3,
NDHWGC,
GKZYXC,
Tuple<>,
NDHWGK,
ConvFwd1x1P0,
Tuple<>,
Clamp>{});
add_device_operation_instances(
instances,
device_grouped_conv_fwd_wmma_cshufflev3_bf16_instances_part3<3,
NDHWGC,
GKZYXC,
Tuple<>,
NDHWGK,
ConvFwd1x1S1P0,
Tuple<>,
Clamp>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,61 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_grouped_conv3d_fwd_clamp_wmma_cshufflev3_ndhwgc_gkzyxc_ndhwgk_bf16_instances_part4(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
NDHWGC,
GKZYXC,
Tuple<>,
NDHWGK,
BF16,
BF16,
Tuple<>,
BF16,
PassThrough,
PassThrough,
Clamp>>>& instances)
{
add_device_operation_instances(
instances,
device_grouped_conv_fwd_wmma_cshufflev3_bf16_instances_part4<3,
NDHWGC,
GKZYXC,
Tuple<>,
NDHWGK,
ConvFwdDefault,
Tuple<>,
Clamp>{});
add_device_operation_instances(
instances,
device_grouped_conv_fwd_wmma_cshufflev3_bf16_instances_part4<3,
NDHWGC,
GKZYXC,
Tuple<>,
NDHWGK,
ConvFwd1x1P0,
Tuple<>,
Clamp>{});
add_device_operation_instances(
instances,
device_grouped_conv_fwd_wmma_cshufflev3_bf16_instances_part4<3,
NDHWGC,
GKZYXC,
Tuple<>,
NDHWGK,
ConvFwd1x1S1P0,
Tuple<>,
Clamp>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,61 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_grouped_conv3d_fwd_clamp_wmma_cshufflev3_ndhwgc_gkzyxc_ndhwgk_f16_instances_part1(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
NDHWGC,
GKZYXC,
Tuple<>,
NDHWGK,
F16,
F16,
Tuple<>,
F16,
PassThrough,
PassThrough,
Clamp>>>& instances)
{
add_device_operation_instances(
instances,
device_grouped_conv_fwd_wmma_cshufflev3_f16_instances_part1<3,
NDHWGC,
GKZYXC,
Tuple<>,
NDHWGK,
ConvFwdDefault,
Tuple<>,
Clamp>{});
add_device_operation_instances(
instances,
device_grouped_conv_fwd_wmma_cshufflev3_f16_instances_part1<3,
NDHWGC,
GKZYXC,
Tuple<>,
NDHWGK,
ConvFwd1x1P0,
Tuple<>,
Clamp>{});
add_device_operation_instances(
instances,
device_grouped_conv_fwd_wmma_cshufflev3_f16_instances_part1<3,
NDHWGC,
GKZYXC,
Tuple<>,
NDHWGK,
ConvFwd1x1S1P0,
Tuple<>,
Clamp>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,61 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_grouped_conv3d_fwd_clamp_wmma_cshufflev3_ndhwgc_gkzyxc_ndhwgk_f16_instances_part2(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
NDHWGC,
GKZYXC,
Tuple<>,
NDHWGK,
F16,
F16,
Tuple<>,
F16,
PassThrough,
PassThrough,
Clamp>>>& instances)
{
add_device_operation_instances(
instances,
device_grouped_conv_fwd_wmma_cshufflev3_f16_instances_part2<3,
NDHWGC,
GKZYXC,
Tuple<>,
NDHWGK,
ConvFwdDefault,
Tuple<>,
Clamp>{});
add_device_operation_instances(
instances,
device_grouped_conv_fwd_wmma_cshufflev3_f16_instances_part2<3,
NDHWGC,
GKZYXC,
Tuple<>,
NDHWGK,
ConvFwd1x1P0,
Tuple<>,
Clamp>{});
add_device_operation_instances(
instances,
device_grouped_conv_fwd_wmma_cshufflev3_f16_instances_part2<3,
NDHWGC,
GKZYXC,
Tuple<>,
NDHWGK,
ConvFwd1x1S1P0,
Tuple<>,
Clamp>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,61 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_grouped_conv3d_fwd_clamp_wmma_cshufflev3_ndhwgc_gkzyxc_ndhwgk_f16_instances_part3(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
NDHWGC,
GKZYXC,
Tuple<>,
NDHWGK,
F16,
F16,
Tuple<>,
F16,
PassThrough,
PassThrough,
Clamp>>>& instances)
{
add_device_operation_instances(
instances,
device_grouped_conv_fwd_wmma_cshufflev3_f16_instances_part3<3,
NDHWGC,
GKZYXC,
Tuple<>,
NDHWGK,
ConvFwdDefault,
Tuple<>,
Clamp>{});
add_device_operation_instances(
instances,
device_grouped_conv_fwd_wmma_cshufflev3_f16_instances_part3<3,
NDHWGC,
GKZYXC,
Tuple<>,
NDHWGK,
ConvFwd1x1P0,
Tuple<>,
Clamp>{});
add_device_operation_instances(
instances,
device_grouped_conv_fwd_wmma_cshufflev3_f16_instances_part3<3,
NDHWGC,
GKZYXC,
Tuple<>,
NDHWGK,
ConvFwd1x1S1P0,
Tuple<>,
Clamp>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,61 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_grouped_conv3d_fwd_clamp_wmma_cshufflev3_ndhwgc_gkzyxc_ndhwgk_f16_instances_part4(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
NDHWGC,
GKZYXC,
Tuple<>,
NDHWGK,
F16,
F16,
Tuple<>,
F16,
PassThrough,
PassThrough,
Clamp>>>& instances)
{
add_device_operation_instances(
instances,
device_grouped_conv_fwd_wmma_cshufflev3_f16_instances_part4<3,
NDHWGC,
GKZYXC,
Tuple<>,
NDHWGK,
ConvFwdDefault,
Tuple<>,
Clamp>{});
add_device_operation_instances(
instances,
device_grouped_conv_fwd_wmma_cshufflev3_f16_instances_part4<3,
NDHWGC,
GKZYXC,
Tuple<>,
NDHWGK,
ConvFwd1x1P0,
Tuple<>,
Clamp>{});
add_device_operation_instances(
instances,
device_grouped_conv_fwd_wmma_cshufflev3_f16_instances_part4<3,
NDHWGC,
GKZYXC,
Tuple<>,
NDHWGK,
ConvFwd1x1S1P0,
Tuple<>,
Clamp>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -1,11 +1,16 @@
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
# ONLY XDL_KERNELS
# ONLY XDL_AND_WMMA_KERNELS
set(GROUPED_CONV3D_FWD_SCALEADD_AB
xdl/device_grouped_conv3d_fwd_xdl_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp
xdl/device_grouped_conv3d_fwd_xdl_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp
xdl/device_grouped_conv3d_fwd_xdl_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp
xdl/device_grouped_conv3d_fwd_xdl_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_int8_instance.cpp)
xdl/device_grouped_conv3d_fwd_xdl_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_int8_instance.cpp
# WMMA CSHUFFLE V3
wmma/device_grouped_conv3d_fwd_wmma_cshufflev3_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp
wmma/device_grouped_conv3d_fwd_wmma_cshufflev3_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp
)
add_instance_library(device_grouped_conv3d_fwd_scaleadd_ab_instance ${GROUPED_CONV3D_FWD_SCALEADD_AB})

View File

@@ -0,0 +1,52 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_scaleadd_ab_instance.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_grouped_conv3d_fwd_wmma_cshufflev3_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_bf16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
NDHWGC,
GKZYXC,
ck::Tuple<>,
NDHWGK,
ck::Tuple<BF16, BF16>,
ck::Tuple<BF16, BF16>,
ck::Tuple<>,
BF16,
ScaleAdd,
ScaleAdd,
PassThrough>>>& instances)
{
add_device_operation_instances(
instances,
device_grouped_conv_fwd_wmma_cshufflev3_scaleadd_ab_bf16_instances<3,
NDHWGC,
GKZYXC,
NDHWGK,
ConvFwdDefault>{});
add_device_operation_instances(
instances,
device_grouped_conv_fwd_wmma_cshufflev3_scaleadd_ab_bf16_instances<3,
NDHWGC,
GKZYXC,
NDHWGK,
ConvFwd1x1P0>{});
add_device_operation_instances(
instances,
device_grouped_conv_fwd_wmma_cshufflev3_scaleadd_ab_bf16_instances<3,
NDHWGC,
GKZYXC,
NDHWGK,
ConvFwd1x1S1P0>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,52 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_scaleadd_ab_instance.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_grouped_conv3d_fwd_wmma_cshufflev3_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
NDHWGC,
GKZYXC,
ck::Tuple<>,
NDHWGK,
ck::Tuple<F16, F16>,
ck::Tuple<F16, F16>,
ck::Tuple<>,
F16,
ScaleAdd,
ScaleAdd,
PassThrough>>>& instances)
{
add_device_operation_instances(
instances,
device_grouped_conv_fwd_wmma_cshufflev3_scaleadd_ab_f16_instances<3,
NDHWGC,
GKZYXC,
NDHWGK,
ConvFwdDefault>{});
add_device_operation_instances(
instances,
device_grouped_conv_fwd_wmma_cshufflev3_scaleadd_ab_f16_instances<3,
NDHWGC,
GKZYXC,
NDHWGK,
ConvFwd1x1P0>{});
add_device_operation_instances(
instances,
device_grouped_conv_fwd_wmma_cshufflev3_scaleadd_ab_f16_instances<3,
NDHWGC,
GKZYXC,
NDHWGK,
ConvFwd1x1S1P0>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -196,6 +196,7 @@ bool profile_grouped_conv_fwd_bias_clamp_impl(int do_verification,
float best_tflops = 0;
float best_gb_per_sec = 0;
int num_kernel = 0;
int valids = 0;
// profile device op instances
bool pass = true;
@@ -217,6 +218,8 @@ bool profile_grouped_conv_fwd_bias_clamp_impl(int do_verification,
// re-init output to zero before profiling next kernel
out_device_buf.SetZero();
valids++;
std::string op_name = op_ptr->GetTypeString();
auto invoker_ptr = op_ptr->MakeInvokerPointer();
@@ -322,6 +325,8 @@ bool profile_grouped_conv_fwd_bias_clamp_impl(int do_verification,
run_impl(op_ptr, argument_ptr);
}
printf("\033[36mvalids: %d\n\033[0m", valids);
std::cout << "Best configuration parameters:" << "\nname: " << best_op_name
<< "\navg_time: " << best_avg_time << "\ntflops: " << best_tflops
<< "\nGB/s: " << best_gb_per_sec << std::endl;

View File

@@ -177,6 +177,8 @@ bool profile_grouped_conv_fwd_impl(int do_verification,
float best_tflops = 0;
float best_gb_per_sec = 0;
index_t num_kernel = 0;
int valids = 0;
// profile device op instances
bool pass = true;
@@ -196,6 +198,9 @@ bool profile_grouped_conv_fwd_impl(int do_verification,
}
std::string op_name = op_ptr->GetTypeString();
valids++;
out_device_buf.SetZero();
auto invoker_ptr = op_ptr->MakeInvokerPointer();
@@ -289,6 +294,8 @@ bool profile_grouped_conv_fwd_impl(int do_verification,
run_impl(op_ptr, argument_ptr);
}
printf("\033[36mvalids: %d\033[0m\n", valids);
std::cout << "Best configuration parameters:" << "\nname: " << best_op_name
<< "\navg_time: " << best_avg_time << "\ntflops: " << best_tflops
<< "\nGB/s: " << best_gb_per_sec << std::endl;

View File

@@ -205,16 +205,11 @@ if(SUPPORTED_GPU_TARGETS MATCHES "gfx9|gfx1[12]")
list(APPEND DEVICE_INSTANCES device_conv2d_fwd_instance)
list(APPEND DEVICE_INSTANCES device_conv2d_fwd_bias_relu_instance)
list(APPEND DEVICE_INSTANCES device_conv2d_fwd_bias_relu_add_instance)
list(APPEND DEVICE_INSTANCES device_grouped_conv1d_fwd_instance)
list(APPEND DEVICE_INSTANCES device_conv1d_bwd_data_instance)
list(APPEND DEVICE_INSTANCES device_conv3d_bwd_data_instance)
list(APPEND DEVICE_INSTANCES device_conv2d_bwd_data_instance)
list(APPEND DEVICE_INSTANCES device_grouped_conv3d_fwd_convscale_instance)
list(APPEND DEVICE_INSTANCES device_grouped_conv3d_fwd_convinvscale_instance)
list(APPEND DEVICE_INSTANCES device_grouped_conv2d_fwd_clamp_instance)
list(APPEND DEVICE_INSTANCES device_grouped_conv3d_fwd_clamp_instance)
list(APPEND DEVICE_INSTANCES device_grouped_conv2d_fwd_bias_clamp_instance)
list(APPEND DEVICE_INSTANCES device_grouped_conv3d_fwd_bias_clamp_instance)
endif()
if((SUPPORTED_GPU_TARGETS MATCHES "gfx9" AND (DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)) OR
@@ -231,10 +226,15 @@ if(SUPPORTED_GPU_TARGETS MATCHES "gfx9|gfx1[12]")
list(APPEND DEVICE_INSTANCES device_gemm_b_scale_instance)
list(APPEND DEVICE_INSTANCES device_gemm_universal_reduce_instance)
list(APPEND DEVICE_INSTANCES device_batched_gemm_b_scale_instance)
list(APPEND DEVICE_INSTANCES device_grouped_conv3d_fwd_instance)
list(APPEND DEVICE_INSTANCES device_grouped_conv2d_bwd_data_instance)
list(APPEND DEVICE_INSTANCES device_grouped_conv3d_bwd_data_instance)
list(APPEND DEVICE_INSTANCES device_grouped_conv1d_fwd_instance)
list(APPEND DEVICE_INSTANCES device_grouped_conv2d_fwd_instance)
list(APPEND DEVICE_INSTANCES device_grouped_conv3d_fwd_instance)
list(APPEND DEVICE_INSTANCES device_grouped_conv2d_fwd_clamp_instance)
list(APPEND DEVICE_INSTANCES device_grouped_conv3d_fwd_clamp_instance)
list(APPEND DEVICE_INSTANCES device_grouped_conv2d_fwd_bias_clamp_instance)
list(APPEND DEVICE_INSTANCES device_grouped_conv3d_fwd_bias_clamp_instance)
list(APPEND DEVICE_INSTANCES device_gemm_add_relu_instance)
list(APPEND DEVICE_INSTANCES device_gemm_multi_abd_instance)
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)

View File

@@ -33,6 +33,7 @@ set(REGRESSION_TESTS
test_convnd_fwd
test_convnd_bwd_data
test_grouped_convnd_fwd
test_grouped_convnd_fwd_scaleadd_ab
test_grouped_convnd_bwd_weight
test_softmax_rank3
test_softmax_rank4

View File

@@ -5,6 +5,9 @@ if(GPU_TARGETS MATCHES "gfx9|gfx11|gfx12")
add_gtest_executable(test_grouped_convnd_fwd test_grouped_convnd_fwd.cpp)
target_link_libraries(test_grouped_convnd_fwd PRIVATE utility device_grouped_conv1d_fwd_instance device_grouped_conv2d_fwd_instance device_grouped_conv3d_fwd_instance)
add_gtest_executable(test_grouped_convnd_fwd_scaleadd_ab test_grouped_convnd_fwd_scaleadd_ab.cpp)
target_link_libraries(test_grouped_convnd_fwd_scaleadd_ab PRIVATE utility device_grouped_conv3d_fwd_scaleadd_ab_instance)
add_executable(test_grouped_convnd_fwd_large_cases_xdl test_grouped_convnd_fwd_large_cases_xdl.cpp)
target_compile_options(test_grouped_convnd_fwd_large_cases_xdl PRIVATE -Wno-global-constructors -Wno-undef)
target_link_libraries(test_grouped_convnd_fwd_large_cases_xdl PRIVATE gtest_main getopt::getopt utility device_grouped_conv1d_fwd_instance device_grouped_conv2d_fwd_instance device_grouped_conv3d_fwd_instance)

View File

@@ -7,20 +7,32 @@
#include <vector>
#include <gtest/gtest.h>
#include "ck/utility/common_header.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "profiler/profile_grouped_conv_fwd_impl.hpp"
static ck::index_t param_mask = 0xffff;
static ck::index_t instance_index = -1;
using I8 = int8_t;
using F8 = ck::f8_t;
using BF8 = ck::bf8_t;
using F16 = ck::half_t;
using BF16 = ck::bhalf_t;
using F32 = float;
template <typename Tuple>
class TestGroupedConvndFwd : public ::testing::Test
{
protected:
using DataType = std::tuple_element_t<0, Tuple>;
using InLayout = std::tuple_element_t<1, Tuple>;
using WeiLayout = std::tuple_element_t<2, Tuple>;
using OutLayout = std::tuple_element_t<3, Tuple>;
using IndexType = ck::index_t;
using InDataType = std::tuple_element_t<0, Tuple>;
using WeiDataType = std::tuple_element_t<1, Tuple>;
using OutDataType = std::tuple_element_t<2, Tuple>;
using AComputeType = std::tuple_element_t<3, Tuple>;
using BComputeType = std::tuple_element_t<4, Tuple>;
using InLayout = std::tuple_element_t<5, Tuple>;
using WeiLayout = std::tuple_element_t<6, Tuple>;
using OutLayout = std::tuple_element_t<7, Tuple>;
using IndexType = ck::index_t;
std::vector<ck::utils::conv::ConvParam> conv_params;
@@ -35,16 +47,25 @@ class TestGroupedConvndFwd : public ::testing::Test
{
continue;
}
// FP8 workaround for CDNA1/2 is currently broken, do not test.
if(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a")
{
if(std::is_same<InDataType, F8>::value || std::is_same<InDataType, BF8>::value)
{
printf("Skipping FP8 / BF8 tests on CDNA1/2.\n");
continue;
}
}
auto& param = conv_params[i];
pass = pass && ck::profiler::profile_grouped_conv_fwd_impl<NDimSpatial,
InLayout,
WeiLayout,
OutLayout,
DataType,
DataType,
DataType,
DataType,
DataType,
InDataType,
WeiDataType,
OutDataType,
AComputeType,
BComputeType,
IndexType>(
2, // do_verification
1, // init_method: integer value
@@ -60,36 +81,43 @@ class TestGroupedConvndFwd : public ::testing::Test
using namespace ck::tensor_layout::convolution;
using KernelTypes1d = ::testing::Types<std::tuple<float, GNWC, GKXC, GNWK>,
std::tuple<ck::half_t, GNWC, GKXC, GNWK>,
std::tuple<ck::bhalf_t, GNWC, GKXC, GNWK>,
std::tuple<int8_t, GNWC, GKXC, GNWK>>;
using KernelTypes1d = ::testing::Types<std::tuple<F32, F32, F32, F32, F32, GNWC, GKXC, GNWK>,
std::tuple<F16, F16, F16, F16, F16, GNWC, GKXC, GNWK>,
std::tuple<BF16, BF16, BF16, BF16, BF16, GNWC, GKXC, GNWK>,
std::tuple<I8, I8, I8, I8, I8, GNWC, GKXC, GNWK>>;
using KernelTypes2d = ::testing::Types<std::tuple<float, GNHWC, GKYXC, GNHWK>,
std::tuple<ck::half_t, GNHWC, GKYXC, GNHWK>,
std::tuple<ck::bhalf_t, GNHWC, GKYXC, GNHWK>,
std::tuple<float, NHWGC, GKYXC, NHWGK>,
std::tuple<ck::half_t, NHWGC, GKYXC, NHWGK>,
std::tuple<ck::bhalf_t, NHWGC, GKYXC, NHWGK>,
std::tuple<int8_t, NHWGC, GKYXC, NHWGK>,
std::tuple<float, NGCHW, GKYXC, NGKHW>,
std::tuple<ck::half_t, NGCHW, GKYXC, NGKHW>,
std::tuple<ck::bhalf_t, NGCHW, GKYXC, NGKHW>,
std::tuple<int8_t, NGCHW, GKYXC, NGKHW>,
std::tuple<float, NGCHW, GKCYX, NGKHW>,
std::tuple<ck::half_t, NGCHW, GKCYX, NGKHW>,
std::tuple<ck::bhalf_t, NGCHW, GKCYX, NGKHW>>;
using KernelTypes2d =
::testing::Types<std::tuple<F32, F32, F32, F32, F32, GNHWC, GKYXC, GNHWK>,
std::tuple<F16, F16, F16, F16, F16, GNHWC, GKYXC, GNHWK>,
std::tuple<BF16, BF16, BF16, BF16, BF16, GNHWC, GKYXC, GNHWK>,
std::tuple<F32, F32, F32, F32, F32, NHWGC, GKYXC, NHWGK>,
std::tuple<F16, F16, F16, F16, F16, NHWGC, GKYXC, NHWGK>,
std::tuple<BF16, BF16, BF16, BF16, BF16, NHWGC, GKYXC, NHWGK>,
std::tuple<I8, I8, I8, I8, I8, NHWGC, GKYXC, NHWGK>,
std::tuple<F32, F32, F32, F32, F32, NGCHW, GKYXC, NGKHW>,
std::tuple<F16, F16, F16, F16, F16, NGCHW, GKYXC, NGKHW>,
std::tuple<BF16, BF16, BF16, BF16, BF16, NGCHW, GKYXC, NGKHW>,
std::tuple<I8, I8, I8, I8, I8, NGCHW, GKYXC, NGKHW>,
std::tuple<F32, F32, F32, F32, F32, NGCHW, GKCYX, NGKHW>,
std::tuple<F16, F16, F16, F16, F16, NGCHW, GKCYX, NGKHW>,
std::tuple<BF16, BF16, BF16, BF16, BF16, NGCHW, GKCYX, NGKHW>>;
using KernelTypes3d = ::testing::Types<std::tuple<float, GNDHWC, GKZYXC, GNDHWK>,
std::tuple<ck::half_t, GNDHWC, GKZYXC, GNDHWK>,
std::tuple<ck::bhalf_t, GNDHWC, GKZYXC, GNDHWK>,
std::tuple<int8_t, GNDHWC, GKZYXC, GNDHWK>,
std::tuple<float, NDHWGC, GKZYXC, NDHWGK>,
std::tuple<ck::half_t, NDHWGC, GKZYXC, NDHWGK>,
std::tuple<ck::bhalf_t, NDHWGC, GKZYXC, NDHWGK>,
std::tuple<float, NGCDHW, GKCZYX, NGKDHW>,
std::tuple<ck::half_t, NGCDHW, GKCZYX, NGKDHW>,
std::tuple<ck::bhalf_t, NGCDHW, GKCZYX, NGKDHW>>;
using KernelTypes3d =
::testing::Types<std::tuple<F32, F32, F32, F32, F32, GNDHWC, GKZYXC, GNDHWK>,
std::tuple<F16, F16, F16, F16, F16, GNDHWC, GKZYXC, GNDHWK>,
std::tuple<BF16, BF16, BF16, BF16, BF16, GNDHWC, GKZYXC, GNDHWK>,
std::tuple<I8, I8, I8, I8, I8, GNDHWC, GKZYXC, GNDHWK>,
std::tuple<I8, I8, I8, I8, I8, NDHWGC, GKZYXC, NDHWGK>,
std::tuple<F8, F8, F8, F8, F8, NDHWGC, GKZYXC, NDHWGK>,
std::tuple<BF8, BF8, F8, BF8, BF8, NDHWGC, GKZYXC, NDHWGK>,
std::tuple<F8, BF8, F8, F8, BF8, NDHWGC, GKZYXC, NDHWGK>,
std::tuple<BF8, F8, F8, BF8, F8, NDHWGC, GKZYXC, NDHWGK>,
std::tuple<F32, F32, F32, F32, F32, NDHWGC, GKZYXC, NDHWGK>,
std::tuple<F16, F16, F16, F16, F16, NDHWGC, GKZYXC, NDHWGK>,
std::tuple<BF16, BF16, BF16, BF16, BF16, NDHWGC, GKZYXC, NDHWGK>,
std::tuple<F32, F32, F32, F32, F32, NGCDHW, GKCZYX, NGKDHW>,
std::tuple<F16, F16, F16, F16, F16, NGCDHW, GKCZYX, NGKDHW>,
std::tuple<BF16, BF16, BF16, BF16, BF16, NGCDHW, GKCZYX, NGKDHW>>;
template <typename Tuple>
class TestGroupedConvndFwd1d : public TestGroupedConvndFwd<Tuple>
@@ -125,15 +153,32 @@ TYPED_TEST(TestGroupedConvndFwd1d, Test1D)
TYPED_TEST(TestGroupedConvndFwd2d, Test2D)
{
this->conv_params.clear();
this->conv_params.push_back(
{2, 3, 5, 96, 200, {1, 1}, {73, 128}, {1, 1}, {1, 1}, {0, 0}, {0, 0}});
this->conv_params.push_back(
{2, 1, 1, 32, 32, {1, 1}, {128, 128}, {1, 1}, {1, 1}, {0, 0}, {0, 0}});
this->conv_params.push_back(
{2, 1, 1, 32, 32, {2, 2}, {128, 128}, {1, 1}, {1, 1}, {0, 0}, {0, 0}});
this->conv_params.push_back(
{2, 1, 1, 32, 32, {3, 3}, {128, 128}, {1, 1}, {1, 1}, {0, 0}, {0, 0}});
this->conv_params.push_back(
{2, 1, 1, 32, 32, {5, 5}, {128, 128}, {1, 1}, {1, 1}, {0, 0}, {0, 0}});
this->conv_params.push_back(
{2, 1, 1, 32, 32, {9, 9}, {128, 128}, {1, 1}, {1, 1}, {0, 0}, {0, 0}});
this->conv_params.push_back(
{2, 2, 32, 128, 256, {1, 1}, {7, 7}, {2, 2}, {1, 1}, {0, 0}, {0, 0}});
this->conv_params.push_back(
{2, 2, 32, 128, 256, {3, 3}, {14, 14}, {1, 1}, {1, 1}, {1, 1}, {1, 1}});
this->conv_params.push_back(
{2, 2, 32, 128, 256, {1, 1}, {3, 3}, {1, 1}, {1, 1}, {0, 0}, {0, 0}});
this->conv_params.push_back({2, 1, 1, 1, 32, {3, 3}, {32, 32}, {1, 1}, {1, 1}, {1, 1}, {1, 1}});
this->conv_params.push_back({2, 1, 1, 64, 3, {3, 3}, {32, 32}, {1, 1}, {1, 1}, {1, 1}, {1, 1}});
this->conv_params.push_back({2, 1, 1, 1, 1, {3, 3}, {32, 32}, {1, 1}, {1, 1}, {1, 1}, {1, 1}});
this->conv_params.push_back(
{2, 96, 1, 1, 1, {1, 1}, {120, 160}, {1, 1}, {1, 1}, {1, 1}, {1, 1}});
this->conv_params.push_back(
{2, 96, 1, 1, 1, {3, 3}, {120, 160}, {1, 1}, {1, 1}, {1, 1}, {1, 1}});
this->template Run<2>();
@@ -142,20 +187,41 @@ TYPED_TEST(TestGroupedConvndFwd2d, Test2D)
TYPED_TEST(TestGroupedConvndFwd3d, Test3D)
{
this->conv_params.clear();
this->conv_params.push_back(
{3, 3, 5, 96, 200, {1, 1, 1}, {37, 37, 16}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}});
this->conv_params.push_back(
{3, 1, 1, 32, 32, {1, 1, 1}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}});
this->conv_params.push_back(
{3, 1, 1, 32, 32, {2, 2, 2}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}});
this->conv_params.push_back(
{3, 1, 1, 32, 32, {3, 3, 3}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}});
this->conv_params.push_back(
{3, 1, 1, 32, 32, {5, 5, 5}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}});
this->conv_params.push_back(
{3, 1, 1, 32, 32, {9, 9, 9}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}});
this->conv_params.push_back(
{3, 2, 32, 128, 256, {1, 1, 1}, {7, 7, 7}, {2, 2, 2}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}});
this->conv_params.push_back(
{3, 2, 32, 128, 256, {3, 3, 3}, {14, 14, 3}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}});
this->conv_params.push_back(
{3, 2, 32, 128, 256, {1, 1, 1}, {3, 3, 3}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}});
this->conv_params.push_back(
{3, 1, 1, 32, 32, {1, 1, 1}, {16, 16, 16}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}});
this->conv_params.push_back(
{3, 1, 1, 1, 32, {3, 3, 3}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}});
this->conv_params.push_back(
{3, 1, 1, 64, 3, {3, 3, 3}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}});
this->conv_params.push_back(
{3, 1, 1, 1, 1, {3, 3, 3}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}});
this->conv_params.push_back(
{3, 96, 1, 1, 1, {3, 3, 3}, {4, 30, 160}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}});
{3, 96, 1, 1, 1, {1, 1, 1}, {120, 40, 20}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}});
this->conv_params.push_back(
{3, 96, 1, 1, 1, {3, 3, 3}, {120, 40, 20}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}});
this->template Run<3>();
}

View File

@@ -0,0 +1,411 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include <cstdlib>
#include <iomanip>
#include <iostream>
#include <iterator>
#include <numeric>
#include <vector>
#include <initializer_list>
#include <gtest/gtest.h>
#include "ck/utility/data_type.hpp"
#include "ck/utility/tuple.hpp"
#include "ck/ck.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_scaleadd_ab.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/algorithm.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/convolution_parameter.hpp"
#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp"
using I8 = int8_t;
using F16 = ck::half_t;
using BF16 = ck::bhalf_t;
using F32 = float;
// This is pretty much a fully functional profiler function, but I only implemented it here to add a
// proper gtest test for the scaleadd_ab flavor. At some point we may want to move this and add it
// to the ckProfiler.
template <ck::index_t NDimSpatial,
typename InLayout,
typename WeiLayout,
typename OutLayout,
typename InDataType,
typename WeiDataType,
typename OutDataType,
typename IndexType = ck::index_t>
bool profile_grouped_conv_fwd_scaleadd_ab_impl(int do_verification,
int init_method,
bool do_log,
[[maybe_unused]] bool time_kernel,
const ck::utils::conv::ConvParam& conv_param)
{
constexpr ck::index_t NumAs = 2;
constexpr ck::index_t NumBs = 2;
using InElementOp = ck::tensor_operation::element_wise::ScaleAdd;
using WeiElementOp = ck::tensor_operation::element_wise::ScaleAdd;
using OutElementOp = ck::tensor_operation::element_wise::PassThrough;
constexpr float scale = 1.5f;
const auto in_element_op = InElementOp{scale};
const auto wei_element_op = WeiElementOp{scale};
const auto out_element_op = OutElementOp{};
const auto in_g_n_c_wis_desc =
ck::utils::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed<InLayout>(conv_param);
const auto wei_g_k_c_xs_desc =
ck::utils::conv::make_weight_host_tensor_descriptor_g_k_c_xs_packed<WeiLayout>(conv_param);
const auto out_g_n_k_wos_desc =
ck::utils::conv::make_output_host_tensor_descriptor_g_n_k_wos_packed<OutLayout>(conv_param);
std::array<IndexType, NDimSpatial + 3> a_g_n_c_wis_lengths{};
std::array<IndexType, NDimSpatial + 3> a_g_n_c_wis_strides{};
std::array<IndexType, NDimSpatial + 3> b_g_k_c_xs_lengths{};
std::array<IndexType, NDimSpatial + 3> b_g_k_c_xs_strides{};
std::array<IndexType, NDimSpatial + 3> e_g_n_k_wos_lengths{};
std::array<IndexType, NDimSpatial + 3> e_g_n_k_wos_strides{};
std::array<IndexType, NDimSpatial> conv_filter_strides{};
std::array<IndexType, NDimSpatial> conv_filter_dilations{};
std::array<IndexType, NDimSpatial> input_left_pads{};
std::array<IndexType, NDimSpatial> input_right_pads{};
auto copy = [](const auto& x, auto& y) { ck::ranges::copy(x, y.begin()); };
copy(in_g_n_c_wis_desc.GetLengths(), a_g_n_c_wis_lengths);
copy(in_g_n_c_wis_desc.GetStrides(), a_g_n_c_wis_strides);
copy(wei_g_k_c_xs_desc.GetLengths(), b_g_k_c_xs_lengths);
copy(wei_g_k_c_xs_desc.GetStrides(), b_g_k_c_xs_strides);
copy(out_g_n_k_wos_desc.GetLengths(), e_g_n_k_wos_lengths);
copy(out_g_n_k_wos_desc.GetStrides(), e_g_n_k_wos_strides);
copy(conv_param.conv_filter_strides_, conv_filter_strides);
copy(conv_param.conv_filter_dilations_, conv_filter_dilations);
copy(conv_param.input_left_pads_, input_left_pads);
copy(conv_param.input_right_pads_, input_right_pads);
ck::Tensor<InDataType> input(in_g_n_c_wis_desc);
ck::Tensor<InDataType> input_bias(in_g_n_c_wis_desc);
ck::Tensor<WeiDataType> weight(wei_g_k_c_xs_desc);
ck::Tensor<WeiDataType> weight_bias(wei_g_k_c_xs_desc);
ck::Tensor<OutDataType> host_output(out_g_n_k_wos_desc);
ck::Tensor<OutDataType> device_output(out_g_n_k_wos_desc);
std::cout << "input: " << input.mDesc << std::endl;
std::cout << "weight: " << weight.mDesc << std::endl;
std::cout << "output: " << host_output.mDesc << std::endl;
switch(init_method)
{
case 0: break;
case 1:
input.GenerateTensorValue(GeneratorTensor_2<InDataType>{-2, 2});
input_bias.GenerateTensorValue(GeneratorTensor_2<InDataType>{-2, 2});
weight.GenerateTensorValue(GeneratorTensor_2<WeiDataType>{-2, 2});
weight_bias.GenerateTensorValue(GeneratorTensor_2<WeiDataType>{-2, 2});
break;
default:
input.GenerateTensorValue(GeneratorTensor_3<InDataType>{-1.0, 1.0});
input_bias.GenerateTensorValue(GeneratorTensor_3<InDataType>{-1.0, 1.0});
weight.GenerateTensorValue(GeneratorTensor_3<WeiDataType>{-0.05, 0.05});
weight_bias.GenerateTensorValue(GeneratorTensor_3<WeiDataType>{-1.0, 1.0});
}
ck::DeviceMem in_device_buf(sizeof(InDataType) * input.mDesc.GetElementSpaceSize());
ck::DeviceMem in_bias_device_buf(sizeof(InDataType) * input_bias.mDesc.GetElementSpaceSize());
ck::DeviceMem wei_device_buf(sizeof(WeiDataType) * weight.mDesc.GetElementSpaceSize());
ck::DeviceMem wei_bias_device_buf(sizeof(WeiDataType) *
weight_bias.mDesc.GetElementSpaceSize());
ck::DeviceMem out_device_buf(sizeof(OutDataType) * device_output.mDesc.GetElementSpaceSize());
in_device_buf.ToDevice(input.mData.data());
in_bias_device_buf.ToDevice(input_bias.mData.data());
wei_device_buf.ToDevice(weight.mData.data());
wei_bias_device_buf.ToDevice(weight_bias.mData.data());
// Run reference op
if(do_verification)
{
const std::array<ck::Tensor<InDataType>, NumAs - 1> elementwise_a_tensors = {input_bias};
const std::array<ck::Tensor<WeiDataType>, NumBs - 1> elementwise_b_tensors = {weight_bias};
auto ref_conv = ck::tensor_operation::host::ReferenceConvFwd<NDimSpatial,
InDataType,
WeiDataType,
OutDataType,
InElementOp,
WeiElementOp,
OutElementOp,
NumAs - 1,
NumBs - 1>();
auto ref_invoker = ref_conv.MakeInvoker();
auto ref_argument = ref_conv.MakeArgument(input,
weight,
host_output,
conv_param.conv_filter_strides_,
conv_param.conv_filter_dilations_,
conv_param.input_left_pads_,
conv_param.input_right_pads_,
in_element_op,
wei_element_op,
out_element_op,
elementwise_a_tensors,
elementwise_b_tensors);
// init host output to zero
host_output.SetZero();
ref_invoker.Run(ref_argument);
}
std::string best_op_name;
float best_avg_time = 0;
float best_tflops = 0;
float best_gb_per_sec = 0;
int valids = 0;
// profile device op instances
bool pass = true;
auto run_impl = [&](auto& op_ptr, auto& argument_ptr) {
// workspace_sz will be equal to 0 for other layout than NGCHW
// TODO: Is workspace even necessary?
const std::size_t workspace_sz = op_ptr->GetWorkSpaceSize(argument_ptr.get());
ck::DeviceMem workspace_dev(workspace_sz);
op_ptr->SetWorkSpacePointer(argument_ptr.get(), workspace_dev.GetDeviceBuffer());
if(op_ptr->IsSupportedArgument(argument_ptr.get()))
{
// re-init output to zero before profiling next kernel
out_device_buf.SetZero();
valids++;
std::string op_name = op_ptr->GetTypeString();
auto invoker_ptr = op_ptr->MakeInvokerPointer();
float avg_time =
invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel});
std::size_t flop = conv_param.GetFlops() +
2 * conv_param.GetOutputByte<InDataType>() / sizeof(InDataType) +
2 * conv_param.GetOutputByte<WeiDataType>() / sizeof(WeiDataType);
std::size_t num_btype = conv_param.GetByte<InDataType, WeiDataType, OutDataType>() +
conv_param.GetInputByte<InDataType>() +
conv_param.GetWeightByte<WeiDataType>();
float tflops = static_cast<float>(flop) / 1.E9 / avg_time;
float gb_per_sec = num_btype / 1.E6 / avg_time;
std::cout << "Perf: " << std::setw(10) << avg_time << " ms, " << tflops << " TFlops, "
<< gb_per_sec << " GB/s, " << op_name << std::endl;
if(tflops > best_tflops)
{
best_op_name = op_name;
best_tflops = tflops;
best_avg_time = avg_time;
best_gb_per_sec = gb_per_sec;
}
if(do_verification)
{
out_device_buf.FromDevice(device_output.mData.data());
pass = pass & ck::utils::check_err(device_output, host_output);
if(do_log)
{
printf("log\n");
// LogRangeAsType<float>(std::cout << "input : ", input.mData, ",") <<
// std::endl; LogRangeAsType<float>(std::cout << "input_bias: ",
// input_bias.mData, ",")
// << std::endl;
// LogRangeAsType<float>(std::cout << "weight: ", weight.mData, ",") <<
// std::endl; LogRangeAsType<float>(std::cout << "weight_bias: ",
// weight_bias.mData, ",")
// << std::endl;
// LogRangeAsType<float>(std::cout << "host_output : ", host_output.mData,
// ",")
// << std::endl;
// LogRangeAsType<float>(std::cout << "device_output: ",
// device_output.mData, ",")
// << std::endl;
}
}
}
else
{
std::cout << op_ptr->GetTypeString() << " does not support this problem" << std::endl;
}
};
// InDataType and WeiDataType must be tuple, inLayout and weiLayout are single.
using DeviceOp = ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD<
NDimSpatial,
InLayout,
WeiLayout,
ck::Tuple<>,
OutLayout,
ck::Tuple<InDataType, InDataType>,
ck::Tuple<WeiDataType, WeiDataType>,
ck::Tuple<>,
OutDataType,
InElementOp,
WeiElementOp,
OutElementOp>;
// get device op instances
const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory<
DeviceOp>::GetInstances();
std::cout << "ckProfiler found " << op_ptrs.size() << " instances" << std::endl;
std::array<const void*, NumAs> as{in_device_buf.GetDeviceBuffer(),
in_bias_device_buf.GetDeviceBuffer()};
std::array<const void*, NumBs> bs{wei_device_buf.GetDeviceBuffer(),
wei_bias_device_buf.GetDeviceBuffer()};
std::array<const void*, 0> ds{};
for(auto& op_ptr : op_ptrs)
{
auto argument_ptr = op_ptr->MakeArgumentPointer(as,
bs,
ds,
out_device_buf.GetDeviceBuffer(),
a_g_n_c_wis_lengths,
a_g_n_c_wis_strides,
b_g_k_c_xs_lengths,
b_g_k_c_xs_strides,
{},
{},
e_g_n_k_wos_lengths,
e_g_n_k_wos_strides,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads,
in_element_op,
wei_element_op,
out_element_op);
run_impl(op_ptr, argument_ptr);
}
printf("\033[36mvalids: %d\n\033[0m", valids);
std::cout << "Best configuration parameters:" << "\nname: " << best_op_name
<< "\navg_time: " << best_avg_time << "\ntflops: " << best_tflops
<< "\nGB/s: " << best_gb_per_sec << std::endl;
return pass;
}
template <typename Tuple>
class TestGroupedConvndFwdScaleaddAB : public ::testing::Test
{
protected:
using InDataType = std::tuple_element_t<0, Tuple>;
using WeiDataType = std::tuple_element_t<1, Tuple>;
using OutDataType = std::tuple_element_t<2, Tuple>;
using InLayout = std::tuple_element_t<3, Tuple>;
using WeiLayout = std::tuple_element_t<4, Tuple>;
using OutLayout = std::tuple_element_t<5, Tuple>;
std::vector<ck::utils::conv::ConvParam> conv_params;
template <ck::index_t NDimSpatial>
void Run()
{
EXPECT_FALSE(conv_params.empty());
bool pass = true;
for(auto& param : conv_params)
{
pass = pass && profile_grouped_conv_fwd_scaleadd_ab_impl<NDimSpatial,
InLayout,
WeiLayout,
OutLayout,
InDataType,
WeiDataType,
OutDataType>(
true, // do_verification
1, // init_method: integer value
false, // do_log
false, // time_kernel
param);
}
EXPECT_TRUE(pass);
}
};
using namespace ck::tensor_layout::convolution;
// TODO: Not all possible layouts exist in the instance factory, (GNDHWC, GKZYXC, GNDHWK) only
// exists in example 62.
using KernelTypes3d = ::testing::Types<std::tuple<F16, F16, F16, NDHWGC, GKZYXC, NDHWGK>,
std::tuple<BF16, BF16, BF16, NDHWGC, GKZYXC, NDHWGK>,
std::tuple<F32, F32, F32, NDHWGC, GKZYXC, NDHWGK>,
std::tuple<I8, I8, I8, NDHWGC, GKZYXC, NDHWGK>>;
template <typename Tuple>
class TestGroupedConvndFwdScaleaddAB3d : public TestGroupedConvndFwdScaleaddAB<Tuple>
{
};
TYPED_TEST_SUITE(TestGroupedConvndFwdScaleaddAB3d, KernelTypes3d);
TYPED_TEST(TestGroupedConvndFwdScaleaddAB3d, Test3D)
{
this->conv_params.clear();
// Client example 24. This one takes quite long.
this->conv_params.push_back(
{3, 32, 64, 32, 64, {3, 3, 3}, {14, 14, 14}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}});
// Generic problems, same set as for vanilla, clamp, and (gk) bias clamp tests.
this->conv_params.push_back(
{3, 3, 5, 96, 200, {1, 1, 1}, {37, 37, 16}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}});
this->conv_params.push_back(
{3, 1, 1, 32, 32, {1, 1, 1}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}});
this->conv_params.push_back(
{3, 1, 1, 32, 32, {2, 2, 2}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}});
this->conv_params.push_back(
{3, 1, 1, 32, 32, {3, 3, 3}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}});
this->conv_params.push_back(
{3, 1, 1, 32, 32, {5, 5, 5}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}});
this->conv_params.push_back(
{3, 1, 1, 32, 32, {9, 9, 9}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}});
this->conv_params.push_back(
{3, 2, 32, 128, 256, {1, 1, 1}, {7, 7, 7}, {2, 2, 2}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}});
this->conv_params.push_back(
{3, 2, 32, 128, 256, {3, 3, 3}, {14, 14, 3}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}});
this->conv_params.push_back(
{3, 2, 32, 128, 256, {1, 1, 1}, {3, 3, 3}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}});
this->conv_params.push_back(
{3, 1, 1, 32, 32, {1, 1, 1}, {16, 16, 16}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}});
this->conv_params.push_back(
{3, 1, 1, 1, 32, {3, 3, 3}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}});
this->conv_params.push_back(
{3, 1, 1, 64, 3, {3, 3, 3}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}});
this->conv_params.push_back(
{3, 1, 1, 1, 1, {3, 3, 3}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}});
this->conv_params.push_back(
{3, 96, 1, 1, 1, {1, 1, 1}, {120, 40, 20}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}});
this->conv_params.push_back(
{3, 96, 1, 1, 1, {3, 3, 3}, {120, 40, 20}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}});
this->template Run<3>();
}

View File

@@ -19,6 +19,7 @@ if(GPU_TARGETS MATCHES "gfx9|gfx11|gfx12")
add_gtest_executable(test_grouped_convnd_fwd_clamp test_grouped_convnd_fwd_clamp.cpp)
target_link_libraries(test_grouped_convnd_fwd_clamp PRIVATE utility device_grouped_conv2d_fwd_clamp_instance device_grouped_conv3d_fwd_clamp_instance)
add_executable(test_grouped_convnd_fwd_bias_clamp_large_cases test_grouped_convnd_fwd_bias_clamp_large_cases.cpp)
target_compile_options(test_grouped_convnd_fwd_bias_clamp_large_cases PRIVATE -Wno-global-constructors -Wno-undef)
target_link_libraries(test_grouped_convnd_fwd_bias_clamp_large_cases PRIVATE gtest_main getopt::getopt utility device_grouped_conv2d_fwd_bias_clamp_instance device_grouped_conv3d_fwd_bias_clamp_instance)

View File

@@ -86,20 +86,76 @@ TYPED_TEST_SUITE(TestGroupedConvndFwd3d, KernelTypes3d);
TYPED_TEST(TestGroupedConvndFwd2d, Test2D)
{
this->conv_params.clear();
this->conv_params.push_back(
{2, 3, 5, 96, 200, {1, 1}, {73, 128}, {1, 1}, {1, 1}, {0, 0}, {0, 0}});
this->conv_params.push_back(
{2, 1, 1, 32, 32, {1, 1}, {128, 128}, {1, 1}, {1, 1}, {0, 0}, {0, 0}});
this->conv_params.push_back(
{2, 1, 1, 32, 32, {2, 2}, {128, 128}, {1, 1}, {1, 1}, {0, 0}, {0, 0}});
this->conv_params.push_back(
{2, 1, 1, 32, 32, {3, 3}, {128, 128}, {1, 1}, {1, 1}, {0, 0}, {0, 0}});
this->conv_params.push_back(
{2, 1, 1, 32, 32, {5, 5}, {128, 128}, {1, 1}, {1, 1}, {0, 0}, {0, 0}});
this->conv_params.push_back(
{2, 1, 1, 32, 32, {9, 9}, {128, 128}, {1, 1}, {1, 1}, {0, 0}, {0, 0}});
this->conv_params.push_back(
{2, 2, 32, 128, 256, {1, 1}, {7, 7}, {2, 2}, {1, 1}, {0, 0}, {0, 0}});
this->conv_params.push_back(
{2, 2, 32, 128, 256, {3, 3}, {14, 14}, {1, 1}, {1, 1}, {1, 1}, {1, 1}});
this->conv_params.push_back(
{2, 2, 32, 128, 256, {1, 1}, {3, 3}, {1, 1}, {1, 1}, {0, 0}, {0, 0}});
this->conv_params.push_back({2, 1, 1, 1, 32, {3, 3}, {32, 32}, {1, 1}, {1, 1}, {1, 1}, {1, 1}});
this->conv_params.push_back({2, 1, 1, 64, 3, {3, 3}, {32, 32}, {1, 1}, {1, 1}, {1, 1}, {1, 1}});
this->conv_params.push_back({2, 1, 1, 1, 1, {3, 3}, {32, 32}, {1, 1}, {1, 1}, {1, 1}, {1, 1}});
this->conv_params.push_back(
{2, 96, 1, 1, 1, {1, 1}, {120, 160}, {1, 1}, {1, 1}, {1, 1}, {1, 1}});
this->conv_params.push_back(
{2, 96, 1, 1, 1, {3, 3}, {120, 160}, {1, 1}, {1, 1}, {1, 1}, {1, 1}});
this->template Run<2>();
}
TYPED_TEST(TestGroupedConvndFwd3d, Test3D)
{
this->conv_params.clear();
this->conv_params.push_back(
{3, 3, 5, 96, 200, {1, 1, 1}, {37, 37, 16}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}});
this->conv_params.push_back(
{3, 1, 1, 32, 32, {1, 1, 1}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}});
this->conv_params.push_back(
{3, 1, 1, 32, 32, {2, 2, 2}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}});
this->conv_params.push_back(
{3, 1, 1, 32, 32, {3, 3, 3}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}});
this->conv_params.push_back(
{3, 1, 1, 32, 32, {5, 5, 5}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}});
this->conv_params.push_back(
{3, 1, 1, 32, 32, {9, 9, 9}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}});
this->conv_params.push_back(
{3, 2, 32, 128, 256, {1, 1, 1}, {7, 7, 7}, {2, 2, 2}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}});
this->conv_params.push_back(
{3, 2, 32, 128, 256, {3, 3, 3}, {14, 14, 3}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}});
this->conv_params.push_back(
{3, 2, 32, 128, 256, {1, 1, 1}, {3, 3, 3}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}});
this->conv_params.push_back(
{3, 1, 1, 32, 32, {1, 1, 1}, {16, 16, 16}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}});
this->conv_params.push_back(
{3, 1, 1, 1, 32, {3, 3, 3}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}});
this->conv_params.push_back(
{3, 1, 1, 64, 3, {3, 3, 3}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}});
this->conv_params.push_back(
{3, 1, 1, 1, 1, {3, 3, 3}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}});
this->conv_params.push_back(
{3, 96, 1, 1, 1, {1, 1, 1}, {120, 40, 20}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}});
this->conv_params.push_back(
{3, 96, 1, 1, 1, {3, 3, 3}, {120, 40, 20}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}});
this->template Run<3>();
}
int main(int argc, char** argv)

View File

@@ -88,20 +88,75 @@ TYPED_TEST_SUITE(TestGroupedConvndFwd3d, KernelTypes3d);
TYPED_TEST(TestGroupedConvndFwd2d, Test2D)
{
this->conv_params.clear();
this->conv_params.push_back(
{2, 3, 5, 96, 200, {1, 1}, {73, 128}, {1, 1}, {1, 1}, {0, 0}, {0, 0}});
this->conv_params.push_back(
{2, 1, 1, 32, 32, {1, 1}, {128, 128}, {1, 1}, {1, 1}, {0, 0}, {0, 0}});
this->conv_params.push_back(
{2, 1, 1, 32, 32, {2, 2}, {128, 128}, {1, 1}, {1, 1}, {0, 0}, {0, 0}});
this->conv_params.push_back(
{2, 1, 1, 32, 32, {3, 3}, {128, 128}, {1, 1}, {1, 1}, {0, 0}, {0, 0}});
this->conv_params.push_back(
{2, 1, 1, 32, 32, {5, 5}, {128, 128}, {1, 1}, {1, 1}, {0, 0}, {0, 0}});
this->conv_params.push_back(
{2, 1, 1, 32, 32, {9, 9}, {128, 128}, {1, 1}, {1, 1}, {0, 0}, {0, 0}});
this->conv_params.push_back(
{2, 2, 32, 128, 256, {1, 1}, {7, 7}, {2, 2}, {1, 1}, {0, 0}, {0, 0}});
this->conv_params.push_back(
{2, 2, 32, 128, 256, {3, 3}, {14, 14}, {1, 1}, {1, 1}, {1, 1}, {1, 1}});
this->conv_params.push_back(
{2, 2, 32, 128, 256, {1, 1}, {3, 3}, {1, 1}, {1, 1}, {0, 0}, {0, 0}});
this->conv_params.push_back({2, 1, 1, 1, 32, {3, 3}, {32, 32}, {1, 1}, {1, 1}, {1, 1}, {1, 1}});
this->conv_params.push_back({2, 1, 1, 64, 3, {3, 3}, {32, 32}, {1, 1}, {1, 1}, {1, 1}, {1, 1}});
this->conv_params.push_back({2, 1, 1, 1, 1, {3, 3}, {32, 32}, {1, 1}, {1, 1}, {1, 1}, {1, 1}});
this->conv_params.push_back(
{2, 96, 1, 1, 1, {1, 1}, {120, 160}, {1, 1}, {1, 1}, {1, 1}, {1, 1}});
this->conv_params.push_back(
{2, 96, 1, 1, 1, {3, 3}, {120, 160}, {1, 1}, {1, 1}, {1, 1}, {1, 1}});
this->template Run<2>();
}
TYPED_TEST(TestGroupedConvndFwd3d, Test3D)
{
this->conv_params.clear();
this->conv_params.push_back(
{3, 3, 5, 96, 200, {1, 1, 1}, {37, 37, 16}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}});
this->conv_params.push_back(
{3, 1, 1, 32, 32, {1, 1, 1}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}});
this->conv_params.push_back(
{3, 1, 1, 32, 32, {2, 2, 2}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}});
this->conv_params.push_back(
{3, 1, 1, 32, 32, {3, 3, 3}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}});
this->conv_params.push_back(
{3, 1, 1, 32, 32, {5, 5, 5}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}});
this->conv_params.push_back(
{3, 1, 1, 32, 32, {9, 9, 9}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}});
this->conv_params.push_back(
{3, 2, 32, 128, 256, {1, 1, 1}, {7, 7, 7}, {2, 2, 2}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}});
this->conv_params.push_back(
{3, 2, 32, 128, 256, {3, 3, 3}, {14, 14, 3}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}});
this->conv_params.push_back(
{3, 2, 32, 128, 256, {1, 1, 1}, {3, 3, 3}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}});
this->conv_params.push_back(
{3, 1, 1, 32, 32, {1, 1, 1}, {16, 16, 16}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}});
this->conv_params.push_back(
{3, 1, 1, 1, 32, {3, 3, 3}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}});
this->conv_params.push_back(
{3, 1, 1, 64, 3, {3, 3, 3}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}});
this->conv_params.push_back(
{3, 1, 1, 1, 1, {3, 3, 3}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}});
this->conv_params.push_back(
{3, 96, 1, 1, 1, {1, 1, 1}, {120, 40, 20}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}});
this->conv_params.push_back(
{3, 96, 1, 1, 1, {3, 3, 3}, {120, 40, 20}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}});
this->template Run<3>();
}
int main(int argc, char** argv)

View File

@@ -81,20 +81,75 @@ TYPED_TEST_SUITE(TestGroupedConvndFwd3d, KernelTypes3d);
TYPED_TEST(TestGroupedConvndFwd2d, Test2D)
{
this->conv_params.clear();
this->conv_params.push_back(
{2, 3, 5, 96, 200, {1, 1}, {73, 128}, {1, 1}, {1, 1}, {0, 0}, {0, 0}});
this->conv_params.push_back(
{2, 1, 1, 32, 32, {1, 1}, {128, 128}, {1, 1}, {1, 1}, {0, 0}, {0, 0}});
this->conv_params.push_back(
{2, 1, 1, 32, 32, {2, 2}, {128, 128}, {1, 1}, {1, 1}, {0, 0}, {0, 0}});
this->conv_params.push_back(
{2, 1, 1, 32, 32, {3, 3}, {128, 128}, {1, 1}, {1, 1}, {0, 0}, {0, 0}});
this->conv_params.push_back(
{2, 1, 1, 32, 32, {5, 5}, {128, 128}, {1, 1}, {1, 1}, {0, 0}, {0, 0}});
this->conv_params.push_back(
{2, 1, 1, 32, 32, {9, 9}, {128, 128}, {1, 1}, {1, 1}, {0, 0}, {0, 0}});
this->conv_params.push_back(
{2, 2, 32, 128, 256, {1, 1}, {7, 7}, {2, 2}, {1, 1}, {0, 0}, {0, 0}});
this->conv_params.push_back(
{2, 2, 32, 128, 256, {3, 3}, {14, 14}, {1, 1}, {1, 1}, {1, 1}, {1, 1}});
this->conv_params.push_back(
{2, 2, 32, 128, 256, {1, 1}, {3, 3}, {1, 1}, {1, 1}, {0, 0}, {0, 0}});
this->conv_params.push_back({2, 1, 1, 1, 32, {3, 3}, {32, 32}, {1, 1}, {1, 1}, {1, 1}, {1, 1}});
this->conv_params.push_back({2, 1, 1, 64, 3, {3, 3}, {32, 32}, {1, 1}, {1, 1}, {1, 1}, {1, 1}});
this->conv_params.push_back({2, 1, 1, 1, 1, {3, 3}, {32, 32}, {1, 1}, {1, 1}, {1, 1}, {1, 1}});
this->conv_params.push_back(
{2, 96, 1, 1, 1, {1, 1}, {120, 160}, {1, 1}, {1, 1}, {1, 1}, {1, 1}});
this->conv_params.push_back(
{2, 96, 1, 1, 1, {3, 3}, {120, 160}, {1, 1}, {1, 1}, {1, 1}, {1, 1}});
this->template Run<2>();
}
TYPED_TEST(TestGroupedConvndFwd3d, Test3D)
{
this->conv_params.clear();
this->conv_params.push_back(
{3, 3, 5, 96, 200, {1, 1, 1}, {37, 37, 16}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}});
this->conv_params.push_back(
{3, 1, 1, 32, 32, {1, 1, 1}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}});
this->conv_params.push_back(
{3, 1, 1, 32, 32, {2, 2, 2}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}});
this->conv_params.push_back(
{3, 1, 1, 32, 32, {3, 3, 3}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}});
this->conv_params.push_back(
{3, 1, 1, 32, 32, {5, 5, 5}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}});
this->conv_params.push_back(
{3, 1, 1, 32, 32, {9, 9, 9}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}});
this->conv_params.push_back(
{3, 2, 32, 128, 256, {1, 1, 1}, {7, 7, 7}, {2, 2, 2}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}});
this->conv_params.push_back(
{3, 2, 32, 128, 256, {3, 3, 3}, {14, 14, 3}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}});
this->conv_params.push_back(
{3, 2, 32, 128, 256, {1, 1, 1}, {3, 3, 3}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}});
this->conv_params.push_back(
{3, 1, 1, 32, 32, {1, 1, 1}, {16, 16, 16}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}});
this->conv_params.push_back(
{3, 1, 1, 1, 32, {3, 3, 3}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}});
this->conv_params.push_back(
{3, 1, 1, 64, 3, {3, 3, 3}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}});
this->conv_params.push_back(
{3, 1, 1, 1, 1, {3, 3, 3}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}});
this->conv_params.push_back(
{3, 96, 1, 1, 1, {1, 1, 1}, {120, 40, 20}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}});
this->conv_params.push_back(
{3, 96, 1, 1, 1, {3, 3, 3}, {120, 40, 20}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}});
this->template Run<3>();
}
int main(int argc, char** argv)