From 334ae1c494b04f9e034d067b4facea0e367ea9ac Mon Sep 17 00:00:00 2001 From: "assistant-librarian[bot]" Date: Thu, 18 Dec 2025 00:34:53 +0000 Subject: [PATCH] Merge commit '87dd073887933fc2c75c234871e3885cee970a98' into develop --- .../20_grouped_conv_bwd_weight/CMakeLists.txt | 7 +- .../grouped_conv_bwd_weight_v3_wmma_bf16.cpp | 100 ++ ... grouped_conv_bwd_weight_v3_wmma_fp16.cpp} | 55 +- .../run_grouped_conv_bwd_weight_example.inc | 18 +- ...tched_gemm_multiple_d_wmma_cshuffle_v3.hpp | 764 ++++++++ ...atched_gemm_multiple_d_xdl_cshuffle_v3.hpp | 5 + .../device_grouped_conv_bwd_weight_dl.hpp | 9 +- ...vice_grouped_conv_bwd_weight_explicit.hpp} | 45 +- ...bwd_weight_multiple_d_wmma_cshuffle_v3.hpp | 1258 +++++++++++++ ..._bwd_weight_two_stage_wmma_cshuffle_v3.hpp | 1578 +++++++++++++++++ ...conv_bwd_weight_two_stage_xdl_cshuffle.hpp | 5 + ..._grouped_conv_bwd_weight_wmma_cshuffle.hpp | 9 +- ...ouped_conv_bwd_weight_wmma_cshuffle_v3.hpp | 1429 +++++++++++++++ .../gridwise_ab_transfer_thread_tiles.hpp | 6 +- .../grid/gridwise_gemm_wmma_cshuffle_v3.hpp | 146 +- ...ridwise_gemm_wmma_cshuffle_v3_ab_scale.hpp | 12 +- .../gridwise_gemm_wmma_cshuffle_v3_common.hpp | 34 +- .../utility/generic_memory_space_atomic.hpp | 23 + ..._bwd_wei_exp_device_operation_instance.hpp | 24 +- ..._gemm_wmma_universal_km_kn_mn_instance.hpp | 138 ++ ...onv_bwd_weight_two_stage_wmma_instance.hpp | 91 + ...ouped_conv_bwd_weight_v3_wmma_instance.hpp | 100 ++ ...conv_bwd_weight_wmma_bilinear_instance.hpp | 97 + ..._grouped_conv_bwd_weight_wmma_instance.hpp | 117 -- ...ed_conv_bwd_weight_wmma_scale_instance.hpp | 96 + .../grouped_convolution_backward_weight.hpp | 153 +- ...d_convolution_backward_weight_bilinear.hpp | 62 + ...volution_backward_weight_explicit_wmma.inc | 171 ++ ...nvolution_backward_weight_explicit_xdl.inc | 72 +- ...uped_convolution_backward_weight_scale.hpp | 62 + ...ouped_convolution_backward_weight_wmma.inc | 120 +- .../grouped_conv1d_bwd_weight/CMakeLists.txt | 2 +- .../grouped_conv2d_bwd_weight/CMakeLists.txt | 9 +- ...nhwgc_gkyxc_nhwgk_bf16_pipev1_instance.cpp | 41 + ...nhwgc_gkyxc_nhwgk_f16_pipev1_instance.cpp} | 28 +- ...t_wmma_nhwgc_gkyxc_nhwgk_bf16_instance.cpp | 38 + ...ht_wmma_nhwgc_gkyxc_nhwgk_f16_instance.cpp | 38 + .../grouped_conv3d_bwd_weight/CMakeLists.txt | 15 +- ...hwc_gkzyxc_gndhwk_f16_1x1s1p0_instance.cpp | 35 - ...dhwc_gkzyxc_gndhwk_i8_1x1s1p0_instance.cpp | 35 - ..._wmma_gndhwc_gkzyxc_gndhwk_i8_instance.cpp | 35 - ...gc_gkzyxc_ndhwgk_bf16_pipev1_instance.cpp} | 26 +- ...wgc_gkzyxc_ndhwgk_f16_pipev1_instance.cpp} | 20 +- ...ma_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp} | 23 +- ...wmma_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp | 15 +- .../CMakeLists.txt | 7 +- ...ear_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp | 50 + ...near_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp | 50 + .../CMakeLists.txt | 7 +- ...ale_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp | 49 + ...cale_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp | 49 + .../grouped_convnd_bwd_weight/CMakeLists.txt | 46 +- ...16_bf16_bf16_exp_comp_default_instance.cpp | 67 + ...bf16_bf16_exp_comp_mnkpadding_instance.cpp | 67 + ...mma_bf16_bf16_bf16_exp_odd_mn_instance.cpp | 67 + ..._f16_f16_f16_exp_comp_default_instance.cpp | 67 + ...6_f16_f16_exp_comp_mnkpadding_instance.cpp | 67 + ...t_wmma_f16_f16_f16_exp_odd_mn_instance.cpp | 67 + ...6_bf16_bf16_exp_comp_default_instance.cpp} | 4 +- ...f16_bf16_exp_comp_mnkpadding_instance.cpp} | 4 +- ...bf16_bf16_exp_mem_v1_default_instance.cpp} | 4 +- ...6_bf16_exp_mem_v1_mnkpadding_instance.cpp} | 4 +- ...bf16_bf16_exp_mem_v2_default_instance.cpp} | 4 +- ...6_bf16_exp_mem_v2_mnkpadding_instance.cpp} | 4 +- ...xdl_bf16_bf16_bf16_exp_odd_m_instance.cpp} | 4 +- ...dl_bf16_bf16_bf16_exp_odd_mn_instance.cpp} | 4 +- ...xdl_bf16_bf16_bf16_exp_odd_n_instance.cpp} | 4 +- ...f16_f16_f16_exp_comp_default_instance.cpp} | 4 +- ..._f16_f16_exp_comp_mnkpadding_instance.cpp} | 4 +- ...6_f16_f16_exp_mem_v1_default_instance.cpp} | 4 +- ...16_f16_exp_mem_v1_mnkpadding_instance.cpp} | 4 +- ...6_f16_f16_exp_mem_v2_default_instance.cpp} | 4 +- ...16_f16_exp_mem_v2_mnkpadding_instance.cpp} | 4 +- ...ht_xdl_f16_f16_f16_exp_odd_m_instance.cpp} | 4 +- ...t_xdl_f16_f16_f16_exp_odd_mn_instance.cpp} | 4 +- ...ht_xdl_f16_f16_f16_exp_odd_n_instance.cpp} | 4 +- .../profile_grouped_conv_bwd_weight_impl.hpp | 37 +- profiler/src/CMakeLists.txt | 8 +- test/grouped_convnd_bwd_weight/CMakeLists.txt | 13 +- .../test_grouped_convnd_bwd_weight.cpp | 38 - ...st_grouped_convnd_bwd_weight_bilinear.cpp} | 30 +- .../test_grouped_convnd_bwd_weight_scale.cpp | 294 +++ 82 files changed, 7696 insertions(+), 622 deletions(-) create mode 100644 example/20_grouped_conv_bwd_weight/grouped_conv_bwd_weight_v3_wmma_bf16.cpp rename example/20_grouped_conv_bwd_weight/{grouped_conv_bwd_weight_wmma_fp16.cpp => grouped_conv_bwd_weight_v3_wmma_fp16.cpp} (56%) create mode 100644 include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_wmma_cshuffle_v3.hpp rename include/ck/tensor_operation/gpu/device/impl/{device_grouped_conv_bwd_weight_explicit_xdl.hpp => device_grouped_conv_bwd_weight_explicit.hpp} (94%) create mode 100644 include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_wmma_cshuffle_v3.hpp create mode 100644 include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_wmma_cshuffle_v3.hpp create mode 100644 include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle_v3.hpp create mode 100644 library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_exp_gemm_wmma_universal_km_kn_mn_instance.hpp create mode 100644 library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_two_stage_wmma_instance.hpp create mode 100644 library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_v3_wmma_instance.hpp create mode 100644 library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_wmma_bilinear_instance.hpp delete mode 100644 library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_wmma_instance.hpp create mode 100644 library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_wmma_scale_instance.hpp create mode 100644 library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight_explicit_wmma.inc create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/wmma/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_two_stage_wmma_nhwgc_gkyxc_nhwgk_bf16_pipev1_instance.cpp rename library/src/tensor_operation_instance/gpu/{grouped_conv3d_bwd_weight/wmma/device_grouped_conv3d_bwd_weight_wmma_gndhwc_gkzyxc_gndhwk_f16_instance.cpp => grouped_conv2d_bwd_weight/wmma/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_two_stage_wmma_nhwgc_gkyxc_nhwgk_f16_pipev1_instance.cpp} (52%) create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/wmma/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_wmma_nhwgc_gkyxc_nhwgk_bf16_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/wmma/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_wmma_nhwgc_gkyxc_nhwgk_f16_instance.cpp delete mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/wmma/device_grouped_conv3d_bwd_weight_wmma_gndhwc_gkzyxc_gndhwk_f16_1x1s1p0_instance.cpp delete mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/wmma/device_grouped_conv3d_bwd_weight_wmma_gndhwc_gkzyxc_gndhwk_i8_1x1s1p0_instance.cpp delete mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/wmma/device_grouped_conv3d_bwd_weight_wmma_gndhwc_gkzyxc_gndhwk_i8_instance.cpp rename library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/wmma/{device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_i8_instance.cpp => ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_two_stage_wmma_ndhwgc_gkzyxc_ndhwgk_bf16_pipev1_instance.cpp} (56%) rename library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/wmma/{device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_f16_1x1s1p0_instance.cpp => ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_two_stage_wmma_ndhwgc_gkzyxc_ndhwgk_f16_pipev1_instance.cpp} (67%) rename library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/wmma/{device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_i8_1x1s1p0_instance.cpp => ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp} (56%) rename library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/wmma/{ => ndhwgc_gkzyxc_ndhwgk}/device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp (68%) create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight_bilinear/wmma/device_grouped_conv3d_bwd_weight_wmma_bilinear_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight_bilinear/wmma/device_grouped_conv3d_bwd_weight_wmma_bilinear_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight_scale/wmma/device_grouped_conv3d_bwd_weight_wmma_scale_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight_scale/wmma/device_grouped_conv3d_bwd_weight_wmma_scale_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/explicit_wmma/bf16_bf16_bf16/device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_comp_default_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/explicit_wmma/bf16_bf16_bf16/device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_comp_mnkpadding_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/explicit_wmma/bf16_bf16_bf16/device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_odd_mn_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/explicit_wmma/fp16_fp16_fp16/device_grouped_convnd_bwd_weight_wmma_f16_f16_f16_exp_comp_default_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/explicit_wmma/fp16_fp16_fp16/device_grouped_convnd_bwd_weight_wmma_f16_f16_f16_exp_comp_mnkpadding_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/explicit_wmma/fp16_fp16_fp16/device_grouped_convnd_bwd_weight_wmma_f16_f16_f16_exp_odd_mn_instance.cpp rename library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/explicit_xdl/bf16_bf16_bf16/{device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_comp_default_instance.cpp => device_grouped_convnd_bwd_weight_xdl_bf16_bf16_bf16_exp_comp_default_instance.cpp} (93%) rename library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/explicit_xdl/bf16_bf16_bf16/{device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_comp_mnkpadding_instance.cpp => device_grouped_convnd_bwd_weight_xdl_bf16_bf16_bf16_exp_comp_mnkpadding_instance.cpp} (93%) rename library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/explicit_xdl/bf16_bf16_bf16/{device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v1_default_instance.cpp => device_grouped_convnd_bwd_weight_xdl_bf16_bf16_bf16_exp_mem_v1_default_instance.cpp} (93%) rename library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/explicit_xdl/bf16_bf16_bf16/{device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v1_mnkpadding_instance.cpp => device_grouped_convnd_bwd_weight_xdl_bf16_bf16_bf16_exp_mem_v1_mnkpadding_instance.cpp} (92%) rename library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/explicit_xdl/bf16_bf16_bf16/{device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v2_default_instance.cpp => device_grouped_convnd_bwd_weight_xdl_bf16_bf16_bf16_exp_mem_v2_default_instance.cpp} (93%) rename library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/explicit_xdl/bf16_bf16_bf16/{device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v2_mnkpadding_instance.cpp => device_grouped_convnd_bwd_weight_xdl_bf16_bf16_bf16_exp_mem_v2_mnkpadding_instance.cpp} (92%) rename library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/explicit_xdl/bf16_bf16_bf16/{device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_odd_m_instance.cpp => device_grouped_convnd_bwd_weight_xdl_bf16_bf16_bf16_exp_odd_m_instance.cpp} (94%) rename library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/explicit_xdl/bf16_bf16_bf16/{device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_odd_mn_instance.cpp => device_grouped_convnd_bwd_weight_xdl_bf16_bf16_bf16_exp_odd_mn_instance.cpp} (94%) rename library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/explicit_xdl/bf16_bf16_bf16/{device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_odd_n_instance.cpp => device_grouped_convnd_bwd_weight_xdl_bf16_bf16_bf16_exp_odd_n_instance.cpp} (93%) rename library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/explicit_xdl/fp16_fp16_fp16/{device_grouped_convnd_bwd_weight_f16_f16_f16_exp_comp_default_instance.cpp => device_grouped_convnd_bwd_weight_xdl_f16_f16_f16_exp_comp_default_instance.cpp} (93%) rename library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/explicit_xdl/fp16_fp16_fp16/{device_grouped_convnd_bwd_weight_f16_f16_f16_exp_comp_mnkpadding_instance.cpp => device_grouped_convnd_bwd_weight_xdl_f16_f16_f16_exp_comp_mnkpadding_instance.cpp} (93%) rename library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/explicit_xdl/fp16_fp16_fp16/{device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v1_default_instance.cpp => device_grouped_convnd_bwd_weight_xdl_f16_f16_f16_exp_mem_v1_default_instance.cpp} (93%) rename library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/explicit_xdl/fp16_fp16_fp16/{device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v1_mnkpadding_instance.cpp => device_grouped_convnd_bwd_weight_xdl_f16_f16_f16_exp_mem_v1_mnkpadding_instance.cpp} (93%) rename library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/explicit_xdl/fp16_fp16_fp16/{device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v2_default_instance.cpp => device_grouped_convnd_bwd_weight_xdl_f16_f16_f16_exp_mem_v2_default_instance.cpp} (93%) rename library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/explicit_xdl/fp16_fp16_fp16/{device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v2_mnkpadding_instance.cpp => device_grouped_convnd_bwd_weight_xdl_f16_f16_f16_exp_mem_v2_mnkpadding_instance.cpp} (93%) rename library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/explicit_xdl/fp16_fp16_fp16/{device_grouped_convnd_bwd_weight_f16_f16_f16_exp_odd_m_instance.cpp => device_grouped_convnd_bwd_weight_xdl_f16_f16_f16_exp_odd_m_instance.cpp} (94%) rename library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/explicit_xdl/fp16_fp16_fp16/{device_grouped_convnd_bwd_weight_f16_f16_f16_exp_odd_mn_instance.cpp => device_grouped_convnd_bwd_weight_xdl_f16_f16_f16_exp_odd_mn_instance.cpp} (94%) rename library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/explicit_xdl/fp16_fp16_fp16/{device_grouped_convnd_bwd_weight_f16_f16_f16_exp_odd_n_instance.cpp => device_grouped_convnd_bwd_weight_xdl_f16_f16_f16_exp_odd_n_instance.cpp} (93%) rename test/grouped_convnd_bwd_weight/{test_grouped_conv_bwd_weight_xdl_bilinear.cpp => test_grouped_convnd_bwd_weight_bilinear.cpp} (89%) create mode 100644 test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight_scale.cpp diff --git a/example/20_grouped_conv_bwd_weight/CMakeLists.txt b/example/20_grouped_conv_bwd_weight/CMakeLists.txt index 2e381b09d3..a787a5c1fd 100644 --- a/example/20_grouped_conv_bwd_weight/CMakeLists.txt +++ b/example/20_grouped_conv_bwd_weight/CMakeLists.txt @@ -11,8 +11,11 @@ add_example_dependencies(example_grouped_conv_bwd_weight example_grouped_conv_bw add_example_executable(example_grouped_conv_bwd_weight_xdl_fp16_comp_bf8_fp8 grouped_conv_bwd_weight_xdl_fp16_comp_bf8_fp8.cpp) add_example_dependencies(example_grouped_conv_bwd_weight example_grouped_conv_bwd_weight_xdl_fp16_comp_bf8_fp8) -add_example_executable(example_grouped_conv_bwd_weight_wmma_fp16 grouped_conv_bwd_weight_wmma_fp16.cpp) -add_example_dependencies(example_grouped_conv_bwd_weight example_grouped_conv_bwd_weight_wmma_fp16) +add_example_executable(example_grouped_conv_bwd_weight_v3_wmma_fp16 grouped_conv_bwd_weight_v3_wmma_fp16.cpp) +add_example_dependencies(example_grouped_conv_bwd_weight example_grouped_conv_bwd_weight_v3_wmma_fp16) + +add_example_executable(example_grouped_conv_bwd_weight_v3_wmma_bf16 grouped_conv_bwd_weight_v3_wmma_bf16.cpp) +add_example_dependencies(example_grouped_conv_bwd_weight example_grouped_conv_bwd_weight_v3_wmma_bf16) add_example_executable(example_grouped_conv_bwd_weight_dl_fp16 grouped_conv_bwd_weight_dl_fp16.cpp) add_example_dependencies(example_grouped_conv_bwd_weight example_grouped_conv_bwd_weight_dl_fp16) diff --git a/example/20_grouped_conv_bwd_weight/grouped_conv_bwd_weight_v3_wmma_bf16.cpp b/example/20_grouped_conv_bwd_weight/grouped_conv_bwd_weight_v3_wmma_bf16.cpp new file mode 100644 index 0000000000..9c76a73b7e --- /dev/null +++ b/example/20_grouped_conv_bwd_weight/grouped_conv_bwd_weight_v3_wmma_bf16.cpp @@ -0,0 +1,100 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "common.hpp" + +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle_v3.hpp" + +using InDataType = BF16; +// bf16 kernel use fp32 atomic add to accumulate Weight tensor into global memory +using WeiDataType = F32; +using OutDataType = BF16; +using AccDataType = F32; + +using InElementOp = PassThrough; +using WeiElementOp = PassThrough; +using OutElementOp = PassThrough; + +template +using DeviceConvBwdWeightInstance = + ck::tensor_operation::device::DeviceGroupedConvBwdWeight_Wmma_CShuffleV3< + NDimSpatial, + ck::tuple_element_t>, + ck::tuple_element_t>, + ck::tuple_element_t>, + InDataType, // InDataType + WeiDataType, // WeiDataType + OutDataType, // OutDataType + AccDataType, // AccDataType + InElementOp, // InElementwiseOperation + WeiElementOp, // WeiElementwiseOperation + OutElementOp, // OutElementwiseOperation + ConvBwdWeightDefault, // ConvolutionBackwardWeightSpecialization + 256, // BlockSize + 128, // MPerBlock + 128, // NPerBlock + 32, // KPerBlock + 8, // K1 + 16, // MPerWmma + 16, // NPerWmma + 4, // MRepeat + 2, // NRepeat + S<4, 16, 1>, // ABlockTransferThreadClusterLengths_K0_M_K1 + S<2, 0, 1>, // ABlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // ABlockTransferSrcAccessOrder + 1, // ABlockTransferSrcVectorDim + 1, // ABlockTransferSrcScalarPerVector + 2, // ABlockTransferDstScalarPerVector_K1 + true, // ABlockLdsAddExtraM + S<4, 16, 1>, // BBlockTransferThreadClusterLengths_K0_N_K1 + S<2, 0, 1>, // BBlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // BBlockTransferSrcAccessOrder + 1, // BBlockTransferSrcVectorDim + 1, // BBlockTransferSrcScalarPerVector + 2, // BBlockTransferDstScalarPerVector_K1 + true, // BBlockLdsAddExtraN + 1, // CShuffleMRepeatPerShuffle + 1, // CShuffleNRepeatPerShuffle + S<1, 32, 1, 4>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock + 4>; // CShuffleBlockTransferScalarPerVector_NPerBlock + +template +using HostConvBwdWeightInstance = ck::tensor_operation::host::ReferenceConvBwdWeight; + +#include "run_grouped_conv_bwd_weight_example.inc" + +int main(int argc, char* argv[]) +{ + ExecutionConfig config; + ck::utils::conv::ConvParam conv_param = DefaultConvParam; + + if(!parse_cmd_args(argc, argv, config, conv_param)) + { + return 1; + } + + switch(conv_param.num_dim_spatial_) + { + case 1: return !run_grouped_conv_bwd_weight<1>(config, conv_param); + case 2: return !run_grouped_conv_bwd_weight<2>(config, conv_param); + case 3: return !run_grouped_conv_bwd_weight<3>(config, conv_param); + default: break; + } + + return 1; +} diff --git a/example/20_grouped_conv_bwd_weight/grouped_conv_bwd_weight_wmma_fp16.cpp b/example/20_grouped_conv_bwd_weight/grouped_conv_bwd_weight_v3_wmma_fp16.cpp similarity index 56% rename from example/20_grouped_conv_bwd_weight/grouped_conv_bwd_weight_wmma_fp16.cpp rename to example/20_grouped_conv_bwd_weight/grouped_conv_bwd_weight_v3_wmma_fp16.cpp index a8e9c49d87..f0e2fa0b9d 100644 --- a/example/20_grouped_conv_bwd_weight/grouped_conv_bwd_weight_wmma_fp16.cpp +++ b/example/20_grouped_conv_bwd_weight/grouped_conv_bwd_weight_v3_wmma_fp16.cpp @@ -3,7 +3,7 @@ #include "common.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle_v3.hpp" using InDataType = F16; using WeiDataType = F16; @@ -16,11 +16,20 @@ using OutElementOp = PassThrough; template using DeviceConvBwdWeightInstance = - ck::tensor_operation::device::DeviceGroupedConvBwdWeight_Wmma_CShuffle< + ck::tensor_operation::device::DeviceGroupedConvBwdWeight_Wmma_CShuffleV3< NDimSpatial, - ck::tensor_layout::convolution::GNDHWC, - ck::tensor_layout::convolution::GKZYXC, - ck::tensor_layout::convolution::GNDHWK, + ck::tuple_element_t>, + ck::tuple_element_t>, + ck::tuple_element_t>, InDataType, // InDataType WeiDataType, // WeiDataType OutDataType, // OutDataType @@ -32,30 +41,30 @@ using DeviceConvBwdWeightInstance = 256, // BlockSize 128, // MPerBlock 128, // NPerBlock - 4, // K0PerBlock + 32, // KPerBlock 8, // K1 - 16, // MPerWMMA - 16, // NPerWMMA + 16, // MPerWmma + 16, // NPerWmma 4, // MRepeat 2, // NRepeat - S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1 - S<0, 2, 1>, // ABlockTransferThreadClusterArrangeOrder - S<0, 2, 1>, // ABlockTransferSrcAccessOrder + S<4, 16, 1>, // ABlockTransferThreadClusterLengths_K0_M_K1 + S<2, 0, 1>, // ABlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // ABlockTransferSrcAccessOrder 1, // ABlockTransferSrcVectorDim 1, // ABlockTransferSrcScalarPerVector - 8, // ABlockTransferDstScalarPerVector_AK1 - true, // ABlockLdsExtraM - S<4, 64, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1 - S<0, 2, 1>, // BBlockTransferThreadClusterArrangeOrder - S<0, 2, 1>, // BBlockTransferSrcAccessOrder + 2, // ABlockTransferDstScalarPerVector_K1 + false, // ABlockLdsAddExtraM + S<4, 16, 1>, // BBlockTransferThreadClusterLengths_K0_N_K1 + S<2, 0, 1>, // BBlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // BBlockTransferSrcAccessOrder 1, // BBlockTransferSrcVectorDim 1, // BBlockTransferSrcScalarPerVector - 8, // BBlockTransferDstScalarPerVector_BK1 - true, // BBlockLdsExtraN - 4, - 2, - S<1, 32, 1, 8>, - 1>; + 2, // BBlockTransferDstScalarPerVector_K1 + false, // BBlockLdsAddExtraN + 1, // CShuffleMRepeatPerShuffle + 1, // CShuffleNRepeatPerShuffle + S<1, 32, 1, 4>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock + 4>; // CShuffleBlockTransferScalarPerVector_NPerBlock template using HostConvBwdWeightInstance = ck::tensor_operation::host::ReferenceConvBwdWeight(config, conv_param); + case 2: return !run_grouped_conv_bwd_weight<2>(config, conv_param); case 3: return !run_grouped_conv_bwd_weight<3>(config, conv_param); default: break; } diff --git a/example/20_grouped_conv_bwd_weight/run_grouped_conv_bwd_weight_example.inc b/example/20_grouped_conv_bwd_weight/run_grouped_conv_bwd_weight_example.inc index bc1a5edac6..8cc9f582eb 100644 --- a/example/20_grouped_conv_bwd_weight/run_grouped_conv_bwd_weight_example.inc +++ b/example/20_grouped_conv_bwd_weight/run_grouped_conv_bwd_weight_example.inc @@ -5,7 +5,7 @@ template bool run_grouped_conv_bwd_weight(const ExecutionConfig& config, const ck::utils::conv::ConvParam& conv_param) { - // Dl and WMMA ops don't support split_k > 1 + // Dl ops don't support split_k > 1 constexpr ck::index_t split_k = 1; const auto in_g_n_c_wis_desc = @@ -131,7 +131,21 @@ bool run_grouped_conv_bwd_weight(const ExecutionConfig& config, wei_device_buf.FromDevice(wei_device_result.mData.data()); - return ck::utils::check_err(wei_device_result.mData, wei_host_result.mData); + float max_accumulated_value = + *std::max_element(wei_host_result.mData.begin(), wei_host_result.mData.end()); + + const ck::index_t num_accums = out.GetElementSize() / conv_param.K_; + const ck::index_t num_accums_split_k = split_k; + double rtol = ck::utils::get_relative_threshold( + num_accums / num_accums_split_k); + double atol = ck::utils::get_absolute_threshold( + max_accumulated_value / num_accums_split_k, num_accums / num_accums_split_k); + + return ck::utils::check_err(wei_device_result.mData, + wei_host_result.mData, + "Error: Incorrect results!", + rtol, + atol); } else if(config.do_verification == 2) { diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_wmma_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_wmma_cshuffle_v3.hpp new file mode 100644 index 0000000000..2a1a210398 --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_wmma_cshuffle_v3.hpp @@ -0,0 +1,764 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include +#include + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_batched_gemm_multi_d.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" +#include "ck/host_utility/flush_cache.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +#endif + kernel_batched_gemm_multi_d_wmma_cshuffle_v3( + typename GridwiseGemm::Argument karg, // This works for now but it actually receives a + // DeviceBatchedGemm_Wmma_CShuffleV3::Argument + // argument through implicit conversion to base class! + const ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch) +{ +#if(defined(__gfx11__) || defined(__gfx12__)) +#if defined(__gfx11__) + // gfx11 does not support *_atomic_pk_add_f16/bf16 instructions + using EDataType = remove_cvref_t>; + if constexpr(!(EGlobalMemoryDataOperation == InMemoryDataOperationEnum::AtomicAdd && + (std::is_same_v || + std::is_same_v))) + { +#endif + // The normal approach to batching would be to increase the grid size by just stretching out + // the grid Z dimension (which is the outermost dimension), but this depends on lower level + // functions not directly using the Z dimension for other calculations. As it turns out, k + // batching does rely directly on blockIdx.Z through SplitKBatchOffset. Therefore, for now + // we will use the grid Y dimension for batching. This may be a bit fragile. + const index_t g_idx = amd_wave_read_first_lane(blockIdx.y); + + const long_index_t a_batch_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx)); + const long_index_t b_batch_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx)); + const auto ds_batch_offset = compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx); + const long_index_t c_batch_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetCPtrOffset(g_idx)); + + constexpr index_t LDS_size = GridwiseGemm::template GetSharedMemoryNumberOfByte< + typename GridwiseGemm::EpilogueCShuffle>(); + __shared__ char p_shared[LDS_size]; + + auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z); + + static_for<0, GridwiseGemm::NumATensor, 1>{}( + [&](auto i) { splitk_batch_offset.a_k_split_offset[i] += a_batch_offset; }); + + static_for<0, GridwiseGemm::NumBTensor, 1>{}( + [&](auto i) { splitk_batch_offset.b_k_split_offset[i] += b_batch_offset; }); + + splitk_batch_offset.c_reduce_offset += c_batch_offset; + + // populate pointer, desc for Ds + static_for<0, GridwiseGemm::NumDTensor, 1>{}([&](auto i) { + // D pointer + karg.p_ds_grid(i) = karg.p_ds_grid(i) + ds_batch_offset[i]; + }); + + auto epilogue_args = typename GridwiseGemm::EpilogueCShuffle{}; + + GridwiseGemm::template Run( + p_shared, splitk_batch_offset, karg, epilogue_args); +#if defined(__gfx11__) + } +#endif +#else + ignore = karg; + ignore = compute_ptr_offset_of_batch; +#endif +} + +template +struct DeviceBatchedGemmMultiD_Wmma_CShuffleV3 + : public DeviceBatchedGemmV2MultiD +{ + using CDEShuffleBlockTransferScalarPerVectors_ = CDEShuffleBlockTransferScalarPerVectors; + using CDataType_ = EDataType; + // GridwiseGemm + using GridwiseGemm = GridwiseGemm_wmma_cshuffle_v3< + ALayout, + BLayout, + DsLayout, + ELayout, + Tuple, + Tuple, + AccDataType, + CShuffleDataType, + DsDataType, + EDataType, + AElementwiseOperation, + BElementwiseOperation, + CDEElementwiseOperation, + GemmSpec, + BlockSize, + MPerBlock, + NPerBlock, + KPerBlock, + AK1, + BK1, + MPerWmma, + NPerWmma, + MRepeat, + NRepeat, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + false, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + false, + BBlockLdsExtraN, + CShuffleMRepeatPerShuffle, + CShuffleNRepeatPerShuffle, + CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CDEShuffleBlockTransferScalarPerVectors, + BlkGemmPipeSched, + BlkGemmPipelineVer, + ComputeTypeA, + ComputeTypeB, + false, + false>; + + struct ComputePtrOffsetOfStridedBatch + { + ComputePtrOffsetOfStridedBatch() = default; + ComputePtrOffsetOfStridedBatch( + index_t BatchStrideA, + index_t BatchStrideB, + std::array BatchStrideDs, + index_t BatchStrideC) + : BatchStrideA_(BatchStrideA), + BatchStrideB_(BatchStrideB), + BatchStrideDs_(BatchStrideDs), + BatchStrideC_(BatchStrideC) + { + } + + __host__ __device__ constexpr long_index_t GetAPtrOffset(index_t g_idx) const + { + return static_cast(BatchStrideA_) * g_idx; + } + + __host__ __device__ constexpr long_index_t GetBPtrOffset(index_t g_idx) const + { + return static_cast(BatchStrideB_) * g_idx; + } + + __host__ __device__ constexpr auto GetDsPtrOffset(index_t g_idx) const + { + std::array ds_offset_; + + static_for<0, GridwiseGemm::NumDTensor, 1>{}([&](auto i) { + ds_offset_[i] = static_cast(BatchStrideDs_[i]) * g_idx; + }); + + return ds_offset_; + } + + __host__ __device__ constexpr long_index_t GetCPtrOffset(index_t g_idx) const + { + return static_cast(BatchStrideC_) * g_idx; + } + + private: + index_t BatchStrideA_; + index_t BatchStrideB_; + std::array BatchStrideDs_; + index_t BatchStrideC_; + }; + + struct Argument : public GridwiseGemm::Argument + { + index_t Batch; + ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch; + + Argument() = default; + Argument(const ADataType* p_a_grid_, + const BDataType* p_b_grid_, + std::array p_ds_grid_, + EDataType* p_e_grid_, + index_t M_, + index_t N_, + index_t K_, + index_t StrideA_, + index_t StrideB_, + std::array StrideDs_, + index_t StrideE_, + index_t BatchStrideA_, + index_t BatchStrideB_, + const std::array& BatchStrideDs_, + index_t BatchStrideE_, + index_t Batch_, + AElementwiseOperation a_element_op_, + BElementwiseOperation b_element_op_, + CDEElementwiseOperation cde_element_op_, + index_t KBatch_) + : GridwiseGemm::Argument{std::array{p_a_grid_}, + std::array{p_b_grid_}, + p_ds_grid_, + p_e_grid_, + M_, + N_, + K_, + std::array{StrideA_}, + std::array{StrideB_}, + StrideDs_, + StrideE_, + KBatch_, + a_element_op_, + b_element_op_, + cde_element_op_, + false}, + Batch{Batch_}, + compute_ptr_offset_of_batch{ + BatchStrideA_, BatchStrideB_, BatchStrideDs_, BatchStrideE_} + { + } + template + void SetEPointer(void* ptr) + { + this->p_e_grid = static_cast(ptr); + } + }; + + struct ActiveWorkgroupsPerCU + { + ActiveWorkgroupsPerCU() + { + constexpr int dynamic_smem_size = 0; + int max_occupancy = 0; + + constexpr index_t minimum_occupancy = []() { + if constexpr(BlkGemmPipeSched == BlockGemmPipelineScheduler::Interwave) + { + return 2; + } + else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) + { + return (MPerBlock * NPerBlock / BlockSize <= 128) ? 2 : 1; + } + else + { + return 1; + } + }(); + + hip_check_error(hipOccupancyMaxActiveBlocksPerMultiprocessor( + &max_occupancy, + kernel_batched_gemm_multi_d_wmma_cshuffle_v3, + BlockSize, + dynamic_smem_size)); + + max_occupancy_ = std::max(1, max_occupancy); + } + int max_occupancy_; + }; + + // Invoker + struct Invoker : public BaseInvoker + { + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + if(stream_config.log_level_ > 0) + { + arg.Print(); + } + + if(!GridwiseGemm::CheckValidity(arg)) + { + throw std::runtime_error("wrong! GridwiseGemm has invalid setting"); + } + + index_t gdx, gdy, gdz; + std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(arg.M, arg.N, arg.KBatch); + + gdy *= arg.Batch; + + float ave_time = 0; + + index_t k_grain = arg.KBatch * KPerBlock; + index_t K_split = (arg.K + k_grain - 1) / k_grain * KPerBlock; + + const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split); + + const auto Run = [&](const auto& kernel) { + if(stream_config.flush_cache) + { + + Argument arg_ = arg; + + const auto a_grid_desc_ak0_m_ak1 = GridwiseGemm::MakeAsGridDescriptor_AK0_M_AK1( + arg_.M, arg_.MPadded, arg_.K, arg_.KPadded, arg_.StrideAs, arg_.AK0); + const auto b_grid_desc_bk0_n_bk1 = GridwiseGemm::MakeBsGridDescriptor_BK0_N_BK1( + arg_.K, arg_.KPadded, arg_.N, arg_.NPadded, arg_.StrideBs, arg_.BK0); + + // Packed sizes are 1 for all implemented data types but we include it anyway + // for future compatibility. + std::array size_as_buffers; + size_as_buffers[0] = arg_.Batch * + a_grid_desc_ak0_m_ak1[Number<0>{}].GetElementSpaceSize() * + sizeof(ADataType) / GridwiseGemm::APackedSize; + + std::array size_bs_buffers; + size_bs_buffers[0] = arg_.Batch * + b_grid_desc_bk0_n_bk1[Number<0>{}].GetElementSpaceSize() * + sizeof(BDataType) / GridwiseGemm::BPackedSize; + + const auto ds_grid_desc_m_n = GridwiseGemm::MakeDsGridDescriptor_M_N( + arg_.M, arg_.MPadded, arg_.N, arg_.NPadded, arg_.StrideDs); + + std::array size_ds_buffers; + static_for<0, GridwiseGemm::NumDTensor, 1>{}([&](auto i) { + using DDataType = remove_cvref_t>; + size_ds_buffers[i] = + ds_grid_desc_m_n[i].GetElementSpaceSize() * sizeof(DDataType); + }); + ck::utility::RotatingMemWrapperMultiABD, + Tuple, + DsDataType> + rotating_mem(arg_, + stream_config.rotating_count, + size_as_buffers, + size_bs_buffers, + size_ds_buffers); + rotating_mem.Print(); + + auto run_flush_cache = [&]() { + // flush icache + ck::utility::flush_icache(); + // rotating mem + rotating_mem.Next(); + // clear c mem + if(arg_.KBatch > 1) + HIP_CHECK_ERROR( + hipMemsetAsync(arg_.p_e_grid, + 0, + arg.Batch * arg_.M * arg_.N * sizeof(EDataType), + stream_config.stream_id_)); + }; + + ave_time = ck::utility::launch_and_time_kernel_with_preprocess( + stream_config, + run_flush_cache, + kernel, + dim3(gdx, gdy, gdz), + dim3(BlockSize), + 0, + arg_, + arg_.compute_ptr_offset_of_batch); + } + else + { + const auto clear_workspace = [&]() { + if(arg.KBatch > 1) + HIP_CHECK_ERROR( + hipMemsetAsync(arg.p_e_grid, + 0, + arg.Batch * arg.M * arg.N * sizeof(EDataType), + stream_config.stream_id_)); + }; + + ave_time = + launch_and_time_kernel_with_preprocess(stream_config, + clear_workspace, + kernel, + dim3(gdx, gdy, gdz), + dim3(BlockSize), + 0, + arg, + arg.compute_ptr_offset_of_batch); + } + }; + + constexpr index_t minimum_occupancy = []() { + if constexpr(BlkGemmPipeSched == BlockGemmPipelineScheduler::Interwave) + { + return 2; + } + else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) + { + return (MPerBlock * NPerBlock / BlockSize <= 128) ? 2 : 1; + } + else + { + return 1; + } + }(); + + if(has_main_k_block_loop) + { + // Tail number always full + if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 || + BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) + { + if(arg.KBatch > 1) + { + const auto kernel = kernel_batched_gemm_multi_d_wmma_cshuffle_v3< + GridwiseGemm, + ComputePtrOffsetOfStridedBatch, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy>; + Run(kernel); + } + else + { + const auto kernel = kernel_batched_gemm_multi_d_wmma_cshuffle_v3< + GridwiseGemm, + ComputePtrOffsetOfStridedBatch, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy>; + Run(kernel); + } + } + } + else + { + // Tail number always 1 + if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1) + { + if(arg.KBatch > 1) + { + const auto kernel = kernel_batched_gemm_multi_d_wmma_cshuffle_v3< + GridwiseGemm, + ComputePtrOffsetOfStridedBatch, + false, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy>; + Run(kernel); + } + else + { + const auto kernel = kernel_batched_gemm_multi_d_wmma_cshuffle_v3< + GridwiseGemm, + ComputePtrOffsetOfStridedBatch, + false, + InMemoryDataOperationEnum::Set, + minimum_occupancy>; + Run(kernel); + } + } + } + + return ave_time; + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + static bool IsSupportedArgument(const Argument& arg) + { + if(!ck::is_gfx11_supported() && !ck::is_gfx12_supported()) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Unsupported: Architecture must be gfx11/gfx12." << std::endl; + } + return false; + } + + if constexpr(std::is_same_v || + std::is_same_v) + { + if(arg.KBatch > 1 && ck::is_gfx11_supported()) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Unsupported splitK on gfx11." << std::endl; + } + // gfx11 does not support *_atomic_pk_add_f16/bf16 instructions + return false; + } + } + + if constexpr(std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v) + { + if(ck::is_gfx11_supported()) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Unsupported f8 / bf8 on gfx11." << std::endl; + } + return false; + } + } + + if((arg.K % AK1 != 0 || arg.K % BK1 != 0) && !(GemmSpec == GemmSpecialization::MKPadding || + GemmSpec == GemmSpecialization::NKPadding || + GemmSpec == GemmSpecialization::MNKPadding || + GemmSpec == GemmSpecialization::KPadding)) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Unsupported K dimension without padding." << std::endl; + } + return false; + } + + return GridwiseGemm::CheckValidity(arg); + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument(const void* p_a, + const void* p_b, + std::array p_ds, + void* p_e, + index_t M, + index_t N, + index_t K, + index_t Batch, + index_t StrideA, + index_t StrideB, + std::array StrideDs, + index_t StrideE, + index_t BatchStrideA, + index_t BatchStrideB, + const std::array& BatchStrideDs, + index_t BatchStrideE, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op, + index_t KBatch = 1) + { + return Argument{static_cast(p_a), + static_cast(p_b), + p_ds, + static_cast(p_e), + M, + N, + K, + StrideA, + StrideB, + StrideDs, + StrideE, + BatchStrideA, + BatchStrideB, + BatchStrideDs, + BatchStrideE, + Batch, + a_element_op, + b_element_op, + cde_element_op, + KBatch}; + } + + static auto MakeInvoker() { return Invoker{}; } + + // polymorphic + std::unique_ptr + MakeArgumentPointer(const void* p_a, + const void* p_b, + const std::array& p_ds, + void* p_e, + index_t M, + index_t N, + index_t K, + index_t Batch, + index_t StrideA, + index_t StrideB, + const std::array& StrideDs, + index_t StrideE, + index_t BatchStrideA, + index_t BatchStrideB, + const std::array& BatchStrideDs, + index_t BatchStrideE, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op, + index_t KBatch = 1) override + { + return std::make_unique(static_cast(p_a), + static_cast(p_b), + p_ds, + static_cast(p_e), + M, + N, + K, + StrideA, + StrideB, + StrideDs, + StrideE, + BatchStrideA, + BatchStrideB, + BatchStrideDs, + BatchStrideE, + Batch, + a_element_op, + b_element_op, + cde_element_op, + KBatch); + } + + // polymorphic + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + // polymorphic + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + std::map BlkGemmPipelineSchedulerToString{ + {BlockGemmPipelineScheduler::Intrawave, "Intrawave"}, + {BlockGemmPipelineScheduler::Interwave, "Interwave"}}; + + std::map BlkGemmPipelineVersionToString{ + {BlockGemmPipelineVersion::v1, "v1"}, + {BlockGemmPipelineVersion::v2, "v2"}, + {BlockGemmPipelineVersion::v3, "v3"}, + {BlockGemmPipelineVersion::v4, "v4"}, + {BlockGemmPipelineVersion::v5, "v5"}}; + + // clang-format off + str << "DeviceBatchedGemmMultipleD_Wmma_CShuffleV3" + << "<" + << getGemmSpecializationString(GemmSpec) << ", " + << std::string(ALayout::name)[0] + << std::string(BLayout::name)[0] + << std::string(ELayout::name)[0] + << ">" + << " BlkSize: " + << BlockSize << ", " + << "BlkTile: " + << MPerBlock<<"x"< + void SetEPointer(void* ptr) + { + this->p_c_grid = static_cast(ptr); + } }; using Argument = ArgumentBase; diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_dl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_dl.hpp index 2152a72105..b52502eb45 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_dl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_dl.hpp @@ -18,6 +18,7 @@ #include "ck/tensor_operation/gpu/device/matrix_padder.hpp" #include "ck/host_utility/device_prop.hpp" #include "ck/host_utility/kernel_launch.hpp" +#include "ck/tensor_operation/gpu/device/impl/split_k_arg.hpp" namespace ck { namespace tensor_operation { @@ -807,7 +808,7 @@ struct DeviceGroupedConvBwdWeight_Dl : public DeviceGroupedConvBwdWeight( a_g_n_c_wis_lengths, // input @@ -915,7 +917,6 @@ struct DeviceGroupedConvBwdWeight_Dl : public DeviceGroupedConvBwdWeight& conv_filter_dilations_; const std::array& input_left_pads_; const std::array& input_right_pads_; - index_t k_batch_; }; // Invoker diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_explicit_xdl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_explicit.hpp similarity index 94% rename from include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_explicit_xdl.hpp rename to include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_explicit.hpp index eea8640151..640b373b66 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_explicit_xdl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_explicit.hpp @@ -32,7 +32,7 @@ template -struct DeviceGroupedConvBwdWeight_Explicit_Xdl +struct DeviceGroupedConvBwdWeight_Explicit : public DeviceGroupedConvBwdWeight; - struct Argument : public BaseArgument + struct Argument : public BaseArgument, public ArgumentSplitK { using GemmArgument = typename DeviceGemmV3Op::Argument; @@ -153,11 +153,11 @@ struct DeviceGroupedConvBwdWeight_Explicit_Xdl std::tie(gdx, gdy, gdz) = DeviceGemmV3Op::GridwiseGemm::CalculateGridSize(M, N, BatchSize); const index_t grid_size = gdx * gdy * gdz; - split_k_ = get_best_occupancy_k_batch_value(max_occupancy, grid_size); + k_batch_ = get_best_occupancy_k_batch_value(max_occupancy, grid_size); } else { - split_k_ = split_k; + k_batch_ = split_k; } } else @@ -170,12 +170,12 @@ struct DeviceGroupedConvBwdWeight_Explicit_Xdl std::tie(gdx, gdy, gdz) = DeviceGemmV3Op::GridwiseGemm::CalculateGridSize(M, N, BatchSize); const index_t grid_size = gdx * gdy * gdz; - split_k_ = get_best_occupancy_k_batch_value(max_occupancy, grid_size); + k_batch_ = get_best_occupancy_k_batch_value(max_occupancy, grid_size); } else #endif { - split_k_ = split_k; + k_batch_ = split_k; } } @@ -213,7 +213,7 @@ struct DeviceGroupedConvBwdWeight_Explicit_Xdl out_element_op, in_element_op, wei_element_op, - split_k_}; + k_batch_}; } else { @@ -236,7 +236,7 @@ struct DeviceGroupedConvBwdWeight_Explicit_Xdl out_element_op, in_element_op, wei_element_op, - split_k_}; + k_batch_}; } } @@ -273,7 +273,6 @@ struct DeviceGroupedConvBwdWeight_Explicit_Xdl bool is_filter_data_packed; CElementwiseGridDesc elementwise_desc_; Block2TileMapElementwise elementwise_block_2_ctile_map_; - ck::index_t split_k_; }; // Invoker @@ -288,8 +287,8 @@ struct DeviceGroupedConvBwdWeight_Explicit_Xdl { // Modify to use workspace as output GemmArgument explicit_gemm_args_with_workspace = arg.explicit_gemm_args; - explicit_gemm_args_with_workspace.p_c_grid = - static_cast(arg.p_workspace_); + explicit_gemm_args_with_workspace.template SetEPointer( + arg.p_workspace_); float avg_time = explicit_gemm_op.Run(explicit_gemm_args_with_workspace, stream_config); const index_t grid_size = @@ -342,7 +341,7 @@ struct DeviceGroupedConvBwdWeight_Explicit_Xdl #if DISABLE_SPLIT_K_AUTODEDUCE_FOR_ONE_STAGE_KERNELS if constexpr(!IsTwoStageNeeded) { - if(arg.split_k_ < 0) + if(arg.k_batch_ < 0) { return false; } @@ -353,6 +352,10 @@ struct DeviceGroupedConvBwdWeight_Explicit_Xdl { if constexpr(!is_NHWGC_GKYXC_NHWGK()) { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Unsupported layout." << std::endl; + } return false; } } @@ -360,11 +363,19 @@ struct DeviceGroupedConvBwdWeight_Explicit_Xdl { if constexpr(!is_NDHWGC_GKZYXC_NDHWGK()) { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Unsupported layout." << std::endl; + } return false; } } else { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Unsupported layout." << std::endl; + } return false; } @@ -374,6 +385,10 @@ struct DeviceGroupedConvBwdWeight_Explicit_Xdl if(!(arg.filter_spatial_lengths_[i] == 1 && arg.conv_filter_strides_[i] == 1 && arg.input_left_pads_[i] == 0 && arg.input_right_pads_[i] == 0)) { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Unsupported stride / pad." << std::endl; + } return false; } } @@ -381,6 +396,10 @@ struct DeviceGroupedConvBwdWeight_Explicit_Xdl { if(!arg.is_filter_data_packed) { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Unsupported: Filter data must be packed." << std::endl; + } return false; } // Check this here, it allows to use other instances from factory even diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_wmma_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_wmma_cshuffle_v3.hpp new file mode 100644 index 0000000000..86e8defb83 --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_wmma_cshuffle_v3.hpp @@ -0,0 +1,1258 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include +#include +#include + +#include "ck/utility/common_header.hpp" + +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_conv_bwd_weight_multiple_d.hpp" +#include "ck/tensor_operation/operator_transform/transform_conv_bwd_weight_to_gemm.hpp" +#include "ck/tensor_operation/operator_transform/transform_conv_bwd_weight_to_gemm_v2.hpp" +#include "ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_elementwise_2d.hpp" +#include +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp" +#include "ck/tensor_operation/gpu/device/impl/split_k_arg.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" +#include "ck/host_utility/flush_cache.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +#endif + kernel_grouped_conv_bwd_weight_wmma_cshuffle_v3( + typename GridwiseGemm::Argument karg, + const AGridDesc_AK0_M_K1 a_grid_desc_ak0_m_ak1, + const BGridDesc_BK0_N_K1 b_grid_desc_bk0_n_bk1, + const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock + c_grid_desc_mblock_mperblock_nblock_nperblock, + const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch, + const index_t num_k_per_block) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__)) +#if defined(__gfx11__) + // gfx11 does not support *_atomic_pk_add_f16/bf16 instructions + using c_data_type = remove_cvref_t>; + if constexpr(!(CGlobalMemoryDataOperation == InMemoryDataOperationEnum::AtomicAdd && + (std::is_same_v || + std::is_same_v))) + { +#endif + + constexpr index_t LDS_size = GridwiseGemm::template GetSharedMemoryNumberOfByte< + typename GridwiseGemm::EpilogueCShuffle>(); + __shared__ char p_shared[LDS_size]; + + auto epilogue_args = typename GridwiseGemm::EpilogueCShuffle{}; + + GridwiseGemm::template Run(p_shared, + a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + c_grid_desc_mblock_mperblock_nblock_nperblock, + compute_ptr_offset_of_batch, + num_k_per_block, + karg, + epilogue_args); + +#if defined(__gfx11__) + } +#endif +#else + ignore = karg; + ignore = a_grid_desc_ak0_m_ak1; + ignore = b_grid_desc_bk0_n_bk1; + ignore = c_grid_desc_mblock_mperblock_nblock_nperblock; + ignore = compute_ptr_offset_of_batch; + ignore = num_k_per_block; +#endif // end of if (defined(__gfx9__) +} + +// out[N, Ho, Wo, K] = in[N, Hi, Wi, C] * wei[K, Y, X, C] +template +struct DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3 + : public DeviceGroupedConvBwdWeightMultipleD +{ + using DeviceOp = DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3; + + using ADataType = OutDataType; + using BDataType = InDataType; + using EDataType = WeiDataType; + + static constexpr index_t NumDTensor = DsLayout::Size(); + + using AElementwiseOperation = OutElementwiseOperation; + using BElementwiseOperation = InElementwiseOperation; + using CDEElementwiseOperation = WeiElementwiseOperation; + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + + static constexpr GemmSpecialization GemmSpec = GemmSpecialization::Default; + static constexpr auto ABK1Number = Number{}; + + static constexpr auto conv_to_gemm_transformer = + TransformConvBwdWeightToGemmV2{}; + + static constexpr index_t MaxScalarPerVectorFP32 = 4; + static constexpr index_t WorkspaceInOutScalarPerVector = + is_same_v + ? math::min(CShuffleBlockTransferScalarPerVector_NPerBlock, MaxScalarPerVectorFP32) + : CShuffleBlockTransferScalarPerVector_NPerBlock; + + template ::type = false> + static auto GetABCGridDesc() + { + const ck::index_t dim = 1; + const ck::index_t batch = 1; + const std::array lengths{1}; + const std::array strides{1, 1, 1, 1}; + const std::array params{1}; + return conv_to_gemm_transformer.template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<1>( + dim, + dim, + dim, + lengths, + lengths, + lengths, + strides, + strides, + strides, + params, + params, + params, + params, + batch); + } + + template ::type = false> + static auto GetABCGridDesc() + { + const ck::index_t dim = 1; + const ck::index_t batch = 1; + const std::array lengths{1, 1}; + const std::array strides{1, 1, 1, 1, 1}; + const std::array params{1, 1}; + return conv_to_gemm_transformer.template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<2>( + dim, + dim, + dim, + lengths, + lengths, + lengths, + strides, + strides, + strides, + params, + params, + params, + params, + batch); + } + + template ::type = false> + static auto GetABCGridDesc() + { + const ck::index_t dim = 1; + const ck::index_t batch = 1; + const std::array lengths{1, 1, 1}; + const std::array strides{1, 1, 1, 1, 1, 1}; + const std::array params{1, 1, 1}; + return conv_to_gemm_transformer.template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<3>( + dim, + dim, + dim, + lengths, + lengths, + lengths, + strides, + strides, + strides, + params, + params, + params, + params, + batch); + } + + using ABCGridDescs = decltype(GetABCGridDesc()); + + using AGridDesc_K0_M_K1 = remove_cvref_t; + using BGridDesc_K0_N_K1 = remove_cvref_t; + using CGridDesc_M_N = remove_cvref_t; + + using GridwiseGemm = GridwiseGemm_wmma_cshuffle_v3< + tensor_layout::gemm::ColumnMajor, + tensor_layout::gemm::RowMajor, + Tuple<>, + tensor_layout::gemm::RowMajor, + Tuple, + Tuple, + AccDataType, + AccDataType, + Tuple<>, + AccDataType, + AElementwiseOperation, + BElementwiseOperation, + element_wise::PassThrough, // CDEElementwiseOperations + GemmSpec, + BlockSize, + MPerBlock, + NPerBlock, + KPerBlock, + ABK1, + ABK1, + MPerWmma, + NPerWmma, + MRepeat, + NRepeat, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + false, + ABlockLdsAddExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + false, + BBlockLdsAddExtraN, + CShuffleMRepeatPerShuffle, + CShuffleNRepeatPerShuffle, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + Sequence, + BlkGemmPipeSched, + BlkGemmPipelineVer, + ComputeTypeA, + ComputeTypeB, + false, // permuteA + false, // permuteB + false, // IsBPreShuffled + true>; // ForceThreadTileTransfer + + static constexpr auto MakeElementwiseInputSequence() + { + return generate_sequence_v2( + [&](auto) constexpr { return Number{}; }, + Number{}); + } + + static constexpr auto GetDsGridPointerTuple() + { + return generate_tuple( + [&](auto i) { + using DDataType = remove_cvref_t>; + return static_cast(nullptr); + }, + Number{}); + } + + template ::type = false> + static auto MakeDsGridDescriptor_M_N( + const std::array, NumDTensor>& ds_g_k_c_xs_lengths, + const std::array, NumDTensor>& ds_g_k_c_xs_strides) + { + return generate_tuple( + [&](auto i) { + const index_t K = ds_g_k_c_xs_lengths[i][I1]; + const index_t C = ds_g_k_c_xs_lengths[i][I2]; + const index_t X = ds_g_k_c_xs_lengths[i][I3]; + const index_t CStride = ds_g_k_c_xs_strides[I2]; + const index_t KStride = ds_g_k_c_xs_strides[I1]; + + const auto wei_grid_desc = make_naive_tensor_descriptor( + make_tuple(K, X * C), make_tuple(KStride, CStride)); + + if constexpr(ConvBackwardWeightSpecialization == + device::ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0) + { + return wei_grid_desc; + } + else + { + const index_t GemmM = K; + const index_t GemmN = C * X; + const auto PadGemmM = + GemmM % MPerBlock == 0 ? 0 : MPerBlock - GemmM % MPerBlock; + const auto PadGemmN = + GemmN % NPerBlock == 0 ? 0 : NPerBlock - GemmN % NPerBlock; + + return transform_tensor_descriptor( + wei_grid_desc, + make_tuple(make_right_pad_transform(GemmM, PadGemmM), + make_right_pad_transform(GemmN, PadGemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + }, + Number{}); + } + + template ::type = false> + static auto MakeDsGridDescriptor_M_N( + const std::array, NumDTensor>& ds_g_k_c_xs_lengths, + const std::array, NumDTensor>& ds_g_k_c_xs_strides) + { + return generate_tuple( + [&](auto i) { + const index_t K = ds_g_k_c_xs_lengths[i][I1]; + const index_t C = ds_g_k_c_xs_lengths[i][I2]; + const index_t Y = ds_g_k_c_xs_lengths[i][I3]; + const index_t X = ds_g_k_c_xs_lengths[i][I4]; + + const auto wei_grid_desc = + conv_to_gemm_transformer.template make_wei_grid_desc( + K, Y, X, C, ds_g_k_c_xs_strides[i]); + + if constexpr(ConvBackwardWeightSpecialization == + device::ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0) + { + return wei_grid_desc; + } + else + { + const index_t GemmM = K; + const index_t GemmN = C * X * Y; + const auto PadGemmM = + GemmM % MPerBlock == 0 ? 0 : MPerBlock - GemmM % MPerBlock; + const auto PadGemmN = + GemmN % NPerBlock == 0 ? 0 : NPerBlock - GemmN % NPerBlock; + + return transform_tensor_descriptor( + wei_grid_desc, + make_tuple(make_right_pad_transform(GemmM, PadGemmM), + make_right_pad_transform(GemmN, PadGemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + }, + Number{}); + } + + template ::type = false> + static auto MakeDsGridDescriptor_M_N( + const std::array, NumDTensor>& ds_g_k_c_xs_lengths, + const std::array, NumDTensor>& ds_g_k_c_xs_strides) + { + return generate_tuple( + [&](auto i) { + const index_t K = ds_g_k_c_xs_lengths[i][I1]; + const index_t C = ds_g_k_c_xs_lengths[i][I2]; + const index_t Z = ds_g_k_c_xs_lengths[i][I3]; + const index_t Y = ds_g_k_c_xs_lengths[i][I4]; + const index_t X = ds_g_k_c_xs_lengths[i][I5]; + + const auto wei_grid_desc = + conv_to_gemm_transformer.template make_wei_grid_desc( + K, Z, Y, X, C, ds_g_k_c_xs_strides[i]); + + if constexpr(ConvBackwardWeightSpecialization == + device::ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0) + { + return wei_grid_desc; + } + else + { + const index_t GemmM = K; + const index_t GemmN = C * X * Y * Z; + const auto PadGemmM = + GemmM % MPerBlock == 0 ? 0 : MPerBlock - GemmM % MPerBlock; + const auto PadGemmN = + GemmN % NPerBlock == 0 ? 0 : NPerBlock - GemmN % NPerBlock; + + return transform_tensor_descriptor( + wei_grid_desc, + make_tuple(make_right_pad_transform(GemmM, PadGemmM), + make_right_pad_transform(GemmN, PadGemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + }, + Number{}); + } + + template + static void + InitElementwiseBatchStrides(const ComputePtrOffsetOfBatch& compute_ptr_offset_of_batch_, + std::array& input_batch_strides, + std::array& output_batch_strides) + { + input_batch_strides[I0] = compute_ptr_offset_of_batch_.BatchStrideC_; + output_batch_strides[I0] = compute_ptr_offset_of_batch_.BatchStrideC_; + + // input_batch_strides = {C, Ds...} + static_for<0, NumDTensor, 1>{}([&](auto i) { + input_batch_strides[i + 1] = compute_ptr_offset_of_batch_.BatchStrideDs_[i]; + }); + } + + using DsGridDesc_M_N = decltype(MakeDsGridDescriptor_M_N({}, {})); + using CDGridDesc_M_N = decltype(concat_tuple(Tuple{}, DsGridDesc_M_N{})); + using DsGridPointerTuple = decltype(GetDsGridPointerTuple()); + using CDDataTypes = decltype(concat_tuple(Tuple{}, DsGridPointerTuple{})); + using EGridDesc_M_N = CGridDesc_M_N; + static constexpr index_t ClusterLengthMPerBlock = + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock::At(1); + static constexpr index_t ClusterLengthNPerBlock = + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock::At(3); + using Block2TileMapElementwise = BlockToCTileMap_M00_N0_M01Adapt; + + using GridwiseElementwise = + GridwiseElementwise, + CDDataTypes, + Tuple, + Block2TileMapElementwise, + CDEElementwiseOperation, + BlockSize, + MPerBlock, + NPerBlock, + MPerBlock / ClusterLengthMPerBlock, + NPerBlock / ClusterLengthNPerBlock, + Sequence<0, 1>, + decltype(MakeElementwiseInputSequence()), + Sequence, + I1, + I1>; + + using CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = + decltype(GridwiseGemm::MakeDEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + CGridDesc_M_N{}, 1, 1)); + + struct Argument : public BaseArgument, public ArgumentSplitK + { + Argument( + const InDataType* p_in_grid, + WeiDataType* p_wei_grid, + const OutDataType* p_out_grid, + const std::array& p_ds, + const std::array& b_g_n_c_wis_lengths, // input + const std::array& b_g_n_c_wis_strides, + const std::array& e_g_k_c_xs_lengths, // weight + const std::array& e_g_k_c_xs_strides, + const std::array& a_g_n_k_wos_lengths, // output + const std::array& a_g_n_k_wos_strides, + const std::array, NumDTensor>& ds_g_k_c_xs_lengths, + const std::array, NumDTensor>& ds_g_k_c_xs_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads, + const ck::index_t M01, + const ck::index_t N01, + InElementwiseOperation in_element_op, + WeiElementwiseOperation wei_element_op, + OutElementwiseOperation out_element_op, + ck::index_t split_k) + : p_a_grid_{p_out_grid}, + p_b_grid_{p_in_grid}, + p_ds_grid_{}, + p_e_grid_{p_wei_grid}, + a_grid_desc_kbatch_k0_m_k1_{}, + b_grid_desc_kbatch_k0_n_k1_{}, + ce_grid_desc_m_n_{}, + c_grid_desc_mblock_mperblock_nblock_nperblock_{}, + compute_ptr_offset_of_batch_{}, + M01_{M01}, + N01_{N01}, + a_element_op_{out_element_op}, + b_element_op_{in_element_op}, + cde_element_op_{wei_element_op}, + Conv_G_{b_g_n_c_wis_lengths[0]}, + Conv_N_{b_g_n_c_wis_lengths[1]}, + Conv_K_{e_g_k_c_xs_lengths[1]}, + Conv_C_{b_g_n_c_wis_lengths[2]}, + input_spatial_lengths_{}, + filter_spatial_lengths_{}, + output_spatial_lengths_{}, + conv_filter_strides_{conv_filter_strides}, + input_left_pads_{input_left_pads}, + input_right_pads_{input_right_pads} + { + constexpr index_t spatial_offset = 3; + std::copy(begin(b_g_n_c_wis_lengths) + spatial_offset, + end(b_g_n_c_wis_lengths), + begin(input_spatial_lengths_)); + std::copy(begin(e_g_k_c_xs_lengths) + spatial_offset, + end(e_g_k_c_xs_lengths), + begin(filter_spatial_lengths_)); + std::copy(begin(a_g_n_k_wos_lengths) + spatial_offset, + end(a_g_n_k_wos_lengths), + begin(output_spatial_lengths_)); + +#if !DISABLE_SPLIT_K_AUTODEDUCE_FOR_ONE_STAGE_KERNELS + if(split_k < 0) + { + ck::index_t gemmM, gemmN, gemmK; + std::tie(gemmM, gemmN, gemmK) = + get_bwd_weight_gemm_sizes(a_g_n_k_wos_lengths, e_g_k_c_xs_lengths); + + const auto grid_size = + calculate_mn_grid_size(gemmM, gemmN) * Conv_G_; + k_batch_ = get_best_occupancy_k_batch_value(active_workgroups_per_cu.max_occupancy_, + grid_size); + + // Ensure that k_batch_ does not exceed the maximum value + // for the GEMM pipeline. + const auto k_batch_max = math::integer_divide_ceil((gemmK - 1), KPerBlock); + k_batch_ = std::min(k_batch_, k_batch_max); + + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "[SPLIT-K AUTODEDUCE] k_batch max value: " << k_batch_max + << std::endl; + std::cout << "[SPLIT-K AUTODEDUCE] Final k_batch value: " << k_batch_ + << std::endl; + } + } + else +#endif + { + k_batch_ = split_k; + } + + const auto descs = + conv_to_gemm_transformer + .template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N( + Conv_N_, + Conv_K_, + Conv_C_, + input_spatial_lengths_, + filter_spatial_lengths_, + output_spatial_lengths_, + b_g_n_c_wis_strides, + e_g_k_c_xs_strides, + a_g_n_k_wos_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + k_batch_); + + static_for<0, NumDTensor, 1>{}([&](auto i) { + using DLayout = remove_cvref_t>; + using DDataType = remove_cvref_t>; + + static_assert(is_same_v, "Not supported D data layout"); + + // D pointer + p_ds_grid_(i) = static_cast(p_ds[i]); + compute_ptr_offset_of_batch_.BatchStrideDs_(i) = ds_g_k_c_xs_strides[i][0]; + }); + + a_grid_desc_kbatch_k0_m_k1_ = descs[I0]; + b_grid_desc_kbatch_k0_n_k1_ = descs[I1]; + ce_grid_desc_m_n_ = descs[I2]; + + ds_grid_descs_tuple_ = + MakeDsGridDescriptor_M_N(ds_g_k_c_xs_lengths, ds_g_k_c_xs_strides); + + elementwise_block_2_ctile_map_ = Block2TileMapElementwise{ + ce_grid_desc_m_n_.GetLength(I0), ce_grid_desc_m_n_.GetLength(I1)}; + + // A/B/C Batch Stride + compute_ptr_offset_of_batch_.BatchStrideA_ = a_g_n_k_wos_strides[0]; + compute_ptr_offset_of_batch_.BatchStrideB_ = b_g_n_c_wis_strides[0]; + compute_ptr_offset_of_batch_.BatchStrideC_ = + Conv_K_ * Conv_C_ * + std::accumulate(begin(filter_spatial_lengths_), + end(filter_spatial_lengths_), + index_t{1}, + std::multiplies<>{}); + + const index_t GemmM = a_grid_desc_kbatch_k0_m_k1_.GetLength(I1); + const index_t GemmN = b_grid_desc_kbatch_k0_n_k1_.GetLength(I1); + + c_grid_desc_mblock_mperblock_nblock_nperblock_ = + GridwiseGemm::MakeDEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + ce_grid_desc_m_n_, + GridwiseGemm::CalculateMBlock(GemmM), + GridwiseGemm::CalculateNBlock(GemmN)); + } + + std::size_t GetWorkspaceSizeBytes() const + { + return sizeof(AccDataType) * ce_grid_desc_m_n_.GetElementSpaceSize() * Conv_G_; + } + + const ADataType* p_a_grid_; + const BDataType* p_b_grid_; + DsGridPointerTuple p_ds_grid_; + EDataType* p_e_grid_; + AGridDesc_K0_M_K1 a_grid_desc_kbatch_k0_m_k1_; + BGridDesc_K0_N_K1 b_grid_desc_kbatch_k0_n_k1_; + CGridDesc_M_N ce_grid_desc_m_n_; + CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock c_grid_desc_mblock_mperblock_nblock_nperblock_; + DsGridDesc_M_N ds_grid_descs_tuple_; + + Block2TileMapElementwise elementwise_block_2_ctile_map_; + + // for computing batch offset + ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch_; + + index_t M01_; + index_t N01_; + + OutElementwiseOperation a_element_op_; + InElementwiseOperation b_element_op_; + WeiElementwiseOperation cde_element_op_; + + // for checking IsSupportedArgument() + const index_t Conv_G_; + const index_t Conv_N_; + const index_t Conv_K_; + const index_t Conv_C_; + std::array input_spatial_lengths_; + std::array filter_spatial_lengths_; + std::array output_spatial_lengths_; + const std::array& conv_filter_strides_; + const std::array& input_left_pads_; + const std::array& input_right_pads_; + }; + + // Invoker + struct Invoker : public BaseInvoker + { + using Argument = DeviceOp::Argument; + + void ShowInfo(const Argument& arg) + { + std::cout << "arg.a_grid_desc_kbatch_k0_m_k1_{" + << arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I0) << ", " + << arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I1) << ", " + << arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I2) << ", " + << arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I3) << "}" << std::endl; + + std::cout << "arg.b_grid_desc_kbatch_k0_n_k1_{" + << arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I0) << ", " + << arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I1) << ", " + << arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I2) << ", " + << arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I3) << "}" << std::endl; + + std::cout << "arg.ce_grid_desc_m_n_{" << arg.ce_grid_desc_m_n_.GetLength(I0) << ", " + << arg.ce_grid_desc_m_n_.GetLength(I1) << "}" << std::endl; + } + + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + float ave_time = 0; + + const index_t GemmM = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I1); + const index_t GemmN = arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I1); + const index_t GemmK = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I0) * + arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I2); + + AccDataType* p_e_grid = type_convert(arg.p_workspace_); + ; + + // Convolution kernel dispatch + typename GridwiseGemm::Argument gemm_arg{ + std::array{arg.p_a_grid_}, + std::array{arg.p_b_grid_}, + std::array{}, // p_ds_grid_ + p_e_grid, + GemmM, + GemmN, + GemmK, + std::array{I0}, + std::array{I0}, + std::array{}, // StrideDs_ + I0, + arg.k_batch_, + AElementwiseOperation{}, + BElementwiseOperation{}, + element_wise::PassThrough{}}; // CElementwiseOperation + + index_t gdx, gdy, gdz; + std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize( + gemm_arg.M, gemm_arg.N, gemm_arg.KBatch, arg.Conv_G_); + + index_t k_grain = gemm_arg.KBatch * KPerBlock; + index_t K_split = (gemm_arg.K + k_grain - 1) / k_grain * KPerBlock; + const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split); + + const auto num_k_per_block = + arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(Number<0>{}) / gemm_arg.KBatch; + + const auto clear_workspace = [&]() { + hip_check_error(hipMemsetAsync( + p_e_grid, 0, arg.GetWorkspaceSizeBytes(), stream_config.stream_id_)); + }; + + const auto Run = [&](const auto& kernel) { + if(stream_config.flush_cache) + { + typename GridwiseGemm::Argument gemm_arg_ = gemm_arg; + + std::array size_as_buffers; + size_as_buffers[0] = arg.a_grid_desc_kbatch_k0_m_k1_.GetElementSpaceSize() * + sizeof(ADataType) / GridwiseGemm::APackedSize; + + std::array size_bs_buffers; + size_bs_buffers[0] = arg.b_grid_desc_kbatch_k0_n_k1_.GetElementSpaceSize() * + sizeof(BDataType) / GridwiseGemm::BPackedSize; + + std::array size_ds_buffers; + + ck::utility::RotatingMemWrapperMultiABD, + Tuple, + Tuple<>> + rotating_mem(gemm_arg_, + stream_config.rotating_count, + size_as_buffers, + size_bs_buffers, + size_ds_buffers); + rotating_mem.Print(); + + auto run_flush_cache = [&]() { + // flush icache + ck::utility::flush_icache(); + // rotating mem + rotating_mem.Next(); + }; + ave_time += ck::utility::launch_and_time_kernel_with_preprocess( + stream_config, + run_flush_cache, + kernel, + dim3(gdx, gdy, gdz), + dim3(BlockSize), + 0, + gemm_arg_, + arg.a_grid_desc_kbatch_k0_m_k1_, + arg.b_grid_desc_kbatch_k0_n_k1_, + arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.compute_ptr_offset_of_batch_, + num_k_per_block); + } + else + { + ave_time += launch_and_time_kernel_with_preprocess( + stream_config, + clear_workspace, + kernel, + dim3(gdx, gdy, gdz), + dim3(BlockSize), + 0, + gemm_arg, + arg.a_grid_desc_kbatch_k0_m_k1_, + arg.b_grid_desc_kbatch_k0_n_k1_, + arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.compute_ptr_offset_of_batch_, + num_k_per_block); + } + }; + + constexpr index_t minimum_occupancy = + BlkGemmPipeSched == BlockGemmPipelineScheduler::Intrawave ? 1 : 2; + + if(has_main_k_block_loop) + { + // Tail number always full + if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 || + BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) + { + if(gemm_arg.KBatch > 1) + { + const auto kernel = kernel_grouped_conv_bwd_weight_wmma_cshuffle_v3< + GridwiseGemm, + remove_reference_t, + remove_reference_t, + remove_reference_t< + DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, + ComputePtrOffsetOfStridedBatch, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy>; + Run(kernel); + } + else + { + const auto kernel = kernel_grouped_conv_bwd_weight_wmma_cshuffle_v3< + GridwiseGemm, + remove_reference_t, + remove_reference_t, + remove_reference_t< + DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, + ComputePtrOffsetOfStridedBatch, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy>; + Run(kernel); + } + } + else + { + // TODO: Implement + } + } + else + { + // Tail number always 1 + if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1) + { + if(gemm_arg.KBatch > 1) + { + const auto kernel = kernel_grouped_conv_bwd_weight_wmma_cshuffle_v3< + GridwiseGemm, + remove_reference_t, + remove_reference_t, + remove_reference_t< + DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, + ComputePtrOffsetOfStridedBatch, + false, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy>; + Run(kernel); + } + else + { + const auto kernel = kernel_grouped_conv_bwd_weight_wmma_cshuffle_v3< + GridwiseGemm, + remove_reference_t, + remove_reference_t, + remove_reference_t< + DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, + ComputePtrOffsetOfStridedBatch, + false, + InMemoryDataOperationEnum::Set, + minimum_occupancy>; + Run(kernel); + } + } + } + + auto launch_elementwise_kernel = [&]() { + const AccDataType* p_c_grid = type_convert(arg.p_workspace_); + const index_t grid_size = + arg.elementwise_block_2_ctile_map_.CalculateGridSize(arg.ce_grid_desc_m_n_) * + arg.Conv_G_; + + std::array input_batch_strides; + std::array output_batch_strides; + InitElementwiseBatchStrides( + arg.compute_ptr_offset_of_batch_, input_batch_strides, output_batch_strides); + + const auto kernel = kernel_batched_elementwise, + CDDataTypes, + ck::Tuple, + Block2TileMapElementwise, + CDEElementwiseOperation, + NumDTensor + I1, + I1>; + + return launch_and_time_kernel( + stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + concat_tuple(make_tuple(arg.ce_grid_desc_m_n_), arg.ds_grid_descs_tuple_), + make_tuple(arg.ce_grid_desc_m_n_), + concat_tuple(make_tuple(p_c_grid), arg.p_ds_grid_), + arg.p_e_grid_, + arg.elementwise_block_2_ctile_map_, + arg.cde_element_op_, + arg.Conv_G_, + input_batch_strides, + output_batch_strides); + }; + + ave_time += launch_elementwise_kernel(); + + return ave_time; + } + + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + static bool IsSupportedArgument(const Argument& arg) + { +#if DISABLE_SPLIT_K_AUTODEDUCE_FOR_ONE_STAGE_KERNELS + if(arg.k_batch_ < 0) + { + return false; + } +#endif + + const index_t GemmM = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I1); + const index_t GemmN = arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I1); + const index_t GemmK = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I0) * + arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I2); + + typename GridwiseGemm::Argument gemm_arg{std::array{nullptr}, // p_as_grid + std::array{nullptr}, // p_bs_grid + std::array{}, // p_ds_grid + nullptr, // p_e_grid + GemmM, // M + GemmN, // N + GemmK, // K + std::array{I0}, // StrideAs + std::array{I0}, // StrideBs + std::array{}, // StrideDs + I0, // StrideE + arg.k_batch_, + AElementwiseOperation{}, + BElementwiseOperation{}, + element_wise::PassThrough{}}; + + const auto num_k_loop = gemm_arg.AK0 / (KPerBlock / ABK1); + if constexpr(BlkGemmPipelineVer != BlockGemmPipelineVersion::v1) + { + if(num_k_loop <= GridwiseGemm::BlockwiseGemmPipe::PrefetchStages) + { + return false; + } + } + + if(!ck::is_gfx11_supported() && !ck::is_gfx12_supported()) + { + return false; + } + if constexpr(std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v) + { + if(ck::is_gfx11_supported()) + { + return false; + } + } + if constexpr(NDimSpatial == 1) + { + if constexpr(!is_GNWC_GKXC_GNWK()) + { + return false; + } + } + else if constexpr(NDimSpatial == 2) + { + if constexpr(!(is_NHWGC_GKYXC_NHWGK() || + is_GNHWC_GKYXC_GNHWK())) + { + return false; + } + } + else if constexpr(NDimSpatial == 3) + { + if constexpr(!(is_NDHWGC_GKZYXC_NDHWGK() || + is_GNDHWC_GKZYXC_GNDHWK())) + { + return false; + } + } + else + { + return false; + } + + if constexpr(ConvBackwardWeightSpecialization == + ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0) + { + // check if it's 1x1, stride=1 pad = 0 conv + for(int i = 0; i < NDimSpatial; i++) + { + if(!(arg.filter_spatial_lengths_[i] == 1 && arg.conv_filter_strides_[i] == 1 && + arg.input_left_pads_[i] == 0 && arg.input_right_pads_[i] == 0)) + { + return false; + } + } + } + + // vector load A/B matrix from global memory + if(!(ABlockTransferSrcVectorDim == 1 && BBlockTransferSrcVectorDim == 1 && + arg.Conv_K_ % ABlockTransferSrcScalarPerVector == 0 && + arg.Conv_C_ % BBlockTransferSrcScalarPerVector == 0)) + { + return false; + } + + // vector store C matrix into global memory + if(!(arg.Conv_C_ % CShuffleBlockTransferScalarPerVector_NPerBlock == 0 && + arg.Conv_C_ % WorkspaceInOutScalarPerVector == 0)) + { + return false; + } + + return GridwiseGemm::CheckValidity(gemm_arg); + } + + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument( + const InDataType* p_in_grid, + WeiDataType* p_wei_grid, + const OutDataType* p_out_grid, + const std::array& p_ds, + const std::array& b_g_n_c_wis_lengths, // input + const std::array& b_g_n_c_wis_strides, + const std::array& e_g_k_c_xs_lengths, // weight + const std::array& e_g_k_c_xs_strides, + const std::array& a_g_n_k_wos_lengths, // output + const std::array& a_g_n_k_wos_strides, + const std::array, NumDTensor>& ds_g_k_c_xs_lengths, + const std::array, NumDTensor>& ds_g_k_c_xs_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads, + InElementwiseOperation in_element_op, + WeiElementwiseOperation wei_element_op, + OutElementwiseOperation out_element_op, + const ck::index_t split_k) + { + return Argument{p_in_grid, + p_wei_grid, + p_out_grid, + p_ds, + b_g_n_c_wis_lengths, // input + b_g_n_c_wis_strides, + e_g_k_c_xs_lengths, // weight + e_g_k_c_xs_strides, + a_g_n_k_wos_lengths, // output + a_g_n_k_wos_strides, + ds_g_k_c_xs_lengths, + ds_g_k_c_xs_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + 1, + 1, + in_element_op, + wei_element_op, + out_element_op, + split_k}; + } + + static auto MakeInvoker() { return Invoker{}; } + + std::unique_ptr MakeArgumentPointer( + const void* p_in_grid, + void* p_wei_grid, + const void* p_out_grid, + const std::array& p_ds, + const std::array& b_g_n_c_wis_lengths, // input + const std::array& b_g_n_c_wis_strides, + const std::array& e_g_k_c_xs_lengths, // weight + const std::array& e_g_k_c_xs_strides, + const std::array& a_g_n_k_wos_lengths, // output + const std::array& a_g_n_k_wos_strides, + const std::array, NumDTensor>& ds_g_k_c_xs_lengths, + const std::array, NumDTensor>& ds_g_k_c_xs_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads, + InElementwiseOperation in_element_op, + WeiElementwiseOperation wei_element_op, + OutElementwiseOperation out_element_op, + const ck::index_t split_k) override + { + return std::make_unique(static_cast(p_in_grid), + static_cast(p_wei_grid), + static_cast(p_out_grid), + p_ds, + b_g_n_c_wis_lengths, // input + b_g_n_c_wis_strides, + e_g_k_c_xs_lengths, // weight + e_g_k_c_xs_strides, + a_g_n_k_wos_lengths, // output + a_g_n_k_wos_strides, + ds_g_k_c_xs_lengths, + ds_g_k_c_xs_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + 1, + 1, + in_element_op, + wei_element_op, + out_element_op, + split_k); + } + + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3" + << "<" + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << KPerBlock << ", " + << getConvBackwardWeightSpecializationString(ConvBackwardWeightSpecialization) << ", " + << ABK1 << ", " + << MRepeat << ", " + << NRepeat << ", " + << ABlockTransferSrcScalarPerVector << ", " + << ABlockTransferDstScalarPerVector_AK1 << ", " + << BBlockTransferSrcScalarPerVector << ", " + << BBlockTransferDstScalarPerVector_BK1 << ", " + << CShuffleMRepeatPerShuffle << ", " + << CShuffleNRepeatPerShuffle << ", " + << CShuffleBlockTransferScalarPerVector_NPerBlock + << ">"; + // clang-format on + + return str.str(); + } + + size_t GetWorkSpaceSize(const BaseArgument* p_arg) const override + { + auto arg = dynamic_cast(p_arg); + if(arg) + { + return arg->GetWorkspaceSizeBytes(); + } + else + throw std::runtime_error( + "The argument pointer is not an object of " + "DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3::Argument structure!"); + } + + void SetWorkSpacePointer(BaseArgument* p_arg, + void* p_workspace, + const StreamConfig& = StreamConfig{}) const override + { + auto p_arg_ = dynamic_cast(p_arg); + if(p_arg_) + { + p_arg_->p_workspace_ = p_workspace; + } + else + throw std::runtime_error( + "The argument pointer is not an object of " + "DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3::Argument structure!"); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_wmma_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_wmma_cshuffle_v3.hpp new file mode 100644 index 0000000000..37fe0b2c7b --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_wmma_cshuffle_v3.hpp @@ -0,0 +1,1578 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include +#include +#include + +#include "ck/utility/common_header.hpp" +#include "ck/utility/env.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_conv_bwd_weight.hpp" +#include "ck/tensor_operation/operator_transform/transform_conv_bwd_weight_to_gemm.hpp" +#include "ck/tensor_operation/operator_transform/transform_conv_bwd_weight_to_gemm_v2.hpp" +#include "ck/tensor_operation/operator_transform/transform_conv_ngchw_to_nhwgc.hpp" +#include "ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_elementwise_2d.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp" +#include +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp" +#include "ck/tensor_operation/gpu/device/impl/split_k_utils.hpp" +#include "ck/tensor_operation/gpu/device/impl/split_k_arg.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" +#include "ck/host_utility/flush_cache.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +#endif + kernel_grouped_conv_bwd_weight_wmma_cshuffle_v3( + typename GridwiseGemm::Argument karg, + const AGridDesc_AK0_M_K1 a_grid_desc_ak0_m_ak1, + const BGridDesc_BK0_N_K1 b_grid_desc_bk0_n_bk1, + const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock + c_grid_desc_mblock_mperblock_nblock_nperblock, + const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch, + const index_t num_k_per_block) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__)) + + constexpr index_t LDS_size = GridwiseGemm::template GetSharedMemoryNumberOfByte< + typename GridwiseGemm::EpilogueCShuffle>(); + __shared__ char p_shared[LDS_size]; + + auto epilogue_args = typename GridwiseGemm::EpilogueCShuffle{}; + + GridwiseGemm::template Run(p_shared, + a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + c_grid_desc_mblock_mperblock_nblock_nperblock, + compute_ptr_offset_of_batch, + num_k_per_block, + karg, + epilogue_args); +#else + ignore = karg; + ignore = a_grid_desc_ak0_m_ak1; + ignore = b_grid_desc_bk0_n_bk1; + ignore = c_grid_desc_mblock_mperblock_nblock_nperblock; + ignore = compute_ptr_offset_of_batch; + ignore = num_k_per_block; +#endif // end of if (defined(__gfx11__) || defined(__gfx12__)) +} + +template +struct DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3 + : public DeviceGroupedConvBwdWeight +{ + static_assert(is_same_v); + static_assert(is_same_v); + static_assert(is_same_v); + + using DeviceOp = DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3; + + using ADataType = OutDataType; + using BDataType = InDataType; + using EDataType = WeiDataType; + + // If NGCHW then ADataType must be equal to BDataType + static_assert(!(is_NGCHW_NGKHW() || + is_NGCDHW_NGKDHW()) || + is_same_v); + + using AElementwiseOperation = OutElementwiseOperation; + using BElementwiseOperation = InElementwiseOperation; + using CDEElementwiseOperation = WeiElementwiseOperation; + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + + static constexpr auto ABK1Number = Number{}; + + static constexpr auto conv_to_gemm_transformer_v2 = + TransformConvBwdWeightToGemmV2{}; + + static constexpr auto conv_to_gemm_transformer_v1 = + TransformConvBwdWeightToGemm{}; + + static constexpr index_t ClusterLengthMPerBlock = + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock::At(1); + static constexpr index_t ClusterLengthNPerBlock = + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock::At(3); + + static constexpr auto conv_ngchw_to_nhwgc_transformer = + TransformConvNGCHWToNHWGC{}; + + static constexpr GemmSpecialization GemmSpec = GemmSpecialization::Default; + + template ::type = false> + static auto GetABCGridDesc() + { + const ck::index_t dim = 1; + const ck::index_t batch = 1; + const std::array lengths{1, 1}; + const std::array strides{1, 1, 1, 1, 1}; + const std::array params{1, 1}; + return conv_to_gemm_transformer_v2 + .template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<2>(dim, + dim, + dim, + lengths, + lengths, + lengths, + strides, + strides, + strides, + params, + params, + params, + params, + batch); + } + + template ::type = false> + static auto GetABCGridDesc() + { + const ck::index_t dim = 1; + const ck::index_t batch = 1; + const std::array lengths{1, 1, 1}; + const std::array strides{1, 1, 1, 1, 1, 1}; + const std::array params{1, 1, 1}; + return conv_to_gemm_transformer_v2 + .template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<3>(dim, + dim, + dim, + lengths, + lengths, + lengths, + strides, + strides, + strides, + params, + params, + params, + params, + batch); + } + + template ::type = false> + static auto GetElementwiseCGridDesc() + { + const ck::index_t dim = 1; + const ck::index_t batch = 1; + const std::array lengths{1, 1}; + const std::array strides{1, 1, 1, 1, 1}; + const std::array params{1, 1}; + return conv_to_gemm_transformer_v1 + .template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<2>(dim, + dim, + dim, + lengths, + lengths, + lengths, + strides, + strides, + strides, + params, + params, + params, + params, + batch)[I2]; + } + + template ::type = false> + static auto GetElementwiseCGridDesc() + { + const ck::index_t dim = 1; + const ck::index_t batch = 1; + const std::array lengths{1, 1, 1}; + const std::array strides{1, 1, 1, 1, 1, 1}; + const std::array params{1, 1, 1}; + return conv_to_gemm_transformer_v1 + .template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<3>(dim, + dim, + dim, + lengths, + lengths, + lengths, + strides, + strides, + strides, + params, + params, + params, + params, + batch)[I2]; + } + + using NGCHWTransposeDescType = + remove_cvref_t({}, {}))>; + using NHWGCTransposeDescType = + remove_cvref_t({}, {}))>; + using GKCYXTransposeDescType = + remove_cvref_t({}, {}))>; + using GKYXCTransposeDescType = + remove_cvref_t({}, {}))>; + + using ABCGridDescs = decltype(GetABCGridDesc()); + + using AGridDesc_K0_M_K1 = remove_cvref_t; + using BGridDesc_K0_N_K1 = remove_cvref_t; + using CGridDesc_M_N = remove_cvref_t; + using CElementwiseGridDesc_M_N = + remove_cvref_t())>; + + using GridwiseGemm = GridwiseGemm_wmma_cshuffle_v3< + tensor_layout::gemm::ColumnMajor, + tensor_layout::gemm::RowMajor, + Tuple<>, + tensor_layout::gemm::RowMajor, + Tuple, + Tuple, + AccDataType, + AccDataType, + Tuple<>, + AccDataType, + AElementwiseOperation, + BElementwiseOperation, + CDEElementwiseOperation, + GemmSpec, + BlockSize, + MPerBlock, + NPerBlock, + KPerBlock, + ABK1, + ABK1, + MPerWmma, + NPerWmma, + MRepeat, + NRepeat, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + false, + ABlockLdsAddExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + false, + BBlockLdsAddExtraN, + CShuffleMRepeatPerShuffle, + CShuffleNRepeatPerShuffle, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + Sequence, + BlkGemmPipeSched, + BlkGemmPipelineVer, + ComputeTypeA, + ComputeTypeB, + false, // permuteA + false, // permuteB + false, // IsBPreShuffled + true>; // ForceThreadTileTransfer + + using Block2TileMapElementwise = BlockToCTileMap_M00_N0_M01Adapt; + + using GridwiseElementwiseCast = + GridwiseElementwise, + Tuple, + Tuple, + Tuple, + Block2TileMapElementwise, + CDEElementwiseOperation, + BlockSize, + MPerBlock, + NPerBlock, + MPerBlock / ClusterLengthMPerBlock, + NPerBlock / ClusterLengthNPerBlock, + Sequence<0, 1>, + Sequence, + Sequence, + I1, + I1>; + + // NPerBlock is used for the first dim which is store dimension + // (with CShuffleBlockTransferScalarPerVector_NPerBlock scalar per vector). + // CShuffleBlockTransferScalarPerVector_NPerBlock is aligned to NPerBlock so + // it is more flexible to use this dim for store dimension with such scalar + // per vector. + using GridwiseElementwiseWeightTransposeCast = + GridwiseElementwise, + Tuple, + Tuple, + Tuple, + Block2TileMapElementwise, + CDEElementwiseOperation, + BlockSize, + MPerBlock, + NPerBlock, + MPerBlock / ClusterLengthMPerBlock, + NPerBlock / ClusterLengthNPerBlock, + Sequence<0, 1>, + Sequence, + Sequence<1>, + I1, + I0>; + + using GridwiseElementwiseTranspose = + GridwiseElementwise, + Tuple, + Tuple, + Tuple, + Block2TileMapElementwise, + element_wise::PassThrough, + BlockSize, + MPerBlock, + NPerBlock, + MPerBlock / ClusterLengthMPerBlock, + NPerBlock / ClusterLengthNPerBlock, + Sequence<1, 0>, + Sequence, + Sequence, + I1, + I0>; + + // Argument + using CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = + decltype(GridwiseGemm::MakeDEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + CGridDesc_M_N{}, 1, 1)); + + struct ActiveWorkgroupsPerCU + { + ActiveWorkgroupsPerCU() + { + constexpr int dynamic_smem_size = 0; + constexpr index_t minimum_occupancy = + BlkGemmPipeSched == BlockGemmPipelineScheduler::Intrawave ? 1 : 2; + int max_occupancy = 0; + + if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v4) + { + // TODO: implement + } + else + { + hip_check_error(hipOccupancyMaxActiveBlocksPerMultiprocessor( + &max_occupancy, + kernel_grouped_conv_bwd_weight_wmma_cshuffle_v3< + GridwiseGemm, + remove_reference_t, + remove_reference_t, + remove_reference_t, + ComputePtrOffsetOfStridedBatch, + NumGroupsToMerge, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy>, + BlockSize, + dynamic_smem_size)); + } + max_occupancy_ = std::max(1, max_occupancy); + } + int max_occupancy_; + }; + + struct Argument : public BaseArgument, public ArgumentSplitK + { + Argument(const InDataType* p_in_grid, + WeiDataType* p_wei_grid, + const OutDataType* p_out_grid, + const std::array& b_g_n_c_wis_lengths, // input + const std::array& b_g_n_c_wis_strides, + const std::array& e_g_k_c_xs_lengths, // weight + const std::array& e_g_k_c_xs_strides, + const std::array& a_g_n_k_wos_lengths, // output + const std::array& a_g_n_k_wos_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads, + const ck::index_t M01, + const ck::index_t N01, + InElementwiseOperation in_element_op, + WeiElementwiseOperation wei_element_op, + OutElementwiseOperation out_element_op, + ck::index_t split_k) + : p_a_grid_{p_out_grid}, + p_b_grid_{p_in_grid}, + p_e_grid_{p_wei_grid}, + a_grid_desc_k0_m_k1_{}, + b_grid_desc_k0_n_k1_{}, + ce_grid_desc_m_n_{}, + c_grid_desc_mblock_mperblock_nblock_nperblock_{}, + compute_ptr_offset_of_batch_{}, + M01_{M01}, + N01_{N01}, + a_element_op_{out_element_op}, + b_element_op_{in_element_op}, + cde_element_op_{wei_element_op}, + Conv_G_{b_g_n_c_wis_lengths[0]}, + Conv_N_{b_g_n_c_wis_lengths[1]}, + Conv_K_{e_g_k_c_xs_lengths[1]}, + Conv_C_{b_g_n_c_wis_lengths[2]}, + input_spatial_lengths_{}, + filter_spatial_lengths_{}, + output_spatial_lengths_{}, + conv_filter_strides_{conv_filter_strides}, + input_left_pads_{input_left_pads}, + input_right_pads_{input_right_pads} + { + static ActiveWorkgroupsPerCU active_workgroups_per_cu; + + constexpr index_t spatial_offset = 3; + std::copy(begin(b_g_n_c_wis_lengths) + spatial_offset, + end(b_g_n_c_wis_lengths), + begin(input_spatial_lengths_)); + std::copy(begin(e_g_k_c_xs_lengths) + spatial_offset, + end(e_g_k_c_xs_lengths), + begin(filter_spatial_lengths_)); + std::copy(begin(a_g_n_k_wos_lengths) + spatial_offset, + end(a_g_n_k_wos_lengths), + begin(output_spatial_lengths_)); + + std::array a_g_n_k_wos_strides_transposed = + conv_ngchw_to_nhwgc_transformer.TransposeInOutStrides(a_g_n_k_wos_lengths, + a_g_n_k_wos_strides); + std::array b_g_n_c_wis_strides_transposed = + conv_ngchw_to_nhwgc_transformer.TransposeInOutStrides(b_g_n_c_wis_lengths, + b_g_n_c_wis_strides); + std::array e_g_k_c_xs_strides_transposed = + conv_ngchw_to_nhwgc_transformer.TransposeWeiStrides(e_g_k_c_xs_lengths, + e_g_k_c_xs_strides); + + if(split_k < 0) + { + ck::index_t gemmM, gemmN, gemmK; + std::tie(gemmM, gemmN, gemmK) = + get_bwd_weight_gemm_sizes(a_g_n_k_wos_lengths, e_g_k_c_xs_lengths); + + const auto grid_size = calculate_mn_grid_size(gemmM, gemmN) * + Conv_G_ / NumGroupsToMerge; + k_batch_ = get_best_occupancy_k_batch_value(active_workgroups_per_cu.max_occupancy_, + grid_size); + + // Ensure that k_batch_ does not exceed the maximum value + // for the GEMM pipeline. + const auto k_batch_max = math::integer_divide_ceil((gemmK - 1), KPerBlock); + k_batch_ = std::min(k_batch_, k_batch_max); + + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "[SPLIT-K AUTODEDUCE] k_batch max value: " << k_batch_max + << std::endl; + std::cout << "[SPLIT-K AUTODEDUCE] Final k_batch value: " << k_batch_ + << std::endl; + } + } + else + { + k_batch_ = split_k; + } + + const auto descs = + conv_to_gemm_transformer_v2 + .template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N( + Conv_N_, + Conv_K_, + Conv_C_, + input_spatial_lengths_, + filter_spatial_lengths_, + output_spatial_lengths_, + b_g_n_c_wis_strides_transposed, + e_g_k_c_xs_strides_transposed, + a_g_n_k_wos_strides_transposed, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + k_batch_); + + a_grid_desc_k0_m_k1_ = descs[I0]; + b_grid_desc_k0_n_k1_ = descs[I1]; + ce_grid_desc_m_n_ = descs[I2]; + + ce_elementwise_grid_desc_m_n_ = + conv_to_gemm_transformer_v1 + .template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N( + Conv_N_, + Conv_K_, + Conv_C_, + input_spatial_lengths_, + filter_spatial_lengths_, + output_spatial_lengths_, + b_g_n_c_wis_strides, + e_g_k_c_xs_strides, + a_g_n_k_wos_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + k_batch_)[I2]; + + const index_t GemmM = a_grid_desc_k0_m_k1_.GetLength(I1); + const index_t GemmN = b_grid_desc_k0_n_k1_.GetLength(I1); + + // A/B/C Batch Stride + compute_ptr_offset_of_batch_.BatchStrideA_ = a_g_n_k_wos_strides_transposed[0]; + compute_ptr_offset_of_batch_.BatchStrideB_ = b_g_n_c_wis_strides_transposed[0]; + compute_ptr_offset_of_batch_.BatchStrideC_ = e_g_k_c_xs_strides_transposed[0]; + c_grid_desc_mblock_mperblock_nblock_nperblock_ = + GridwiseGemm::MakeDEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + ce_grid_desc_m_n_, + GridwiseGemm::CalculateMBlock(GemmM), + GridwiseGemm::CalculateNBlock(GemmN)); + + if constexpr(is_NGCHW_NGKHW() || + is_NGCDHW_NGKDHW()) + { + a_in_transpose_desc_ = + conv_ngchw_to_nhwgc_transformer.template MakeNGCHWTransposeDesc( + a_g_n_k_wos_lengths, a_g_n_k_wos_strides); + a_out_transpose_desc_ = + conv_ngchw_to_nhwgc_transformer.template MakeNHWGCTransposeDesc( + a_g_n_k_wos_lengths, a_g_n_k_wos_strides); + + b_in_transpose_desc_ = + conv_ngchw_to_nhwgc_transformer.template MakeNGCHWTransposeDesc( + b_g_n_c_wis_lengths, b_g_n_c_wis_strides); + b_out_transpose_desc_ = + conv_ngchw_to_nhwgc_transformer.template MakeNHWGCTransposeDesc( + b_g_n_c_wis_lengths, b_g_n_c_wis_strides); + + e_in_transpose_desc_ = + conv_ngchw_to_nhwgc_transformer.template MakeGKYXCTransposeDesc( + e_g_k_c_xs_lengths, e_g_k_c_xs_strides); + e_out_transpose_desc_ = + conv_ngchw_to_nhwgc_transformer.template MakeGKCYXTransposeDesc( + e_g_k_c_xs_lengths, e_g_k_c_xs_strides); + + elementwise_block_2_ctile_map_transpose_a_ = Block2TileMapElementwise{ + a_in_transpose_desc_.GetLength(I0), a_in_transpose_desc_.GetLength(I1)}; + + elementwise_block_2_ctile_map_transpose_b_ = Block2TileMapElementwise{ + b_in_transpose_desc_.GetLength(I0), b_in_transpose_desc_.GetLength(I1)}; + } + + elementwise_block_2_ctile_map_ = + is_NGCHW_GKCYX_NGKHW() || + is_NGCDHW_GKCZYX_NGKDHW() + ? Block2TileMapElementwise{e_in_transpose_desc_.GetLength(I0), + e_in_transpose_desc_.GetLength(I1)} + : Block2TileMapElementwise{ce_grid_desc_m_n_.GetLength(I0), + ce_grid_desc_m_n_.GetLength(I1)}; + } + + std::size_t GetWorkspaceATensorSizeBytes() const + { + // Align to 128B + return math::integer_divide_ceil( + sizeof(ADataType) * a_in_transpose_desc_.GetElementSpaceSize(), 128) * + 128; + } + + std::size_t GetWorkspaceBTensorSizeBytes() const + { + return sizeof(BDataType) * b_in_transpose_desc_.GetElementSpaceSize(); + } + + std::size_t GetWorkspaceETensorSizeBytes() const + { + // Align to 128B + return math::integer_divide_ceil(sizeof(AccDataType) * + ce_grid_desc_m_n_.GetElementSpaceSize() * Conv_G_, + 128) * + 128; + } + + std::size_t GetWorkspaceSizeBytes() const + { + // 1. We need to transpose A and B for NGCHW and NGKHW layouts + // 2. If C format is GKCYX then tranpose during second stage. + // If C format is GKYXC then just perform second stage. + // Due to the fact that E workspace is always needed, we + // allocate them as the first part of the workspace. + // [EWorkspace, AWorkspace, BWorkspace] + if constexpr(is_NGCHW_NGKHW() || + is_NGCDHW_NGKDHW()) + { + return GetWorkspaceATensorSizeBytes() + GetWorkspaceBTensorSizeBytes() + + GetWorkspaceETensorSizeBytes(); + } + else + { + return GetWorkspaceETensorSizeBytes(); + } + } + + const ADataType* p_a_grid_; + const BDataType* p_b_grid_; + EDataType* p_e_grid_; + + AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1_; + BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_; + CGridDesc_M_N ce_grid_desc_m_n_; + CElementwiseGridDesc_M_N ce_elementwise_grid_desc_m_n_; + CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock c_grid_desc_mblock_mperblock_nblock_nperblock_; + + Block2TileMapElementwise elementwise_block_2_ctile_map_; + Block2TileMapElementwise elementwise_block_2_ctile_map_transpose_a_, + elementwise_block_2_ctile_map_transpose_b_; + + NGCHWTransposeDescType a_in_transpose_desc_, b_in_transpose_desc_; + NHWGCTransposeDescType a_out_transpose_desc_, b_out_transpose_desc_; + GKYXCTransposeDescType e_in_transpose_desc_; + GKCYXTransposeDescType e_out_transpose_desc_; + + // for computing batch offset + ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch_; + + index_t M01_; + index_t N01_; + + OutElementwiseOperation a_element_op_; + InElementwiseOperation b_element_op_; + WeiElementwiseOperation cde_element_op_; + + // for checking IsSupportedArgument() + const index_t Conv_G_; + const index_t Conv_N_; + const index_t Conv_K_; + const index_t Conv_C_; + std::array input_spatial_lengths_; + std::array filter_spatial_lengths_; + std::array output_spatial_lengths_; + const std::array& conv_filter_strides_; + const std::array& input_left_pads_; + const std::array& input_right_pads_; + }; + + // Invoker + struct Invoker : public BaseInvoker + { + using Argument = DeviceOp::Argument; + + void ShowInfo(const Argument& arg) + { + std::cout << "arg.a_grid_desc_k0_m_k1_{" << arg.a_grid_desc_k0_m_k1_.GetLength(I0) + << ", " << arg.a_grid_desc_k0_m_k1_.GetLength(I1) << ", " + << arg.a_grid_desc_k0_m_k1_.GetLength(I2) << "}" << std::endl; + + std::cout << "arg.b_grid_desc_k0_n_k1_{" << arg.b_grid_desc_k0_n_k1_.GetLength(I0) + << ", " << arg.b_grid_desc_k0_n_k1_.GetLength(I1) << ", " + << arg.b_grid_desc_k0_n_k1_.GetLength(I2) << "}" << std::endl; + + std::cout << "arg.ce_grid_desc_m_n_{" << arg.ce_grid_desc_m_n_.GetLength(I0) << ", " + << arg.ce_grid_desc_m_n_.GetLength(I1) << "}" << std::endl; + } + + float RunGemmV3(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + const index_t GemmM = arg.a_grid_desc_k0_m_k1_.GetLength(I1); + const index_t GemmN = arg.b_grid_desc_k0_n_k1_.GetLength(I1); + const index_t GemmK = + arg.a_grid_desc_k0_m_k1_.GetLength(I0) * arg.a_grid_desc_k0_m_k1_.GetLength(I2); + + AccDataType* p_c_grid = type_convert(arg.p_workspace_); + + const ADataType* p_a_grid = arg.p_a_grid_; + const BDataType* p_b_grid = arg.p_b_grid_; + + if constexpr(is_NGCHW_NGKHW() || + is_NGCDHW_NGKDHW()) + { + p_a_grid = type_convert(arg.p_workspace_) + + arg.GetWorkspaceETensorSizeBytes() / sizeof(ADataType); + p_b_grid = + type_convert(arg.p_workspace_) + + (arg.GetWorkspaceETensorSizeBytes() + arg.GetWorkspaceATensorSizeBytes()) / + sizeof(BDataType); + } + + // nullptr for output, will be set after workspace set + typename GridwiseGemm::Argument gemm_arg{std::array{p_a_grid}, + std::array{p_b_grid}, + std::array{}, // p_ds_grid_ + p_c_grid, + GemmM, + GemmN, + GemmK, + std::array{I0}, + std::array{I0}, + std::array{}, // StrideDs_ + I0, + arg.k_batch_, + AElementwiseOperation{}, + BElementwiseOperation{}, + CDEElementwiseOperation{}}; + + index_t gdx, gdy, gdz; + std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize( + gemm_arg.M, gemm_arg.N, gemm_arg.KBatch, arg.Conv_G_ / NumGroupsToMerge); + + float ave_time = 0; + + index_t k_grain = gemm_arg.KBatch * KPerBlock; + index_t K_split = (gemm_arg.K + k_grain - 1) / k_grain * KPerBlock; + const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split); + + const auto num_k_per_block = + arg.a_grid_desc_k0_m_k1_.GetLength(Number<0>{}) / gemm_arg.KBatch; + + const auto clear_workspace = [&]() { + hip_check_error(hipMemsetAsync(gemm_arg.p_e_grid, + 0, + arg.GetWorkspaceETensorSizeBytes(), + stream_config.stream_id_)); + }; + + const auto Run = [&](const auto& kernel) { + if(stream_config.flush_cache) + { + typename GridwiseGemm::Argument gemm_arg_ = gemm_arg; + + std::array size_as_buffers; + size_as_buffers[0] = arg.a_grid_desc_k0_m_k1_.GetElementSpaceSize() * + sizeof(ADataType) / GridwiseGemm::APackedSize; + + std::array size_bs_buffers; + size_bs_buffers[0] = arg.b_grid_desc_k0_n_k1_.GetElementSpaceSize() * + sizeof(BDataType) / GridwiseGemm::BPackedSize; + + std::array size_ds_buffers; + + ck::utility::RotatingMemWrapperMultiABD, + Tuple, + Tuple<>> + rotating_mem(gemm_arg_, + stream_config.rotating_count, + size_as_buffers, + size_bs_buffers, + size_ds_buffers); + rotating_mem.Print(); + + auto run_flush_cache = [&]() { + // flush icache + ck::utility::flush_icache(); + // rotating mem + rotating_mem.Next(); + clear_workspace(); + }; + + ave_time += ck::utility::launch_and_time_kernel_with_preprocess( + stream_config, + run_flush_cache, + kernel, + dim3(gdx, gdy, gdz), + dim3(BlockSize), + 0, + gemm_arg_, + arg.a_grid_desc_k0_m_k1_, + arg.b_grid_desc_k0_n_k1_, + arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.compute_ptr_offset_of_batch_, + num_k_per_block); + } + else + { + ave_time += launch_and_time_kernel_with_preprocess( + stream_config, + clear_workspace, + kernel, + dim3(gdx, gdy, gdz), + dim3(BlockSize), + 0, + gemm_arg, + arg.a_grid_desc_k0_m_k1_, + arg.b_grid_desc_k0_n_k1_, + arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.compute_ptr_offset_of_batch_, + num_k_per_block); + } + }; + + constexpr index_t minimum_occupancy = + BlkGemmPipeSched == BlockGemmPipelineScheduler::Intrawave ? 1 : 2; + + if(has_main_k_block_loop) + { + // Tail number always full + if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 || + BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) + { + if(gemm_arg.KBatch > 1) + { + const auto kernel = kernel_grouped_conv_bwd_weight_wmma_cshuffle_v3< + GridwiseGemm, + remove_reference_t, + remove_reference_t, + remove_reference_t< + DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, + ComputePtrOffsetOfStridedBatch, + NumGroupsToMerge, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy>; + Run(kernel); + } + else + { + const auto kernel = kernel_grouped_conv_bwd_weight_wmma_cshuffle_v3< + GridwiseGemm, + remove_reference_t, + remove_reference_t, + remove_reference_t< + DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, + ComputePtrOffsetOfStridedBatch, + NumGroupsToMerge, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy>; + Run(kernel); + } + } + else + { + // TODO: Implement + } + } + else + { + // Tail number always 1 + if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1) + { + if(gemm_arg.KBatch > 1) + { + const auto kernel = kernel_grouped_conv_bwd_weight_wmma_cshuffle_v3< + GridwiseGemm, + remove_reference_t, + remove_reference_t, + remove_reference_t< + DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, + ComputePtrOffsetOfStridedBatch, + NumGroupsToMerge, + false, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy>; + Run(kernel); + } + else + { + const auto kernel = kernel_grouped_conv_bwd_weight_wmma_cshuffle_v3< + GridwiseGemm, + remove_reference_t, + remove_reference_t, + remove_reference_t< + DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, + ComputePtrOffsetOfStridedBatch, + NumGroupsToMerge, + false, + InMemoryDataOperationEnum::Set, + minimum_occupancy>; + Run(kernel); + } + } + } + + return ave_time; + } + + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + float avg_time = 0.f; + auto launch_elementwise_kernel = [&]() { + const AccDataType* p_c_grid = type_convert(arg.p_workspace_); + + std::array in_out_batch_strides = { + static_cast(arg.compute_ptr_offset_of_batch_.BatchStrideC_)}; + + if constexpr(is_NGCHW_GKCYX_NGKHW() || + is_NGCDHW_GKCZYX_NGKDHW()) + { + const index_t grid_size = arg.elementwise_block_2_ctile_map_.CalculateGridSize( + arg.e_in_transpose_desc_); + + const auto kernel = kernel_elementwise, + ck::Tuple, + ck::Tuple, + ck::Tuple, + Block2TileMapElementwise, + CDEElementwiseOperation>; + + return launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + make_tuple(arg.e_in_transpose_desc_), + make_tuple(arg.e_out_transpose_desc_), + make_tuple(p_c_grid), + make_tuple(arg.p_e_grid_), + arg.elementwise_block_2_ctile_map_, + arg.cde_element_op_); + } + else + { + const index_t grid_size = arg.elementwise_block_2_ctile_map_.CalculateGridSize( + arg.ce_elementwise_grid_desc_m_n_) * + arg.Conv_G_; + + const auto kernel = + kernel_batched_elementwise, + ck::Tuple, + ck::Tuple, + ck::Tuple, + Block2TileMapElementwise, + CDEElementwiseOperation, + I1, + I1>; + + return launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + make_tuple(arg.ce_elementwise_grid_desc_m_n_), + make_tuple(arg.ce_elementwise_grid_desc_m_n_), + make_tuple(p_c_grid), + make_tuple(arg.p_e_grid_), + arg.elementwise_block_2_ctile_map_, + arg.cde_element_op_, + arg.Conv_G_, + in_out_batch_strides, + in_out_batch_strides); + } + }; + + if constexpr(is_NGCHW_NGKHW() || + is_NGCDHW_NGKDHW()) + { + const index_t grid_size_a = + arg.elementwise_block_2_ctile_map_transpose_a_.CalculateGridSize( + arg.a_in_transpose_desc_); + const index_t grid_size_b = + arg.elementwise_block_2_ctile_map_transpose_b_.CalculateGridSize( + arg.b_in_transpose_desc_); + + ADataType* p_a_out_grid = type_convert(arg.p_workspace_) + + arg.GetWorkspaceETensorSizeBytes() / sizeof(ADataType); + BDataType* p_b_out_grid = + type_convert(arg.p_workspace_) + + (arg.GetWorkspaceETensorSizeBytes() + arg.GetWorkspaceATensorSizeBytes()) / + sizeof(BDataType); + + // Different data type for A and B is not supported + auto kernel_transpose = kernel_elementwise_dual, + ck::Tuple, + ck::Tuple, + ck::Tuple, + ck::Tuple, + ck::Tuple, + ck::Tuple, + ck::Tuple, + Block2TileMapElementwise, + Block2TileMapElementwise, + element_wise::PassThrough>; + + avg_time += launch_and_time_kernel(stream_config, + kernel_transpose, + dim3(grid_size_a + grid_size_b), + dim3(BlockSize), + 0, + make_tuple(arg.a_in_transpose_desc_), + make_tuple(arg.b_in_transpose_desc_), + make_tuple(arg.a_out_transpose_desc_), + make_tuple(arg.b_out_transpose_desc_), + make_tuple(arg.p_a_grid_), + make_tuple(arg.p_b_grid_), + make_tuple(p_a_out_grid), + make_tuple(p_b_out_grid), + arg.elementwise_block_2_ctile_map_transpose_a_, + arg.elementwise_block_2_ctile_map_transpose_b_, + element_wise::PassThrough{}, + grid_size_a); + } + + avg_time += RunGemmV3(arg, stream_config); + avg_time += launch_elementwise_kernel(); + return avg_time; + } + + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + static bool IsSupportedArgument(const Argument& arg) + { + const index_t GemmM = arg.a_grid_desc_k0_m_k1_.GetLength(I1); + const index_t GemmN = arg.b_grid_desc_k0_n_k1_.GetLength(I1); + const index_t GemmK = + arg.a_grid_desc_k0_m_k1_.GetLength(I0) * arg.a_grid_desc_k0_m_k1_.GetLength(I2); + + typename GridwiseGemm::Argument gemm_arg{std::array{nullptr}, // p_as_grid + std::array{nullptr}, // p_bs_grid + std::array{}, // p_ds_grid + nullptr, // p_e_grid + GemmM, // M + GemmN, // N + GemmK, // K + std::array{I0}, // StrideAs + std::array{I0}, // StrideBs + std::array{}, // StrideDs + I0, // StrideE + arg.k_batch_, + AElementwiseOperation{}, + BElementwiseOperation{}, + CDEElementwiseOperation{}}; + + const auto num_k_loop = gemm_arg.AK0 / (KPerBlock / ABK1); + if constexpr(BlkGemmPipelineVer != BlockGemmPipelineVersion::v1) + { + if(num_k_loop <= GridwiseGemm::BlockwiseGemmPipe::PrefetchStages) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Unsupported num K loop." << std::endl; + } + return false; + } + } + + if(!ck::is_gfx11_supported() && !ck::is_gfx12_supported()) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Unsupported: Architecture must be gfx11/gfx12." << std::endl; + } + return false; + } + + // Check this here, it allows to use other instances from factory even + // if workspace is not allocated + if(!arg.p_workspace_) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Warning: Workspace for " + "DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3::Argument is not " + "allocated, use SetWorkSpacePointer." + << std::endl; + } + return false; + } + + if constexpr(std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v) + { + if(ck::is_gfx11_supported()) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Unsupported f8 / bf8 on gfx11." << std::endl; + } + return false; + } + } + + if constexpr(NDimSpatial == 2) + { + if constexpr(!(is_NHWGC_GKYXC_NHWGK() || + is_NGCHW_NGKHW())) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Unsupported layout." << std::endl; + } + return false; + } + } + else if constexpr(NDimSpatial == 3) + { + if constexpr(!(is_NDHWGC_GKZYXC_NDHWGK() || + is_NGCDHW_NGKDHW())) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Unsupported layout." << std::endl; + } + return false; + } + } + else + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Unsupported layout." << std::endl; + } + return false; + } + + if constexpr(ConvBackwardWeightSpecialization == + ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0) + { + // check if it's 1x1, stride=1 pad = 0 conv + for(int i = 0; i < NDimSpatial; i++) + { + if(!(arg.filter_spatial_lengths_[i] == 1 && arg.conv_filter_strides_[i] == 1 && + arg.input_left_pads_[i] == 0 && arg.input_right_pads_[i] == 0)) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Unsupported stride / pad." << std::endl; + } + return false; + } + } + } + + if constexpr(NumGroupsToMerge > 1) + { + // support only if whole M and N can be proccessed on one block + if(!(GemmM <= MPerBlock && GemmN <= NPerBlock)) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Unsupported GemmMN for merge groups." << std::endl; + } + return false; + } + if(!(arg.Conv_C_ == 1 && arg.Conv_K_ == 1)) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Unsupported conv CK for merge groups." << std::endl; + } + return false; + } + if(arg.Conv_G_ % NumGroupsToMerge != 0) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Unsupported conv G for merge groups." << std::endl; + } + return false; + } + } + + const bool is_w_pad_zero = arg.input_left_pads_[NDimSpatial - 1] == 0 && + arg.input_right_pads_[NDimSpatial - 1] == 0; + const auto X = arg.filter_spatial_lengths_[NDimSpatial - 1]; + const bool XC_access_allowed = arg.Conv_G_ == 1 && + (arg.Conv_C_ * X) % BBlockTransferSrcScalarPerVector == 0 && + is_w_pad_zero; + + if(!((arg.Conv_C_ % BBlockTransferSrcScalarPerVector == 0 || XC_access_allowed) && + arg.Conv_K_ % ABlockTransferSrcScalarPerVector == 0)) + { + if(!(arg.Conv_K_ == 1 && arg.compute_ptr_offset_of_batch_.BatchStrideA_ == 1 && + NumGroupsToMerge > 1)) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Unsupported Conv_K_ % ABlockTransferSrcScalarPerVector" + << std::endl; + } + return false; + } + if(!(arg.Conv_C_ == 1 && arg.compute_ptr_offset_of_batch_.BatchStrideB_ == 1 && + NumGroupsToMerge > 1)) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Unsupported Conv_C_ % BBlockTransferSrcScalarPerVector" + << std::endl; + } + return false; + } + } + + // vector load A/B matrix from global memory + if(!(ABlockTransferSrcVectorDim == 1 && BBlockTransferSrcVectorDim == 1)) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Unsupported BlockTransferSrcVectorDim." << std::endl; + } + return false; + } + + // vector store C matrix into global memory + if(!(arg.Conv_C_ % CShuffleBlockTransferScalarPerVector_NPerBlock == 0)) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Unsupported CShuffleBlockTransferScalarPerVector_NPerBlock." + << std::endl; + } + return false; + } + + if constexpr(is_NGCHW_NGKHW() || + is_NGCDHW_NGKDHW()) + { + if((arg.Conv_G_ * arg.Conv_C_) % TransposeTransferDstScalarPerVector != 0) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Unsupported TransposeTransferDstScalarPerVector with GC." + << std::endl; + } + return false; + } + + if((arg.Conv_G_ * arg.Conv_K_) % TransposeTransferDstScalarPerVector != 0) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Unsupported TransposeTransferDstScalarPerVector with GK." + << std::endl; + } + return false; + } + + const index_t input_spatial_acum = ck::accumulate_n( + arg.input_spatial_lengths_.begin(), NDimSpatial, 1, std::multiplies<>()); + const index_t output_spatial_acum = ck::accumulate_n( + arg.output_spatial_lengths_.begin(), NDimSpatial, 1, std::multiplies<>()); + + if(input_spatial_acum % TransposeTransferSrcScalarPerVector != 0) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout + << "Unsupported input_spatial_acum % TransposeTransferSrcScalarPerVector." + << std::endl; + } + return false; + } + + if(output_spatial_acum % TransposeTransferSrcScalarPerVector != 0) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout + << "Unsupported input_spatial_acum % TransposeTransferSrcScalarPerVector." + << std::endl; + } + return false; + } + + constexpr long_index_t TwoGB = (long_index_t{1} << 31); + if(!(arg.a_out_transpose_desc_.GetElementSpaceSize() * sizeof(ADataType) <= TwoGB && + arg.b_out_transpose_desc_.GetElementSpaceSize() * sizeof(BDataType) <= TwoGB)) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Unsupported: Problem exceeds 2GB limit." << std::endl; + } + return false; + } + } + + return true; + } + + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto + MakeArgument(const InDataType* p_in_grid, + WeiDataType* p_wei_grid, + const OutDataType* p_out_grid, + const std::array& b_g_n_c_wis_lengths, // input + const std::array& b_g_n_c_wis_strides, + const std::array& e_g_k_c_xs_lengths, // weight + const std::array& e_g_k_c_xs_strides, + const std::array& a_g_n_k_wos_lengths, // output + const std::array& a_g_n_k_wos_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads, + InElementwiseOperation in_element_op, + WeiElementwiseOperation wei_element_op, + OutElementwiseOperation out_element_op, + const ck::index_t split_k) + { + return Argument{p_in_grid, + p_wei_grid, + p_out_grid, + b_g_n_c_wis_lengths, // input + b_g_n_c_wis_strides, + e_g_k_c_xs_lengths, // weight + e_g_k_c_xs_strides, + a_g_n_k_wos_lengths, // output + a_g_n_k_wos_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + 1, + 1, + in_element_op, + wei_element_op, + out_element_op, + split_k}; + } + + static auto MakeInvoker() { return Invoker{}; } + + std::unique_ptr + MakeArgumentPointer(const void* p_in_grid, + void* p_wei_grid, + const void* p_out_grid, + const std::array& b_g_n_c_wis_lengths, // input + const std::array& b_g_n_c_wis_strides, + const std::array& e_g_k_c_xs_lengths, // weight + const std::array& e_g_k_c_xs_strides, + const std::array& a_g_n_k_wos_lengths, // output + const std::array& a_g_n_k_wos_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads, + InElementwiseOperation in_element_op, + WeiElementwiseOperation wei_element_op, + OutElementwiseOperation out_element_op, + const ck::index_t split_k) override + { + return std::make_unique(static_cast(p_in_grid), + static_cast(p_wei_grid), + static_cast(p_out_grid), + b_g_n_c_wis_lengths, // input + b_g_n_c_wis_strides, + e_g_k_c_xs_lengths, // weight + e_g_k_c_xs_strides, + a_g_n_k_wos_lengths, // output + a_g_n_k_wos_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + 1, + 1, + in_element_op, + wei_element_op, + out_element_op, + split_k); + } + + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + std::map BlkGemmPipelineSchedulerToString{ + {BlockGemmPipelineScheduler::Intrawave, "Intrawave"}, + {BlockGemmPipelineScheduler::Interwave, "Interwave"}}; + + std::map BlkGemmPipelineVersionToString{ + {BlockGemmPipelineVersion::v1, "v1"}, + {BlockGemmPipelineVersion::v2, "v2"}, + {BlockGemmPipelineVersion::v3, "v3"}, + {BlockGemmPipelineVersion::v4, "v4"}, + {BlockGemmPipelineVersion::v5, "v5"}}; + + // clang-format off + str << "DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3" + << "<" + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << KPerBlock << ", " + << getConvBackwardWeightSpecializationString(ConvBackwardWeightSpecialization) << ", " + << ABK1 << ", " + << MRepeat << ", " + << NRepeat << ", " + << ABlockTransferSrcScalarPerVector << ", " + << ABlockTransferDstScalarPerVector_AK1 << ", " + << BBlockTransferSrcScalarPerVector << ", " + << BBlockTransferDstScalarPerVector_BK1 << ", " + << CShuffleMRepeatPerShuffle << ", " + << CShuffleNRepeatPerShuffle << ", " + << CShuffleBlockTransferScalarPerVector_NPerBlock << ", " + << "BlkGemmPipelineScheduler: " + << BlkGemmPipelineSchedulerToString[BlkGemmPipeSched] << ", " + << "BlkGemmPipelineVersion: " + << BlkGemmPipelineVersionToString[BlkGemmPipelineVer] << ", " + << NumGroupsToMerge; + + if constexpr(is_NGCHW_NGKHW() || + is_NGCDHW_NGKDHW()) { + str << ", TransposeTransferSrcScalarPerVector: " + << TransposeTransferSrcScalarPerVector <<", " + << "TransposeTransferDstScalarPerVector: " << TransposeTransferDstScalarPerVector; + } + + + str << ">"; + // clang-format on + + return str.str(); + } + + size_t GetWorkSpaceSize(const BaseArgument* p_arg) const override + { + auto arg = dynamic_cast(p_arg); + if(arg) + { + return arg->GetWorkspaceSizeBytes(); + } + else + throw std::runtime_error( + "The argument pointer is not an object of " + "DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3::Argument structure!"); + } + + void SetWorkSpacePointer(BaseArgument* p_arg, + void* p_workspace, + const StreamConfig& = StreamConfig{}) const override + { + auto p_arg_ = dynamic_cast(p_arg); + if(p_arg_) + { + p_arg_->p_workspace_ = p_workspace; + } + else + throw std::runtime_error( + "The argument pointer is not an object of " + "DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3::Argument structure!"); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp index 9b89b549f4..e975534a06 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp @@ -1745,6 +1745,11 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle { return false; } + // TODO: this is needed because there is a bug + if(arg.k_batch_ > 1) + { + return false; + } } // Check this here, it allows to use other instances from factory even diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle.hpp index 3db7b85551..c50940da41 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle.hpp @@ -17,6 +17,7 @@ #include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp" #include "ck/host_utility/device_prop.hpp" #include "ck/host_utility/kernel_launch.hpp" +#include "ck/tensor_operation/gpu/device/impl/split_k_arg.hpp" namespace ck { namespace tensor_operation { @@ -450,7 +451,7 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffle using Block2CTileMap = decltype(GridwiseGemm::MakeDefaultBlock2CTileMap( CGridDesc_M_N{}, I1 /* M01 */, I1 /* N01 */)); - struct Argument : public BaseArgument + struct Argument : public BaseArgument, public ArgumentSplitK { Argument(const InDataType* p_in_grid, WeiDataType* p_wei_grid, @@ -490,8 +491,7 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffle output_spatial_lengths_{}, conv_filter_strides_{conv_filter_strides}, input_left_pads_{input_left_pads}, - input_right_pads_{input_right_pads}, - k_batch_{split_k} + input_right_pads_{input_right_pads} { constexpr index_t spatial_offset = 3; std::copy(begin(a_g_n_c_wis_lengths) + spatial_offset, @@ -504,6 +504,8 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffle end(e_g_n_k_wos_lengths), begin(output_spatial_lengths_)); + k_batch_ = split_k; + const auto descs = DeviceOp::MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N( Conv_N_, @@ -576,7 +578,6 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffle const std::array& conv_filter_strides_; const std::array& input_left_pads_; const std::array& input_right_pads_; - const index_t k_batch_; }; // Invoker diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle_v3.hpp new file mode 100644 index 0000000000..1ab6bc446f --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle_v3.hpp @@ -0,0 +1,1429 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include +#include +#include + +#include "ck/utility/common_header.hpp" + +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_conv_bwd_weight.hpp" +#include "ck/tensor_operation/operator_transform/transform_conv_bwd_weight_to_gemm.hpp" +#include "ck/tensor_operation/operator_transform/transform_conv_bwd_weight_to_gemm_v2.hpp" +#include "ck/tensor_operation/operator_transform/transform_conv_ngchw_to_nhwgc.hpp" +#include "ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_elementwise_2d.hpp" +#include +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp" +#include "ck/tensor_operation/gpu/device/impl/split_k_utils.hpp" +#include "ck/tensor_operation/gpu/device/impl/split_k_arg.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/device/matrix_padder.hpp" + +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" +#include "ck/host_utility/flush_cache.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +#endif + kernel_grouped_conv_bwd_weight_wmma_cshuffle_v3( + typename GridwiseGemm::Argument karg, + const AGridDesc_AK0_M_K1 a_grid_desc_ak0_m_ak1, + const BGridDesc_BK0_N_K1 b_grid_desc_bk0_n_bk1, + const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock + c_grid_desc_mblock_mperblock_nblock_nperblock, + const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch, + const index_t num_k_per_block) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__)) +#if defined(__gfx11__) + // gfx11 does not support *_atomic_pk_add_f16/bf16 instructions + using e_data_type = remove_cvref_t>; + if constexpr(!(CGlobalMemoryDataOperation == InMemoryDataOperationEnum::AtomicAdd && + (std::is_same_v || + std::is_same_v))) + { +#endif + constexpr index_t LDS_size = GridwiseGemm::template GetSharedMemoryNumberOfByte< + typename GridwiseGemm::EpilogueCShuffle>(); + __shared__ char p_shared[LDS_size]; + + auto epilogue_args = typename GridwiseGemm::EpilogueCShuffle{}; + + GridwiseGemm::template Run(p_shared, + a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + c_grid_desc_mblock_mperblock_nblock_nperblock, + compute_ptr_offset_of_batch, + num_k_per_block, + karg, + epilogue_args); + +#if defined(__gfx11__) + } +#endif +#else + ignore = karg; + ignore = a_grid_desc_ak0_m_ak1; + ignore = b_grid_desc_bk0_n_bk1; + ignore = c_grid_desc_mblock_mperblock_nblock_nperblock; + ignore = compute_ptr_offset_of_batch; + ignore = num_k_per_block; +#endif // end of if (defined(__gfx9__) +} + +// out[N, Ho, Wo, K] = in[N, Hi, Wi, C] * wei[K, Y, X, C] +template +struct DeviceGroupedConvBwdWeight_Wmma_CShuffleV3 + : public DeviceGroupedConvBwdWeight +{ + static_assert(is_same_v); + static_assert(is_same_v); + static_assert(is_same_v); + + using DeviceOp = DeviceGroupedConvBwdWeight_Wmma_CShuffleV3; + + using ADataType = OutDataType; + using BDataType = InDataType; + using CDataType = WeiDataType; + + // If NGCHW then ADataType must be equal to BDataType + static_assert(!(is_NGCHW_NGKHW() || + is_NGCDHW_NGKDHW()) || + is_same_v); + + using AElementwiseOperation = OutElementwiseOperation; + using BElementwiseOperation = InElementwiseOperation; + using CElementwiseOperation = WeiElementwiseOperation; + + static inline auto I0 = Number<0>{}; + static inline auto I1 = Number<1>{}; + static inline auto I2 = Number<2>{}; + static inline auto I3 = Number<3>{}; + static inline auto I4 = Number<4>{}; + static inline auto I5 = Number<5>{}; + + static constexpr GemmSpecialization GemmSpec = GemmSpecialization::Default; + static constexpr auto ABK1Number = Number{}; + + static constexpr auto conv_to_gemm_transformer = + TransformConvBwdWeightToGemmV2{}; + + static constexpr index_t ClusterLengthMPerBlock = + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock::At(1); + static constexpr index_t ClusterLengthNPerBlock = + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock::At(3); + + static constexpr auto conv_ngchw_to_nhwgc_transformer = + TransformConvNGCHWToNHWGC{}; + + static constexpr index_t TransposeTransferSrcScalarPerVector = + std::min(NPerBlock / ClusterLengthNPerBlock, MaxTransposeTransferSrcScalarPerVector); + static constexpr index_t TransposeTransferDstScalarPerVector = + std::min(MPerBlock / ClusterLengthMPerBlock, MaxTransposeTransferDstScalarPerVector); + + template ::type = false> + static auto GetABCGridDesc() + { + const ck::index_t dim = 1; + const ck::index_t batch = 1; + const std::array lengths{1}; + const std::array strides{1, 1, 1, 1}; + const std::array params{1}; + return conv_to_gemm_transformer.template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<1>( + dim, + dim, + dim, + lengths, + lengths, + lengths, + strides, + strides, + strides, + params, + params, + params, + params, + batch); + } + + template ::type = false> + static auto GetABCGridDesc() + { + const ck::index_t dim = 1; + const ck::index_t batch = 1; + const std::array lengths{1, 1}; + const std::array strides{1, 1, 1, 1, 1}; + const std::array params{1, 1}; + return conv_to_gemm_transformer.template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<2>( + dim, + dim, + dim, + lengths, + lengths, + lengths, + strides, + strides, + strides, + params, + params, + params, + params, + batch); + } + + template ::type = false> + static auto GetABCGridDesc() + { + const ck::index_t dim = 1; + const ck::index_t batch = 1; + const std::array lengths{1, 1, 1}; + const std::array strides{1, 1, 1, 1, 1, 1}; + const std::array params{1, 1, 1}; + return conv_to_gemm_transformer.template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<3>( + dim, + dim, + dim, + lengths, + lengths, + lengths, + strides, + strides, + strides, + params, + params, + params, + params, + batch); + } + + using NGCHWTransposeDescType = + remove_cvref_t({}, {}))>; + using NHWGCTransposeDescType = + remove_cvref_t({}, {}))>; + using GKCYXTransposeDescType = + remove_cvref_t({}, {}))>; + using GKYXCTransposeDescType = + remove_cvref_t({}, {}))>; + + using ABCGridDescs = decltype(GetABCGridDesc()); + + using AGridDesc_K0_M_K1 = remove_cvref_t; + using BGridDesc_K0_N_K1 = remove_cvref_t; + using CGridDesc_M_N = remove_cvref_t; + + using Block2TileMapTranspose = BlockToCTileMap_M00_N0_M01Adapt; + + using GridwiseInOutTranspose = + GridwiseElementwise, + Tuple, + Tuple, + Tuple, + Block2TileMapTranspose, + element_wise::PassThrough, + BlockSize, + MPerBlock, + NPerBlock, + MPerBlock / ClusterLengthMPerBlock, + NPerBlock / ClusterLengthNPerBlock, + Sequence<1, 0>, + Sequence, + Sequence, + I1, + I0>; + + // NPerBlock is used for the first dim which is store dimension + // (with CShuffleBlockTransferScalarPerVector_NPerBlock scalar per vector). + using GridwiseElementwiseWeightTranspose = + GridwiseElementwise, + Tuple, + Tuple, + Tuple, + Block2TileMapTranspose, + element_wise::PassThrough, + BlockSize, + MPerBlock, + NPerBlock, + MPerBlock / ClusterLengthMPerBlock, + NPerBlock / ClusterLengthNPerBlock, + Sequence<1, 0>, + Sequence, + Sequence<1>, + I1, + I0>; + + using GridwiseGemm = GridwiseGemm_wmma_cshuffle_v3< + tensor_layout::gemm::ColumnMajor, + tensor_layout::gemm::RowMajor, + Tuple<>, + tensor_layout::gemm::RowMajor, + Tuple, + Tuple, + AccDataType, + CDataType, + Tuple<>, + CDataType, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + GemmSpec, + BlockSize, + MPerBlock, + NPerBlock, + KPerBlock, + ABK1, + ABK1, + MPerWmma, + NPerWmma, + MRepeat, + NRepeat, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + false, + ABlockLdsAddExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + false, + BBlockLdsAddExtraN, + CShuffleMRepeatPerShuffle, + CShuffleNRepeatPerShuffle, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + Sequence, + BlkGemmPipeSched, + BlkGemmPipelineVer, + ComputeTypeA, + ComputeTypeB, + false, // PermuteA + false, // permuteB + false, // IsBPreshuffle + true>; // ForceThreadTileTransfer + + // Argument + using CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = + decltype(GridwiseGemm::MakeDEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + CGridDesc_M_N{}, 1, 1)); + + struct ActiveWorkgroupsPerCU + { + ActiveWorkgroupsPerCU() + { + constexpr int dynamic_smem_size = 0; + constexpr index_t minimum_occupancy = + BlkGemmPipeSched == BlockGemmPipelineScheduler::Intrawave ? 1 : 2; + int max_occupancy = 0; + + if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v4) + { + // TODO: implement + } + else + { + hip_check_error(hipOccupancyMaxActiveBlocksPerMultiprocessor( + &max_occupancy, + kernel_grouped_conv_bwd_weight_wmma_cshuffle_v3< + GridwiseGemm, + remove_reference_t, + remove_reference_t, + remove_reference_t, + ComputePtrOffsetOfStridedBatch, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy>, + BlockSize, + dynamic_smem_size)); + } + max_occupancy_ = std::max(1, max_occupancy); + } + int max_occupancy_; + }; + + struct Argument : public BaseArgument, public ArgumentSplitK + { + Argument(const InDataType* p_in_grid, + WeiDataType* p_wei_grid, + const OutDataType* p_out_grid, + const std::array& b_g_n_c_wis_lengths, // input + const std::array& b_g_n_c_wis_strides, + const std::array& e_g_k_c_xs_lengths, // weight + const std::array& e_g_k_c_xs_strides, + const std::array& a_g_n_k_wos_lengths, // output + const std::array& a_g_n_k_wos_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads, + const ck::index_t M01, + const ck::index_t N01, + InElementwiseOperation in_element_op, + WeiElementwiseOperation wei_element_op, + OutElementwiseOperation out_element_op, + ck::index_t split_k) + : p_a_grid_{p_out_grid}, + p_b_grid_{p_in_grid}, + p_c_grid_{p_wei_grid}, + a_grid_desc_kbatch_k0_m_k1_{}, + b_grid_desc_kbatch_k0_n_k1_{}, + c_grid_desc_m_n_{}, + c_grid_desc_mblock_mperblock_nblock_nperblock_{}, + compute_ptr_offset_of_batch_{}, + M01_{M01}, + N01_{N01}, + a_element_op_{out_element_op}, + b_element_op_{in_element_op}, + c_element_op_{wei_element_op}, + Conv_G_{b_g_n_c_wis_lengths[0]}, + Conv_N_{b_g_n_c_wis_lengths[1]}, + Conv_K_{e_g_k_c_xs_lengths[1]}, + Conv_C_{b_g_n_c_wis_lengths[2]}, + input_spatial_lengths_{}, + filter_spatial_lengths_{}, + output_spatial_lengths_{}, + conv_filter_strides_{conv_filter_strides}, + input_left_pads_{input_left_pads}, + input_right_pads_{input_right_pads} + { + static ActiveWorkgroupsPerCU active_workgroups_per_cu; + + c_space_size_bytes = + ck::accumulate_n( + e_g_k_c_xs_lengths.begin(), NDimSpatial + I3, 1, std::multiplies<>()) * + sizeof(WeiDataType); + + constexpr index_t spatial_offset = 3; + std::copy(begin(b_g_n_c_wis_lengths) + spatial_offset, + end(b_g_n_c_wis_lengths), + begin(input_spatial_lengths_)); + std::copy(begin(e_g_k_c_xs_lengths) + spatial_offset, + end(e_g_k_c_xs_lengths), + begin(filter_spatial_lengths_)); + std::copy(begin(a_g_n_k_wos_lengths) + spatial_offset, + end(a_g_n_k_wos_lengths), + begin(output_spatial_lengths_)); +#if !DISABLE_SPLIT_K_AUTODEDUCE_FOR_ONE_STAGE_KERNELS + if(split_k < 0) + { + ck::index_t gemmM, gemmN, gemmK; + std::tie(gemmM, gemmN, gemmK) = + get_bwd_weight_gemm_sizes(a_g_n_k_wos_lengths, e_g_k_c_xs_lengths); + + const auto grid_size = + calculate_mn_grid_size(gemmM, gemmN) * Conv_G_; + k_batch_ = get_best_occupancy_k_batch_value(active_workgroups_per_cu.max_occupancy_, + grid_size); + + // Ensure that k_batch_ does not exceed the maximum value + // for the GEMM pipeline. + const auto k_batch_max = math::integer_divide_ceil((gemmK - 1), KPerBlock); + k_batch_ = std::min(k_batch_, k_batch_max); + + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "[SPLIT-K AUTODEDUCE] k_batch max value: " << k_batch_max + << std::endl; + std::cout << "[SPLIT-K AUTODEDUCE] Final k_batch value: " << k_batch_ + << std::endl; + } + } + else +#endif + { + k_batch_ = split_k; + } + + std::array a_g_n_k_wos_strides_transposed = + conv_ngchw_to_nhwgc_transformer.TransposeInOutStrides(a_g_n_k_wos_lengths, + a_g_n_k_wos_strides); + std::array b_g_n_c_wis_strides_transposed = + conv_ngchw_to_nhwgc_transformer.TransposeInOutStrides(b_g_n_c_wis_lengths, + b_g_n_c_wis_strides); + std::array e_g_k_c_xs_strides_transposed = + conv_ngchw_to_nhwgc_transformer.TransposeWeiStrides(e_g_k_c_xs_lengths, + e_g_k_c_xs_strides); + + const auto descs = + conv_to_gemm_transformer + .template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N( + Conv_N_, + Conv_K_, + Conv_C_, + input_spatial_lengths_, + filter_spatial_lengths_, + output_spatial_lengths_, + b_g_n_c_wis_strides_transposed, + e_g_k_c_xs_strides_transposed, + a_g_n_k_wos_strides_transposed, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + k_batch_); + + a_grid_desc_kbatch_k0_m_k1_ = descs[I0]; + b_grid_desc_kbatch_k0_n_k1_ = descs[I1]; + c_grid_desc_m_n_ = descs[I2]; + + // A/B/C Batch Stride + compute_ptr_offset_of_batch_.BatchStrideA_ = a_g_n_k_wos_strides_transposed[0]; + compute_ptr_offset_of_batch_.BatchStrideB_ = b_g_n_c_wis_strides_transposed[0]; + compute_ptr_offset_of_batch_.BatchStrideC_ = e_g_k_c_xs_strides_transposed[0]; + const index_t GemmM = a_grid_desc_kbatch_k0_m_k1_.GetLength(I1); + const index_t GemmN = b_grid_desc_kbatch_k0_n_k1_.GetLength(I1); + + c_grid_desc_mblock_mperblock_nblock_nperblock_ = + GridwiseGemm::MakeDEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + c_grid_desc_m_n_, + GridwiseGemm::CalculateMBlock(GemmM), + GridwiseGemm::CalculateNBlock(GemmN)); + + if constexpr(is_NGCHW_NGKHW() || + is_NGCDHW_NGKDHW()) + { + a_in_transpose_desc_ = + conv_ngchw_to_nhwgc_transformer.template MakeNGCHWTransposeDesc( + a_g_n_k_wos_lengths, a_g_n_k_wos_strides); + a_out_transpose_desc_ = + conv_ngchw_to_nhwgc_transformer.template MakeNHWGCTransposeDesc( + a_g_n_k_wos_lengths, a_g_n_k_wos_strides); + + b_in_transpose_desc_ = + conv_ngchw_to_nhwgc_transformer.template MakeNGCHWTransposeDesc( + b_g_n_c_wis_lengths, b_g_n_c_wis_strides); + b_out_transpose_desc_ = + conv_ngchw_to_nhwgc_transformer.template MakeNHWGCTransposeDesc( + b_g_n_c_wis_lengths, b_g_n_c_wis_strides); + + e_in_transpose_desc_ = + conv_ngchw_to_nhwgc_transformer.template MakeGKYXCTransposeDesc( + e_g_k_c_xs_lengths, e_g_k_c_xs_strides); + e_out_transpose_desc_ = + conv_ngchw_to_nhwgc_transformer.template MakeGKCYXTransposeDesc( + e_g_k_c_xs_lengths, e_g_k_c_xs_strides); + + elementwise_block_2_ctile_map_transpose_a_ = Block2TileMapTranspose{ + a_in_transpose_desc_.GetLength(I0), a_in_transpose_desc_.GetLength(I1)}; + + elementwise_block_2_ctile_map_transpose_b_ = Block2TileMapTranspose{ + b_in_transpose_desc_.GetLength(I0), b_in_transpose_desc_.GetLength(I1)}; + + elementwise_block_2_ctile_map_transpose_e_ = Block2TileMapTranspose{ + e_in_transpose_desc_.GetLength(I0), e_in_transpose_desc_.GetLength(I1)}; + } + } + + std::size_t GetWorkspaceATensorSizeBytes() const + { + if constexpr(is_NGCHW_NGKHW() || + is_NGCDHW_NGKDHW()) + { + // Align to 128B + return math::integer_divide_ceil( + sizeof(ADataType) * a_in_transpose_desc_.GetElementSpaceSize(), 128) * + 128; + } + else + { + return 0; + } + } + + std::size_t GetWorkspaceBTensorSizeBytes() const + { + if constexpr(is_NGCHW_NGKHW() || + is_NGCDHW_NGKDHW()) + { + // Align to 128B + return math::integer_divide_ceil( + sizeof(BDataType) * b_in_transpose_desc_.GetElementSpaceSize(), 128) * + 128; + } + else + { + return 0; + } + } + + std::size_t GetWorkspaceETensorSizeBytes() const + { + if constexpr(is_NGCHW_GKCYX_NGKHW() || + is_NGCDHW_GKCZYX_NGKDHW()) + { + return sizeof(CDataType) * e_in_transpose_desc_.GetElementSpaceSize(); + } + else + { + return 0; + } + } + + std::size_t GetWorkspaceSizeBytes() const + { + return GetWorkspaceATensorSizeBytes() + GetWorkspaceBTensorSizeBytes() + + GetWorkspaceETensorSizeBytes(); + } + + const ADataType* p_a_grid_; + const BDataType* p_b_grid_; + CDataType* p_c_grid_; + AGridDesc_K0_M_K1 a_grid_desc_kbatch_k0_m_k1_; + BGridDesc_K0_N_K1 b_grid_desc_kbatch_k0_n_k1_; + CGridDesc_M_N c_grid_desc_m_n_; + CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock c_grid_desc_mblock_mperblock_nblock_nperblock_; + + Block2TileMapTranspose elementwise_block_2_ctile_map_transpose_a_; + Block2TileMapTranspose elementwise_block_2_ctile_map_transpose_b_; + Block2TileMapTranspose elementwise_block_2_ctile_map_transpose_e_; + + NGCHWTransposeDescType a_in_transpose_desc_, b_in_transpose_desc_; + NHWGCTransposeDescType a_out_transpose_desc_, b_out_transpose_desc_; + + GKYXCTransposeDescType e_in_transpose_desc_; + GKCYXTransposeDescType e_out_transpose_desc_; + + // for computing batch offset + ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch_; + + index_t M01_; + index_t N01_; + + OutElementwiseOperation a_element_op_; + InElementwiseOperation b_element_op_; + WeiElementwiseOperation c_element_op_; + + // for checking IsSupportedArgument() + const index_t Conv_G_; + const index_t Conv_N_; + const index_t Conv_K_; + const index_t Conv_C_; + std::array input_spatial_lengths_; + std::array filter_spatial_lengths_; + std::array output_spatial_lengths_; + const std::array& conv_filter_strides_; + const std::array& input_left_pads_; + const std::array& input_right_pads_; + long_index_t c_space_size_bytes; + }; + + // Invoker + struct Invoker : public BaseInvoker + { + using Argument = DeviceOp::Argument; + + void ShowInfo(const Argument& arg) + { + std::cout << "arg.a_grid_desc_kbatch_k0_m_k1_{" + << arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I0) << ", " + << arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I1) << ", " + << arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I2) << ", " + << arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I3) << "}" << std::endl; + + std::cout << "arg.b_grid_desc_kbatch_k0_n_k1_{" + << arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I0) << ", " + << arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I1) << ", " + << arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I2) << ", " + << arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I3) << "}" << std::endl; + + std::cout << "arg.c_grid_desc_m_n_{" << arg.c_grid_desc_m_n_.GetLength(I0) << ", " + << arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl; + } + + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + float ave_time = 0; + + const index_t GemmM = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I1); + const index_t GemmN = arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I1); + const index_t GemmK = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I0) * + arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I2); + + const ADataType* p_a_grid = arg.p_a_grid_; + const BDataType* p_b_grid = arg.p_b_grid_; + CDataType* p_e_grid = arg.p_c_grid_; + + // A/B Transpose kernel dispatch (if needed) + if constexpr(is_NGCHW_GKCYX_NGKHW() || + is_NGCDHW_GKCZYX_NGKDHW()) + { + p_e_grid = + type_convert(arg.p_workspace_) + + (arg.GetWorkspaceATensorSizeBytes() + arg.GetWorkspaceBTensorSizeBytes()) / + sizeof(CDataType); + } + + if constexpr(is_NGCHW_NGKHW() || + is_NGCDHW_NGKDHW()) + { + const index_t grid_size_a = + arg.elementwise_block_2_ctile_map_transpose_a_.CalculateGridSize( + arg.a_in_transpose_desc_); + const index_t grid_size_b = + arg.elementwise_block_2_ctile_map_transpose_b_.CalculateGridSize( + arg.b_in_transpose_desc_); + + p_a_grid = type_convert(arg.p_workspace_); + p_b_grid = type_convert(arg.p_workspace_) + + arg.GetWorkspaceATensorSizeBytes() / sizeof(BDataType); + ADataType* p_out_a_grid = type_convert(arg.p_workspace_); + BDataType* p_out_b_grid = type_convert(arg.p_workspace_) + + arg.GetWorkspaceATensorSizeBytes() / sizeof(BDataType); + + // Different data type for A and B is not supported + auto kernel_transpose = kernel_elementwise_dual, + ck::Tuple, + ck::Tuple, + ck::Tuple, + ck::Tuple, + ck::Tuple, + ck::Tuple, + ck::Tuple, + Block2TileMapTranspose, + Block2TileMapTranspose, + element_wise::PassThrough>; + + ave_time += launch_and_time_kernel(stream_config, + kernel_transpose, + dim3(grid_size_a + grid_size_b), + dim3(BlockSize), + 0, + make_tuple(arg.a_in_transpose_desc_), + make_tuple(arg.b_in_transpose_desc_), + make_tuple(arg.a_out_transpose_desc_), + make_tuple(arg.b_out_transpose_desc_), + make_tuple(arg.p_a_grid_), + make_tuple(arg.p_b_grid_), + make_tuple(p_out_a_grid), + make_tuple(p_out_b_grid), + arg.elementwise_block_2_ctile_map_transpose_a_, + arg.elementwise_block_2_ctile_map_transpose_b_, + element_wise::PassThrough{}, + grid_size_a); + } + + // Convolution kernel dispatch + typename GridwiseGemm::Argument gemm_arg{std::array{p_a_grid}, + std::array{p_b_grid}, + std::array{}, // p_ds_grid_ + p_e_grid, + GemmM, + GemmN, + GemmK, + std::array{I0}, + std::array{I0}, + std::array{}, // StrideDs_ + I0, + arg.k_batch_, + AElementwiseOperation{}, + BElementwiseOperation{}, + CElementwiseOperation{}}; + + index_t gdx, gdy, gdz; + std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize( + gemm_arg.M, gemm_arg.N, gemm_arg.KBatch, arg.Conv_G_); + + index_t k_grain = gemm_arg.KBatch * KPerBlock; + index_t K_split = (gemm_arg.K + k_grain - 1) / k_grain * KPerBlock; + const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split); + + const auto num_k_per_block = + arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(Number<0>{}) / gemm_arg.KBatch; + + const auto clear_workspace = [&]() { + hip_check_error( + hipMemsetAsync(p_e_grid, 0, arg.c_space_size_bytes, stream_config.stream_id_)); + }; + + const auto Run = [&](const auto& kernel) { + if(stream_config.flush_cache && + !(is_NGCHW_GKCYX_NGKHW() || + is_NGCDHW_GKCZYX_NGKDHW())) + { + typename GridwiseGemm::Argument gemm_arg_ = gemm_arg; + + std::array size_as_buffers; + size_as_buffers[0] = arg.a_grid_desc_kbatch_k0_m_k1_.GetElementSpaceSize() * + sizeof(ADataType) / GridwiseGemm::APackedSize; + + std::array size_bs_buffers; + size_bs_buffers[0] = arg.b_grid_desc_kbatch_k0_n_k1_.GetElementSpaceSize() * + sizeof(BDataType) / GridwiseGemm::BPackedSize; + + std::array size_ds_buffers; + + ck::utility::RotatingMemWrapperMultiABD, + Tuple, + Tuple<>> + rotating_mem(gemm_arg_, + stream_config.rotating_count, + size_as_buffers, + size_bs_buffers, + size_ds_buffers); + rotating_mem.Print(); + + auto run_flush_cache = [&]() { + // flush icache + ck::utility::flush_icache(); + // rotating mem + rotating_mem.Next(); + }; + ave_time += ck::utility::launch_and_time_kernel_with_preprocess( + stream_config, + run_flush_cache, + kernel, + dim3(gdx, gdy, gdz), + dim3(BlockSize), + 0, + gemm_arg_, + arg.a_grid_desc_kbatch_k0_m_k1_, + arg.b_grid_desc_kbatch_k0_n_k1_, + arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.compute_ptr_offset_of_batch_, + num_k_per_block); + } + else + { + ave_time += launch_and_time_kernel_with_preprocess( + stream_config, + clear_workspace, + kernel, + dim3(gdx, gdy, gdz), + dim3(BlockSize), + 0, + gemm_arg, + arg.a_grid_desc_kbatch_k0_m_k1_, + arg.b_grid_desc_kbatch_k0_n_k1_, + arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.compute_ptr_offset_of_batch_, + num_k_per_block); + } + }; + + constexpr index_t minimum_occupancy = + BlkGemmPipeSched == BlockGemmPipelineScheduler::Intrawave ? 1 : 2; + + if(has_main_k_block_loop) + { + // Tail number always full + if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 || + BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) + { + if(gemm_arg.KBatch > 1) + { + const auto kernel = kernel_grouped_conv_bwd_weight_wmma_cshuffle_v3< + GridwiseGemm, + remove_reference_t, + remove_reference_t, + remove_reference_t< + DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, + ComputePtrOffsetOfStridedBatch, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy>; + Run(kernel); + } + else + { + const auto kernel = kernel_grouped_conv_bwd_weight_wmma_cshuffle_v3< + GridwiseGemm, + remove_reference_t, + remove_reference_t, + remove_reference_t< + DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, + ComputePtrOffsetOfStridedBatch, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy>; + Run(kernel); + } + } + else + { + // TODO: Implement + } + } + else + { + // Tail number always 1 + if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1) + { + if(gemm_arg.KBatch > 1) + { + const auto kernel = kernel_grouped_conv_bwd_weight_wmma_cshuffle_v3< + GridwiseGemm, + remove_reference_t, + remove_reference_t, + remove_reference_t< + DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, + ComputePtrOffsetOfStridedBatch, + false, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy>; + Run(kernel); + } + else + { + const auto kernel = kernel_grouped_conv_bwd_weight_wmma_cshuffle_v3< + GridwiseGemm, + remove_reference_t, + remove_reference_t, + remove_reference_t< + DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, + ComputePtrOffsetOfStridedBatch, + false, + InMemoryDataOperationEnum::Set, + minimum_occupancy>; + Run(kernel); + } + } + } + + // C Transpose kernel dispatch (if needed) + if constexpr(is_NGCHW_GKCYX_NGKHW() || + is_NGCDHW_GKCZYX_NGKDHW()) + { + const index_t grid_size_e = + arg.elementwise_block_2_ctile_map_transpose_e_.CalculateGridSize( + arg.e_in_transpose_desc_); + + const CDataType* p_e_in_grid = static_cast(p_e_grid); + + // Different data type for A and B is not supported + auto kernel_transpose = kernel_elementwise, + ck::Tuple, + ck::Tuple, + ck::Tuple, + Block2TileMapTranspose, + element_wise::PassThrough>; + + ave_time += launch_and_time_kernel(stream_config, + kernel_transpose, + dim3(grid_size_e), + dim3(BlockSize), + 0, + make_tuple(arg.e_in_transpose_desc_), + make_tuple(arg.e_out_transpose_desc_), + make_tuple(p_e_in_grid), + make_tuple(arg.p_c_grid_), + arg.elementwise_block_2_ctile_map_transpose_e_, + element_wise::PassThrough{}); + } + + return ave_time; + } + + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + static bool IsSupportedArgument(const Argument& arg) + { +#if DISABLE_SPLIT_K_AUTODEDUCE_FOR_ONE_STAGE_KERNELS + if(arg.k_batch_ < 0) + { + return false; + } +#endif + const index_t GemmM = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I1); + const index_t GemmN = arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I1); + const index_t GemmK = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I0) * + arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I2); + + typename GridwiseGemm::Argument gemm_arg{std::array{nullptr}, // p_as_grid + std::array{nullptr}, // p_bs_grid + std::array{}, // p_ds_grid + nullptr, // p_e_grid + GemmM, // M + GemmN, // N + GemmK, // K + std::array{I0}, // StrideAs + std::array{I0}, // StrideBs + std::array{}, // StrideDs + I0, // StrideE + arg.k_batch_, + AElementwiseOperation{}, + BElementwiseOperation{}, + CElementwiseOperation{}}; + + const auto num_k_loop = gemm_arg.AK0 / (KPerBlock / ABK1); + if constexpr(BlkGemmPipelineVer != BlockGemmPipelineVersion::v1) + { + if(num_k_loop <= GridwiseGemm::BlockwiseGemmPipe::PrefetchStages) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Unsupported num K loop." << std::endl; + } + return false; + } + } + + if(!ck::is_gfx11_supported() && !ck::is_gfx12_supported()) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Unsupported: Architecture must be gfx11/gfx12." << std::endl; + } + return false; + } + + if constexpr(std::is_same_v || + std::is_same_v) + { + if(gemm_arg.KBatch > 1 && ck::is_gfx11_supported()) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Unsupported splitK on gfx11." << std::endl; + } + // gfx11 does not support *_atomic_pk_add_f16/bf16 instructions + return false; + } + } + + if constexpr(std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v) + { + if(ck::is_gfx11_supported()) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Unsupported f8 / bf8 on gfx11." << std::endl; + } + return false; + } + } + + if constexpr(NDimSpatial == 1) + { + if constexpr(!is_GNWC_GKXC_GNWK()) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Unsupported layout." << std::endl; + } + return false; + } + } + else if constexpr(NDimSpatial == 2) + { + if constexpr(!(is_NHWGC_GKYXC_NHWGK() || + is_GNHWC_GKYXC_GNHWK() || + is_NGCHW_NGKHW())) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Unsupported layout." << std::endl; + } + return false; + } + } + else if constexpr(NDimSpatial == 3) + { + if constexpr(!(is_NDHWGC_GKZYXC_NDHWGK() || + is_GNDHWC_GKZYXC_GNDHWK() || + is_NGCDHW_NGKDHW())) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Unsupported layout." << std::endl; + } + return false; + } + } + else + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Unsupported layout." << std::endl; + } + return false; + } + + if constexpr(ConvBackwardWeightSpecialization == + ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0) + { + // check if it's 1x1, stride=1 pad = 0 conv + for(int i = 0; i < NDimSpatial; i++) + { + if(!(arg.filter_spatial_lengths_[i] == 1 && arg.conv_filter_strides_[i] == 1 && + arg.input_left_pads_[i] == 0 && arg.input_right_pads_[i] == 0)) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Unsupported stride / pad." << std::endl; + } + return false; + } + } + } + if(!(ABlockTransferSrcVectorDim == 1 && BBlockTransferSrcVectorDim == 1 && + arg.Conv_K_ % ABlockTransferSrcScalarPerVector == 0 && + arg.Conv_C_ % BBlockTransferSrcScalarPerVector == 0)) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Unsupported BlockTransferSrcScalarPerVector." << std::endl; + } + return false; + } + + // vector store C matrix into global memory + if(!(arg.Conv_C_ % CShuffleBlockTransferScalarPerVector_NPerBlock == 0)) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Unsupported CShuffleBlockTransferScalarPerVector_NPerBlock." + << std::endl; + } + return false; + } + + if constexpr(is_NGCHW_NGKHW() || + is_NGCDHW_NGKDHW()) + { + if((arg.Conv_G_ * arg.Conv_C_) % TransposeTransferDstScalarPerVector != 0) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Unsupported TransposeTransferDstScalarPerVector with GC." + << std::endl; + } + return false; + } + + if((arg.Conv_G_ * arg.Conv_K_) % TransposeTransferDstScalarPerVector != 0) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Unsupported TransposeTransferDstScalarPerVector with GK." + << std::endl; + } + return false; + } + + const index_t input_spatial_acum = ck::accumulate_n( + arg.input_spatial_lengths_.begin(), NDimSpatial, 1, std::multiplies<>()); + const index_t output_spatial_acum = ck::accumulate_n( + arg.output_spatial_lengths_.begin(), NDimSpatial, 1, std::multiplies<>()); + + if(input_spatial_acum % TransposeTransferSrcScalarPerVector != 0) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout + << "Unsupported input_spatial_acum % TransposeTransferSrcScalarPerVector." + << std::endl; + } + return false; + } + + if(output_spatial_acum % TransposeTransferSrcScalarPerVector != 0) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout + << "Unsupported output_spatial_acum % TransposeTransferSrcScalarPerVector." + << std::endl; + } + return false; + } + + if(!arg.p_workspace_) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Warning: Workspace for " + "DeviceGroupedConvBwdWeight_Xdl_CShuffle::Argument is not " + "allocated, use SetWorkSpacePointer." + << std::endl; + } + return false; + } + + constexpr long_index_t TwoGB = (long_index_t{1} << 31); + if(!(arg.a_out_transpose_desc_.GetElementSpaceSize() * sizeof(ADataType) <= TwoGB && + arg.b_out_transpose_desc_.GetElementSpaceSize() * sizeof(BDataType) <= TwoGB)) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Unsupported: Problem exceeds 2GB limit." << std::endl; + } + return false; + } + } + + return GridwiseGemm::CheckValidity(gemm_arg); + } + + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto + MakeArgument(const InDataType* p_in_grid, + WeiDataType* p_wei_grid, + const OutDataType* p_out_grid, + const std::array& b_g_n_c_wis_lengths, // input + const std::array& b_g_n_c_wis_strides, + const std::array& e_g_k_c_xs_lengths, // weight + const std::array& e_g_k_c_xs_strides, + const std::array& a_g_n_k_wos_lengths, // output + const std::array& a_g_n_k_wos_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads, + InElementwiseOperation in_element_op, + WeiElementwiseOperation wei_element_op, + OutElementwiseOperation out_element_op, + const ck::index_t split_k) + { + return Argument{p_in_grid, + p_wei_grid, + p_out_grid, + b_g_n_c_wis_lengths, // input + b_g_n_c_wis_strides, + e_g_k_c_xs_lengths, // weight + e_g_k_c_xs_strides, + a_g_n_k_wos_lengths, // output + a_g_n_k_wos_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + 1, + 1, + in_element_op, + wei_element_op, + out_element_op, + split_k}; + } + + static auto MakeInvoker() { return Invoker{}; } + + std::unique_ptr + MakeArgumentPointer(const void* p_in_grid, + void* p_wei_grid, + const void* p_out_grid, + const std::array& b_g_n_c_wis_lengths, // input + const std::array& b_g_n_c_wis_strides, + const std::array& e_g_k_c_xs_lengths, // weight + const std::array& e_g_k_c_xs_strides, + const std::array& a_g_n_k_wos_lengths, // output + const std::array& a_g_n_k_wos_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads, + InElementwiseOperation in_element_op, + WeiElementwiseOperation wei_element_op, + OutElementwiseOperation out_element_op, + const ck::index_t split_k) override + { + return std::make_unique(static_cast(p_in_grid), + static_cast(p_wei_grid), + static_cast(p_out_grid), + b_g_n_c_wis_lengths, // input + b_g_n_c_wis_strides, + e_g_k_c_xs_lengths, // weight + e_g_k_c_xs_strides, + a_g_n_k_wos_lengths, // output + a_g_n_k_wos_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + 1, + 1, + in_element_op, + wei_element_op, + out_element_op, + split_k); + } + + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceGroupedConvBwdWeight_Wmma_CShuffleV3" + << "<" + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << KPerBlock << ", " + << getConvBackwardWeightSpecializationString(ConvBackwardWeightSpecialization) << ", " + << ABK1 << ", " + << MRepeat << ", " + << NRepeat << ", " + << ABlockTransferSrcScalarPerVector << ", " + << ABlockTransferDstScalarPerVector_AK1 << ", " + << BBlockTransferSrcScalarPerVector << ", " + << BBlockTransferDstScalarPerVector_BK1 << ", " + << CShuffleMRepeatPerShuffle << ", " + << CShuffleNRepeatPerShuffle << ", " + << CShuffleBlockTransferScalarPerVector_NPerBlock << ", " + << TransposeTransferSrcScalarPerVector << ", " + << TransposeTransferDstScalarPerVector + << ">"; + // clang-format on + + return str.str(); + } + + size_t GetWorkSpaceSize(const BaseArgument* p_arg) const override + { + auto arg = dynamic_cast(p_arg); + if(arg) + { + return arg->GetWorkspaceSizeBytes(); + } + else + throw std::runtime_error( + "The argument pointer is not an object of " + "DeviceGroupedConvBwdWeight_Wmma_CShuffleV3::Argument structure!"); + } + + void SetWorkSpacePointer(BaseArgument* p_arg, + void* p_workspace, + const StreamConfig& = StreamConfig{}) const override + { + auto p_arg_ = dynamic_cast(p_arg); + if(p_arg_) + { + p_arg_->p_workspace_ = p_workspace; + } + else + throw std::runtime_error( + "The argument pointer is not an object of " + "DeviceGroupedConvBwdWeight_Wmma_CShuffleV3::Argument structure!"); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_ab_transfer_thread_tiles.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_ab_transfer_thread_tiles.hpp index 69f8f44390..96387c6f64 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_ab_transfer_thread_tiles.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_ab_transfer_thread_tiles.hpp @@ -295,7 +295,7 @@ struct ABTransferThreadTiles BlockDescriptor& block_descriptor, ABElementwiseOperation& ab_element_op, const index_t block_mn_id, - const index_t) + const index_t k_id) { constexpr index_t NumABTensor = ABsDataType::Size(); const index_t mn_block_data_idx_on_grid = @@ -304,7 +304,7 @@ struct ABTransferThreadTiles if constexpr(NumABTensor > 1) { const auto idx_as_block_begin = generate_tuple( - [&](auto) { return make_multi_index(0, mn_block_data_idx_on_grid, 0); }, + [&](auto) { return make_multi_index(k_id, mn_block_data_idx_on_grid, 0); }, Number{}); return ThreadGroupTensorSliceTransfer_v7r2< @@ -357,7 +357,7 @@ struct ABTransferThreadTiles ABThreadTransferSrcResetCoordinateAfterRun, true, GlobalBufferNum>(grid_descriptor[I0], - make_multi_index(0, mn_block_data_idx_on_grid, 0), + make_multi_index(k_id, mn_block_data_idx_on_grid, 0), ab_element_op, block_descriptor, make_multi_index(0, 0, 0), diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp index fea0102337..0166e2f005 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp @@ -333,6 +333,7 @@ struct GridwiseGemm_wmma_cshuffle_v3 struct Problem { + __host__ Problem() = default; __host__ Problem(index_t M_, index_t N_, index_t K_, @@ -409,6 +410,7 @@ struct GridwiseGemm_wmma_cshuffle_v3 // Argument struct Argument : public tensor_operation::device::BaseArgument, public Problem { + __host__ Argument() = default; __host__ Argument(std::array p_as_grid_, std::array p_bs_grid_, std::array p_ds_grid_, @@ -583,7 +585,8 @@ struct GridwiseGemm_wmma_cshuffle_v3 BElementwiseOperation b_element_op, CDEElementwiseOperation cde_element_op, EpilogueArgument& epilogue_args, - const index_t k_id = 0) + const index_t A_k_id = 0, + const index_t B_k_id = 0) { const auto as_grid_desc_ak0_m_ak1 = MakeAsGridDescriptor_AK0_M_AK1( problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideAs, problem.AK0); @@ -651,7 +654,8 @@ struct GridwiseGemm_wmma_cshuffle_v3 a_scale_struct, b_scale_struct, epilogue_args, - k_id); + A_k_id, + B_k_id); } template ( - p_shared, splitk_batch_offset, karg, DefaultBlock2CTileMap(karg), epilogue_args, k_id); + EpilogueArgument>(p_shared, + splitk_batch_offset, + karg, + DefaultBlock2CTileMap(karg), + epilogue_args, + A_k_id, + B_k_id); } __device__ static auto DefaultBlock2CTileMap(const Problem& problem) { return Block2CTileMap{problem.M, problem.N, 4}; } + + // Run method for convolution (grid descriptors are passed as arguments, + // not generated internally) + template + __device__ static void Run(void* p_shared, + const AGridDesc_AK0_M_K1 a_grid_desc_ak0_m_ak1, + const BGridDesc_BK0_N_K1 b_grid_desc_bk0_n_bk1, + const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock + c_grid_desc_mblock_mperblock_nblock_nperblock, + const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch, + const index_t num_k_per_block, + Argument& karg, + EpilogueArgument& epilogue_args) + { + const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.z * NumGroupsToMerge); + const index_t k_idx = __builtin_amdgcn_readfirstlane(blockIdx.y * num_k_per_block); + + const long_index_t a_batch_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx)); + const long_index_t b_batch_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx)); + const long_index_t e_batch_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx)); + + AsGridPointer p_as_grid_; + static_for<0, NumATensor, 1>{}([&](auto i) { + using ADataType_ = remove_cvref_t>; + p_as_grid_(i) = static_cast(karg.p_as_grid[i]) + a_batch_offset; + }); + + BsGridPointer p_bs_grid_; + static_for<0, NumBTensor, 1>{}([&](auto i) { + using BDataType_ = remove_cvref_t>; + p_bs_grid_(i) = static_cast(karg.p_bs_grid[i]) + b_batch_offset; + }); + + const auto ds_grid_desc_m_n = + MakeDsGridDescriptor_M_N(karg.M, karg.MPadded, karg.N, karg.NPadded, karg.StrideDs); + + const auto ds_grid_desc_mblock_mperblock_nblock_nperblock = + MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + ds_grid_desc_m_n, karg.MBlock, karg.NBlock); + + const auto as_grid_desc_ak0_m_ak1 = generate_tuple( + [&](auto i) { + ignore = i; + return a_grid_desc_ak0_m_ak1; + }, + Number{}); + + const auto bs_grid_desc_bk0_n_bk1 = generate_tuple( + [&](auto i) { + ignore = i; + return b_grid_desc_bk0_n_bk1; + }, + Number{}); + + // divide block work by [M, N] + const auto block_2_ctile_map = Block2CTileMap{karg.M, karg.N, 4}; + + const auto block_work_idx = + block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); + + if(!block_2_ctile_map.ValidCTileIndex( + block_work_idx, + make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0), + c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2)))) + { + return; + } + + const index_t block_m_id = __builtin_amdgcn_readfirstlane(block_work_idx[I0]); + const index_t block_n_id = __builtin_amdgcn_readfirstlane(block_work_idx[I1]); + + // Scale structs (Empty) + using Scale = typename BlockwiseGemmPipe::Empty; + auto b_scale_struct = Scale{}; + auto a_scale_struct = Scale{}; + + const index_t num_k_block_per_scale = GetKBlockPerScale(); + + Base::template Run(p_as_grid_, + p_bs_grid_, + karg.p_ds_grid, + karg.p_e_grid + e_batch_offset, + p_shared, + as_grid_desc_ak0_m_ak1, + bs_grid_desc_bk0_n_bk1, + ds_grid_desc_mblock_mperblock_nblock_nperblock, + c_grid_desc_mblock_mperblock_nblock_nperblock, + karg.a_element_op, + karg.b_element_op, + karg.cde_element_op, + block_m_id, + block_n_id, + num_k_block_per_scale, + a_scale_struct, + b_scale_struct, + epilogue_args, + k_idx, + k_idx, + karg.KBatch); + } }; } // namespace ck diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_ab_scale.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_ab_scale.hpp index 0974f45a2b..92561d00d4 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_ab_scale.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_ab_scale.hpp @@ -723,7 +723,8 @@ struct GridwiseGemm_wmma_cshuffle_v3_ab_scale BElementwiseOperation b_element_op, CDEElementwiseOperation cde_element_op, EpilogueArgument& epilogue_args, - const index_t k_id = 0) + const index_t A_k_id = 0, + const index_t B_k_id = 0) { const auto as_grid_desc_ak0_m_ak1 = MakeAsGridDescriptor_AK0_M_AK1( problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideAs, problem.AK0); @@ -793,7 +794,8 @@ struct GridwiseGemm_wmma_cshuffle_v3_ab_scale a_scale_struct, b_scale_struct, epilogue_args, - k_id); + A_k_id, + B_k_id); } // NOTE: Wrapper function to have __global__ function in common @@ -806,7 +808,8 @@ struct GridwiseGemm_wmma_cshuffle_v3_ab_scale const SplitKBatchOffset& splitk_batch_offset, Argument& karg, EpilogueArgument& epilogue_args, - const index_t k_id = 0) + const index_t A_k_id = 0, + const index_t B_k_id = 0) { // shift A matrices pointer for splitk AsGridPointer p_as_grid_splitk; @@ -857,7 +860,8 @@ struct GridwiseGemm_wmma_cshuffle_v3_ab_scale karg.b_element_op, karg.cde_element_op, epilogue_args, - k_id); + A_k_id, + B_k_id); } }; diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp index 81aa1ac986..9b5dab493e 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp @@ -101,7 +101,12 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) auto epilogue_args = typename GridwiseGemm::EpilogueCShuffle{}; GridwiseGemm::template Run( - p_shared, splitk_batch_offset, karg, epilogue_args, k_id); + p_shared, + splitk_batch_offset, + karg, + epilogue_args, + 0, /* A_k_id == 0 (we shift the pointer for splitk) */ + k_id); #if defined(__gfx11__) } @@ -344,11 +349,20 @@ struct GridwiseGemm_wmma_cshuffle_v3_base // return block_id to C matrix tile idx (m0, n0) mapping using Block2CTileMap = BlockToCTileMap_Grouped_M00_N0_M01Adapt<8, MPerBlock, NPerBlock>; + // Calculate grid size taking into account splitk (KBatch) + // 2D grid (x,z) __host__ static auto CalculateGridSize(index_t M, index_t N, index_t KBatch) { return std::make_tuple(Block2CTileMap::CalculateGridSize(M, N), 1, KBatch); } + // Calculate grid size taking into account splitk (KBatch) and multiple groups (Batch) + // 3D grid (x,y,z) + __host__ static auto CalculateGridSize(index_t M, index_t N, index_t KBatch, index_t Batch) + { + return std::make_tuple(Block2CTileMap::CalculateGridSize(M, N), KBatch, Batch); + } + __host__ static auto CalculateMPadded(index_t M) { return math::integer_least_multiple(M, MPerBlock); @@ -706,8 +720,10 @@ struct GridwiseGemm_wmma_cshuffle_v3_base ReduceTrait>; template - __device__ static constexpr auto MakeDEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( - const DEGridDesc& de_grid_desc_m_n, index_t MBlock, index_t NBlock) + __host__ __device__ static constexpr auto + MakeDEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const DEGridDesc& de_grid_desc_m_n, + index_t MBlock, + index_t NBlock) { const auto de_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor( de_grid_desc_m_n, @@ -1004,6 +1020,8 @@ struct GridwiseGemm_wmma_cshuffle_v3_base } } + // Note: arguments k_batch and k_id should be set if splitk is used + // with implicit gemm (no pointer shift but shift using tensor descriptors) template ( - as_grid_desc_ak0_m_ak1, a_block_desc_ak0_m_ak1, a_element_op, block_m_id, k_id); + as_grid_desc_ak0_m_ak1, a_block_desc_ak0_m_ak1, a_element_op, block_m_id, A_k_id); // B matrix blockwise copy auto b_blockwise_copy = @@ -1075,7 +1095,7 @@ struct GridwiseGemm_wmma_cshuffle_v3_base BsDataType, BElementwiseOperation, BlockwiseGemmPipe::GlobalBufferNum>( - bs_grid_desc_bk0_n_bk1, b_block_desc_bk0_n_bk1, b_element_op, block_n_id, k_id); + bs_grid_desc_bk0_n_bk1, b_block_desc_bk0_n_bk1, b_element_op, block_n_id, B_k_id); // LDS allocation for A and B: be careful of alignment constexpr auto a_block_space_size_aligned = math::integer_least_multiple( @@ -1100,7 +1120,7 @@ struct GridwiseGemm_wmma_cshuffle_v3_base auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer(); const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane( - ATransfer::GetKDimension(as_grid_desc_ak0_m_ak1[I0]) / KPerBlock); + ATransfer::GetKDimension(as_grid_desc_ak0_m_ak1[I0]) / (KPerBlock * k_batch)); blockwise_gemm_pipeline.template Run( get_first_element_workaround(as_grid_desc_ak0_m_ak1), diff --git a/include/ck/utility/generic_memory_space_atomic.hpp b/include/ck/utility/generic_memory_space_atomic.hpp index 210b354504..b76d957044 100644 --- a/include/ck/utility/generic_memory_space_atomic.hpp +++ b/include/ck/utility/generic_memory_space_atomic.hpp @@ -71,6 +71,29 @@ __device__ float2_t atomic_add(float2_t* p_dst, const float2_t& x) return vy.template AsType()[I0]; } +template <> +__device__ float4_t atomic_add(float4_t* p_dst, const float4_t& x) +{ + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + + const vector_type vx{x}; + vector_type vy{0}; + + vy.template AsType()(I0) = + atomicAdd(c_style_pointer_cast(p_dst), vx.template AsType()[I0]); + vy.template AsType()(I1) = + atomicAdd(c_style_pointer_cast(p_dst) + 1, vx.template AsType()[I1]); + vy.template AsType()(I2) = + atomicAdd(c_style_pointer_cast(p_dst) + 2, vx.template AsType()[I2]); + vy.template AsType()(I3) = + atomicAdd(c_style_pointer_cast(p_dst) + 3, vx.template AsType()[I3]); + + return vy.template AsType()[I0]; +} + template <> __device__ double2_t atomic_add(double2_t* p_dst, const double2_t& x) { diff --git a/library/include/ck/library/tensor_operation_instance/add_grouped_conv_bwd_wei_exp_device_operation_instance.hpp b/library/include/ck/library/tensor_operation_instance/add_grouped_conv_bwd_wei_exp_device_operation_instance.hpp index 6a23a595bc..594c9ca5a7 100644 --- a/library/include/ck/library/tensor_operation_instance/add_grouped_conv_bwd_wei_exp_device_operation_instance.hpp +++ b/library/include/ck/library/tensor_operation_instance/add_grouped_conv_bwd_wei_exp_device_operation_instance.hpp @@ -7,7 +7,7 @@ #include #include "ck/utility/functional2.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_explicit_xdl.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_explicit.hpp" namespace ck { namespace tensor_operation { @@ -32,17 +32,17 @@ void add_explicit_gemm_device_operation_instances( ck::static_for<0, std::tuple_size_v, 1>{}([&](auto i) { using DeviceGemmOp = std::tuple_element_t; - using NewOpInstance = DeviceGroupedConvBwdWeight_Explicit_Xdl; + using NewOpInstance = DeviceGroupedConvBwdWeight_Explicit; static_assert(std::is_base_of_v, "wrong! NewOpInstance should be derived from BaseOp"); diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_exp_gemm_wmma_universal_km_kn_mn_instance.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_exp_gemm_wmma_universal_km_kn_mn_instance.hpp new file mode 100644 index 0000000000..48c9f10312 --- /dev/null +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_exp_gemm_wmma_universal_km_kn_mn_instance.hpp @@ -0,0 +1,138 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_wmma_cshuffle_v3.hpp" + +#include "ck/library/tensor_operation_instance/add_grouped_conv_bwd_wei_exp_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using namespace ck::tensor_layout::convolution; + +using BF16 = bhalf_t; +using F16 = half_t; +using F32 = float; + +using Row = tensor_layout::gemm::RowMajor; +using Col = tensor_layout::gemm::ColumnMajor; + +template +using S = Sequence; + +using PassThrough = element_wise::PassThrough; + +static constexpr auto GemmDefault = GemmSpecialization::Default; +static constexpr auto GemmKPadding = GemmSpecialization::KPadding; +static constexpr auto GemmMPadding = GemmSpecialization::MPadding; +static constexpr auto GemmMNPadding = GemmSpecialization::MNPadding; +static constexpr auto GemmMKPadding = GemmSpecialization::MKPadding; +static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding; + +static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; +static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; + +template +using device_gemm_wmma_universal_km_kn_mn_GemmDefault_instances = std::tuple< + // clang-format off + //#####################################| ALayout| BLayout| DsLayout |ELayout| ADataType| BDataType| DsDataType| CDataType| AccDataType| CShuffle| A| B| CDE| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransferClusterLengths| CShuffleBlockTransfer| BlockwiseGemm| BlockwiseGemm| + //#####################################| | | | | | | | | | DataType| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| _MBlock_MPerBlock_NBlock_NPerBlock| ScalarPerVectors| Pipeline| Pipeline| + //#####################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| | | Scheduler| Verision| + //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceBatchedGemmMultiD_Wmma_CShuffleV3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, InOutDataType, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 48, 96, 64, 8, 8, 16, 16, 3, 3, S<8, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 6, 8, 0, S<8, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 6, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceBatchedGemmMultiD_Wmma_CShuffleV3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, InOutDataType, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 96, 32, 8, 8, 16, 16, 4, 3, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 0, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 6, 8, 1, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceBatchedGemmMultiD_Wmma_CShuffleV3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, InOutDataType, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 32, 64, 128, 8, 8, 16, 16, 2, 1, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 0, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, S<8, 8, 8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceBatchedGemmMultiD_Wmma_CShuffleV3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, InOutDataType, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 64, 32, 8, 8, 16, 16, 4, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 0, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceBatchedGemmMultiD_Wmma_CShuffleV3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, InOutDataType, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 96, 32, 8, 8, 16, 16, 4, 3, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 6, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceBatchedGemmMultiD_Wmma_CShuffleV3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, InOutDataType, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmDefault, 192, 48, 96, 192, 8, 8, 16, 16, 3, 1, S<24, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 6, 8, 1, S<24, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 6, 8, 0, 1, 1, S<1, 16, 1, 12>, S<8, 8, 8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceBatchedGemmMultiD_Wmma_CShuffleV3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, InOutDataType, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 48, 64, 64, 8, 8, 16, 16, 3, 2, S<8, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 6, 8, 1, S<8, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceBatchedGemmMultiD_Wmma_CShuffleV3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, InOutDataType, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 96, 128, 64, 8, 8, 16, 16, 6, 2, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 6, 8, 0, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, S<8, 8, 8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceBatchedGemmMultiD_Wmma_CShuffleV3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, InOutDataType, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmDefault, 192, 32, 96, 192, 8, 8, 16, 16, 2, 1, S<24, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 0, S<24, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 6, 8, 1, 1, 1, S<1, 16, 1, 12>, S<8, 8, 8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceBatchedGemmMultiD_Wmma_CShuffleV3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, InOutDataType, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 32, 96, 64, 8, 8, 16, 16, 2, 3, S<8, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 0, S<8, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 6, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceBatchedGemmMultiD_Wmma_CShuffleV3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, InOutDataType, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 32, 32, 128, 8, 8, 16, 16, 2, 1, S<16, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 0, S<16, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceBatchedGemmMultiD_Wmma_CShuffleV3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, InOutDataType, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<8, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<8, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceBatchedGemmMultiD_Wmma_CShuffleV3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, InOutDataType, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmDefault, 192, 32, 96, 192, 8, 8, 16, 16, 2, 1, S<24, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<24, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 6, 8, 0, 1, 1, S<1, 16, 1, 12>, S<8, 8, 8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceBatchedGemmMultiD_Wmma_CShuffleV3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, InOutDataType, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<8, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 0, S<8, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceBatchedGemmMultiD_Wmma_CShuffleV3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, InOutDataType, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 64, 32, 8, 8, 16, 16, 4, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 0, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceBatchedGemmMultiD_Wmma_CShuffleV3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, InOutDataType, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 32, 32, 64, 8, 8, 16, 16, 2, 1, S<8, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 0, S<8, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1> + // clang-format on + >; + +using device_gemm_wmma_universal_km_kn_mn_GemmMNKPadding_f16_instances = std::tuple< + // clang-format off + //#####################################| ALayout| BLayout| DsLayout |ELayout| ADataType| BDataType| DsDataType| CDataType| AccDataType| CShuffle| A| B| CDE| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransferClusterLengths| CShuffleBlockTransfer| BlockwiseGemm| BlockwiseGemm| + //#####################################| | | | | | | | | | DataType| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| _MBlock_MPerBlock_NBlock_NPerBlock| ScalarPerVectors| Pipeline| Pipeline| + //#####################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| | | Scheduler| Verision| + //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceBatchedGemmMultiD_Wmma_CShuffleV3< Col, Row, Tuple<>, Row, F16, F16, Tuple<>, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 64, 64, 96, 32, 8, 8, 16, 16, 4, 3, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 0, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 6, 8, 1, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceBatchedGemmMultiD_Wmma_CShuffleV3< Col, Row, Tuple<>, Row, F16, F16, Tuple<>, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 64, 64, 64, 64, 8, 8, 16, 16, 4, 2, S<8, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 0, S<8, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceBatchedGemmMultiD_Wmma_CShuffleV3< Col, Row, Tuple<>, Row, F16, F16, Tuple<>, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 128, 48, 64, 128, 8, 8, 16, 16, 3, 1, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 6, 8, 1, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 0, 1, 1, S<1, 16, 1, 8>, S<8, 8, 8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceBatchedGemmMultiD_Wmma_CShuffleV3< Col, Row, Tuple<>, Row, F16, F16, Tuple<>, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 64, 64, 64, 64, 8, 8, 16, 16, 4, 2, S<8, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 0, S<8, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceBatchedGemmMultiD_Wmma_CShuffleV3< Col, Row, Tuple<>, Row, F16, F16, Tuple<>, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 64, 64, 96, 32, 8, 8, 16, 16, 4, 3, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 0, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 6, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceBatchedGemmMultiD_Wmma_CShuffleV3< Col, Row, Tuple<>, Row, F16, F16, Tuple<>, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 64, 48, 64, 64, 8, 8, 16, 16, 3, 2, S<8, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 6, 8, 1, S<8, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceBatchedGemmMultiD_Wmma_CShuffleV3< Col, Row, Tuple<>, Row, F16, F16, Tuple<>, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 64, 96, 64, 32, 8, 8, 16, 16, 6, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 6, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceBatchedGemmMultiD_Wmma_CShuffleV3< Col, Row, Tuple<>, Row, F16, F16, Tuple<>, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 64, 32, 32, 128, 8, 8, 16, 16, 2, 1, S<16, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 0, S<16, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceBatchedGemmMultiD_Wmma_CShuffleV3< Col, Row, Tuple<>, Row, F16, F16, Tuple<>, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 64, 32, 32, 128, 8, 8, 16, 16, 2, 1, S<16, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 0, S<16, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceBatchedGemmMultiD_Wmma_CShuffleV3< Col, Row, Tuple<>, Row, F16, F16, Tuple<>, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 64, 48, 32, 128, 8, 8, 16, 16, 3, 1, S<16, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 6, 8, 1, S<16, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceBatchedGemmMultiD_Wmma_CShuffleV3< Col, Row, Tuple<>, Row, F16, F16, Tuple<>, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 64, 32, 96, 64, 8, 8, 16, 16, 2, 3, S<8, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 0, S<8, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 6, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + // DeviceBatchedGemmMultiD_Wmma_CShuffleV3< Col, Row, Tuple<>, Row, F16, F16, Tuple<>, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 96, 64, 96, 48, 8, 8, 16, 16, 4, 2, S<6, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<6, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 6, 8, 0, 1, 1, S<1, 16, 1, 6>, S<8, 8, 8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, // Incorrect results for f16 + DeviceBatchedGemmMultiD_Wmma_CShuffleV3< Col, Row, Tuple<>, Row, F16, F16, Tuple<>, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 64, 32, 32, 128, 8, 8, 16, 16, 2, 1, S<16, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 0, S<16, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1> + // clang-format on + >; + +using device_gemm_wmma_universal_km_kn_mn_GemmMNKPadding_bf16_instances = std::tuple< + // clang-format off + //#####################################| ALayout| BLayout| DsLayout |ELayout| ADataType| BDataType| DsDataType| CDataType| AccDataType| CShuffle| A| B| CDE| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransferClusterLengths| CShuffleBlockTransfer| BlockwiseGemm| BlockwiseGemm| + //#####################################| | | | | | | | | | DataType| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| _MBlock_MPerBlock_NBlock_NPerBlock| ScalarPerVectors| Pipeline| Pipeline| + //#####################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| | | Scheduler| Verision| + //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceBatchedGemmMultiD_Wmma_CShuffleV3< Col, Row, Tuple<>, Row, BF16, BF16, Tuple<>, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 64, 64, 96, 32, 8, 8, 16, 16, 4, 3, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 0, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 6, 8, 1, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceBatchedGemmMultiD_Wmma_CShuffleV3< Col, Row, Tuple<>, Row, BF16, BF16, Tuple<>, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 64, 64, 64, 64, 8, 8, 16, 16, 4, 2, S<8, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 0, S<8, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceBatchedGemmMultiD_Wmma_CShuffleV3< Col, Row, Tuple<>, Row, BF16, BF16, Tuple<>, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 128, 48, 64, 128, 8, 8, 16, 16, 3, 1, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 6, 8, 1, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 0, 1, 1, S<1, 16, 1, 8>, S<8, 8, 8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceBatchedGemmMultiD_Wmma_CShuffleV3< Col, Row, Tuple<>, Row, BF16, BF16, Tuple<>, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 64, 64, 64, 64, 8, 8, 16, 16, 4, 2, S<8, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 0, S<8, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceBatchedGemmMultiD_Wmma_CShuffleV3< Col, Row, Tuple<>, Row, BF16, BF16, Tuple<>, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 64, 64, 96, 32, 8, 8, 16, 16, 4, 3, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 0, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 6, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceBatchedGemmMultiD_Wmma_CShuffleV3< Col, Row, Tuple<>, Row, BF16, BF16, Tuple<>, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 64, 48, 64, 64, 8, 8, 16, 16, 3, 2, S<8, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 6, 8, 1, S<8, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceBatchedGemmMultiD_Wmma_CShuffleV3< Col, Row, Tuple<>, Row, BF16, BF16, Tuple<>, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 64, 96, 64, 32, 8, 8, 16, 16, 6, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 6, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceBatchedGemmMultiD_Wmma_CShuffleV3< Col, Row, Tuple<>, Row, BF16, BF16, Tuple<>, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 64, 32, 32, 128, 8, 8, 16, 16, 2, 1, S<16, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 0, S<16, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceBatchedGemmMultiD_Wmma_CShuffleV3< Col, Row, Tuple<>, Row, BF16, BF16, Tuple<>, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 64, 32, 32, 128, 8, 8, 16, 16, 2, 1, S<16, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 0, S<16, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceBatchedGemmMultiD_Wmma_CShuffleV3< Col, Row, Tuple<>, Row, BF16, BF16, Tuple<>, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 64, 48, 32, 128, 8, 8, 16, 16, 3, 1, S<16, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 6, 8, 1, S<16, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceBatchedGemmMultiD_Wmma_CShuffleV3< Col, Row, Tuple<>, Row, BF16, BF16, Tuple<>, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 64, 32, 96, 64, 8, 8, 16, 16, 2, 3, S<8, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 0, S<8, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 6, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceBatchedGemmMultiD_Wmma_CShuffleV3< Col, Row, Tuple<>, Row, BF16, BF16, Tuple<>, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 96, 64, 96, 48, 8, 8, 16, 16, 4, 2, S<6, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<6, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 6, 8, 0, 1, 1, S<1, 16, 1, 6>, S<8, 8, 8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, // Incorrect results for f16 + DeviceBatchedGemmMultiD_Wmma_CShuffleV3< Col, Row, Tuple<>, Row, BF16, BF16, Tuple<>, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 64, 32, 32, 128, 8, 8, 16, 16, 2, 1, S<16, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 0, S<16, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1> + // clang-format on + >; + +template +using device_gemm_wmma_universal_km_kn_mn_irregular_odd_mn_instances = std::tuple< + // clang-format off + //#####################################| ALayout| BLayout| DsLayout |ELayout| ADataType| BDataType| DsDataType| CDataType| AccDataType| CShuffle| A| B| CDE| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransferClusterLengths| CShuffleBlockTransfer| BlockwiseGemm| BlockwiseGemm| + //#####################################| | | | | | | | | | DataType| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| _MBlock_MPerBlock_NBlock_NPerBlock| ScalarPerVectors| Pipeline| Pipeline| + //#####################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| | | Scheduler| Verision| + //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // Latency friendly + DeviceBatchedGemmMultiD_Wmma_CShuffleV3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, F32, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 256, 128, 32, 64, 8, 8, 16, 16, 1, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 16, 1, 16>, S<1, 1, 1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + DeviceBatchedGemmMultiD_Wmma_CShuffleV3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, F32, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 256, 128, 32, 64, 8, 8, 16, 16, 1, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 16>, S<1, 1, 1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + DeviceBatchedGemmMultiD_Wmma_CShuffleV3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, F32, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 256, 128, 32, 128, 8, 8, 16, 16, 1, 2, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 16>, S<1, 1, 1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + DeviceBatchedGemmMultiD_Wmma_CShuffleV3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, F32, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 256, 128, 48, 128, 8, 8, 16, 16, 1, 3, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 16>, S<1, 1, 1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + DeviceBatchedGemmMultiD_Wmma_CShuffleV3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, F32, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 256, 128, 64, 32, 8, 8, 16, 16, 1, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 16>, S<1, 1, 1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + DeviceBatchedGemmMultiD_Wmma_CShuffleV3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, F32, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 256, 128, 64, 32, 8, 8, 16, 16, 1, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 16>, S<1, 1, 1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + DeviceBatchedGemmMultiD_Wmma_CShuffleV3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, F32, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 256, 128, 64, 64, 8, 8, 16, 16, 1, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 16>, S<1, 1, 1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + DeviceBatchedGemmMultiD_Wmma_CShuffleV3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, F32, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 256, 128, 64, 128, 8, 8, 16, 16, 1, 4, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 16>, S<1, 1, 1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + DeviceBatchedGemmMultiD_Wmma_CShuffleV3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, F32, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 256, 128, 96, 64, 8, 8, 16, 16, 1, 6, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 16>, S<1, 1, 1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + DeviceBatchedGemmMultiD_Wmma_CShuffleV3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, F32, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 256, 128, 96, 128, 8, 8, 16, 16, 1, 6, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 16>, S<1, 1, 1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + DeviceBatchedGemmMultiD_Wmma_CShuffleV3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, F32, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 256, 128, 192, 32, 8, 8, 16, 16, 1, 12, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 16>, S<1, 1, 1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + DeviceBatchedGemmMultiD_Wmma_CShuffleV3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, F32, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 256, 256, 96, 64, 8, 8, 16, 16, 2, 6, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 16>, S<1, 1, 1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1> + // Memory friendly + // TODO: add once v2 is implemented + // clang-format on + >; + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_two_stage_wmma_instance.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_two_stage_wmma_instance.hpp new file mode 100644 index 0000000000..b77c7348db --- /dev/null +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_two_stage_wmma_instance.hpp @@ -0,0 +1,91 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_wmma_cshuffle_v3.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using namespace ck::tensor_layout::convolution; + +using BF16 = ck::bhalf_t; +using F16 = ck::half_t; +using F32 = float; + +using Empty_Tuple = ck::Tuple<>; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto ConvBwdWeightDefault = + ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization::Default; + +static constexpr auto ConvBwdWeightFilter1x1Stride1Pad0 = + ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0; + +template +using device_grouped_conv_bwd_weight_two_stage_nhwgc_wmma_c_shuffle_f16_instances = std::tuple< + // clang-format off + //################################################| Num| InLayout| WeiLayout| OutLayout| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| KPer| ABK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CShuffleBlockTransfer| BlockGemm| BlockGemm| NumGroups| + //################################################| Dim| | | | Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVector| Pipeline| Pipeline| ToMerge| + //################################################| Spatial| | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| MBlock_MPerBlock| _NPerBlock| Sched| Ver| | + //################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | | | | + DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 32, 16, 16, 32, 8, 16, 16, 1, 1, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 1, 4, 0, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 1, 4, 0, 1, 1, S<1, 4, 1, 8>, 1, Scheduler, PipelineVersion, 1> + // DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 128, 128, 32, 8, 16, 16, 8, 2, S<4, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 8, 0, S<4, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, Scheduler, PipelineVersion, 1>, + // DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 128, 128, 32, 8, 16, 16, 8, 2, S<4, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 8, 1, S<4, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, Scheduler, PipelineVersion, 1>, + // DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 64, 64, 64, 8, 16, 16, 4, 2, S<8, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, S<8, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, 4, Scheduler, PipelineVersion, 1>, + // DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 128, 256, 64, 8, 16, 16, 8, 2, S<8, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 8, 1, S<8, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, 4, Scheduler, PipelineVersion, 1>, + // DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 48, 64, 128, 8, 16, 16, 3, 1, S<16, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 6, 8, 1, S<16, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, Scheduler, PipelineVersion, 1>, + // DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 96, 128, 64, 8, 16, 16, 6, 2, S<8, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 6, 8, 1, S<8, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, Scheduler, PipelineVersion, 1>, + // DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 64, 64, 128, 8, 16, 16, 4, 1, S<16, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, S<16, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, Scheduler, PipelineVersion, 1>, + // DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 96, 128, 128, 8, 16, 16, 6, 1, S<16, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 6, 8, 1, S<16, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, 4, Scheduler, PipelineVersion, 1> + // DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 96, 96, 96, 48, 8, 16, 16, 6, 2, S<6, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 6, 8, 0, S<6, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 6, 8, 0, 1, 1, S<1, 16, 1, 6>, 4, Scheduler, PipelineVersion, 1>, // Incorrect results for at least GemmDefault + // DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 96, 96, 96, 48, 8, 16, 16, 6, 2, S<6, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 6, 8, 1, S<6, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 6, 8, 0, 1, 1, S<1, 16, 1, 6>, 4, Scheduler, PipelineVersion, 1> // Incorrect results for at least GemmDefault + // clang-format on + >; + +template +using device_grouped_conv_bwd_weight_two_stage_nhwgc_wmma_c_shuffle_bf16_instances = std::tuple< + // clang-format off + //################################################| Num| InLayout| WeiLayout| OutLayout| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| KPer| ABK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CShuffleBlockTransfer| BlockGemm| BlockGemm| NumGroups| + //################################################| Dim| | | | Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVector| Pipeline| Pipeline| ToMerge| + //################################################| Spatial| | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| MBlock_MPerBlock| _NPerBlock| Sched| Ver| | + //################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | | | | + DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 32, 16, 16, 32, 8, 16, 16, 1, 1, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 1, 4, 0, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 1, 4, 0, 1, 1, S<1, 4, 1, 8>, 1, Scheduler, PipelineVersion, 1> + // DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 128, 128, 32, 8, 16, 16, 8, 2, S<4, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 8, 0, S<4, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, Scheduler, PipelineVersion, 1>, + // DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 128, 128, 32, 8, 16, 16, 8, 2, S<4, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 8, 1, S<4, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, Scheduler, PipelineVersion, 1>, + // DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 64, 64, 64, 8, 16, 16, 4, 2, S<8, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, S<8, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, 4, Scheduler, PipelineVersion, 1>, + // DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 128, 256, 64, 8, 16, 16, 8, 2, S<8, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 8, 1, S<8, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, 4, Scheduler, PipelineVersion, 1>, + // DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 48, 64, 128, 8, 16, 16, 3, 1, S<16, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 6, 8, 1, S<16, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, Scheduler, PipelineVersion, 1>, + // DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 96, 128, 64, 8, 16, 16, 6, 2, S<8, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 6, 8, 1, S<8, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, Scheduler, PipelineVersion, 1>, + // DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 64, 64, 128, 8, 16, 16, 4, 1, S<16, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, S<16, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, Scheduler, PipelineVersion, 1>, + // DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 96, 128, 128, 8, 16, 16, 6, 1, S<16, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 6, 8, 1, S<16, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, 4, Scheduler, PipelineVersion, 1>, + // DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 96, 96, 96, 48, 8, 16, 16, 6, 2, S<6, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 6, 8, 0, S<6, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 6, 8, 0, 1, 1, S<1, 16, 1, 6>, 4, Scheduler, PipelineVersion, 1>, + // DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 96, 96, 96, 48, 8, 16, 16, 6, 2, S<6, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 6, 8, 1, S<6, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 6, 8, 0, 1, 1, S<1, 16, 1, 6>, 4, Scheduler, PipelineVersion, 1> + // clang-format on + >; +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_v3_wmma_instance.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_v3_wmma_instance.hpp new file mode 100644 index 0000000000..761b07ea60 --- /dev/null +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_v3_wmma_instance.hpp @@ -0,0 +1,100 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle_v3.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using namespace ck::tensor_layout::convolution; + +using BF16 = ck::bhalf_t; +using F16 = ck::half_t; +using F32 = float; + +#ifdef CK_ENABLE_FP8 +using F8 = ck::f8_t; +#endif + +#ifdef CK_ENABLE_BF8 +using BF8 = ck::bf8_t; +#endif + +using Empty_Tuple = ck::Tuple<>; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto ConvBwdWeightDefault = + ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization::Default; + +static constexpr auto ConvBwdWeightFilter1x1Stride1Pad0 = + ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0; + +template +using device_grouped_conv_bwd_weight_v3_wmma_c_shuffle_f16_instances = std::tuple< + // clang-format off + //#########################################| Num| InLayout| WeiLayout| OutLayout| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| KPer| ABK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CShuffleBlockTransfer| BlockGemm| BlockGemm| + //#########################################| Dim| | | | Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVector| Pipeline| Pipeline| + //#########################################| Spatial| | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| MBlock_MPerBlock| _NPerBlock| Scheduler| Version| + //#########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | | | + DeviceGroupedConvBwdWeight_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 32, 32, 32, 8, 16, 16, 2, 1, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 2, 0, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 2, 0, 1, 1, S<1, 8, 1, 8>, 2, Scheduler, PipelineVersion>, + DeviceGroupedConvBwdWeight_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 128, 128, 32, 8, 16, 16, 8, 2, S<4, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 8, 0, S<4, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 8, 0, 1, 1, S<1, 16, 1, 8>, 8, Scheduler, PipelineVersion>, + DeviceGroupedConvBwdWeight_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 128, 128, 32, 8, 16, 16, 8, 2, S<4, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 8, 1, S<4, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 8, 0, 1, 1, S<1, 16, 1, 8>, 8, Scheduler, PipelineVersion>, + DeviceGroupedConvBwdWeight_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 64, 64, 64, 8, 16, 16, 4, 2, S<8, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, S<8, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, 8, Scheduler, PipelineVersion>, + DeviceGroupedConvBwdWeight_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 128, 256, 64, 8, 16, 16, 8, 2, S<8, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 8, 1, S<8, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, 8, Scheduler, PipelineVersion>, + DeviceGroupedConvBwdWeight_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 48, 64, 128, 8, 16, 16, 3, 1, S<16, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 6, 8, 1, S<16, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 8, Scheduler, PipelineVersion>, + DeviceGroupedConvBwdWeight_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 96, 128, 64, 8, 16, 16, 6, 2, S<8, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 6, 8, 1, S<8, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 8, Scheduler, PipelineVersion>, + DeviceGroupedConvBwdWeight_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 64, 64, 128, 8, 16, 16, 4, 1, S<16, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, S<16, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 8, Scheduler, PipelineVersion>, + DeviceGroupedConvBwdWeight_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 96, 128, 128, 8, 16, 16, 6, 1, S<16, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 6, 8, 1, S<16, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, 8, Scheduler, PipelineVersion> + // DeviceGroupedConvBwdWeight_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 96, 96, 96, 48, 8, 16, 16, 6, 2, S<6, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 6, 8, 0, S<6, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 6, 8, 0, 1, 1, S<1, 16, 1, 6>, 8, Scheduler, PipelineVersion>, // Incorrect results for at least GemmDefault + // DeviceGroupedConvBwdWeight_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 96, 96, 96, 48, 8, 16, 16, 6, 2, S<6, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 6, 8, 1, S<6, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 6, 8, 0, 1, 1, S<1, 16, 1, 6>, 8, Scheduler, PipelineVersion> // Incorrect results for at least GemmDefault + // clang-format on + >; + +template +using device_grouped_conv_bwd_weight_v3_wmma_c_shuffle_bf16_instances = std::tuple< + // clang-format off + //#########################################| Num| InLayout| WeiLayout| OutLayout| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| KPer| ABK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CShuffleBlockTransfer| BlockGemm| BlockGemm| + //#########################################| Dim| | | | Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVector| Pipeline| Pipeline| + //#########################################| Spatial| | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| MBlock_MPerBlock| _NPerBlock| Scheduler| Version| + //#########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | | | + DeviceGroupedConvBwdWeight_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 32, 32, 32, 8, 16, 16, 2, 1, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 2, 0, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 2, 0, 1, 1, S<1, 8, 1, 8>, 2, Scheduler, PipelineVersion>, + DeviceGroupedConvBwdWeight_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 128, 128, 32, 8, 16, 16, 8, 2, S<4, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 8, 0, S<4, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 8, 0, 1, 1, S<1, 16, 1, 8>, 8, Scheduler, PipelineVersion>, + DeviceGroupedConvBwdWeight_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 128, 128, 32, 8, 16, 16, 8, 2, S<4, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 8, 1, S<4, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 8, 0, 1, 1, S<1, 16, 1, 8>, 8, Scheduler, PipelineVersion>, + DeviceGroupedConvBwdWeight_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 64, 64, 64, 8, 16, 16, 4, 2, S<8, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, S<8, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, 8, Scheduler, PipelineVersion>, + DeviceGroupedConvBwdWeight_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 128, 256, 64, 8, 16, 16, 8, 2, S<8, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 8, 1, S<8, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, 8, Scheduler, PipelineVersion>, + DeviceGroupedConvBwdWeight_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 48, 64, 128, 8, 16, 16, 3, 1, S<16, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 6, 8, 1, S<16, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 8, Scheduler, PipelineVersion>, + DeviceGroupedConvBwdWeight_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 96, 128, 64, 8, 16, 16, 6, 2, S<8, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 6, 8, 1, S<8, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 8, Scheduler, PipelineVersion>, + DeviceGroupedConvBwdWeight_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 64, 64, 128, 8, 16, 16, 4, 1, S<16, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, S<16, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 8, Scheduler, PipelineVersion>, + DeviceGroupedConvBwdWeight_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 96, 128, 128, 8, 16, 16, 6, 1, S<16, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 6, 8, 1, S<16, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, 8, Scheduler, PipelineVersion>, + DeviceGroupedConvBwdWeight_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 96, 96, 96, 48, 8, 16, 16, 6, 2, S<6, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 6, 8, 0, S<6, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 6, 8, 0, 1, 1, S<1, 16, 1, 6>, 8, Scheduler, PipelineVersion>, + DeviceGroupedConvBwdWeight_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 96, 96, 96, 48, 8, 16, 16, 6, 2, S<6, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 6, 8, 1, S<6, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 6, 8, 0, 1, 1, S<1, 16, 1, 6>, 8, Scheduler, PipelineVersion> + //clang-format on + >; + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_wmma_bilinear_instance.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_wmma_bilinear_instance.hpp new file mode 100644 index 0000000000..f254628f73 --- /dev/null +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_wmma_bilinear_instance.hpp @@ -0,0 +1,97 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_wmma_cshuffle_v3.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using namespace ck::tensor_layout::convolution; + +using BF16 = ck::bhalf_t; +using F16 = ck::half_t; +using F32 = float; + +using Empty_Tuple = ck::Tuple<>; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using Bilinear = ck::tensor_operation::element_wise::Bilinear; + +static constexpr auto ConvBwdWeightDefault = + ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization::Default; + +static constexpr auto ConvBwdWeightFilter1x1Stride1Pad0 = + ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0; + +template +using device_grouped_conv_bwd_weight_wmma_c_shuffle_f16_bilinear_instances = std::tuple< + // clang-format off + //#################################################| Num| InLayout| WeiLayout| OutLayout| DsLayout| InData| WeiData| OutData| AccData| DsData| In| Wei| Out| ConvBackward| Block| MPer| NPer| KPer| ABK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CBlockTransfer| BlockGemm| BlockGemm| + //#################################################| Dim| | | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVector| Pipeline| Pipeline | + //#################################################| Spatial| | | | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| MBlock_MPerBlock| _NPerBlock| Scheduler| Version | + //#################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | | | + // generic instance + DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Tuple, F16, F16, F16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 64, 64, 64, 32, 8, 16, 16, 4, 2, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 1, 4, 1, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 1, 4, 1, 1, 1, S<1, 16, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Tuple, F16, F16, F16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 64, 64, 64, 32, 8, 16, 16, 4, 2, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 4, 1, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 4, 1, 1, 1, S<1, 16, 1, 4>, 2, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + // for fp16 conv.K and conv.C must be divisible by 2 + // since half_t atomic_add require scalar_per_x_vector % 2 == 0 + DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Tuple, F16, F16, F16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 64, 32, 32, 32, 8, 16, 16, 2, 1, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 2, 0, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 2, 0, 1, 1, S<1, 8, 1, 8>, 2, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Tuple, F16, F16, F16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 128, 128, 128, 32, 8, 16, 16, 8, 2, S<4, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 8, 0, S<4, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 8, 0, 1, 1, S<1, 16, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Tuple, F16, F16, F16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 128, 128, 128, 32, 8, 16, 16, 8, 2, S<4, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 8, 1, S<4, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 8, 0, 1, 1, S<1, 16, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Tuple, F16, F16, F16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 64, 64, 64, 64, 8, 16, 16, 4, 2, S<8, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, S<8, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Tuple, F16, F16, F16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 256, 128, 256, 64, 8, 16, 16, 8, 2, S<8, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 8, 1, S<8, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Tuple, F16, F16, F16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 128, 48, 64, 128, 8, 16, 16, 3, 1, S<16, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 6, 8, 1, S<16, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Tuple, F16, F16, F16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 128, 96, 128, 64, 8, 16, 16, 6, 2, S<8, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 6, 8, 1, S<8, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Tuple, F16, F16, F16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 128, 64, 64, 128, 8, 16, 16, 4, 1, S<16, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, S<16, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Tuple, F16, F16, F16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 256, 96, 128, 128, 8, 16, 16, 6, 1, S<16, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 6, 8, 1, S<16, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1> + // DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Tuple, F16, F16, F16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 96, 96, 96, 48, 8, 16, 16, 6, 2, S<6, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 6, 8, 0, S<6, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 6, 8, 0, 1, 1, S<1, 16, 1, 6>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, // Presumably doesn't produce correct results for f16 + // DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Tuple, F16, F16, F16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 96, 96, 96, 48, 8, 16, 16, 6, 2, S<6, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 6, 8, 1, S<6, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 6, 8, 0, 1, 1, S<1, 16, 1, 6>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1> // Presumably doesn't produce correct results for f16 + // clang-format on + >; + +template +using device_grouped_conv_bwd_weight_wmma_c_shuffle_bf16_bilinear_instances = std::tuple< + // clang-format off + //#################################################| Num| InLayout| WeiLayout| OutLayout| DsLayout| InData| WeiData| OutData| AccData| DsData| In| Wei| Out| ConvBackward| Block| MPer| NPer| KPer| ABK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CBlockTransfer| BlockGemm| BlockGemm| + //#################################################| Dim| | | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVector| Pipeline| Pipeline | + //#################################################| Spatial| | | | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| MBlock_MPerBlock| _NPerBlock| Scheduler| Version | + //#################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | | | + // generic instance + DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Tuple, BF16, F32, BF16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 64, 64, 64, 32, 8, 16, 16, 4, 2, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 1, 4, 1, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 1, 4, 1, 1, 1, S<1, 16, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + // other instances + DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Tuple, BF16, F32, BF16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 64, 32, 32, 32, 8, 16, 16, 2, 1, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 2, 0, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 2, 0, 1, 1, S<1, 8, 1, 8>, 2, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Tuple, BF16, F32, BF16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 128, 128, 128, 32, 8, 16, 16, 8, 2, S<4, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 8, 0, S<4, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 8, 0, 1, 1, S<1, 16, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Tuple, BF16, F32, BF16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 128, 128, 128, 32, 8, 16, 16, 8, 2, S<4, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 8, 1, S<4, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 8, 0, 1, 1, S<1, 16, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Tuple, BF16, F32, BF16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 64, 64, 64, 64, 8, 16, 16, 4, 2, S<8, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, S<8, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Tuple, BF16, F32, BF16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 256, 128, 256, 64, 8, 16, 16, 8, 2, S<8, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 8, 1, S<8, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Tuple, BF16, F32, BF16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 128, 48, 64, 128, 8, 16, 16, 3, 1, S<16, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 6, 8, 1, S<16, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Tuple, BF16, F32, BF16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 128, 96, 128, 64, 8, 16, 16, 6, 2, S<8, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 6, 8, 1, S<8, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Tuple, BF16, F32, BF16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 128, 64, 64, 128, 8, 16, 16, 4, 1, S<16, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, S<16, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Tuple, BF16, F32, BF16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 256, 96, 128, 128, 8, 16, 16, 6, 1, S<16, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 6, 8, 1, S<16, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1> + // DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Tuple, BF16, F32, BF16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 96, 96, 96, 48, 8, 16, 16, 6, 2, S<6, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 6, 8, 0, S<6, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 6, 8, 0, 1, 1, S<1, 16, 1, 6>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, // Verification failure + // DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Tuple, BF16, F32, BF16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 96, 96, 96, 48, 8, 16, 16, 6, 2, S<6, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 6, 8, 1, S<6, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 6, 8, 0, 1, 1, S<1, 16, 1, 6>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1> // Verification failure + // clang-format on + >; + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_wmma_instance.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_wmma_instance.hpp deleted file mode 100644 index 8743fb041c..0000000000 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_wmma_instance.hpp +++ /dev/null @@ -1,117 +0,0 @@ -// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. -// SPDX-License-Identifier: MIT - -#pragma once - -#include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" -#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle.hpp" -#include "ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { - -using F16 = ck::half_t; -using F32 = float; -using I8 = int8_t; -using I32 = int32_t; - -template -using S = ck::Sequence; - -using namespace ck::tensor_layout::convolution; - -using PassThrough = ck::tensor_operation::element_wise::PassThrough; - -static constexpr auto ConvBwdWeightDefault = - ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization::Default; - -static constexpr auto ConvBwdWeightFilter1x1Stride1Pad0 = - ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0; - -template -using device_grouped_conv_bwd_weight_wmma_f16_instances = std::tuple< - // clang-format off - //#####################################| NumDim| A| B| C| AData| BData| CData| AccData| A| B| C| ConvForward| Block| MPer| NPer| KPer| K1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //#####################################| Spatial| Layout| Layout| Layout| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | WMMA| WMMA| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeatPerWave| NRepeatPerWave| _MBlock_MPerBlock| ScalarPerVector| - //#####################################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| - //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - // generic instance - DeviceGroupedConvBwdWeight_Wmma_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, 1>, - // blocksize=256 - DeviceGroupedConvBwdWeight_Wmma_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 4>, - DeviceGroupedConvBwdWeight_Wmma_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 32, 1, 8>, 2>, - DeviceGroupedConvBwdWeight_Wmma_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 4>, - DeviceGroupedConvBwdWeight_Wmma_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 16>, 4>, - // blocksize=128 - DeviceGroupedConvBwdWeight_Wmma_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvBwdWeight_Wmma_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvBwdWeight_Wmma_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvBwdWeight_Wmma_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvBwdWeight_Wmma_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, - // blocksize=64 - DeviceGroupedConvBwdWeight_Wmma_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, S<8, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, - DeviceGroupedConvBwdWeight_Wmma_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<8, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 2>, 8>, - DeviceGroupedConvBwdWeight_Wmma_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, S<8, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 2>, 8>, - DeviceGroupedConvBwdWeight_Wmma_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, S<8, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, - DeviceGroupedConvBwdWeight_Wmma_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<8, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 2>, 8>, - // blocksize=32 - DeviceGroupedConvBwdWeight_Wmma_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<8, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 2>, 8>, - DeviceGroupedConvBwdWeight_Wmma_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<8, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 2>, 8>, - DeviceGroupedConvBwdWeight_Wmma_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, S<8, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 2>, 8>, - DeviceGroupedConvBwdWeight_Wmma_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, S<8, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 2>, 8>, - DeviceGroupedConvBwdWeight_Wmma_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, S<8, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 2>, 8>, - DeviceGroupedConvBwdWeight_Wmma_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, S<8, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 2>, 8>, - DeviceGroupedConvBwdWeight_Wmma_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, S<8, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 2>, 8> - // clang-format on - >; - -template -using device_grouped_conv_bwd_weight_wmma_i8_instances = std::tuple< - // clang-format off - //#####################################| NumDim| A| B| C| AData| BData| CData| AccData| A| B| C| ConvForward| Block| MPer| NPer| KPer| K1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //#####################################| Spatial| Layout| Layout| Layout| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | WMMA| WMMA| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeatPerWave| NRepeatPerWave| _MBlock_MPerBlock| ScalarPerVector| - //#####################################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| - //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - // generic instance - DeviceGroupedConvBwdWeight_Wmma_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, 1>, - // blocksize=256 - DeviceGroupedConvBwdWeight_Wmma_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvBwdWeight_Wmma_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 64, 1, 4>, 8>, - // blocksize=128 - DeviceGroupedConvBwdWeight_Wmma_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvBwdWeight_Wmma_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvBwdWeight_Wmma_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvBwdWeight_Wmma_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvBwdWeight_Wmma_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvBwdWeight_Wmma_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 8, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvBwdWeight_Wmma_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 8, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 2>, - DeviceGroupedConvBwdWeight_Wmma_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 8, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, - // blocksize=64 - DeviceGroupedConvBwdWeight_Wmma_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<8, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 8, 1, 1, 1, S<1, 32, 1, 2>, 8>, - DeviceGroupedConvBwdWeight_Wmma_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, S<8, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 8, 1, 1, 1, S<1, 32, 1, 2>, 8>, - DeviceGroupedConvBwdWeight_Wmma_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 8, 1, S<8, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, - DeviceGroupedConvBwdWeight_Wmma_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 8, 1, S<8, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, - // blocksize=32 - DeviceGroupedConvBwdWeight_Wmma_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<8, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 8, 1, 1, 1, S<1, 16, 1, 2>, 8>, - DeviceGroupedConvBwdWeight_Wmma_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 8, 1, S<8, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 8, 1, 1, 1, S<1, 16, 1, 2>, 8>, - DeviceGroupedConvBwdWeight_Wmma_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, S<8, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 2>, 8>, - DeviceGroupedConvBwdWeight_Wmma_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 8, 1, S<8, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 2>, 8> - // clang-format on - >; - -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_wmma_scale_instance.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_wmma_scale_instance.hpp new file mode 100644 index 0000000000..e893c92d1d --- /dev/null +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_wmma_scale_instance.hpp @@ -0,0 +1,96 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_wmma_cshuffle_v3.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using namespace ck::tensor_layout::convolution; + +using BF16 = ck::bhalf_t; +using F16 = ck::half_t; +using F32 = float; + +using Empty_Tuple = ck::Tuple<>; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using Scale = ck::tensor_operation::element_wise::Scale; + +static constexpr auto ConvBwdWeightDefault = + ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization::Default; + +static constexpr auto ConvBwdWeightFilter1x1Stride1Pad0 = + ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0; + +template +using device_grouped_conv_bwd_weight_wmma_c_shuffle_f16_scale_instances = std::tuple< + // clang-format off + //#################################################| Num| InLayout| WeiLayout| OutLayout| DsLayout| InData| WeiData| OutData| AccData| DsData| In| Wei| Out| ConvBackward| Block| MPer| NPer| KPer| ABK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CBlockTransfer| BlockGemm| BlockGemm| + //#################################################| Dim| | | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVector| Pipeline| Pipeline | + //#################################################| Spatial| | | | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| MBlock_MPerBlock| _NPerBlock| Scheduler| Version | + //#################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | | | + // generic instance + DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, F16, F16, F16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 64, 64, 64, 32, 8, 16, 16, 4, 2, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 4, 1, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 4, 1, 1, 1, S<1, 16, 1, 4>, 2, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + // for fp16 conv.K and conv.C must be divisible by 2 + // since half_t atomic_add require scalar_per_x_vector % 2 == 0 + DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, F16, F16, F16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 64, 32, 32, 32, 8, 16, 16, 2, 1, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 2, 0, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 2, 0, 1, 1, S<1, 8, 1, 8>, 2, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, F16, F16, F16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 128, 128, 128, 32, 8, 16, 16, 8, 2, S<4, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 8, 0, S<4, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 8, 0, 1, 1, S<1, 16, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, F16, F16, F16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 128, 128, 128, 32, 8, 16, 16, 8, 2, S<4, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 8, 1, S<4, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 8, 0, 1, 1, S<1, 16, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, F16, F16, F16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 64, 64, 64, 64, 8, 16, 16, 4, 2, S<8, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, S<8, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, F16, F16, F16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 256, 128, 256, 64, 8, 16, 16, 8, 2, S<8, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 8, 1, S<8, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, F16, F16, F16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 128, 48, 64, 128, 8, 16, 16, 3, 1, S<16, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 6, 8, 1, S<16, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, F16, F16, F16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 128, 96, 128, 64, 8, 16, 16, 6, 2, S<8, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 6, 8, 1, S<8, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, F16, F16, F16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 128, 64, 64, 128, 8, 16, 16, 4, 1, S<16, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, S<16, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, F16, F16, F16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 256, 96, 128, 128, 8, 16, 16, 6, 1, S<16, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 6, 8, 1, S<16, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1> + // DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, F16, F16, F16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 96, 96, 96, 48, 8, 16, 16, 6, 2, S<6, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 6, 8, 0, S<6, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 6, 8, 0, 1, 1, S<1, 16, 1, 6>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, // Presumably doesn't produce correct results for fp16 + // DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, F16, F16, F16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 96, 96, 96, 48, 8, 16, 16, 6, 2, S<6, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 6, 8, 1, S<6, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 6, 8, 0, 1, 1, S<1, 16, 1, 6>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1> // Presumably doesn't produce correct results for fp16 + // clang-format on + >; + +template +using device_grouped_conv_bwd_weight_wmma_c_shuffle_bf16_scale_instances = std::tuple< + // clang-format off + //#################################################| Num| InLayout| WeiLayout| OutLayout| DsLayout| InData| WeiData| OutData| AccData| DsData| In| Wei| Out| ConvBackward| Block| MPer| NPer| KPer| ABK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CBlockTransfer| BlockGemm| BlockGemm| + //#################################################| Dim| | | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVector| Pipeline| Pipeline | + //#################################################| Spatial| | | | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| MBlock_MPerBlock| _NPerBlock| Scheduler| Version | + //#################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | | | + // generic instance + DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, BF16, F32, BF16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 64, 64, 64, 32, 8, 16, 16, 4, 2, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 1, 4, 1, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 1, 4, 1, 1, 1, S<1, 16, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + // other instances + DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, BF16, F32, BF16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 64, 32, 32, 32, 8, 16, 16, 2, 1, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 2, 0, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 2, 0, 1, 1, S<1, 8, 1, 8>, 2, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, BF16, F32, BF16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 128, 128, 128, 32, 8, 16, 16, 8, 2, S<4, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 8, 0, S<4, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 8, 0, 1, 1, S<1, 16, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, BF16, F32, BF16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 128, 128, 128, 32, 8, 16, 16, 8, 2, S<4, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 8, 1, S<4, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 8, 0, 1, 1, S<1, 16, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, BF16, F32, BF16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 64, 64, 64, 64, 8, 16, 16, 4, 2, S<8, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, S<8, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, BF16, F32, BF16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 256, 128, 256, 64, 8, 16, 16, 8, 2, S<8, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 8, 1, S<8, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, BF16, F32, BF16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 128, 48, 64, 128, 8, 16, 16, 3, 1, S<16, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 6, 8, 1, S<16, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, BF16, F32, BF16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 128, 96, 128, 64, 8, 16, 16, 6, 2, S<8, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 6, 8, 1, S<8, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, BF16, F32, BF16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 128, 64, 64, 128, 8, 16, 16, 4, 1, S<16, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, S<16, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, BF16, F32, BF16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 256, 96, 128, 128, 8, 16, 16, 6, 1, S<16, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 6, 8, 1, S<16, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1> + // DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, BF16, F32, BF16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 96, 96, 96, 48, 8, 16, 16, 6, 2, S<6, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 6, 8, 0, S<6, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 6, 8, 0, 1, 1, S<1, 16, 1, 6>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, // Verification failure + // DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, BF16, F32, BF16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 96, 96, 96, 48, 8, 16, 16, 6, 2, S<6, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 6, 8, 1, S<6, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 6, 8, 0, 1, 1, S<1, 16, 1, 6>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1> // Verification failure + // clang-format on + >; + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight.hpp index 3fe8fa9c5a..6dd8758eb7 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight.hpp @@ -21,6 +21,7 @@ #endif #ifdef CK_USE_WMMA #include "grouped_convolution_backward_weight_wmma.inc" +#include "grouped_convolution_backward_weight_explicit_wmma.inc" #endif namespace ck { namespace tensor_operation { @@ -414,21 +415,24 @@ struct DeviceOperationInstanceFactory && is_same_v && - is_same_v) + if constexpr(is_same_v && is_same_v && + is_same_v) { #ifdef CK_ENABLE_FP16 if constexpr(is_same_v && is_same_v && is_same_v && is_same_v && is_same_v) { - add_device_grouped_conv3d_bwd_weight_wmma_gndhwc_gkzyxc_gndhwk_f16_instances( + add_device_grouped_conv2d_bwd_weight_wmma_nhwgc_gkyxc_nhwgk_f16_instances( op_ptrs); - add_device_grouped_conv3d_bwd_weight_wmma_gndhwc_gkzyxc_gndhwk_f16_1x1s1p0_instances( + add_device_grouped_conv2d_bwd_weight_two_stage_wmma_nhwgc_gkyxc_nhwgk_f16_pipev1_instances( + op_ptrs); + // Explicit GEMM + add_device_grouped_convnd_bwd_weight_wmma_f16_f16_f16_exp_comp_default_instances( + op_ptrs); + add_device_grouped_convnd_bwd_weight_wmma_f16_f16_f16_exp_comp_mnkpadding_instances( + op_ptrs); + add_device_grouped_convnd_bwd_weight_wmma_f16_f16_f16_exp_odd_mn_instances( op_ptrs); } #endif -#ifdef CK_ENABLE_INT8 - else if constexpr(is_same_v && is_same_v && - is_same_v && - is_same_v && - is_same_v) +#ifdef CK_ENABLE_BF16 + if constexpr(is_same_v && + is_same_v && + is_same_v && + is_same_v && + is_same_v) { - add_device_grouped_conv3d_bwd_weight_wmma_gndhwc_gkzyxc_gndhwk_i8_instances( + add_device_grouped_conv2d_bwd_weight_wmma_nhwgc_gkyxc_nhwgk_bf16_instances( op_ptrs); - add_device_grouped_conv3d_bwd_weight_wmma_gndhwc_gkzyxc_gndhwk_i8_1x1s1p0_instances( + add_device_grouped_conv2d_bwd_weight_two_stage_wmma_nhwgc_gkyxc_nhwgk_bf16_pipev1_instances( + op_ptrs); + // Explicit GEMM + add_device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_comp_default_instances( + op_ptrs); + add_device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_odd_mn_instances( + op_ptrs); + add_device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_comp_mnkpadding_instances( op_ptrs); } #endif } + } + if constexpr(NumDimSpatial == 3) + { if constexpr(is_same_v && is_same_v && is_same_v) { @@ -889,26 +914,40 @@ struct DeviceOperationInstanceFactory && is_same_v && - is_same_v && - is_same_v && - is_same_v) +#ifdef CK_ENABLE_BF16 + if constexpr(is_same_v && + is_same_v && + is_same_v && + is_same_v && + is_same_v) { - add_device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_i8_instances( + add_device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_bf16_instances( op_ptrs); - add_device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_i8_1x1s1p0_instances( + add_device_grouped_conv3d_bwd_weight_two_stage_wmma_ndhwgc_gkzyxc_ndhwgk_bf16_pipev1_instances( + op_ptrs); + // Explicit GEMM + add_device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_comp_default_instances( + op_ptrs); + add_device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_odd_mn_instances( + op_ptrs); + add_device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_comp_mnkpadding_instances( op_ptrs); } #endif } } #endif - return op_ptrs; } }; diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight_bilinear.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight_bilinear.hpp index a0e8e46570..48a43e59ad 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight_bilinear.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight_bilinear.hpp @@ -17,6 +17,39 @@ namespace tensor_operation { namespace device { namespace instance { +#ifdef CK_USE_WMMA +#ifdef CK_ENABLE_FP16 +void add_device_grouped_conv3d_bwd_weight_wmma_bilinear_ndhwgc_gkzyxc_ndhwgk_f16_instances( + std::vector, + F16, + F16, + F16, + Tuple, + PassThrough, + Bilinear, + PassThrough>>>& instances); +#endif +#ifdef CK_ENABLE_BF16 +void add_device_grouped_conv3d_bwd_weight_wmma_bilinear_ndhwgc_gkzyxc_ndhwgk_bf16_f32_bf16_instances( + std::vector, + BF16, + F32, + BF16, + Tuple, + PassThrough, + Bilinear, + PassThrough>>>& instances); +#endif +#endif + #ifdef CK_USE_XDL #ifdef CK_ENABLE_BF16 void add_device_grouped_conv3d_bwd_weight_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_bf16_f32_bf16_instances( @@ -148,6 +181,35 @@ struct DeviceOperationInstanceFactory< { std::vector> op_ptrs; +#ifdef CK_USE_WMMA + if constexpr(NumDimSpatial == 3) + { + if constexpr(is_same_v && is_same_v && + is_same_v) + { +#ifdef CK_ENABLE_FP16 + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_conv3d_bwd_weight_wmma_bilinear_ndhwgc_gkzyxc_ndhwgk_f16_instances( + op_ptrs); + } +#endif +#ifdef CK_ENABLE_BF16 + if constexpr(is_same_v && is_same_v && + is_same_v && + is_same_v && + is_same_v) + { + add_device_grouped_conv3d_bwd_weight_wmma_bilinear_ndhwgc_gkzyxc_ndhwgk_bf16_f32_bf16_instances( + op_ptrs); + } +#endif + } + } +#endif + #ifdef CK_USE_XDL if constexpr(NumDimSpatial == 3) { diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight_explicit_wmma.inc b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight_explicit_wmma.inc new file mode 100644 index 0000000000..d7fefde5cd --- /dev/null +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight_explicit_wmma.inc @@ -0,0 +1,171 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +// 2D +#ifdef CK_ENABLE_BF16 + +void add_device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_comp_default_instances( + std::vector>>& instances); + +void add_device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_odd_mn_instances( + std::vector>>& instances); + +void add_device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_comp_mnkpadding_instances( + std::vector>>& instances); +#endif + +#ifdef CK_ENABLE_FP16 + +void add_device_grouped_convnd_bwd_weight_wmma_f16_f16_f16_exp_comp_default_instances( + std::vector>>& instances); + +void add_device_grouped_convnd_bwd_weight_wmma_f16_f16_f16_exp_comp_mnkpadding_instances( + std::vector>>& instances); + +void add_device_grouped_convnd_bwd_weight_wmma_f16_f16_f16_exp_odd_mn_instances( + std::vector>>& instances); +#endif + +// 3D +#ifdef CK_ENABLE_BF16 + +void add_device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_comp_default_instances( + std::vector>>& instances); + +void add_device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_odd_mn_instances( + std::vector>>& instances); + +void add_device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_comp_mnkpadding_instances( + std::vector>>& instances); +#endif + +#ifdef CK_ENABLE_FP16 + +void add_device_grouped_convnd_bwd_weight_wmma_f16_f16_f16_exp_comp_default_instances( + std::vector>>& instances); + +void add_device_grouped_convnd_bwd_weight_wmma_f16_f16_f16_exp_comp_mnkpadding_instances( + std::vector>>& instances); + +void add_device_grouped_convnd_bwd_weight_wmma_f16_f16_f16_exp_odd_mn_instances( + std::vector>>& instances); +#endif + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight_explicit_xdl.inc b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight_explicit_xdl.inc index d566c331f9..faa0120776 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight_explicit_xdl.inc +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight_explicit_xdl.inc @@ -10,7 +10,7 @@ namespace instance { // 2D #ifdef CK_ENABLE_BF16 -void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_comp_default_instances( +void add_device_grouped_convnd_bwd_weight_xdl_bf16_bf16_bf16_exp_comp_default_instances( std::vector>>& instances); -void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_comp_mnkpadding_instances( +void add_device_grouped_convnd_bwd_weight_xdl_bf16_bf16_bf16_exp_comp_mnkpadding_instances( std::vector>>& instances); -void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v1_default_instances( +void add_device_grouped_convnd_bwd_weight_xdl_bf16_bf16_bf16_exp_mem_v1_default_instances( std::vector>>& instances); -void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v1_mnkpadding_instances( +void add_device_grouped_convnd_bwd_weight_xdl_bf16_bf16_bf16_exp_mem_v1_mnkpadding_instances( std::vector>>& instances); -void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v2_default_instances( +void add_device_grouped_convnd_bwd_weight_xdl_bf16_bf16_bf16_exp_mem_v2_default_instances( std::vector>>& instances); -void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v2_mnkpadding_instances( +void add_device_grouped_convnd_bwd_weight_xdl_bf16_bf16_bf16_exp_mem_v2_mnkpadding_instances( std::vector>>& instances); -void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_odd_mn_instances( +void add_device_grouped_convnd_bwd_weight_xdl_bf16_bf16_bf16_exp_odd_mn_instances( std::vector>>& instances); -void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_odd_m_instances( +void add_device_grouped_convnd_bwd_weight_xdl_bf16_bf16_bf16_exp_odd_m_instances( std::vector>>& instances); -void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_odd_n_instances( +void add_device_grouped_convnd_bwd_weight_xdl_bf16_bf16_bf16_exp_odd_n_instances( std::vector>>& instances); -void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_comp_mnkpadding_instances( +void add_device_grouped_convnd_bwd_weight_xdl_f16_f16_f16_exp_comp_mnkpadding_instances( std::vector>>& instances); -void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v1_default_instances( +void add_device_grouped_convnd_bwd_weight_xdl_f16_f16_f16_exp_mem_v1_default_instances( std::vector>>& instances); -void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v1_mnkpadding_instances( +void add_device_grouped_convnd_bwd_weight_xdl_f16_f16_f16_exp_mem_v1_mnkpadding_instances( std::vector>>& instances); -void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v2_default_instances( +void add_device_grouped_convnd_bwd_weight_xdl_f16_f16_f16_exp_mem_v2_default_instances( std::vector>>& instances); -void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v2_mnkpadding_instances( +void add_device_grouped_convnd_bwd_weight_xdl_f16_f16_f16_exp_mem_v2_mnkpadding_instances( std::vector>>& instances); -void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_odd_mn_instances( +void add_device_grouped_convnd_bwd_weight_xdl_f16_f16_f16_exp_odd_mn_instances( std::vector>>& instances); -void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_odd_m_instances( +void add_device_grouped_convnd_bwd_weight_xdl_f16_f16_f16_exp_odd_m_instances( std::vector>>& instances); -void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_odd_n_instances( +void add_device_grouped_convnd_bwd_weight_xdl_f16_f16_f16_exp_odd_n_instances( std::vector>>& instances); -void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_comp_mnkpadding_instances( +void add_device_grouped_convnd_bwd_weight_xdl_bf16_bf16_bf16_exp_comp_mnkpadding_instances( std::vector>>& instances); -void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v1_default_instances( +void add_device_grouped_convnd_bwd_weight_xdl_bf16_bf16_bf16_exp_mem_v1_default_instances( std::vector>>& instances); -void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v1_mnkpadding_instances( +void add_device_grouped_convnd_bwd_weight_xdl_bf16_bf16_bf16_exp_mem_v1_mnkpadding_instances( std::vector>>& instances); -void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v2_default_instances( +void add_device_grouped_convnd_bwd_weight_xdl_bf16_bf16_bf16_exp_mem_v2_default_instances( std::vector>>& instances); -void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v2_mnkpadding_instances( +void add_device_grouped_convnd_bwd_weight_xdl_bf16_bf16_bf16_exp_mem_v2_mnkpadding_instances( std::vector>>& instances); -void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_odd_mn_instances( +void add_device_grouped_convnd_bwd_weight_xdl_bf16_bf16_bf16_exp_odd_mn_instances( std::vector>>& instances); -void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_odd_m_instances( +void add_device_grouped_convnd_bwd_weight_xdl_bf16_bf16_bf16_exp_odd_m_instances( std::vector>>& instances); -void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_odd_n_instances( +void add_device_grouped_convnd_bwd_weight_xdl_bf16_bf16_bf16_exp_odd_n_instances( std::vector>>& instances); -void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_comp_mnkpadding_instances( +void add_device_grouped_convnd_bwd_weight_xdl_f16_f16_f16_exp_comp_mnkpadding_instances( std::vector>>& instances); -void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v1_default_instances( +void add_device_grouped_convnd_bwd_weight_xdl_f16_f16_f16_exp_mem_v1_default_instances( std::vector>>& instances); -void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v1_mnkpadding_instances( +void add_device_grouped_convnd_bwd_weight_xdl_f16_f16_f16_exp_mem_v1_mnkpadding_instances( std::vector>>& instances); -void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v2_default_instances( +void add_device_grouped_convnd_bwd_weight_xdl_f16_f16_f16_exp_mem_v2_default_instances( std::vector>>& instances); -void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v2_mnkpadding_instances( +void add_device_grouped_convnd_bwd_weight_xdl_f16_f16_f16_exp_mem_v2_mnkpadding_instances( std::vector>>& instances); -void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_odd_mn_instances( +void add_device_grouped_convnd_bwd_weight_xdl_f16_f16_f16_exp_odd_mn_instances( std::vector>>& instances); -void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_odd_m_instances( +void add_device_grouped_convnd_bwd_weight_xdl_f16_f16_f16_exp_odd_m_instances( std::vector>>& instances); -void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_odd_n_instances( +void add_device_grouped_convnd_bwd_weight_xdl_f16_f16_f16_exp_odd_n_instances( std::vector, + F16, + F16, + F16, + Tuple<>, + PassThrough, + Scale, + PassThrough>>>& instances); +#endif + +#ifdef CK_ENABLE_BF16 +void add_device_grouped_conv3d_bwd_weight_wmma_scale_ndhwgc_gkzyxc_ndhwgk_bf16_f32_bf16_instances( + std::vector, + BF16, + F32, + BF16, + Tuple<>, + PassThrough, + Scale, + PassThrough>>>& instances); +#endif +#endif + #ifdef CK_USE_XDL #ifdef CK_ENABLE_BF16 void add_device_grouped_conv3d_bwd_weight_xdl_scale_ndhwgc_gkzyxc_ndhwgk_bf16_f32_bf16_instances( @@ -147,6 +181,34 @@ struct DeviceOperationInstanceFactory< static auto GetInstances() { std::vector> op_ptrs; +#ifdef CK_USE_WMMA + if constexpr(NumDimSpatial == 3) + { + if constexpr(is_same_v && is_same_v && + is_same_v) + { +#ifdef CK_ENABLE_FP16 + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_conv3d_bwd_weight_wmma_scale_ndhwgc_gkzyxc_ndhwgk_f16_instances( + op_ptrs); + } +#endif +#ifdef CK_ENABLE_BF16 + if constexpr(is_same_v && is_same_v && + is_same_v && + is_same_v && + is_same_v) + { + add_device_grouped_conv3d_bwd_weight_wmma_scale_ndhwgc_gkzyxc_ndhwgk_bf16_f32_bf16_instances( + op_ptrs); + } +#endif + } + } +#endif #ifdef CK_USE_XDL if constexpr(NumDimSpatial == 3) diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight_wmma.inc b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight_wmma.inc index 658cdf431d..06247019f1 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight_wmma.inc +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight_wmma.inc @@ -8,32 +8,61 @@ namespace tensor_operation { namespace device { namespace instance { +// conv2d backward weight +#ifdef CK_ENABLE_FP16 +void add_device_grouped_conv2d_bwd_weight_wmma_nhwgc_gkyxc_nhwgk_f16_instances( + std::vector>>& instances); + +void add_device_grouped_conv2d_bwd_weight_two_stage_wmma_nhwgc_gkyxc_nhwgk_f16_pipev1_instances( + std::vector>>& instances); +#endif + +#ifdef CK_ENABLE_BF16 +void add_device_grouped_conv2d_bwd_weight_wmma_nhwgc_gkyxc_nhwgk_bf16_instances( + std::vector>>& instances); + +void add_device_grouped_conv2d_bwd_weight_two_stage_wmma_nhwgc_gkyxc_nhwgk_bf16_pipev1_instances( + std::vector>>& instances); +#endif + // conv3d backward weight #ifdef CK_ENABLE_FP16 -void add_device_grouped_conv3d_bwd_weight_wmma_gndhwc_gkzyxc_gndhwk_f16_instances( - std::vector>>& instances); - -void add_device_grouped_conv3d_bwd_weight_wmma_gndhwc_gkzyxc_gndhwk_f16_1x1s1p0_instances( - std::vector>>& instances); - void add_device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_f16_instances( std::vector>>& instances); -void add_device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_f16_1x1s1p0_instances( +void add_device_grouped_conv3d_bwd_weight_two_stage_wmma_ndhwgc_gkzyxc_ndhwgk_f16_pipev1_instances( std::vector>>& instances); #endif -#ifdef CK_ENABLE_INT8 -void add_device_grouped_conv3d_bwd_weight_wmma_gndhwc_gkzyxc_gndhwk_i8_instances( - std::vector>>& instances); -void add_device_grouped_conv3d_bwd_weight_wmma_gndhwc_gkzyxc_gndhwk_i8_1x1s1p0_instances( - std::vector>>& instances); - -void add_device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_i8_instances( +#ifdef CK_ENABLE_BF16 +void add_device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_bf16_instances( std::vector>>& instances); -void add_device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_i8_1x1s1p0_instances( +void add_device_grouped_conv3d_bwd_weight_two_stage_wmma_ndhwgc_gkzyxc_ndhwgk_bf16_pipev1_instances( std::vector>>& instances); diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv1d_bwd_weight/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv1d_bwd_weight/CMakeLists.txt index 4ef6722ab5..56a9d16623 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv1d_bwd_weight/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv1d_bwd_weight/CMakeLists.txt @@ -1,7 +1,7 @@ # Copyright (c) Advanced Micro Devices, Inc., or its affiliates. # SPDX-License-Identifier: MIT -# ONLY XDL_AND_DL_KERNELS +# XDL_DL_WMMA_KERNELS set(GROUPED_CONV1D_BWD_WEIGHT xdl/device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f16_instance.cpp xdl/device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f32_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/CMakeLists.txt index 7e9a26c092..ec9e7da391 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/CMakeLists.txt @@ -1,7 +1,7 @@ # Copyright (c) Advanced Micro Devices, Inc., or its affiliates. # SPDX-License-Identifier: MIT -# ONLY XDL_AND_DL_KERNELS +# XDL_DL_WMMA_KERNELS set(GROUPED_CONV2D_BWD_WEIGHT xdl/gnhwc_gkyxc_gnhwk/device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f16_instance.cpp xdl/gnhwc_gkyxc_gnhwk/device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f32_instance.cpp @@ -72,4 +72,11 @@ if(DL_KERNELS) dl/device_grouped_conv2d_bwd_weight_dl_nhwgc_gkyxc_nhwgk_bf16_f32_bf16_instance.cpp) endif() +list(APPEND GROUPED_CONV2D_BWD_WEIGHT + wmma/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_wmma_nhwgc_gkyxc_nhwgk_bf16_instance.cpp + wmma/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_two_stage_wmma_nhwgc_gkyxc_nhwgk_bf16_pipev1_instance.cpp + wmma/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_wmma_nhwgc_gkyxc_nhwgk_f16_instance.cpp + wmma/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_two_stage_wmma_nhwgc_gkyxc_nhwgk_f16_pipev1_instance.cpp + ) + add_instance_library(device_grouped_conv2d_bwd_weight_instance ${GROUPED_CONV2D_BWD_WEIGHT}) diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/wmma/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_two_stage_wmma_nhwgc_gkyxc_nhwgk_bf16_pipev1_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/wmma/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_two_stage_wmma_nhwgc_gkyxc_nhwgk_bf16_pipev1_instance.cpp new file mode 100644 index 0000000000..94601d8f27 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/wmma/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_two_stage_wmma_nhwgc_gkyxc_nhwgk_bf16_pipev1_instance.cpp @@ -0,0 +1,41 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_two_stage_wmma_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv2d_bwd_weight_two_stage_wmma_nhwgc_gkyxc_nhwgk_bf16_pipev1_instances( + std::vector>>& instances) +{ + // 1. Default + add_device_operation_instances( + instances, + device_grouped_conv_bwd_weight_two_stage_nhwgc_wmma_c_shuffle_bf16_instances< + 2, + NHWGC, + GKYXC, + NHWGK, + ConvBwdWeightDefault, + BlockGemmPipelineScheduler::Intrawave, + BlockGemmPipelineVersion::v1>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/wmma/device_grouped_conv3d_bwd_weight_wmma_gndhwc_gkzyxc_gndhwk_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/wmma/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_two_stage_wmma_nhwgc_gkyxc_nhwgk_f16_pipev1_instance.cpp similarity index 52% rename from library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/wmma/device_grouped_conv3d_bwd_weight_wmma_gndhwc_gkzyxc_gndhwk_f16_instance.cpp rename to library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/wmma/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_two_stage_wmma_nhwgc_gkyxc_nhwgk_f16_pipev1_instance.cpp index cd8f8f5726..24ff6dcb96 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/wmma/device_grouped_conv3d_bwd_weight_wmma_gndhwc_gkzyxc_gndhwk_f16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/wmma/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_two_stage_wmma_nhwgc_gkyxc_nhwgk_f16_pipev1_instance.cpp @@ -2,17 +2,19 @@ // SPDX-License-Identifier: MIT #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" -#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_wmma_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_two_stage_wmma_instance.hpp" namespace ck { namespace tensor_operation { namespace device { namespace instance { -void add_device_grouped_conv3d_bwd_weight_wmma_gndhwc_gkzyxc_gndhwk_f16_instances( - std::vector>>& instances) { + // 1. Default add_device_operation_instances( instances, - device_grouped_conv_bwd_weight_wmma_f16_instances<3, - GNDHWC, - GKZYXC, - GNDHWK, - ConvBwdWeightDefault>{}); + device_grouped_conv_bwd_weight_two_stage_nhwgc_wmma_c_shuffle_f16_instances< + 2, + NHWGC, + GKYXC, + NHWGK, + ConvBwdWeightDefault, + BlockGemmPipelineScheduler::Intrawave, + BlockGemmPipelineVersion::v1>{}); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/wmma/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_wmma_nhwgc_gkyxc_nhwgk_bf16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/wmma/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_wmma_nhwgc_gkyxc_nhwgk_bf16_instance.cpp new file mode 100644 index 0000000000..adc9de3a3d --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/wmma/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_wmma_nhwgc_gkyxc_nhwgk_bf16_instance.cpp @@ -0,0 +1,38 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_v3_wmma_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv2d_bwd_weight_wmma_nhwgc_gkyxc_nhwgk_bf16_instances( + std::vector>>& instances) +{ + // 1. Default + add_device_operation_instances( + instances, + device_grouped_conv_bwd_weight_v3_wmma_c_shuffle_bf16_instances<2, + NHWGC, + GKYXC, + NHWGK, + ConvBwdWeightDefault>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/wmma/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_wmma_nhwgc_gkyxc_nhwgk_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/wmma/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_wmma_nhwgc_gkyxc_nhwgk_f16_instance.cpp new file mode 100644 index 0000000000..f304d1bba4 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/wmma/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_wmma_nhwgc_gkyxc_nhwgk_f16_instance.cpp @@ -0,0 +1,38 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_v3_wmma_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv2d_bwd_weight_wmma_nhwgc_gkyxc_nhwgk_f16_instances( + std::vector>>& instances) +{ + // 1. Default + add_device_operation_instances( + instances, + device_grouped_conv_bwd_weight_v3_wmma_c_shuffle_f16_instances<2, + NHWGC, + GKYXC, + NHWGK, + ConvBwdWeightDefault>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/CMakeLists.txt index 24c608f4ba..b246b87178 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/CMakeLists.txt @@ -1,7 +1,7 @@ # Copyright (c) Advanced Micro Devices, Inc., or its affiliates. # SPDX-License-Identifier: MIT - # XDL_DL_WMMA_KERNELS +# XDL_DL_WMMA_KERNELS set(GROUPED_CONV3D_BWD_WEIGHT xdl/gndhwc_gkzyxc_gndhwk/device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f16_instance.cpp xdl/gndhwc_gkzyxc_gndhwk/device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f32_instance.cpp @@ -69,14 +69,11 @@ if(DL_KERNELS) endif() list(APPEND GROUPED_CONV3D_BWD_WEIGHT - wmma/device_grouped_conv3d_bwd_weight_wmma_gndhwc_gkzyxc_gndhwk_f16_1x1s1p0_instance.cpp - wmma/device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_f16_1x1s1p0_instance.cpp - wmma/device_grouped_conv3d_bwd_weight_wmma_gndhwc_gkzyxc_gndhwk_f16_instance.cpp - wmma/device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp - wmma/device_grouped_conv3d_bwd_weight_wmma_gndhwc_gkzyxc_gndhwk_i8_1x1s1p0_instance.cpp - wmma/device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_i8_1x1s1p0_instance.cpp - wmma/device_grouped_conv3d_bwd_weight_wmma_gndhwc_gkzyxc_gndhwk_i8_instance.cpp - wmma/device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_i8_instance.cpp) + wmma/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp + wmma/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp + wmma/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_two_stage_wmma_ndhwgc_gkzyxc_ndhwgk_f16_pipev1_instance.cpp + wmma/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_two_stage_wmma_ndhwgc_gkzyxc_ndhwgk_bf16_pipev1_instance.cpp + ) if((DTYPES MATCHES "fp8" AND DTYPES MATCHES "bf8" AND DTYPES MATCHES "fp16") OR NOT DEFINED DTYPES) list(APPEND GROUPED_CONV3D_BWD_WEIGHT diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/wmma/device_grouped_conv3d_bwd_weight_wmma_gndhwc_gkzyxc_gndhwk_f16_1x1s1p0_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/wmma/device_grouped_conv3d_bwd_weight_wmma_gndhwc_gkzyxc_gndhwk_f16_1x1s1p0_instance.cpp deleted file mode 100644 index 643f1914c8..0000000000 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/wmma/device_grouped_conv3d_bwd_weight_wmma_gndhwc_gkzyxc_gndhwk_f16_1x1s1p0_instance.cpp +++ /dev/null @@ -1,35 +0,0 @@ -// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. -// SPDX-License-Identifier: MIT - -#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" -#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_wmma_instance.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { -void add_device_grouped_conv3d_bwd_weight_wmma_gndhwc_gkzyxc_gndhwk_f16_1x1s1p0_instances( - std::vector>>& instances) -{ - add_device_operation_instances( - instances, - device_grouped_conv_bwd_weight_wmma_f16_instances<3, - GNDHWC, - GKZYXC, - GNDHWK, - ConvBwdWeightFilter1x1Stride1Pad0>{}); -} - -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/wmma/device_grouped_conv3d_bwd_weight_wmma_gndhwc_gkzyxc_gndhwk_i8_1x1s1p0_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/wmma/device_grouped_conv3d_bwd_weight_wmma_gndhwc_gkzyxc_gndhwk_i8_1x1s1p0_instance.cpp deleted file mode 100644 index 7eb5e434ff..0000000000 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/wmma/device_grouped_conv3d_bwd_weight_wmma_gndhwc_gkzyxc_gndhwk_i8_1x1s1p0_instance.cpp +++ /dev/null @@ -1,35 +0,0 @@ -// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. -// SPDX-License-Identifier: MIT - -#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" -#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_wmma_instance.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { -void add_device_grouped_conv3d_bwd_weight_wmma_gndhwc_gkzyxc_gndhwk_i8_1x1s1p0_instances( - std::vector>>& instances) -{ - add_device_operation_instances( - instances, - device_grouped_conv_bwd_weight_wmma_i8_instances<3, - GNDHWC, - GKZYXC, - GNDHWK, - ConvBwdWeightFilter1x1Stride1Pad0>{}); -} - -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/wmma/device_grouped_conv3d_bwd_weight_wmma_gndhwc_gkzyxc_gndhwk_i8_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/wmma/device_grouped_conv3d_bwd_weight_wmma_gndhwc_gkzyxc_gndhwk_i8_instance.cpp deleted file mode 100644 index 0ae9ee61e4..0000000000 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/wmma/device_grouped_conv3d_bwd_weight_wmma_gndhwc_gkzyxc_gndhwk_i8_instance.cpp +++ /dev/null @@ -1,35 +0,0 @@ -// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. -// SPDX-License-Identifier: MIT - -#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" -#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_wmma_instance.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { -void add_device_grouped_conv3d_bwd_weight_wmma_gndhwc_gkzyxc_gndhwk_i8_instances( - std::vector>>& instances) -{ - add_device_operation_instances( - instances, - device_grouped_conv_bwd_weight_wmma_i8_instances<3, - GNDHWC, - GKZYXC, - GNDHWK, - ConvBwdWeightDefault>{}); -} - -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/wmma/device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_i8_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/wmma/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_two_stage_wmma_ndhwgc_gkzyxc_ndhwgk_bf16_pipev1_instance.cpp similarity index 56% rename from library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/wmma/device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_i8_instance.cpp rename to library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/wmma/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_two_stage_wmma_ndhwgc_gkzyxc_ndhwgk_bf16_pipev1_instance.cpp index 268aeb617c..60435d0a43 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/wmma/device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_i8_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/wmma/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_two_stage_wmma_ndhwgc_gkzyxc_ndhwgk_bf16_pipev1_instance.cpp @@ -2,31 +2,37 @@ // SPDX-License-Identifier: MIT #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" -#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_wmma_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_two_stage_wmma_instance.hpp" namespace ck { namespace tensor_operation { namespace device { namespace instance { -void add_device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_i8_instances( + +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv3d_bwd_weight_two_stage_wmma_ndhwgc_gkzyxc_ndhwgk_bf16_pipev1_instances( std::vector>>& instances) { + // 1. Default add_device_operation_instances( instances, - device_grouped_conv_bwd_weight_wmma_i8_instances<3, - NDHWGC, - GKZYXC, - NDHWGK, - ConvBwdWeightDefault>{}); + device_grouped_conv_bwd_weight_two_stage_nhwgc_wmma_c_shuffle_bf16_instances< + 3, + NDHWGC, + GKZYXC, + NDHWGK, + ConvBwdWeightDefault, + BlockGemmPipelineScheduler::Intrawave, + BlockGemmPipelineVersion::v1>{}); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/wmma/device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_f16_1x1s1p0_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/wmma/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_two_stage_wmma_ndhwgc_gkzyxc_ndhwgk_f16_pipev1_instance.cpp similarity index 67% rename from library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/wmma/device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_f16_1x1s1p0_instance.cpp rename to library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/wmma/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_two_stage_wmma_ndhwgc_gkzyxc_ndhwgk_f16_pipev1_instance.cpp index 7a9d75560f..e912b086c0 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/wmma/device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_f16_1x1s1p0_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/wmma/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_two_stage_wmma_ndhwgc_gkzyxc_ndhwgk_f16_pipev1_instance.cpp @@ -2,13 +2,15 @@ // SPDX-License-Identifier: MIT #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" -#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_wmma_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_two_stage_wmma_instance.hpp" namespace ck { namespace tensor_operation { namespace device { namespace instance { -void add_device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_f16_1x1s1p0_instances( + +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv3d_bwd_weight_two_stage_wmma_ndhwgc_gkzyxc_ndhwgk_f16_pipev1_instances( std::vector>>& instances) { + // 1. Default add_device_operation_instances( instances, - device_grouped_conv_bwd_weight_wmma_f16_instances<3, - NDHWGC, - GKZYXC, - NDHWGK, - ConvBwdWeightFilter1x1Stride1Pad0>{}); + device_grouped_conv_bwd_weight_two_stage_nhwgc_wmma_c_shuffle_f16_instances< + 3, + NDHWGC, + GKZYXC, + NDHWGK, + ConvBwdWeightDefault, + BlockGemmPipelineScheduler::Intrawave, + BlockGemmPipelineVersion::v1>{}); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/wmma/device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_i8_1x1s1p0_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/wmma/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp similarity index 56% rename from library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/wmma/device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_i8_1x1s1p0_instance.cpp rename to library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/wmma/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp index 7dd523bae3..728f514f9a 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/wmma/device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_i8_1x1s1p0_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/wmma/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp @@ -2,31 +2,34 @@ // SPDX-License-Identifier: MIT #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" -#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_wmma_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_v3_wmma_instance.hpp" namespace ck { namespace tensor_operation { namespace device { namespace instance { -void add_device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_i8_1x1s1p0_instances( + +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_bf16_instances( std::vector>>& instances) { + // 1. Default add_device_operation_instances( instances, - device_grouped_conv_bwd_weight_wmma_i8_instances<3, - NDHWGC, - GKZYXC, - NDHWGK, - ConvBwdWeightFilter1x1Stride1Pad0>{}); + device_grouped_conv_bwd_weight_v3_wmma_c_shuffle_bf16_instances<3, + NDHWGC, + GKZYXC, + NDHWGK, + ConvBwdWeightDefault>{}); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/wmma/device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/wmma/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp similarity index 68% rename from library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/wmma/device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp rename to library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/wmma/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp index 764c4a0224..f929196ddb 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/wmma/device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/wmma/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp @@ -2,12 +2,14 @@ // SPDX-License-Identifier: MIT #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" -#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_wmma_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_v3_wmma_instance.hpp" namespace ck { namespace tensor_operation { namespace device { namespace instance { + +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] void add_device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_f16_instances( std::vector>>& instances) { + // 1. Default add_device_operation_instances( instances, - device_grouped_conv_bwd_weight_wmma_f16_instances<3, - NDHWGC, - GKZYXC, - NDHWGK, - ConvBwdWeightDefault>{}); + device_grouped_conv_bwd_weight_v3_wmma_c_shuffle_f16_instances<3, + NDHWGC, + GKZYXC, + NDHWGK, + ConvBwdWeightDefault>{}); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight_bilinear/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight_bilinear/CMakeLists.txt index f2187485a9..455f14d9f4 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight_bilinear/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight_bilinear/CMakeLists.txt @@ -1,7 +1,7 @@ # Copyright (c) Advanced Micro Devices, Inc., or its affiliates. # SPDX-License-Identifier: MIT -# ONLY XDL_KERNELS +# ONLY XDL_AND_WMMA_KERNELS set(GROUPED_CONV3D_BWD_WEIGHT_BILINEAR xdl/device_grouped_conv3d_bwd_weight_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp xdl/device_grouped_conv3d_bwd_weight_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp @@ -13,4 +13,9 @@ if((DTYPES MATCHES "fp8" AND DTYPES MATCHES "bf8" AND DTYPES MATCHES "fp16") OR xdl/device_grouped_conv3d_bwd_weight_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_f16_comp_bf8_fp8_instance.cpp) endif() +list(APPEND GROUPED_CONV3D_BWD_WEIGHT_BILINEAR + wmma/device_grouped_conv3d_bwd_weight_wmma_bilinear_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp + wmma/device_grouped_conv3d_bwd_weight_wmma_bilinear_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp + ) + add_instance_library(device_grouped_conv3d_bwd_weight_bilinear_instance ${GROUPED_CONV3D_BWD_WEIGHT_BILINEAR}) diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight_bilinear/wmma/device_grouped_conv3d_bwd_weight_wmma_bilinear_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight_bilinear/wmma/device_grouped_conv3d_bwd_weight_wmma_bilinear_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp new file mode 100644 index 0000000000..06398729af --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight_bilinear/wmma/device_grouped_conv3d_bwd_weight_wmma_bilinear_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp @@ -0,0 +1,50 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_wmma_bilinear_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv3d_bwd_weight_wmma_bilinear_ndhwgc_gkzyxc_ndhwgk_bf16_f32_bf16_instances( + std::vector, + BF16, + F32, + BF16, + Tuple, + PassThrough, + Bilinear, + PassThrough>>>& instances) +{ + // 1. Default + add_device_operation_instances( + instances, + device_grouped_conv_bwd_weight_wmma_c_shuffle_bf16_bilinear_instances< + 3, + NDHWGC, + GKZYXC, + NDHWGK, + ConvBwdWeightDefault>{}); + // 2. Filter1x1Stride1Pad0 + add_device_operation_instances( + instances, + device_grouped_conv_bwd_weight_wmma_c_shuffle_bf16_bilinear_instances< + 3, + NDHWGC, + GKZYXC, + NDHWGK, + ConvBwdWeightFilter1x1Stride1Pad0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight_bilinear/wmma/device_grouped_conv3d_bwd_weight_wmma_bilinear_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight_bilinear/wmma/device_grouped_conv3d_bwd_weight_wmma_bilinear_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp new file mode 100644 index 0000000000..a7df39161a --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight_bilinear/wmma/device_grouped_conv3d_bwd_weight_wmma_bilinear_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp @@ -0,0 +1,50 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_wmma_bilinear_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv3d_bwd_weight_wmma_bilinear_ndhwgc_gkzyxc_ndhwgk_f16_instances( + std::vector, + F16, + F16, + F16, + Tuple, + PassThrough, + Bilinear, + PassThrough>>>& instances) +{ + // 1. Default + add_device_operation_instances( + instances, + device_grouped_conv_bwd_weight_wmma_c_shuffle_f16_bilinear_instances< + 3, + NDHWGC, + GKZYXC, + NDHWGK, + ConvBwdWeightDefault>{}); + // 2. Filter1x1Stride1Pad0 + add_device_operation_instances( + instances, + device_grouped_conv_bwd_weight_wmma_c_shuffle_f16_bilinear_instances< + 3, + NDHWGC, + GKZYXC, + NDHWGK, + ConvBwdWeightFilter1x1Stride1Pad0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight_scale/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight_scale/CMakeLists.txt index bce32f3bdb..b7fefdc94f 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight_scale/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight_scale/CMakeLists.txt @@ -1,7 +1,7 @@ # Copyright (c) Advanced Micro Devices, Inc., or its affiliates. # SPDX-License-Identifier: MIT -# ONLY XDL_KERNELS +# ONLY XDL_AND_WMMA_KERNELS set(GROUPED_CONV3D_BWD_WEIGHT_SCALE xdl/device_grouped_conv3d_bwd_weight_xdl_scale_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp xdl/device_grouped_conv3d_bwd_weight_xdl_scale_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp @@ -13,4 +13,9 @@ if((DTYPES MATCHES "fp8" AND DTYPES MATCHES "bf8" AND DTYPES MATCHES "fp16") OR xdl/device_grouped_conv3d_bwd_weight_xdl_scale_ndhwgc_gkzyxc_ndhwgk_f16_comp_bf8_fp8_instance.cpp) endif() +list(APPEND GROUPED_CONV3D_BWD_WEIGHT_SCALE + wmma/device_grouped_conv3d_bwd_weight_wmma_scale_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp + wmma/device_grouped_conv3d_bwd_weight_wmma_scale_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp + ) + add_instance_library(device_grouped_conv3d_bwd_weight_scale_instance ${GROUPED_CONV3D_BWD_WEIGHT_SCALE}) diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight_scale/wmma/device_grouped_conv3d_bwd_weight_wmma_scale_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight_scale/wmma/device_grouped_conv3d_bwd_weight_wmma_scale_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp new file mode 100644 index 0000000000..32aeb2f19f --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight_scale/wmma/device_grouped_conv3d_bwd_weight_wmma_scale_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp @@ -0,0 +1,49 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_wmma_scale_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv3d_bwd_weight_wmma_scale_ndhwgc_gkzyxc_ndhwgk_bf16_f32_bf16_instances( + std::vector, + BF16, + F32, + BF16, + Tuple<>, + PassThrough, + Scale, + PassThrough>>>& instances) +{ + // 1. Default + add_device_operation_instances( + instances, + device_grouped_conv_bwd_weight_wmma_c_shuffle_bf16_scale_instances<3, + NDHWGC, + GKZYXC, + NDHWGK, + ConvBwdWeightDefault>{}); + // 2. Filter1x1Stride1Pad0 + add_device_operation_instances( + instances, + device_grouped_conv_bwd_weight_wmma_c_shuffle_bf16_scale_instances< + 3, + NDHWGC, + GKZYXC, + NDHWGK, + ConvBwdWeightFilter1x1Stride1Pad0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight_scale/wmma/device_grouped_conv3d_bwd_weight_wmma_scale_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight_scale/wmma/device_grouped_conv3d_bwd_weight_wmma_scale_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp new file mode 100644 index 0000000000..389b80cfb5 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight_scale/wmma/device_grouped_conv3d_bwd_weight_wmma_scale_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp @@ -0,0 +1,49 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_wmma_scale_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv3d_bwd_weight_wmma_scale_ndhwgc_gkzyxc_ndhwgk_f16_instances( + std::vector, + F16, + F16, + F16, + Tuple<>, + PassThrough, + Scale, + PassThrough>>>& instances) +{ + // 1. Default + add_device_operation_instances( + instances, + device_grouped_conv_bwd_weight_wmma_c_shuffle_f16_scale_instances<3, + NDHWGC, + GKZYXC, + NDHWGK, + ConvBwdWeightDefault>{}); + // 2. Filter1x1Stride1Pad0 + add_device_operation_instances( + instances, + device_grouped_conv_bwd_weight_wmma_c_shuffle_f16_scale_instances< + 3, + NDHWGC, + GKZYXC, + NDHWGK, + ConvBwdWeightFilter1x1Stride1Pad0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/CMakeLists.txt index f909fe0356..08f95601f7 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/CMakeLists.txt @@ -1,29 +1,37 @@ # Copyright (c) Advanced Micro Devices, Inc., or its affiliates. # SPDX-License-Identifier: MIT -# ONLY XDL_KERNELS +# ONLY XDL_AND_WMMA_KERNELS set(GROUPED_CONVND_EXP_BWD_WEIGHT # Explicit instances are common for 2d and 3d - explicit_xdl/bf16_bf16_bf16/device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_comp_default_instance.cpp - explicit_xdl/bf16_bf16_bf16/device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_comp_mnkpadding_instance.cpp - explicit_xdl/bf16_bf16_bf16/device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v1_default_instance.cpp - explicit_xdl/bf16_bf16_bf16/device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v1_mnkpadding_instance.cpp - explicit_xdl/bf16_bf16_bf16/device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v2_default_instance.cpp - explicit_xdl/bf16_bf16_bf16/device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v2_mnkpadding_instance.cpp + explicit_xdl/bf16_bf16_bf16/device_grouped_convnd_bwd_weight_xdl_bf16_bf16_bf16_exp_comp_default_instance.cpp + explicit_xdl/bf16_bf16_bf16/device_grouped_convnd_bwd_weight_xdl_bf16_bf16_bf16_exp_comp_mnkpadding_instance.cpp + explicit_xdl/bf16_bf16_bf16/device_grouped_convnd_bwd_weight_xdl_bf16_bf16_bf16_exp_mem_v1_default_instance.cpp + explicit_xdl/bf16_bf16_bf16/device_grouped_convnd_bwd_weight_xdl_bf16_bf16_bf16_exp_mem_v1_mnkpadding_instance.cpp + explicit_xdl/bf16_bf16_bf16/device_grouped_convnd_bwd_weight_xdl_bf16_bf16_bf16_exp_mem_v2_default_instance.cpp + explicit_xdl/bf16_bf16_bf16/device_grouped_convnd_bwd_weight_xdl_bf16_bf16_bf16_exp_mem_v2_mnkpadding_instance.cpp - explicit_xdl/bf16_bf16_bf16/device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_odd_mn_instance.cpp - explicit_xdl/bf16_bf16_bf16/device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_odd_m_instance.cpp - explicit_xdl/bf16_bf16_bf16/device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_odd_n_instance.cpp + explicit_xdl/bf16_bf16_bf16/device_grouped_convnd_bwd_weight_xdl_bf16_bf16_bf16_exp_odd_mn_instance.cpp + explicit_xdl/bf16_bf16_bf16/device_grouped_convnd_bwd_weight_xdl_bf16_bf16_bf16_exp_odd_m_instance.cpp + explicit_xdl/bf16_bf16_bf16/device_grouped_convnd_bwd_weight_xdl_bf16_bf16_bf16_exp_odd_n_instance.cpp - explicit_xdl/fp16_fp16_fp16/device_grouped_convnd_bwd_weight_f16_f16_f16_exp_comp_default_instance.cpp - explicit_xdl/fp16_fp16_fp16/device_grouped_convnd_bwd_weight_f16_f16_f16_exp_comp_mnkpadding_instance.cpp - explicit_xdl/fp16_fp16_fp16/device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v1_default_instance.cpp - explicit_xdl/fp16_fp16_fp16/device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v1_mnkpadding_instance.cpp - explicit_xdl/fp16_fp16_fp16/device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v2_default_instance.cpp - explicit_xdl/fp16_fp16_fp16/device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v2_mnkpadding_instance.cpp + explicit_xdl/fp16_fp16_fp16/device_grouped_convnd_bwd_weight_xdl_f16_f16_f16_exp_comp_default_instance.cpp + explicit_xdl/fp16_fp16_fp16/device_grouped_convnd_bwd_weight_xdl_f16_f16_f16_exp_comp_mnkpadding_instance.cpp + explicit_xdl/fp16_fp16_fp16/device_grouped_convnd_bwd_weight_xdl_f16_f16_f16_exp_mem_v1_default_instance.cpp + explicit_xdl/fp16_fp16_fp16/device_grouped_convnd_bwd_weight_xdl_f16_f16_f16_exp_mem_v1_mnkpadding_instance.cpp + explicit_xdl/fp16_fp16_fp16/device_grouped_convnd_bwd_weight_xdl_f16_f16_f16_exp_mem_v2_default_instance.cpp + explicit_xdl/fp16_fp16_fp16/device_grouped_convnd_bwd_weight_xdl_f16_f16_f16_exp_mem_v2_mnkpadding_instance.cpp - explicit_xdl/fp16_fp16_fp16/device_grouped_convnd_bwd_weight_f16_f16_f16_exp_odd_mn_instance.cpp - explicit_xdl/fp16_fp16_fp16/device_grouped_convnd_bwd_weight_f16_f16_f16_exp_odd_n_instance.cpp - explicit_xdl/fp16_fp16_fp16/device_grouped_convnd_bwd_weight_f16_f16_f16_exp_odd_m_instance.cpp + explicit_xdl/fp16_fp16_fp16/device_grouped_convnd_bwd_weight_xdl_f16_f16_f16_exp_odd_mn_instance.cpp + explicit_xdl/fp16_fp16_fp16/device_grouped_convnd_bwd_weight_xdl_f16_f16_f16_exp_odd_n_instance.cpp + explicit_xdl/fp16_fp16_fp16/device_grouped_convnd_bwd_weight_xdl_f16_f16_f16_exp_odd_m_instance.cpp + + explicit_wmma/bf16_bf16_bf16/device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_comp_default_instance.cpp + explicit_wmma/bf16_bf16_bf16/device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_comp_mnkpadding_instance.cpp + explicit_wmma/bf16_bf16_bf16/device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_odd_mn_instance.cpp + + explicit_wmma/fp16_fp16_fp16/device_grouped_convnd_bwd_weight_wmma_f16_f16_f16_exp_comp_default_instance.cpp + explicit_wmma/fp16_fp16_fp16/device_grouped_convnd_bwd_weight_wmma_f16_f16_f16_exp_comp_mnkpadding_instance.cpp + explicit_wmma/fp16_fp16_fp16/device_grouped_convnd_bwd_weight_wmma_f16_f16_f16_exp_odd_mn_instance.cpp ) add_instance_library(device_grouped_convnd_bwd_weight_instance ${GROUPED_CONVND_EXP_BWD_WEIGHT}) diff --git a/library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/explicit_wmma/bf16_bf16_bf16/device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_comp_default_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/explicit_wmma/bf16_bf16_bf16/device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_comp_default_instance.cpp new file mode 100644 index 0000000000..894063e081 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/explicit_wmma/bf16_bf16_bf16/device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_comp_default_instance.cpp @@ -0,0 +1,67 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_exp_gemm_wmma_universal_km_kn_mn_instance.hpp" +#include "ck/host_utility/device_prop.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_comp_default_instances( + std::vector>>& instances) +{ + add_explicit_gemm_device_operation_instances< + 2, + NHWGC, + GKYXC, + NHWGK, + BF16, + BF16, + BF16, + PassThrough, + PassThrough, + PassThrough, + device_gemm_wmma_universal_km_kn_mn_GemmDefault_instances>(instances); +} + +void add_device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_comp_default_instances( + std::vector>>& instances) +{ + add_explicit_gemm_device_operation_instances< + 3, + NDHWGC, + GKZYXC, + NDHWGK, + BF16, + BF16, + BF16, + PassThrough, + PassThrough, + PassThrough, + device_gemm_wmma_universal_km_kn_mn_GemmDefault_instances>(instances); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/explicit_wmma/bf16_bf16_bf16/device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_comp_mnkpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/explicit_wmma/bf16_bf16_bf16/device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_comp_mnkpadding_instance.cpp new file mode 100644 index 0000000000..a3b16e4216 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/explicit_wmma/bf16_bf16_bf16/device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_comp_mnkpadding_instance.cpp @@ -0,0 +1,67 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_exp_gemm_wmma_universal_km_kn_mn_instance.hpp" +#include "ck/host_utility/device_prop.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_comp_mnkpadding_instances( + std::vector>>& instances) +{ + add_explicit_gemm_device_operation_instances< + 2, + NHWGC, + GKYXC, + NHWGK, + BF16, + BF16, + BF16, + PassThrough, + PassThrough, + PassThrough, + device_gemm_wmma_universal_km_kn_mn_GemmMNKPadding_bf16_instances>(instances); +} + +void add_device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_comp_mnkpadding_instances( + std::vector>>& instances) +{ + add_explicit_gemm_device_operation_instances< + 3, + NDHWGC, + GKZYXC, + NDHWGK, + BF16, + BF16, + BF16, + PassThrough, + PassThrough, + PassThrough, + device_gemm_wmma_universal_km_kn_mn_GemmMNKPadding_bf16_instances>(instances); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/explicit_wmma/bf16_bf16_bf16/device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_odd_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/explicit_wmma/bf16_bf16_bf16/device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_odd_mn_instance.cpp new file mode 100644 index 0000000000..967e2884f9 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/explicit_wmma/bf16_bf16_bf16/device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_odd_mn_instance.cpp @@ -0,0 +1,67 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_exp_gemm_wmma_universal_km_kn_mn_instance.hpp" +#include "ck/host_utility/device_prop.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_odd_mn_instances( + std::vector>>& instances) +{ + add_explicit_gemm_device_operation_instances< + 2, + NHWGC, + GKYXC, + NHWGK, + BF16, + BF16, + BF16, + PassThrough, + PassThrough, + PassThrough, + device_gemm_wmma_universal_km_kn_mn_irregular_odd_mn_instances>(instances); +} + +void add_device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_odd_mn_instances( + std::vector>>& instances) +{ + add_explicit_gemm_device_operation_instances< + 3, + NDHWGC, + GKZYXC, + NDHWGK, + BF16, + BF16, + BF16, + PassThrough, + PassThrough, + PassThrough, + device_gemm_wmma_universal_km_kn_mn_irregular_odd_mn_instances>(instances); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/explicit_wmma/fp16_fp16_fp16/device_grouped_convnd_bwd_weight_wmma_f16_f16_f16_exp_comp_default_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/explicit_wmma/fp16_fp16_fp16/device_grouped_convnd_bwd_weight_wmma_f16_f16_f16_exp_comp_default_instance.cpp new file mode 100644 index 0000000000..38e98e719e --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/explicit_wmma/fp16_fp16_fp16/device_grouped_convnd_bwd_weight_wmma_f16_f16_f16_exp_comp_default_instance.cpp @@ -0,0 +1,67 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_exp_gemm_wmma_universal_km_kn_mn_instance.hpp" +#include "ck/host_utility/device_prop.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_convnd_bwd_weight_wmma_f16_f16_f16_exp_comp_default_instances( + std::vector>>& instances) +{ + add_explicit_gemm_device_operation_instances< + 2, + NHWGC, + GKYXC, + NHWGK, + F16, + F16, + F16, + PassThrough, + PassThrough, + PassThrough, + device_gemm_wmma_universal_km_kn_mn_GemmDefault_instances>(instances); +} + +void add_device_grouped_convnd_bwd_weight_wmma_f16_f16_f16_exp_comp_default_instances( + std::vector>>& instances) +{ + add_explicit_gemm_device_operation_instances< + 3, + NDHWGC, + GKZYXC, + NDHWGK, + F16, + F16, + F16, + PassThrough, + PassThrough, + PassThrough, + device_gemm_wmma_universal_km_kn_mn_GemmDefault_instances>(instances); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/explicit_wmma/fp16_fp16_fp16/device_grouped_convnd_bwd_weight_wmma_f16_f16_f16_exp_comp_mnkpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/explicit_wmma/fp16_fp16_fp16/device_grouped_convnd_bwd_weight_wmma_f16_f16_f16_exp_comp_mnkpadding_instance.cpp new file mode 100644 index 0000000000..b0a8998562 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/explicit_wmma/fp16_fp16_fp16/device_grouped_convnd_bwd_weight_wmma_f16_f16_f16_exp_comp_mnkpadding_instance.cpp @@ -0,0 +1,67 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_exp_gemm_wmma_universal_km_kn_mn_instance.hpp" +#include "ck/host_utility/device_prop.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_convnd_bwd_weight_wmma_f16_f16_f16_exp_comp_mnkpadding_instances( + std::vector>>& instances) +{ + add_explicit_gemm_device_operation_instances< + 2, + NHWGC, + GKYXC, + NHWGK, + F16, + F16, + F16, + PassThrough, + PassThrough, + PassThrough, + device_gemm_wmma_universal_km_kn_mn_GemmMNKPadding_f16_instances>(instances); +} + +void add_device_grouped_convnd_bwd_weight_wmma_f16_f16_f16_exp_comp_mnkpadding_instances( + std::vector>>& instances) +{ + add_explicit_gemm_device_operation_instances< + 3, + NDHWGC, + GKZYXC, + NDHWGK, + F16, + F16, + F16, + PassThrough, + PassThrough, + PassThrough, + device_gemm_wmma_universal_km_kn_mn_GemmMNKPadding_f16_instances>(instances); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/explicit_wmma/fp16_fp16_fp16/device_grouped_convnd_bwd_weight_wmma_f16_f16_f16_exp_odd_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/explicit_wmma/fp16_fp16_fp16/device_grouped_convnd_bwd_weight_wmma_f16_f16_f16_exp_odd_mn_instance.cpp new file mode 100644 index 0000000000..ace411ea68 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/explicit_wmma/fp16_fp16_fp16/device_grouped_convnd_bwd_weight_wmma_f16_f16_f16_exp_odd_mn_instance.cpp @@ -0,0 +1,67 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_exp_gemm_wmma_universal_km_kn_mn_instance.hpp" +#include "ck/host_utility/device_prop.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_convnd_bwd_weight_wmma_f16_f16_f16_exp_odd_mn_instances( + std::vector>>& instances) +{ + add_explicit_gemm_device_operation_instances< + 2, + NHWGC, + GKYXC, + NHWGK, + F16, + F16, + F16, + PassThrough, + PassThrough, + PassThrough, + device_gemm_wmma_universal_km_kn_mn_irregular_odd_mn_instances>(instances); +} + +void add_device_grouped_convnd_bwd_weight_wmma_f16_f16_f16_exp_odd_mn_instances( + std::vector>>& instances) +{ + add_explicit_gemm_device_operation_instances< + 3, + NDHWGC, + GKZYXC, + NDHWGK, + F16, + F16, + F16, + PassThrough, + PassThrough, + PassThrough, + device_gemm_wmma_universal_km_kn_mn_irregular_odd_mn_instances>(instances); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/explicit_xdl/bf16_bf16_bf16/device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_comp_default_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/explicit_xdl/bf16_bf16_bf16/device_grouped_convnd_bwd_weight_xdl_bf16_bf16_bf16_exp_comp_default_instance.cpp similarity index 93% rename from library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/explicit_xdl/bf16_bf16_bf16/device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_comp_default_instance.cpp rename to library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/explicit_xdl/bf16_bf16_bf16/device_grouped_convnd_bwd_weight_xdl_bf16_bf16_bf16_exp_comp_default_instance.cpp index 331b3a7eaa..2aefcde0fa 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/explicit_xdl/bf16_bf16_bf16/device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_comp_default_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/explicit_xdl/bf16_bf16_bf16/device_grouped_convnd_bwd_weight_xdl_bf16_bf16_bf16_exp_comp_default_instance.cpp @@ -9,7 +9,7 @@ namespace tensor_operation { namespace device { namespace instance { -void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_comp_default_instances( +void add_device_grouped_convnd_bwd_weight_xdl_bf16_bf16_bf16_exp_comp_default_instances( std::vector>(instances); } -void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_comp_default_instances( +void add_device_grouped_convnd_bwd_weight_xdl_bf16_bf16_bf16_exp_comp_default_instances( std::vector>(instances); } -void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_comp_mnkpadding_instances( +void add_device_grouped_convnd_bwd_weight_xdl_bf16_bf16_bf16_exp_comp_mnkpadding_instances( std::vector>(instances); } -void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v1_default_instances( +void add_device_grouped_convnd_bwd_weight_xdl_bf16_bf16_bf16_exp_mem_v1_default_instances( std::vector>(instances); } -void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v2_default_instances( +void add_device_grouped_convnd_bwd_weight_xdl_bf16_bf16_bf16_exp_mem_v2_default_instances( std::vector>(instances); } -void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_odd_m_instances( +void add_device_grouped_convnd_bwd_weight_xdl_bf16_bf16_bf16_exp_odd_m_instances( std::vector>(instances); } -void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_odd_mn_instances( +void add_device_grouped_convnd_bwd_weight_xdl_bf16_bf16_bf16_exp_odd_mn_instances( std::vector>(instances); } -void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_comp_default_instances( +void add_device_grouped_convnd_bwd_weight_xdl_f16_f16_f16_exp_comp_default_instances( std::vector>(instances); } -void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_comp_mnkpadding_instances( +void add_device_grouped_convnd_bwd_weight_xdl_f16_f16_f16_exp_comp_mnkpadding_instances( std::vector>(instances); } -void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v1_default_instances( +void add_device_grouped_convnd_bwd_weight_xdl_f16_f16_f16_exp_mem_v1_default_instances( std::vector>(instances); } -void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v2_default_instances( +void add_device_grouped_convnd_bwd_weight_xdl_f16_f16_f16_exp_mem_v2_default_instances( std::vector>(instances); } -void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_odd_m_instances( +void add_device_grouped_convnd_bwd_weight_xdl_f16_f16_f16_exp_odd_m_instances( std::vector>(instances); } -void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_odd_mn_instances( +void add_device_grouped_convnd_bwd_weight_xdl_f16_f16_f16_exp_odd_mn_instances( std::vector(argument_ptr.get()); - if(split_k_arg && split_k_value < 0) + + // If split_k was determined by the device implementation, get the resulting value. + if(split_k_value < 0) { - split_k_value = split_k_arg->k_batch_; - split_k_param_str = std::to_string(split_k_value) + " (best occupancy)"; + auto* split_k_arg = + dynamic_cast(argument_ptr.get()); + if(split_k_arg) + { + split_k_value = split_k_arg->k_batch_; + split_k_param_str = std::to_string(split_k_value) + " (best occupancy)"; + } + else + { + // We may have an implementation whose argument is not derived from + // ArgumentSplitK, which means we can not determine the splitK value. Warn. + printf("Warning: Unable to determine split_k value for this instance!\n"); + } + } + + // Not all device implementation actually do anything with the passed split_k value but + // it needs to be positive to determine error tolerances. + if(split_k_value < 0) + { + split_k_value = 1; } const std::size_t workspace_sz = op_ptr->GetWorkSpaceSize(argument_ptr.get()); @@ -297,12 +315,13 @@ bool profile_grouped_conv_bwd_weight_impl(int do_verification, "Error: Incorrect results!", rtol, atol); - std::cout << "Relative error threshold: " << rtol - << " Absolute error threshold: " << atol << std::endl; if(!pass) { - std::cout << "Fail info: " << op_ptr->GetTypeString() << std::endl; + std::cout << "Relative error threshold: " << rtol + << " Absolute error threshold: " << atol << std::endl; + std::cout << "Fail info: splitK: " << split_k_value << " " + << op_ptr->GetTypeString() << std::endl; } all_pass &= pass; @@ -330,6 +349,8 @@ bool profile_grouped_conv_bwd_weight_impl(int do_verification, } } + printf("\033[36mvalids: %d\033[0m\n", num_kernel); + std::cout << "Best configuration parameters:" << "\nname: " << best_op_name << "\navg_time: " << best_avg_time << "\ntflops: " << best_tflops << "\nGB/s: " << best_gb_per_sec << ", SplitK " << best_split_k << std::endl; diff --git a/profiler/src/CMakeLists.txt b/profiler/src/CMakeLists.txt index 71f1637653..5833fc3626 100644 --- a/profiler/src/CMakeLists.txt +++ b/profiler/src/CMakeLists.txt @@ -209,9 +209,6 @@ if(SUPPORTED_GPU_TARGETS MATCHES "gfx9|gfx1[12]") list(APPEND DEVICE_INSTANCES device_conv1d_bwd_data_instance) list(APPEND DEVICE_INSTANCES device_conv3d_bwd_data_instance) list(APPEND DEVICE_INSTANCES device_conv2d_bwd_data_instance) - list(APPEND DEVICE_INSTANCES device_grouped_conv1d_bwd_weight_instance) - list(APPEND DEVICE_INSTANCES device_grouped_conv2d_bwd_weight_instance) - list(APPEND DEVICE_INSTANCES device_grouped_convnd_bwd_weight_instance) list(APPEND DEVICE_INSTANCES device_grouped_conv3d_fwd_convscale_instance) list(APPEND DEVICE_INSTANCES device_grouped_conv3d_fwd_convinvscale_instance) list(APPEND DEVICE_INSTANCES device_grouped_conv2d_fwd_clamp_instance) @@ -238,7 +235,6 @@ if(SUPPORTED_GPU_TARGETS MATCHES "gfx9|gfx1[12]") list(APPEND DEVICE_INSTANCES device_grouped_conv2d_bwd_data_instance) list(APPEND DEVICE_INSTANCES device_grouped_conv3d_bwd_data_instance) list(APPEND DEVICE_INSTANCES device_grouped_conv2d_fwd_instance) - list(APPEND DEVICE_INSTANCES device_grouped_conv3d_bwd_weight_instance) list(APPEND DEVICE_INSTANCES device_gemm_add_relu_instance) list(APPEND DEVICE_INSTANCES device_gemm_multi_abd_instance) if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) @@ -251,6 +247,10 @@ if(SUPPORTED_GPU_TARGETS MATCHES "gfx9|gfx1[12]") list(APPEND DEVICE_INSTANCES device_gemm_add_add_fastgelu_instance) endif() list(APPEND DEVICE_INSTANCES device_batched_gemm_gemm_instance) + list(APPEND DEVICE_INSTANCES device_grouped_conv1d_bwd_weight_instance) + list(APPEND DEVICE_INSTANCES device_grouped_conv2d_bwd_weight_instance) + list(APPEND DEVICE_INSTANCES device_grouped_convnd_bwd_weight_instance) + list(APPEND DEVICE_INSTANCES device_grouped_conv3d_bwd_weight_instance) endif() if(DL_KERNELS) diff --git a/test/grouped_convnd_bwd_weight/CMakeLists.txt b/test/grouped_convnd_bwd_weight/CMakeLists.txt index 7b994f5bb8..165c3b7863 100644 --- a/test/grouped_convnd_bwd_weight/CMakeLists.txt +++ b/test/grouped_convnd_bwd_weight/CMakeLists.txt @@ -5,16 +5,19 @@ if(GPU_TARGETS MATCHES "gfx9|gfx11|gfx12") add_gtest_executable(test_grouped_convnd_bwd_weight test_grouped_convnd_bwd_weight.cpp) target_link_libraries(test_grouped_convnd_bwd_weight PRIVATE utility device_grouped_conv1d_bwd_weight_instance device_grouped_conv2d_bwd_weight_instance device_grouped_conv3d_bwd_weight_instance device_grouped_convnd_bwd_weight_instance) + add_gtest_executable(test_grouped_convnd_bwd_weight_bilinear test_grouped_convnd_bwd_weight_bilinear.cpp) + target_link_libraries(test_grouped_convnd_bwd_weight_bilinear PRIVATE utility device_grouped_conv3d_bwd_weight_bilinear_instance) + add_gtest_executable(test_grouped_convnd_bwd_weight_scale test_grouped_convnd_bwd_weight_scale.cpp) + target_link_libraries(test_grouped_convnd_bwd_weight_scale PRIVATE utility device_grouped_conv3d_bwd_weight_scale_instance) + add_executable(test_grouped_convnd_bwd_weight_dataset_xdl test_grouped_convnd_bwd_weight_dataset_xdl.cpp) target_compile_options(test_grouped_convnd_bwd_weight_dataset_xdl PRIVATE -Wno-global-constructors -Wno-undef) target_link_libraries(test_grouped_convnd_bwd_weight_dataset_xdl PRIVATE gtest_main getopt::getopt utility device_grouped_conv1d_bwd_weight_instance device_grouped_conv2d_bwd_weight_instance device_grouped_conv3d_bwd_weight_instance device_grouped_convnd_bwd_weight_instance) elseif(DL_KERNELS) add_gtest_executable(test_grouped_convnd_bwd_weight test_grouped_convnd_bwd_weight.cpp) target_link_libraries(test_grouped_convnd_bwd_weight PRIVATE utility device_grouped_conv1d_bwd_weight_instance device_grouped_conv2d_bwd_weight_instance device_grouped_conv3d_bwd_weight_instance) -elseif(GPU_TARGETS MATCHES "gfx11") - add_gtest_executable(test_grouped_convnd_bwd_weight test_grouped_convnd_bwd_weight.cpp) - target_link_libraries(test_grouped_convnd_bwd_weight PRIVATE utility device_grouped_conv3d_bwd_weight_instance) endif() + add_gtest_executable(test_grouped_convnd_bwd_weight_interface_xdl test_grouped_convnd_bwd_weight_interface_xdl.cpp) if(result EQUAL 0) target_link_libraries(test_grouped_convnd_bwd_weight_interface_xdl PRIVATE utility) @@ -27,7 +30,3 @@ add_gtest_executable(test_grouped_convnd_bwd_weight_interface_wmma test_grouped_ if(result EQUAL 0) target_link_libraries(test_grouped_convnd_bwd_weight_interface_wmma PRIVATE utility) endif() -add_gtest_executable(test_grouped_conv_bwd_weight_xdl_bilinear test_grouped_conv_bwd_weight_xdl_bilinear.cpp) -if(result EQUAL 0) - target_link_libraries(test_grouped_conv_bwd_weight_xdl_bilinear PRIVATE utility device_grouped_conv3d_bwd_weight_bilinear_instance) -endif() diff --git a/test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight.cpp b/test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight.cpp index 4d4fcb300d..4b5e38dea6 100644 --- a/test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight.cpp +++ b/test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight.cpp @@ -46,44 +46,6 @@ class TestGroupedConvndBwdWeight : public ::testing::Test return true; } } - if(ck::is_gfx11_supported() || ck::is_gfx12_supported()) - { - // on gfx11 only support for 3d is implemented - if constexpr(NDimSpatial{} != 3) - { - return true; - } - // on gfx11 only support for i8 and fp16 is implemented - if constexpr(!((std::is_same_v && - std::is_same_v && - std::is_same_v) || - (std::is_same_v && - std::is_same_v && - std::is_same_v))) - { - return true; - } - // WMMA kernel is only supported for split_k=1 - if(split_k != 1) - { - return true; - } - // Skip due to the lack of kernels for NGCDHW - if constexpr(std::is_same_v || std::is_same_v || - std::is_same_v) - { - return true; - } - } - else - { - // support for i8 is only implemented on gfx11 - if constexpr(std::is_same_v && - std::is_same_v && std::is_same_v) - { - return true; - } - } return false; } diff --git a/test/grouped_convnd_bwd_weight/test_grouped_conv_bwd_weight_xdl_bilinear.cpp b/test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight_bilinear.cpp similarity index 89% rename from test/grouped_convnd_bwd_weight/test_grouped_conv_bwd_weight_xdl_bilinear.cpp rename to test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight_bilinear.cpp index fe71ba86c0..08f509a7e5 100644 --- a/test/grouped_convnd_bwd_weight/test_grouped_conv_bwd_weight_xdl_bilinear.cpp +++ b/test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight_bilinear.cpp @@ -212,7 +212,34 @@ class TestGroupedConvndBwdWeight : public ::testing::Test } float avg_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr}); wei_device_buf.FromDevice(wei_device.mData.data()); - passed &= ck::utils::check_err(wei_device, wei_host, "Error: incorrect results!"); + + using AccDataType = float; + float max_accumulated_value = + *std::max_element(wei_host.mData.begin(), wei_host.mData.end()); + + const ck::index_t num_accums = out.GetElementSize() / conv_param.K_; + const ck::index_t num_accums_split_k = split_k; + double rtol = + ck::utils::get_relative_threshold( + num_accums / num_accums_split_k); + double atol = + ck::utils::get_absolute_threshold( + max_accumulated_value / num_accums_split_k, + num_accums / num_accums_split_k); + + // Calculate error due to split_k accumulation + auto rtol_split_k = + ck::utils::get_relative_threshold( + num_accums_split_k); + auto atol_split_k = + ck::utils::get_absolute_threshold( + max_accumulated_value, num_accums_split_k); + // Use higher threshold + rtol = std::max(rtol, rtol_split_k); + atol = std::max(atol, atol_split_k); + + passed &= ck::utils::check_err( + wei_device, wei_host, "Error: incorrect results!", rtol, atol); std::size_t flop = conv_param.GetFlops() + @@ -236,6 +263,7 @@ class TestGroupedConvndBwdWeight : public ::testing::Test std::cout << "grouped_conv_bwd_weight_instance (" << instance_index << "/" << num_kernel << "): Passed" << std::endl; } + printf("\033[36mvalids: %d\033[0m\n", num_kernel); return passed; } diff --git a/test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight_scale.cpp b/test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight_scale.cpp new file mode 100644 index 0000000000..5600ab5c0a --- /dev/null +++ b/test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight_scale.cpp @@ -0,0 +1,294 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include +#include +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_conv_bwd_weight_multiple_d.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight_scale.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/convolution_parameter.hpp" +#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_conv_bwd_weight.hpp" + +template +class TestGroupedConvndBwdWeight : public ::testing::Test +{ + protected: + using InDataType = std::tuple_element_t<0, Tuple>; + using WeiDataType = std::tuple_element_t<1, Tuple>; + using OutDataType = std::tuple_element_t<2, Tuple>; + using InLayout = ck::tensor_layout::convolution::NDHWGC; + using WeiLayout = ck::tensor_layout::convolution::GKZYXC; + using OutLayout = ck::tensor_layout::convolution::NDHWGK; + using InElementOp = ck::tensor_operation::element_wise::PassThrough; + using WeiElementOp = ck::tensor_operation::element_wise::Scale; + using OutElementOp = ck::tensor_operation::element_wise::PassThrough; + + static constexpr ck::index_t NDimSpatial = std::tuple_element_t<3, Tuple>{}; + static constexpr float alpha = 2.f; + + std::vector conv_params; + std::vector split_ks{1, 2}; + + void RunReference(ck::utils::conv::ConvParam& conv_param, + ck::Tensor& in, + ck::Tensor& wei_host, + ck::Tensor& out) + { + auto ref_conv = + ck::tensor_operation::host::ReferenceConvBwdWeight /*Num D Elementwise Tensors*/ + {}; + + auto ref_invoker = ref_conv.MakeInvoker(); + auto ref_argument = ref_conv.MakeArgument(in, + wei_host, + out, + conv_param.conv_filter_strides_, + conv_param.conv_filter_dilations_, + conv_param.input_left_pads_, + conv_param.input_right_pads_, + InElementOp{}, + WeiElementOp{alpha}, + OutElementOp{}, + {}, + {}, + {}); + + ref_invoker.Run(ref_argument); + } + + bool PerformConvWeightScale(ck::utils::conv::ConvParam& conv_param, const ck::index_t split_k) + { + bool passed = true; + + const auto in_g_n_c_wis_desc = + ck::utils::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed( + conv_param); + const auto wei_g_k_c_xs_desc = + ck::utils::conv::make_weight_host_tensor_descriptor_g_k_c_xs_packed( + conv_param); + const auto out_g_n_k_wos_desc = + ck::utils::conv::make_output_host_tensor_descriptor_g_n_k_wos_packed( + conv_param); + + ck::Tensor in(in_g_n_c_wis_desc); + ck::Tensor out(out_g_n_k_wos_desc); + ck::Tensor wei_host(wei_g_k_c_xs_desc); + ck::Tensor wei_device(wei_g_k_c_xs_desc); + + std::cout << "in: " << in.mDesc << std::endl; + std::cout << "wei: " << wei_host.mDesc << std::endl; + std::cout << "out: " << out.mDesc << std::endl; + + in.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + out.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + + ck::DeviceMem in_device_buf(sizeof(InDataType) * in.mDesc.GetElementSpaceSize()); + ck::DeviceMem out_device_buf(sizeof(OutDataType) * out.mDesc.GetElementSpaceSize()); + ck::DeviceMem wei_device_buf(sizeof(WeiDataType) * wei_device.mDesc.GetElementSpaceSize()); + in_device_buf.ToDevice(in.mData.data()); + wei_device_buf.ToDevice(wei_device.mData.data()); + out_device_buf.ToDevice(out.mData.data()); + + std::array b_g_n_c_wis_lengths{}; + std::array b_g_n_c_wis_strides{}; + std::array e_g_k_c_xs_lengths{}; + std::array e_g_k_c_xs_strides{}; + std::array a_g_n_k_wos_lengths{}; + std::array a_g_n_k_wos_strides{}; + std::array conv_filter_strides{}; + std::array conv_filter_dilations{}; + std::array input_left_pads{}; + std::array input_right_pads{}; + + auto copy = [](const auto& x, auto& y) { ck::ranges::copy(x, y.begin()); }; + + copy(in_g_n_c_wis_desc.GetLengths(), b_g_n_c_wis_lengths); + copy(in_g_n_c_wis_desc.GetStrides(), b_g_n_c_wis_strides); + copy(wei_g_k_c_xs_desc.GetLengths(), e_g_k_c_xs_lengths); + copy(wei_g_k_c_xs_desc.GetStrides(), e_g_k_c_xs_strides); + copy(out_g_n_k_wos_desc.GetLengths(), a_g_n_k_wos_lengths); + copy(out_g_n_k_wos_desc.GetStrides(), a_g_n_k_wos_strides); + copy(conv_param.conv_filter_strides_, conv_filter_strides); + copy(conv_param.conv_filter_dilations_, conv_filter_dilations); + copy(conv_param.input_left_pads_, input_left_pads); + copy(conv_param.input_right_pads_, input_right_pads); + + RunReference(conv_param, in, wei_host, out); + + using DeviceOp = + ck::tensor_operation::device::DeviceGroupedConvBwdWeightMultipleD, + InDataType, + WeiDataType, + OutDataType, + ck::Tuple<>, + InElementOp, + WeiElementOp, + OutElementOp>; + + // get device op instances + const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + int num_kernel = 0; + + for(std::size_t i = 0; i < op_ptrs.size(); ++i) + { + auto& op_ptr = op_ptrs[i]; + auto argument_ptr = op_ptr->MakeArgumentPointer( + static_cast(in_device_buf.GetDeviceBuffer()), + static_cast(wei_device_buf.GetDeviceBuffer()), + static_cast(out_device_buf.GetDeviceBuffer()), + std::array{}, + b_g_n_c_wis_lengths, + b_g_n_c_wis_strides, + e_g_k_c_xs_lengths, + e_g_k_c_xs_strides, + a_g_n_k_wos_lengths, + a_g_n_k_wos_strides, + std::array, 0>{}, + std::array, 0>{}, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + InElementOp{}, + WeiElementOp{alpha}, + OutElementOp{}, + split_k); + + ck::DeviceMem workspace_buf(op_ptr->GetWorkSpaceSize(argument_ptr.get())); + op_ptr->SetWorkSpacePointer(argument_ptr.get(), workspace_buf.GetDeviceBuffer()); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + std::string op_name = op_ptr->GetTypeString(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + num_kernel++; + float avg_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr}); + wei_device_buf.FromDevice(wei_device.mData.data()); + + using AccDataType = float; + float max_accumulated_value = + *std::max_element(wei_host.mData.begin(), wei_host.mData.end()); + + const ck::index_t num_accums = out.GetElementSize() / conv_param.K_; + const ck::index_t num_accums_split_k = split_k; + double rtol = + ck::utils::get_relative_threshold( + num_accums / num_accums_split_k); + double atol = + ck::utils::get_absolute_threshold( + max_accumulated_value / num_accums_split_k, + num_accums / num_accums_split_k); + + // Calculate error due to split_k accumulation + auto rtol_split_k = + ck::utils::get_relative_threshold( + num_accums_split_k); + auto atol_split_k = + ck::utils::get_absolute_threshold( + max_accumulated_value, num_accums_split_k); + // Use higher threshold + rtol = std::max(rtol, rtol_split_k); + atol = std::max(atol, atol_split_k); + + passed &= ck::utils::check_err( + wei_device, wei_host, "Error: incorrect results!", rtol, atol); + + std::size_t flop = + conv_param.GetFlops() + + 3 * conv_param.GetOutputByte() / sizeof(WeiDataType); + std::size_t num_bytes = conv_param.GetByte() + + conv_param.GetOutputByte(); + + float tflops = static_cast(flop) / 1.E9 / avg_time; + float gb_per_sec = num_bytes / 1.E6 / avg_time; + + std::cout << "Perf: " << std::setw(10) << avg_time << " ms, " << tflops + << " TFlops, " << gb_per_sec << " GB/s, " << op_name << std::endl; + } + else + { + std::cerr << op_name << " does not support this problem" << std::endl; + } + } + + printf("\033[36mvalids: %d\033[0m\n", num_kernel); + return passed; + } + + void Run() + { + EXPECT_FALSE(conv_params.empty()); + bool pass = true; + + for(auto split_k : split_ks) + { + for(auto& param : conv_params) + { + pass = pass && PerformConvWeightScale(param, split_k); + } + } + EXPECT_TRUE(pass); + } +}; + +template +class TestGroupedConvndBwdWeight3d : public TestGroupedConvndBwdWeight +{ +}; + +using KernelTypes3d = + ::testing::Types>, + std::tuple>, + std::tuple>>; + +TYPED_TEST_SUITE(TestGroupedConvndBwdWeight3d, KernelTypes3d); + +TYPED_TEST(TestGroupedConvndBwdWeight3d, Test3D) +{ + this->conv_params.clear(); + this->conv_params.push_back( + {3, 2, 16, 128, 128, {1, 1, 1}, {7, 7, 7}, {2, 2, 2}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}}); + this->conv_params.push_back( + {3, 2, 2, 128, 128, {3, 3, 3}, {14, 14, 3}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}}); + this->conv_params.push_back( + {3, 2, 32, 128, 128, {1, 1, 1}, {3, 3, 3}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}}); + this->conv_params.push_back( + {3, 1, 1, 1, 32, {3, 3, 3}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}}); + this->conv_params.push_back( + {3, 1, 1, 64, 3, {3, 3, 3}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}}); + this->conv_params.push_back( + {3, 1, 1, 1, 1, {3, 3, 3}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}}); + this->conv_params.push_back( + {3, 1, 1, 4, 4, {3, 3, 3}, {14, 28, 28}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}}); + this->Run(); +}