diff --git a/CHANGELOG.md b/CHANGELOG.md index 10d92d8e3e..f88df081da 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -14,7 +14,7 @@ None ### Additions - Added an image to a column kernel (#867) - Added a column to an image kernel (#930) -- Support for 3D grouped convolution forward on RDNA 3 GPUs (#935) +- Support for 3D grouped convolution on RDNA 3 GPUs (#935, #950, #985) - Grouped convolution support for small K and C (#822 #879 #897) - Support for NHWGC (2D and 3D) grouped convolution backward weight (#769 #804) - Support for bf16/f32/f16 and NHWGC (2D and 3d) grouped convolution backward data (#757 #799) diff --git a/example/20_grouped_conv_bwd_weight/CMakeLists.txt b/example/20_grouped_conv_bwd_weight/CMakeLists.txt index c83b384a01..1ecbab5825 100644 --- a/example/20_grouped_conv_bwd_weight/CMakeLists.txt +++ b/example/20_grouped_conv_bwd_weight/CMakeLists.txt @@ -1,7 +1,8 @@ -list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) +list(APPEND gpu_list_xdl gfx908 gfx90a gfx940 gfx941 gfx942) +list(APPEND gpu_list_wmma gfx1100 gfx1101 gfx1102) set(target 0) foreach(gpu IN LISTS GPU_TARGETS) - if(gpu IN_LIST gpu_list AND target EQUAL 0) + if(gpu IN_LIST gpu_list_xdl AND target EQUAL 0) add_custom_target(example_grouped_conv_bwd_weight) add_example_executable(example_grouped_conv_bwd_weight_xdl_fp16 grouped_conv_bwd_weight_xdl_fp16.cpp) if(result EQUAL 0) @@ -18,6 +19,14 @@ foreach(gpu IN LISTS GPU_TARGETS) endif() endif() set(target 1) + endif() + if(gpu IN_LIST gpu_list_wmma AND target EQUAL 0) + add_custom_target(example_grouped_conv_bwd_weight) + add_example_executable(example_grouped_conv_bwd_weight_wmma_fp16 grouped_conv_bwd_weight_wmma_fp16.cpp) + if(result EQUAL 0) + add_dependencies(example_grouped_conv_bwd_weight example_grouped_conv_bwd_weight_wmma_fp16) + endif() + set(target 1) endif() endforeach() diff --git a/example/20_grouped_conv_bwd_weight/common.hpp b/example/20_grouped_conv_bwd_weight/common.hpp index 0e7eb00e4b..1d2206ff72 100644 --- a/example/20_grouped_conv_bwd_weight/common.hpp +++ b/example/20_grouped_conv_bwd_weight/common.hpp @@ -46,25 +46,21 @@ struct CommonLayoutSetting using OutputLayout = OutputLay; }; -template -struct CommonLayoutSettingSelector; - namespace ctl = ck::tensor_layout::convolution; - -template <> -struct CommonLayoutSettingSelector<1> final : CommonLayoutSetting -{ -}; - -template <> -struct CommonLayoutSettingSelector<2> final - : CommonLayoutSetting -{ -}; - -template <> -struct CommonLayoutSettingSelector<3> final - : CommonLayoutSetting +template +struct CommonLayoutSettingSelector + : CommonLayoutSetting>, + ck::tuple_element_t>, + ck::tuple_element_t>> { }; @@ -84,10 +80,10 @@ struct ExecutionConfig final bool time_kernel = false; }; -#define DefaultConvParam \ - ck::utils::conv::ConvParam \ - { \ - 2, 4, 1, 128, 256, {3, 3}, {14, 14}, {1, 1}, {1, 1}, {1, 1}, { 1, 1 } \ +#define DefaultConvParam \ + ck::utils::conv::ConvParam \ + { \ + 3, 4, 1, 128, 256, {3, 3, 3}, {14, 14, 14}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, { 1, 1, 1 } \ } inline void print_help_msg() diff --git a/example/20_grouped_conv_bwd_weight/grouped_conv_bwd_weight_dl_fp16.cpp b/example/20_grouped_conv_bwd_weight/grouped_conv_bwd_weight_dl_fp16.cpp index dccde28aa0..44ba3d8e02 100644 --- a/example/20_grouped_conv_bwd_weight/grouped_conv_bwd_weight_dl_fp16.cpp +++ b/example/20_grouped_conv_bwd_weight/grouped_conv_bwd_weight_dl_fp16.cpp @@ -76,4 +76,23 @@ using HostConvBwdWeightInstance = ck::tensor_operation::host::ReferenceConvBwdWe #include "run_grouped_conv_bwd_weight_example.inc" -int main(int argc, char* argv[]) { return !run_grouped_conv_bwd_weight_example(argc, argv); } +int main(int argc, char* argv[]) +{ + ExecutionConfig config; + ck::utils::conv::ConvParam conv_param = DefaultConvParam; + + if(!parse_cmd_args(argc, argv, config, conv_param)) + { + return 1; + } + + switch(conv_param.num_dim_spatial_) + { + case 1: return !run_grouped_conv_bwd_weight<1>(config, conv_param); + case 2: return !run_grouped_conv_bwd_weight<2>(config, conv_param); + case 3: return !run_grouped_conv_bwd_weight<3>(config, conv_param); + default: break; + } + + return 1; +} diff --git a/example/20_grouped_conv_bwd_weight/grouped_conv_bwd_weight_wmma_fp16.cpp b/example/20_grouped_conv_bwd_weight/grouped_conv_bwd_weight_wmma_fp16.cpp new file mode 100644 index 0000000000..2c2950b17e --- /dev/null +++ b/example/20_grouped_conv_bwd_weight/grouped_conv_bwd_weight_wmma_fp16.cpp @@ -0,0 +1,88 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" + +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle.hpp" + +using InDataType = F16; +using WeiDataType = F16; +using OutDataType = F16; +using AccDataType = F32; + +using InElementOp = PassThrough; +using WeiElementOp = PassThrough; +using OutElementOp = PassThrough; + +template +using DeviceConvBwdWeightInstance = + ck::tensor_operation::device::DeviceGroupedConvBwdWeight_Wmma_CShuffle< + NDimSpatial, + ck::tensor_layout::convolution::GNDHWC, + ck::tensor_layout::convolution::GKZYXC, + ck::tensor_layout::convolution::GNDHWK, + InDataType, // InDataType + WeiDataType, // WeiDataType + OutDataType, // OutDataType + AccDataType, // AccDataType + InElementOp, // InElementwiseOperation + WeiElementOp, // WeiElementwiseOperation + OutElementOp, // OutElementwiseOperation + ConvBwdWeightDefault, // ConvolutionBackwardWeightSpecialization + 256, // BlockSize + 128, // MPerBlock + 128, // NPerBlock + 4, // K0PerBlock + 8, // K1 + 16, // MPerWMMA + 16, // NPerWMMA + 4, // MRepeat + 2, // NRepeat + S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1 + S<0, 2, 1>, // ABlockTransferThreadClusterArrangeOrder + S<0, 2, 1>, // ABlockTransferSrcAccessOrder + 1, // ABlockTransferSrcVectorDim + 1, // ABlockTransferSrcScalarPerVector + 8, // ABlockTransferDstScalarPerVector_AK1 + true, // ABlockLdsExtraM + S<4, 64, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1 + S<0, 2, 1>, // BBlockTransferThreadClusterArrangeOrder + S<0, 2, 1>, // BBlockTransferSrcAccessOrder + 1, // BBlockTransferSrcVectorDim + 1, // BBlockTransferSrcScalarPerVector + 8, // BBlockTransferDstScalarPerVector_BK1 + true, // BBlockLdsExtraN + 4, + 2, + S<1, 32, 1, 8>, + 1>; + +template +using HostConvBwdWeightInstance = ck::tensor_operation::host::ReferenceConvBwdWeight; + +#include "run_grouped_conv_bwd_weight_example.inc" + +int main(int argc, char* argv[]) +{ + ExecutionConfig config; + ck::utils::conv::ConvParam conv_param = DefaultConvParam; + + if(!parse_cmd_args(argc, argv, config, conv_param)) + { + return 1; + } + + switch(conv_param.num_dim_spatial_) + { + case 3: return !run_grouped_conv_bwd_weight<3>(config, conv_param); + default: break; + } + + return 1; +} diff --git a/example/20_grouped_conv_bwd_weight/grouped_conv_bwd_weight_xdl_bf16.cpp b/example/20_grouped_conv_bwd_weight/grouped_conv_bwd_weight_xdl_bf16.cpp index 467209ef28..80b9724930 100644 --- a/example/20_grouped_conv_bwd_weight/grouped_conv_bwd_weight_xdl_bf16.cpp +++ b/example/20_grouped_conv_bwd_weight/grouped_conv_bwd_weight_xdl_bf16.cpp @@ -78,4 +78,23 @@ using HostConvBwdWeightInstance = ck::tensor_operation::host::ReferenceConvBwdWe #include "run_grouped_conv_bwd_weight_example.inc" -int main(int argc, char* argv[]) { return !run_grouped_conv_bwd_weight_example(argc, argv); } +int main(int argc, char* argv[]) +{ + ExecutionConfig config; + ck::utils::conv::ConvParam conv_param = DefaultConvParam; + + if(!parse_cmd_args(argc, argv, config, conv_param)) + { + return 1; + } + + switch(conv_param.num_dim_spatial_) + { + case 1: return !run_grouped_conv_bwd_weight<1>(config, conv_param); + case 2: return !run_grouped_conv_bwd_weight<2>(config, conv_param); + case 3: return !run_grouped_conv_bwd_weight<3>(config, conv_param); + default: break; + } + + return 1; +} diff --git a/example/20_grouped_conv_bwd_weight/grouped_conv_bwd_weight_xdl_fp16.cpp b/example/20_grouped_conv_bwd_weight/grouped_conv_bwd_weight_xdl_fp16.cpp index 8b751b0d98..5c1c97d615 100644 --- a/example/20_grouped_conv_bwd_weight/grouped_conv_bwd_weight_xdl_fp16.cpp +++ b/example/20_grouped_conv_bwd_weight/grouped_conv_bwd_weight_xdl_fp16.cpp @@ -77,4 +77,23 @@ using HostConvBwdWeightInstance = ck::tensor_operation::host::ReferenceConvBwdWe #include "run_grouped_conv_bwd_weight_example.inc" -int main(int argc, char* argv[]) { return !run_grouped_conv_bwd_weight_example(argc, argv); } +int main(int argc, char* argv[]) +{ + ExecutionConfig config; + ck::utils::conv::ConvParam conv_param = DefaultConvParam; + + if(!parse_cmd_args(argc, argv, config, conv_param)) + { + return 1; + } + + switch(conv_param.num_dim_spatial_) + { + case 1: return !run_grouped_conv_bwd_weight<1>(config, conv_param); + case 2: return !run_grouped_conv_bwd_weight<2>(config, conv_param); + case 3: return !run_grouped_conv_bwd_weight<3>(config, conv_param); + default: break; + } + + return 1; +} diff --git a/example/20_grouped_conv_bwd_weight/grouped_conv_bwd_weight_xdl_fp16_comp_bf8_fp8.cpp b/example/20_grouped_conv_bwd_weight/grouped_conv_bwd_weight_xdl_fp16_comp_bf8_fp8.cpp index c678867e74..0d0d58eb2a 100644 --- a/example/20_grouped_conv_bwd_weight/grouped_conv_bwd_weight_xdl_fp16_comp_bf8_fp8.cpp +++ b/example/20_grouped_conv_bwd_weight/grouped_conv_bwd_weight_xdl_fp16_comp_bf8_fp8.cpp @@ -83,4 +83,23 @@ using HostConvBwdWeightInstance = ck::tensor_operation::host::ReferenceConvBwdWe #include "run_grouped_conv_bwd_weight_example.inc" -int main(int argc, char* argv[]) { return !run_grouped_conv_bwd_weight_example(argc, argv); } +int main(int argc, char* argv[]) +{ + ExecutionConfig config; + ck::utils::conv::ConvParam conv_param = DefaultConvParam; + + if(!parse_cmd_args(argc, argv, config, conv_param)) + { + return 1; + } + + switch(conv_param.num_dim_spatial_) + { + case 1: return !run_grouped_conv_bwd_weight<1>(config, conv_param); + case 2: return !run_grouped_conv_bwd_weight<2>(config, conv_param); + case 3: return !run_grouped_conv_bwd_weight<3>(config, conv_param); + default: break; + } + + return 1; +} diff --git a/example/20_grouped_conv_bwd_weight/run_grouped_conv_bwd_weight_example.inc b/example/20_grouped_conv_bwd_weight/run_grouped_conv_bwd_weight_example.inc index 755d6addd8..73d15f3ea8 100644 --- a/example/20_grouped_conv_bwd_weight/run_grouped_conv_bwd_weight_example.inc +++ b/example/20_grouped_conv_bwd_weight/run_grouped_conv_bwd_weight_example.inc @@ -5,7 +5,7 @@ template bool run_grouped_conv_bwd_weight(const ExecutionConfig& config, const ck::utils::conv::ConvParam& conv_param) { - // Dl op doesn't support split_k > 1 + // Dl and WMMA ops don't support split_k > 1 constexpr ck::index_t split_k = 1; const auto in_g_n_c_wis_desc = @@ -143,23 +143,3 @@ bool run_grouped_conv_bwd_weight(const ExecutionConfig& config, return true; } - -bool run_grouped_conv_bwd_weight_example(int argc, char* argv[]) -{ - ExecutionConfig config; - ck::utils::conv::ConvParam conv_param = DefaultConvParam; - - if(!parse_cmd_args(argc, argv, config, conv_param)) - { - return false; - } - - switch(conv_param.num_dim_spatial_) - { - case 1: return run_grouped_conv_bwd_weight<1>(config, conv_param); - case 2: return run_grouped_conv_bwd_weight<2>(config, conv_param); - case 3: return run_grouped_conv_bwd_weight<3>(config, conv_param); - } - - return false; -} diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp index 35581de72b..e2f9f24bf5 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp @@ -565,7 +565,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle auto launch_kernel = [&](auto has_main_k_block_loop) { constexpr bool has_main_loop = has_main_k_block_loop.value; - const auto kernel = kernel_grouped_conv_fwd_multiple_d_wmma_cshuffle< + const auto kernel = kernel_grouped_conv_multiple_d_wmma_cshuffle< GridwiseGemm, ADataType, BDataType, diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_dl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_dl.hpp index ef57e6e4d2..e2c9dca904 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_dl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_dl.hpp @@ -12,6 +12,7 @@ #include "ck/tensor_description/tensor_descriptor_helper.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/device_grouped_conv_bwd_weight.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp" #include "ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_gemm_dl_v1r3.hpp" #include "ck/tensor_operation/gpu/device/matrix_padder.hpp" @@ -22,32 +23,6 @@ namespace ck { namespace tensor_operation { namespace device { -namespace { - -struct ComputePtrOffsetOfStridedBatch -{ - __host__ __device__ constexpr long_index_t GetAPtrOffset(index_t g_idx) const - { - return g_idx * static_cast(BatchStrideA_); - } - - __host__ __device__ constexpr long_index_t GetBPtrOffset(index_t g_idx) const - { - return g_idx * static_cast(BatchStrideB_); - } - - __host__ __device__ constexpr long_index_t GetCPtrOffset(index_t g_idx) const - { - return g_idx * static_cast(BatchStrideC_); - } - - index_t BatchStrideA_; - index_t BatchStrideB_; - index_t BatchStrideC_; -}; - -} // namespace - template compute_ptr_offset_of_batch_; // element-wise op OutElementwiseOperation a_element_op_; @@ -1024,7 +999,7 @@ struct DeviceGroupedConvBwdWeight_Dl : public DeviceGroupedConvBwdWeight, remove_reference_t, remove_reference_t, - ComputePtrOffsetOfStridedBatch, + ComputePtrOffsetOfStridedBatch, has_main_loop, has_double_loop>; diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle.hpp new file mode 100644 index 0000000000..e8a721eb36 --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle.hpp @@ -0,0 +1,877 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_conv_bwd_weight.hpp" +#include "ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template ::type = false> +struct DeviceGroupedConvBwdWeight_Wmma_CShuffle + : public DeviceGroupedConvBwdWeight +{ + using DeviceOp = DeviceGroupedConvBwdWeight_Wmma_CShuffle; + + using ADataType = OutDataType; + using BDataType = InDataType; + using CDataType = WeiDataType; + + using AElementwiseOperation = OutElementwiseOperation; + using BElementwiseOperation = InElementwiseOperation; + using CElementwiseOperation = WeiElementwiseOperation; + + // TODO make A/B datatype different + using ABDataType = InDataType; + + // 3d + static constexpr bool is_NDHWGK_GKZYXC_NDHWGC = + is_same_v && + is_same_v && + is_same_v; + static constexpr bool is_GNDHWK_GKZYXC_GNDHWC = + is_same_v && + is_same_v && + is_same_v; + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + + static constexpr auto GemmK1Number = Number{}; + static constexpr index_t KPerBlock = K0PerBlock * GemmK1Number; + + template ::type = false> + constexpr static auto + make_out_grid_desc(const index_t N, + const index_t Do, + const index_t Ho, + const index_t Wo, + const index_t K, + const std::array& output_strides) + { + const index_t WoStride = output_strides[5]; + const auto KStride = Number<1>{}; + return make_naive_tensor_descriptor(make_tuple(N * Do * Ho * Wo, K), + make_tuple(WoStride, KStride)); + } + + template ::type = false> + constexpr static auto + make_in_grid_desc(const index_t N, + const index_t Di, + const index_t Hi, + const index_t Wi, + const index_t C, + const std::array& input_strides) + { + const index_t NStride = input_strides[1]; + const index_t DiStride = input_strides[3]; + const index_t HiStride = input_strides[4]; + const index_t WiStride = input_strides[5]; + const auto CStride = input_strides[2]; + if constexpr(ConvBackwardWeightSpecialization == + ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0) + { + return make_naive_tensor_descriptor(make_tuple(N * Di * Hi * Wi, C), + make_tuple(WiStride, CStride)); + } + else + { + return make_naive_tensor_descriptor( + make_tuple(N, Di, Hi, Wi, C), + make_tuple(NStride, DiStride, HiStride, WiStride, CStride)); + } + } + + template ::type = false> + constexpr static auto + make_wei_grid_desc(const index_t K, + const index_t Z, + const index_t Y, + const index_t X, + const index_t C, + const std::array& weights_strides) + { + const auto CStride = Number<1>{}; + const auto KStride = weights_strides[1]; + return make_naive_tensor_descriptor(make_tuple(K, Z * Y * X * C), + make_tuple(KStride, CStride)); + } + + template ::type = false> + static auto MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N( + const index_t N, + const index_t K, + const index_t C, + const std::array& input_spatial_lengths, + const std::array& filter_spatial_lengths, + const std::array& output_spatial_lengths, + const std::array& input_strides, + const std::array& weights_strides, + const std::array& output_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads) + { + using namespace ck; + + const index_t Di = input_spatial_lengths[0]; + const index_t Hi = input_spatial_lengths[1]; + const index_t Wi = input_spatial_lengths[2]; + + const index_t Do = output_spatial_lengths[0]; + const index_t Ho = output_spatial_lengths[1]; + const index_t Wo = output_spatial_lengths[2]; + + const index_t Z = filter_spatial_lengths[0]; + const index_t Y = filter_spatial_lengths[1]; + const index_t X = filter_spatial_lengths[2]; + + const index_t ConvStrideD = conv_filter_strides[0]; + const index_t ConvStrideH = conv_filter_strides[1]; + const index_t ConvStrideW = conv_filter_strides[2]; + + const index_t ConvDilationD = conv_filter_dilations[0]; + const index_t ConvDilationH = conv_filter_dilations[1]; + const index_t ConvDilationW = conv_filter_dilations[2]; + + const index_t InLeftPadD = input_left_pads[0]; + const index_t InLeftPadH = input_left_pads[1]; + const index_t InLeftPadW = input_left_pads[2]; + + const index_t InRightPadD = input_right_pads[0]; + const index_t InRightPadH = input_right_pads[1]; + const index_t InRightPadW = input_right_pads[2]; + + const index_t GemmKTotal = N * Do * Ho * Wo; + const index_t GemmM = K; + const index_t GemmN = C * Z * X * Y; + + const auto PadGemmM = (MPerBlock - GemmM % MPerBlock) % MPerBlock; + const auto PadGemmN = (NPerBlock - GemmN % NPerBlock) % NPerBlock; + + const index_t GemmK0 = + math::integer_divide_ceil(GemmKTotal, GemmK1Number * K0PerBlock) * K0PerBlock; + const index_t GemmKPad = GemmK0 * GemmK1Number; + + const auto out_grid_desc = make_out_grid_desc(N, Do, Ho, Wo, K, output_strides); + const auto in_grid_desc = make_in_grid_desc(N, Di, Hi, Wi, C, input_strides); + const auto wei_grid_desc = make_wei_grid_desc(K, Z, Y, X, C, weights_strides); + + if constexpr(ConvBackwardWeightSpecialization == + ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0) + { + // A: output tensor + const auto out_gemmkpad_gemmm_grid_desc = transform_tensor_descriptor( + out_grid_desc, + make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), + make_pass_through_transform(GemmM)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( + out_gemmkpad_gemmm_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)), + make_pass_through_transform(GemmM)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + // B: input tensor + const auto in_gemmkpad_gemmn_grid_desc = transform_tensor_descriptor( + in_grid_desc, + make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( + in_gemmkpad_gemmn_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc, + in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc, + wei_grid_desc); + } + else + { + // A: output tensor + const auto out_gemmkpad_gemmm_grid_desc = transform_tensor_descriptor( + out_grid_desc, + make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), + make_pass_through_transform(GemmM)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( + out_gemmkpad_gemmm_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)), + make_pass_through_transform(GemmM)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + // B: input tensor + const auto in_n_dip_hip_wip_c_grid_desc = transform_tensor_descriptor( + in_grid_desc, + make_tuple(make_pass_through_transform(N), + make_pad_transform(Di, InLeftPadD, InRightPadD), + make_pad_transform(Hi, InLeftPadH, InRightPadH), + make_pad_transform(Wi, InLeftPadW, InRightPadW), + make_pass_through_transform(C)), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{})); + + const auto in_n_z_do_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor( + in_n_dip_hip_wip_c_grid_desc, + make_tuple( + make_pass_through_transform(N), + make_embed_transform(make_tuple(Z, Do), make_tuple(ConvDilationD, ConvStrideD)), + make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)), + make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)), + make_pass_through_transform(C)), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), + make_tuple(Sequence<0>{}, + Sequence<1, 2>{}, + Sequence<3, 4>{}, + Sequence<5, 6>{}, + Sequence<7>{})); + + const auto in_gemmktotal_gemmn_grid_desc = transform_tensor_descriptor( + in_n_z_do_y_ho_x_wo_c_grid_desc, + make_tuple(make_merge_transform(make_tuple(Z, Y, X, C)), + make_merge_transform(make_tuple(N, Do, Ho, Wo))), + make_tuple(Sequence<1, 3, 5, 7>{}, Sequence<0, 2, 4, 6>{}), + make_tuple(Sequence<1>{}, Sequence<0>{})); + + const auto in_gemmkpad_gemmn_grid_desc = transform_tensor_descriptor( + in_gemmktotal_gemmn_grid_desc, + make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( + in_gemmkpad_gemmn_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + // Pad + const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_pad_grid_desc = + transform_tensor_descriptor( + out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc, + make_tuple(make_pass_through_transform(GemmK0), + make_right_pad_transform(GemmM, PadGemmM), + make_pass_through_transform(GemmK1Number)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_pad_grid_desc = + transform_tensor_descriptor( + in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc, + make_tuple(make_pass_through_transform(GemmK0), + make_right_pad_transform(GemmN, PadGemmN), + make_pass_through_transform(GemmK1Number)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + const auto wei_gemmm_gemmn_pad_grid_desc = + transform_tensor_descriptor(wei_grid_desc, + make_tuple(make_right_pad_transform(GemmM, PadGemmM), + make_right_pad_transform(GemmN, PadGemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_pad_grid_desc, + in_gemmkbatch_gemmk0_gemmn_gemmk1_pad_grid_desc, + wei_gemmm_gemmn_pad_grid_desc); + } + } + + template ::type = false> + static auto GetABCGridDesc() + { + const index_t dim = 1; + const std::array lengths{1, 1, 1}; + const std::array strides{1, 1, 1, 1, 1, 1}; + const std::array params{1, 1, 1}; + return MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<3>(dim, + dim, + dim, + lengths, + lengths, + lengths, + strides, + strides, + strides, + params, + params, + params, + params); + } + + using ABCGridDescs = decltype(GetABCGridDesc()); + + using AGridDesc_K0_M_K1 = remove_cvref_t; + using BGridDesc_K0_N_K1 = remove_cvref_t; + using CGridDesc_M_N = remove_cvref_t; + + using GridwiseGemm = GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle< + // DataType Family + ADataType, + BDataType, + AccDataType, + CDataType, + Tuple<>, + CDataType, + // InMemory Data Descriptor + AGridDesc_K0_M_K1, + BGridDesc_K0_N_K1, + Tuple<>, + CGridDesc_M_N, + // ElementwiseOp Family + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + InMemoryDataOperationEnum::Set, + // Tiling Family + MPerBlock, + NPerBlock, + K0PerBlock, + MPerWMMA, + NPerWMMA, + K1, + MRepeat, + NRepeat, + // ThreadCluster Family + BlockSize, + ABlockTransferThreadClusterLengths_K0_M_K1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_K1, + false, + ABlockLdsAddExtraM, + BBlockTransferThreadClusterLengths_K0_N_K1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_K1, + false, + BBlockLdsAddExtraN, + CShuffleMRepeatPerShuffle, + CShuffleNRepeatPerShuffle, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CShuffleBlockTransferScalarPerVector_NPerBlock, + NumGemmKPrefetchStage, + LoopSched, + PipelineVer>; + + using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = + decltype(GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(Tuple<>{})); + + using CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = + decltype(GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + CGridDesc_M_N{})); + + using Block2CTileMap = decltype(GridwiseGemm::MakeDefaultBlock2CTileMap( + CGridDesc_M_N{}, I1 /* M01 */, I1 /* N01 */)); + + struct Argument : public BaseArgument + { + Argument(const InDataType* p_in_grid, + WeiDataType* p_wei_grid, + const OutDataType* p_out_grid, + const std::array& a_g_n_c_wis_lengths, // input + const std::array& a_g_n_c_wis_strides, + const std::array& b_g_k_c_xs_lengths, // weight + const std::array& b_g_k_c_xs_strides, + const std::array& e_g_n_k_wos_lengths, // output + const std::array& e_g_n_k_wos_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads, + InElementwiseOperation in_element_op, + WeiElementwiseOperation wei_element_op, + OutElementwiseOperation out_element_op, + index_t split_k) + : p_a_grid_{p_out_grid}, + p_b_grid_{p_in_grid}, + p_c_grid_{p_wei_grid}, + a_grid_desc_kbatch_k0_m_k1_{}, + b_grid_desc_kbatch_k0_n_k1_{}, + c_grid_desc_m_n_{}, + c_grid_desc_mblock_mperblock_nblock_nperblock_{}, + block_2_ctile_map_{}, + compute_ptr_offset_of_batch_{}, + a_element_op_{out_element_op}, + b_element_op_{in_element_op}, + c_element_op_{wei_element_op}, + Conv_G_{a_g_n_c_wis_lengths[0]}, + Conv_N_{a_g_n_c_wis_lengths[1]}, + Conv_K_{b_g_k_c_xs_lengths[1]}, + Conv_C_{a_g_n_c_wis_lengths[2]}, + input_spatial_lengths_{}, + filter_spatial_lengths_{}, + output_spatial_lengths_{}, + conv_filter_strides_{conv_filter_strides}, + input_left_pads_{input_left_pads}, + input_right_pads_{input_right_pads}, + k_batch_{split_k} + { + constexpr index_t spatial_offset = 3; + std::copy(begin(a_g_n_c_wis_lengths) + spatial_offset, + end(a_g_n_c_wis_lengths), + begin(input_spatial_lengths_)); + std::copy(begin(b_g_k_c_xs_lengths) + spatial_offset, + end(b_g_k_c_xs_lengths), + begin(filter_spatial_lengths_)); + std::copy(begin(e_g_n_k_wos_lengths) + spatial_offset, + end(e_g_n_k_wos_lengths), + begin(output_spatial_lengths_)); + + const auto descs = + DeviceOp::MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N( + Conv_N_, + Conv_K_, + Conv_C_, + input_spatial_lengths_, + filter_spatial_lengths_, + output_spatial_lengths_, + a_g_n_c_wis_strides, + b_g_k_c_xs_strides, + e_g_n_k_wos_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads); + + a_grid_desc_kbatch_k0_m_k1_ = descs[I0]; + b_grid_desc_kbatch_k0_n_k1_ = descs[I1]; + c_grid_desc_m_n_ = descs[I2]; + + block_2_ctile_map_ = GridwiseGemm::MakeDefaultBlock2CTileMap( + c_grid_desc_m_n_, I1 /* M01 */, I1 /* N01 */); + + // A/B/C Batch Stride + compute_ptr_offset_of_batch_.BatchStrideA_ = e_g_n_k_wos_strides[0]; + compute_ptr_offset_of_batch_.BatchStrideB_ = a_g_n_c_wis_strides[0]; + compute_ptr_offset_of_batch_.BatchStrideE_ = + Conv_K_ * Conv_C_ * + std::accumulate(begin(filter_spatial_lengths_), + end(filter_spatial_lengths_), + index_t{1}, + std::multiplies<>{}); + + if(GridwiseGemm::CheckValidity(a_grid_desc_kbatch_k0_m_k1_, + b_grid_desc_kbatch_k0_n_k1_, + c_grid_desc_m_n_, + block_2_ctile_map_)) + { + c_grid_desc_mblock_mperblock_nblock_nperblock_ = + GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + c_grid_desc_m_n_); + } + } + + const ADataType* p_a_grid_; + const BDataType* p_b_grid_; + CDataType* p_c_grid_; + AGridDesc_K0_M_K1 a_grid_desc_kbatch_k0_m_k1_; + BGridDesc_K0_N_K1 b_grid_desc_kbatch_k0_n_k1_; + CGridDesc_M_N c_grid_desc_m_n_; + CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock c_grid_desc_mblock_mperblock_nblock_nperblock_; + + Block2CTileMap block_2_ctile_map_; + + // for computing batch offset + ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch_; + + OutElementwiseOperation a_element_op_; + InElementwiseOperation b_element_op_; + WeiElementwiseOperation c_element_op_; + + // for checking IsSupportedArgument() + const index_t Conv_G_; + const index_t Conv_N_; + const index_t Conv_K_; + const index_t Conv_C_; + std::array input_spatial_lengths_; + std::array filter_spatial_lengths_; + std::array output_spatial_lengths_; + const std::array& conv_filter_strides_; + const std::array& input_left_pads_; + const std::array& input_right_pads_; + const index_t k_batch_; + }; + + // Invoker + struct Invoker : public BaseInvoker + { + using Argument = DeviceOp::Argument; + + void Print(const Argument& arg) + { + std::cout << "arg.a_grid_desc_kbatch_k0_m_k1_{" + << arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I0) << ", " + << arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I1) << ", " + << arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I2) << "}" << std::endl; + + std::cout << "arg.b_grid_desc_kbatch_k0_n_k1_{" + << arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I0) << ", " + << arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I1) << ", " + << arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I2) << "}" << std::endl; + + std::cout << "arg.c_grid_desc_m_n_{" << arg.c_grid_desc_m_n_.GetLength(I0) << ", " + << arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl; + } + + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + if(stream_config.log_level_ > 0) + { + Print(arg); + } + + if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_kbatch_k0_m_k1_, + arg.b_grid_desc_kbatch_k0_n_k1_, + arg.c_grid_desc_m_n_, + arg.block_2_ctile_map_)) + { + throw std::runtime_error( + "wrong! GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle has invalid " + "setting"); + } + + const index_t grid_size = + arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_) * arg.Conv_G_; + + const auto K0 = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I1); + + const bool has_main_k0_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K0); + + auto launch_kernel = [&](auto has_main_k_block_loop) { + constexpr bool has_main_loop = has_main_k_block_loop.value; + + const auto kernel = kernel_grouped_conv_multiple_d_wmma_cshuffle< + GridwiseGemm, + ADataType, + BDataType, + typename GridwiseGemm::DsGridPointer, + CDataType, + OutElementwiseOperation, + InElementwiseOperation, + WeiElementwiseOperation, + AGridDesc_K0_M_K1, + BGridDesc_K0_N_K1, + DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, + CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, + remove_reference_t, + ComputePtrOffsetOfStridedBatch, + has_main_loop>; + + using EmptyTuple = Tuple<>; + return launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + EmptyTuple{}, // Ds + arg.p_c_grid_, + arg.a_element_op_, + arg.b_element_op_, + arg.c_element_op_, + arg.Conv_G_, + arg.a_grid_desc_kbatch_k0_m_k1_, + arg.b_grid_desc_kbatch_k0_n_k1_, + DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock{}, + arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.block_2_ctile_map_, + arg.compute_ptr_offset_of_batch_); + }; + + if(has_main_k0_block_loop) + { + return launch_kernel(integral_constant{}); + } + else + { + return launch_kernel(integral_constant{}); + } + } + + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + static bool IsSupportedArgument(const Argument& arg) + { + // check device + if(get_device_name() == "gfx1100" || get_device_name() == "gfx1101" || + get_device_name() == "gfx1102") + { + if constexpr(!(is_same_v || is_same_v)) + { + return false; + } + } + else + { + return false; + } + + // TODO: Add support for split_k > 1 + if(arg.k_batch_ != 1) + { + return false; + } + + if constexpr(!(is_NDHWGK_GKZYXC_NDHWGC || is_GNDHWK_GKZYXC_GNDHWC)) + { + return false; + } + + if constexpr(ConvBackwardWeightSpecialization == + ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0) + { + // check if it's a 1x1 convolution with stride=1 and no padding + for(int i = 0; i < NDimSpatial; i++) + { + if(!(arg.filter_spatial_lengths_[i] == 1 && arg.conv_filter_strides_[i] == 1 && + arg.input_left_pads_[i] == 0 && arg.input_right_pads_[i] == 0)) + { + return false; + } + } + } + + // vector load A/B matrix from global memory + if(!(ABlockTransferSrcVectorDim == 1 && BBlockTransferSrcVectorDim == 1 && + arg.Conv_K_ % ABlockTransferSrcScalarPerVector == 0 && + arg.Conv_C_ % BBlockTransferSrcScalarPerVector == 0)) + { + return false; + } + + // vector store C matrix into global memory + if(!(arg.Conv_C_ % CShuffleBlockTransferScalarPerVector_NPerBlock == 0)) + { + return false; + } + + // Gridwise GEMM size + return GridwiseGemm::CheckValidity(arg.a_grid_desc_kbatch_k0_m_k1_, + arg.b_grid_desc_kbatch_k0_n_k1_, + arg.c_grid_desc_m_n_, + arg.block_2_ctile_map_); + } + + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto + MakeArgument(const InDataType* p_in_grid, + WeiDataType* p_wei_grid, + const OutDataType* p_out_grid, + const std::array& a_g_n_c_wis_lengths, // input + const std::array& a_g_n_c_wis_strides, + const std::array& b_g_k_c_xs_lengths, // weight + const std::array& b_g_k_c_xs_strides, + const std::array& e_g_n_k_wos_lengths, // output + const std::array& e_g_n_k_wos_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads, + InElementwiseOperation in_element_op, + WeiElementwiseOperation wei_element_op, + OutElementwiseOperation out_element_op, + const index_t split_k) + { + return Argument{p_in_grid, + p_wei_grid, + p_out_grid, + a_g_n_c_wis_lengths, // input + a_g_n_c_wis_strides, + b_g_k_c_xs_lengths, // weight + b_g_k_c_xs_strides, + e_g_n_k_wos_lengths, // output + e_g_n_k_wos_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + in_element_op, + wei_element_op, + out_element_op, + split_k}; + } + + static auto MakeInvoker() { return Invoker{}; } + + std::unique_ptr + MakeArgumentPointer(const void* p_in_grid, + void* p_wei_grid, + const void* p_out_grid, + const std::array& a_g_n_c_wis_lengths, // input + const std::array& a_g_n_c_wis_strides, + const std::array& b_g_k_c_xs_lengths, // weight + const std::array& b_g_k_c_xs_strides, + const std::array& e_g_n_k_wos_lengths, // output + const std::array& e_g_n_k_wos_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads, + InElementwiseOperation in_element_op, + WeiElementwiseOperation wei_element_op, + OutElementwiseOperation out_element_op, + const index_t split_k) override + { + return std::make_unique(static_cast(p_in_grid), + static_cast(p_wei_grid), + static_cast(p_out_grid), + a_g_n_c_wis_lengths, // input + a_g_n_c_wis_strides, + b_g_k_c_xs_lengths, // weight + b_g_k_c_xs_strides, + e_g_n_k_wos_lengths, // output + e_g_n_k_wos_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + in_element_op, + wei_element_op, + out_element_op, + split_k); + } + + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceGroupedConvBwdWeight_Wmma_CShuffle" + << "<" + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << K0PerBlock << ", " + << getConvBackwardWeightSpecializationString(ConvBackwardWeightSpecialization) << ", " + << K1 << ", " + << ABlockTransferSrcScalarPerVector << ", " + << ABlockTransferDstScalarPerVector_K1 << ", " + << BBlockTransferSrcScalarPerVector << ", " + << BBlockTransferDstScalarPerVector_K1 << ", " + << CShuffleMRepeatPerShuffle << ", " + << CShuffleNRepeatPerShuffle << ", " + << CShuffleBlockTransferScalarPerVector_NPerBlock + << ">"; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp index d4bba3bf64..cbb8643627 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp @@ -14,6 +14,7 @@ #include "ck/tensor_operation/gpu/device/device_grouped_conv_bwd_weight.hpp" #include "ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp" #include "ck/host_utility/device_prop.hpp" #include "ck/host_utility/kernel_launch.hpp" @@ -21,32 +22,6 @@ namespace ck { namespace tensor_operation { namespace device { -namespace { - -struct ComputePtrOffsetOfStridedBatch -{ - __host__ __device__ constexpr long_index_t GetAPtrOffset(index_t g_idx) const - { - return g_idx * static_cast(BatchStrideA_); - } - - __host__ __device__ constexpr long_index_t GetBPtrOffset(index_t g_idx) const - { - return g_idx * static_cast(BatchStrideB_); - } - - __host__ __device__ constexpr long_index_t GetCPtrOffset(index_t g_idx) const - { - return g_idx * static_cast(BatchStrideC_); - } - - index_t BatchStrideA_; - index_t BatchStrideB_; - index_t BatchStrideC_; -}; - -} // namespace - template compute_ptr_offset_of_batch_; index_t M01_; index_t N01_; @@ -1301,7 +1276,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle remove_reference_t, remove_reference_t, remove_reference_t, - ComputePtrOffsetOfStridedBatch, + ComputePtrOffsetOfStridedBatch, has_main_loop>; return launch_and_time_kernel(stream_config, @@ -1348,6 +1323,10 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle static bool IsSupportedArgument(const Argument& arg) { + if(!ck::is_xdl_supported()) + { + return false; + } if constexpr(NDimSpatial == 1) { if constexpr(!is_GNWK_GKXC_GNWC) diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp index e4e793eff8..3e303f5c5e 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp @@ -471,7 +471,7 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle auto launch_kernel = [&](auto has_main_k_block_loop) { constexpr bool has_main_loop = has_main_k_block_loop.value; - const auto kernel = kernel_grouped_conv_fwd_multiple_d_wmma_cshuffle< + const auto kernel = kernel_grouped_conv_multiple_d_wmma_cshuffle< GridwiseOp, ADataType, BDataType, diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp index 948232ed6f..e19a00299e 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp @@ -43,7 +43,13 @@ struct ComputePtrOffsetOfStridedBatch return ds_offset; } - __host__ __device__ constexpr long_index_t GetEPtrOffset(index_t g_idx) const + [[maybe_unused]] __host__ __device__ constexpr long_index_t GetEPtrOffset(index_t g_idx) const + { + return g_idx * static_cast(BatchStrideE_); + } + + // alias for kernels without multiple D + [[maybe_unused]] __host__ __device__ constexpr long_index_t GetCPtrOffset(index_t g_idx) const { return g_idx * static_cast(BatchStrideE_); } @@ -52,6 +58,7 @@ struct ComputePtrOffsetOfStridedBatch index_t BatchStrideB_; Array BatchStrideDs_; index_t BatchStrideE_; + index_t& BatchStrideC_ = BatchStrideE_; // alias for kernels without multiple D }; } // namespace device diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp index 98ade85a30..53b2169bc6 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp @@ -36,7 +36,7 @@ __global__ void #if CK_USE_LAUNCH_BOUNDS __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #endif - kernel_grouped_conv_fwd_multiple_d_wmma_cshuffle( + kernel_grouped_conv_multiple_d_wmma_cshuffle( const ADataType* __restrict__ p_a_grid, const BDataType* __restrict__ p_b_grid, DsPointer p_ds_grid, @@ -452,11 +452,11 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle } // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} + // CheckValidity for kernels without multi D template __host__ __device__ static constexpr bool CheckValidity(const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1, const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1, - const DsGridDesc_M_N& ds_grid_desc_m_n, const EGridDesc_M_N& e_grid_desc_m_n, const Block2CTileMap& block_2_ctile_map) { @@ -471,18 +471,6 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle const auto N = b_grid_desc_k0_n_k1.GetLength(I1); const auto K0 = a_grid_desc_k0_m_k1.GetLength(I0); - bool valid = true; - - static_for<0, NumDTensor, 1>{}([&](auto i) { - valid = valid && (M == ds_grid_desc_m_n[i].GetLength(I0) && - N == ds_grid_desc_m_n[i].GetLength(I1)); - }); - - if(!valid) - { - return false; - } - if(!(M == e_grid_desc_m_n.GetLength(I0) && N == e_grid_desc_m_n.GetLength(I1) && K0 == b_grid_desc_k0_n_k1.GetLength(I0) && K1 == a_grid_desc_k0_m_k1.GetLength(I2) && K1 == b_grid_desc_k0_n_k1.GetLength(I2))) @@ -517,6 +505,31 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle return true; } + template + __host__ __device__ static constexpr bool + CheckValidity(const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1, + const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1, + const DsGridDesc_M_N& ds_grid_desc_m_n, + const EGridDesc_M_N& e_grid_desc_m_n, + const Block2CTileMap& block_2_ctile_map) + { + const auto M = a_grid_desc_k0_m_k1.GetLength(I1); + const auto N = b_grid_desc_k0_n_k1.GetLength(I1); + bool valid = true; + static_for<0, NumDTensor, 1>{}([&](auto i) { + valid = valid && (M == ds_grid_desc_m_n[i].GetLength(I0) && + N == ds_grid_desc_m_n[i].GetLength(I1)); + }); + + if(!valid) + { + return false; + } + + return CheckValidity( + a_grid_desc_k0_m_k1, b_grid_desc_k0_n_k1, e_grid_desc_m_n, block_2_ctile_map); + } + __host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K) { const index_t num_loop = K / (K0PerBlock * K1); diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_dl_instance.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_dl_instance.hpp index 29fe298e42..c3e333e720 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_dl_instance.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_dl_instance.hpp @@ -6,8 +6,6 @@ #include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_dl.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" -#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" - namespace ck { namespace tensor_operation { namespace device { diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_wmma_instance.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_wmma_instance.hpp new file mode 100644 index 0000000000..40c4d558b8 --- /dev/null +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_wmma_instance.hpp @@ -0,0 +1,119 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle.hpp" +#include "ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; +using I8 = int8_t; +using I32 = int32_t; + +template +using S = ck::Sequence; + +using namespace ck::tensor_layout::convolution; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto ConvBwdWeightDefault = + ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization::Default; + +static constexpr auto ConvBwdWeightFilter1x1Stride1Pad0 = + ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0; + +template +using device_grouped_conv_bwd_weight_wmma_f16_instances = + std::tuple< + // clang-format off + //#####################################| NumDim| A| B| C| AData| BData| CData| AccData| A| B| C| ConvForward| Block| MPer| NPer| KPer| K1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#####################################| Spatial| Layout| Layout| Layout| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | WMMA| WMMA| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeatPerWave| NRepeatPerWave| _MBlock_MPerBlock| ScalarPerVector| + //#####################################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| + //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // generic instance + DeviceGroupedConvBwdWeight_Wmma_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, 1>, + // blocksize=256 + DeviceGroupedConvBwdWeight_Wmma_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 4>, + DeviceGroupedConvBwdWeight_Wmma_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 32, 1, 8>, 2>, + DeviceGroupedConvBwdWeight_Wmma_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 4>, + DeviceGroupedConvBwdWeight_Wmma_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 16>, 4>, + // blocksize=128 + DeviceGroupedConvBwdWeight_Wmma_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvBwdWeight_Wmma_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvBwdWeight_Wmma_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvBwdWeight_Wmma_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvBwdWeight_Wmma_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + // blocksize=64 + DeviceGroupedConvBwdWeight_Wmma_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, S<8, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGroupedConvBwdWeight_Wmma_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<8, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 2>, 8>, + DeviceGroupedConvBwdWeight_Wmma_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, S<8, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 2>, 8>, + DeviceGroupedConvBwdWeight_Wmma_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, S<8, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGroupedConvBwdWeight_Wmma_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<8, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 2>, 8>, + // blocksize=32 + DeviceGroupedConvBwdWeight_Wmma_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<8, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 2>, 8>, + DeviceGroupedConvBwdWeight_Wmma_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<8, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 2>, 8>, + DeviceGroupedConvBwdWeight_Wmma_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, S<8, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 2>, 8>, + DeviceGroupedConvBwdWeight_Wmma_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, S<8, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 2>, 8>, + DeviceGroupedConvBwdWeight_Wmma_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, S<8, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 2>, 8>, + DeviceGroupedConvBwdWeight_Wmma_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, S<8, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 2>, 8>, + DeviceGroupedConvBwdWeight_Wmma_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, S<8, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 2>, 8> + // clang-format on + >; + +template +using device_grouped_conv_bwd_weight_wmma_i8_instances = + std::tuple< + // clang-format off + //#####################################| NumDim| A| B| C| AData| BData| CData| AccData| A| B| C| ConvForward| Block| MPer| NPer| KPer| K1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#####################################| Spatial| Layout| Layout| Layout| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | WMMA| WMMA| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeatPerWave| NRepeatPerWave| _MBlock_MPerBlock| ScalarPerVector| + //#####################################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| + //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // generic instance + DeviceGroupedConvBwdWeight_Wmma_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, 1>, + // blocksize=256 + DeviceGroupedConvBwdWeight_Wmma_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvBwdWeight_Wmma_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 64, 1, 4>, 8>, + // blocksize=128 + DeviceGroupedConvBwdWeight_Wmma_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvBwdWeight_Wmma_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvBwdWeight_Wmma_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvBwdWeight_Wmma_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvBwdWeight_Wmma_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvBwdWeight_Wmma_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 8, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvBwdWeight_Wmma_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 8, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 2>, + DeviceGroupedConvBwdWeight_Wmma_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 8, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + // blocksize=64 + DeviceGroupedConvBwdWeight_Wmma_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<8, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 8, 1, 1, 1, S<1, 32, 1, 2>, 8>, + DeviceGroupedConvBwdWeight_Wmma_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, S<8, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 8, 1, 1, 1, S<1, 32, 1, 2>, 8>, + DeviceGroupedConvBwdWeight_Wmma_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 8, 1, S<8, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGroupedConvBwdWeight_Wmma_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 8, 1, S<8, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + // blocksize=32 + DeviceGroupedConvBwdWeight_Wmma_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<8, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 8, 1, 1, 1, S<1, 16, 1, 2>, 8>, + DeviceGroupedConvBwdWeight_Wmma_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 8, 1, S<8, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 8, 1, 1, 1, S<1, 16, 1, 2>, 8>, + DeviceGroupedConvBwdWeight_Wmma_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, S<8, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 2>, 8>, + DeviceGroupedConvBwdWeight_Wmma_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 8, 1, S<8, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 2>, 8> + // clang-format on + >; + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight.hpp index 9f536f210d..9310e9e57d 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight.hpp @@ -163,6 +163,30 @@ void add_device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f16_instances PassThrough, PassThrough, PassThrough>>>& instances); + +void add_device_grouped_conv3d_bwd_weight_wmma_gndhwc_gkzyxc_gndhwk_f16_instances( + std::vector>>& instances); + +void add_device_grouped_conv3d_bwd_weight_wmma_gndhwc_gkzyxc_gndhwk_f16_1x1s1p0_instances( + std::vector>>& instances); #endif #ifdef CK_ENABLE_FP32 void add_device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f32_instances( @@ -177,6 +201,31 @@ void add_device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f32_instances PassThrough, PassThrough>>>& instances); #endif +#ifdef CK_ENABLE_INT8 +void add_device_grouped_conv3d_bwd_weight_wmma_gndhwc_gkzyxc_gndhwk_i8_instances( + std::vector>>& instances); + +void add_device_grouped_conv3d_bwd_weight_wmma_gndhwc_gkzyxc_gndhwk_i8_1x1s1p0_instances( + std::vector>>& instances); +#endif #ifdef CK_ENABLE_BF16 void add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_f32_bf16_instances( std::vector>>& instances); + +void add_device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_f16_instances( + std::vector>>& instances); + +void add_device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_f16_1x1s1p0_instances( + std::vector>>& instances); #endif #ifdef CK_ENABLE_FP32 void add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instances( @@ -231,6 +304,31 @@ void add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_bf8_ BF8, F8>>>& instances); #endif +#ifdef CK_ENABLE_INT8 +void add_device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_i8_instances( + std::vector>>& instances); + +void add_device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_i8_1x1s1p0_instances( + std::vector>>& instances); +#endif #ifdef DL_KERNELS // dl @@ -694,6 +792,10 @@ struct DeviceOperationInstanceFactory && is_same_v && + is_same_v) + { + add_device_grouped_conv3d_bwd_weight_wmma_gndhwc_gkzyxc_gndhwk_i8_instances( + op_ptrs); + add_device_grouped_conv3d_bwd_weight_wmma_gndhwc_gkzyxc_gndhwk_i8_1x1s1p0_instances( + op_ptrs); + } #endif } else if constexpr(is_same_v && is_same_v && @@ -737,6 +849,10 @@ struct DeviceOperationInstanceFactory && is_same_v && + is_same_v) + { + add_device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_i8_instances( + op_ptrs); + add_device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_i8_1x1s1p0_instances( + op_ptrs); + } +#endif #if defined CK_ENABLE_FP16 && defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8 else if constexpr(is_same_v && is_same_v && is_same_v && diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv1d_bwd_weight/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv1d_bwd_weight/CMakeLists.txt index 660f2544ac..cfd829f87e 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv1d_bwd_weight/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv1d_bwd_weight/CMakeLists.txt @@ -1,16 +1,16 @@ set(GROUPED_CONV1D_BWD_WEIGHT - device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f16_instance.cpp - device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f32_instance.cpp - device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_bf16_instance.cpp) + xdl/device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f16_instance.cpp + xdl/device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f32_instance.cpp + xdl/device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_bf16_instance.cpp) if(DL_KERNELS) list(APPEND GROUPED_CONV1D_BWD_WEIGHT - device_grouped_conv1d_bwd_weight_dl_gnwc_gkxc_gnwk_f16_instance.cpp - device_grouped_conv1d_bwd_weight_dl_gnwc_gkxc_gnwk_f32_instance.cpp - device_grouped_conv1d_bwd_weight_dl_gnwc_gkxc_gnwk_bf16_instance.cpp - device_grouped_conv1d_bwd_weight_dl_nwgc_gkxc_nwgk_f16_instance.cpp - device_grouped_conv1d_bwd_weight_dl_nwgc_gkxc_nwgk_f32_instance.cpp - device_grouped_conv1d_bwd_weight_dl_nwgc_gkxc_nwgk_bf16_instance.cpp) + dl/device_grouped_conv1d_bwd_weight_dl_gnwc_gkxc_gnwk_f16_instance.cpp + dl/device_grouped_conv1d_bwd_weight_dl_gnwc_gkxc_gnwk_f32_instance.cpp + dl/device_grouped_conv1d_bwd_weight_dl_gnwc_gkxc_gnwk_bf16_instance.cpp + dl/device_grouped_conv1d_bwd_weight_dl_nwgc_gkxc_nwgk_f16_instance.cpp + dl/device_grouped_conv1d_bwd_weight_dl_nwgc_gkxc_nwgk_f32_instance.cpp + dl/device_grouped_conv1d_bwd_weight_dl_nwgc_gkxc_nwgk_bf16_instance.cpp) endif() add_instance_library(device_grouped_conv1d_bwd_weight_instance ${GROUPED_CONV1D_BWD_WEIGHT}) diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv1d_bwd_weight/device_grouped_conv1d_bwd_weight_dl_gnwc_gkxc_gnwk_bf16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv1d_bwd_weight/dl/device_grouped_conv1d_bwd_weight_dl_gnwc_gkxc_gnwk_bf16_instance.cpp similarity index 100% rename from library/src/tensor_operation_instance/gpu/grouped_conv1d_bwd_weight/device_grouped_conv1d_bwd_weight_dl_gnwc_gkxc_gnwk_bf16_instance.cpp rename to library/src/tensor_operation_instance/gpu/grouped_conv1d_bwd_weight/dl/device_grouped_conv1d_bwd_weight_dl_gnwc_gkxc_gnwk_bf16_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv1d_bwd_weight/device_grouped_conv1d_bwd_weight_dl_gnwc_gkxc_gnwk_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv1d_bwd_weight/dl/device_grouped_conv1d_bwd_weight_dl_gnwc_gkxc_gnwk_f16_instance.cpp similarity index 100% rename from library/src/tensor_operation_instance/gpu/grouped_conv1d_bwd_weight/device_grouped_conv1d_bwd_weight_dl_gnwc_gkxc_gnwk_f16_instance.cpp rename to library/src/tensor_operation_instance/gpu/grouped_conv1d_bwd_weight/dl/device_grouped_conv1d_bwd_weight_dl_gnwc_gkxc_gnwk_f16_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv1d_bwd_weight/device_grouped_conv1d_bwd_weight_dl_gnwc_gkxc_gnwk_f32_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv1d_bwd_weight/dl/device_grouped_conv1d_bwd_weight_dl_gnwc_gkxc_gnwk_f32_instance.cpp similarity index 100% rename from library/src/tensor_operation_instance/gpu/grouped_conv1d_bwd_weight/device_grouped_conv1d_bwd_weight_dl_gnwc_gkxc_gnwk_f32_instance.cpp rename to library/src/tensor_operation_instance/gpu/grouped_conv1d_bwd_weight/dl/device_grouped_conv1d_bwd_weight_dl_gnwc_gkxc_gnwk_f32_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv1d_bwd_weight/device_grouped_conv1d_bwd_weight_dl_nwgc_gkxc_nwgk_bf16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv1d_bwd_weight/dl/device_grouped_conv1d_bwd_weight_dl_nwgc_gkxc_nwgk_bf16_instance.cpp similarity index 100% rename from library/src/tensor_operation_instance/gpu/grouped_conv1d_bwd_weight/device_grouped_conv1d_bwd_weight_dl_nwgc_gkxc_nwgk_bf16_instance.cpp rename to library/src/tensor_operation_instance/gpu/grouped_conv1d_bwd_weight/dl/device_grouped_conv1d_bwd_weight_dl_nwgc_gkxc_nwgk_bf16_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv1d_bwd_weight/device_grouped_conv1d_bwd_weight_dl_nwgc_gkxc_nwgk_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv1d_bwd_weight/dl/device_grouped_conv1d_bwd_weight_dl_nwgc_gkxc_nwgk_f16_instance.cpp similarity index 100% rename from library/src/tensor_operation_instance/gpu/grouped_conv1d_bwd_weight/device_grouped_conv1d_bwd_weight_dl_nwgc_gkxc_nwgk_f16_instance.cpp rename to library/src/tensor_operation_instance/gpu/grouped_conv1d_bwd_weight/dl/device_grouped_conv1d_bwd_weight_dl_nwgc_gkxc_nwgk_f16_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv1d_bwd_weight/device_grouped_conv1d_bwd_weight_dl_nwgc_gkxc_nwgk_f32_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv1d_bwd_weight/dl/device_grouped_conv1d_bwd_weight_dl_nwgc_gkxc_nwgk_f32_instance.cpp similarity index 100% rename from library/src/tensor_operation_instance/gpu/grouped_conv1d_bwd_weight/device_grouped_conv1d_bwd_weight_dl_nwgc_gkxc_nwgk_f32_instance.cpp rename to library/src/tensor_operation_instance/gpu/grouped_conv1d_bwd_weight/dl/device_grouped_conv1d_bwd_weight_dl_nwgc_gkxc_nwgk_f32_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv1d_bwd_weight/device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_bf16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv1d_bwd_weight/xdl/device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_bf16_instance.cpp similarity index 100% rename from library/src/tensor_operation_instance/gpu/grouped_conv1d_bwd_weight/device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_bf16_instance.cpp rename to library/src/tensor_operation_instance/gpu/grouped_conv1d_bwd_weight/xdl/device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_bf16_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv1d_bwd_weight/device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv1d_bwd_weight/xdl/device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f16_instance.cpp similarity index 100% rename from library/src/tensor_operation_instance/gpu/grouped_conv1d_bwd_weight/device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f16_instance.cpp rename to library/src/tensor_operation_instance/gpu/grouped_conv1d_bwd_weight/xdl/device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f16_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv1d_bwd_weight/device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f32_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv1d_bwd_weight/xdl/device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f32_instance.cpp similarity index 100% rename from library/src/tensor_operation_instance/gpu/grouped_conv1d_bwd_weight/device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f32_instance.cpp rename to library/src/tensor_operation_instance/gpu/grouped_conv1d_bwd_weight/xdl/device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f32_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/CMakeLists.txt index 578e9029e9..8a896b06c7 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/CMakeLists.txt @@ -1,19 +1,19 @@ set(GROUPED_CONV2D_BWD_WEIGHT - device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f16_instance.cpp - device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f32_instance.cpp - device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_bf16_instance.cpp - device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f16_instance.cpp - device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_instance.cpp - device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_instance.cpp) + xdl/device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f16_instance.cpp + xdl/device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f32_instance.cpp + xdl/device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_bf16_instance.cpp + xdl/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f16_instance.cpp + xdl/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_instance.cpp + xdl/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_instance.cpp) if(DL_KERNELS) list(APPEND GROUPED_CONV2D_BWD_WEIGHT - device_grouped_conv2d_bwd_weight_dl_gnhwc_gkyxc_gnhwk_f16_instance.cpp - device_grouped_conv2d_bwd_weight_dl_gnhwc_gkyxc_gnhwk_f32_instance.cpp - device_grouped_conv2d_bwd_weight_dl_gnhwc_gkyxc_gnhwk_bf16_instance.cpp - device_grouped_conv2d_bwd_weight_dl_nhwgc_gkyxc_nhwgk_f16_instance.cpp - device_grouped_conv2d_bwd_weight_dl_nhwgc_gkyxc_nhwgk_f32_instance.cpp - device_grouped_conv2d_bwd_weight_dl_nhwgc_gkyxc_nhwgk_bf16_instance.cpp) + dl/device_grouped_conv2d_bwd_weight_dl_gnhwc_gkyxc_gnhwk_f16_instance.cpp + dl/device_grouped_conv2d_bwd_weight_dl_gnhwc_gkyxc_gnhwk_f32_instance.cpp + dl/device_grouped_conv2d_bwd_weight_dl_gnhwc_gkyxc_gnhwk_bf16_instance.cpp + dl/device_grouped_conv2d_bwd_weight_dl_nhwgc_gkyxc_nhwgk_f16_instance.cpp + dl/device_grouped_conv2d_bwd_weight_dl_nhwgc_gkyxc_nhwgk_f32_instance.cpp + dl/device_grouped_conv2d_bwd_weight_dl_nhwgc_gkyxc_nhwgk_bf16_instance.cpp) endif() add_instance_library(device_grouped_conv2d_bwd_weight_instance ${GROUPED_CONV2D_BWD_WEIGHT}) diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/device_grouped_conv2d_bwd_weight_dl_gnhwc_gkyxc_gnhwk_bf16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/dl/device_grouped_conv2d_bwd_weight_dl_gnhwc_gkyxc_gnhwk_bf16_instance.cpp similarity index 100% rename from library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/device_grouped_conv2d_bwd_weight_dl_gnhwc_gkyxc_gnhwk_bf16_instance.cpp rename to library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/dl/device_grouped_conv2d_bwd_weight_dl_gnhwc_gkyxc_gnhwk_bf16_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/device_grouped_conv2d_bwd_weight_dl_gnhwc_gkyxc_gnhwk_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/dl/device_grouped_conv2d_bwd_weight_dl_gnhwc_gkyxc_gnhwk_f16_instance.cpp similarity index 100% rename from library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/device_grouped_conv2d_bwd_weight_dl_gnhwc_gkyxc_gnhwk_f16_instance.cpp rename to library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/dl/device_grouped_conv2d_bwd_weight_dl_gnhwc_gkyxc_gnhwk_f16_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/device_grouped_conv2d_bwd_weight_dl_gnhwc_gkyxc_gnhwk_f32_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/dl/device_grouped_conv2d_bwd_weight_dl_gnhwc_gkyxc_gnhwk_f32_instance.cpp similarity index 100% rename from library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/device_grouped_conv2d_bwd_weight_dl_gnhwc_gkyxc_gnhwk_f32_instance.cpp rename to library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/dl/device_grouped_conv2d_bwd_weight_dl_gnhwc_gkyxc_gnhwk_f32_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/device_grouped_conv2d_bwd_weight_dl_nhwgc_gkyxc_nhwgk_bf16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/dl/device_grouped_conv2d_bwd_weight_dl_nhwgc_gkyxc_nhwgk_bf16_instance.cpp similarity index 100% rename from library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/device_grouped_conv2d_bwd_weight_dl_nhwgc_gkyxc_nhwgk_bf16_instance.cpp rename to library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/dl/device_grouped_conv2d_bwd_weight_dl_nhwgc_gkyxc_nhwgk_bf16_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/device_grouped_conv2d_bwd_weight_dl_nhwgc_gkyxc_nhwgk_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/dl/device_grouped_conv2d_bwd_weight_dl_nhwgc_gkyxc_nhwgk_f16_instance.cpp similarity index 100% rename from library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/device_grouped_conv2d_bwd_weight_dl_nhwgc_gkyxc_nhwgk_f16_instance.cpp rename to library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/dl/device_grouped_conv2d_bwd_weight_dl_nhwgc_gkyxc_nhwgk_f16_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/device_grouped_conv2d_bwd_weight_dl_nhwgc_gkyxc_nhwgk_f32_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/dl/device_grouped_conv2d_bwd_weight_dl_nhwgc_gkyxc_nhwgk_f32_instance.cpp similarity index 100% rename from library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/device_grouped_conv2d_bwd_weight_dl_nhwgc_gkyxc_nhwgk_f32_instance.cpp rename to library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/dl/device_grouped_conv2d_bwd_weight_dl_nhwgc_gkyxc_nhwgk_f32_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_bf16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_bf16_instance.cpp similarity index 100% rename from library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_bf16_instance.cpp rename to library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_bf16_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f16_instance.cpp similarity index 100% rename from library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f16_instance.cpp rename to library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f16_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f32_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f32_instance.cpp similarity index 100% rename from library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f32_instance.cpp rename to library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f32_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_instance.cpp similarity index 100% rename from library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_instance.cpp rename to library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f16_instance.cpp similarity index 100% rename from library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f16_instance.cpp rename to library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f16_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_instance.cpp similarity index 100% rename from library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_instance.cpp rename to library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/CMakeLists.txt index 65ea645a39..a2c4b1f80b 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/CMakeLists.txt @@ -1,20 +1,30 @@ set(GROUPED_CONV3D_BWD_WEIGHT - device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f16_instance.cpp - device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f32_instance.cpp - device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_bf16_instance.cpp - device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp - device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp - device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp - device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_bf8_fp8_instance.cpp) + xdl/device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f16_instance.cpp + xdl/device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f32_instance.cpp + xdl/device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_bf16_instance.cpp + xdl/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp + xdl/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp + xdl/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp + xdl/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_bf8_fp8_instance.cpp) if(DL_KERNELS) list(APPEND GROUPED_CONV3D_BWD_WEIGHT - device_grouped_conv3d_bwd_weight_dl_gndhwc_gkzyxc_gndhwk_f16_instance.cpp - device_grouped_conv3d_bwd_weight_dl_gndhwc_gkzyxc_gndhwk_f32_instance.cpp - device_grouped_conv3d_bwd_weight_dl_gndhwc_gkzyxc_gndhwk_bf16_instance.cpp - device_grouped_conv3d_bwd_weight_dl_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp - device_grouped_conv3d_bwd_weight_dl_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp - device_grouped_conv3d_bwd_weight_dl_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp) + dl/device_grouped_conv3d_bwd_weight_dl_gndhwc_gkzyxc_gndhwk_f16_instance.cpp + dl/device_grouped_conv3d_bwd_weight_dl_gndhwc_gkzyxc_gndhwk_f32_instance.cpp + dl/device_grouped_conv3d_bwd_weight_dl_gndhwc_gkzyxc_gndhwk_bf16_instance.cpp + dl/device_grouped_conv3d_bwd_weight_dl_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp + dl/device_grouped_conv3d_bwd_weight_dl_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp + dl/device_grouped_conv3d_bwd_weight_dl_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp) endif() +list(APPEND GROUPED_CONV3D_BWD_WEIGHT + wmma/device_grouped_conv3d_bwd_weight_wmma_gndhwc_gkzyxc_gndhwk_f16_1x1s1p0_instance.cpp + wmma/device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_f16_1x1s1p0_instance.cpp + wmma/device_grouped_conv3d_bwd_weight_wmma_gndhwc_gkzyxc_gndhwk_f16_instance.cpp + wmma/device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp + wmma/device_grouped_conv3d_bwd_weight_wmma_gndhwc_gkzyxc_gndhwk_i8_1x1s1p0_instance.cpp + wmma/device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_i8_1x1s1p0_instance.cpp + wmma/device_grouped_conv3d_bwd_weight_wmma_gndhwc_gkzyxc_gndhwk_i8_instance.cpp + wmma/device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_i8_instance.cpp) + add_instance_library(device_grouped_conv3d_bwd_weight_instance ${GROUPED_CONV3D_BWD_WEIGHT}) diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/device_grouped_conv3d_bwd_weight_dl_gndhwc_gkzyxc_gndhwk_bf16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/dl/device_grouped_conv3d_bwd_weight_dl_gndhwc_gkzyxc_gndhwk_bf16_instance.cpp similarity index 100% rename from library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/device_grouped_conv3d_bwd_weight_dl_gndhwc_gkzyxc_gndhwk_bf16_instance.cpp rename to library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/dl/device_grouped_conv3d_bwd_weight_dl_gndhwc_gkzyxc_gndhwk_bf16_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/device_grouped_conv3d_bwd_weight_dl_gndhwc_gkzyxc_gndhwk_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/dl/device_grouped_conv3d_bwd_weight_dl_gndhwc_gkzyxc_gndhwk_f16_instance.cpp similarity index 100% rename from library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/device_grouped_conv3d_bwd_weight_dl_gndhwc_gkzyxc_gndhwk_f16_instance.cpp rename to library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/dl/device_grouped_conv3d_bwd_weight_dl_gndhwc_gkzyxc_gndhwk_f16_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/device_grouped_conv3d_bwd_weight_dl_gndhwc_gkzyxc_gndhwk_f32_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/dl/device_grouped_conv3d_bwd_weight_dl_gndhwc_gkzyxc_gndhwk_f32_instance.cpp similarity index 100% rename from library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/device_grouped_conv3d_bwd_weight_dl_gndhwc_gkzyxc_gndhwk_f32_instance.cpp rename to library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/dl/device_grouped_conv3d_bwd_weight_dl_gndhwc_gkzyxc_gndhwk_f32_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/device_grouped_conv3d_bwd_weight_dl_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/dl/device_grouped_conv3d_bwd_weight_dl_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp similarity index 100% rename from library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/device_grouped_conv3d_bwd_weight_dl_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp rename to library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/dl/device_grouped_conv3d_bwd_weight_dl_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/device_grouped_conv3d_bwd_weight_dl_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/dl/device_grouped_conv3d_bwd_weight_dl_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp similarity index 100% rename from library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/device_grouped_conv3d_bwd_weight_dl_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp rename to library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/dl/device_grouped_conv3d_bwd_weight_dl_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/device_grouped_conv3d_bwd_weight_dl_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/dl/device_grouped_conv3d_bwd_weight_dl_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp similarity index 100% rename from library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/device_grouped_conv3d_bwd_weight_dl_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp rename to library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/dl/device_grouped_conv3d_bwd_weight_dl_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/wmma/device_grouped_conv3d_bwd_weight_wmma_gndhwc_gkzyxc_gndhwk_f16_1x1s1p0_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/wmma/device_grouped_conv3d_bwd_weight_wmma_gndhwc_gkzyxc_gndhwk_f16_1x1s1p0_instance.cpp new file mode 100644 index 0000000000..26171031aa --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/wmma/device_grouped_conv3d_bwd_weight_wmma_gndhwc_gkzyxc_gndhwk_f16_1x1s1p0_instance.cpp @@ -0,0 +1,35 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_wmma_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +void add_device_grouped_conv3d_bwd_weight_wmma_gndhwc_gkzyxc_gndhwk_f16_1x1s1p0_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_bwd_weight_wmma_f16_instances<3, + GNDHWC, + GKZYXC, + GNDHWK, + ConvBwdWeightFilter1x1Stride1Pad0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/wmma/device_grouped_conv3d_bwd_weight_wmma_gndhwc_gkzyxc_gndhwk_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/wmma/device_grouped_conv3d_bwd_weight_wmma_gndhwc_gkzyxc_gndhwk_f16_instance.cpp new file mode 100644 index 0000000000..9399655700 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/wmma/device_grouped_conv3d_bwd_weight_wmma_gndhwc_gkzyxc_gndhwk_f16_instance.cpp @@ -0,0 +1,35 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_wmma_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +void add_device_grouped_conv3d_bwd_weight_wmma_gndhwc_gkzyxc_gndhwk_f16_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_bwd_weight_wmma_f16_instances<3, + GNDHWC, + GKZYXC, + GNDHWK, + ConvBwdWeightDefault>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/wmma/device_grouped_conv3d_bwd_weight_wmma_gndhwc_gkzyxc_gndhwk_i8_1x1s1p0_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/wmma/device_grouped_conv3d_bwd_weight_wmma_gndhwc_gkzyxc_gndhwk_i8_1x1s1p0_instance.cpp new file mode 100644 index 0000000000..3187173677 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/wmma/device_grouped_conv3d_bwd_weight_wmma_gndhwc_gkzyxc_gndhwk_i8_1x1s1p0_instance.cpp @@ -0,0 +1,35 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_wmma_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +void add_device_grouped_conv3d_bwd_weight_wmma_gndhwc_gkzyxc_gndhwk_i8_1x1s1p0_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_bwd_weight_wmma_i8_instances<3, + GNDHWC, + GKZYXC, + GNDHWK, + ConvBwdWeightFilter1x1Stride1Pad0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/wmma/device_grouped_conv3d_bwd_weight_wmma_gndhwc_gkzyxc_gndhwk_i8_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/wmma/device_grouped_conv3d_bwd_weight_wmma_gndhwc_gkzyxc_gndhwk_i8_instance.cpp new file mode 100644 index 0000000000..19451b855e --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/wmma/device_grouped_conv3d_bwd_weight_wmma_gndhwc_gkzyxc_gndhwk_i8_instance.cpp @@ -0,0 +1,35 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_wmma_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +void add_device_grouped_conv3d_bwd_weight_wmma_gndhwc_gkzyxc_gndhwk_i8_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_bwd_weight_wmma_i8_instances<3, + GNDHWC, + GKZYXC, + GNDHWK, + ConvBwdWeightDefault>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/wmma/device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_f16_1x1s1p0_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/wmma/device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_f16_1x1s1p0_instance.cpp new file mode 100644 index 0000000000..71b4365d62 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/wmma/device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_f16_1x1s1p0_instance.cpp @@ -0,0 +1,35 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_wmma_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +void add_device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_f16_1x1s1p0_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_bwd_weight_wmma_f16_instances<3, + NDHWGC, + GKZYXC, + NDHWGK, + ConvBwdWeightFilter1x1Stride1Pad0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/wmma/device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/wmma/device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp new file mode 100644 index 0000000000..06134d1f20 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/wmma/device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp @@ -0,0 +1,35 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_wmma_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +void add_device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_f16_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_bwd_weight_wmma_f16_instances<3, + NDHWGC, + GKZYXC, + NDHWGK, + ConvBwdWeightDefault>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/wmma/device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_i8_1x1s1p0_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/wmma/device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_i8_1x1s1p0_instance.cpp new file mode 100644 index 0000000000..39e2cc8c28 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/wmma/device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_i8_1x1s1p0_instance.cpp @@ -0,0 +1,35 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_wmma_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +void add_device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_i8_1x1s1p0_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_bwd_weight_wmma_i8_instances<3, + NDHWGC, + GKZYXC, + NDHWGK, + ConvBwdWeightFilter1x1Stride1Pad0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/wmma/device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_i8_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/wmma/device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_i8_instance.cpp new file mode 100644 index 0000000000..0d56619aeb --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/wmma/device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_i8_instance.cpp @@ -0,0 +1,35 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_wmma_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +void add_device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_i8_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_bwd_weight_wmma_i8_instances<3, + NDHWGC, + GKZYXC, + NDHWGK, + ConvBwdWeightDefault>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_bf16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_bf16_instance.cpp similarity index 100% rename from library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_bf16_instance.cpp rename to library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_bf16_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f16_instance.cpp similarity index 100% rename from library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f16_instance.cpp rename to library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f16_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f32_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f32_instance.cpp similarity index 100% rename from library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f32_instance.cpp rename to library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f32_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp similarity index 100% rename from library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp rename to library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_bf8_fp8_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_bf8_fp8_instance.cpp similarity index 100% rename from library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_bf8_fp8_instance.cpp rename to library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_bf8_fp8_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp similarity index 100% rename from library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp rename to library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp similarity index 100% rename from library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp rename to library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp diff --git a/profiler/README.md b/profiler/README.md index d081fc33e9..a2d370b210 100644 --- a/profiler/README.md +++ b/profiler/README.md @@ -147,7 +147,9 @@ GB/s: 127.947 # arg1: tensor operation (grouped_conv_bwd_weight: Grouped Convolution Backward Weight) # arg2: data type (0: Input fp32, Weight fp32, Output fp32 # 1: Input fp16, Weight fp16, Output fp16 -# 2: Input bf16, Weight fp32, Output bf16) +# 2: Input bf16, Weight fp32, Output bf16 +# 3: Input fp16, Weight fp16, Output fp16, Gemm bf8@fp8 +# 4: Input int8, Weight int8, Output int8) # arg3: tensor layout (0: Input[G, N, C, Hi, Wi], Weight[G, K, C, Y, X], Output[G, N, K, Ho, Wo] # 1: Input[G, N, Hi, Wi, C], Weight[G, K, Y, X, C], Output[G, N, Ho, Wo, K] # 2: Input[N, Hi, Wi, G, C], Weight[G, K, Y, X, C], Output[N, Ho, Wo, G, K] @@ -167,7 +169,7 @@ GB/s: 127.947 # SplitK ################ op datatype layout verify init log time Ndims G N K C Y X Hi Wi Sy Sx Dy Dx LeftPy LeftPx RightPy RightPx SplitK -./bin/ckProfiler grouped_conv_bwd_weight 1 0 1 1 0 1 2 32 256 256 512 3 3 28 28 1 1 1 1 1 0 0 0 1 +./bin/ckProfiler grouped_conv_bwd_weight 1 1 0 1 0 1 2 32 256 256 512 3 3 28 28 1 1 1 1 1 0 0 0 1 ``` diff --git a/profiler/src/profile_grouped_conv_bwd_weight.cpp b/profiler/src/profile_grouped_conv_bwd_weight.cpp index 9a9b5a0f74..bd1016f286 100644 --- a/profiler/src/profile_grouped_conv_bwd_weight.cpp +++ b/profiler/src/profile_grouped_conv_bwd_weight.cpp @@ -20,10 +20,11 @@ enum struct ConvLayout enum struct ConvDataType { - F32_F32_F32, // 0 - F16_F16_F16, // 1 - BF16_F32_BF16, // 2 - F16_F16_F16_BF8_F8 // 3 + F32_F32_F32, // 0 + F16_F16_F16, // 1 + BF16_F32_BF16, // 2 + F16_F16_F16_BF8_F8, // 3 + I8_I8_I8 // 4 }; #define OP_NAME "grouped_conv_bwd_weight" @@ -35,7 +36,8 @@ static void print_helper_msg() << "arg2: data type (0: Input fp32, Weight fp32, Output fp32\n" << " 1: Input fp16, Weight fp16, Output fp16\n" << " 2: Input bf16, Weight fp32, Output bf16\n" - << " 3: Input fp16, Weight fp16, Output fp16, Gemm bf8@fp8)\n" + << " 3: Input fp16, Weight fp16, Output fp16, Gemm bf8@fp8\n" + << " 4: Input int8, Weight int8, Output int8)\n" << "arg3: tensor layout (0: Input[G, N, C, Hi, Wi], Weight[G, K, C, Y, X], Output[G, " "N, K, Ho, Wo]\n" << " 1: Input[G, N, Hi, Wi, C], Weight[G, K, Y, X, C], Output[G, " @@ -196,6 +198,11 @@ int profile_grouped_conv_bwd_weight(int argc, char* argv[]) // fp32 atomic add is used for weight tensor in bf16 kernel return profile(I3, GNDHWC{}, GKZYXC{}, GNDHWK{}, BF16{}, F32{}, BF16{}, BF16{}, BF16{}); } + else if(data_type == ConvDataType::I8_I8_I8) + { + return profile( + I3, GNDHWC{}, GKZYXC{}, GNDHWK{}, int8_t{}, int8_t{}, int8_t{}, int8_t{}, int8_t{}); + } } else if(num_dim_spatial == 3 && layout == ConvLayout::NHWGC_GKYXC_NHWGK) { @@ -216,6 +223,11 @@ int profile_grouped_conv_bwd_weight(int argc, char* argv[]) { return profile(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, F16{}, F16{}, F16{}, BF8{}, F8{}); } + else if(data_type == ConvDataType::I8_I8_I8) + { + return profile( + I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, int8_t{}, int8_t{}, int8_t{}, int8_t{}, int8_t{}); + } } std::cout << "this data_type & layout is not implemented" << std::endl; diff --git a/test/grouped_convnd_bwd_weight/CMakeLists.txt b/test/grouped_convnd_bwd_weight/CMakeLists.txt index f1a3f04b77..b167943c97 100644 --- a/test/grouped_convnd_bwd_weight/CMakeLists.txt +++ b/test/grouped_convnd_bwd_weight/CMakeLists.txt @@ -1,11 +1,20 @@ -list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) +list(APPEND gpu_list_xdl gfx908 gfx90a gfx940 gfx941 gfx942) +list(APPEND gpu_list_wmma gfx1100 gfx1101 gfx1102) + set(target 0) foreach(gpu IN LISTS GPU_TARGETS) - if(gpu IN_LIST gpu_list AND target EQUAL 0) + if(gpu IN_LIST gpu_list_xdl AND target EQUAL 0) add_gtest_executable(test_grouped_convnd_bwd_weight test_grouped_convnd_bwd_weight.cpp) target_link_libraries(test_grouped_convnd_bwd_weight PRIVATE utility device_grouped_conv1d_bwd_weight_instance device_grouped_conv2d_bwd_weight_instance device_grouped_conv3d_bwd_weight_instance) - add_gtest_executable(test_grouped_convnd_bwd_weight_interface test_grouped_convnd_bwd_weight_interface.cpp) - target_link_libraries(test_grouped_convnd_bwd_weight_interface PRIVATE utility device_grouped_conv1d_bwd_weight_instance device_grouped_conv2d_bwd_weight_instance device_grouped_conv3d_bwd_weight_instance) + add_gtest_executable(test_grouped_convnd_bwd_weight_interface test_grouped_convnd_bwd_weight_interface_xdl.cpp) + target_link_libraries(test_grouped_convnd_bwd_weight_interface PRIVATE utility) + set(target 1) + endif() + if(gpu IN_LIST gpu_list_wmma AND target EQUAL 0) + add_gtest_executable(test_grouped_convnd_bwd_weight test_grouped_convnd_bwd_weight.cpp) + target_link_libraries(test_grouped_convnd_bwd_weight PRIVATE utility device_grouped_conv1d_bwd_weight_instance device_grouped_conv2d_bwd_weight_instance device_grouped_conv3d_bwd_weight_instance) + add_gtest_executable(test_grouped_convnd_bwd_weight_interface test_grouped_convnd_bwd_weight_interface_wmma.cpp) + target_link_libraries(test_grouped_convnd_bwd_weight_interface PRIVATE utility) set(target 1) endif() endforeach() \ No newline at end of file diff --git a/test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight.cpp b/test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight.cpp index 2132020320..856f9fd15c 100644 --- a/test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight.cpp +++ b/test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight.cpp @@ -11,6 +11,7 @@ #include "ck/utility/common_header.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/host_utility/device_prop.hpp" #include "profiler/profile_grouped_conv_bwd_weight_impl.hpp" @@ -33,8 +34,9 @@ class TestGroupedConvndBwdWeight : public ::testing::Test bool skip_case(const ck::utils::conv::ConvParam& params, const ck::index_t split_k) { - // Odd K or C values are supported only by DL kernel (only applies to fp16) - // DL kernel currently supports only `split_k=1` + // Odd K or C values are supported only by DL and WMMA + // kernels (only applies to fp16) + // DL and WMMA kernels currently support only `split_k=1` if constexpr(std::is_same_v) { if(split_k != 1 && (params.K_ % 2 != 0 || params.C_ % 2 != 0)) @@ -53,6 +55,42 @@ class TestGroupedConvndBwdWeight : public ::testing::Test } } + const bool is_navi3x = ck::get_device_name() == "gfx1100" || + ck::get_device_name() == "gfx1101" || + ck::get_device_name() == "gfx1102"; + if(is_navi3x) + { + // on navi3x only support for 3d is implemented + if constexpr(NDimSpatial{} != 3) + { + return true; + } + // on navi3x only support for i8 and fp16 is implemented + if constexpr(!((std::is_same_v && + std::is_same_v && + std::is_same_v) || + (std::is_same_v && + std::is_same_v && + std::is_same_v))) + { + return true; + } + // WMMA kernel is only supported for split_k=1 + if(split_k != 1) + { + return true; + } + } + else + { + // support for i8 is only implemented on navi3x + if constexpr(std::is_same_v && + std::is_same_v && std::is_same_v) + { + return true; + } + } + return false; } @@ -120,9 +158,11 @@ using KernelTypes3d = ::testing::Types< std::tuple>, std::tuple>, std::tuple>, + std::tuple>, std::tuple>, std::tuple>, - std::tuple>>; + std::tuple>, + std::tuple>>; TYPED_TEST_SUITE(TestGroupedConvndBwdWeight1d, KernelTypes1d); TYPED_TEST_SUITE(TestGroupedConvndBwdWeight2d, KernelTypes2d); diff --git a/test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight_interface_wmma.cpp b/test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight_interface_wmma.cpp new file mode 100644 index 0000000000..1dcb8f866d --- /dev/null +++ b/test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight_interface_wmma.cpp @@ -0,0 +1,191 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle.hpp" + +#include "ck/library/utility/convolution_parameter.hpp" +#include "ck/library/utility/algorithm.hpp" +#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp" + +#include + +using F16 = ck::half_t; +using F32 = float; +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +template +using S = ck::Sequence; +using ConvolutionBackwardWeightSpecialization = + ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization; + +static constexpr auto ConvBwdWeightDefault = ConvolutionBackwardWeightSpecialization::Default; +static constexpr auto Filter1x1Stride1Pad0 = + ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0; + +template +class TestGroupedConvndBwdWeight : public ::testing::Test +{ + protected: + using OutLayout = std::tuple_element_t<0, Tuple>; + using WeiLayout = std::tuple_element_t<1, Tuple>; + using InLayout = std::tuple_element_t<2, Tuple>; + static constexpr ck::index_t NDimSpatial = std::tuple_element_t<3, Tuple>{}; + + // clang-format off + using GroupedConvBwdWeightDeviceInstance = ck::tensor_operation::device::DeviceGroupedConvBwdWeight_Wmma_CShuffle + //| NumDim| A| B| C| AData| BData| CData| AccData| A| B| C| ConvForward| Block| MPer| NPer| KPer| K1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //| Spatial| Layout| Layout| Layout| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | WMMA| WMMA| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeatPerWave| NRepeatPerWave| _MBlock_MPerBlock| ScalarPerVector| + //| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| + //| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + , S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>; + // clang-format on + + ck::utils::conv::ConvParam conv_param; + + template + bool Run() + { + + const auto in_g_n_c_wis_desc = + ck::utils::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed( + conv_param); + + const auto wei_g_k_c_xs_desc = + ck::utils::conv::make_weight_host_tensor_descriptor_g_k_c_xs_packed( + conv_param); + + const auto out_g_n_k_wos_desc = + ck::utils::conv::make_output_host_tensor_descriptor_g_n_k_wos_packed( + conv_param); + + std::array input_lengths{}; + std::array filter_lengths{}; + std::array output_lengths{}; + std::array input_strides{}; + std::array weights_strides{}; + std::array output_strides{}; + std::array conv_filter_strides{}; + std::array conv_filter_dilations{}; + std::array input_left_pads{}; + std::array input_right_pads{}; + + auto range_copy = [](const auto& from, auto to) { std::copy(begin(from), end(from), to); }; + + range_copy(in_g_n_c_wis_desc.GetLengths(), begin(input_lengths)); + range_copy(in_g_n_c_wis_desc.GetStrides(), begin(input_strides)); + range_copy(wei_g_k_c_xs_desc.GetLengths(), begin(filter_lengths)); + range_copy(wei_g_k_c_xs_desc.GetStrides(), begin(weights_strides)); + range_copy(out_g_n_k_wos_desc.GetLengths(), begin(output_lengths)); + range_copy(out_g_n_k_wos_desc.GetStrides(), begin(output_strides)); + range_copy(conv_param.conv_filter_strides_, begin(conv_filter_strides)); + range_copy(conv_param.conv_filter_dilations_, begin(conv_filter_dilations)); + range_copy(conv_param.input_left_pads_, begin(input_left_pads)); + range_copy(conv_param.input_right_pads_, begin(input_right_pads)); + + auto conv = GroupedConvBwdWeightDeviceInstance{}; + + auto argument = conv.MakeArgument(nullptr, + nullptr, + nullptr, + input_lengths, + input_strides, + filter_lengths, + weights_strides, + output_lengths, + output_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + PassThrough{}, + PassThrough{}, + PassThrough{}, + SplitK); + return conv.IsSupportedArgument(argument); + } +}; + +using namespace ck::tensor_layout::convolution; + +using KernelTypes3d = ::testing::Types>, + std::tuple>>; + +template +class TestGroupedConvndBwdWeightFilter1x13d + : public TestGroupedConvndBwdWeight +{ +}; + +template +class TestGroupedConvndBwdWeightDefault3d + : public TestGroupedConvndBwdWeight +{ +}; + +TYPED_TEST_SUITE(TestGroupedConvndBwdWeightFilter1x13d, KernelTypes3d); +TYPED_TEST_SUITE(TestGroupedConvndBwdWeightDefault3d, KernelTypes3d); + +TYPED_TEST(TestGroupedConvndBwdWeightFilter1x13d, SpecializationCheck) +{ + // Check filter 3x3x3 instead of 1x1x1 + this->conv_param = { + 3, 2, 4, 192, 192, {3, 3, 3}, {28, 28, 28}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}}; + bool is_supported = this->template Run<1>(); + EXPECT_FALSE(is_supported); + + // Check strides 2x2x2 instead of 1x1x1 + this->conv_param = { + 3, 2, 4, 192, 192, {1, 1, 1}, {28, 28, 28}, {2, 2, 2}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}}; + is_supported = this->template Run<1>(); + EXPECT_FALSE(is_supported); + + // Check with pad + this->conv_param = { + 3, 2, 4, 192, 192, {1, 1, 1}, {28, 28, 28}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}}; + is_supported = this->template Run<1>(); + EXPECT_FALSE(is_supported); + + // Supported version + this->conv_param = { + 3, 2, 128, 128, 256, {1, 1, 1}, {3, 3, 3}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}}; + is_supported = this->template Run<1>(); + EXPECT_TRUE(is_supported); +} + +TYPED_TEST(TestGroupedConvndBwdWeightDefault3d, VectorLoadCheck) +{ + // vector load for A + this->conv_param = { + 3, 2, 128, 129, 256, {1, 1, 1}, {7, 7, 7}, {2, 2, 2}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}}; + bool is_supported = this->template Run<1>(); + EXPECT_FALSE(is_supported); + // vector load for B, E, Ds + this->conv_param = { + 3, 2, 128, 128, 257, {1, 1, 1}, {7, 7, 7}, {2, 2, 2}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}}; + is_supported = this->template Run<1>(); + EXPECT_FALSE(is_supported); +} + +TYPED_TEST(TestGroupedConvndBwdWeightDefault3d, SplitKCheck) +{ + // SplitK=1 + this->conv_param = { + 3, 2, 128, 128, 256, {1, 1, 1}, {3, 3, 3}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}}; + bool is_supported = this->template Run<1>(); + EXPECT_TRUE(is_supported); + // SplitK=2 + this->conv_param = { + 3, 2, 128, 128, 256, {1, 1, 1}, {3, 3, 3}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}}; + is_supported = this->template Run<2>(); + EXPECT_FALSE(is_supported); +} diff --git a/test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight_interface.cpp b/test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight_interface_xdl.cpp similarity index 100% rename from test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight_interface.cpp rename to test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight_interface_xdl.cpp