diff --git a/example/38_grouped_conv_bwd_data_multiple_d/CMakeLists.txt b/example/38_grouped_conv_bwd_data_multiple_d/CMakeLists.txt index 1675e3663a..96d1a6c3c0 100644 --- a/example/38_grouped_conv_bwd_data_multiple_d/CMakeLists.txt +++ b/example/38_grouped_conv_bwd_data_multiple_d/CMakeLists.txt @@ -1,15 +1,27 @@ -list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) +list(APPEND gpu_list_xdl gfx908 gfx90a gfx940 gfx941 gfx942) +list(APPEND gpu_list_wmma gfx1100 gfx1101 gfx1102) set(target 0) foreach(gpu IN LISTS GPU_TARGETS) - if(gpu IN_LIST gpu_list AND target EQUAL 0) + if(gpu IN_LIST gpu_list_xdl AND target EQUAL 0) add_custom_target(example_grouped_conv_bwd_data) - add_example_executable(example_grouped_conv_bwd_data_fp16 grouped_conv_bwd_data_fp16.cpp) + add_example_executable(example_grouped_conv_bwd_data_xdl_fp16 grouped_conv_bwd_data_xdl_fp16.cpp) if(result EQUAL 0) - add_dependencies(example_grouped_conv_bwd_data example_grouped_conv_bwd_data_fp16) + add_dependencies(example_grouped_conv_bwd_data example_grouped_conv_bwd_data_xdl_fp16) endif() - add_example_executable(example_grouped_conv_bwd_data_bias_relu_fp16 grouped_conv_bwd_data_bias_relu_fp16.cpp) + add_example_executable(example_grouped_conv_bwd_data_bias_relu_xdl_fp16 grouped_conv_bwd_data_bias_relu_xdl_fp16.cpp) if(result EQUAL 0) - add_dependencies(example_grouped_conv_bwd_data example_grouped_conv_bwd_data_bias_relu_fp16) + add_dependencies(example_grouped_conv_bwd_data example_grouped_conv_bwd_data_bias_relu_xdl_fp16) + endif() + set(target 1) + endif() +endforeach() + +foreach(gpu IN LISTS GPU_TARGETS) + if(gpu IN_LIST gpu_list_wmma AND target EQUAL 0) + add_custom_target(example_grouped_conv_bwd_data) + add_example_executable(example_grouped_conv_bwd_data_wmma_fp16 grouped_conv_bwd_data_wmma_fp16.cpp) + if(result EQUAL 0) + add_dependencies(example_grouped_conv_bwd_data example_grouped_conv_bwd_data_wmma_fp16) endif() set(target 1) endif() diff --git a/example/38_grouped_conv_bwd_data_multiple_d/common.hpp b/example/38_grouped_conv_bwd_data_multiple_d/common.hpp index ca824b1075..ebb1c606cb 100644 --- a/example/38_grouped_conv_bwd_data_multiple_d/common.hpp +++ b/example/38_grouped_conv_bwd_data_multiple_d/common.hpp @@ -10,7 +10,6 @@ #include "ck/ck.hpp" #include "ck/tensor_operation/gpu/device/convolution_backward_data_specialization.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" diff --git a/example/38_grouped_conv_bwd_data_multiple_d/grouped_conv_bwd_data_bias_relu_fp16.cpp b/example/38_grouped_conv_bwd_data_multiple_d/grouped_conv_bwd_data_bias_relu_xdl_fp16.cpp similarity index 97% rename from example/38_grouped_conv_bwd_data_multiple_d/grouped_conv_bwd_data_bias_relu_fp16.cpp rename to example/38_grouped_conv_bwd_data_multiple_d/grouped_conv_bwd_data_bias_relu_xdl_fp16.cpp index a3533bb4cc..4c28e25e01 100644 --- a/example/38_grouped_conv_bwd_data_multiple_d/grouped_conv_bwd_data_bias_relu_fp16.cpp +++ b/example/38_grouped_conv_bwd_data_multiple_d/grouped_conv_bwd_data_bias_relu_xdl_fp16.cpp @@ -1,6 +1,7 @@ // SPDX-License-Identifier: MIT // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp" #include "common.hpp" using OutDataType = FP16; diff --git a/example/38_grouped_conv_bwd_data_multiple_d/grouped_conv_bwd_data_wmma_fp16.cpp b/example/38_grouped_conv_bwd_data_multiple_d/grouped_conv_bwd_data_wmma_fp16.cpp new file mode 100644 index 0000000000..5baa521501 --- /dev/null +++ b/example/38_grouped_conv_bwd_data_multiple_d/grouped_conv_bwd_data_wmma_fp16.cpp @@ -0,0 +1,34 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp" +#include "common.hpp" + +using OutDataType = FP16; +using WeiDataType = FP16; +using AccDataType = FP32; +using CShuffleDataType = FP16; +using DsDataType = ck::Tuple<>; +using InDataType = FP16; + +using OutLayout = ck::tensor_layout::convolution::GNHWK; +using WeiLayout = ck::tensor_layout::convolution::GKYXC; +using DsLayout = ck::Tuple<>; +using InLayout = ck::tensor_layout::convolution::GNHWC; + +using OutElementOp = PassThrough; +using WeiElementOp = PassThrough; +using InElementOp = PassThrough; + +// clang-format off +using DeviceConvInstance = ck::tensor_operation::device::DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle + //| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| ConvForward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | WMMA| WMMA| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //| | | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + < 2,OutLayout, WeiLayout, DsLayout, InLayout, OutDataType, WeiDataType, AccDataType, CShuffleDataType, DsDataType, InDataType, OutElementOp, WeiElementOp, InElementOp, ConvBwdDataDefault, 128, 64, 64, 4, 8, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, 1>; +// clang-format on + +#include "run_grouped_conv_bwd_data_example.inc" + +int main(int argc, char* argv[]) { return run_grouped_conv_bwd_data_example(argc, argv); } diff --git a/example/38_grouped_conv_bwd_data_multiple_d/grouped_conv_bwd_data_fp16.cpp b/example/38_grouped_conv_bwd_data_multiple_d/grouped_conv_bwd_data_xdl_fp16.cpp similarity index 97% rename from example/38_grouped_conv_bwd_data_multiple_d/grouped_conv_bwd_data_fp16.cpp rename to example/38_grouped_conv_bwd_data_multiple_d/grouped_conv_bwd_data_xdl_fp16.cpp index fb688b6f3f..b1554412b1 100644 --- a/example/38_grouped_conv_bwd_data_multiple_d/grouped_conv_bwd_data_fp16.cpp +++ b/example/38_grouped_conv_bwd_data_multiple_d/grouped_conv_bwd_data_xdl_fp16.cpp @@ -1,6 +1,7 @@ // SPDX-License-Identifier: MIT // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp" #include "common.hpp" using OutDataType = FP16; diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp new file mode 100644 index 0000000000..35581de72b --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp @@ -0,0 +1,879 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_conv_bwd_data_multiple_d.hpp" +#include "ck/tensor_operation/gpu/device/convolution_backward_data_specialization.hpp" +#include "ck/tensor_operation/operator_transform/transform_conv_bwd_data_to_gemm_v1.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" +#include "ck/host_utility/io.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +// Conv backward data multiple D: +// input : output image A: [G, N, K, Ho, Wo] +// input : weight B: [G, K, C, Y, X], +// input : D0, D1, ... : [G, N, K, Ho, Wo] +// output : input image E: [G, N, C, Hi, Wi] +// C = a_op(A) * b_op(B) +// E = cde_op(C, D0, D1, ...) +template +struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle + : public DeviceGroupedConvBwdDataMultipleD +{ + // TODO: Extend support for more spatial dimensions. + static_assert(NDimSpatial == 2 || NDimSpatial == 3, + "wrong! only implemented for 2D and 3D now"); + + using DeviceOp = DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle; + + static constexpr index_t NumDTensor = DsDataType::Size(); + + // TODO: Add support for different A and B data types. + using ABDataType = ADataType; + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr index_t KPerBlock = K0PerBlock * K1; + + static constexpr auto transform_conv_to_gemm = + TransformConvBwdDataToGemm_v1{}; + + static auto GetDummyABDsEGridDescriptor() + { + const std::array dummy_tensor_lengths = {1}; + const std::array dummy_tensor_strides = {1}; + const std::array dummy_spatial_lengths = {1}; + + const auto a_grid_desc_ak0_m_ak1 = + transform_conv_to_gemm.template MakeADescriptor_AK0_M_AK1( + dummy_tensor_lengths, + dummy_tensor_strides, + dummy_tensor_lengths, + dummy_tensor_strides, + dummy_tensor_lengths, + dummy_tensor_strides, + dummy_spatial_lengths, + dummy_spatial_lengths, + dummy_spatial_lengths, + dummy_spatial_lengths, + dummy_spatial_lengths); + + const auto b_grid_desc_bk0_n_bk1 = + transform_conv_to_gemm.template MakeBDescriptor_BK0_N_BK1( + dummy_tensor_lengths, + dummy_tensor_strides, + dummy_tensor_lengths, + dummy_tensor_strides, + dummy_tensor_lengths, + dummy_tensor_strides, + dummy_spatial_lengths, + dummy_spatial_lengths, + dummy_spatial_lengths, + dummy_spatial_lengths, + dummy_spatial_lengths); + + const auto ds_grid_desc_m_n = generate_tuple( + [&](auto i) { + using DLayout = remove_cvref_t>; + + return transform_conv_to_gemm.template MakeCDescriptor_M_N( + dummy_tensor_lengths, + dummy_tensor_strides, + dummy_tensor_lengths, + dummy_tensor_strides, + dummy_tensor_lengths, + dummy_tensor_strides, + dummy_spatial_lengths, + dummy_spatial_lengths, + dummy_spatial_lengths, + dummy_spatial_lengths, + dummy_spatial_lengths); + }, + Number{}); + + const auto e_grid_desc_m_n = + transform_conv_to_gemm.template MakeCDescriptor_M_N(dummy_tensor_lengths, + dummy_tensor_strides, + dummy_tensor_lengths, + dummy_tensor_strides, + dummy_tensor_lengths, + dummy_tensor_strides, + dummy_spatial_lengths, + dummy_spatial_lengths, + dummy_spatial_lengths, + dummy_spatial_lengths, + dummy_spatial_lengths); + + return make_tuple( + a_grid_desc_ak0_m_ak1, b_grid_desc_bk0_n_bk1, ds_grid_desc_m_n, e_grid_desc_m_n); + } + + // desc + using ABDsEGridDesc = decltype(GetDummyABDsEGridDescriptor()); + + using AGridDesc_AK0_M_AK1 = remove_cvref_t>; + using BGridDesc_BK0_N_BK1 = remove_cvref_t>; + using DsGridDesc_M_N = remove_cvref_t>; + using EGridDesc_M_N = remove_cvref_t>; + + // GridwiseGemm + using GridwiseGemm = GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle< + // DataType Family + ADataType, + BDataType, + AccDataType, + CShuffleDataType, + DsDataType, + EDataType, + // InMemory Data Descriptor + AGridDesc_AK0_M_AK1, + BGridDesc_BK0_N_BK1, + DsGridDesc_M_N, + EGridDesc_M_N, + // ElementwiseOp Family + AElementwiseOp, + BElementwiseOp, + CDEElementwiseOp, + InMemoryDataOperationEnum::Set, + // Tiling Family + MPerBlock, + NPerBlock, + K0PerBlock, + MPerWMMA, + NPerWMMA, + K1, + MRepeat, + NRepeat, + // ThreadCluster Family + BlockSize, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + false, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + false, + BBlockLdsExtraN, + CShuffleMRepeatPerShuffle, + CShuffleNRepeatPerShuffle, + CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CDEShuffleBlockTransferScalarPerVector_NPerBlock, + NumGemmKPrefetchStage, + LoopSched, + PipelineVer>; + + using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = + decltype(GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + DsGridDesc_M_N{})); + using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = + decltype(GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + EGridDesc_M_N{})); + + // Argument + struct Argument : public BaseArgument + { + Argument(const void* p_a, // output image + const void* p_b, // weight + const std::array& p_ds, // bias + void* p_e, // input image + const std::array& a_g_n_k_wos_lengths, + const std::array& a_g_n_k_wos_strides, + const std::array& b_g_k_c_xs_lengths, + const std::array& b_g_k_c_xs_strides, + const std::array, NumDTensor>& + ds_g_n_c_wis_lengths, + const std::array, NumDTensor>& + ds_g_n_c_wis_strides, + const std::array& e_g_n_c_wis_lengths, + const std::array& e_g_n_c_wis_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads, + const AElementwiseOp& a_element_op, + const BElementwiseOp& b_element_op, + const CDEElementwiseOp& cde_element_op) + : p_a_grid_{static_cast(p_a)}, + p_b_grid_{static_cast(p_b)}, + p_ds_grid_{}, + p_e_grid_{static_cast(p_e)}, + num_group_{a_g_n_k_wos_lengths[0]}, + a_element_op_{a_element_op}, + b_element_op_{b_element_op}, + cde_element_op_{cde_element_op}, + a_g_n_k_wos_lengths_{a_g_n_k_wos_lengths}, + a_g_n_k_wos_strides_{a_g_n_k_wos_strides}, + b_g_k_c_xs_lengths_{b_g_k_c_xs_lengths}, + b_g_k_c_xs_strides_{b_g_k_c_xs_strides}, + ds_g_n_c_wis_lengths_{ds_g_n_c_wis_lengths}, + ds_g_n_c_wis_strides_{ds_g_n_c_wis_strides}, + e_g_n_c_wis_lengths_{e_g_n_c_wis_lengths}, + e_g_n_c_wis_strides_{e_g_n_c_wis_strides}, + conv_filter_strides_{conv_filter_strides}, + conv_filter_dilations_{conv_filter_dilations}, + input_left_pads_{input_left_pads}, + input_right_pads_{input_right_pads} + { + // populate Ds pointer + static_for<0, NumDTensor, 1>{}([&](auto i) { + using DDataType = remove_cvref_t>; + + p_ds_grid_(i) = static_cast(p_ds[i]); + }); + + // A/B/Ds/E Batch Stride + compute_ptr_offset_of_batch_.BatchStrideA_ = a_g_n_k_wos_strides[0]; + compute_ptr_offset_of_batch_.BatchStrideB_ = b_g_k_c_xs_strides[0]; + compute_ptr_offset_of_batch_.BatchStrideE_ = e_g_n_c_wis_strides[0]; + + static_for<0, NumDTensor, 1>{}([&](auto i) { + compute_ptr_offset_of_batch_.BatchStrideDs_(i) = ds_g_n_c_wis_strides[i][0]; + }); + + static constexpr auto NonSpatialDimsNum = Number<3>{}; + + static constexpr auto DIdx = Number{}; + static constexpr auto HIdx = + NDimSpatial == 2 ? Number{} : Number{}; + static constexpr auto WIdx = NDimSpatial == 2 ? Number{} + : Number{}; + + static constexpr auto ZIdx = Number{}; + static constexpr auto YIdx = + NDimSpatial == 2 ? Number{} : Number{}; + static constexpr auto XIdx = NDimSpatial == 2 ? Number{} + : Number{}; + + // problem definition + const index_t Z = b_g_k_c_xs_lengths[ZIdx]; + const index_t Y = b_g_k_c_xs_lengths[YIdx]; + const index_t X = b_g_k_c_xs_lengths[XIdx]; + + const index_t ConvStrideD = conv_filter_strides[DIdx - NonSpatialDimsNum]; + const index_t ConvStrideH = conv_filter_strides[HIdx - NonSpatialDimsNum]; + const index_t ConvStrideW = conv_filter_strides[WIdx - NonSpatialDimsNum]; + + const index_t ConvDilationD = conv_filter_dilations[DIdx - NonSpatialDimsNum]; + const index_t ConvDilationH = conv_filter_dilations[HIdx - NonSpatialDimsNum]; + const index_t ConvDilationW = conv_filter_dilations[WIdx - NonSpatialDimsNum]; + + const auto GcdStrideDilationD = math::gcd(ConvStrideD, ConvDilationD); + const auto GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH); + const auto GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW); + + const auto ZTilde = NDimSpatial == 3 ? ConvStrideD / GcdStrideDilationD : 1; + const auto YTilde = ConvStrideH / GcdStrideDilationH; + const auto XTilde = ConvStrideW / GcdStrideDilationW; + + for(index_t i_ztilde = 0; i_ztilde < ZTilde; ++i_ztilde) + { + + for(index_t i_ytilde = 0; i_ytilde < YTilde; ++i_ytilde) + { + for(index_t i_xtilde = 0; i_xtilde < XTilde; ++i_xtilde) + { + // check slice is valid + const auto ZDotSlice = + NDimSpatial == 3 ? math::integer_divide_ceil(Z - i_ztilde, ZTilde) : 1; + const auto YDotSlice = math::integer_divide_ceil(Y - i_ytilde, YTilde); + const auto XDotSlice = math::integer_divide_ceil(X - i_xtilde, XTilde); + + if(YDotSlice * XDotSlice * ZDotSlice <= 0) + { + continue; + } + + std::array tildes; + if constexpr(NDimSpatial == 2) + { + tildes = {i_ytilde, i_xtilde}; + } + else if constexpr(NDimSpatial == 3) + { + tildes = {i_ztilde, i_ytilde, i_xtilde}; + } + else + { + throw std::runtime_error("wrong! only implemented for 2D and 3D now"); + } + + const auto a_grid_desc_ak0_m_ak1 = + transform_conv_to_gemm.template MakeADescriptor_AK0_M_AK1( + a_g_n_k_wos_lengths, + a_g_n_k_wos_strides, + b_g_k_c_xs_lengths, + b_g_k_c_xs_strides, + e_g_n_c_wis_lengths, + e_g_n_c_wis_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + tildes); + + const auto b_grid_desc_bk0_n_bk1 = + transform_conv_to_gemm.template MakeBDescriptor_BK0_N_BK1( + a_g_n_k_wos_lengths, + a_g_n_k_wos_strides, + b_g_k_c_xs_lengths, + b_g_k_c_xs_strides, + e_g_n_c_wis_lengths, + e_g_n_c_wis_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + tildes); + + DsGridDesc_M_N ds_grid_desc_m_n; + + // populate Ds desc + static_for<0, NumDTensor, 1>{}([&](auto i) { + using DLayout = remove_cvref_t>; + + ds_grid_desc_m_n(i) = + transform_conv_to_gemm.template MakeCDescriptor_M_N( + a_g_n_k_wos_lengths, + a_g_n_k_wos_strides, + b_g_k_c_xs_lengths, + b_g_k_c_xs_strides, + ds_g_n_c_wis_lengths[i], + ds_g_n_c_wis_strides[i], + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + tildes); + }); + + const auto e_grid_desc_m_n = + transform_conv_to_gemm.template MakeCDescriptor_M_N( + a_g_n_k_wos_lengths, + a_g_n_k_wos_strides, + b_g_k_c_xs_lengths, + b_g_k_c_xs_strides, + e_g_n_c_wis_lengths, + e_g_n_c_wis_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + tildes); + + // for check validity + ds_grid_desc_m_n_container_.push_back(ds_grid_desc_m_n); + e_grid_desc_m_n_container_.push_back(e_grid_desc_m_n); + + // desc for blockwise copy + a_grid_desc_ak0_m_ak1_container_.push_back(a_grid_desc_ak0_m_ak1); + b_grid_desc_bk0_n_bk1_container_.push_back(b_grid_desc_bk0_n_bk1); + + // block-to-e-tile-map + auto block_2_ctile_map = GridwiseGemm::MakeDefaultBlock2CTileMap( + e_grid_desc_m_n, 1 /* M01 */, 1 /* N01 */); + + block_2_ctile_map_container_.push_back(block_2_ctile_map); + + ds_grid_desc_mblock_mperblock_nblock_nperblock_container_.push_back( + GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + ds_grid_desc_m_n)); + e_grid_desc_mblock_mperblock_nblock_nperblock_container_.push_back( + GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + e_grid_desc_m_n)); + } + } + } + } + + void Print() const + { + for(std::size_t i = 0; i < a_grid_desc_ak0_m_ak1_container_.size(); i++) + { + std::cout << "a_grid_desc_ak0_m_ak1_container_" + << a_grid_desc_ak0_m_ak1_container_[i] << std::endl; + + std::cout << "b_grid_desc_bk0_n_bk1_container_" + << b_grid_desc_bk0_n_bk1_container_[i] << std::endl; + + static_for<0, NumDTensor, 1>{}([&](auto j) { + std::cout << "ds_grid_desc_mblock_mperblock_nblock_nperblock_container_" + << ds_grid_desc_mblock_mperblock_nblock_nperblock_container_[i][j] + << std::endl; + }); + + std::cout << "e_grid_desc_mblock_mperblock_nblock_nperblock_container_" + << e_grid_desc_mblock_mperblock_nblock_nperblock_container_[i] + << std::endl; + } + } + + // pointers + const ADataType* p_a_grid_; + const BDataType* p_b_grid_; + typename GridwiseGemm::DsGridPointer p_ds_grid_; + EDataType* p_e_grid_; + + // tensor descriptor for problem definition + index_t num_group_; + std::vector ds_grid_desc_m_n_container_; + std::vector e_grid_desc_m_n_container_; + + // tensor descriptor for block-wise copy + std::vector a_grid_desc_ak0_m_ak1_container_; + std::vector b_grid_desc_bk0_n_bk1_container_; + std::vector + ds_grid_desc_mblock_mperblock_nblock_nperblock_container_; + std::vector + e_grid_desc_mblock_mperblock_nblock_nperblock_container_; + + // block-to-e-tile map + std::vector block_2_ctile_map_container_; + + // for computing batch offset + ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch_; + + // element-wise op + AElementwiseOp a_element_op_; + BElementwiseOp b_element_op_; + CDEElementwiseOp cde_element_op_; + + // for checking IsSupportedArgument() + std::array a_g_n_k_wos_lengths_; + std::array a_g_n_k_wos_strides_; + std::array b_g_k_c_xs_lengths_; + std::array b_g_k_c_xs_strides_; + std::array, NumDTensor> ds_g_n_c_wis_lengths_; + std::array, NumDTensor> ds_g_n_c_wis_strides_; + std::array e_g_n_c_wis_lengths_; + std::array e_g_n_c_wis_strides_; + std::array conv_filter_strides_; + std::array conv_filter_dilations_; + std::array input_left_pads_; + std::array input_right_pads_; + }; + + // Invoker + struct Invoker : public BaseInvoker + { + using Argument = DeviceOp::Argument; + + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + if(stream_config.log_level_ > 0) + { + arg.Print(); + } + + float ave_time = 0; + + for(std::size_t i = 0; i < arg.a_grid_desc_ak0_m_ak1_container_.size(); i++) + { + const index_t grid_size = arg.block_2_ctile_map_container_[i].CalculateGridSize( + arg.e_grid_desc_m_n_container_[i]) * + arg.num_group_; + + const auto GemmK = arg.a_grid_desc_ak0_m_ak1_container_[i].GetLength(I0) * + arg.a_grid_desc_ak0_m_ak1_container_[i].GetLength(I2); + + auto launch_kernel = [&](auto has_main_k_block_loop) { + constexpr bool has_main_loop = has_main_k_block_loop.value; + + const auto kernel = kernel_grouped_conv_fwd_multiple_d_wmma_cshuffle< + GridwiseGemm, + ADataType, + BDataType, + typename GridwiseGemm::DsGridPointer, + EDataType, + AElementwiseOp, + BElementwiseOp, + CDEElementwiseOp, + DeviceOp::AGridDesc_AK0_M_AK1, + DeviceOp::BGridDesc_BK0_N_BK1, + typename GridwiseGemm::DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, + typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, + remove_reference_t, + ComputePtrOffsetOfStridedBatch, + has_main_loop>; + + return launch_and_time_kernel( + stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_ds_grid_, + arg.p_e_grid_, + arg.a_element_op_, + arg.b_element_op_, + arg.cde_element_op_, + arg.a_g_n_k_wos_lengths_[0], // Group count + arg.a_grid_desc_ak0_m_ak1_container_[i], + arg.b_grid_desc_bk0_n_bk1_container_[i], + arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_container_[i], + arg.e_grid_desc_mblock_mperblock_nblock_nperblock_container_[i], + arg.block_2_ctile_map_container_[i], + arg.compute_ptr_offset_of_batch_); + }; + + if(GridwiseGemm::CalculateHasMainKBlockLoop(GemmK)) + { + ave_time += launch_kernel(integral_constant{}); + } + else + { + ave_time += launch_kernel(integral_constant{}); + } + } + + return ave_time; + } + + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static bool IsSupportedArgument(const Argument& arg) + { + // check device + if(get_device_name() == "gfx1100" || get_device_name() == "gfx1101" || + ck::get_device_name() == "gfx1102") + { + if constexpr(!(is_same_v || is_same_v)) + { + return false; + } + } + else + { + return false; + } + + const index_t ConvK = arg.b_g_k_c_xs_lengths_[1]; + const index_t ConvC = arg.b_g_k_c_xs_lengths_[2]; + + // Specialization + if constexpr(ConvBackwardDataSpecialization == + ConvolutionBackwardDataSpecialization::Filter1x1Stride1Pad0) + { + // check if it's a 1x1 convolution with stride=1 and no padding + for(int i = 0; i < NDimSpatial; i++) + { + if(!(arg.b_g_k_c_xs_lengths_[3 + i] == 1 && arg.conv_filter_strides_[i] == 1 && + arg.input_left_pads_[i] == 0 && arg.input_right_pads_[i] == 0)) + { + return false; + } + } + } + + // vector load for A matrix from global memory to LDS + if constexpr(is_same_v || + is_same_v || + is_same_v || + is_same_v) + { + if(!(ABlockTransferSrcVectorDim == 2 && ConvK % ABlockTransferSrcScalarPerVector == 0)) + { + return false; + } + } + else + { + return false; + } + + // vector load for B matrix from global memory to LDS + if constexpr(is_same_v || + is_same_v) + { + if(!(BBlockTransferSrcVectorDim == 1 && ConvC % BBlockTransferSrcScalarPerVector == 0)) + { + return false; + } + } + else + { + return false; + } + + // vector store for Ds + bool ds_valid = true; + + static_for<0, NumDTensor, 1>{}([&](auto i) { + using DLayout = remove_cvref_t>; + + if constexpr(is_same_v || + is_same_v || + is_same_v || + is_same_v || + is_same_v || + is_same_v || + is_same_v) + { + // vector load D matrix from global memory + if(!(ConvC % CDEShuffleBlockTransferScalarPerVector_NPerBlock == 0)) + { + ds_valid = false; + } + } + else + { + ds_valid = false; + } + }); + + if(!ds_valid) + { + return false; + } + + // vector store for E + if constexpr(is_same_v || + is_same_v || + is_same_v || + is_same_v) + { + // vector store C matrix into global memory + if(!(ConvC % CDEShuffleBlockTransferScalarPerVector_NPerBlock == 0)) + { + return false; + } + } + else + { + return false; + } + + // Gridwise GEMM size + for(std::size_t i = 0; i < arg.a_grid_desc_ak0_m_ak1_container_.size(); i++) + { + if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_container_[i], + arg.b_grid_desc_bk0_n_bk1_container_[i], + arg.ds_grid_desc_m_n_container_[i], + arg.e_grid_desc_m_n_container_[i], + arg.block_2_ctile_map_container_[i])) + { + return false; + } + } + + return true; + } + + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto + MakeArgument(const void* p_a, // output image + const void* p_b, // weight + const std::array& p_ds, // bias + void* p_e, // input image + const std::array& a_g_n_k_wos_lengths, // output image + const std::array& a_g_n_k_wos_strides, // output image + const std::array& b_g_k_c_xs_lengths, // weight + const std::array& b_g_k_c_xs_strides, // weight + const std::array, NumDTensor>& + ds_g_n_c_wis_lengths, // bias + const std::array, NumDTensor>& + ds_g_n_c_wis_strides, // bias + const std::array& e_g_n_c_wis_lengths, // input image + const std::array& e_g_n_c_wis_strides, // input image + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads, + const AElementwiseOp& a_element_op, + const BElementwiseOp& b_element_op, + const CDEElementwiseOp& cde_element_op) + { + return Argument{p_a, + p_b, + p_ds, + p_e, + a_g_n_k_wos_lengths, + a_g_n_k_wos_strides, + b_g_k_c_xs_lengths, + b_g_k_c_xs_strides, + ds_g_n_c_wis_lengths, + ds_g_n_c_wis_strides, + e_g_n_c_wis_lengths, + e_g_n_c_wis_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + a_element_op, + b_element_op, + cde_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + std::unique_ptr MakeArgumentPointer( + const void* p_a, // output image + const void* p_b, // weight + const std::array& p_ds, // bias + void* p_e, // input image + const std::array& a_g_n_k_wos_lengths, // output image + const std::array& a_g_n_k_wos_strides, // output image + const std::array& b_g_k_c_xs_lengths, // weight + const std::array& b_g_k_c_xs_strides, // weight + const std::array, NumDTensor>& + ds_g_n_c_wis_lengths, // bias + const std::array, NumDTensor>& + ds_g_n_c_wis_strides, // bias + const std::array& e_g_n_c_wis_lengths, // input image + const std::array& e_g_n_c_wis_strides, // input image + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads, + const AElementwiseOp& a_element_op, + const BElementwiseOp& b_element_op, + const CDEElementwiseOp& cde_element_op) override + { + return std::make_unique(p_a, + p_b, + p_ds, + p_e, + a_g_n_k_wos_lengths, + a_g_n_k_wos_strides, + b_g_k_c_xs_lengths, + b_g_k_c_xs_strides, + ds_g_n_c_wis_lengths, + ds_g_n_c_wis_strides, + e_g_n_c_wis_lengths, + e_g_n_c_wis_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + a_element_op, + b_element_op, + cde_element_op); + } + + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle" + << "<" + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << KPerBlock << ", " + << getConvBackwardDataSpecializationString(ConvBackwardDataSpecialization) << ", " + << K1 << ", " + << ABlockTransferSrcScalarPerVector << ", " + << BBlockTransferSrcScalarPerVector + << ">"; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp index c828269acf..37d7b0b1fc 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp @@ -14,6 +14,7 @@ #include "ck/tensor_operation/gpu/device/convolution_backward_data_specialization.hpp" #include "ck/tensor_operation/operator_transform/transform_conv_bwd_data_to_gemm_v1.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp" #include "ck/host_utility/device_prop.hpp" #include "ck/host_utility/kernel_launch.hpp" #include "ck/host_utility/io.hpp" @@ -24,51 +25,6 @@ namespace device { namespace { -template -struct ComputePtrOffsetOfStridedBatch -{ - ComputePtrOffsetOfStridedBatch() = default; - - ComputePtrOffsetOfStridedBatch(index_t BatchStrideA, - index_t BatchStrideB, - Array BatchStrideDs, - index_t BatchStrideE) - : BatchStrideA_(BatchStrideA), - BatchStrideB_(BatchStrideB), - BatchStrideDs_(BatchStrideDs), - BatchStrideE_(BatchStrideE) - { - } - - __host__ __device__ constexpr long_index_t GetAPtrOffset(index_t g_idx) const - { - return g_idx * static_cast(BatchStrideA_); - } - - __host__ __device__ constexpr long_index_t GetBPtrOffset(index_t g_idx) const - { - return g_idx * static_cast(BatchStrideB_); - } - - __host__ __device__ constexpr auto GetDsPtrOffset(index_t g_idx) const - { - Array ds_offset; - static_for<0, NumDTensor, 1>{}( - [&](auto i) { ds_offset(i) = g_idx * static_cast(BatchStrideDs_[i]); }); - return ds_offset; - } - - __host__ __device__ constexpr long_index_t GetEPtrOffset(index_t g_idx) const - { - return g_idx * static_cast(BatchStrideE_); - } - - index_t BatchStrideA_; - index_t BatchStrideB_; - Array BatchStrideDs_; - index_t BatchStrideE_; -}; - /* * \brief Wrapper function of GridwiseGemm::Run to realize BatchedGEMM. * @@ -257,7 +213,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 BElementwiseOp, CDEElementwiseOp> { - // FIXME + // TODO: Extend support for more spatial dimensions. static_assert(NDimSpatial == 2 || NDimSpatial == 3, "wrong! only implemented for 2D and 3D now"); @@ -265,7 +221,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 static constexpr index_t NumDTensor = DsDataType::Size(); - // TODO make A/B datatype different + // TODO: Add support for different A and B data types. using ABDataType = ADataType; static constexpr auto I0 = Number<0>{}; diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp index 8b22bd2090..8cfdd04e55 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp @@ -19,6 +19,7 @@ #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/matrix_padder.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_gemm_dl_multiple_d.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp" #include "ck/host_utility/device_prop.hpp" #include "ck/host_utility/kernel_launch.hpp" #include "ck/host_utility/io.hpp" @@ -29,51 +30,6 @@ namespace device { namespace { -template -struct ComputePtrOffsetOfStridedBatch -{ - ComputePtrOffsetOfStridedBatch() = default; - - ComputePtrOffsetOfStridedBatch(index_t BatchStrideA, - index_t BatchStrideB, - Array BatchStrideDs, - index_t BatchStrideE) - : BatchStrideA_(BatchStrideA), - BatchStrideB_(BatchStrideB), - BatchStrideDs_(BatchStrideDs), - BatchStrideE_(BatchStrideE) - { - } - - __host__ __device__ constexpr long_index_t GetAPtrOffset(index_t g_idx) const - { - return g_idx * static_cast(BatchStrideA_); - } - - __host__ __device__ constexpr long_index_t GetBPtrOffset(index_t g_idx) const - { - return g_idx * static_cast(BatchStrideB_); - } - - __host__ __device__ constexpr auto GetDsPtrOffset(index_t g_idx) const - { - Array ds_offset; - static_for<0, NumDTensor, 1>{}( - [&](auto i) { ds_offset(i) = g_idx * static_cast(BatchStrideDs_[i]); }); - return ds_offset; - } - - __host__ __device__ constexpr long_index_t GetEPtrOffset(index_t g_idx) const - { - return g_idx * static_cast(BatchStrideE_); - } - - index_t BatchStrideA_; - index_t BatchStrideB_; - Array BatchStrideDs_; - index_t BatchStrideE_; -}; - /* * \brief Wrapper function of GridwiseGemm::Run to realize BatchedGEMM. * diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp index 1d73a72387..e4e793eff8 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp @@ -19,6 +19,7 @@ #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/matrix_padder.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp" #include "ck/host_utility/device_prop.hpp" #include "ck/host_utility/kernel_launch.hpp" #include "ck/host_utility/io.hpp" @@ -27,55 +28,6 @@ namespace ck { namespace tensor_operation { namespace device { -namespace { - -template -struct ComputePtrOffsetOfStridedBatch -{ - ComputePtrOffsetOfStridedBatch() = default; - - ComputePtrOffsetOfStridedBatch(index_t BatchStrideA, - index_t BatchStrideB, - Array BatchStrideDs, - index_t BatchStrideE) - : BatchStrideA_(BatchStrideA), - BatchStrideB_(BatchStrideB), - BatchStrideDs_(BatchStrideDs), - BatchStrideE_(BatchStrideE) - { - } - - __host__ __device__ constexpr long_index_t GetAPtrOffset(index_t g_idx) const - { - return g_idx * static_cast(BatchStrideA_); - } - - __host__ __device__ constexpr long_index_t GetBPtrOffset(index_t g_idx) const - { - return g_idx * static_cast(BatchStrideB_); - } - - __host__ __device__ constexpr auto GetDsPtrOffset(index_t g_idx) const - { - Array ds_offset; - static_for<0, NumDTensor, 1>{}( - [&](auto i) { ds_offset(i) = g_idx * static_cast(BatchStrideDs_[i]); }); - return ds_offset; - } - - __host__ __device__ constexpr long_index_t GetEPtrOffset(index_t g_idx) const - { - return g_idx * static_cast(BatchStrideE_); - } - - index_t BatchStrideA_; - index_t BatchStrideB_; - Array BatchStrideDs_; - index_t BatchStrideE_; -}; - -} // namespace - // // @brief Device Convolution operation. // diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_cshuffle.hpp index bcef5c3b62..b04594fceb 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_cshuffle.hpp @@ -19,6 +19,7 @@ #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/matrix_padder.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp" #include "ck/host_utility/device_prop.hpp" #include "ck/host_utility/kernel_launch.hpp" #include "ck/host_utility/io.hpp" @@ -29,51 +30,6 @@ namespace device { namespace { -template -struct ComputePtrOffsetOfStridedBatch -{ - ComputePtrOffsetOfStridedBatch() = default; - - ComputePtrOffsetOfStridedBatch(index_t BatchStrideA, - index_t BatchStrideB, - Array BatchStrideDs, - index_t BatchStrideE) - : BatchStrideA_(BatchStrideA), - BatchStrideB_(BatchStrideB), - BatchStrideDs_(BatchStrideDs), - BatchStrideE_(BatchStrideE) - { - } - - __host__ __device__ constexpr long_index_t GetAPtrOffset(index_t g_idx) const - { - return g_idx * static_cast(BatchStrideA_); - } - - __host__ __device__ constexpr long_index_t GetBPtrOffset(index_t g_idx) const - { - return g_idx * static_cast(BatchStrideB_); - } - - __host__ __device__ constexpr auto GetDsPtrOffset(index_t g_idx) const - { - Array ds_offset; - static_for<0, NumDTensor, 1>{}( - [&](auto i) { ds_offset(i) = g_idx * static_cast(BatchStrideDs_[i]); }); - return ds_offset; - } - - __host__ __device__ constexpr long_index_t GetEPtrOffset(index_t g_idx) const - { - return g_idx * static_cast(BatchStrideE_); - } - - index_t BatchStrideA_; - index_t BatchStrideB_; - Array BatchStrideDs_; - index_t BatchStrideE_; -}; - /* * \brief Wrapper function of GridwiseGemm::Run to realize BatchedGEMM. * diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp new file mode 100644 index 0000000000..948232ed6f --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp @@ -0,0 +1,59 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct ComputePtrOffsetOfStridedBatch +{ + ComputePtrOffsetOfStridedBatch() = default; + + ComputePtrOffsetOfStridedBatch(index_t BatchStrideA, + index_t BatchStrideB, + Array BatchStrideDs, + index_t BatchStrideE) + : BatchStrideA_(BatchStrideA), + BatchStrideB_(BatchStrideB), + BatchStrideDs_(BatchStrideDs), + BatchStrideE_(BatchStrideE) + { + } + + __host__ __device__ constexpr long_index_t GetAPtrOffset(index_t g_idx) const + { + return g_idx * static_cast(BatchStrideA_); + } + + __host__ __device__ constexpr long_index_t GetBPtrOffset(index_t g_idx) const + { + return g_idx * static_cast(BatchStrideB_); + } + + __host__ __device__ constexpr auto GetDsPtrOffset(index_t g_idx) const + { + Array ds_offset; + static_for<0, NumDTensor, 1>{}( + [&](auto i) { ds_offset(i) = g_idx * static_cast(BatchStrideDs_[i]); }); + return ds_offset; + } + + __host__ __device__ constexpr long_index_t GetEPtrOffset(index_t g_idx) const + { + return g_idx * static_cast(BatchStrideE_); + } + + index_t BatchStrideA_; + index_t BatchStrideB_; + Array BatchStrideDs_; + index_t BatchStrideE_; +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_wmma_instance.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_wmma_instance.hpp new file mode 100644 index 0000000000..5db8226e11 --- /dev/null +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_wmma_instance.hpp @@ -0,0 +1,118 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/convolution_backward_data_specialization.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; +using I8 = int8_t; +using I32 = int32_t; + +using Empty_Tuple = ck::Tuple<>; + +template +using S = ck::Sequence; + +using namespace ck::tensor_layout::convolution; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto ConvBwdDataDefault = + ck::tensor_operation::device::ConvolutionBackwardDataSpecialization::Default; + +static constexpr auto ConvBwdData1x1S1P0 = + ck::tensor_operation::device::ConvolutionBackwardDataSpecialization::Filter1x1Stride1Pad0; + +template +using device_grouped_conv_bwd_data_wmma_f16_instances = std::tuple< + // clang-format off + //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| ConvForward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | WMMA| WMMA| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // generic instance + DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, 1>, + // blocksize=256 + DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + // blocksize=128 + DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + // blocksize=64 + DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 2>, 8>, + DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 2>, 8>, + DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 2>, 8>, + DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 2>, 8>, + // blocksize=32 + DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 2>, 8>, + DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 2>, 8>, + DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 2>, 8>, + DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 2>, 8> + // clang-format on + >; + +template +using device_grouped_conv_bwd_data_wmma_i8_instances = std::tuple< + // clang-format off + //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| ConvForward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | WMMA| WMMA| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // generic instance + DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 16, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 16, 1, 1, 1, S<1, 32, 1, 4>, 1>, + // blocksize=256 + DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 16, 1, 1, 1, S<1, 32, 1, 8>, 8>, + // blocksize=128 + DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 16, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 16, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 16, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 16, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 16, 1, 1, 1, S<1, 32, 1, 4>, 8>, + // blocksize=64 + DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<8, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 16, 1, 1, 1, S<1, 32, 1, 2>, 8>, + DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<8, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 16, 1, 1, 1, S<1, 32, 1, 2>, 8>, + DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<8, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 16, 1, 1, 1, S<1, 32, 1, 2>, 8>, + DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<8, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 16, 1, 1, 1, S<1, 32, 1, 2>, 8>, + // blocksize=32 + DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<8, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 16, 1, 1, 1, S<1, 16, 1, 2>, 8>, + DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<8, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 16, 1, 1, 1, S<1, 16, 1, 2>, 8>, + DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<8, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 16, 1, 1, 1, S<1, 16, 1, 2>, 8>, + DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<8, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 16, 1, 1, 1, S<1, 16, 1, 2>, 8> + // clang-format on + >; + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data.hpp index 2b49e17187..b259c640da 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data.hpp @@ -30,6 +30,34 @@ void add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_f16_instances( PassThrough, PassThrough, PassThrough>>>& instances); + +void add_device_grouped_conv2d_bwd_data_wmma_gnhwk_gkyxc_gnhwc_f16_instances( + std::vector>>& instances); + +void add_device_grouped_conv2d_bwd_data_wmma_gnhwk_gkyxc_gnhwc_f16_1x1s1p0_instances( + std::vector>>& instances); #endif #ifdef CK_ENABLE_FP32 void add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_f32_instances( @@ -61,6 +89,35 @@ void add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_bf16_instances( PassThrough, PassThrough>>>& instances); #endif +#ifdef CK_ENABLE_INT8 +void add_device_grouped_conv2d_bwd_data_wmma_gnhwk_gkyxc_gnhwc_i8_instances( + std::vector>>& instances); + +void add_device_grouped_conv2d_bwd_data_wmma_gnhwk_gkyxc_gnhwc_i8_1x1s1p0_instances( + std::vector>>& instances); +#endif #ifdef CK_ENABLE_FP16 void add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f16_instances( std::vector>>& instances); + +void add_device_grouped_conv2d_bwd_data_wmma_nhwgk_gkyxc_nhwgc_f16_instances( + std::vector>>& instances); + +void add_device_grouped_conv2d_bwd_data_wmma_nhwgk_gkyxc_nhwgc_f16_1x1s1p0_instances( + std::vector>>& instances); + #endif #ifdef CK_ENABLE_FP32 void add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f32_instances( @@ -106,6 +192,35 @@ void add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_bf16_instances( PassThrough, PassThrough>>>& instances); #endif +#ifdef CK_ENABLE_INT8 +void add_device_grouped_conv2d_bwd_data_wmma_nhwgk_gkyxc_nhwgc_i8_instances( + std::vector>>& instances); + +void add_device_grouped_conv2d_bwd_data_wmma_nhwgk_gkyxc_nhwgc_i8_1x1s1p0_instances( + std::vector>>& instances); +#endif // conv3d backward data #ifdef CK_ENABLE_FP16 void add_device_grouped_conv3d_bwd_data_xdl_gndhwk_gkzyxc_gndhwc_f16_instances( @@ -121,6 +236,34 @@ void add_device_grouped_conv3d_bwd_data_xdl_gndhwk_gkzyxc_gndhwc_f16_instances( PassThrough, PassThrough, PassThrough>>>& instances); + +void add_device_grouped_conv3d_bwd_data_wmma_gndhwk_gkzyxc_gndhwc_f16_instances( + std::vector>>& instances); + +void add_device_grouped_conv3d_bwd_data_wmma_gndhwk_gkzyxc_gndhwc_f16_1x1s1p0_instances( + std::vector>>& instances); #endif #ifdef CK_ENABLE_FP32 void add_device_grouped_conv3d_bwd_data_xdl_gndhwk_gkzyxc_gndhwc_f32_instances( @@ -152,6 +295,35 @@ void add_device_grouped_conv3d_bwd_data_xdl_gndhwk_gkzyxc_gndhwc_bf16_instances( PassThrough, PassThrough>>>& instances); #endif +#ifdef CK_ENABLE_INT8 +void add_device_grouped_conv3d_bwd_data_wmma_gndhwk_gkzyxc_gndhwc_i8_instances( + std::vector>>& instances); + +void add_device_grouped_conv3d_bwd_data_wmma_gndhwk_gkzyxc_gndhwc_i8_1x1s1p0_instances( + std::vector>>& instances); +#endif #ifdef CK_ENABLE_FP16 void add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f16_instances( std::vector>>& instances); + +void add_device_grouped_conv3d_bwd_data_wmma_ndhwgk_gkzyxc_ndhwgc_f16_instances( + std::vector>>& instances); + +void add_device_grouped_conv3d_bwd_data_wmma_ndhwgk_gkzyxc_ndhwgc_f16_1x1s1p0_instances( + std::vector>>& instances); #endif #ifdef CK_ENABLE_FP32 void add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f32_instances( @@ -197,6 +397,35 @@ void add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_bf16_instances( PassThrough, PassThrough>>>& instances); #endif +#ifdef CK_ENABLE_INT8 +void add_device_grouped_conv3d_bwd_data_wmma_ndhwgk_gkzyxc_ndhwgc_i8_instances( + std::vector>>& instances); + +void add_device_grouped_conv3d_bwd_data_wmma_ndhwgk_gkzyxc_ndhwgc_i8_1x1s1p0_instances( + std::vector>>& instances); +#endif template ) { add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_f16_instances(op_ptrs); + add_device_grouped_conv2d_bwd_data_wmma_gnhwk_gkyxc_gnhwc_f16_instances( + op_ptrs); + add_device_grouped_conv2d_bwd_data_wmma_gnhwk_gkyxc_gnhwc_f16_1x1s1p0_instances( + op_ptrs); } #endif #ifdef CK_ENABLE_FP32 @@ -263,6 +496,15 @@ struct DeviceOperationInstanceFactory< add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_bf16_instances( op_ptrs); } +#endif +#ifdef CK_ENABLE_INT8 + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_conv2d_bwd_data_wmma_gnhwk_gkyxc_gnhwc_i8_instances(op_ptrs); + add_device_grouped_conv2d_bwd_data_wmma_gnhwk_gkyxc_gnhwc_i8_1x1s1p0_instances( + op_ptrs); + } #endif } else if constexpr(is_same_v && is_same_v && @@ -273,6 +515,10 @@ struct DeviceOperationInstanceFactory< is_same_v) { add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f16_instances(op_ptrs); + add_device_grouped_conv2d_bwd_data_wmma_nhwgk_gkyxc_nhwgc_f16_instances( + op_ptrs); + add_device_grouped_conv2d_bwd_data_wmma_nhwgk_gkyxc_nhwgc_f16_1x1s1p0_instances( + op_ptrs); } #endif #ifdef CK_ENABLE_FP32 @@ -289,6 +535,15 @@ struct DeviceOperationInstanceFactory< add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_bf16_instances( op_ptrs); } +#endif +#ifdef CK_ENABLE_INT8 + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_conv2d_bwd_data_wmma_nhwgk_gkyxc_nhwgc_i8_instances(op_ptrs); + add_device_grouped_conv2d_bwd_data_wmma_nhwgk_gkyxc_nhwgc_i8_1x1s1p0_instances( + op_ptrs); + } #endif } } @@ -304,6 +559,10 @@ struct DeviceOperationInstanceFactory< { add_device_grouped_conv3d_bwd_data_xdl_gndhwk_gkzyxc_gndhwc_f16_instances( op_ptrs); + add_device_grouped_conv3d_bwd_data_wmma_gndhwk_gkzyxc_gndhwc_f16_instances( + op_ptrs); + add_device_grouped_conv3d_bwd_data_wmma_gndhwk_gkzyxc_gndhwc_f16_1x1s1p0_instances( + op_ptrs); } #endif #ifdef CK_ENABLE_FP32 @@ -321,6 +580,16 @@ struct DeviceOperationInstanceFactory< add_device_grouped_conv3d_bwd_data_xdl_gndhwk_gkzyxc_gndhwc_bf16_instances( op_ptrs); } +#endif +#ifdef CK_ENABLE_INT8 + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_conv3d_bwd_data_wmma_gndhwk_gkzyxc_gndhwc_i8_instances( + op_ptrs); + add_device_grouped_conv3d_bwd_data_wmma_gndhwk_gkzyxc_gndhwc_i8_1x1s1p0_instances( + op_ptrs); + } #endif } else if constexpr(is_same_v && is_same_v && @@ -332,6 +601,10 @@ struct DeviceOperationInstanceFactory< { add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f16_instances( op_ptrs); + add_device_grouped_conv3d_bwd_data_wmma_ndhwgk_gkzyxc_ndhwgc_f16_instances( + op_ptrs); + add_device_grouped_conv3d_bwd_data_wmma_ndhwgk_gkzyxc_ndhwgc_f16_1x1s1p0_instances( + op_ptrs); } #endif #ifdef CK_ENABLE_FP32 @@ -349,6 +622,16 @@ struct DeviceOperationInstanceFactory< add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_bf16_instances( op_ptrs); } +#endif +#ifdef CK_ENABLE_INT8 + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_conv3d_bwd_data_wmma_ndhwgk_gkzyxc_ndhwgc_i8_instances( + op_ptrs); + add_device_grouped_conv3d_bwd_data_wmma_ndhwgk_gkzyxc_ndhwgc_i8_1x1s1p0_instances( + op_ptrs); + } #endif } } diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/CMakeLists.txt index 85ec0f55aa..d8bd0de692 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/CMakeLists.txt @@ -1,8 +1,16 @@ add_instance_library(device_grouped_conv2d_bwd_data_instance - device_grouped_conv2d_bwd_data_xdl_gnhwc_gkyxc_gnhwk_f16_instance.cpp - device_grouped_conv2d_bwd_data_xdl_gnhwc_gkyxc_gnhwk_bf16_instance.cpp - device_grouped_conv2d_bwd_data_xdl_gnhwc_gkyxc_gnhwk_f32_instance.cpp - device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_f16_instance.cpp - device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_bf16_instance.cpp - device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_f32_instance.cpp -) + xdl/device_grouped_conv2d_bwd_data_xdl_gnhwc_gkyxc_gnhwk_f16_instance.cpp + xdl/device_grouped_conv2d_bwd_data_xdl_gnhwc_gkyxc_gnhwk_bf16_instance.cpp + xdl/device_grouped_conv2d_bwd_data_xdl_gnhwc_gkyxc_gnhwk_f32_instance.cpp + xdl/device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_f16_instance.cpp + xdl/device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_bf16_instance.cpp + xdl/device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_f32_instance.cpp + + wmma/device_grouped_conv2d_bwd_data_wmma_gnhwc_gkyxc_gnhwk_f16_1x1s1p0_instance.cpp + wmma/device_grouped_conv2d_bwd_data_wmma_nhwgc_gkyxc_nhwgk_f16_1x1s1p0_instance.cpp + wmma/device_grouped_conv2d_bwd_data_wmma_gnhwc_gkyxc_gnhwk_i8_1x1s1p0_instance.cpp + wmma/device_grouped_conv2d_bwd_data_wmma_nhwgc_gkyxc_nhwgk_i8_1x1s1p0_instance.cpp + wmma/device_grouped_conv2d_bwd_data_wmma_gnhwc_gkyxc_gnhwk_f16_instance.cpp + wmma/device_grouped_conv2d_bwd_data_wmma_nhwgc_gkyxc_nhwgk_f16_instance.cpp + wmma/device_grouped_conv2d_bwd_data_wmma_gnhwc_gkyxc_gnhwk_i8_instance.cpp + wmma/device_grouped_conv2d_bwd_data_wmma_nhwgc_gkyxc_nhwgk_i8_instance.cpp) diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/wmma/device_grouped_conv2d_bwd_data_wmma_gnhwc_gkyxc_gnhwk_f16_1x1s1p0_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/wmma/device_grouped_conv2d_bwd_data_wmma_gnhwc_gkyxc_gnhwk_f16_1x1s1p0_instance.cpp new file mode 100644 index 0000000000..ac2ba91b63 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/wmma/device_grouped_conv2d_bwd_data_wmma_gnhwc_gkyxc_gnhwk_f16_1x1s1p0_instance.cpp @@ -0,0 +1,40 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_wmma_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +void add_device_grouped_conv2d_bwd_data_wmma_gnhwk_gkyxc_gnhwc_f16_1x1s1p0_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_bwd_data_wmma_f16_instances<2, + GNHWK, + GKYXC, + Empty_Tuple, + GNHWC, + Empty_Tuple, + PassThrough, + ConvBwdData1x1S1P0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/wmma/device_grouped_conv2d_bwd_data_wmma_gnhwc_gkyxc_gnhwk_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/wmma/device_grouped_conv2d_bwd_data_wmma_gnhwc_gkyxc_gnhwk_f16_instance.cpp new file mode 100644 index 0000000000..39af70f623 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/wmma/device_grouped_conv2d_bwd_data_wmma_gnhwc_gkyxc_gnhwk_f16_instance.cpp @@ -0,0 +1,40 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_wmma_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +void add_device_grouped_conv2d_bwd_data_wmma_gnhwk_gkyxc_gnhwc_f16_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_bwd_data_wmma_f16_instances<2, + GNHWK, + GKYXC, + Empty_Tuple, + GNHWC, + Empty_Tuple, + PassThrough, + ConvBwdDataDefault>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/wmma/device_grouped_conv2d_bwd_data_wmma_gnhwc_gkyxc_gnhwk_i8_1x1s1p0_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/wmma/device_grouped_conv2d_bwd_data_wmma_gnhwc_gkyxc_gnhwk_i8_1x1s1p0_instance.cpp new file mode 100644 index 0000000000..1de9f7a95c --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/wmma/device_grouped_conv2d_bwd_data_wmma_gnhwc_gkyxc_gnhwk_i8_1x1s1p0_instance.cpp @@ -0,0 +1,40 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_wmma_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +void add_device_grouped_conv2d_bwd_data_wmma_gnhwk_gkyxc_gnhwc_i8_1x1s1p0_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_bwd_data_wmma_i8_instances<2, + GNHWK, + GKYXC, + Empty_Tuple, + GNHWC, + Empty_Tuple, + PassThrough, + ConvBwdData1x1S1P0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/wmma/device_grouped_conv2d_bwd_data_wmma_gnhwc_gkyxc_gnhwk_i8_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/wmma/device_grouped_conv2d_bwd_data_wmma_gnhwc_gkyxc_gnhwk_i8_instance.cpp new file mode 100644 index 0000000000..8eb6558a2a --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/wmma/device_grouped_conv2d_bwd_data_wmma_gnhwc_gkyxc_gnhwk_i8_instance.cpp @@ -0,0 +1,40 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_wmma_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +void add_device_grouped_conv2d_bwd_data_wmma_gnhwk_gkyxc_gnhwc_i8_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_bwd_data_wmma_i8_instances<2, + GNHWK, + GKYXC, + Empty_Tuple, + GNHWC, + Empty_Tuple, + PassThrough, + ConvBwdDataDefault>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/wmma/device_grouped_conv2d_bwd_data_wmma_nhwgc_gkyxc_nhwgk_f16_1x1s1p0_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/wmma/device_grouped_conv2d_bwd_data_wmma_nhwgc_gkyxc_nhwgk_f16_1x1s1p0_instance.cpp new file mode 100644 index 0000000000..f46cdf7f18 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/wmma/device_grouped_conv2d_bwd_data_wmma_nhwgc_gkyxc_nhwgk_f16_1x1s1p0_instance.cpp @@ -0,0 +1,40 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_wmma_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +void add_device_grouped_conv2d_bwd_data_wmma_nhwgk_gkyxc_nhwgc_f16_1x1s1p0_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_bwd_data_wmma_f16_instances<2, + NHWGK, + GKYXC, + Empty_Tuple, + NHWGC, + Empty_Tuple, + PassThrough, + ConvBwdData1x1S1P0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/wmma/device_grouped_conv2d_bwd_data_wmma_nhwgc_gkyxc_nhwgk_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/wmma/device_grouped_conv2d_bwd_data_wmma_nhwgc_gkyxc_nhwgk_f16_instance.cpp new file mode 100644 index 0000000000..2f4659524f --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/wmma/device_grouped_conv2d_bwd_data_wmma_nhwgc_gkyxc_nhwgk_f16_instance.cpp @@ -0,0 +1,40 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_wmma_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +void add_device_grouped_conv2d_bwd_data_wmma_nhwgk_gkyxc_nhwgc_f16_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_bwd_data_wmma_f16_instances<2, + NHWGK, + GKYXC, + Empty_Tuple, + NHWGC, + Empty_Tuple, + PassThrough, + ConvBwdDataDefault>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/wmma/device_grouped_conv2d_bwd_data_wmma_nhwgc_gkyxc_nhwgk_i8_1x1s1p0_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/wmma/device_grouped_conv2d_bwd_data_wmma_nhwgc_gkyxc_nhwgk_i8_1x1s1p0_instance.cpp new file mode 100644 index 0000000000..789b80b8ac --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/wmma/device_grouped_conv2d_bwd_data_wmma_nhwgc_gkyxc_nhwgk_i8_1x1s1p0_instance.cpp @@ -0,0 +1,40 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_wmma_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +void add_device_grouped_conv2d_bwd_data_wmma_nhwgk_gkyxc_nhwgc_i8_1x1s1p0_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_bwd_data_wmma_i8_instances<2, + NHWGK, + GKYXC, + Empty_Tuple, + NHWGC, + Empty_Tuple, + PassThrough, + ConvBwdData1x1S1P0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/wmma/device_grouped_conv2d_bwd_data_wmma_nhwgc_gkyxc_nhwgk_i8_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/wmma/device_grouped_conv2d_bwd_data_wmma_nhwgc_gkyxc_nhwgk_i8_instance.cpp new file mode 100644 index 0000000000..71394762b4 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/wmma/device_grouped_conv2d_bwd_data_wmma_nhwgc_gkyxc_nhwgk_i8_instance.cpp @@ -0,0 +1,40 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_wmma_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +void add_device_grouped_conv2d_bwd_data_wmma_nhwgk_gkyxc_nhwgc_i8_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_bwd_data_wmma_i8_instances<2, + NHWGK, + GKYXC, + Empty_Tuple, + NHWGC, + Empty_Tuple, + PassThrough, + ConvBwdDataDefault>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/device_grouped_conv2d_bwd_data_xdl_gnhwc_gkyxc_gnhwk_bf16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/xdl/device_grouped_conv2d_bwd_data_xdl_gnhwc_gkyxc_gnhwk_bf16_instance.cpp similarity index 100% rename from library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/device_grouped_conv2d_bwd_data_xdl_gnhwc_gkyxc_gnhwk_bf16_instance.cpp rename to library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/xdl/device_grouped_conv2d_bwd_data_xdl_gnhwc_gkyxc_gnhwk_bf16_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/device_grouped_conv2d_bwd_data_xdl_gnhwc_gkyxc_gnhwk_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/xdl/device_grouped_conv2d_bwd_data_xdl_gnhwc_gkyxc_gnhwk_f16_instance.cpp similarity index 100% rename from library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/device_grouped_conv2d_bwd_data_xdl_gnhwc_gkyxc_gnhwk_f16_instance.cpp rename to library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/xdl/device_grouped_conv2d_bwd_data_xdl_gnhwc_gkyxc_gnhwk_f16_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/device_grouped_conv2d_bwd_data_xdl_gnhwc_gkyxc_gnhwk_f32_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/xdl/device_grouped_conv2d_bwd_data_xdl_gnhwc_gkyxc_gnhwk_f32_instance.cpp similarity index 100% rename from library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/device_grouped_conv2d_bwd_data_xdl_gnhwc_gkyxc_gnhwk_f32_instance.cpp rename to library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/xdl/device_grouped_conv2d_bwd_data_xdl_gnhwc_gkyxc_gnhwk_f32_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_bf16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/xdl/device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_bf16_instance.cpp similarity index 100% rename from library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_bf16_instance.cpp rename to library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/xdl/device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_bf16_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/xdl/device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_f16_instance.cpp similarity index 100% rename from library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_f16_instance.cpp rename to library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/xdl/device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_f16_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_f32_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/xdl/device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_f32_instance.cpp similarity index 100% rename from library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_f32_instance.cpp rename to library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/xdl/device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_f32_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/CMakeLists.txt index 8383d622b2..f44e97ed75 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/CMakeLists.txt @@ -1,8 +1,17 @@ add_instance_library(device_grouped_conv3d_bwd_data_instance - device_grouped_conv3d_bwd_data_xdl_gndhwc_gkzyxc_gndhwk_f16_instance.cpp - device_grouped_conv3d_bwd_data_xdl_gndhwc_gkzyxc_gndhwk_bf16_instance.cpp - device_grouped_conv3d_bwd_data_xdl_gndhwc_gkzyxc_gndhwk_f32_instance.cpp - device_grouped_conv3d_bwd_data_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp - device_grouped_conv3d_bwd_data_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp - device_grouped_conv3d_bwd_data_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp + xdl/device_grouped_conv3d_bwd_data_xdl_gndhwc_gkzyxc_gndhwk_f16_instance.cpp + xdl/device_grouped_conv3d_bwd_data_xdl_gndhwc_gkzyxc_gndhwk_bf16_instance.cpp + xdl/device_grouped_conv3d_bwd_data_xdl_gndhwc_gkzyxc_gndhwk_f32_instance.cpp + xdl/device_grouped_conv3d_bwd_data_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp + xdl/device_grouped_conv3d_bwd_data_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp + xdl/device_grouped_conv3d_bwd_data_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp + + wmma/device_grouped_conv3d_bwd_data_wmma_gndhwc_gkzyxc_gndhwk_f16_instance.cpp + wmma/device_grouped_conv3d_bwd_data_wmma_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp + wmma/device_grouped_conv3d_bwd_data_wmma_gndhwc_gkzyxc_gndhwk_i8_instance.cpp + wmma/device_grouped_conv3d_bwd_data_wmma_ndhwgc_gkzyxc_ndhwgk_i8_instance.cpp + wmma/device_grouped_conv3d_bwd_data_wmma_gndhwc_gkzyxc_gndhwk_f16_1x1s1p0_instance.cpp + wmma/device_grouped_conv3d_bwd_data_wmma_ndhwgc_gkzyxc_ndhwgk_f16_1x1s1p0_instance.cpp + wmma/device_grouped_conv3d_bwd_data_wmma_gndhwc_gkzyxc_gndhwk_i8_1x1s1p0_instance.cpp + wmma/device_grouped_conv3d_bwd_data_wmma_ndhwgc_gkzyxc_ndhwgk_i8_1x1s1p0_instance.cpp ) diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/wmma/device_grouped_conv3d_bwd_data_wmma_gndhwc_gkzyxc_gndhwk_f16_1x1s1p0_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/wmma/device_grouped_conv3d_bwd_data_wmma_gndhwc_gkzyxc_gndhwk_f16_1x1s1p0_instance.cpp new file mode 100644 index 0000000000..ba2f6bdc0a --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/wmma/device_grouped_conv3d_bwd_data_wmma_gndhwc_gkzyxc_gndhwk_f16_1x1s1p0_instance.cpp @@ -0,0 +1,40 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_wmma_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +void add_device_grouped_conv3d_bwd_data_wmma_gndhwk_gkzyxc_gndhwc_f16_1x1s1p0_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_bwd_data_wmma_f16_instances<3, + GNDHWK, + GKZYXC, + Empty_Tuple, + GNDHWC, + Empty_Tuple, + PassThrough, + ConvBwdData1x1S1P0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/wmma/device_grouped_conv3d_bwd_data_wmma_gndhwc_gkzyxc_gndhwk_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/wmma/device_grouped_conv3d_bwd_data_wmma_gndhwc_gkzyxc_gndhwk_f16_instance.cpp new file mode 100644 index 0000000000..26403bf0e5 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/wmma/device_grouped_conv3d_bwd_data_wmma_gndhwc_gkzyxc_gndhwk_f16_instance.cpp @@ -0,0 +1,40 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_wmma_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +void add_device_grouped_conv3d_bwd_data_wmma_gndhwk_gkzyxc_gndhwc_f16_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_bwd_data_wmma_f16_instances<3, + GNDHWK, + GKZYXC, + Empty_Tuple, + GNDHWC, + Empty_Tuple, + PassThrough, + ConvBwdDataDefault>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/wmma/device_grouped_conv3d_bwd_data_wmma_gndhwc_gkzyxc_gndhwk_i8_1x1s1p0_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/wmma/device_grouped_conv3d_bwd_data_wmma_gndhwc_gkzyxc_gndhwk_i8_1x1s1p0_instance.cpp new file mode 100644 index 0000000000..9e453a36c5 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/wmma/device_grouped_conv3d_bwd_data_wmma_gndhwc_gkzyxc_gndhwk_i8_1x1s1p0_instance.cpp @@ -0,0 +1,40 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_wmma_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +void add_device_grouped_conv3d_bwd_data_wmma_gndhwk_gkzyxc_gndhwc_i8_1x1s1p0_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_bwd_data_wmma_i8_instances<3, + GNDHWK, + GKZYXC, + Empty_Tuple, + GNDHWC, + Empty_Tuple, + PassThrough, + ConvBwdData1x1S1P0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/wmma/device_grouped_conv3d_bwd_data_wmma_gndhwc_gkzyxc_gndhwk_i8_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/wmma/device_grouped_conv3d_bwd_data_wmma_gndhwc_gkzyxc_gndhwk_i8_instance.cpp new file mode 100644 index 0000000000..6f7d4f79d1 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/wmma/device_grouped_conv3d_bwd_data_wmma_gndhwc_gkzyxc_gndhwk_i8_instance.cpp @@ -0,0 +1,40 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_wmma_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +void add_device_grouped_conv3d_bwd_data_wmma_gndhwk_gkzyxc_gndhwc_i8_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_bwd_data_wmma_i8_instances<3, + GNDHWK, + GKZYXC, + Empty_Tuple, + GNDHWC, + Empty_Tuple, + PassThrough, + ConvBwdDataDefault>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/wmma/device_grouped_conv3d_bwd_data_wmma_ndhwgc_gkzyxc_ndhwgk_f16_1x1s1p0_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/wmma/device_grouped_conv3d_bwd_data_wmma_ndhwgc_gkzyxc_ndhwgk_f16_1x1s1p0_instance.cpp new file mode 100644 index 0000000000..4475a0e456 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/wmma/device_grouped_conv3d_bwd_data_wmma_ndhwgc_gkzyxc_ndhwgk_f16_1x1s1p0_instance.cpp @@ -0,0 +1,40 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_wmma_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +void add_device_grouped_conv3d_bwd_data_wmma_ndhwgk_gkzyxc_ndhwgc_f16_1x1s1p0_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_bwd_data_wmma_f16_instances<3, + NDHWGK, + GKZYXC, + Empty_Tuple, + NDHWGC, + Empty_Tuple, + PassThrough, + ConvBwdData1x1S1P0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/wmma/device_grouped_conv3d_bwd_data_wmma_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/wmma/device_grouped_conv3d_bwd_data_wmma_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp new file mode 100644 index 0000000000..f3941636f8 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/wmma/device_grouped_conv3d_bwd_data_wmma_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp @@ -0,0 +1,40 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_wmma_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +void add_device_grouped_conv3d_bwd_data_wmma_ndhwgk_gkzyxc_ndhwgc_f16_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_bwd_data_wmma_f16_instances<3, + NDHWGK, + GKZYXC, + Empty_Tuple, + NDHWGC, + Empty_Tuple, + PassThrough, + ConvBwdDataDefault>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/wmma/device_grouped_conv3d_bwd_data_wmma_ndhwgc_gkzyxc_ndhwgk_i8_1x1s1p0_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/wmma/device_grouped_conv3d_bwd_data_wmma_ndhwgc_gkzyxc_ndhwgk_i8_1x1s1p0_instance.cpp new file mode 100644 index 0000000000..b8479c9c9d --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/wmma/device_grouped_conv3d_bwd_data_wmma_ndhwgc_gkzyxc_ndhwgk_i8_1x1s1p0_instance.cpp @@ -0,0 +1,40 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_wmma_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +void add_device_grouped_conv3d_bwd_data_wmma_ndhwgk_gkzyxc_ndhwgc_i8_1x1s1p0_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_bwd_data_wmma_i8_instances<3, + NDHWGK, + GKZYXC, + Empty_Tuple, + NDHWGC, + Empty_Tuple, + PassThrough, + ConvBwdData1x1S1P0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/wmma/device_grouped_conv3d_bwd_data_wmma_ndhwgc_gkzyxc_ndhwgk_i8_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/wmma/device_grouped_conv3d_bwd_data_wmma_ndhwgc_gkzyxc_ndhwgk_i8_instance.cpp new file mode 100644 index 0000000000..d8a74e744b --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/wmma/device_grouped_conv3d_bwd_data_wmma_ndhwgc_gkzyxc_ndhwgk_i8_instance.cpp @@ -0,0 +1,40 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_wmma_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +void add_device_grouped_conv3d_bwd_data_wmma_ndhwgk_gkzyxc_ndhwgc_i8_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_bwd_data_wmma_i8_instances<3, + NDHWGK, + GKZYXC, + Empty_Tuple, + NDHWGC, + Empty_Tuple, + PassThrough, + ConvBwdDataDefault>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/device_grouped_conv3d_bwd_data_xdl_gndhwc_gkzyxc_gndhwk_bf16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/xdl/device_grouped_conv3d_bwd_data_xdl_gndhwc_gkzyxc_gndhwk_bf16_instance.cpp similarity index 100% rename from library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/device_grouped_conv3d_bwd_data_xdl_gndhwc_gkzyxc_gndhwk_bf16_instance.cpp rename to library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/xdl/device_grouped_conv3d_bwd_data_xdl_gndhwc_gkzyxc_gndhwk_bf16_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/device_grouped_conv3d_bwd_data_xdl_gndhwc_gkzyxc_gndhwk_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/xdl/device_grouped_conv3d_bwd_data_xdl_gndhwc_gkzyxc_gndhwk_f16_instance.cpp similarity index 100% rename from library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/device_grouped_conv3d_bwd_data_xdl_gndhwc_gkzyxc_gndhwk_f16_instance.cpp rename to library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/xdl/device_grouped_conv3d_bwd_data_xdl_gndhwc_gkzyxc_gndhwk_f16_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/device_grouped_conv3d_bwd_data_xdl_gndhwc_gkzyxc_gndhwk_f32_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/xdl/device_grouped_conv3d_bwd_data_xdl_gndhwc_gkzyxc_gndhwk_f32_instance.cpp similarity index 100% rename from library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/device_grouped_conv3d_bwd_data_xdl_gndhwc_gkzyxc_gndhwk_f32_instance.cpp rename to library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/xdl/device_grouped_conv3d_bwd_data_xdl_gndhwc_gkzyxc_gndhwk_f32_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/device_grouped_conv3d_bwd_data_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/xdl/device_grouped_conv3d_bwd_data_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp similarity index 100% rename from library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/device_grouped_conv3d_bwd_data_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp rename to library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/xdl/device_grouped_conv3d_bwd_data_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/device_grouped_conv3d_bwd_data_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/xdl/device_grouped_conv3d_bwd_data_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp similarity index 100% rename from library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/device_grouped_conv3d_bwd_data_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp rename to library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/xdl/device_grouped_conv3d_bwd_data_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/device_grouped_conv3d_bwd_data_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/xdl/device_grouped_conv3d_bwd_data_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp similarity index 100% rename from library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/device_grouped_conv3d_bwd_data_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp rename to library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/xdl/device_grouped_conv3d_bwd_data_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp diff --git a/test/grouped_convnd_bwd_data/CMakeLists.txt b/test/grouped_convnd_bwd_data/CMakeLists.txt index 436ab2af76..9773e5a9c6 100644 --- a/test/grouped_convnd_bwd_data/CMakeLists.txt +++ b/test/grouped_convnd_bwd_data/CMakeLists.txt @@ -1,6 +1,19 @@ -if(GPU_TARGETS MATCHES "gfx908" OR GPU_TARGETS MATCHES "gfx90a" OR GPU_TARGETS MATCHES "gfx940") - add_gtest_executable(test_grouped_convnd_bwd_data test_grouped_convnd_bwd_data.cpp) - target_link_libraries(test_grouped_convnd_bwd_data PRIVATE utility device_grouped_conv2d_bwd_data_instance device_grouped_conv3d_bwd_data_instance) - add_gtest_executable(test_grouped_convnd_bwd_data_interface test_grouped_convnd_bwd_data_interface.cpp) - target_link_libraries(test_grouped_convnd_bwd_data_interface PRIVATE utility device_grouped_conv2d_bwd_data_instance) -endif() \ No newline at end of file +list(APPEND gpu_list_xdl gfx908 gfx90a gfx940) +list(APPEND gpu_list_wmma gfx1100 gfx1101 gfx1102) +set(target 0) +foreach(gpu IN LISTS GPU_TARGETS) + if(gpu IN_LIST gpu_list_xdl AND target EQUAL 0) + add_gtest_executable(test_grouped_convnd_bwd_data test_grouped_convnd_bwd_data.cpp) + target_link_libraries(test_grouped_convnd_bwd_data PRIVATE utility device_grouped_conv2d_bwd_data_instance device_grouped_conv3d_bwd_data_instance) + add_gtest_executable(test_grouped_convnd_bwd_data_interface test_grouped_convnd_bwd_data_interface_xdl.cpp) + target_link_libraries(test_grouped_convnd_bwd_data_interface PRIVATE utility device_grouped_conv2d_bwd_data_instance) + set(target 1) + endif() + if(gpu IN_LIST gpu_list_wmma AND target EQUAL 0) + add_gtest_executable(test_grouped_convnd_bwd_data test_grouped_convnd_bwd_data.cpp) + target_link_libraries(test_grouped_convnd_bwd_data PRIVATE utility device_grouped_conv2d_bwd_data_instance device_grouped_conv3d_bwd_data_instance) + add_gtest_executable(test_grouped_convnd_bwd_data_interface test_grouped_convnd_bwd_data_interface_wmma.cpp) + target_link_libraries(test_grouped_convnd_bwd_data_interface PRIVATE utility device_grouped_conv2d_bwd_data_instance) + set(target 1) + endif() +endforeach() \ No newline at end of file diff --git a/test/grouped_convnd_bwd_data/test_grouped_convnd_bwd_data.cpp b/test/grouped_convnd_bwd_data/test_grouped_convnd_bwd_data.cpp index bdad37fe14..96506b876d 100644 --- a/test/grouped_convnd_bwd_data/test_grouped_convnd_bwd_data.cpp +++ b/test/grouped_convnd_bwd_data/test_grouped_convnd_bwd_data.cpp @@ -51,16 +51,20 @@ using namespace ck::tensor_layout::convolution; using KernelTypes2d = ::testing::Types, std::tuple, std::tuple, + std::tuple, std::tuple, std::tuple, - std::tuple>; + std::tuple, + std::tuple>; using KernelTypes3d = ::testing::Types, std::tuple, std::tuple, + std::tuple, std::tuple, std::tuple, - std::tuple>; + std::tuple, + std::tuple>; template class TestGroupedConvndBwdData2d : public TestGroupedConvndBwdData diff --git a/test/grouped_convnd_bwd_data/test_grouped_convnd_bwd_data_interface_wmma.cpp b/test/grouped_convnd_bwd_data/test_grouped_convnd_bwd_data_interface_wmma.cpp new file mode 100644 index 0000000000..c0429c6d09 --- /dev/null +++ b/test/grouped_convnd_bwd_data/test_grouped_convnd_bwd_data_interface_wmma.cpp @@ -0,0 +1,178 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/convolution_backward_data_specialization.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp" + +#include "ck/library/utility/convolution_parameter.hpp" +#include "ck/library/utility/algorithm.hpp" +#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp" + +#include + +using DataType = ck::half_t; +using AccDataType = float; +using Pass = ck::tensor_operation::element_wise::PassThrough; + +template +using S = ck::Sequence; +using ConvBackwardDataSpecialization = + ck::tensor_operation::device::ConvolutionBackwardDataSpecialization; + +static constexpr auto ConvBwdDataDefault = ConvBackwardDataSpecialization::Default; +static constexpr auto Filter1x1Stride1Pad0 = ConvBackwardDataSpecialization::Filter1x1Stride1Pad0; + +template +class TestGroupedConvndBwdData : public ::testing::Test +{ + protected: + static constexpr ck::index_t NDimSpatial = 2; + + using OutLayout = std::tuple_element_t<0, Tuple>; + using WeiLayout = std::tuple_element_t<1, Tuple>; + using InLayout = std::tuple_element_t<2, Tuple>; + + // clang-format off + using GroupedConvBwdDataDeviceInstance = ck::tensor_operation::device::DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle + //| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| ConvForward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | WMMA| WMMA| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //| | | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + < NDimSpatial,OutLayout, WeiLayout, ck::Tuple<>, InLayout, DataType, DataType, AccDataType, DataType, ck::Tuple<>, DataType, Pass, Pass, Pass, ConvSpec, 64, 32, 64, 8, 8, 16, 16, 1, 4, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 2>, 8>; + // clang-format on + + ck::utils::conv::ConvParam conv_param; + + template + bool Run() + { + + const auto out_g_n_k_wos_desc = + ck::utils::conv::make_output_host_tensor_descriptor_g_n_k_wos_packed( + conv_param); + + const auto wei_g_k_c_xs_desc = + ck::utils::conv::make_weight_host_tensor_descriptor_g_k_c_xs_packed( + conv_param); + + const auto in_g_n_c_wis_desc = + ck::utils::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed( + conv_param); + + std::array out_lengths{}; + std::array out_strides{}; + std::array wei_lengths{}; + std::array wei_strides{}; + std::array in_lengths{}; + std::array in_strides{}; + std::array conv_filter_strides{}; + std::array conv_filter_dilations{}; + std::array input_left_pads{}; + std::array input_right_pads{}; + + auto copy = [](const auto& x, auto& y) { ck::ranges::copy(x, y.begin()); }; + + copy(out_g_n_k_wos_desc.GetLengths(), out_lengths); + copy(out_g_n_k_wos_desc.GetStrides(), out_strides); + copy(wei_g_k_c_xs_desc.GetLengths(), wei_lengths); + copy(wei_g_k_c_xs_desc.GetStrides(), wei_strides); + copy(in_g_n_c_wis_desc.GetLengths(), in_lengths); + copy(in_g_n_c_wis_desc.GetStrides(), in_strides); + copy(conv_param.conv_filter_strides_, conv_filter_strides); + copy(conv_param.conv_filter_dilations_, conv_filter_dilations); + copy(conv_param.input_left_pads_, input_left_pads); + copy(conv_param.input_right_pads_, input_right_pads); + + auto conv = GroupedConvBwdDataDeviceInstance{}; + + auto argument = conv.MakeArgument(nullptr, + nullptr, + std::array{}, + nullptr, + out_lengths, + out_strides, + wei_lengths, + wei_strides, + {}, + {}, + in_lengths, + in_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + Pass{}, + Pass{}, + Pass{}); + return conv.IsSupportedArgument(argument); + } +}; + +using GNHWC = ck::tensor_layout::convolution::GNHWC; +using NHWGC = ck::tensor_layout::convolution::NHWGC; + +using GKYXC = ck::tensor_layout::convolution::GKYXC; + +using GNHWK = ck::tensor_layout::convolution::GNHWK; +using NHWGK = ck::tensor_layout::convolution::NHWGK; + +using KernelTypes = + ::testing::Types, std::tuple>; + +template +class TestGroupedConvndBwdDataDefault : public TestGroupedConvndBwdData +{ +}; + +template +class TestGroupedConvndBwdDataFilter1x1 + : public TestGroupedConvndBwdData +{ +}; + +TYPED_TEST_SUITE(TestGroupedConvndBwdDataDefault, KernelTypes); +TYPED_TEST_SUITE(TestGroupedConvndBwdDataFilter1x1, KernelTypes); + +TYPED_TEST(TestGroupedConvndBwdDataFilter1x1, SpecializationCheck) +{ + // Check filter 3,3 instead of 1,1 + this->conv_param = {2, 2, 4, 192, 192, {3, 3}, {28, 28}, {1, 1}, {1, 1}, {0, 0}, {0, 0}}; + bool is_supported = this->template Run<2>(); + EXPECT_FALSE(is_supported); + + // Check strides 2,2 instead of 1,1 + this->conv_param = {2, 2, 4, 192, 192, {1, 1}, {28, 28}, {2, 2}, {1, 1}, {0, 0}, {0, 0}}; + is_supported = this->template Run<2>(); + EXPECT_FALSE(is_supported); + + // Check with pad + this->conv_param = {2, 2, 4, 192, 192, {1, 1}, {28, 28}, {1, 1}, {1, 1}, {1, 1}, {1, 1}}; + is_supported = this->template Run<2>(); + EXPECT_FALSE(is_supported); + + // Supported version + this->conv_param = {2, 2, 4, 192, 192, {1, 1}, {28, 28}, {1, 1}, {1, 1}, {0, 0}, {0, 0}}; + is_supported = this->template Run<2>(); + EXPECT_TRUE(is_supported); +} + +TYPED_TEST(TestGroupedConvndBwdDataDefault, VectorLoadCheck) +{ + // vector load for A + this->conv_param = {2, 2, 128, 129, 256, {1, 1}, {7, 7}, {2, 2}, {1, 1}, {0, 0}, {0, 0}}; + bool is_supported = this->template Run<2>(); + EXPECT_FALSE(is_supported); + // vector load for B, E, Ds + this->conv_param = {2, 2, 128, 128, 257, {1, 1}, {7, 7}, {2, 2}, {1, 1}, {0, 0}, {0, 0}}; + is_supported = this->template Run<2>(); + EXPECT_FALSE(is_supported); +} diff --git a/test/grouped_convnd_bwd_data/test_grouped_convnd_bwd_data_interface.cpp b/test/grouped_convnd_bwd_data/test_grouped_convnd_bwd_data_interface_xdl.cpp similarity index 100% rename from test/grouped_convnd_bwd_data/test_grouped_convnd_bwd_data_interface.cpp rename to test/grouped_convnd_bwd_data/test_grouped_convnd_bwd_data_interface_xdl.cpp