From a71a7b2d834cd47af51fa940bf9d3e522e74b6c6 Mon Sep 17 00:00:00 2001 From: ApoorvaKalyani Date: Tue, 30 Dec 2025 16:25:08 +0100 Subject: [PATCH] Grouped convolution backward data WMMA v3 implementation (#3460) * Added device level implementation for bwd_data_wmma_v3. * Added first instance of bwd_data_wmma_v3(f16). * Add support for bwd data in gridwise implementation Some changes are general for convolution and some are specific for bwd data. We need to generalize them once we have fwd, bwd data and bwd weight * Initial device implementation of bwd data * Remove unused template parameters in device impl * Add one instance for different layout initial check of device implementation * Add tests for splitk and for different layouts * Appended more instances to wmma_v3_f16. * Added conv_2d bf16 wmma_v3 instances. * Added conv_3d_bf16 wmma_v3_instances. * Added conv_3d_f16_wmma_v3_instances. * Added SplitN test cases for wmma. * Conv3d_bwd_data_scale_wmma_v3 instances. * Conv3d_bwd_data_bilinear_wmma_v3_instances * Renaming the device level instances file to common name , since it is defined for different DataTypes. * Renaming the instances and fixing typo * Added the test cases to regression test list * NCHW support for wmma_v3 * Examples for bf16 and f16 bwd_data_wmma_v3 * Added transpose conditons for device impl * fixing bugs * Added the gemm_args array implmentation * WIP debug conv bwd * fix splitk * Grouped gemm fix * Update CmakeLists with EOF * Added more instances for tests * Fixed the run time error in examples and removed 3d conv examples. * Fixed a typo. * Updated CmakeLists to removed the 3d convultion deleted files * Added print error statements for unsupoorted argument * Added the merge conflict related changes * Fixed compilation error * Fixed the InstanceFactory duplication error. * Removed the print statements and added logs to Arg function * All the merge conflict related errors resolved * Added d_tensor tests. * Added the missing example types of wmm_v3 * Merge error fix * Corrected the instance name * Reverted the bias relu change * Revereted the transpose load local change * Updated the regression test list with bwd_data_scale * Revert "Revereted the transpose load local change" This reverts commit 0b7281edb2bf008e407006690a00621174d9d19b. * Revert "Merge error fix" This reverts commit f3c85daa474b1b83d10c8a3ce077354e71d91a2b. * Reverting the local change * Added merge error fix * Build error fix due to merge conflicts * Added bias_relu example for wmma_v3 * Modified the main method in dtensor tests * Updated the dtensor tests to pick all the shapes * Updated the dtensor test shapes. * Updated the mem operations in tests. * Added reference func * Fixed typos in device impl * Added new header file and modified the include file for 3d tests * Renamed the test file and added reference func call. * clang format fix * Added ignore params * Modified device impl and tests * Removed debug print statements and updated dtensor test shapes * Fixing merge conflicts * Fixing more merge conflicts * Fixed copyrights * Updated the tuned instances to bilinear and scale. * Adding tuned instances to vanilla wmma_v3 * Removed all unused instances and modified test layouts. * Cleaned up all instances , reverted back fwd fp16 instances and updated tuned fp16 instances. * Fix clang format * Updated tuned f16/-genric instances * Formatting the instances file * Fixed copyrights and clang issues * Nonsense commit to force git to force * Removed the transpose instances * Added verified genric instances * Fixing namespace errors * Added todo for failing shapes * Formatting instance file * Fix instance list formatting * Removing unnecessary formats * Renamed the common file * Unification of xdl and wmma bwd_data tests * Updated Cmake * Added all layout types and deleted code. * Updated Cmake to add the condition to all tests. --------- Co-authored-by: Enrico Degregori Co-authored-by: Anton Gorenko Co-authored-by: kiefer [ROCm/composable_kernel commit: 53a1e4f551cd0a063ec722ffd19f79a67a9ef320] --- .../CMakeLists.txt | 21 + .../common.hpp | 4 + .../common_conv3d.hpp | 116 + .../grouped_conv3d_bwd_data_wmma_v3_bf16.cpp | 31 + .../grouped_conv3d_bwd_data_wmma_v3_fp16.cpp | 30 + ...d_conv_bwd_data_bias_relu_wmma_v3_fp16.cpp | 34 + .../grouped_conv_bwd_data_wmma_v3_bf16.cpp | 34 + .../grouped_conv_bwd_data_wmma_v3_fp16.cpp | 35 + ...onv_bwd_data_wmma_v3_fp16_comp_bf8_fp8.cpp | 47 + .../run_grouped_conv3d_bwd_data_example.inc | 192 ++ ...v_bwd_data_multiple_d_wmma_cshuffle_v3.hpp | 1994 +++++++++++++++++ .../grid/gridwise_gemm_wmma_cshuffle_v3.hpp | 141 ++ ...onv_bwd_data_wmma_v3_bilinear_instance.hpp | 100 + ...rouped_conv_bwd_data_wmma_v3_instances.hpp | 125 ++ ...d_conv_bwd_data_wmma_v3_scale_instance.hpp | 102 + .../gpu/grouped_convolution_backward_data.hpp | 33 +- ...ped_convolution_backward_data_bilinear.hpp | 71 + ...rouped_convolution_backward_data_scale.hpp | 67 + ...grouped_convolution_backward_data_wmma.inc | 122 +- .../grouped_conv2d_bwd_data/CMakeLists.txt | 12 +- ..._nhwgc_gkyxc_nhwgk_bf16_16_16_instance.cpp | 49 + ...mma_v3_nhwgc_gkyxc_nhwgk_bf16_instance.cpp | 49 + ...3_nhwgc_gkyxc_nhwgk_f16_16_16_instance.cpp | 49 + ...wmma_v3_nhwgc_gkyxc_nhwgk_f16_instance.cpp | 49 + .../grouped_conv3d_bwd_data/CMakeLists.txt | 10 +- ...hwgc_gkzyxc_ndhwgk_bf16_16_16_instance.cpp | 48 + ..._v3_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp | 48 + ...dhwgc_gkzyxc_ndhwgk_f16_16_16_instance.cpp | 48 + ...a_v3_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp | 48 + .../CMakeLists.txt | 7 +- ...ear_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp | 49 + ...near_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp | 49 + .../CMakeLists.txt | 11 +- ...ale_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp | 49 + ...cale_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp | 49 + test/CMakeLists.txt | 3 + test/grouped_convnd_bwd_data/CMakeLists.txt | 27 +- .../test_grouped_conv_bwd_data_bilinear.cpp | 324 +++ .../test_grouped_conv_bwd_data_scale.cpp | 324 +++ ...l.cpp => test_grouped_convnd_bwd_data.cpp} | 14 +- ...t_grouped_convnd_bwd_data_large_cases.cpp} | 14 +- .../test_grouped_convnd_bwd_data_wmma.cpp | 135 -- 42 files changed, 4593 insertions(+), 171 deletions(-) create mode 100644 example/38_grouped_conv_bwd_data_multiple_d/common_conv3d.hpp create mode 100644 example/38_grouped_conv_bwd_data_multiple_d/grouped_conv3d_bwd_data_wmma_v3_bf16.cpp create mode 100644 example/38_grouped_conv_bwd_data_multiple_d/grouped_conv3d_bwd_data_wmma_v3_fp16.cpp create mode 100644 example/38_grouped_conv_bwd_data_multiple_d/grouped_conv_bwd_data_bias_relu_wmma_v3_fp16.cpp create mode 100644 example/38_grouped_conv_bwd_data_multiple_d/grouped_conv_bwd_data_wmma_v3_bf16.cpp create mode 100644 example/38_grouped_conv_bwd_data_multiple_d/grouped_conv_bwd_data_wmma_v3_fp16.cpp create mode 100644 example/38_grouped_conv_bwd_data_multiple_d/grouped_conv_bwd_data_wmma_v3_fp16_comp_bf8_fp8.cpp create mode 100644 example/38_grouped_conv_bwd_data_multiple_d/run_grouped_conv3d_bwd_data_example.inc create mode 100644 include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle_v3.hpp create mode 100644 library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_wmma_v3_bilinear_instance.hpp create mode 100644 library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_wmma_v3_instances.hpp create mode 100644 library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_wmma_v3_scale_instance.hpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/wmma/device_grouped_conv2d_bwd_data_wmma_v3_nhwgc_gkyxc_nhwgk_bf16_16_16_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/wmma/device_grouped_conv2d_bwd_data_wmma_v3_nhwgc_gkyxc_nhwgk_bf16_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/wmma/device_grouped_conv2d_bwd_data_wmma_v3_nhwgc_gkyxc_nhwgk_f16_16_16_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/wmma/device_grouped_conv2d_bwd_data_wmma_v3_nhwgc_gkyxc_nhwgk_f16_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/wmma/device_grouped_conv3d_bwd_data_wmma_v3_ndhwgc_gkzyxc_ndhwgk_bf16_16_16_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/wmma/device_grouped_conv3d_bwd_data_wmma_v3_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/wmma/device_grouped_conv3d_bwd_data_wmma_v3_ndhwgc_gkzyxc_ndhwgk_f16_16_16_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/wmma/device_grouped_conv3d_bwd_data_wmma_v3_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data_bilinear/wmma/device_grouped_conv3d_bwd_data_wmma_v3_bilinear_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data_bilinear/wmma/device_grouped_conv3d_bwd_data_wmma_v3_bilinear_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data_scale/wmma/device_grouped_conv3d_bwd_data_wmma_v3_scale_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data_scale/wmma/device_grouped_conv3d_bwd_data_wmma_v3_scale_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp create mode 100644 test/grouped_convnd_bwd_data/test_grouped_conv_bwd_data_bilinear.cpp create mode 100644 test/grouped_convnd_bwd_data/test_grouped_conv_bwd_data_scale.cpp rename test/grouped_convnd_bwd_data/{test_grouped_convnd_bwd_data_xdl.cpp => test_grouped_convnd_bwd_data.cpp} (94%) rename test/grouped_convnd_bwd_data/{test_grouped_convnd_bwd_data_xdl_large_cases.cpp => test_grouped_convnd_bwd_data_large_cases.cpp} (91%) delete mode 100644 test/grouped_convnd_bwd_data/test_grouped_convnd_bwd_data_wmma.cpp diff --git a/example/38_grouped_conv_bwd_data_multiple_d/CMakeLists.txt b/example/38_grouped_conv_bwd_data_multiple_d/CMakeLists.txt index b58bd7cb3a..a50bac177e 100644 --- a/example/38_grouped_conv_bwd_data_multiple_d/CMakeLists.txt +++ b/example/38_grouped_conv_bwd_data_multiple_d/CMakeLists.txt @@ -9,8 +9,29 @@ add_example_dependencies(example_grouped_conv_bwd_data example_grouped_conv_bwd_ add_example_executable(example_grouped_conv_bwd_data_xdl_fp16_comp_bf8_fp8 grouped_conv_bwd_data_xdl_fp16_comp_bf8_fp8.cpp) add_example_dependencies(example_grouped_conv_bwd_data example_grouped_conv_bwd_data_xdl_fp16_comp_bf8_fp8) +add_example_executable(example_grouped_conv_bwd_data_wmma_v3_fp16_comp_bf8_fp8 grouped_conv_bwd_data_wmma_v3_fp16_comp_bf8_fp8.cpp) +add_example_dependencies(example_grouped_conv_bwd_data example_grouped_conv_bwd_data_wmma_v3_fp16_comp_bf8_fp8) + add_example_executable(example_grouped_conv_bwd_data_bias_relu_xdl_fp16 grouped_conv_bwd_data_bias_relu_xdl_fp16.cpp) add_example_dependencies(example_grouped_conv_bwd_data example_grouped_conv_bwd_data_bias_relu_xdl_fp16) +add_example_executable(example_grouped_conv_bwd_data_bias_relu_wmma_v3_fp16 grouped_conv_bwd_data_bias_relu_wmma_v3_fp16.cpp) +add_example_dependencies(example_grouped_conv_bwd_data example_grouped_conv_bwd_data_bias_relu_wmma_v3_fp16) + add_example_executable(example_grouped_conv_bwd_data_wmma_fp16 grouped_conv_bwd_data_wmma_fp16.cpp) add_example_dependencies(example_grouped_conv_bwd_data example_grouped_conv_bwd_data_wmma_fp16) + +add_example_executable(example_grouped_conv_bwd_data_wmma_v3_bf16 grouped_conv_bwd_data_wmma_v3_bf16.cpp) +add_example_dependencies(example_grouped_conv_bwd_data example_grouped_conv_bwd_data_wmma_v3_bf16) + +add_example_executable(example_grouped_conv3d_bwd_data_wmma_v3_bf16 grouped_conv3d_bwd_data_wmma_v3_bf16.cpp) +add_example_dependencies(example_grouped_conv_bwd_data example_grouped_conv3d_bwd_data_wmma_v3_bf16) + +add_example_executable(example_grouped_conv3d_bwd_data_wmma_v3_fp16 grouped_conv3d_bwd_data_wmma_v3_fp16.cpp) +add_example_dependencies(example_grouped_conv_bwd_data example_grouped_conv3d_bwd_data_wmma_v3_fp16) + +add_example_executable(example_grouped_conv_bwd_data_wmma_v3_fp16 grouped_conv_bwd_data_wmma_v3_fp16.cpp) +add_example_dependencies(example_grouped_conv_bwd_data example_grouped_conv_bwd_data_wmma_v3_fp16) + + + diff --git a/example/38_grouped_conv_bwd_data_multiple_d/common.hpp b/example/38_grouped_conv_bwd_data_multiple_d/common.hpp index 074718c901..a00150ba6c 100644 --- a/example/38_grouped_conv_bwd_data_multiple_d/common.hpp +++ b/example/38_grouped_conv_bwd_data_multiple_d/common.hpp @@ -37,7 +37,11 @@ static inline constexpr ck::index_t NDimSpatial = 2; static constexpr auto ConvBwdDataDefault = ck::tensor_operation::device::ConvolutionBackwardDataSpecialization::Default; +static constexpr auto ConvBwdDataFilter1x1Stride1Pad0 = + ck::tensor_operation::device::ConvolutionBackwardDataSpecialization::Filter1x1Stride1Pad0; + using FP16 = ck::half_t; +using BF16 = ck::bhalf_t; using FP32 = float; using FP8 = ck::f8_t; using BF8 = ck::bf8_t; diff --git a/example/38_grouped_conv_bwd_data_multiple_d/common_conv3d.hpp b/example/38_grouped_conv_bwd_data_multiple_d/common_conv3d.hpp new file mode 100644 index 0000000000..0e3ee14253 --- /dev/null +++ b/example/38_grouped_conv_bwd_data_multiple_d/common_conv3d.hpp @@ -0,0 +1,116 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/convolution_backward_data_specialization.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/reference_tensor_operation/cpu/reference_conv_bwd_data.hpp" +#include "ck/library/utility/algorithm.hpp" +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp" +#include "ck/library/utility/convolution_parameter.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" + +using ::ck::DeviceMem; +using ::ck::hip_check_error; +using ::ck::HostTensorDescriptor; +using ::ck::Tensor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static inline constexpr ck::index_t NDimSpatial = 3; + +static constexpr auto ConvBwdDataDefault = + ck::tensor_operation::device::ConvolutionBackwardDataSpecialization::Default; + +static constexpr auto ConvBwdDataFilter1x1Stride1Pad0 = + ck::tensor_operation::device::ConvolutionBackwardDataSpecialization::Filter1x1Stride1Pad0; + +using FP16 = ck::half_t; +using BF16 = ck::bhalf_t; +using FP32 = float; +using FP8 = ck::f8_t; +using BF8 = ck::bf8_t; + +struct ExecutionConfig final +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = false; +}; + +#define DefaultConvParams \ + ck::utils::conv::ConvParam \ + { \ + NDimSpatial, 32, 4, 192, 192, {3, 3, 3}, {28, 28, 28}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, \ + { \ + 1, 1, 1 \ + } \ + } + +inline void print_help_msg() +{ + std::cerr << "arg1: verification (0=no, 1=yes)\n" + << "arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n" + << "arg3: time kernel (0=no, 1=yes)\n" + << ck::utils::conv::get_conv_param_parser_helper_msg() << std::endl; +} + +inline bool parse_cmd_args(int argc, + char* argv[], + ExecutionConfig& config, + ck::utils::conv::ConvParam& conv_params) +{ + constexpr int num_execution_config_args = + 3; // arguments for do_verification, init_method, time_kernel + constexpr int num_conv_param_leading_args = 5; // arguments for num_dim_spatial_, G_, N_, K_, C_ + + constexpr int threshold_to_catch_partial_args = 1 + num_execution_config_args; + constexpr int threshold_to_catch_all_args = + threshold_to_catch_partial_args + num_conv_param_leading_args; + + if(argc == 1) + { + // use default + config = ExecutionConfig{}; + } + // catch only ExecutionConfig arguments + else if(argc == threshold_to_catch_partial_args) + { + config.do_verification = std::stoi(argv[1]); + config.init_method = std::stoi(argv[2]); + config.time_kernel = std::stoi(argv[3]); + } + // catch both ExecutionConfig & ConvParam arguments + else if(threshold_to_catch_all_args < argc && ((argc - threshold_to_catch_all_args) % 3 == 0)) + { + config.do_verification = std::stoi(argv[1]); + config.init_method = std::stoi(argv[2]); + config.time_kernel = std::stoi(argv[3]); + + const ck::index_t num_dim_spatial = std::stoi(argv[4]); + conv_params = ck::utils::conv::parse_conv_param( + num_dim_spatial, threshold_to_catch_partial_args + 1, argv); + } + else + { + print_help_msg(); + return false; + } + + return true; +} diff --git a/example/38_grouped_conv_bwd_data_multiple_d/grouped_conv3d_bwd_data_wmma_v3_bf16.cpp b/example/38_grouped_conv_bwd_data_multiple_d/grouped_conv3d_bwd_data_wmma_v3_bf16.cpp new file mode 100644 index 0000000000..e35b006c90 --- /dev/null +++ b/example/38_grouped_conv_bwd_data_multiple_d/grouped_conv3d_bwd_data_wmma_v3_bf16.cpp @@ -0,0 +1,31 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle_v3.hpp" +#include "common_conv3d.hpp" + +using OutDataType = BF16; +using WeiDataType = BF16; +using AccDataType = FP32; +using CShuffleDataType = BF16; +using DsDataType = ck::Tuple<>; +using InDataType = BF16; + +using InLayout = ck::tensor_layout::convolution::NDHWGC; +using WeiLayout = ck::tensor_layout::convolution::GKZYXC; +using DsLayout = ck::Tuple<>; +using OutLayout = ck::tensor_layout::convolution::NDHWGK; +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +// clang-format off +using DeviceConvInstance = ck::tensor_operation::device::DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffleV3 +// ######| NDimSpatial| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| AElementwise| BElementwise| CDEElementwise| ConvolutionBackward| DoPad| DoPad| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle | CShuffle | CDEBlockTransfer| CDEBlockTransfer| +// ######| | | | | | Type| Type| Type| DataType| Type| Type| Operation| Operation| Operation| DataSpecialization| GemmM| GemmN| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat | NRepeat | _MBlock_MPerBlock| ScalarPerVector| +// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| +// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + < NDimSpatial, OutLayout, WeiLayout, DsLayout, InLayout, OutDataType, WeiDataType, AccDataType, CShuffleDataType, DsDataType, InDataType, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, true, true, 128, 64, 64, 32, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, S<8,8,8>>; +// clang-format on + +#include "run_grouped_conv3d_bwd_data_example.inc" + +int main(int argc, char* argv[]) { return run_grouped_conv_bwd_data_example(argc, argv); } diff --git a/example/38_grouped_conv_bwd_data_multiple_d/grouped_conv3d_bwd_data_wmma_v3_fp16.cpp b/example/38_grouped_conv_bwd_data_multiple_d/grouped_conv3d_bwd_data_wmma_v3_fp16.cpp new file mode 100644 index 0000000000..2d9da2aab4 --- /dev/null +++ b/example/38_grouped_conv_bwd_data_multiple_d/grouped_conv3d_bwd_data_wmma_v3_fp16.cpp @@ -0,0 +1,30 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle_v3.hpp" +#include "common_conv3d.hpp" +using OutDataType = FP16; +using WeiDataType = FP16; +using AccDataType = FP32; +using CShuffleDataType = FP16; +using DsDataType = ck::Tuple<>; +using InDataType = FP16; + +using InLayout = ck::tensor_layout::convolution::NDHWGC; +using WeiLayout = ck::tensor_layout::convolution::GKZYXC; +using DsLayout = ck::Tuple<>; +using OutLayout = ck::tensor_layout::convolution::NDHWGK; +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +// clang-format off +using DeviceConvInstance = ck::tensor_operation::device::DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffleV3 +// ######| NDimSpatial| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| AElementwise| BElementwise| CDEElementwise| ConvolutionBackward| DoPad| DoPad| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle | CShuffle | CDEBlockTransfer| CDEBlockTransfer| +// ######| | | | | | Type| Type| Type| DataType| Type| Type| Operation| Operation| Operation| DataSpecialization| GemmM| GemmN| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat | NRepeat | _MBlock_MPerBlock| ScalarPerVector| +// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| +// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + < NDimSpatial, OutLayout, WeiLayout, DsLayout, InLayout, OutDataType, WeiDataType, AccDataType, CShuffleDataType, DsDataType, InDataType, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, true, true, 128, 64, 64, 32, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, S<8,8,8>>; +// clang-format on + +#include "run_grouped_conv3d_bwd_data_example.inc" + +int main(int argc, char* argv[]) { return run_grouped_conv_bwd_data_example(argc, argv); } diff --git a/example/38_grouped_conv_bwd_data_multiple_d/grouped_conv_bwd_data_bias_relu_wmma_v3_fp16.cpp b/example/38_grouped_conv_bwd_data_multiple_d/grouped_conv_bwd_data_bias_relu_wmma_v3_fp16.cpp new file mode 100644 index 0000000000..2a90d5a143 --- /dev/null +++ b/example/38_grouped_conv_bwd_data_multiple_d/grouped_conv_bwd_data_bias_relu_wmma_v3_fp16.cpp @@ -0,0 +1,34 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle_v3.hpp" +#include "common.hpp" + +using OutDataType = FP16; +using WeiDataType = FP16; +using AccDataType = FP32; +using CShuffleDataType = FP16; +using BiasDataType = FP16; // bias +using InDataType = FP16; + +using OutLayout = ck::tensor_layout::convolution::GNHWK; +using WeiLayout = ck::tensor_layout::convolution::GKYXC; +using BiasLayout = ck::Tuple; +using InLayout = ck::tensor_layout::convolution::GNHWC; + +using OutElementOp = PassThrough; +using WeiElementOp = PassThrough; +using InElementOp = ck::tensor_operation::element_wise::AddRelu; + +// clang-format off +using DeviceConvInstance = ck::tensor_operation::device::DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffleV3 +// ######| NDimSpatial| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| AElementwise| BElementwise| CDEElementwise| ConvolutionBackward| DoPad| DoPad| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffleMXdl| CShuffleNXdl| CDEBlockTransfer| CDEBlockTransfer| +// ######| | | | | | Type| Type| Type| DataType| Type| Type| Operation| Operation| Operation| DataSpecialization| GemmM| GemmN| PrefetchStage| Size| Block| Block| Block| | | XDL| XDL| PerWave| PerWave| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| PerWave| PerWave| _MBlock_MPerBlock| ScalarPerVector| +// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| +// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + < NDimSpatial, OutLayout, WeiLayout, BiasLayout, InLayout, OutDataType, WeiDataType, AccDataType, CShuffleDataType, ck::Tuple, InDataType, OutElementOp, WeiElementOp, InElementOp, ConvBwdDataDefault, true, true, 64, 64, 64, 32, 8, 8, 16, 16, 4, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, S<8,8,8>>; +// clang-format on + +#include "run_grouped_conv_bwd_data_bias_relu_example.inc" + +int main(int argc, char* argv[]) { return run_grouped_conv_bwd_data_bias_relu_example(argc, argv); } diff --git a/example/38_grouped_conv_bwd_data_multiple_d/grouped_conv_bwd_data_wmma_v3_bf16.cpp b/example/38_grouped_conv_bwd_data_multiple_d/grouped_conv_bwd_data_wmma_v3_bf16.cpp new file mode 100644 index 0000000000..49e1662332 --- /dev/null +++ b/example/38_grouped_conv_bwd_data_multiple_d/grouped_conv_bwd_data_wmma_v3_bf16.cpp @@ -0,0 +1,34 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle_v3.hpp" +#include "common.hpp" + +using OutDataType = BF16; +using WeiDataType = BF16; +using AccDataType = FP32; +using CShuffleDataType = BF16; +using DsDataType = ck::Tuple<>; +using InDataType = BF16; + +using OutLayout = ck::tensor_layout::convolution::GNHWK; +using WeiLayout = ck::tensor_layout::convolution::GKYXC; +using DsLayout = ck::Tuple<>; +using InLayout = ck::tensor_layout::convolution::GNHWC; + +using OutElementOp = PassThrough; +using WeiElementOp = PassThrough; +using InElementOp = PassThrough; + +// clang-format off +using DeviceConvInstance = ck::tensor_operation::device::DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffleV3 +// ######| NDimSpatial| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| AElementwise| BElementwise| CDEElementwise| ConvolutionBackward| DoPad| DoPad| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle | CShuffle | CDEBlockTransfer| CDEBlockTransfer| +// ######| | | | | | Type| Type| Type| DataType| Type| Type| Operation| Operation| Operation| DataSpecialization| GemmM| GemmN| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat | NRepeat | _MBlock_MPerBlock| ScalarPerVector| +// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| +// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + < NDimSpatial, OutLayout, WeiLayout, DsLayout, InLayout, OutDataType, WeiDataType, AccDataType, CShuffleDataType, DsDataType, InDataType, OutElementOp, WeiElementOp, InElementOp, ConvBwdDataDefault, true, true, 128, 64, 64, 32, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, S<8,8,8>>; +// clang-format on + +#include "run_grouped_conv_bwd_data_example.inc" + +int main(int argc, char* argv[]) { return run_grouped_conv_bwd_data_example(argc, argv); } diff --git a/example/38_grouped_conv_bwd_data_multiple_d/grouped_conv_bwd_data_wmma_v3_fp16.cpp b/example/38_grouped_conv_bwd_data_multiple_d/grouped_conv_bwd_data_wmma_v3_fp16.cpp new file mode 100644 index 0000000000..46f71c00b5 --- /dev/null +++ b/example/38_grouped_conv_bwd_data_multiple_d/grouped_conv_bwd_data_wmma_v3_fp16.cpp @@ -0,0 +1,35 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle_v3.hpp" +#include "common.hpp" + +using OutDataType = FP16; +using WeiDataType = FP16; +using AccDataType = FP32; +using CShuffleDataType = FP16; +using DsDataType = ck::Tuple<>; +using InDataType = FP16; + +using OutLayout = ck::tensor_layout::convolution::GNHWK; +using WeiLayout = ck::tensor_layout::convolution::GKYXC; +using DsLayout = ck::Tuple<>; +using InLayout = ck::tensor_layout::convolution::GNHWC; + +using OutElementOp = PassThrough; +using WeiElementOp = PassThrough; +using InElementOp = PassThrough; + +// clang-format off +using DeviceConvInstance = ck::tensor_operation::device::DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffleV3 +// ######| NDimSpatial| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| AElementwise| BElementwise| CDEElementwise| ConvolutionBackward| DoPad| DoPad| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle | CShuffle | CDEBlockTransfer| CDEBlockTransfer| +// ######| | | | | | Type| Type| Type| DataType| Type| Type| Operation| Operation| Operation| DataSpecialization| GemmM| GemmN| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat | NRepeat | _MBlock_MPerBlock| ScalarPerVector| +// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| +// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + < NDimSpatial, OutLayout, WeiLayout, DsLayout, InLayout, OutDataType, WeiDataType, AccDataType, CShuffleDataType, DsDataType, InDataType, OutElementOp, WeiElementOp, InElementOp, ConvBwdDataDefault, true, true, 64, 64, 64, 32, 8, 8, 16, 16, 4, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, S<8,8,8>>; + +// clang-format on + +#include "run_grouped_conv_bwd_data_example.inc" + +int main(int argc, char* argv[]) { return run_grouped_conv_bwd_data_example(argc, argv); } diff --git a/example/38_grouped_conv_bwd_data_multiple_d/grouped_conv_bwd_data_wmma_v3_fp16_comp_bf8_fp8.cpp b/example/38_grouped_conv_bwd_data_multiple_d/grouped_conv_bwd_data_wmma_v3_fp16_comp_bf8_fp8.cpp new file mode 100644 index 0000000000..3c49710416 --- /dev/null +++ b/example/38_grouped_conv_bwd_data_multiple_d/grouped_conv_bwd_data_wmma_v3_fp16_comp_bf8_fp8.cpp @@ -0,0 +1,47 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle_v3.hpp" +#include "common.hpp" + +using OutDataType = FP16; +using WeiDataType = FP16; +using AccDataType = FP32; +using CShuffleDataType = FP16; +using DsDataType = ck::Tuple<>; +using InDataType = FP16; +using AComputeType = BF8; +using BComputeType = FP8; + +using OutLayout = ck::tensor_layout::convolution::GNHWK; +using WeiLayout = ck::tensor_layout::convolution::GKYXC; +using DsLayout = ck::Tuple<>; +using InLayout = ck::tensor_layout::convolution::GNHWC; + +using OutElementOp = PassThrough; +using WeiElementOp = PassThrough; +using InElementOp = PassThrough; + +static constexpr auto BlkGemmPipeSched = ck::BlockGemmPipelineScheduler::Intrawave; +static constexpr auto BlkGemmPipelineVer = ck::BlockGemmPipelineVersion::v1; + +// clang-format off +using DeviceConvInstance = ck::tensor_operation::device::DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffleV3 +// ######| NDimSpatial| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| AElementwise| BElementwise| CDEElementwise| ConvolutionBackward| DoPad| DoPad| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffleMXdl| CShuffleNXdl| CDEBlockTransfer| CDEBlockTransfer| Loop| ACompute| BCompute| +// ######| | | | | | Type| Type| Type| DataType| Type| Type| Operation| Operation| Operation| DataSpecialization| GemmM| GemmN| PrefetchStage| Size| Block| Block| Block| | | XDL| XDL| PerWave| PerWave| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| PerWave| PerWave| _MBlock_MPerBlock| ScalarPerVector| Scheduler| Type| Type| +// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| | | | +// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + < NDimSpatial, OutLayout, WeiLayout, DsLayout, InLayout, OutDataType, WeiDataType, AccDataType, CShuffleDataType, DsDataType, InDataType, OutElementOp, WeiElementOp, InElementOp, ConvBwdDataDefault, true, true, 64, 64, 64, 32, 8, 8, 16, 16, 4, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, S<8,8,8>, BlkGemmPipeSched,BlkGemmPipelineVer, AComputeType, BComputeType , false , false>; +// clang-format on + +#include "run_grouped_conv_bwd_data_example.inc" + +int main(int argc, char* argv[]) +{ + // temp disable on gfx11 + if(ck::is_gfx11_supported()) + { + return 0; + } + return run_grouped_conv_bwd_data_example(argc, argv); +} diff --git a/example/38_grouped_conv_bwd_data_multiple_d/run_grouped_conv3d_bwd_data_example.inc b/example/38_grouped_conv_bwd_data_multiple_d/run_grouped_conv3d_bwd_data_example.inc new file mode 100644 index 0000000000..0e323f6320 --- /dev/null +++ b/example/38_grouped_conv_bwd_data_multiple_d/run_grouped_conv3d_bwd_data_example.inc @@ -0,0 +1,192 @@ + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using OutElementOp = PassThrough; +using WeiElementOp = PassThrough; +using InElementOp = PassThrough; + +bool run_conv_bwd_data(const ExecutionConfig& config, + const ck::utils::conv::ConvParam& conv_params, + const HostTensorDescriptor& out_g_n_k_wos_desc, + const HostTensorDescriptor& wei_g_k_c_xs_desc, + const HostTensorDescriptor& in_g_n_c_wis_desc, + const OutElementOp& out_element_op, + const WeiElementOp& wei_element_op, + const InElementOp& in_element_op) +{ + + Tensor out(out_g_n_k_wos_desc); + Tensor wei(wei_g_k_c_xs_desc); + Tensor in_host(in_g_n_c_wis_desc); + Tensor in_device(in_g_n_c_wis_desc); + + std::cout << "out: " << out.mDesc << std::endl; + std::cout << "wei: " << wei.mDesc << std::endl; + std::cout << "in: " << in_host.mDesc << std::endl; + + switch(config.init_method) + { + case 0: break; + case 1: + out.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + wei.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + default: + out.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + wei.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + } + + DeviceMem out_device_buf(sizeof(OutDataType) * out.mDesc.GetElementSpaceSize()); + DeviceMem wei_device_buf(sizeof(WeiDataType) * wei.mDesc.GetElementSpaceSize()); + DeviceMem in_device_buf(sizeof(InDataType) * in_device.mDesc.GetElementSpaceSize()); + + out_device_buf.ToDevice(out.mData.data()); + wei_device_buf.ToDevice(wei.mData.data()); + + // reset input to zero + in_device_buf.SetZero(); + + std::array a_g_n_k_wos_lengths{}; + std::array a_g_n_k_wos_strides{}; + std::array b_g_k_c_xs_lengths{}; + std::array b_g_k_c_xs_strides{}; + std::array e_g_n_c_wis_lengths{}; + std::array e_g_n_c_wis_strides{}; + std::array conv_filter_strides{}; + std::array conv_filter_dilations{}; + std::array input_left_pads{}; + std::array input_right_pads{}; + + auto copy = [](auto& x, auto& y) { ck::ranges::copy(x, y.begin()); }; + + copy(out_g_n_k_wos_desc.GetLengths(), a_g_n_k_wos_lengths); + copy(out_g_n_k_wos_desc.GetStrides(), a_g_n_k_wos_strides); + copy(wei_g_k_c_xs_desc.GetLengths(), b_g_k_c_xs_lengths); + copy(wei_g_k_c_xs_desc.GetStrides(), b_g_k_c_xs_strides); + copy(in_g_n_c_wis_desc.GetLengths(), e_g_n_c_wis_lengths); + copy(in_g_n_c_wis_desc.GetStrides(), e_g_n_c_wis_strides); + copy(conv_params.conv_filter_strides_, conv_filter_strides); + copy(conv_params.conv_filter_dilations_, conv_filter_dilations); + copy(conv_params.input_left_pads_, input_left_pads); + copy(conv_params.input_right_pads_, input_right_pads); + + static_assert(std::is_default_constructible_v); + // do conv + auto conv = DeviceConvInstance{}; + auto invoker = conv.MakeInvoker(); + auto argument = conv.MakeArgument(out_device_buf.GetDeviceBuffer(), + wei_device_buf.GetDeviceBuffer(), + std::array{}, + in_device_buf.GetDeviceBuffer(), + a_g_n_k_wos_lengths, + a_g_n_k_wos_strides, + b_g_k_c_xs_lengths, + b_g_k_c_xs_strides, + std::array, 0>{}, + std::array, 0>{}, + e_g_n_c_wis_lengths, + e_g_n_c_wis_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + out_element_op, + wei_element_op, + in_element_op); + + if(!conv.IsSupportedArgument(argument)) + { + std::cerr << "wrong! device_conv with the specified compilation parameters does " + "not support this Conv problem" + << std::endl; + + return false; + } + std::string op_name = conv.GetTypeString(); + float ave_time = invoker.Run(argument, StreamConfig{nullptr, config.time_kernel}); + + std::size_t flop = conv_params.GetFlops(); + std::size_t num_btype = conv_params.GetByte(); + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s" + << std::endl; + + if(config.do_verification) + { + auto ref_conv = ck::tensor_operation::host::ReferenceConvBwdData(); + + auto ref_invoker = ref_conv.MakeInvoker(); + + auto ref_argument = ref_conv.MakeArgument(in_host, + wei, + out, + conv_params.conv_filter_strides_, + conv_params.conv_filter_dilations_, + conv_params.input_left_pads_, + conv_params.input_right_pads_, + PassThrough{}, + wei_element_op, + out_element_op); + + ref_invoker.Run(ref_argument); + + in_device_buf.FromDevice(in_device.mData.data()); + return ck::utils::check_err(in_device.mData, in_host.mData); + } + + return true; +} + +int run_grouped_conv_bwd_data_example(int argc, char* argv[]) +{ + namespace ctc = ck::tensor_layout::convolution; + + ExecutionConfig config; + ck::utils::conv::ConvParam conv_params = DefaultConvParams; + + if(!parse_cmd_args(argc, argv, config, conv_params)) + { + return EXIT_FAILURE; + } + + const auto in_element_op = InElementOp{}; + const auto wei_element_op = WeiElementOp{}; + const auto out_element_op = OutElementOp{}; + + if(conv_params.num_dim_spatial_ != NDimSpatial) + { + std::cerr << "unsupported # of spatials dimensions" << std::endl; + return EXIT_FAILURE; + } + + // output image: GNHWK + const auto out_g_n_k_wos_desc = + ck::utils::conv::make_output_host_tensor_descriptor_g_n_k_wos_packed( + conv_params); + + // weight: GKYXC + const auto wei_g_k_c_xs_desc = + ck::utils::conv::make_weight_host_tensor_descriptor_g_k_c_xs_packed(conv_params); + + // input image: GNHWC + const auto in_g_n_c_wis_desc = + ck::utils::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed(conv_params); + + return !run_conv_bwd_data(config, + conv_params, + out_g_n_k_wos_desc, + wei_g_k_c_xs_desc, + in_g_n_c_wis_desc, + wei_element_op, + out_element_op, + in_element_op); +} diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle_v3.hpp new file mode 100644 index 0000000000..7bc3be1a95 --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle_v3.hpp @@ -0,0 +1,1994 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include + +#include + +#include "ck/library/utility/numeric.hpp" +#include "ck/utility/common_header.hpp" +#include "ck/utility/env.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_conv_bwd_data_multiple_d.hpp" +#include "ck/tensor_operation/gpu/device/convolution_backward_data_specialization.hpp" +#include "ck/tensor_operation/operator_transform/transform_conv_bwd_data_to_gemm_v1.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/operator_transform/transform_conv_ngchw_to_nhwgc.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_elementwise_2d.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" +#include "ck/host_utility/flush_cache.hpp" +#include "ck/host_utility/io.hpp" + +#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +namespace { + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +#endif + kernel_grouped_conv_bwd_data_wmma_cshuffle_v3( + typename GridwiseGemm::Argument karg, + const std::array gemm_kernel_args, + const index_t gemms_count, + const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch, + const ComputePtrOffsetOfN compute_ptr_offset_of_n, + const index_t KBatch) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__)) +#if defined(__gfx11__) + // gfx11 does not support *_atomic_pk_add_f16/bf16 instructions + using e_data_type = remove_cvref_t>; + if constexpr(!(EGlobalMemoryDataOperation == InMemoryDataOperationEnum::AtomicAdd && + (std::is_same_v || + std::is_same_v))) + { +#endif + __shared__ char p_shared[GridwiseGemm::template GetSharedMemoryNumberOfByte< + typename GridwiseGemm::EpilogueCShuffle>()]; + auto epilogue_args = typename GridwiseGemm::EpilogueCShuffle{}; + + const index_t block_args_id = __builtin_amdgcn_readfirstlane(blockIdx.x); + index_t left = 0; + index_t right = gemms_count; + index_t group_id = index_t((left + right) / 2); + while((!(block_args_id >= gemm_kernel_args[group_id].BlockStart_ && + block_args_id < gemm_kernel_args[group_id].BlockEnd_)) && + left <= right) + { + if(block_args_id < gemm_kernel_args[group_id].BlockStart_) + { + right = group_id; + } + else + { + left = group_id; + } + group_id = index_t((left + right) / 2); + } + + const auto num_k_per_block = + gemm_kernel_args[group_id].a_grid_desc_ak0_m_ak1_.GetLength(Number<0>{}) / KBatch; + + if constexpr(HasMainKBlockLoopInAllGemm || NoMainKBlockLoopInAllGemm) + { + + GridwiseGemm::template Run( + p_shared, + gemm_kernel_args[group_id].a_grid_desc_ak0_m_ak1_, + gemm_kernel_args[group_id].b_grid_desc_bk0_n_bk1_, + gemm_kernel_args[group_id].ds_grid_desc_mblock_mperblock_nblock_nperblock_, + gemm_kernel_args[group_id].e_grid_desc_mblock_mperblock_nblock_nperblock_, + gemm_kernel_args[group_id].block_2_ctile_map_, + compute_ptr_offset_of_batch, + compute_ptr_offset_of_n, + num_k_per_block, + karg, + epilogue_args); + } + else + { + if(gemm_kernel_args[group_id].HasMainKBlockLoop_) + { + GridwiseGemm::template Run( + p_shared, + gemm_kernel_args[group_id].a_grid_desc_ak0_m_ak1_, + gemm_kernel_args[group_id].b_grid_desc_bk0_n_bk1_, + gemm_kernel_args[group_id].ds_grid_desc_mblock_mperblock_nblock_nperblock_, + gemm_kernel_args[group_id].e_grid_desc_mblock_mperblock_nblock_nperblock_, + gemm_kernel_args[group_id].block_2_ctile_map_, + compute_ptr_offset_of_batch, + compute_ptr_offset_of_n, + num_k_per_block, + karg, + epilogue_args); + } + else + { + GridwiseGemm::template Run( + p_shared, + gemm_kernel_args[group_id].a_grid_desc_ak0_m_ak1_, + gemm_kernel_args[group_id].b_grid_desc_bk0_n_bk1_, + gemm_kernel_args[group_id].ds_grid_desc_mblock_mperblock_nblock_nperblock_, + gemm_kernel_args[group_id].e_grid_desc_mblock_mperblock_nblock_nperblock_, + gemm_kernel_args[group_id].block_2_ctile_map_, + compute_ptr_offset_of_batch, + compute_ptr_offset_of_n, + num_k_per_block, + karg, + epilogue_args); + } + } + +#if defined(__gfx11__) + } +#endif +#else + ignore = karg; + ignore = gemm_kernel_args; + ignore = gemms_count; + ignore = compute_ptr_offset_of_batch; + ignore = compute_ptr_offset_of_n; + ignore = KBatch; + +#endif // End of if (!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__)) +} + +} // namespace + +// Conv backward data multiple D: +// input : output image A: [G, N, K, Ho, Wo] +// input : weight B: [G, K, C, Y, X], +// input : D0, D1, ... : [G, N, K, Ho, Wo] +// output : input image E: [G, N, C, Hi, Wi] +// C = a_op(A) * b_op(B) +// E = cde_op(C, D0, D1, ...) +template +struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffleV3 + : public DeviceGroupedConvBwdDataMultipleD +{ + // TODO: Extend support for more spatial dimensions. + static_assert(NDimSpatial == 2 || NDimSpatial == 3, + "wrong! only implemented for 2D and 3D now"); + + // MaxGroupedGemmGroupsNum is used to specify number of gemm args in compile time. With this + // implementation we can avoid copy data to workspace before kernel launch since number of + // groups is runtime parameter. If number of groups is larger than MaxGroupedGemmGroupsNum then + // we run this kernel in the loop. + static constexpr index_t MaxGroupedGemmGroupsNum = + ConvBackwardDataSpecialization == + ConvolutionBackwardDataSpecialization::Filter1x1Stride1Pad0 + ? 1 + : 32; + + using DeviceOp = DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffleV3; + + static constexpr index_t NumDTensor = DsDataType::Size(); + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + // Note: the values in CShuffleBlockTransferScalarPerVector sequence must be all the same. + // This is a limitation of the thread transfer implementation (v7r3) + // It should be fixed later on + static constexpr index_t CShuffleBlockTransferScalarPerVector_NPerBlock = + CShuffleBlockTransferScalarPerVector{}[I0]; + + static constexpr GemmSpecialization GemmSpec = GemmSpecialization::MNKPadding; + static constexpr bool IsSplitKSupported = + (CShuffleBlockTransferScalarPerVector_NPerBlock % 2 == 0 || sizeof(EDataType) % 4 == 0) && + std::is_same_v, element_wise::PassThrough>; + + // TODO: Add support for different A and B data types. + using ABDataType = ADataType; + + static constexpr bool isATensorColMajor = + (ConvBackwardDataSpecialization == + ConvolutionBackwardDataSpecialization::Filter1x1Stride1Pad0) && + (ABlockTransferSrcVectorDim == 1) && + (is_NGCHW_NGKHW() || + is_NGCDHW_NGKDHW()); + + static constexpr bool NeedTransposeKernel = + (isATensorColMajor == false) && (is_NGCHW_NGKHW() || + is_NGCDHW_NGKDHW()); + + static constexpr bool CTranspose = + (NeedTransposeKernel == false) && (is_same_v || + is_same_v); + + using ALayoutAfterTranspose = std::conditional_t< + is_NGCHW_NGKHW() && NeedTransposeKernel, + tensor_layout::convolution::NHWGK, + std::conditional_t() && NeedTransposeKernel, + tensor_layout::convolution::NDHWGK, + ALayout>>; + using BLayoutAfterTranspose = std::conditional_t< + is_NGCHW_GKCYX_NGKHW() && NeedTransposeKernel, + tensor_layout::convolution::GKYXC, + std::conditional_t() && + NeedTransposeKernel, + tensor_layout::convolution::GKZYXC, + BLayout>>; + using ELayoutAfterTranspose = std::conditional_t< + is_NGCHW_NGKHW() && NeedTransposeKernel, + tensor_layout::convolution::NHWGC, + std::conditional_t() && NeedTransposeKernel, + tensor_layout::convolution::NDHWGC, + ELayout>>; + + using ConvToGemmBwdDataTransform = TransformConvBwdDataToGemm_v1; + + // Dummy function just used to create an alias to Grid Descriptors + static auto + GetDummyABDsEGridDescriptor(const ConvToGemmBwdDataTransform& conv_to_gemm_transform) + { + const auto a_grid_desc_ak0_m_ak1 = conv_to_gemm_transform.MakeADescriptor_AK0_M_AK1(); + + const auto b_grid_desc_bk0_n_bk1 = conv_to_gemm_transform.MakeBDescriptor_BK0_N_BK1(); + + const auto ds_grid_desc_m_n = generate_tuple( + [&](auto i) { + using DLayout = remove_cvref_t>; + using DDataType = remove_cvref_t>; + using ConvToGemmBwdDataTransformD = + TransformConvBwdDataToGemm_v1; + return ConvToGemmBwdDataTransformD{}.MakeCDescriptor_M_N(); + }, + Number{}); + + const auto e_grid_desc_m_n = conv_to_gemm_transform.MakeCDescriptor_M_N(); + + if constexpr(CTranspose) + { + return make_tuple( + b_grid_desc_bk0_n_bk1, a_grid_desc_ak0_m_ak1, ds_grid_desc_m_n, e_grid_desc_m_n); + } + else + { + return make_tuple( + a_grid_desc_ak0_m_ak1, b_grid_desc_bk0_n_bk1, ds_grid_desc_m_n, e_grid_desc_m_n); + } + } + + // GridwiseGemm + using GridwiseGemm = GridwiseGemm_wmma_cshuffle_v3< + ALayout, + BLayout, + DsLayout, + ELayout, + Tuple, + Tuple, + AccDataType, + CShuffleDataType, + DsDataType, + EDataType, + AElementwiseOp, + BElementwiseOp, + CDEElementwiseOp, + GemmSpec, + BlockSize, + MPerBlock, + NPerBlock, + KPerBlock, + AK1, + BK1, + MPerWmma, + NPerWmma, + MRepeat, + NRepeat, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + false, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + false, + BBlockLdsExtraN, + CShuffleMRepeatPerShuffle, + CShuffleNRepeatPerShuffle, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CShuffleBlockTransferScalarPerVector, + BlkGemmPipeSched, + BlkGemmPipelineVer, + AComputeType, + BComputeType, + false, + false>; + +#define GridwiseGemmCTransposeTemplateParameters \ + ALayout, BLayout, DsLayout, ELayout, Tuple, Tuple, AccDataType, \ + CShuffleDataType, DsDataType, EDataType, BElementwiseOp, AElementwiseOp, CDEElementwiseOp, \ + GemmSpec, BlockSize, NPerBlock, MPerBlock, KPerBlock, BK1, AK1, NPerWmma, MPerWmma, \ + NRepeat, MRepeat, BBlockTransferThreadClusterLengths_BK0_N_BK1, \ + BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, \ + BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, \ + BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, \ + ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, \ + ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, \ + ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, \ + ABlockLdsExtraM, CShuffleMRepeatPerShuffle, CShuffleNRepeatPerShuffle, \ + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, \ + CShuffleBlockTransferScalarPerVector, BlkGemmPipeSched, BlkGemmPipelineVer, BComputeType, \ + AComputeType, false, false + + using GridwiseGemmCTranspose = + std::conditional_t, + GridwiseGemm>; + + template + static auto transform_k0_m_k1_to_m_k(const Desc_K0_M_K1& desc_k0_m_k1) + { + const auto grid_desc_m_k = transform_tensor_descriptor( + desc_k0_m_k1, + make_tuple(make_pass_through_transform(desc_k0_m_k1.GetLength(I1)), + make_merge_transform( + make_tuple(desc_k0_m_k1.GetLength(I0), desc_k0_m_k1.GetLength(I2)))), + make_tuple(Sequence<1>{}, Sequence<0, 2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return grid_desc_m_k; + } + + // Note: the dummy function is used just to create the alias + constexpr static ConvToGemmBwdDataTransform dummy_conv_to_gemm_transform; + using ABDsEGridDesc = decltype(GetDummyABDsEGridDescriptor(dummy_conv_to_gemm_transform)); + + using AGridDesc_AK0_M_AK1 = remove_cvref_t>; + using BGridDesc_BK0_N_BK1 = remove_cvref_t>; + using DsGridDesc_M_N = remove_cvref_t>; + using EGridDesc_M_N = remove_cvref_t>; + + using AGridDesc_M_K = decltype(transform_k0_m_k1_to_m_k(AGridDesc_AK0_M_AK1{})); + using BGridDesc_N_K = decltype(transform_k0_m_k1_to_m_k(BGridDesc_BK0_N_BK1{})); + + // Note: here we can call gridwise functions with dummy arguments, + // just to create the alias + using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = + decltype(GridwiseGemmCTranspose::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + DsGridDesc_M_N{}, 1, 1)); + using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = + decltype(GridwiseGemmCTranspose::MakeDEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + EGridDesc_M_N{}, 1, 1)); + + using Block2ETileMap = typename GridwiseGemmCTranspose::Block2CTileMap; + using GroupedGemmBlock2ETileMap = OffsettedBlockToCTileMap; + + struct GemmArgs + { + GemmArgs() = default; + GemmArgs(AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, + BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, + DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock + ds_grid_desc_mblock_mperblock_nblock_nperblock, + EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock + e_grid_desc_mblock_mperblock_nblock_nperblock, + GroupedGemmBlock2ETileMap block_2_ctile_map, + index_t BlockStart, + index_t BlockEnd, + bool HasMainKBlockLoop) + : a_grid_desc_ak0_m_ak1_(a_grid_desc_ak0_m_ak1), + b_grid_desc_bk0_n_bk1_(b_grid_desc_bk0_n_bk1), + + ds_grid_desc_mblock_mperblock_nblock_nperblock_( + ds_grid_desc_mblock_mperblock_nblock_nperblock), + + e_grid_desc_mblock_mperblock_nblock_nperblock_( + e_grid_desc_mblock_mperblock_nblock_nperblock), + block_2_ctile_map_(block_2_ctile_map), + BlockStart_(BlockStart), + BlockEnd_(BlockEnd), + HasMainKBlockLoop_(HasMainKBlockLoop) + + { + } + // tensor descriptors for block/thread-wise copy + AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_; + BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_; + DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock + ds_grid_desc_mblock_mperblock_nblock_nperblock_; + EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock e_grid_desc_mblock_mperblock_nblock_nperblock_; + GroupedGemmBlock2ETileMap block_2_ctile_map_; + index_t BlockStart_, BlockEnd_; + bool HasMainKBlockLoop_; + }; + // block-to-e-tile map for elementwise kernels + using Block2TileMapInOutElementwise = BlockToCTileMap_M00_N0_M01Adapt; + using Block2TileMapWeiElementwise = BlockToCTileMap_M00_N0_M01Adapt; + static constexpr index_t ClusterLengthMPerBlock = + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock::At(1); + static constexpr index_t ClusterLengthNPerBlock = + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock::At(3); + + static constexpr auto conv_ngchw_to_nhwgc_transformer = + TransformConvNGCHWToNHWGC{}; + + static constexpr index_t TransposeTransferInScalarPerVectorAligned = + std::min(MPerBlock / ClusterLengthMPerBlock, MaxTransposeTransferInScalarPerVector); + static constexpr index_t TransposeTransferOutScalarPerVectorAligned = + std::min(MPerBlock / ClusterLengthMPerBlock, MaxTransposeTransferOutScalarPerVector); + + using NGCHWTransposeDescType = + remove_cvref_t({}, {}))>; + using NHWGCTransposeDescType = + remove_cvref_t({}, {}))>; + using GKCYXTransposeDescType = + remove_cvref_t({}, {}))>; + using GKYXCTransposeDescType = + remove_cvref_t({}, {}))>; + + static constexpr index_t ElementwiseBlocksize = ClusterLengthMPerBlock * ClusterLengthNPerBlock; + + using GridwiseElementwiseInputTranspose = + GridwiseElementwise, + Tuple, + Tuple, + Tuple, + Block2TileMapInOutElementwise, + element_wise::PassThrough, + ElementwiseBlocksize, + NPerBlock, + MPerBlock, + NPerBlock / ClusterLengthNPerBlock, + MPerBlock / ClusterLengthMPerBlock, + Sequence<1, 0>, + Sequence, + Sequence, + I1, + I0>; + + using GridwiseElementwiseWeightTranspose = + GridwiseElementwise, + Tuple, + Tuple, + Tuple, + Block2TileMapWeiElementwise, + element_wise::PassThrough, + ElementwiseBlocksize, + MPerBlock, + NPerBlock, + MPerBlock / ClusterLengthMPerBlock, + NPerBlock / ClusterLengthNPerBlock, + Sequence<1, 0>, + Sequence<1>, + Sequence, + I0, + I1>; + + using GridwiseElementwiseOutputTranspose = + GridwiseElementwise, + Tuple, + Tuple, + Tuple, + Block2TileMapInOutElementwise, + element_wise::PassThrough, + ElementwiseBlocksize, + NPerBlock, + MPerBlock, + NPerBlock / ClusterLengthNPerBlock, + MPerBlock / ClusterLengthMPerBlock, + Sequence<1, 0>, + Sequence, + Sequence, + I0, + I1>; + // Argument + struct Argument : public BaseArgument + { + Argument(const void* p_a, // output image + const void* p_b, // weight + const std::array& p_ds, // bias + void* p_e, // input image + const std::array& a_g_n_k_wos_lengths, + const std::array& a_g_n_k_wos_strides, + const std::array& b_g_k_c_xs_lengths, + const std::array& b_g_k_c_xs_strides, + const std::array, NumDTensor>& + ds_g_n_c_wis_lengths, + const std::array, NumDTensor>& + ds_g_n_c_wis_strides, + const std::array& e_g_n_c_wis_lengths, + const std::array& e_g_n_c_wis_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads, + const AElementwiseOp& a_element_op, + const BElementwiseOp& b_element_op, + const CDEElementwiseOp& cde_element_op, + ck::index_t split_k = 1) + : p_a_grid_{static_cast(p_a)}, + p_b_grid_{static_cast(p_b)}, + p_ds_grid_{}, + p_e_grid_{static_cast(p_e)}, + num_group_{a_g_n_k_wos_lengths[0]}, + a_element_op_{a_element_op}, + b_element_op_{b_element_op}, + cde_element_op_{cde_element_op}, + a_g_n_k_wos_lengths_{a_g_n_k_wos_lengths}, + b_g_k_c_xs_lengths_{b_g_k_c_xs_lengths}, + e_g_n_c_wis_lengths_{e_g_n_c_wis_lengths}, + conv_filter_strides_{conv_filter_strides}, + input_left_pads_{input_left_pads}, + input_right_pads_{input_right_pads}, + k_batch_{split_k} + { + bool image_covered_dilation = true; + bool image_covered_strides = true; + for(index_t d = 0; d < NDimSpatial; d++) + { + // If dilation and stride is not equal we will have some empty places + image_covered_dilation &= + conv_filter_dilations[d] == 1 || conv_filter_strides[d] == 1; + // If stride is larger than windows size then we will have some empty places + image_covered_strides &= conv_filter_strides[d] <= b_g_k_c_xs_lengths[d + I3]; + } + bool if_d_is_output_mem = false; + const void* out_mem_void = static_cast(p_e); + static_for<0, NumDTensor, 1>{}([&](auto i) { + if(p_ds[i] == out_mem_void) + { + if_d_is_output_mem = true; + } + }); + + bwd_needs_zero_out = k_batch_ > 1 || !image_covered_dilation || !image_covered_strides; + + // Temporary workaround untill prove/fix above conditions. + bwd_needs_zero_out = !if_d_is_output_mem; + e_space_size_bytes = + ck::accumulate_n( + e_g_n_c_wis_lengths_.begin(), NDimSpatial + I3, 1, std::multiplies<>()) * + sizeof(EDataType); + + std::array a_g_n_k_wos_strides_transposed = + NeedTransposeKernel ? conv_ngchw_to_nhwgc_transformer.TransposeInOutStrides( + a_g_n_k_wos_lengths, a_g_n_k_wos_strides) + : a_g_n_k_wos_strides; + std::array b_g_k_c_xs_strides_transposed = + NeedTransposeKernel ? conv_ngchw_to_nhwgc_transformer.TransposeWeiStrides( + b_g_k_c_xs_lengths, b_g_k_c_xs_strides) + : b_g_k_c_xs_strides; + std::array e_g_n_c_wis_strides_transposed = + NeedTransposeKernel ? conv_ngchw_to_nhwgc_transformer.TransposeInOutStrides( + e_g_n_c_wis_lengths, e_g_n_c_wis_strides) + : e_g_n_c_wis_strides; + + // populate Ds pointer + static_for<0, NumDTensor, 1>{}([&](auto i) { + using DDataType = remove_cvref_t>; + + p_ds_grid_(i) = static_cast(p_ds[i]); + }); + + static_for<0, NumDTensor, 1>{}([&](auto i) { + compute_ptr_offset_of_batch_.BatchStrideDs_(i) = ds_g_n_c_wis_strides[i][0]; + }); + + static constexpr auto NonSpatialDimsNum = Number<3>{}; + + static constexpr auto DIdx = Number{}; + static constexpr auto HIdx = + NDimSpatial == 2 ? Number{} : Number{}; + static constexpr auto WIdx = NDimSpatial == 2 ? Number{} + : Number{}; + + static constexpr auto ZIdx = Number{}; + static constexpr auto YIdx = + NDimSpatial == 2 ? Number{} : Number{}; + static constexpr auto XIdx = NDimSpatial == 2 ? Number{} + : Number{}; + + // problem definition + const index_t Z = b_g_k_c_xs_lengths[ZIdx]; + const index_t Y = b_g_k_c_xs_lengths[YIdx]; + const index_t X = b_g_k_c_xs_lengths[XIdx]; + + const index_t ConvStrideD = conv_filter_strides[DIdx - NonSpatialDimsNum]; + const index_t ConvStrideH = conv_filter_strides[HIdx - NonSpatialDimsNum]; + const index_t ConvStrideW = conv_filter_strides[WIdx - NonSpatialDimsNum]; + + const index_t ConvDilationD = conv_filter_dilations[DIdx - NonSpatialDimsNum]; + const index_t ConvDilationH = conv_filter_dilations[HIdx - NonSpatialDimsNum]; + const index_t ConvDilationW = conv_filter_dilations[WIdx - NonSpatialDimsNum]; + + const auto GcdStrideDilationD = math::gcd(ConvStrideD, ConvDilationD); + const auto GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH); + const auto GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW); + + const auto ZTilde = NDimSpatial == 3 ? ConvStrideD / GcdStrideDilationD : 1; + const auto YTilde = ConvStrideH / GcdStrideDilationH; + const auto XTilde = ConvStrideW / GcdStrideDilationW; + + index_t grid_size = 0; + // Allocate place for sets of gemms + gemm_kernel_args_.resize( + math::integer_divide_ceil(ZTilde * YTilde * XTilde, MaxGroupedGemmGroupsNum)); + + for(index_t i_ztilde = 0; i_ztilde < ZTilde; ++i_ztilde) + { + for(index_t i_ytilde = 0; i_ytilde < YTilde; ++i_ytilde) + { + for(index_t i_xtilde = 0; i_xtilde < XTilde; ++i_xtilde) + { + // check slice is valid + const auto ZDotSlice = + NDimSpatial == 3 ? math::integer_divide_ceil(Z - i_ztilde, ZTilde) : 1; + const auto YDotSlice = math::integer_divide_ceil(Y - i_ytilde, YTilde); + const auto XDotSlice = math::integer_divide_ceil(X - i_xtilde, XTilde); + + if(YDotSlice * XDotSlice * ZDotSlice <= 0) + { + continue; + } + + std::array tildes; + if constexpr(NDimSpatial == 2) + { + tildes = {i_ytilde, i_xtilde}; + } + else if constexpr(NDimSpatial == 3) + { + tildes = {i_ztilde, i_ytilde, i_xtilde}; + } + else + { + throw std::runtime_error("wrong! only implemented for 2D and 3D now"); + } + + ConvToGemmBwdDataTransform conv_to_gemm_transform_{ + a_g_n_k_wos_lengths, + a_g_n_k_wos_strides_transposed, + b_g_k_c_xs_lengths, + b_g_k_c_xs_strides_transposed, + e_g_n_c_wis_lengths, + e_g_n_c_wis_strides_transposed, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + tildes, + k_batch_}; + + conv_N_per_block_ = conv_to_gemm_transform_.N_; + + const auto a_grid_desc_ak0_m_ak1 = [&]() { + if constexpr(CTranspose) + { + return conv_to_gemm_transform_.MakeBDescriptor_BK0_N_BK1(); + } + else + { + return conv_to_gemm_transform_.MakeADescriptor_AK0_M_AK1(); + } + }(); + + const auto b_grid_desc_bk0_n_bk1 = [&]() { + if constexpr(CTranspose) + { + return conv_to_gemm_transform_.MakeADescriptor_AK0_M_AK1(); + } + else + { + return conv_to_gemm_transform_.MakeBDescriptor_BK0_N_BK1(); + } + }(); + + DsGridDesc_M_N ds_grid_desc_m_n; + + // populate Ds desc + static_for<0, NumDTensor, 1>{}([&](auto i) { + using DLayout = remove_cvref_t>; + using DDataType = remove_cvref_t>; + using ConvToGemmBwdDataTransformD = + TransformConvBwdDataToGemm_v1; + ConvToGemmBwdDataTransformD conv_to_gemm_transform_d{ + a_g_n_k_wos_lengths, + a_g_n_k_wos_strides_transposed, + b_g_k_c_xs_lengths, + b_g_k_c_xs_strides_transposed, + ds_g_n_c_wis_lengths[i], + ds_g_n_c_wis_strides[i], + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + tildes}; + + ds_grid_desc_m_n(i) = conv_to_gemm_transform_d.MakeCDescriptor_M_N(); + }); + + const auto e_grid_desc_m_n = conv_to_gemm_transform_.MakeCDescriptor_M_N(); + + // desc for problem definition + const auto a_grid_desc_m_k = + transform_k0_m_k1_to_m_k(a_grid_desc_ak0_m_ak1); + const auto b_grid_desc_n_k = + transform_k0_m_k1_to_m_k(b_grid_desc_bk0_n_bk1); + + a_grid_desc_m_k_container_.push_back(a_grid_desc_m_k); + b_grid_desc_n_k_container_.push_back(b_grid_desc_n_k); + ds_grid_desc_m_n_container_.push_back(ds_grid_desc_m_n); + e_grid_desc_m_n_container_.push_back(e_grid_desc_m_n); + + const index_t grid_size_grp = + std::get<0>(GridwiseGemmCTranspose::CalculateGridSize( + e_grid_desc_m_n.GetLength(I0), e_grid_desc_m_n.GetLength(I1), 1)); + const index_t BlockStart = grid_size; + const index_t BlockEnd = grid_size + grid_size_grp; + + grid_size += grid_size_grp; + + const auto block_2_etile_map = GroupedGemmBlock2ETileMap( + Block2ETileMap( + e_grid_desc_m_n.GetLength(I0), e_grid_desc_m_n.GetLength(I1), 4), + BlockStart); + + const index_t GemmM = a_grid_desc_m_k.GetLength(I0); + const index_t GemmN = b_grid_desc_n_k.GetLength(I0); + const index_t GemmK = a_grid_desc_m_k.GetLength(I1); + + const auto MBlock = GridwiseGemmCTranspose::CalculateMBlock(GemmM); + const auto NBlock = GridwiseGemmCTranspose::CalculateNBlock(GemmN); + + index_t k_grain = split_k * KPerBlock; + index_t K_split = (GemmK + k_grain - 1) / k_grain * KPerBlock; + + const bool HasMainKBlockLoop = + GridwiseGemmCTranspose::CalculateHasMainKBlockLoop(K_split); + + gemm_kernel_args_[gemms_count_ / + MaxGroupedGemmGroupsNum][gemms_count_ % + MaxGroupedGemmGroupsNum] = + GemmArgs{a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + GridwiseGemmCTranspose:: + MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + ds_grid_desc_m_n, MBlock, NBlock), + GridwiseGemmCTranspose:: + MakeDEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + e_grid_desc_m_n, MBlock, NBlock), + block_2_etile_map, + BlockStart, + BlockEnd, + HasMainKBlockLoop}; + gemms_count_++; + if(gemms_count_ % MaxGroupedGemmGroupsNum == 0) + { + gemms_grid_size_.push_back(grid_size); + grid_size = 0; + } + } + } + } + gemm_kernel_args_.resize( + math::integer_divide_ceil(gemms_count_, MaxGroupedGemmGroupsNum)); + gemms_grid_size_.push_back(grid_size); + + // A/B/Ds/E Batch Stride + compute_ptr_offset_of_batch_.BatchStrideA_ = a_g_n_k_wos_strides_transposed[0]; + compute_ptr_offset_of_batch_.BatchStrideB_ = b_g_k_c_xs_strides_transposed[0]; + compute_ptr_offset_of_batch_.BatchStrideE_ = e_g_n_c_wis_strides_transposed[0]; + + compute_ptr_offset_of_n_.BatchStrideA_ = + a_g_n_k_wos_strides_transposed[1] * conv_N_per_block_; + compute_ptr_offset_of_n_.BatchStrideE_ = + e_g_n_c_wis_strides_transposed[1] * conv_N_per_block_; + + num_workgroups_per_Conv_N_ = a_g_n_k_wos_lengths_[I1] / conv_N_per_block_; + + if constexpr(NeedTransposeKernel) + { + // Use not modified base strides + a_in_transpose_desc_ = + conv_ngchw_to_nhwgc_transformer.template MakeNGCHWTransposeDesc( + a_g_n_k_wos_lengths, a_g_n_k_wos_strides, num_workgroups_per_Conv_N_); + a_out_transpose_desc_ = + conv_ngchw_to_nhwgc_transformer.template MakeNHWGCTransposeDesc( + a_g_n_k_wos_lengths, a_g_n_k_wos_strides, num_workgroups_per_Conv_N_); + + b_in_transpose_desc_ = + conv_ngchw_to_nhwgc_transformer.template MakeGKCYXTransposeDesc( + b_g_k_c_xs_lengths, b_g_k_c_xs_strides); + b_out_transpose_desc_ = + conv_ngchw_to_nhwgc_transformer.template MakeGKYXCTransposeDesc( + b_g_k_c_xs_lengths, b_g_k_c_xs_strides); + + e_in_transpose_desc_ = + conv_ngchw_to_nhwgc_transformer.template MakeNHWGCTransposeDesc( + e_g_n_c_wis_lengths, e_g_n_c_wis_strides, num_workgroups_per_Conv_N_); + e_out_transpose_desc_ = + conv_ngchw_to_nhwgc_transformer.template MakeNGCHWTransposeDesc( + e_g_n_c_wis_lengths, e_g_n_c_wis_strides, num_workgroups_per_Conv_N_); + + elementwise_block_2_ctile_map_transpose_a_ = Block2TileMapInOutElementwise{ + a_in_transpose_desc_.GetLength(I0), a_in_transpose_desc_.GetLength(I1)}; + elementwise_block_2_ctile_map_transpose_b_ = Block2TileMapWeiElementwise{ + b_in_transpose_desc_.GetLength(I0), b_in_transpose_desc_.GetLength(I1)}; + elementwise_block_2_ctile_map_transpose_e_ = Block2TileMapInOutElementwise{ + e_in_transpose_desc_.GetLength(I0), e_in_transpose_desc_.GetLength(I1)}; + + compute_ptr_offset_of_workspace_n_.BatchStrideA_ = + a_g_n_k_wos_strides[1] * conv_N_per_block_; + compute_ptr_offset_of_workspace_n_.BatchStrideE_ = + e_g_n_c_wis_strides[1] * conv_N_per_block_; + } + } + + std::size_t GetWorkspaceATensorSizeBytes() const + { + if constexpr(NeedTransposeKernel) + { + const long_index_t a_acum = ck::accumulate_n( + a_g_n_k_wos_lengths_.begin(), NDimSpatial + I3, 1, std::multiplies<>()); + // Align to 128B + return math::integer_divide_ceil(sizeof(ADataType) * a_acum, 128) * 128; + } + else + { + return 0; + } + } + + std::size_t GetWorkspaceBTensorSizeBytes() const + { + if constexpr(NeedTransposeKernel) + { + const long_index_t b_acum = ck::accumulate_n( + b_g_k_c_xs_lengths_.begin(), NDimSpatial + I3, 1, std::multiplies<>()); + // Align to 128B + return math::integer_divide_ceil(sizeof(BDataType) * b_acum, 128) * 128; + } + else + { + return 0; + } + } + + std::size_t GetWorkspaceETensorSizeBytes() const + { + if constexpr(NeedTransposeKernel) + { + const long_index_t e_accum = ck::accumulate_n( + e_g_n_c_wis_lengths_.begin(), NDimSpatial + I3, 1, std::multiplies<>()); + return sizeof(EDataType) * e_accum; + } + else + { + return 0; + } + } + + std::size_t GetWorkspaceSizeBytes() const + { + return GetWorkspaceATensorSizeBytes() + GetWorkspaceBTensorSizeBytes() + + GetWorkspaceETensorSizeBytes(); + } + + void Print() const + { + for(std::size_t i = 0; i < a_grid_desc_m_k_container_.size(); i++) + { + std::cout << "a_grid_desc_m_ak_container_" << a_grid_desc_m_k_container_[i] + << std::endl; + + std::cout << "b_grid_desc_n_bk_container_" << b_grid_desc_n_k_container_[i] + << std::endl; + + static_for<0, NumDTensor, 1>{}([&](auto j) { + std::cout << "ds_grid_desc_mblock_mperblock_nblock_nperblock_container_" + << ds_grid_desc_m_n_container_[i][j] << std::endl; + }); + + std::cout << "e_grid_desc_mblock_mperblock_nblock_nperblock_container_" + << e_grid_desc_m_n_container_[i] << std::endl; + } + } + + // pointers + const ADataType* p_a_grid_; + const BDataType* p_b_grid_; + typename GridwiseGemm::DsGridPointer p_ds_grid_; + EDataType* p_e_grid_; + + // tensor descriptor for problem definition + index_t num_group_; + index_t conv_N_per_block_; + std::vector a_grid_desc_m_k_container_; + std::vector b_grid_desc_n_k_container_; + std::vector ds_grid_desc_m_n_container_; + std::vector e_grid_desc_m_n_container_; + + // tensor descriptor for block-wise copy + std::vector a_grid_desc_ak0_m_ak1_container_; + std::vector b_grid_desc_bk0_n_bk1_container_; + std::vector + ds_grid_desc_mblock_mperblock_nblock_nperblock_container_; + std::vector + e_grid_desc_mblock_mperblock_nblock_nperblock_container_; + + // block-to-e-tile map elementwise kernels + Block2TileMapInOutElementwise elementwise_block_2_ctile_map_transpose_a_, + elementwise_block_2_ctile_map_transpose_e_; + Block2TileMapWeiElementwise elementwise_block_2_ctile_map_transpose_b_; + + NGCHWTransposeDescType a_in_transpose_desc_, e_out_transpose_desc_; + NHWGCTransposeDescType a_out_transpose_desc_, e_in_transpose_desc_; + GKCYXTransposeDescType b_in_transpose_desc_; + GKYXCTransposeDescType b_out_transpose_desc_; + + // for computing batch offset + ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch_; + ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_n_; + ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_workspace_n_; + + // element-wise op + AElementwiseOp a_element_op_; + BElementwiseOp b_element_op_; + CDEElementwiseOp cde_element_op_; + + std::array a_g_n_k_wos_lengths_; + std::array b_g_k_c_xs_lengths_; + std::array e_g_n_c_wis_lengths_; + std::array conv_filter_strides_; + std::array input_left_pads_; + std::array input_right_pads_; + + const index_t k_batch_; + index_t num_workgroups_per_Conv_N_; + std::vector gemms_grid_size_; + index_t gemms_count_ = 0; + std::vector> gemm_kernel_args_; + + bool bwd_needs_zero_out; + long_index_t e_space_size_bytes; + }; + + // Invoker + struct Invoker : public BaseInvoker + { + using Argument = DeviceOp::Argument; + + template + float RunMultiDGemm(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + float ave_time = 0; + + const index_t gdy = arg.num_group_; + const index_t gdz = arg.num_workgroups_per_Conv_N_ * arg.k_batch_; + + const ADataType* p_a_grid = arg.p_a_grid_; + const BDataType* p_b_grid = arg.p_b_grid_; + EDataType* p_e_grid = arg.p_e_grid_; + if constexpr(NeedTransposeKernel) + { + if constexpr(is_NGCHW_NGKHW() || + is_NGCDHW_NGKDHW()) + { + p_a_grid = type_convert(arg.p_workspace_); + p_e_grid = + type_convert(arg.p_workspace_) + + (arg.GetWorkspaceATensorSizeBytes() + arg.GetWorkspaceBTensorSizeBytes()) / + sizeof(EDataType); + } + + if constexpr(is_NGCHW_GKCYX_NGKHW() || + is_NGCDHW_GKCZYX_NGKDHW()) + { + p_b_grid = type_convert(arg.p_workspace_) + + arg.GetWorkspaceATensorSizeBytes() / sizeof(BDataType); + } + } + // Create dummy Ds strides because they are not used in convolution + // since we pass the grid descriptor to gridwise gemm + std::array StrideDs_dummy; + static_for<0, NumDTensor, 1>{}([&](auto i) { StrideDs_dummy[i] = I0; }); + // TODO: fix this, it's not nice to go back and forth + std::array p_ds; + static_for<0, NumDTensor, 1>{}( + [&](auto i) { p_ds[i] = static_cast(arg.p_ds_grid_[i]); }); + + for(std::size_t gemm_set_id = 0; gemm_set_id < arg.gemm_kernel_args_.size(); + gemm_set_id++) + { + const index_t GemmM = arg.a_grid_desc_m_k_container_[gemm_set_id].GetLength(I0); + const index_t GemmN = arg.b_grid_desc_n_k_container_[gemm_set_id].GetLength(I0); + const index_t GemmK = arg.a_grid_desc_m_k_container_[gemm_set_id].GetLength(I1); + typename GridwiseGemmCTranspose::Argument gemm_arg{ + CTranspose ? std::array{p_b_grid} + : std::array{p_a_grid}, + CTranspose ? std::array{p_a_grid} + : std::array{p_b_grid}, + p_ds, + p_e_grid, + GemmM, + GemmN, + GemmK, + std::array{I0}, + std::array{I0}, + StrideDs_dummy, + I0, + arg.k_batch_, + CTranspose ? arg.b_element_op_ : arg.a_element_op_, + CTranspose ? arg.a_element_op_ : arg.b_element_op_, + arg.cde_element_op_}; + if(!GridwiseGemmCTranspose::CheckValidity(gemm_arg)) + { + throw std::runtime_error("wrong! device_op has invalid setting"); + } + const index_t gdx = arg.gemms_grid_size_[gemm_set_id]; + + const index_t gemms_count_for_set = + gemm_set_id == arg.gemm_kernel_args_.size() - 1 + ? arg.gemms_count_ - MaxGroupedGemmGroupsNum * gemm_set_id + : MaxGroupedGemmGroupsNum; + + const std::array& gemm_kernel_args = + arg.gemm_kernel_args_[gemm_set_id]; + + const auto clear_workspace = [&]() { + if(arg.bwd_needs_zero_out && gemm_set_id == 0) + { + hip_check_error(hipMemsetAsync( + p_e_grid, 0, arg.e_space_size_bytes, stream_config.stream_id_)); + } + }; + + bool has_loop_in_all_gemm = true; + bool no_loop_in_all_gemm = true; + for(auto i = 0; i < gemms_count_for_set; i++) + { + has_loop_in_all_gemm &= gemm_kernel_args[i].HasMainKBlockLoop_; + no_loop_in_all_gemm &= !gemm_kernel_args[i].HasMainKBlockLoop_; + } + + auto launch_kernel = [&](auto has_main_k_block_loop_, auto no_main_k_block_loop) { + constexpr bool has_main_loop = has_main_k_block_loop_.value; + constexpr bool no_main_loop = no_main_k_block_loop.value; + const auto kernel = kernel_grouped_conv_bwd_data_wmma_cshuffle_v3< + GridwiseGemmCTranspose, + DeviceOp::AGridDesc_AK0_M_AK1, + DeviceOp::BGridDesc_BK0_N_BK1, + DeviceOp::DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, + DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, + MaxGroupedGemmGroupsNum, + GemmArgs, + ComputePtrOffsetOfStridedBatch, + ComputePtrOffsetOfStridedBatch, + ElementOp, + has_main_loop, + no_main_loop, + CTranspose>; + + return launch_and_time_kernel_with_preprocess(stream_config, + clear_workspace, + kernel, + dim3(gdx, gdy, gdz), + dim3(BlockSize), + 0, + gemm_arg, + gemm_kernel_args, + gemms_count_for_set, + arg.compute_ptr_offset_of_batch_, + arg.compute_ptr_offset_of_n_, + arg.k_batch_); + }; + if(has_loop_in_all_gemm) + { + ave_time += launch_kernel(integral_constant{}, + integral_constant{}); + } + else if(no_loop_in_all_gemm) + { + ave_time += launch_kernel(integral_constant{}, + integral_constant{}); + } + else + { + ave_time += launch_kernel(integral_constant{}, + integral_constant{}); + } + } + + return ave_time; + } + + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + float ave_time = 0; + + if(stream_config.log_level_ > 0) + { + arg.Print(); + } + // Transpose from NGKHW to NHWGK + if constexpr(NeedTransposeKernel) + { + EDataType* p_e_in_grid = + type_convert(arg.p_workspace_) + + (arg.GetWorkspaceATensorSizeBytes() + arg.GetWorkspaceBTensorSizeBytes()) / + sizeof(EDataType); + + const auto clear_workspace = [&]() { + hip_check_error(hipMemsetAsync(p_e_in_grid, + 0, + arg.GetWorkspaceETensorSizeBytes(), + stream_config.stream_id_)); + }; + + const index_t a_grid_size = + arg.elementwise_block_2_ctile_map_transpose_a_.CalculateGridSize( + arg.a_in_transpose_desc_) * + arg.num_workgroups_per_Conv_N_; + const index_t b_grid_size = + (is_NGCHW_GKCYX_NGKHW() || + is_NGCDHW_GKCZYX_NGKDHW()) + ? arg.elementwise_block_2_ctile_map_transpose_b_.CalculateGridSize( + arg.b_in_transpose_desc_) + : 0; // Dont run transpose B if not needed + + ADataType* p_a_out_grid = type_convert(arg.p_workspace_); + BDataType* p_b_out_grid = type_convert(arg.p_workspace_) + + arg.GetWorkspaceATensorSizeBytes() / sizeof(BDataType); + + auto kernel_transpose = + kernel_elementwise_batched_dual, + ck::Tuple, + ck::Tuple, + ck::Tuple, + ck::Tuple, + ck::Tuple, + ck::Tuple, + ck::Tuple, + Block2TileMapInOutElementwise, + Block2TileMapWeiElementwise, + element_wise::PassThrough, + I1, + I1, + I1, + I1>; + + ave_time += launch_and_time_kernel_with_preprocess( + stream_config, + clear_workspace, + kernel_transpose, + dim3(a_grid_size + b_grid_size), + dim3(ElementwiseBlocksize), + 0, + make_tuple(arg.a_in_transpose_desc_), + make_tuple(arg.b_in_transpose_desc_), + make_tuple(arg.a_out_transpose_desc_), + make_tuple(arg.b_out_transpose_desc_), + make_tuple(arg.p_a_grid_), + make_tuple(arg.p_b_grid_), + make_tuple(p_a_out_grid), + make_tuple(p_b_out_grid), + arg.elementwise_block_2_ctile_map_transpose_a_, + arg.elementwise_block_2_ctile_map_transpose_b_, + element_wise::PassThrough{}, + a_grid_size, + arg.num_workgroups_per_Conv_N_, + I1, // B is not splited per N + std::array{ + static_cast(arg.compute_ptr_offset_of_workspace_n_.BatchStrideA_)}, + std::array{0}, + std::array{ + static_cast(arg.compute_ptr_offset_of_n_.BatchStrideA_)}, + std::array{0}); + } + if(arg.k_batch_ > 1) + { + if constexpr(IsSplitKSupported) + { + ave_time += + RunMultiDGemm(arg, stream_config); + } + } + else + { + ave_time += RunMultiDGemm(arg, stream_config); + } + + arg.Print(); + + // Transpose from NHWGC to NGCHW + if constexpr(NeedTransposeKernel) + { + const index_t grid_size = + arg.elementwise_block_2_ctile_map_transpose_e_.CalculateGridSize( + arg.e_in_transpose_desc_) * + arg.num_workgroups_per_Conv_N_; + + const EDataType* p_e_in_grid = + type_convert(arg.p_workspace_) + + (arg.GetWorkspaceATensorSizeBytes() + arg.GetWorkspaceBTensorSizeBytes()) / + sizeof(EDataType); + + EDataType* p_e_out_grid = arg.p_e_grid_; + + auto kernel_transpose = + kernel_batched_elementwise, + ck::Tuple, + ck::Tuple, + ck::Tuple, + Block2TileMapInOutElementwise, + element_wise::PassThrough, + I1, + I1>; + + ave_time += launch_and_time_kernel( + stream_config, + kernel_transpose, + dim3(grid_size), + dim3(ElementwiseBlocksize), + 0, + make_tuple(arg.e_in_transpose_desc_), + make_tuple(arg.e_out_transpose_desc_), + make_tuple(p_e_in_grid), + make_tuple(p_e_out_grid), + arg.elementwise_block_2_ctile_map_transpose_e_, + element_wise::PassThrough{}, + arg.num_workgroups_per_Conv_N_, + std::array{ + static_cast(arg.compute_ptr_offset_of_n_.BatchStrideE_)}, + std::array{static_cast( + arg.compute_ptr_offset_of_workspace_n_.BatchStrideE_)}); + } + + return ave_time; + } + + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static bool IsSupportedArgument(const Argument& arg) + { + if(!ck::is_gfx11_supported() && !ck::is_gfx12_supported()) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "This configuration is not supported!" << " In " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + } + return false; + } + + if constexpr(!IsSplitKSupported) + { + if(arg.k_batch_ > 1) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "SplitK tests are not supported!" << " In " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + } + + return false; + } + } + + if(ck::is_gfx11_supported() && arg.k_batch_ > 1) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "SplitK tests are not supported!" << " In " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + } + + return false; + } + + const index_t ConvG = arg.b_g_k_c_xs_lengths_[0]; + const index_t ConvK = arg.b_g_k_c_xs_lengths_[1]; + const index_t ConvC = arg.b_g_k_c_xs_lengths_[2]; + const index_t output_spatial_acum = ck::accumulate_n( + arg.e_g_n_c_wis_lengths_.begin() + I3, NDimSpatial, 1, std::multiplies<>()); + const index_t input_spatial_acum = ck::accumulate_n( + arg.a_g_n_k_wos_lengths_.begin() + I3, NDimSpatial, 1, std::multiplies<>()); + + // Specialization + if constexpr(ConvBackwardDataSpecialization == + ConvolutionBackwardDataSpecialization::Filter1x1Stride1Pad0) + { + // check if it's 1x1, stride=1 pad = 0 conv + for(int i = 0; i < NDimSpatial; i++) + { + if(!(arg.b_g_k_c_xs_lengths_[3 + i] == 1 && arg.conv_filter_strides_[i] == 1 && + arg.input_left_pads_[i] == 0 && arg.input_right_pads_[i] == 0)) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "ConvBwdDataSpecialization is unsupported!" << " In " + << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ + << std::endl; + } + + return false; + } + } + } + + // vector load for A matrix from global memory to LDS + if constexpr(is_same_v || + is_same_v || + is_same_v || + is_same_v || NeedTransposeKernel) + { + if(!(ABlockTransferSrcVectorDim == 2 && ConvK % ABlockTransferSrcScalarPerVector == 0)) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "VectorDim is wrong!" << " In " << __FILE__ << ":" << __LINE__ + << ", in function: " << __func__ << std::endl; + } + + return false; + } + } + else if(is_same_v || + is_same_v) + { + static_assert(NeedTransposeKernel == false); + + if constexpr(ABlockTransferSrcScalarPerVector != 1) + { + if(ABlockTransferSrcVectorDim != 1) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "CTranspose is not supported!" << " In " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + } + + return false; + } + if(output_spatial_acum % ABlockTransferSrcScalarPerVector != 0) + { + + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "CTranspose is not supported!" << " In " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + } + + return false; + } + } + } + + else + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Unsupported A Layout!" << " In " << __FILE__ << ":" << __LINE__ + << ", in function: " << __func__ << std::endl; + } + + return false; + } + + // vector load for B matrix from global memory to LDS + if constexpr(is_same_v || + is_same_v || + is_same_v || + is_same_v) + { + + if(!(BBlockTransferSrcVectorDim == 1 && ConvC % BBlockTransferSrcScalarPerVector == 0)) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "VectorDim is wrong!" << " In " << __FILE__ << ":" << __LINE__ + << ", in function: " << __func__ << std::endl; + } + + return false; + } + } + else + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Unsupported B Layout!" << " In " << __FILE__ << ":" << __LINE__ + << ", in function: " << __func__ << std::endl; + } + + return false; + } + + // vector store for Ds + bool ds_valid = true; + + static_for<0, NumDTensor, 1>{}([&](auto i) { + using DLayout = remove_cvref_t>; + + if constexpr(is_same_v || + is_same_v || + is_same_v || + is_same_v || + is_same_v || + is_same_v || + is_same_v) + { + if(CTranspose == false) + { + // vector load D matrix from global memory + if(!(ConvC % CShuffleBlockTransferScalarPerVector_NPerBlock == 0)) + { + + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "VectorDim is wrong!" << " In " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + } + + ds_valid = false; + } + } + else + { + if(input_spatial_acum % CShuffleBlockTransferScalarPerVector_NPerBlock != 0) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "input_spatial_acum / " + "CShuffleBlockTransferScalarPerVector_NPerBlock is wrong!" + << " In " << __FILE__ << ":" << __LINE__ + << ", in function: " << __func__ << std::endl; + } + + ds_valid = false; + } + } + } + else + { + + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "ds_valid is false!" << " In " << __FILE__ << ":" << __LINE__ + << ", in function: " << __func__ << std::endl; + } + + ds_valid = false; + } + }); + + if(!ds_valid) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "ds_valid is false!" << " In " << __FILE__ << ":" << __LINE__ + << ", in function: " << __func__ << std::endl; + } + + return false; + } + + // vector store for E + if constexpr(is_same_v || + is_same_v || + is_same_v || + is_same_v || + is_same_v || + is_same_v) + { + if(CTranspose == false) + { + // vector store C matrix into global memory + if(!(ConvC % CShuffleBlockTransferScalarPerVector_NPerBlock == 0)) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + + std::cout << "VectorDim is wrong!" << " In " << __FILE__ << ":" << __LINE__ + << ", in function: " << __func__ << std::endl; + } + + return false; + } + } + else + { + if(input_spatial_acum % CShuffleBlockTransferScalarPerVector_NPerBlock != 0) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "input_spatial_acum / " + "ChuffleBlockTransferScalarPerVector_NPerBlock is wrong!" + << " In " << __FILE__ << ":" << __LINE__ + << ", in function: " << __func__ << std::endl; + } + + return false; + } + } + } + else + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Unsupported E Layout!" << " In " << __FILE__ << ":" << __LINE__ + << ", in function: " << __func__ << std::endl; + } + + return false; + } + + if constexpr(NeedTransposeKernel) + { + if((ConvG * ConvC) % CShuffleBlockTransferScalarPerVector_NPerBlock != 0) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "VectorDim is wrong!" << " In " << __FILE__ << ":" << __LINE__ + << ", in function: " << __func__ << std::endl; + } + + return false; + } + + if((ConvG * ConvK) % CShuffleBlockTransferScalarPerVector_NPerBlock != 0) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "VectorDim is wrong!" << " In " << __FILE__ << ":" << __LINE__ + << ", in function: " << __func__ << std::endl; + } + + return false; + } + + const index_t a_spatial_acum = ck::accumulate_n( + arg.a_g_n_k_wos_lengths_.begin() + I3, NDimSpatial, 1, std::multiplies<>()); + const index_t e_spatial_acum = ck::accumulate_n( + arg.e_g_n_c_wis_lengths_.begin() + I3, NDimSpatial, 1, std::multiplies<>()); + + if(a_spatial_acum % TransposeTransferInScalarPerVectorAligned != 0) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout + << "a_spatial_acum % TransposeTransferInScalarPerVectorAligned is wrong!" + << " In " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ + << std::endl; + } + + return false; + } + + if(e_spatial_acum % TransposeTransferOutScalarPerVectorAligned != 0) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout + << "e_spatial_acum % TransposeTransferOutScalarPerVectorAligned is wrong!" + << " In " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ + << std::endl; + } + + return false; + } + + if(!arg.p_workspace_) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout + << "Warning: Workspace for " + "DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffleV3::Argument is not " + "allocated, use SetWorkSpacePointer." + << std::endl; + } + return false; + } + } + + // Check gridwise gemm validity + // Create dummy values for Ds pointers and strides + std::array p_ds_grid_dummy; + std::array StrideDs_dummy; + static_for<0, NumDTensor, 1>{}([&](auto i) { + p_ds_grid_dummy[i] = nullptr; + StrideDs_dummy[i] = I0; + }); + for(std::size_t i = 0; i < arg.a_grid_desc_ak0_m_ak1_container_.size(); i++) + { + const index_t GemmM = arg.a_grid_desc_ak0_m_ak1_container_[i].GetLength(I1); + const index_t GemmN = arg.b_grid_desc_bk0_n_bk1_container_[i].GetLength(I1); + const index_t GemmK = arg.a_grid_desc_ak0_m_ak1_container_[i].GetLength(I0) * + arg.a_grid_desc_ak0_m_ak1_container_[i].GetLength(I2); + // Create gemm arguments with dummy values to check for validity + typename GridwiseGemmCTranspose::Argument gemm_arg{ + std::array{nullptr}, // p_as_grid + std::array{nullptr}, // p_bs_grid + p_ds_grid_dummy, // p_ds_grid + nullptr, // p_e_grid + GemmM, // M + GemmN, // N + GemmK, // K + std::array{I0}, // StrideAs + std::array{I0}, // StrideBs + StrideDs_dummy, // StrideDs + I0, // StrideE + arg.k_batch_, + AElementwiseOp{}, + BElementwiseOp{}, + CDEElementwiseOp{}}; + + if(!GridwiseGemmCTranspose::CheckValidity(gemm_arg)) + { + return false; + } + } + + return true; + } + + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto + MakeArgument(const void* p_a, // output image + const void* p_b, // weight + const std::array& p_ds, // bias + void* p_e, // input image + const std::array& a_g_n_k_wos_lengths, // output image + const std::array& a_g_n_k_wos_strides, // output image + const std::array& b_g_k_c_xs_lengths, // weight + const std::array& b_g_k_c_xs_strides, // weight + const std::array, NumDTensor>& + ds_g_n_c_wis_lengths, // bias + const std::array, NumDTensor>& + ds_g_n_c_wis_strides, // bias + const std::array& e_g_n_c_wis_lengths, // input image + const std::array& e_g_n_c_wis_strides, // input image + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads, + const AElementwiseOp& a_element_op, + const BElementwiseOp& b_element_op, + const CDEElementwiseOp& cde_element_op, + const ck::index_t split_k = 1) + { + return Argument{p_a, + p_b, + p_ds, + p_e, + a_g_n_k_wos_lengths, + a_g_n_k_wos_strides, + b_g_k_c_xs_lengths, + b_g_k_c_xs_strides, + ds_g_n_c_wis_lengths, + ds_g_n_c_wis_strides, + e_g_n_c_wis_lengths, + e_g_n_c_wis_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + a_element_op, + b_element_op, + cde_element_op, + split_k}; + } + + static auto MakeInvoker() { return Invoker{}; } + + std::unique_ptr MakeArgumentPointer( + const void* p_a, // output image + const void* p_b, // weight + const std::array& p_ds, // bias + void* p_e, // input image + const std::array& a_g_n_k_wos_lengths, // output image + const std::array& a_g_n_k_wos_strides, // output image + const std::array& b_g_k_c_xs_lengths, // weight + const std::array& b_g_k_c_xs_strides, // weight + const std::array, NumDTensor>& + ds_g_n_c_wis_lengths, // bias + const std::array, NumDTensor>& + ds_g_n_c_wis_strides, // bias + const std::array& e_g_n_c_wis_lengths, // input image + const std::array& e_g_n_c_wis_strides, // input image + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads, + const AElementwiseOp& a_element_op, + const BElementwiseOp& b_element_op, + const CDEElementwiseOp& cde_element_op, + const ck::index_t split_k = 1) override + { + return std::make_unique(p_a, + p_b, + p_ds, + p_e, + a_g_n_k_wos_lengths, + a_g_n_k_wos_strides, + b_g_k_c_xs_lengths, + b_g_k_c_xs_strides, + ds_g_n_c_wis_lengths, + ds_g_n_c_wis_strides, + e_g_n_c_wis_lengths, + e_g_n_c_wis_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + a_element_op, + b_element_op, + cde_element_op, + split_k); + } + + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffleV3" + << "<" + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << KPerBlock << ", " + << AK1 << ", " + << BK1 << ", " + << getConvBackwardDataSpecializationString(ConvBackwardDataSpecialization) << ", " + << MPerWmma << ", " + << NPerWmma << ", " + << MRepeat << ", " + << NRepeat << ", " + << ABlockTransferSrcScalarPerVector << ", " + << BBlockTransferSrcScalarPerVector << ", " + << CShuffleMRepeatPerShuffle << ", " + << CShuffleNRepeatPerShuffle; + + if constexpr(is_NGCHW_NGKHW() || + is_NGCDHW_NGKDHW()) { + str << ", TransposeTransferInScalarPerVectorAligned: " + << TransposeTransferInScalarPerVectorAligned <<", " + << "TransposeTransferOutScalarPerVectorAligned: " << TransposeTransferOutScalarPerVectorAligned; + } + + + str << ">"; + + return str.str(); + } + + size_t GetWorkSpaceSize(const BaseArgument* p_arg) const override + { + auto arg = dynamic_cast(p_arg); + if(arg) + { + return arg->GetWorkspaceSizeBytes(); + } + else + throw std::runtime_error( + "The argument pointer is not an object of " + "DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffleV3::Argument structure!"); + } + + void SetWorkSpacePointer(BaseArgument* p_arg, + void* p_workspace, + const StreamConfig& = StreamConfig{}) const override + { + auto p_arg_ = dynamic_cast(p_arg); + if(p_arg_) + { + p_arg_->p_workspace_ = p_workspace; + } + else + throw std::runtime_error( + "The argument pointer is not an object of " + "DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffleV3::Argument structure!"); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp index 7a9d5517c0..c3c14edfb8 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp @@ -775,6 +775,147 @@ struct GridwiseGemm_wmma_cshuffle_v3 return Block2CTileMap{problem.M, problem.N, 4}; } + // Run method for convolution for bwd_data (grid descriptors are passed as arguments, + // not generated internally) + template + __device__ static void Run(void* p_shared, + const AGridDesc_AK0_M_K1 a_grid_desc_ak0_m_ak1, + const BGridDesc_BK0_N_K1 b_grid_desc_bk0_n_bk1, + const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + ds_grid_desc_mblock_mperblock_nblock_nperblock, + const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock + e_grid_desc_mblock_mperblock_nblock_nperblock, + const Block2CTileMapExt& block_2_ctile_map, + const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch, + const ComputePtrOffsetOfN compute_ptr_offset_of_n, + const index_t num_k_per_block, + Argument& karg, + EpilogueArgument& epilogue_args) + { + const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.y); + const index_t n_idx = __builtin_amdgcn_readfirstlane(blockIdx.z / karg.KBatch); + const index_t k_idx = + __builtin_amdgcn_readfirstlane((blockIdx.z - n_idx * karg.KBatch) * num_k_per_block); + + // offset base pointer for each work-group + const long_index_t a_batch_offset = + CTranspose ? amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx)) + : amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx)); + const long_index_t b_batch_offset = + CTranspose ? amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx)) + : amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx)); + const long_index_t e_batch_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx)); + + const auto ds_batch_offset = compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx); + + const long_index_t a_n_offset = + CTranspose ? 0 : amd_wave_read_first_lane(compute_ptr_offset_of_n.GetAPtrOffset(n_idx)); + const long_index_t b_n_offset = + CTranspose ? amd_wave_read_first_lane(compute_ptr_offset_of_n.GetAPtrOffset(n_idx)) : 0; + const long_index_t e_n_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_n.GetEPtrOffset(n_idx)); + + AsGridPointer p_as_grid_; + static_for<0, NumATensor, 1>{}([&](auto i) { + using ADataType_ = remove_cvref_t>; + p_as_grid_(i) = + static_cast(karg.p_as_grid[i]) + a_batch_offset + a_n_offset; + }); + + BsGridPointer p_bs_grid_; + static_for<0, NumBTensor, 1>{}([&](auto i) { + using BDataType_ = remove_cvref_t>; + p_bs_grid_(i) = + static_cast(karg.p_bs_grid[i]) + b_batch_offset + b_n_offset; + }); + + DsGridPointer p_ds_grid_grp; + static_for<0, NumDTensor, 1>{}( + [&](auto i) { p_ds_grid_grp(i) = karg.p_ds_grid[i] + ds_batch_offset[i]; }); + + // Currently supporting one A and one B + const auto as_grid_desc_ak0_m_ak1 = generate_tuple( + [&](auto i) { + ignore = i; + return a_grid_desc_ak0_m_ak1; + }, + Number{}); + + const auto bs_grid_desc_bk0_n_bk1 = generate_tuple( + [&](auto i) { + ignore = i; + return b_grid_desc_bk0_n_bk1; + }, + Number{}); + + const auto block_work_idx = + block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); + + if(!block_2_ctile_map.ValidCTileIndex( + block_work_idx, + make_tuple(e_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0), + e_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2)))) + { + return; + } + + const index_t block_m_id = __builtin_amdgcn_readfirstlane(block_work_idx[I0]); + const index_t block_n_id = __builtin_amdgcn_readfirstlane(block_work_idx[I1]); + + // AScale struct (Empty) + using AScale = typename BlockwiseGemmPipe::Empty; + auto a_scale_struct = AScale{}; + + // BScale struct (Empty) + using BScale = typename BlockwiseGemmPipe::Empty; + auto b_scale_struct = BScale{}; + + const index_t num_k_block_per_scale = GetKBlockPerScale(); + + Base::template Run(p_as_grid_, + p_bs_grid_, + p_ds_grid_grp, + karg.p_e_grid + e_batch_offset + e_n_offset, + p_shared, + as_grid_desc_ak0_m_ak1, + bs_grid_desc_bk0_n_bk1, + ds_grid_desc_mblock_mperblock_nblock_nperblock, + e_grid_desc_mblock_mperblock_nblock_nperblock, + karg.a_element_op, + karg.b_element_op, + karg.cde_element_op, + block_m_id, + block_n_id, + num_k_block_per_scale, + a_scale_struct, + b_scale_struct, + epilogue_args, + k_idx, + k_idx, + karg.KBatch); + } + // Run method for convolution (grid descriptors are passed as arguments, // not generated internally) template +using S = ck::Sequence; + +using namespace ck::tensor_layout::convolution; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using Bilinear = ck::tensor_operation::element_wise::Bilinear; + +static constexpr auto ConvBwdDataDefault = ConvolutionBackwardDataSpecialization::Default; + +static constexpr auto ConvBwdDataFilter1x1Stride1Pad0 = + ConvolutionBackwardDataSpecialization::Filter1x1Stride1Pad0; + +// bf16_bf16_f32_bf16 +template +using device_grouped_conv_bwd_data_wmma_v3_bilinear_bf16_instances = std::tuple< + // clang-format off + // ##############################################| NDim| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| AElementwise| BElementwise| CDEElementwise| ConvolutionBackward| DoPad| DoPad| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle | CShuffle | CDEBlockTransfer| CDEBlockTransfer| + // ##############################################| Spatial| | | | | Type| Type| Type| DataType| Type| Type| Operation| Operation| Operation| DataSpecialization| GemmM| GemmN| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat | _MBlock_MPerBlock| ScalarPerVector| + // ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| | + // ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // generic instance + DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffleV3, BF16, PassThrough, PassThrough, Bilinear, ConvSpec, true, true, 64, 64, 64, 32, 8, 8, 16, 16, 4, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, S<1,1,1>>, + + DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffleV3, BF16, PassThrough, PassThrough, Bilinear, ConvSpec, true, true, 128, 128, 128, 32, 8, 8, 16, 16, 8, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, S<8,8,8>>, + DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffleV3, BF16, PassThrough, PassThrough, Bilinear, ConvSpec, true, true, 64, 64, 32, 32, 8, 8, 16, 16, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, S<8,8,8>>, + DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffleV3, BF16, PassThrough, PassThrough, Bilinear, ConvSpec, true, true, 128, 32, 64, 32, 8, 8, 16, 16, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 8>, S<8,8,8>>, + DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffleV3, BF16, PassThrough, PassThrough, Bilinear, ConvSpec, true, true, 64, 32, 32, 128, 8, 8, 16, 16, 2, 1, S<16, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<16, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, S<8,8,8>>, + DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffleV3, BF16, PassThrough, PassThrough, Bilinear, ConvSpec, true, true, 128, 128, 128, 32, 8, 8, 16, 16, 8, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, S<8,8,8>>, + DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffleV3, BF16, PassThrough, PassThrough, Bilinear, ConvSpec, true, true, 128, 128, 128, 32, 8, 8, 16, 16, 8, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, S<8,8,8>>, + DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffleV3, BF16, PassThrough, PassThrough, Bilinear, ConvSpec, true, true, 128, 128, 64, 32, 8, 8, 16, 16, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 8>, S<8,8,8>>, + DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffleV3, BF16, PassThrough, PassThrough, Bilinear, ConvSpec, true, true, 128, 128, 128, 32, 8, 8, 16, 16, 8, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, S<8,8,8>>, + DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffleV3, BF16, PassThrough, PassThrough, Bilinear, ConvSpec, true, true, 128, 64, 64, 32, 8, 8, 16, 16, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 8>, S<8,8,8>>, + DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffleV3, BF16, PassThrough, PassThrough, Bilinear, ConvSpec, true, true, 64, 64, 32, 32, 8, 8, 16, 16, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, S<8,8,8>>, + DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffleV3, BF16, PassThrough, PassThrough, Bilinear, ConvSpec, true, true, 64, 32, 64, 32, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 4>, S<8,8,8>>, + DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffleV3, BF16, PassThrough, PassThrough, Bilinear, ConvSpec, true, true, 128, 128, 128, 32, 8, 8, 16, 16, 8, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, S<8,8,8>>, + DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffleV3, BF16, PassThrough, PassThrough, Bilinear, ConvSpec, true, true, 128, 64, 64, 32, 8, 8, 16, 16, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 8>, S<8,8,8>>, + DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffleV3, BF16, PassThrough, PassThrough, Bilinear, ConvSpec, true, true, 256, 128, 128, 32, 8, 8, 16, 16, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 16>, S<8,8,8>> + // clang-format on + >; + +// f16_f16_f32_f16 +template +using device_grouped_conv_bwd_data_wmma_v3_bilinear_f16_instances = std::tuple< + // clang-format off + // ##############################################| NDim| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| AElementwise| BElementwise| CDEElementwise| ConvolutionBackward| DoPad| DoPad| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle | CShuffle | CDEBlockTransfer| CDEBlockTransfer| + // ##############################################| Spatial| | | | | Type| Type| Type| DataType| Type| Type| Operation| Operation| Operation| DataSpecialization| GemmM| GemmN| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat | _MBlock_MPerBlock| ScalarPerVector| + // ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| | + // ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // generic instance + DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffleV3, F16, PassThrough, PassThrough, Bilinear, ConvSpec, true, true, 64, 64, 64, 32, 8, 8, 16, 16, 4, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, S<1,1,1>>, + + DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffleV3, F16, PassThrough, PassThrough, Bilinear, ConvSpec, true, true, 128, 128, 128, 32, 8, 8, 16, 16, 8, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, S<8,8,8>>, + DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffleV3, F16, PassThrough, PassThrough, Bilinear, ConvSpec, true, true, 128, 128, 128, 32, 8, 8, 16, 16, 8, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, S<8,8,8>>, + DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffleV3, F16, PassThrough, PassThrough, Bilinear, ConvSpec, true, true, 64, 64, 64, 32, 8, 8, 16, 16, 4, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 4>, S<8,8,8>>, + DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffleV3, F16, PassThrough, PassThrough, Bilinear, ConvSpec, true, true, 128, 128, 128, 32, 8, 8, 16, 16, 8, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, S<8,8,8>>, + DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffleV3, F16, PassThrough, PassThrough, Bilinear, ConvSpec, true, true, 128, 128, 128, 32, 8, 8, 16, 16, 8, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, S<8,8,8>>, + DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffleV3, F16, PassThrough, PassThrough, Bilinear, ConvSpec, true, true, 128, 128, 128, 32, 8, 8, 16, 16, 8, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, S<8,8,8>>, + DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffleV3, F16, PassThrough, PassThrough, Bilinear, ConvSpec, true, true, 128, 128, 128, 32, 8, 8, 16, 16, 8, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, S<8,8,8>>, + DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffleV3, F16, PassThrough, PassThrough, Bilinear, ConvSpec, true, true, 64, 64, 64, 32, 8, 8, 16, 16, 4, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 4>, S<8,8,8>>, + DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffleV3, F16, PassThrough, PassThrough, Bilinear, ConvSpec, true, true, 128, 128, 128, 32, 8, 8, 16, 16, 8, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, S<8,8,8>>, + DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffleV3, F16, PassThrough, PassThrough, Bilinear, ConvSpec, true, true, 128, 128, 128, 32, 8, 8, 16, 16, 8, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, S<8,8,8>> + // clang-format on + >; + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_wmma_v3_instances.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_wmma_v3_instances.hpp new file mode 100644 index 0000000000..6afb160728 --- /dev/null +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_wmma_v3_instances.hpp @@ -0,0 +1,125 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle_v3.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using BF16 = ck::bhalf_t; +using F16 = ck::half_t; +using F32 = float; +using BF8 = ck::bf8_t; +using F8 = ck::f8_t; + +using Empty_Tuple = ck::Tuple<>; + +template +using S = ck::Sequence; + +using namespace ck::tensor_layout::convolution; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto ConvBwdDataDefault = ConvolutionBackwardDataSpecialization::Default; + +static constexpr auto ConvBwdDataFilter1x1Stride1Pad0 = + ConvolutionBackwardDataSpecialization::Filter1x1Stride1Pad0; +template +using device_grouped_conv_bwd_data_wmma_v3_f16_16_16_instances = std::tuple< + // clang-format off + // ##############################################| NDim| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| AElementwise| BElementwise| CDEElementwise| ConvolutionBackward| DoPad| DoPad| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle | CShuffle | CDEBlockTransfer| CDEBlockTransfer| + // ##############################################| Spatial| | | | | Type| Type| Type| DataType| Type| Type| Operation| Operation| Operation| DataSpecialization| GemmM| GemmN| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat | _MBlock_MPerBlock| ScalarPerVector| + // ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| | + // ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffleV3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 4>, S<8,8,8>>, + DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffleV3, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 4>, S<8,8,8>> + // clang-format on + >; + +template +using device_grouped_conv_bwd_data_wmma_v3_f16_instances = std::tuple< + // clang-format off + // ##############################################| NDim| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| AElementwise| BElementwise| CDEElementwise| ConvolutionBackward| DoPad| DoPad| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle | CShuffle | CDEBlockTransfer| CDEBlockTransfer| + // ##############################################| Spatial| | | | | Type| Type| Type| DataType| Type| Type| Operation| Operation| Operation| DataSpecialization| GemmM| GemmN| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat | _MBlock_MPerBlock| ScalarPerVector| + // ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| | + // ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // generic instance + DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffleV3, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, S<1,1,1>>, + + DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffleV3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 4>, S<8,8,8>>, + DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffleV3, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 4>, S<8,8,8>>, + DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffleV3, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 8>, S<8,8,8>>, + DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffleV3, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 8>, S<8,8,8>>, + DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffleV3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, S<8,8,8>> + // clang-format on + >; + +template +using device_grouped_conv_bwd_data_wmma_v3_bf16_16_16_instances = std::tuple< + // clang-format off + // ##############################################| NDim| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| AElementwise| BElementwise| CDEElementwise| ConvolutionBackward| DoPad| DoPad| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle | CShuffle | CDEBlockTransfer| CDEBlockTransfer| + // ##############################################| Spatial| | | | | Type| Type| Type| DataType| Type| Type| Operation| Operation| Operation| DataSpecialization| GemmM| GemmN| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat | _MBlock_MPerBlock| ScalarPerVector| + // ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| | + // ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffleV3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, S<8,8,8>>, + DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffleV3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<16, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, S<8,8,8>> + // clang-format on + >; + +template +using device_grouped_conv_bwd_data_wmma_v3_bf16_instances = std::tuple< + // clang-format off + // ##############################################| NDim| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| AElementwise| BElementwise| CDEElementwise| ConvolutionBackward| DoPad| DoPad| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle | CShuffle | CDEBlockTransfer| CDEBlockTransfer| + // ##############################################| Spatial| | | | | Type| Type| Type| DataType| Type| Type| Operation| Operation| Operation| DataSpecialization| GemmM| GemmN| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat | _MBlock_MPerBlock| ScalarPerVector| + // ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| | + // ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // generic instance + DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffleV3, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, S<1,1,1>>, + + DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffleV3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, S<8,8,8>>, + DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffleV3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, S<8,8,8>>, + DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffleV3, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 8>, S<8,8,8>>, + DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffleV3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<16, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, S<8,8,8>>, + DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffleV3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, S<8,8,8>>, + DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffleV3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, S<8,8,8>>, + DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffleV3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 8>, S<8,8,8>>, + DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffleV3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, S<8,8,8>>, + DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffleV3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, S<8,8,8>>, + DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffleV3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 16>, S<8,8,8>> + // clang-format on + >; + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_wmma_v3_scale_instance.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_wmma_v3_scale_instance.hpp new file mode 100644 index 0000000000..e7700a8fee --- /dev/null +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_wmma_v3_scale_instance.hpp @@ -0,0 +1,102 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle_v3.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using BF16 = ck::bhalf_t; +using F16 = ck::half_t; +using F32 = float; + +using Empty_Tuple = ck::Tuple<>; + +template +using S = ck::Sequence; + +using namespace ck::tensor_layout::convolution; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using Scale = ck::tensor_operation::element_wise::Scale; + +static constexpr auto ConvBwdDataDefault = ConvolutionBackwardDataSpecialization::Default; + +static constexpr auto ConvBwdDataFilter1x1Stride1Pad0 = + ConvolutionBackwardDataSpecialization::Filter1x1Stride1Pad0; + +// bf16_bf16_f32_bf16 +template +using device_grouped_conv_bwd_data_wmma_v3_scale_bf16_instances = std::tuple< + // clang-format off + // ##############################################| NDim| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| AElementwise| BElementwise| CDEElementwise| ConvolutionBackward| DoPad| DoPad| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle | CShuffle | CDEBlockTransfer| CDEBlockTransfer| + // ##############################################| Spatial| | | | | Type| Type| Type| DataType| Type| Type| Operation| Operation| Operation| DataSpecialization| GemmM| GemmN| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat | _MBlock_MPerBlock| ScalarPerVector| + // ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| | + // ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // generic instance + DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffleV3, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, S<1,1,1>>, + + DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffleV3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, S<8,8,8>>, + DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffleV3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, S<8,8,8>>, + DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffleV3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, S<8,8,8>>, + DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffleV3, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 8>, S<8,8,8>>, + DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffleV3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<16, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, S<8,8,8>>, + DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffleV3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, S<8,8,8>>, + DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffleV3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, S<8,8,8>>, + DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffleV3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 8>, S<8,8,8>>, + DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffleV3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, S<8,8,8>>, + DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffleV3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 8>, S<8,8,8>>, + DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffleV3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, S<8,8,8>>, + DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffleV3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 4>, S<8,8,8>>, + DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffleV3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, S<8,8,8>>, + DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffleV3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 8>, S<8,8,8>>, + DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffleV3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 16>, S<8,8,8>> + // clang-format on + >; + +// f16_f16_f32_f16 +template +using device_grouped_conv_bwd_data_wmma_v3_scale_f16_instances = std::tuple< + // clang-format off + // ##############################################| NDim| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| AElementwise| BElementwise| CDEElementwise| ConvolutionBackward| DoPad| DoPad| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle | CShuffle | CDEBlockTransfer| CDEBlockTransfer| + // ##############################################| Spatial| | | | | Type| Type| Type| DataType| Type| Type| Operation| Operation| Operation| DataSpecialization| GemmM| GemmN| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat | _MBlock_MPerBlock| ScalarPerVector| + // ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| | + // ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // generic instance + DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffleV3, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, S<1,1,1>>, + + DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffleV3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, S<8,8,8>>, + DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffleV3, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, S<8,8,8>>, + DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffleV3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 4>, S<8,8,8>>, + DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffleV3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, S<8,8,8>>, + DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffleV3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, S<8,8,8>>, + DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffleV3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, S<8,8,8>>, + DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffleV3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, S<8,8,8>>, + DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffleV3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 4>, S<8,8,8>>, + DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffleV3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, S<8,8,8>>, + DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffleV3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, S<8,8,8>> + // clang-format on + >; +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data.hpp index 89009c6d0b..f784b6ea51 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data.hpp @@ -422,6 +422,7 @@ struct DeviceOperationInstanceFactory< op_ptrs); } #endif + #ifdef CK_ENABLE_INT8 if constexpr(is_same_v && is_same_v && is_same_v && is_same_v && @@ -441,12 +442,27 @@ struct DeviceOperationInstanceFactory< is_same_v && is_same_v && is_same_v) { + add_device_grouped_conv2d_bwd_data_wmma_v3_nhwgk_gkyxc_nhwgc_f16_instances( + op_ptrs); + add_device_grouped_conv2d_bwd_data_wmma_v3_nhwgk_gkyxc_nhwgc_f16_16_16_instances( + op_ptrs); add_device_grouped_conv2d_bwd_data_wmma_nhwgk_gkyxc_nhwgc_f16_instances( op_ptrs); add_device_grouped_conv2d_bwd_data_wmma_nhwgk_gkyxc_nhwgc_f16_1x1s1p0_instances( op_ptrs); } #endif +#ifdef CK_ENABLE_BF16 + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_conv2d_bwd_data_wmma_v3_nhwgk_gkyxc_nhwgc_bf16_instances( + op_ptrs); + add_device_grouped_conv2d_bwd_data_wmma_v3_nhwgk_gkyxc_nhwgc_bf16_16_16_instances( + op_ptrs); + } +#endif #ifdef CK_ENABLE_INT8 if constexpr(is_same_v && is_same_v && is_same_v && is_same_v && @@ -475,6 +491,7 @@ struct DeviceOperationInstanceFactory< op_ptrs); } #endif + #ifdef CK_ENABLE_INT8 if constexpr(is_same_v && is_same_v && is_same_v && is_same_v && @@ -499,6 +516,21 @@ struct DeviceOperationInstanceFactory< op_ptrs); add_device_grouped_conv3d_bwd_data_wmma_ndhwgk_gkzyxc_ndhwgc_f16_1x1s1p0_instances( op_ptrs); + add_device_grouped_conv3d_bwd_data_wmma_v3_ndhwgk_gkzyxc_ndhwgc_f16_instances( + op_ptrs); + add_device_grouped_conv3d_bwd_data_wmma_v3_ndhwgk_gkzyxc_ndhwgc_f16_16_16_instances( + op_ptrs); + } +#endif +#ifdef CK_ENABLE_BF16 + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_conv3d_bwd_data_wmma_v3_ndhwgk_gkzyxc_ndhwgc_bf16_instances( + op_ptrs); + add_device_grouped_conv3d_bwd_data_wmma_v3_ndhwgk_gkzyxc_ndhwgc_bf16_16_16_instances( + op_ptrs); } #endif #ifdef CK_ENABLE_INT8 @@ -515,7 +547,6 @@ struct DeviceOperationInstanceFactory< } } #endif - return op_ptrs; } }; diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data_bilinear.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data_bilinear.hpp index 84a715b70a..c8269e9fcd 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data_bilinear.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data_bilinear.hpp @@ -15,6 +15,7 @@ namespace tensor_operation { namespace device { namespace instance { +#ifdef CK_USE_XDL #ifdef CK_ENABLE_FP16 void add_device_grouped_conv3d_bwd_data_xdl_bilinear_ndhwgk_gkzyxc_ndhwgc_f16_instances( std::vector>>& instances); #endif +#endif + +#ifdef CK_USE_WMMA +#ifdef CK_ENABLE_FP16 +void add_device_grouped_conv3d_bwd_data_wmma_v3_bilinear_ndhwgk_gkzyxc_ndhwgc_f16_instances( + std::vector, + NDHWGC, + F16, + F16, + Tuple, + F16, + PassThrough, + PassThrough, + Bilinear>>>& instances); +#endif + +#ifdef CK_ENABLE_BF16 +void add_device_grouped_conv3d_bwd_data_wmma_v3_bilinear_ndhwgk_gkzyxc_ndhwgc_bf16_instances( + std::vector, + NDHWGC, + BF16, + BF16, + Tuple, + BF16, + PassThrough, + PassThrough, + Bilinear>>>& instances); +#endif +#endif + template > op_ptrs; + +#ifdef CK_USE_XDL if constexpr(NumDimSpatial == 3) { if constexpr(is_same_v && is_same_v && @@ -169,6 +208,38 @@ struct DeviceOperationInstanceFactory< } } +#endif + +#ifdef CK_USE_WMMA + + if constexpr(NumDimSpatial == 3) + { + if constexpr(is_same_v && is_same_v && + is_same_v) + { +#ifdef CK_ENABLE_FP16 + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_conv3d_bwd_data_wmma_v3_bilinear_ndhwgk_gkzyxc_ndhwgc_f16_instances( + op_ptrs); + } +#endif + +#ifdef CK_ENABLE_BF16 + else if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_conv3d_bwd_data_wmma_v3_bilinear_ndhwgk_gkzyxc_ndhwgc_bf16_instances( + op_ptrs); + } +#endif + } + } +#endif + return op_ptrs; } }; diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data_scale.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data_scale.hpp index c898dbf781..5f189a75a0 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data_scale.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data_scale.hpp @@ -15,6 +15,7 @@ namespace tensor_operation { namespace device { namespace instance { +#ifdef CK_USE_XDL #ifdef CK_ENABLE_FP16 void add_device_grouped_conv3d_bwd_data_xdl_scale_ndhwgk_gkzyxc_ndhwgc_f16_instances( std::vector>>& instances); #endif +#endif + +#ifdef CK_USE_WMMA +#ifdef CK_ENABLE_FP16 +void add_device_grouped_conv3d_bwd_data_wmma_v3_scale_ndhwgk_gkzyxc_ndhwgc_f16_instances( + std::vector, + NDHWGC, + F16, + F16, + Tuple<>, + F16, + PassThrough, + PassThrough, + Scale>>>& instances); +#endif +#ifdef CK_ENABLE_BF16 +void add_device_grouped_conv3d_bwd_data_wmma_v3_scale_ndhwgk_gkzyxc_ndhwgc_bf16_instances( + std::vector, + NDHWGC, + BF16, + BF16, + Tuple<>, + BF16, + PassThrough, + PassThrough, + Scale>>>& instances); +#endif +#endif + template > op_ptrs; +#ifdef CK_USE_XDL if constexpr(NumDimSpatial == 3) { if constexpr(is_same_v && is_same_v && @@ -168,6 +205,36 @@ struct DeviceOperationInstanceFactory< #endif } } +#endif + +#ifdef CK_USE_WMMA + + if constexpr(NumDimSpatial == 3) + { + if constexpr(is_same_v && is_same_v && + is_same_v) + { +#ifdef CK_ENABLE_FP16 + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_conv3d_bwd_data_wmma_v3_scale_ndhwgk_gkzyxc_ndhwgc_f16_instances( + op_ptrs); + } +#endif +#ifdef CK_ENABLE_BF16 + else if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_conv3d_bwd_data_wmma_v3_scale_ndhwgk_gkzyxc_ndhwgc_bf16_instances( + op_ptrs); + } +#endif + } + } +#endif return op_ptrs; } diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data_wmma.inc b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data_wmma.inc index 40f36d24a5..40b659a87f 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data_wmma.inc +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data_wmma.inc @@ -38,6 +38,34 @@ void add_device_grouped_conv2d_bwd_data_wmma_gnhwk_gkyxc_gnhwc_f16_1x1s1p0_insta PassThrough, PassThrough>>>& instances); +void add_device_grouped_conv2d_bwd_data_wmma_v3_nhwgk_gkyxc_nhwgc_f16_instances( + std::vector>>& instances); + +void add_device_grouped_conv2d_bwd_data_wmma_v3_nhwgk_gkyxc_nhwgc_f16_16_16_instances( + std::vector>>& instances); + void add_device_grouped_conv2d_bwd_data_wmma_nhwgk_gkyxc_nhwgc_f16_instances( std::vector>>& instances); - void add_device_grouped_conv3d_bwd_data_wmma_gndhwk_gkzyxc_gndhwc_f16_instances( std::vector>>& instances); #endif +#ifdef CK_ENABLE_BF16 + +void add_device_grouped_conv2d_bwd_data_wmma_v3_nhwgk_gkyxc_nhwgc_bf16_instances( + std::vector>>& instances); + +void add_device_grouped_conv2d_bwd_data_wmma_v3_nhwgk_gkyxc_nhwgc_bf16_16_16_instances( + std::vector>>& instances); + +#endif + +// conv3dbwdData + +#ifdef CK_ENABLE_BF16 +void add_device_grouped_conv3d_bwd_data_wmma_v3_ndhwgk_gkzyxc_ndhwgc_bf16_instances( + std::vector>>& instances); + +void add_device_grouped_conv3d_bwd_data_wmma_v3_ndhwgk_gkzyxc_ndhwgc_bf16_16_16_instances( + std::vector>>& instances); +#endif +#ifdef CK_ENABLE_FP16 +void add_device_grouped_conv3d_bwd_data_wmma_v3_ndhwgk_gkzyxc_ndhwgc_f16_instances( + std::vector>>& instances); + +void add_device_grouped_conv3d_bwd_data_wmma_v3_ndhwgk_gkzyxc_ndhwgc_f16_16_16_instances( + std::vector>>& instances); +#endif + } // namespace instance } // namespace device } // namespace tensor_operation diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/CMakeLists.txt index 9da738480b..19e27cf173 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/CMakeLists.txt @@ -3,7 +3,8 @@ # ONLY XDL_AND_WMMA_KERNELS add_instance_library( - device_grouped_conv2d_bwd_data_instance + device_grouped_conv2d_bwd_data_instance + xdl/device_grouped_conv2d_bwd_data_xdl_gnhwc_gkyxc_gnhwk_f16_instance.cpp xdl/device_grouped_conv2d_bwd_data_xdl_gnhwc_gkyxc_gnhwk_bf16_instance.cpp xdl/device_grouped_conv2d_bwd_data_xdl_gnhwc_gkyxc_gnhwk_f32_instance.cpp @@ -40,4 +41,13 @@ add_instance_library( wmma/device_grouped_conv2d_bwd_data_wmma_nhwgc_gkyxc_nhwgk_f16_instance.cpp wmma/device_grouped_conv2d_bwd_data_wmma_gnhwc_gkyxc_gnhwk_i8_instance.cpp wmma/device_grouped_conv2d_bwd_data_wmma_nhwgc_gkyxc_nhwgk_i8_instance.cpp + + + wmma/device_grouped_conv2d_bwd_data_wmma_v3_nhwgc_gkyxc_nhwgk_f16_instance.cpp + wmma/device_grouped_conv2d_bwd_data_wmma_v3_nhwgc_gkyxc_nhwgk_f16_16_16_instance.cpp + + wmma/device_grouped_conv2d_bwd_data_wmma_v3_nhwgc_gkyxc_nhwgk_bf16_instance.cpp + wmma/device_grouped_conv2d_bwd_data_wmma_v3_nhwgc_gkyxc_nhwgk_bf16_16_16_instance.cpp + + ) diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/wmma/device_grouped_conv2d_bwd_data_wmma_v3_nhwgc_gkyxc_nhwgk_bf16_16_16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/wmma/device_grouped_conv2d_bwd_data_wmma_v3_nhwgc_gkyxc_nhwgk_bf16_16_16_instance.cpp new file mode 100644 index 0000000000..0bfb02e692 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/wmma/device_grouped_conv2d_bwd_data_wmma_v3_nhwgc_gkyxc_nhwgk_bf16_16_16_instance.cpp @@ -0,0 +1,49 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_wmma_v3_instances.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_conv2d_bwd_data_wmma_v3_nhwgk_gkyxc_nhwgc_bf16_16_16_instances( + std::vector>>& instances) +{ + // 1. Default + add_device_operation_instances( + instances, + device_grouped_conv_bwd_data_wmma_v3_bf16_16_16_instances<2, + NHWGK, + GKYXC, + Empty_Tuple, + NHWGC, + ConvBwdDataDefault>{}); + // 2. Filter1x1Stride1Pad0 + add_device_operation_instances(instances, + device_grouped_conv_bwd_data_wmma_v3_bf16_16_16_instances< + 2, + NHWGK, + GKYXC, + Empty_Tuple, + NHWGC, + ConvBwdDataFilter1x1Stride1Pad0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/wmma/device_grouped_conv2d_bwd_data_wmma_v3_nhwgc_gkyxc_nhwgk_bf16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/wmma/device_grouped_conv2d_bwd_data_wmma_v3_nhwgc_gkyxc_nhwgk_bf16_instance.cpp new file mode 100644 index 0000000000..ee7e26523f --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/wmma/device_grouped_conv2d_bwd_data_wmma_v3_nhwgc_gkyxc_nhwgk_bf16_instance.cpp @@ -0,0 +1,49 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_wmma_v3_instances.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_conv2d_bwd_data_wmma_v3_nhwgk_gkyxc_nhwgc_bf16_instances( + std::vector>>& instances) +{ + // 1. Default + add_device_operation_instances( + instances, + device_grouped_conv_bwd_data_wmma_v3_bf16_instances<2, + NHWGK, + GKYXC, + Empty_Tuple, + NHWGC, + ConvBwdDataDefault>{}); + // 2. Filter1x1Stride1Pad0 + add_device_operation_instances( + instances, + device_grouped_conv_bwd_data_wmma_v3_bf16_instances<2, + NHWGK, + GKYXC, + Empty_Tuple, + NHWGC, + ConvBwdDataFilter1x1Stride1Pad0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/wmma/device_grouped_conv2d_bwd_data_wmma_v3_nhwgc_gkyxc_nhwgk_f16_16_16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/wmma/device_grouped_conv2d_bwd_data_wmma_v3_nhwgc_gkyxc_nhwgk_f16_16_16_instance.cpp new file mode 100644 index 0000000000..b8f85c8c4f --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/wmma/device_grouped_conv2d_bwd_data_wmma_v3_nhwgc_gkyxc_nhwgk_f16_16_16_instance.cpp @@ -0,0 +1,49 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_wmma_v3_instances.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_conv2d_bwd_data_wmma_v3_nhwgk_gkyxc_nhwgc_f16_16_16_instances( + std::vector>>& instances) +{ + // 1. Default + add_device_operation_instances( + instances, + device_grouped_conv_bwd_data_wmma_v3_f16_16_16_instances<2, + NHWGK, + GKYXC, + Empty_Tuple, + NHWGC, + ConvBwdDataDefault>{}); + // 2. Filter1x1Stride1Pad0 + add_device_operation_instances(instances, + device_grouped_conv_bwd_data_wmma_v3_f16_16_16_instances< + 2, + NHWGK, + GKYXC, + Empty_Tuple, + NHWGC, + ConvBwdDataFilter1x1Stride1Pad0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/wmma/device_grouped_conv2d_bwd_data_wmma_v3_nhwgc_gkyxc_nhwgk_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/wmma/device_grouped_conv2d_bwd_data_wmma_v3_nhwgc_gkyxc_nhwgk_f16_instance.cpp new file mode 100644 index 0000000000..6965bb00fc --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/wmma/device_grouped_conv2d_bwd_data_wmma_v3_nhwgc_gkyxc_nhwgk_f16_instance.cpp @@ -0,0 +1,49 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_wmma_v3_instances.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_conv2d_bwd_data_wmma_v3_nhwgk_gkyxc_nhwgc_f16_instances( + std::vector>>& instances) +{ + // 1. Default + add_device_operation_instances( + instances, + device_grouped_conv_bwd_data_wmma_v3_f16_instances<2, + NHWGK, + GKYXC, + Empty_Tuple, + NHWGC, + ConvBwdDataDefault>{}); + // 2. Filter1x1Stride1Pad0 + add_device_operation_instances( + instances, + device_grouped_conv_bwd_data_wmma_v3_f16_instances<2, + NHWGK, + GKYXC, + Empty_Tuple, + NHWGC, + ConvBwdDataFilter1x1Stride1Pad0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/CMakeLists.txt index a2a792e745..01ff4095d7 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/CMakeLists.txt @@ -38,7 +38,15 @@ set(GROUPED_CONV3D_BWD_DATA wmma/device_grouped_conv3d_bwd_data_wmma_gndhwc_gkzyxc_gndhwk_f16_1x1s1p0_instance.cpp wmma/device_grouped_conv3d_bwd_data_wmma_ndhwgc_gkzyxc_ndhwgk_f16_1x1s1p0_instance.cpp wmma/device_grouped_conv3d_bwd_data_wmma_gndhwc_gkzyxc_gndhwk_i8_1x1s1p0_instance.cpp - wmma/device_grouped_conv3d_bwd_data_wmma_ndhwgc_gkzyxc_ndhwgk_i8_1x1s1p0_instance.cpp) + wmma/device_grouped_conv3d_bwd_data_wmma_ndhwgc_gkzyxc_ndhwgk_i8_1x1s1p0_instance.cpp + + wmma/device_grouped_conv3d_bwd_data_wmma_v3_ndhwgc_gkzyxc_ndhwgk_bf16_16_16_instance.cpp + wmma/device_grouped_conv3d_bwd_data_wmma_v3_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp + + wmma/device_grouped_conv3d_bwd_data_wmma_v3_ndhwgc_gkzyxc_ndhwgk_f16_16_16_instance.cpp + wmma/device_grouped_conv3d_bwd_data_wmma_v3_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp +) + if((DTYPES MATCHES "fp8" AND DTYPES MATCHES "bf8" AND DTYPES MATCHES "fp16") OR NOT DEFINED DTYPES) list(APPEND GROUPED_CONV3D_BWD_DATA diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/wmma/device_grouped_conv3d_bwd_data_wmma_v3_ndhwgc_gkzyxc_ndhwgk_bf16_16_16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/wmma/device_grouped_conv3d_bwd_data_wmma_v3_ndhwgc_gkzyxc_ndhwgk_bf16_16_16_instance.cpp new file mode 100644 index 0000000000..1822efaaaf --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/wmma/device_grouped_conv3d_bwd_data_wmma_v3_ndhwgc_gkzyxc_ndhwgk_bf16_16_16_instance.cpp @@ -0,0 +1,48 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_wmma_v3_instances.hpp" +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_conv3d_bwd_data_wmma_v3_ndhwgk_gkzyxc_ndhwgc_bf16_16_16_instances( + std::vector>>& instances) +{ + // 1. Default + add_device_operation_instances( + instances, + device_grouped_conv_bwd_data_wmma_v3_bf16_16_16_instances<3, + NDHWGK, + GKZYXC, + Empty_Tuple, + NDHWGC, + ConvBwdDataDefault>{}); + // 2. Filter1x1Stride1Pad0 + add_device_operation_instances(instances, + device_grouped_conv_bwd_data_wmma_v3_bf16_16_16_instances< + 3, + NDHWGK, + GKZYXC, + Empty_Tuple, + NDHWGC, + ConvBwdDataFilter1x1Stride1Pad0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/wmma/device_grouped_conv3d_bwd_data_wmma_v3_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/wmma/device_grouped_conv3d_bwd_data_wmma_v3_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp new file mode 100644 index 0000000000..b85ab90331 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/wmma/device_grouped_conv3d_bwd_data_wmma_v3_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp @@ -0,0 +1,48 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_wmma_v3_instances.hpp" +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_conv3d_bwd_data_wmma_v3_ndhwgk_gkzyxc_ndhwgc_bf16_instances( + std::vector>>& instances) +{ + // 1. Default + add_device_operation_instances( + instances, + device_grouped_conv_bwd_data_wmma_v3_bf16_instances<3, + NDHWGK, + GKZYXC, + Empty_Tuple, + NDHWGC, + ConvBwdDataDefault>{}); + // 2. Filter1x1Stride1Pad0 + add_device_operation_instances( + instances, + device_grouped_conv_bwd_data_wmma_v3_bf16_instances<3, + NDHWGK, + GKZYXC, + Empty_Tuple, + NDHWGC, + ConvBwdDataFilter1x1Stride1Pad0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/wmma/device_grouped_conv3d_bwd_data_wmma_v3_ndhwgc_gkzyxc_ndhwgk_f16_16_16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/wmma/device_grouped_conv3d_bwd_data_wmma_v3_ndhwgc_gkzyxc_ndhwgk_f16_16_16_instance.cpp new file mode 100644 index 0000000000..ae5a3fac28 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/wmma/device_grouped_conv3d_bwd_data_wmma_v3_ndhwgc_gkzyxc_ndhwgk_f16_16_16_instance.cpp @@ -0,0 +1,48 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_wmma_v3_instances.hpp" +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_conv3d_bwd_data_wmma_v3_ndhwgk_gkzyxc_ndhwgc_f16_16_16_instances( + std::vector>>& instances) +{ + // 1. Default + add_device_operation_instances( + instances, + device_grouped_conv_bwd_data_wmma_v3_f16_16_16_instances<3, + NDHWGK, + GKZYXC, + Empty_Tuple, + NDHWGC, + ConvBwdDataDefault>{}); + // 2. Filter1x1Stride1Pad0 + add_device_operation_instances(instances, + device_grouped_conv_bwd_data_wmma_v3_f16_16_16_instances< + 3, + NDHWGK, + GKZYXC, + Empty_Tuple, + NDHWGC, + ConvBwdDataFilter1x1Stride1Pad0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/wmma/device_grouped_conv3d_bwd_data_wmma_v3_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/wmma/device_grouped_conv3d_bwd_data_wmma_v3_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp new file mode 100644 index 0000000000..5583c65e86 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/wmma/device_grouped_conv3d_bwd_data_wmma_v3_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp @@ -0,0 +1,48 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_wmma_v3_instances.hpp" +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_conv3d_bwd_data_wmma_v3_ndhwgk_gkzyxc_ndhwgc_f16_instances( + std::vector>>& instances) +{ + // 1. Default + add_device_operation_instances( + instances, + device_grouped_conv_bwd_data_wmma_v3_f16_instances<3, + NDHWGK, + GKZYXC, + Empty_Tuple, + NDHWGC, + ConvBwdDataDefault>{}); + // 2. Filter1x1Stride1Pad0 + add_device_operation_instances( + instances, + device_grouped_conv_bwd_data_wmma_v3_f16_instances<3, + NDHWGK, + GKZYXC, + Empty_Tuple, + NDHWGC, + ConvBwdDataFilter1x1Stride1Pad0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data_bilinear/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data_bilinear/CMakeLists.txt index 69ea0c5ccf..18a127c4a9 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data_bilinear/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data_bilinear/CMakeLists.txt @@ -1,11 +1,14 @@ # Copyright (c) Advanced Micro Devices, Inc., or its affiliates. # SPDX-License-Identifier: MIT -# ONLY XDL_KERNELS +#WMMA_AND_XDL_KERNELS set(GROUPED_CONV3D_BWD_DATA_BILINEAR xdl/device_grouped_conv3d_bwd_data_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp xdl/device_grouped_conv3d_bwd_data_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp xdl/device_grouped_conv3d_bwd_data_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp - xdl/device_grouped_conv3d_bwd_data_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instance.cpp) + xdl/device_grouped_conv3d_bwd_data_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instance.cpp + + wmma/device_grouped_conv3d_bwd_data_wmma_v3_bilinear_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp + wmma/device_grouped_conv3d_bwd_data_wmma_v3_bilinear_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp) add_instance_library(device_grouped_conv3d_bwd_data_bilinear_instance ${GROUPED_CONV3D_BWD_DATA_BILINEAR}) diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data_bilinear/wmma/device_grouped_conv3d_bwd_data_wmma_v3_bilinear_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data_bilinear/wmma/device_grouped_conv3d_bwd_data_wmma_v3_bilinear_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp new file mode 100644 index 0000000000..98527eb425 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data_bilinear/wmma/device_grouped_conv3d_bwd_data_wmma_v3_bilinear_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp @@ -0,0 +1,49 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_wmma_v3_bilinear_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_conv3d_bwd_data_wmma_v3_bilinear_ndhwgk_gkzyxc_ndhwgc_bf16_instances( + std::vector, + NDHWGC, + BF16, + BF16, + Tuple, + BF16, + PassThrough, + PassThrough, + Bilinear>>>& instances) +{ + // 1. Default + add_device_operation_instances( + instances, + device_grouped_conv_bwd_data_wmma_v3_bilinear_bf16_instances<3, + NDHWGK, + GKZYXC, + Tuple, + NDHWGC, + ConvBwdDataDefault>{}); + // 2. Filter1x1Stride1Pad0 + add_device_operation_instances(instances, + device_grouped_conv_bwd_data_wmma_v3_bilinear_bf16_instances< + 3, + NDHWGK, + GKZYXC, + Tuple, + NDHWGC, + ConvBwdDataFilter1x1Stride1Pad0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data_bilinear/wmma/device_grouped_conv3d_bwd_data_wmma_v3_bilinear_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data_bilinear/wmma/device_grouped_conv3d_bwd_data_wmma_v3_bilinear_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp new file mode 100644 index 0000000000..3b71f14566 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data_bilinear/wmma/device_grouped_conv3d_bwd_data_wmma_v3_bilinear_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp @@ -0,0 +1,49 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_wmma_v3_bilinear_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_conv3d_bwd_data_wmma_v3_bilinear_ndhwgk_gkzyxc_ndhwgc_f16_instances( + std::vector, + NDHWGC, + F16, + F16, + Tuple, + F16, + PassThrough, + PassThrough, + Bilinear>>>& instances) +{ + // 1. Default + add_device_operation_instances( + instances, + device_grouped_conv_bwd_data_wmma_v3_bilinear_f16_instances<3, + NDHWGK, + GKZYXC, + Tuple, + NDHWGC, + ConvBwdDataDefault>{}); + // 2. Filter1x1Stride1Pad0 + add_device_operation_instances(instances, + device_grouped_conv_bwd_data_wmma_v3_bilinear_f16_instances< + 3, + NDHWGK, + GKZYXC, + Tuple, + NDHWGC, + ConvBwdDataFilter1x1Stride1Pad0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data_scale/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data_scale/CMakeLists.txt index a3837c51b9..9c78904abe 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data_scale/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data_scale/CMakeLists.txt @@ -1,11 +1,14 @@ # Copyright (c) Advanced Micro Devices, Inc., or its affiliates. # SPDX-License-Identifier: MIT -# ONLY XDL_KERNELS -set(GROUPED_CONV3D_BWD_DATA_BILINEAR +# WMMA_AND_XDL_KERNELS +set(GROUPED_CONV3D_BWD_DATA_SCALE xdl/device_grouped_conv3d_bwd_data_xdl_scale_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp xdl/device_grouped_conv3d_bwd_data_xdl_scale_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp xdl/device_grouped_conv3d_bwd_data_xdl_scale_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp - xdl/device_grouped_conv3d_bwd_data_xdl_scale_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instance.cpp) + xdl/device_grouped_conv3d_bwd_data_xdl_scale_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instance.cpp + + wmma/device_grouped_conv3d_bwd_data_wmma_v3_scale_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp + wmma/device_grouped_conv3d_bwd_data_wmma_v3_scale_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp) -add_instance_library(device_grouped_conv3d_bwd_data_scale_instance ${GROUPED_CONV3D_BWD_DATA_BILINEAR}) +add_instance_library(device_grouped_conv3d_bwd_data_scale_instance ${GROUPED_CONV3D_BWD_DATA_SCALE}) \ No newline at end of file diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data_scale/wmma/device_grouped_conv3d_bwd_data_wmma_v3_scale_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data_scale/wmma/device_grouped_conv3d_bwd_data_wmma_v3_scale_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp new file mode 100644 index 0000000000..04ee8007d7 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data_scale/wmma/device_grouped_conv3d_bwd_data_wmma_v3_scale_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp @@ -0,0 +1,49 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_wmma_v3_scale_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_conv3d_bwd_data_wmma_v3_scale_ndhwgk_gkzyxc_ndhwgc_bf16_instances( + std::vector, + NDHWGC, + BF16, + BF16, + Tuple<>, + BF16, + PassThrough, + PassThrough, + Scale>>>& instances) +{ + // 1. Default + add_device_operation_instances( + instances, + device_grouped_conv_bwd_data_wmma_v3_scale_bf16_instances<3, + NDHWGK, + GKZYXC, + Tuple<>, + NDHWGC, + ConvBwdDataDefault>{}); + // 2. Filter1x1Stride1Pad0 + add_device_operation_instances(instances, + device_grouped_conv_bwd_data_wmma_v3_scale_bf16_instances< + 3, + NDHWGK, + GKZYXC, + Tuple<>, + NDHWGC, + ConvBwdDataFilter1x1Stride1Pad0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data_scale/wmma/device_grouped_conv3d_bwd_data_wmma_v3_scale_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data_scale/wmma/device_grouped_conv3d_bwd_data_wmma_v3_scale_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp new file mode 100644 index 0000000000..de197f93ce --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data_scale/wmma/device_grouped_conv3d_bwd_data_wmma_v3_scale_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp @@ -0,0 +1,49 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_wmma_v3_scale_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_conv3d_bwd_data_wmma_v3_scale_ndhwgk_gkzyxc_ndhwgc_f16_instances( + std::vector, + NDHWGC, + F16, + F16, + Tuple<>, + F16, + PassThrough, + PassThrough, + Scale>>>& instances) +{ + // 1. Default + add_device_operation_instances( + instances, + device_grouped_conv_bwd_data_wmma_v3_scale_f16_instances<3, + NDHWGK, + GKZYXC, + Tuple<>, + NDHWGC, + ConvBwdDataDefault>{}); + // 2. Filter1x1Stride1Pad0 + add_device_operation_instances(instances, + device_grouped_conv_bwd_data_wmma_v3_scale_f16_instances< + 3, + NDHWGK, + GKZYXC, + Tuple<>, + NDHWGC, + ConvBwdDataFilter1x1Stride1Pad0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 719ab861ce..81d1ed4063 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -40,6 +40,9 @@ set(REGRESSION_TESTS test_batchnorm_fwd_rank_4 test_batchnorm_bwd_rank_4 test_grouped_convnd_bwd_data_xdl + test_grouped_convnd_bwd_data_wmma + test_grouped_convnd_bwd_data_wmma_large_cases + test_grouped_conv_bwd_data_wmma_scale test_conv_tensor_rearrange test_gemm_mx test_ck_tile_batched_transpose diff --git a/test/grouped_convnd_bwd_data/CMakeLists.txt b/test/grouped_convnd_bwd_data/CMakeLists.txt index a9413bd25b..0f6285cfea 100644 --- a/test/grouped_convnd_bwd_data/CMakeLists.txt +++ b/test/grouped_convnd_bwd_data/CMakeLists.txt @@ -1,22 +1,25 @@ # Copyright (c) Advanced Micro Devices, Inc., or its affiliates. # SPDX-License-Identifier: MIT -add_gtest_executable(test_grouped_convnd_bwd_data_xdl test_grouped_convnd_bwd_data_xdl.cpp) -if(result EQUAL 0) - target_link_libraries(test_grouped_convnd_bwd_data_xdl PRIVATE utility device_grouped_conv2d_bwd_data_instance device_grouped_conv3d_bwd_data_instance) -endif() if(GPU_TARGETS MATCHES "gfx9|gfx11|gfx12") - add_executable(test_grouped_convnd_bwd_data_xdl_large_cases test_grouped_convnd_bwd_data_xdl_large_cases.cpp) - target_compile_options(test_grouped_convnd_bwd_data_xdl_large_cases PRIVATE -Wno-global-constructors -Wno-undef) - target_link_libraries(test_grouped_convnd_bwd_data_xdl_large_cases PRIVATE gtest_main getopt::getopt utility device_grouped_conv2d_bwd_data_instance device_grouped_conv3d_bwd_data_instance) + add_gtest_executable(test_grouped_convnd_bwd_data test_grouped_convnd_bwd_data.cpp) + target_link_libraries(test_grouped_convnd_bwd_data PRIVATE utility device_grouped_conv2d_bwd_data_instance device_grouped_conv3d_bwd_data_instance) + + add_executable(test_grouped_convnd_bwd_data_large_cases test_grouped_convnd_bwd_data_large_cases.cpp) + target_compile_options(test_grouped_convnd_bwd_data_large_cases PRIVATE -Wno-global-constructors -Wno-undef) + target_link_libraries(test_grouped_convnd_bwd_data_large_cases PRIVATE gtest_main getopt::getopt utility device_grouped_conv2d_bwd_data_instance device_grouped_conv3d_bwd_data_instance) add_executable(test_grouped_convnd_bwd_data_dataset_xdl test_grouped_convnd_bwd_data_dataset_xdl.cpp) target_compile_options(test_grouped_convnd_bwd_data_dataset_xdl PRIVATE -Wno-global-constructors -Wno-undef) target_link_libraries(test_grouped_convnd_bwd_data_dataset_xdl PRIVATE gtest_main getopt::getopt utility device_grouped_conv2d_bwd_data_instance device_grouped_conv3d_bwd_data_instance) -endif() -add_gtest_executable(test_grouped_convnd_bwd_data_wmma test_grouped_convnd_bwd_data_wmma.cpp) -if(result EQUAL 0) - target_link_libraries(test_grouped_convnd_bwd_data_wmma PRIVATE utility device_grouped_conv2d_bwd_data_instance device_grouped_conv3d_bwd_data_instance) + + add_executable(test_grouped_conv_bwd_data_bilinear test_grouped_conv_bwd_data_bilinear.cpp) + target_compile_options(test_grouped_conv_bwd_data_bilinear PRIVATE -Wno-global-constructors -Wno-undef) + target_link_libraries(test_grouped_conv_bwd_data_bilinear PRIVATE gtest_main getopt::getopt utility device_grouped_conv3d_bwd_data_bilinear_instance) + + add_executable(test_grouped_conv_bwd_data_scale test_grouped_conv_bwd_data_scale.cpp) + target_compile_options(test_grouped_conv_bwd_data_scale PRIVATE -Wno-global-constructors -Wno-undef) + target_link_libraries(test_grouped_conv_bwd_data_scale PRIVATE gtest_main getopt::getopt utility device_grouped_conv3d_bwd_data_scale_instance) endif() add_gtest_executable(test_grouped_convnd_bwd_data_interface_xdl test_grouped_convnd_bwd_data_interface_xdl.cpp) if(result EQUAL 0) @@ -25,4 +28,4 @@ endif() add_gtest_executable(test_grouped_convnd_bwd_data_interface_wmma test_grouped_convnd_bwd_data_interface_wmma.cpp) if(result EQUAL 0) target_link_libraries(test_grouped_convnd_bwd_data_interface_wmma PRIVATE utility device_grouped_conv2d_bwd_data_instance) -endif() +endif() \ No newline at end of file diff --git a/test/grouped_convnd_bwd_data/test_grouped_conv_bwd_data_bilinear.cpp b/test/grouped_convnd_bwd_data/test_grouped_conv_bwd_data_bilinear.cpp new file mode 100644 index 0000000000..b45f204b40 --- /dev/null +++ b/test/grouped_convnd_bwd_data/test_grouped_conv_bwd_data_bilinear.cpp @@ -0,0 +1,324 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include +#include +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" + +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data_bilinear.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/convolution_parameter.hpp" +#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_conv_bwd_data.hpp" + +using ::ck::DeviceMem; +using ::ck::HostTensorDescriptor; +using ::ck::Tensor; + +static ck::index_t param_mask = 0xffff; +static ck::index_t instance_index = -1; + +template +class TestGroupedConvndBwdData : public ::testing::Test +{ + protected: + using InDataType = std::tuple_element_t<0, Tuple>; + using WeiDataType = std::tuple_element_t<0, Tuple>; + using OutDataType = std::tuple_element_t<0, Tuple>; + + using ComputeDataType = InDataType; + using InLayout = std::tuple_element_t<3, Tuple>; + using WeiLayout = std::tuple_element_t<2, Tuple>; + using OutLayout = std::tuple_element_t<1, Tuple>; + + using WeiElementOp = ck::tensor_operation::element_wise::PassThrough; + using InElementOp = ck::tensor_operation::element_wise::Bilinear; + using OutElementOp = ck::tensor_operation::element_wise::PassThrough; + using PassThrough = ck::tensor_operation::element_wise::PassThrough; + + using Bilinear = ck::tensor_operation::element_wise::Bilinear; + static constexpr ck::index_t NDimSpatial = 3; + static constexpr float alpha = 2.f; + static constexpr float beta = 2.f; + static constexpr ck::index_t NumDs = 1; + + std::vector conv_params; + std::vector split_ks{1}; + + void RunReference(ck::utils::conv::ConvParam& conv_param, + Tensor& in_host, + Tensor& wei, + Tensor& out, + Tensor& d) + { + + std::array, NumDs> d_tensors = {d}; + auto ref_conv = + ck::tensor_operation::host::ReferenceConvBwdData(); + + auto ref_invoker = ref_conv.MakeInvoker(); + + auto ref_argument = ref_conv.MakeArgument(in_host, + wei, + out, + conv_param.conv_filter_strides_, + conv_param.conv_filter_dilations_, + conv_param.input_left_pads_, + conv_param.input_right_pads_, + Bilinear{alpha, beta}, + WeiElementOp{}, + OutElementOp{}, + {}, + {}, + d_tensors); + + ref_invoker.Run(ref_argument); + } + + bool PerformConvDataBilinear(ck::utils::conv::ConvParam& conv_param, + const ck::index_t split_k, + ck::index_t instance_index_ = -1) + { + bool passed = true; + + const auto out_g_n_k_wos_desc = + ck::utils::conv::make_output_host_tensor_descriptor_g_n_k_wos_packed( + conv_param); + + const auto wei_g_k_c_xs_desc = + ck::utils::conv::make_weight_host_tensor_descriptor_g_k_c_xs_packed( + conv_param); + + const auto in_g_n_c_wis_desc = + ck::utils::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed( + conv_param); + + Tensor wei(wei_g_k_c_xs_desc); + Tensor out(out_g_n_k_wos_desc); + Tensor in_host(in_g_n_c_wis_desc); + Tensor in_device(in_g_n_c_wis_desc); + Tensor d(in_g_n_c_wis_desc); + + std::cout << "in: " << in_host.mDesc << std::endl; + std::cout << "wei: " << wei.mDesc << std::endl; + std::cout << "out: " << out.mDesc << std::endl; + + out.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + wei.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + d.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + + DeviceMem wei_device_buf(sizeof(WeiDataType) * wei.mDesc.GetElementSpaceSize()); + DeviceMem out_device_buf(sizeof(OutDataType) * out.mDesc.GetElementSpaceSize()); + DeviceMem in_device_buf(sizeof(InDataType) * in_device.mDesc.GetElementSpaceSize()); + DeviceMem d_device_buf(sizeof(InDataType) * d.mDesc.GetElementSpaceSize()); + + in_device_buf.ToDevice(in_device.mData.data()); + out_device_buf.ToDevice(out.mData.data()); + wei_device_buf.ToDevice(wei.mData.data()); + d_device_buf.ToDevice(d.mData.data()); + + std::array out_lengths{}; + std::array out_strides{}; + std::array wei_lengths{}; + std::array wei_strides{}; + std::array in_lengths{}; + std::array in_strides{}; + std::array conv_filter_strides{}; + std::array conv_filter_dilations{}; + std::array input_left_pads{}; + std::array input_right_pads{}; + + auto copy = [](const auto& x, auto& y) { ck::ranges::copy(x, y.begin()); }; + + copy(out_g_n_k_wos_desc.GetLengths(), out_lengths); + copy(out_g_n_k_wos_desc.GetStrides(), out_strides); + copy(wei_g_k_c_xs_desc.GetLengths(), wei_lengths); + copy(wei_g_k_c_xs_desc.GetStrides(), wei_strides); + copy(in_g_n_c_wis_desc.GetLengths(), in_lengths); + copy(in_g_n_c_wis_desc.GetStrides(), in_strides); + copy(conv_param.conv_filter_strides_, conv_filter_strides); + copy(conv_param.conv_filter_dilations_, conv_filter_dilations); + copy(conv_param.input_left_pads_, input_left_pads); + copy(conv_param.input_right_pads_, input_right_pads); + + RunReference(conv_param, in_host, wei, out, d); + + using DeviceOp = + ck::tensor_operation::device::DeviceGroupedConvBwdDataMultipleD, + InLayout, + OutDataType, + WeiDataType, + ck::Tuple, + InDataType, + PassThrough, + PassThrough, + Bilinear>; + + // get device op instances + const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + int num_kernel = 0; + + for(std::size_t i = 0; i < op_ptrs.size(); ++i) + { + auto& op_ptr = op_ptrs[i]; + auto argument_ptr = op_ptr->MakeArgumentPointer( + out_device_buf.GetDeviceBuffer(), + wei_device_buf.GetDeviceBuffer(), + {d_device_buf.GetDeviceBuffer()}, + in_device_buf.GetDeviceBuffer(), + out_lengths, + out_strides, + wei_lengths, + wei_strides, + std::array, NumDs>{in_lengths}, + std::array, NumDs>{in_strides}, + in_lengths, + in_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + PassThrough{}, + PassThrough{}, + Bilinear{alpha, beta}, + split_k); + + DeviceMem workspace_buf(op_ptr->GetWorkSpaceSize(argument_ptr.get())); + op_ptr->SetWorkSpacePointer(argument_ptr.get(), workspace_buf.GetDeviceBuffer()); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + std::string op_name = op_ptr->GetTypeString(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + ++num_kernel; + if((instance_index_ != -1) && (instance_index_ + 1 != num_kernel)) + { + // skip test if instance_index is specified + continue; + } + + float avg_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true}); + in_device_buf.FromDevice(in_device.mData.data()); + + passed &= ck::utils::check_err(in_device, in_host); + + std::size_t flop = conv_param.GetFlops() + + 3 * conv_param.GetOutputByte() / sizeof(InDataType); + std::size_t num_bytes = conv_param.GetByte() + + conv_param.GetOutputByte(); + + float tflops = static_cast(flop) / 1.E9 / avg_time; + float gb_per_sec = num_bytes / 1.E6 / avg_time; + + std::cout << "Perf: " << std::setw(10) << avg_time << " ms, " << tflops + << " TFlops, " << gb_per_sec << " GB/s, " << op_name << std::endl; + } + else + { + std::cerr << op_name << " does not support this problem" << std::endl; + } + } + if(instance_index != -1) + { + std::cout << "grouped_conv_bwd_data_instance (" << instance_index << "/" << num_kernel + << "): Passed" << std::endl; + } + printf("\033[36mvalids: %d\033[0m\n", num_kernel); + return passed; + } + + void Run() + { + EXPECT_FALSE(conv_params.empty()); + bool pass = true; + for(auto split_k : split_ks) + { + for(auto& param : conv_params) + { + pass = pass && PerformConvDataBilinear(param, split_k, instance_index); + } + } + EXPECT_TRUE(pass); + } +}; + +template +class TestGroupedConvndBwdData3d : public TestGroupedConvndBwdData +{ +}; + +using NDHWGC = ck::tensor_layout::convolution::NDHWGC; +using GKZYXC = ck::tensor_layout::convolution::GKZYXC; +using NDHWGK = ck::tensor_layout::convolution::NDHWGK; + +using KernelTypes3d = ::testing::Types, + std::tuple, + std::tuple>; + +TYPED_TEST_SUITE(TestGroupedConvndBwdData3d, KernelTypes3d); + +TYPED_TEST(TestGroupedConvndBwdData3d, Test3D) +{ + // TODO: To fix the impl to pass with stride greater than 1. + // this->conv_params.push_back( + // {3, 2, 16, 128, 128, {1, 1, 1}, {7, 7, 7}, {2, 2, 2}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}}); + this->conv_params.push_back( + {3, 2, 16, 128, 128, {1, 1, 1}, {7, 7, 7}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}}); + this->conv_params.push_back( + {3, 2, 2, 128, 128, {3, 3, 3}, {14, 14, 3}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}}); + this->conv_params.push_back( + {3, 2, 32, 128, 128, {1, 1, 1}, {3, 3, 3}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}}); + this->conv_params.push_back( + {3, 1, 1, 1, 32, {3, 3, 3}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}}); + this->conv_params.push_back( + {3, 1, 1, 64, 3, {3, 3, 3}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}}); + this->conv_params.push_back( + {3, 1, 1, 1, 1, {3, 3, 3}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}}); + this->conv_params.push_back( + {3, 1, 1, 4, 4, {3, 3, 3}, {14, 28, 28}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}}); + this->Run(); +} + +int main(int argc, char** argv) +{ + testing::InitGoogleTest(&argc, argv); + if(argc == 1) {} + else if(argc == 3) + { + param_mask = strtol(argv[1], nullptr, 0); + instance_index = atoi(argv[2]); + } + else + { + std::cout << "Usage of " << argv[0] << std::endl; + std::cout << "Arg1,2: param_mask instance_index(-1 means all)" << std::endl; + } + return RUN_ALL_TESTS(); +} diff --git a/test/grouped_convnd_bwd_data/test_grouped_conv_bwd_data_scale.cpp b/test/grouped_convnd_bwd_data/test_grouped_conv_bwd_data_scale.cpp new file mode 100644 index 0000000000..84d013bca7 --- /dev/null +++ b/test/grouped_convnd_bwd_data/test_grouped_conv_bwd_data_scale.cpp @@ -0,0 +1,324 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include +#include +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" + +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data_scale.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/convolution_parameter.hpp" +#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_conv_bwd_data.hpp" + +using ::ck::DeviceMem; +using ::ck::HostTensorDescriptor; +using ::ck::Tensor; + +template +class TestGroupedConvndBwdData : public ::testing::Test +{ + protected: + using F16 = ck::half_t; + using InDataType = std::tuple_element_t<0, Tuple>; + using WeiDataType = std::tuple_element_t<0, Tuple>; + using OutDataType = std::tuple_element_t<0, Tuple>; + + using ComputeDataType = InDataType; + using InLayout = std::tuple_element_t<3, Tuple>; + using WeiLayout = std::tuple_element_t<2, Tuple>; + using OutLayout = std::tuple_element_t<1, Tuple>; + + using WeiElementOp = ck::tensor_operation::element_wise::PassThrough; + using InElementOp = ck::tensor_operation::element_wise::Scale; + using OutElementOp = ck::tensor_operation::element_wise::PassThrough; + using PassThrough = ck::tensor_operation::element_wise::PassThrough; + + using Scale = ck::tensor_operation::element_wise::Scale; + static constexpr ck::index_t NDimSpatial = 3; + static constexpr float alpha = 2.f; + + std::vector conv_params; + std::vector split_ks{1}; + + void RunReference(ck::utils::conv::ConvParam& conv_param, + Tensor& in_host, + Tensor& wei, + Tensor& out) + { + auto ref_conv = + ck::tensor_operation::host::ReferenceConvBwdData /*Num D Elementwise + Tensors*/ + {}; + + auto ref_invoker = ref_conv.MakeInvoker(); + + auto ref_argument = ref_conv.MakeArgument(in_host, + wei, + out, + conv_param.conv_filter_strides_, + conv_param.conv_filter_dilations_, + conv_param.input_left_pads_, + conv_param.input_right_pads_, + InElementOp{alpha}, + WeiElementOp{}, + OutElementOp{}); + + ref_invoker.Run(ref_argument); + } + + bool PerformConvDataScale(ck::utils::conv::ConvParam& conv_param, const ck::index_t split_k) + { + bool passed = true; + + const auto out_g_n_k_wos_desc = + ck::utils::conv::make_output_host_tensor_descriptor_g_n_k_wos_packed( + conv_param); + + const auto wei_g_k_c_xs_desc = + ck::utils::conv::make_weight_host_tensor_descriptor_g_k_c_xs_packed( + conv_param); + + const auto in_g_n_c_wis_desc = + ck::utils::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed( + conv_param); + + Tensor wei(wei_g_k_c_xs_desc); + Tensor out(out_g_n_k_wos_desc); + Tensor in_host(in_g_n_c_wis_desc); + Tensor in_device(in_g_n_c_wis_desc); + + std::cout << "in: " << in_host.mDesc << std::endl; + std::cout << "wei: " << wei.mDesc << std::endl; + std::cout << "out: " << out.mDesc << std::endl; + + out.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + wei.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + + DeviceMem in_device_buf(sizeof(InDataType) * in_device.mDesc.GetElementSpaceSize()); + DeviceMem wei_device_buf(sizeof(WeiDataType) * wei.mDesc.GetElementSpaceSize()); + DeviceMem out_device_buf(sizeof(OutDataType) * out.mDesc.GetElementSpaceSize()); + + in_device_buf.ToDevice(in_device.mData.data()); + out_device_buf.ToDevice(out.mData.data()); + wei_device_buf.ToDevice(wei.mData.data()); + + std::array out_lengths{}; + std::array out_strides{}; + std::array wei_lengths{}; + std::array wei_strides{}; + std::array in_lengths{}; + std::array in_strides{}; + std::array conv_filter_strides{}; + std::array conv_filter_dilations{}; + std::array input_left_pads{}; + std::array input_right_pads{}; + + auto copy = [](const auto& x, auto& y) { ck::ranges::copy(x, y.begin()); }; + + copy(out_g_n_k_wos_desc.GetLengths(), out_lengths); + copy(out_g_n_k_wos_desc.GetStrides(), out_strides); + copy(wei_g_k_c_xs_desc.GetLengths(), wei_lengths); + copy(wei_g_k_c_xs_desc.GetStrides(), wei_strides); + copy(in_g_n_c_wis_desc.GetLengths(), in_lengths); + copy(in_g_n_c_wis_desc.GetStrides(), in_strides); + copy(conv_param.conv_filter_strides_, conv_filter_strides); + copy(conv_param.conv_filter_dilations_, conv_filter_dilations); + copy(conv_param.input_left_pads_, input_left_pads); + copy(conv_param.input_right_pads_, input_right_pads); + + RunReference(conv_param, in_host, wei, out); + + using DeviceOp = + ck::tensor_operation::device::DeviceGroupedConvBwdDataMultipleD, + InLayout, + OutDataType, + WeiDataType, + ck::Tuple<>, + InDataType, + PassThrough, + PassThrough, + Scale>; + + // get device op instances + const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + int num_kernel = 0; + + for(std::size_t i = 0; i < op_ptrs.size(); ++i) + { + auto& op_ptr = op_ptrs[i]; + auto argument_ptr = op_ptr->MakeArgumentPointer(out_device_buf.GetDeviceBuffer(), + wei_device_buf.GetDeviceBuffer(), + {}, + in_device_buf.GetDeviceBuffer(), + out_lengths, + out_strides, + wei_lengths, + wei_strides, + {}, + {}, + in_lengths, + in_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + PassThrough{}, + PassThrough{}, + Scale{alpha}); + + DeviceMem workspace_buf(op_ptr->GetWorkSpaceSize(argument_ptr.get())); + op_ptr->SetWorkSpacePointer(argument_ptr.get(), workspace_buf.GetDeviceBuffer()); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + std::string op_name = op_ptr->GetTypeString(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + num_kernel++; + float avg_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true}); + in_device_buf.FromDevice(in_device.mData.data()); + + using ComputeType_ = std::conditional_t; + using ComputeType = + std::conditional_t; + using AccDataType = + std::conditional_t, int32_t, float>; + const ck::index_t num_accums = conv_param.K_; + float max_accumulated_value = + *std::max_element(in_host.mData.begin(), in_host.mData.end()); + + const ck::index_t split_k_for_run = split_k; + // Calculate thresholds + auto rtol = ck::utils::get_relative_threshold( + num_accums / split_k_for_run); + auto atol = ck::utils::get_absolute_threshold( + max_accumulated_value / split_k_for_run, num_accums / split_k_for_run); + // Calculate error due to split_k accumulation + auto rtol_split_k = + ck::utils::get_relative_threshold( + split_k_for_run); + auto atol_split_k = + ck::utils::get_absolute_threshold( + max_accumulated_value, split_k_for_run); + // Use higher threshold + rtol = std::max(rtol, rtol_split_k); + atol = std::max(atol, atol_split_k); + if(split_k_for_run > 1) + { + passed &= ck::utils::check_err( + in_device, in_host, "Error: Incorrect results!", rtol, atol); + std::cout << "Relative error threshold: " << rtol + << " Absolute error threshold: " << atol << std::endl; + } + else + { + passed &= ck::utils::check_err( + in_device, in_host, "Error: Incorrect results!", rtol, atol); + std::cout << "Relative error threshold: " << rtol + << " Absolute error threshold: " << atol << std::endl; + } + std::size_t flop = conv_param.GetFlops() + + 3 * conv_param.GetOutputByte() / sizeof(InDataType); + std::size_t num_bytes = conv_param.GetByte() + + conv_param.GetOutputByte(); + + float tflops = static_cast(flop) / 1.E9 / avg_time; + float gb_per_sec = num_bytes / 1.E6 / avg_time; + + std::cout << "Perf: " << std::setw(10) << avg_time << " ms, " << tflops + << " TFlops, " << gb_per_sec << " GB/s, " << op_name << std::endl; + } + else + { + std::cerr << op_name << " does not support this problem" << std::endl; + } + } + + printf("\033[36mvalids: %d\033[0m\n", num_kernel); + return passed; + } + + void Run() + { + EXPECT_FALSE(conv_params.empty()); + bool pass = true; + + for(auto split_k : split_ks) + { + for(auto& param : conv_params) + { + pass = pass && PerformConvDataScale(param, split_k); + } + } + EXPECT_TRUE(pass); + } +}; + +template +class TestGroupedConvndBwdData3d : public TestGroupedConvndBwdData +{ +}; + +using NDHWGC = ck::tensor_layout::convolution::NDHWGC; +using GKZYXC = ck::tensor_layout::convolution::GKZYXC; +using NDHWGK = ck::tensor_layout::convolution::NDHWGK; + +using KernelTypes3d = ::testing::Types, + std::tuple, + std::tuple>; + +TYPED_TEST_SUITE(TestGroupedConvndBwdData3d, KernelTypes3d); + +TYPED_TEST(TestGroupedConvndBwdData3d, Test3D) +{ + this->conv_params.push_back( + {3, 2, 16, 128, 128, {1, 1, 1}, {7, 7, 7}, {2, 2, 2}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}}); + this->conv_params.push_back( + {3, 2, 2, 128, 128, {3, 3, 3}, {14, 14, 3}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}}); + this->conv_params.push_back( + {3, 2, 32, 128, 128, {1, 1, 1}, {3, 3, 3}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}}); + this->conv_params.push_back( + {3, 1, 1, 1, 32, {3, 3, 3}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}}); + this->conv_params.push_back( + {3, 1, 1, 64, 3, {3, 3, 3}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}}); + this->conv_params.push_back( + {3, 1, 1, 1, 1, {3, 3, 3}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}}); + this->conv_params.push_back( + {3, 1, 1, 4, 4, {3, 3, 3}, {14, 28, 28}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}}); + this->conv_params.push_back( + {3, 1, 64, 16, 32, {3, 3, 3}, {28, 28, 28}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}}); + + this->Run(); +} diff --git a/test/grouped_convnd_bwd_data/test_grouped_convnd_bwd_data_xdl.cpp b/test/grouped_convnd_bwd_data/test_grouped_convnd_bwd_data.cpp similarity index 94% rename from test/grouped_convnd_bwd_data/test_grouped_convnd_bwd_data_xdl.cpp rename to test/grouped_convnd_bwd_data/test_grouped_convnd_bwd_data.cpp index efedf416f0..846d477973 100644 --- a/test/grouped_convnd_bwd_data/test_grouped_convnd_bwd_data_xdl.cpp +++ b/test/grouped_convnd_bwd_data/test_grouped_convnd_bwd_data.cpp @@ -15,7 +15,7 @@ static ck::index_t param_mask = 0xffffff; static ck::index_t instance_index = -1; template -class TestGroupedConvndBwdDataXdl : public ::testing::Test +class TestGroupedConvndBwdData : public ::testing::Test { protected: using DataType = std::tuple_element_t<0, Tuple>; @@ -89,19 +89,19 @@ using KernelTypes3d = ::testing::Types std::tuple>; template -class TestGroupedConvndBwdDataXdl2d : public TestGroupedConvndBwdDataXdl +class TestGroupedConvndBwdData2d : public TestGroupedConvndBwdData { }; template -class TestGroupedConvndBwdDataXdl3d : public TestGroupedConvndBwdDataXdl +class TestGroupedConvndBwdData3d : public TestGroupedConvndBwdData { }; -TYPED_TEST_SUITE(TestGroupedConvndBwdDataXdl2d, KernelTypes2d); -TYPED_TEST_SUITE(TestGroupedConvndBwdDataXdl3d, KernelTypes3d); +TYPED_TEST_SUITE(TestGroupedConvndBwdData2d, KernelTypes2d); +TYPED_TEST_SUITE(TestGroupedConvndBwdData3d, KernelTypes3d); -TYPED_TEST(TestGroupedConvndBwdDataXdl2d, Test2D) +TYPED_TEST(TestGroupedConvndBwdData2d, Test2D) { this->conv_params.clear(); @@ -137,7 +137,7 @@ TYPED_TEST(TestGroupedConvndBwdDataXdl2d, Test2D) this->template Run<2>(); } -TYPED_TEST(TestGroupedConvndBwdDataXdl3d, Test3D) +TYPED_TEST(TestGroupedConvndBwdData3d, Test3D) { this->conv_params.clear(); this->conv_params.push_back( diff --git a/test/grouped_convnd_bwd_data/test_grouped_convnd_bwd_data_xdl_large_cases.cpp b/test/grouped_convnd_bwd_data/test_grouped_convnd_bwd_data_large_cases.cpp similarity index 91% rename from test/grouped_convnd_bwd_data/test_grouped_convnd_bwd_data_xdl_large_cases.cpp rename to test/grouped_convnd_bwd_data/test_grouped_convnd_bwd_data_large_cases.cpp index ffea6516db..207b085e1a 100644 --- a/test/grouped_convnd_bwd_data/test_grouped_convnd_bwd_data_xdl_large_cases.cpp +++ b/test/grouped_convnd_bwd_data/test_grouped_convnd_bwd_data_large_cases.cpp @@ -12,7 +12,7 @@ #include "profiler/profile_grouped_conv_bwd_data_impl.hpp" template -class TestGroupedConvndBwdDataXdl : public ::testing::Test +class TestGroupedConvndBwdData : public ::testing::Test { protected: using DataType = std::tuple_element_t<0, Tuple>; @@ -80,19 +80,19 @@ using KernelTypes3d = ::testing::Types std::tuple>; template -class TestGroupedConvndBwdDataXdl2d : public TestGroupedConvndBwdDataXdl +class TestGroupedConvndBwdData2d : public TestGroupedConvndBwdData { }; template -class TestGroupedConvndBwdDataXdl3d : public TestGroupedConvndBwdDataXdl +class TestGroupedConvndBwdData3d : public TestGroupedConvndBwdData { }; -TYPED_TEST_SUITE(TestGroupedConvndBwdDataXdl2d, KernelTypes2d); -TYPED_TEST_SUITE(TestGroupedConvndBwdDataXdl3d, KernelTypes3d); +TYPED_TEST_SUITE(TestGroupedConvndBwdData2d, KernelTypes2d); +TYPED_TEST_SUITE(TestGroupedConvndBwdData3d, KernelTypes3d); -TYPED_TEST(TestGroupedConvndBwdDataXdl2d, Test2D) +TYPED_TEST(TestGroupedConvndBwdData2d, Test2D) { this->conv_params.clear(); // SplitN case @@ -101,7 +101,7 @@ TYPED_TEST(TestGroupedConvndBwdDataXdl2d, Test2D) this->template Run<2>(); } -TYPED_TEST(TestGroupedConvndBwdDataXdl3d, Test3D) +TYPED_TEST(TestGroupedConvndBwdData3d, Test3D) { this->conv_params.clear(); // SplitN case diff --git a/test/grouped_convnd_bwd_data/test_grouped_convnd_bwd_data_wmma.cpp b/test/grouped_convnd_bwd_data/test_grouped_convnd_bwd_data_wmma.cpp deleted file mode 100644 index 9becb48be2..0000000000 --- a/test/grouped_convnd_bwd_data/test_grouped_convnd_bwd_data_wmma.cpp +++ /dev/null @@ -1,135 +0,0 @@ -// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. -// SPDX-License-Identifier: MIT - -#include -#include -#include -#include -#include - -#include - -#include "profiler/profile_grouped_conv_bwd_data_impl.hpp" - -static ck::index_t param_mask = 0xffff; -static ck::index_t instance_index = -1; - -template -class TestGroupedConvndBwdDataWmma : public ::testing::Test -{ - protected: - using DataType = std::tuple_element_t<0, Tuple>; - using OutLayout = std::tuple_element_t<1, Tuple>; - using WeiLayout = std::tuple_element_t<2, Tuple>; - using InLayout = std::tuple_element_t<3, Tuple>; - - std::vector conv_params; - - template - void Run() - { - EXPECT_FALSE(conv_params.empty()); - bool pass = true; - for(size_t i = 0; i < conv_params.size(); i++) - { - if((param_mask & (1 << i)) == 0) - { - continue; - } - auto& param = conv_params[i]; - pass = pass && ck::profiler::profile_grouped_conv_bwd_data_impl( - true, // do_verification - 1, // init_method: integer value - false, // do_log - false, // time_kernel - param, - 1, // splitK - instance_index); - } - EXPECT_TRUE(pass); - } -}; - -using namespace ck::tensor_layout::convolution; - -using KernelTypes2d = ::testing::Types, - std::tuple, - std::tuple, - std::tuple>; - -using KernelTypes3d = ::testing::Types, - std::tuple, - std::tuple, - std::tuple>; - -template -class TestGroupedConvndBwdDataWmma2d : public TestGroupedConvndBwdDataWmma -{ -}; - -template -class TestGroupedConvndBwdDataWmma3d : public TestGroupedConvndBwdDataWmma -{ -}; - -TYPED_TEST_SUITE(TestGroupedConvndBwdDataWmma2d, KernelTypes2d); -TYPED_TEST_SUITE(TestGroupedConvndBwdDataWmma3d, KernelTypes3d); - -TYPED_TEST(TestGroupedConvndBwdDataWmma2d, Test2D) -{ - this->conv_params.clear(); - - this->conv_params.push_back( - {2, 2, 4, 192, 192, {3, 3}, {28, 28}, {1, 1}, {1, 1}, {1, 1}, {1, 1}}); - this->conv_params.push_back( - {2, 2, 128, 128, 256, {3, 3}, {14, 14}, {1, 1}, {1, 1}, {1, 1}, {1, 1}}); - this->conv_params.push_back( - {2, 2, 128, 128, 256, {1, 1}, {7, 7}, {2, 2}, {1, 1}, {0, 0}, {0, 0}}); - this->conv_params.push_back( - {2, 2, 128, 128, 256, {1, 1}, {3, 3}, {1, 1}, {1, 1}, {0, 0}, {0, 0}}); - this->conv_params.push_back({2, 1, 1, 1, 32, {8, 8}, {32, 32}, {1, 1}, {1, 1}, {1, 1}, {1, 1}}); - this->conv_params.push_back({2, 1, 1, 64, 3, {8, 8}, {32, 32}, {1, 1}, {1, 1}, {1, 1}, {1, 1}}); - this->conv_params.push_back({2, 1, 1, 1, 1, {8, 8}, {32, 32}, {1, 1}, {1, 1}, {1, 1}, {1, 1}}); - this->template Run<2>(); -} - -TYPED_TEST(TestGroupedConvndBwdDataWmma3d, Test3D) -{ - this->conv_params.clear(); - this->conv_params.push_back( - {3, 2, 16, 128, 256, {1, 1, 1}, {7, 7, 7}, {2, 2, 2}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}}); - this->conv_params.push_back( - {3, 2, 2, 128, 256, {3, 3, 3}, {14, 14, 3}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}}); - this->conv_params.push_back( - {3, 2, 32, 128, 256, {1, 1, 1}, {3, 3, 3}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}}); - this->conv_params.push_back( - {3, 1, 1, 1, 32, {3, 3, 3}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}}); - this->conv_params.push_back( - {3, 1, 1, 64, 3, {3, 3, 3}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}}); - this->conv_params.push_back( - {3, 1, 1, 1, 1, {3, 3, 3}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}}); - this->template Run<3>(); -} - -int main(int argc, char** argv) -{ - testing::InitGoogleTest(&argc, argv); - if(argc == 1) {} - else if(argc == 3) - { - param_mask = strtol(argv[1], nullptr, 0); - instance_index = atoi(argv[2]); - } - else - { - std::cout << "Usage of " << argv[0] << std::endl; - std::cout << "Arg1,2: param_mask instance_index(-1 means all)" << std::endl; - } - return RUN_ALL_TESTS(); -}