diff --git a/client_example/CMakeLists.txt b/client_example/CMakeLists.txt index 8e7aa76f87..1dfa845306 100644 --- a/client_example/CMakeLists.txt +++ b/client_example/CMakeLists.txt @@ -6,9 +6,10 @@ find_package(composable_kernel 1.0.0 COMPONENTS device_operations) find_package(hip REQUIRED PATHS /opt/rocm) message(STATUS "Build with HIP ${hip_VERSION}") -add_subdirectory(01_gemm) -add_subdirectory(02_gemm_add_add_fastgelu) -add_subdirectory(03_gemm_layernorm) -add_subdirectory(04_contraction) -add_subdirectory(05_layernorm) -add_subdirectory(06_softmax) +# add all example subdir +file(GLOB dir_list LIST_DIRECTORIES true *) +FOREACH(subdir ${dir_list}) + IF(IS_DIRECTORY "${subdir}") + add_subdirectory(${subdir}) + ENDIF() +ENDFOREACH() diff --git a/example/24_batched_gemm_e_permute/batched_gemm_e_permute_xdl_fp16.cpp b/example/24_batched_gemm_e_permute/batched_gemm_e_permute_xdl_fp16.cpp deleted file mode 100644 index 5b7f988134..0000000000 --- a/example/24_batched_gemm_e_permute/batched_gemm_e_permute_xdl_fp16.cpp +++ /dev/null @@ -1,258 +0,0 @@ -#include -#include -#include -#include - -#include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" -#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" -#include "ck/tensor_operation/gpu/device/device_batched_gemm_e_permute_xdl.hpp" -#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" - -#include "ck/library/utility/check_err.hpp" -#include "ck/library/utility/device_memory.hpp" -#include "ck/library/utility/host_tensor.hpp" -#include "ck/library/utility/host_tensor_generator.hpp" -#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp" - -template -using S = ck::Sequence; - -using F16 = ck::half_t; -using F32 = float; - -using Row = ck::tensor_layout::gemm::RowMajor; -using Col = ck::tensor_layout::gemm::ColumnMajor; - -using PassThrough = ck::tensor_operation::element_wise::PassThrough; - -using ADataType = F16; -using BDataType = F16; -using AccDataType = F32; -using CShuffleDataType = F16; -using EDataType = F16; - -using ALayout = Row; -using BLayout = Col; -using ELayout = Row; - -using AElementOp = PassThrough; -using BElementOp = PassThrough; -using CDEElementOp = PassThrough; - -static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; - -using DeviceGemmInstance = ck::tensor_operation::device::DeviceBatchedGemmEPermuteXdl - // clang-format off -//######| ALayout| BLayout| ELayout| AData| BData| AccData| CShuffle| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| -//######| | | | Type| Type| Type| DataType| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| -//######| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| -//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - < ALayout, BLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>; -// clang-format on - -using ReferenceBatchedGemmInstance = ck::tensor_operation::host::ReferenceBatchedGemm; - -int main(int argc, char* argv[]) -{ - bool do_verification = true; - int init_method = 1; - bool time_kernel = false; - - const int M = 256; - const int N = 128; - const int K = 64; - - const int stride_A = K; - const int stride_B = K; - - const int batch_stride_A = M * K; - const int batch_stride_B = K * N; - - const int G0 = 16; - const int G1 = 8; - - const int batch_count = G0 * G1; - - // output layout - [G0, M, G1, N] - const int stride_G0 = M * G1 * N; - const int stride_G1 = N; - const int stride_M = G1 * N; - const int stride_N = 1; - - if(argc == 4) - { - do_verification = std::stoi(argv[1]); - init_method = std::stoi(argv[2]); - time_kernel = std::stoi(argv[3]); - } - else - { - printf("arg1: verification (0=no, 1=yes)\n"); - printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); - printf("arg3: time kernel (0=n0, 1=yes)\n"); - exit(0); - } - - // GEMM shape - ck::tensor_operation::device::BatchedGemmEPermuteDesc batched_gemm_e_permute_desc{ - G0, G1, M, N, stride_G0, stride_G1, stride_M, stride_N}; - - auto f_host_tensor_descriptor = [](std::size_t batch_count_, - std::size_t row, - std::size_t col, - std::size_t stride, - std::size_t batch_stride, - auto layout) { - if(std::is_same::value) - { - return HostTensorDescriptor(std::vector({batch_count_, row, col}), - std::vector({batch_stride, stride, 1})); - } - else - { - return HostTensorDescriptor(std::vector({batch_count_, row, col}), - std::vector({batch_stride, 1, stride})); - } - }; - - Tensor a_g_m_k( - f_host_tensor_descriptor(batch_count, M, K, stride_A, batch_stride_A, ALayout{})); - Tensor b_g_k_n( - f_host_tensor_descriptor(batch_count, K, N, stride_B, batch_stride_B, BLayout{})); - - auto f_host_e_tensor_descriptor = [](std::size_t G0_, - std::size_t G1_, - std::size_t M_, - std::size_t N_, - std::size_t stride_G0_, - std::size_t stride_G1_, - std::size_t stride_M_, - std::size_t stride_N_) { - return HostTensorDescriptor( - std::vector({G0_, G1_, M_, N_}), - std::vector({stride_G0_, stride_G1_, stride_M_, stride_N_})); - }; - - Tensor e_g0_g1_m_n_host_result( - f_host_e_tensor_descriptor(G0, G1, M, N, stride_G0, stride_G1, stride_M, stride_N)); - - Tensor e_g0_g1_m_n_device_result( - f_host_e_tensor_descriptor(G0, G1, M, N, stride_G0, stride_G1, stride_M, stride_N)); - - std::cout << "a_g_m_k: " << a_g_m_k.mDesc << std::endl; - std::cout << "b_g_k_n: " << b_g_k_n.mDesc << std::endl; - std::cout << "e_g0_g1_m_n: " << e_g0_g1_m_n_host_result.mDesc << std::endl; - - switch(init_method) - { - case 0: break; - case 1: - a_g_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 5}); - b_g_k_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); - break; - default: - a_g_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); - b_g_k_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); - break; - } - - DeviceMem a_device_buf(sizeof(ADataType) * a_g_m_k.mDesc.GetElementSpaceSize()); - DeviceMem b_device_buf(sizeof(BDataType) * b_g_k_n.mDesc.GetElementSpaceSize()); - DeviceMem e_device_buf(sizeof(EDataType) * - e_g0_g1_m_n_device_result.mDesc.GetElementSpaceSize()); - - a_device_buf.ToDevice(a_g_m_k.mData.data()); - b_device_buf.ToDevice(b_g_k_n.mData.data()); - - auto a_element_op = AElementOp{}; - auto b_element_op = BElementOp{}; - auto cde_element_op = CDEElementOp{}; - - auto gemm = DeviceGemmInstance{}; - auto invoker = gemm.MakeInvoker(); - - // do GEM - auto argument = gemm.MakeArgument(static_cast(a_device_buf.GetDeviceBuffer()), - static_cast(b_device_buf.GetDeviceBuffer()), - static_cast(e_device_buf.GetDeviceBuffer()), - M, - N, - K, - stride_A, - stride_B, - batch_stride_A, - batch_stride_B, - batched_gemm_e_permute_desc, - batch_count, - a_element_op, - b_element_op, - cde_element_op); - - if(!gemm.IsSupportedArgument(argument)) - { - throw std::runtime_error( - "wrong! device_gemm with the specified compilation parameters does " - "not support this GEMM problem"); - } - - float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); - - std::size_t flop = std::size_t(2) * batch_count * M * N * K; - std::size_t num_btype = sizeof(ADataType) * batch_count * M * K + - sizeof(BDataType) * batch_count * K * N + - sizeof(EDataType) * batch_count * M * N; - - float tflops = static_cast(flop) / 1.E9 / ave_time; - - float gb_per_sec = num_btype / 1.E6 / ave_time; - - std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " - << gemm.GetTypeString() << std::endl; - - bool pass = true; - - if(do_verification) - { - e_device_buf.FromDevice(e_g0_g1_m_n_device_result.mData.data()); - - auto ref_batched_gemm = ReferenceBatchedGemmInstance{}; - auto ref_invoker = ref_batched_gemm.MakeInvoker(); - - Tensor c_g_m_n_host_result = HostTensorDescriptor( - std::vector({batch_count, M, N}), std::vector({M * N, N, 1})); - - auto ref_argument = ref_batched_gemm.MakeArgument( - a_g_m_k, b_g_k_n, c_g_m_n_host_result, a_element_op, b_element_op, cde_element_op); - - ref_invoker.Run(ref_argument); - - for(int g0 = 0; g0 < G0; g0++) - { - for(int g1 = 0; g1 < G1; g1++) - { - for(int m = 0; m < M; m++) - { - for(int n = 0; n < N; n++) - { - int g = g0 * G1 + g1; - - e_g0_g1_m_n_host_result(g0, g1, m, n) = c_g_m_n_host_result(g, m, n); - } - } - } - } - - pass = ck::utils::check_err(e_g0_g1_m_n_host_result.mData, - e_g0_g1_m_n_device_result.mData, - "Error: Incorrect results c"); - } - - return pass ? 0 : 1; -} diff --git a/example/30_grouped_convnd_fwd_bias_relu_add/grouped_convnd_fwd_bias_relu_add_xdl_bf16.cpp b/example/30_grouped_convnd_fwd_bias_relu_add/grouped_convnd_fwd_bias_relu_add_xdl_bf16.cpp index 4ac996dbaa..bd5b48f884 100644 --- a/example/30_grouped_convnd_fwd_bias_relu_add/grouped_convnd_fwd_bias_relu_add_xdl_bf16.cpp +++ b/example/30_grouped_convnd_fwd_bias_relu_add/grouped_convnd_fwd_bias_relu_add_xdl_bf16.cpp @@ -137,7 +137,7 @@ int main(int argc, char* argv[]) { using InLayout = ctc::G_NW_C; using WeiLayout = ctc::G_K_X_C; - using BiasLayout = ctc::G_NW_K; + using BiasLayout = ctc::G_K; using ResidualLayout = ctc::G_NW_K; using OutLayout = ctc::G_NW_K; @@ -220,7 +220,7 @@ int main(int argc, char* argv[]) { using InLayout = ctc::G_NHW_C; using WeiLayout = ctc::G_K_YX_C; - using BiasLayout = ctc::G_NHW_K; + using BiasLayout = ctc::G_K; using ResidualLayout = ctc::G_NHW_K; using OutLayout = ctc::G_NHW_K; @@ -332,7 +332,7 @@ int main(int argc, char* argv[]) { using InLayout = ctc::G_NDHW_C; using WeiLayout = ctc::G_K_ZYX_C; - using BiasLayout = ctc::G_NDHW_K; + using BiasLayout = ctc::G_K; using ResidualLayout = ctc::G_NDHW_K; using OutLayout = ctc::G_NDHW_K; diff --git a/example/30_grouped_convnd_fwd_bias_relu_add/grouped_convnd_fwd_bias_relu_add_xdl_fp16.cpp b/example/30_grouped_convnd_fwd_bias_relu_add/grouped_convnd_fwd_bias_relu_add_xdl_fp16.cpp index 2fb2681ea6..36997c33c4 100644 --- a/example/30_grouped_convnd_fwd_bias_relu_add/grouped_convnd_fwd_bias_relu_add_xdl_fp16.cpp +++ b/example/30_grouped_convnd_fwd_bias_relu_add/grouped_convnd_fwd_bias_relu_add_xdl_fp16.cpp @@ -137,7 +137,7 @@ int main(int argc, char* argv[]) { using InLayout = ctc::G_NW_C; using WeiLayout = ctc::G_K_X_C; - using BiasLayout = ctc::G_NW_K; + using BiasLayout = ctc::G_K; using ResidualLayout = ctc::G_NW_K; using OutLayout = ctc::G_NW_K; @@ -220,7 +220,7 @@ int main(int argc, char* argv[]) { using InLayout = ctc::G_NHW_C; using WeiLayout = ctc::G_K_YX_C; - using BiasLayout = ctc::G_NHW_K; + using BiasLayout = ctc::G_K; using ResidualLayout = ctc::G_NHW_K; using OutLayout = ctc::G_NHW_K; @@ -332,7 +332,7 @@ int main(int argc, char* argv[]) { using InLayout = ctc::G_NDHW_C; using WeiLayout = ctc::G_K_ZYX_C; - using BiasLayout = ctc::G_NDHW_K; + using BiasLayout = ctc::G_K; using ResidualLayout = ctc::G_NDHW_K; using OutLayout = ctc::G_NDHW_K; diff --git a/example/30_grouped_convnd_fwd_bias_relu_add/grouped_convnd_fwd_bias_relu_add_xdl_fp32.cpp b/example/30_grouped_convnd_fwd_bias_relu_add/grouped_convnd_fwd_bias_relu_add_xdl_fp32.cpp index c792ac5fe3..9b2374de2e 100644 --- a/example/30_grouped_convnd_fwd_bias_relu_add/grouped_convnd_fwd_bias_relu_add_xdl_fp32.cpp +++ b/example/30_grouped_convnd_fwd_bias_relu_add/grouped_convnd_fwd_bias_relu_add_xdl_fp32.cpp @@ -137,7 +137,7 @@ int main(int argc, char* argv[]) { using InLayout = ctc::G_NW_C; using WeiLayout = ctc::G_K_X_C; - using BiasLayout = ctc::G_NW_K; + using BiasLayout = ctc::G_K; using ResidualLayout = ctc::G_NW_K; using OutLayout = ctc::G_NW_K; @@ -220,7 +220,7 @@ int main(int argc, char* argv[]) { using InLayout = ctc::G_NHW_C; using WeiLayout = ctc::G_K_YX_C; - using BiasLayout = ctc::G_NHW_K; + using BiasLayout = ctc::G_K; using ResidualLayout = ctc::G_NHW_K; using OutLayout = ctc::G_NHW_K; @@ -332,7 +332,7 @@ int main(int argc, char* argv[]) { using InLayout = ctc::G_NDHW_C; using WeiLayout = ctc::G_K_ZYX_C; - using BiasLayout = ctc::G_NDHW_K; + using BiasLayout = ctc::G_K; using ResidualLayout = ctc::G_NDHW_K; using OutLayout = ctc::G_NDHW_K; diff --git a/example/30_grouped_convnd_fwd_bias_relu_add/grouped_convnd_fwd_bias_relu_add_xdl_int4.cpp b/example/30_grouped_convnd_fwd_bias_relu_add/grouped_convnd_fwd_bias_relu_add_xdl_int4.cpp index d989e63590..be5b791249 100644 --- a/example/30_grouped_convnd_fwd_bias_relu_add/grouped_convnd_fwd_bias_relu_add_xdl_int4.cpp +++ b/example/30_grouped_convnd_fwd_bias_relu_add/grouped_convnd_fwd_bias_relu_add_xdl_int4.cpp @@ -137,7 +137,7 @@ int main(int argc, char* argv[]) { using InLayout = ctc::G_NW_C; using WeiLayout = ctc::G_K_X_C; - using BiasLayout = ctc::G_NW_K; + using BiasLayout = ctc::G_K; using ResidualLayout = ctc::G_NW_K; using OutLayout = ctc::G_NW_K; @@ -220,7 +220,7 @@ int main(int argc, char* argv[]) { using InLayout = ctc::G_NHW_C; using WeiLayout = ctc::G_K_YX_C; - using BiasLayout = ctc::G_NHW_K; + using BiasLayout = ctc::G_K; using ResidualLayout = ctc::G_NHW_K; using OutLayout = ctc::G_NHW_K; @@ -332,7 +332,7 @@ int main(int argc, char* argv[]) { using InLayout = ctc::G_NDHW_C; using WeiLayout = ctc::G_K_ZYX_C; - using BiasLayout = ctc::G_NDHW_K; + using BiasLayout = ctc::G_K; using ResidualLayout = ctc::G_NDHW_K; using OutLayout = ctc::G_NDHW_K; diff --git a/example/30_grouped_convnd_fwd_bias_relu_add/grouped_convnd_fwd_bias_relu_add_xdl_int8.cpp b/example/30_grouped_convnd_fwd_bias_relu_add/grouped_convnd_fwd_bias_relu_add_xdl_int8.cpp index 9aabe86948..1f3434694d 100644 --- a/example/30_grouped_convnd_fwd_bias_relu_add/grouped_convnd_fwd_bias_relu_add_xdl_int8.cpp +++ b/example/30_grouped_convnd_fwd_bias_relu_add/grouped_convnd_fwd_bias_relu_add_xdl_int8.cpp @@ -137,7 +137,7 @@ int main(int argc, char* argv[]) { using InLayout = ctc::G_NW_C; using WeiLayout = ctc::G_K_X_C; - using BiasLayout = ctc::G_NW_K; + using BiasLayout = ctc::G_K; using ResidualLayout = ctc::G_NW_K; using OutLayout = ctc::G_NW_K; @@ -220,7 +220,7 @@ int main(int argc, char* argv[]) { using InLayout = ctc::G_NHW_C; using WeiLayout = ctc::G_K_YX_C; - using BiasLayout = ctc::G_NHW_K; + using BiasLayout = ctc::G_K; using ResidualLayout = ctc::G_NHW_K; using OutLayout = ctc::G_NHW_K; @@ -332,7 +332,7 @@ int main(int argc, char* argv[]) { using InLayout = ctc::G_NDHW_C; using WeiLayout = ctc::G_K_ZYX_C; - using BiasLayout = ctc::G_NDHW_K; + using BiasLayout = ctc::G_K; using ResidualLayout = ctc::G_NDHW_K; using OutLayout = ctc::G_NDHW_K; diff --git a/example/38_grouped_conv_bwd_data_bias_relu/CMakeLists.txt b/example/38_grouped_conv_bwd_data_bias_relu/CMakeLists.txt new file mode 100644 index 0000000000..36112157d6 --- /dev/null +++ b/example/38_grouped_conv_bwd_data_bias_relu/CMakeLists.txt @@ -0,0 +1 @@ +add_example_executable(example_grouped_conv_bwd_data_bias_relu_fp16 grouped_conv_bwd_data_bias_relu_fp16.cpp) diff --git a/example/38_grouped_conv_bwd_data_bias_relu/grouped_conv_bwd_data_bias_relu_common.hpp b/example/38_grouped_conv_bwd_data_bias_relu/grouped_conv_bwd_data_bias_relu_common.hpp new file mode 100644 index 0000000000..481d2e6d39 --- /dev/null +++ b/example/38_grouped_conv_bwd_data_bias_relu/grouped_conv_bwd_data_bias_relu_common.hpp @@ -0,0 +1,199 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/convolution_parameter.hpp" +#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_conv_bwd_data.hpp" + +void print_helper_msg() +{ + std::cout << "arg1: verification (0=no, 1=yes)\n" + << "arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n" + << "arg3: time kernel (0=no, 1=yes)\n" + << ck::utils::conv::get_conv_param_parser_helper_msg() << std::endl; +} + +template +int run_conv_bwd_data_bias_relu(bool do_verification, + int init_method, + bool time_kernel, + const ck::utils::conv::ConvParam& conv_param, + const HostTensorDescriptor& out_g_n_k_wos_desc, + const HostTensorDescriptor& wei_g_k_c_xs_desc, + const HostTensorDescriptor& bias_g_n_c_wis_desc, + const HostTensorDescriptor& in_g_n_c_wis_desc, + const OutElementOp& out_element_op, + const WeiElementOp& wei_element_op, + const InElementOp& in_element_op) +{ + Tensor out(out_g_n_k_wos_desc); + Tensor wei(wei_g_k_c_xs_desc); + Tensor bias(bias_g_n_c_wis_desc); + Tensor in_host(in_g_n_c_wis_desc); + Tensor in_device(in_g_n_c_wis_desc); + + std::cout << "out: " << out.mDesc << std::endl; + std::cout << "wei: " << wei.mDesc << std::endl; + std::cout << "bias: " << bias.mDesc << std::endl; + std::cout << "in: " << in_host.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + out.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + wei.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + bias.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + default: + out.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + wei.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + bias.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + } + + DeviceMem out_device_buf(sizeof(OutDataType) * out.mDesc.GetElementSpaceSize()); + DeviceMem wei_device_buf(sizeof(WeiDataType) * wei.mDesc.GetElementSpaceSize()); + DeviceMem bias_device_buf(sizeof(BiasDataType) * bias.mDesc.GetElementSpaceSize()); + DeviceMem in_device_buf(sizeof(InDataType) * in_device.mDesc.GetElementSpaceSize()); + + out_device_buf.ToDevice(out.mData.data()); + wei_device_buf.ToDevice(wei.mData.data()); + bias_device_buf.ToDevice(bias.mData.data()); + + // reset input to zero + in_device_buf.SetZero(); + + std::array a_g_n_k_wos_lengths{}; + std::array a_g_n_k_wos_strides{}; + std::array b_g_k_c_xs_lengths{}; + std::array b_g_k_c_xs_strides{}; + std::array d0_g_n_c_wis_lengths{}; + std::array d0_g_n_c_wis_strides{}; + std::array e_g_n_c_wis_lengths{}; + std::array e_g_n_c_wis_strides{}; + std::array conv_filter_strides{}; + std::array conv_filter_dilations{}; + std::array input_left_pads{}; + std::array input_right_pads{}; + + auto copy = [](auto& x, auto& y) { std::copy(x.begin(), x.end(), y.begin()); }; + + copy(out_g_n_k_wos_desc.GetLengths(), a_g_n_k_wos_lengths); + copy(out_g_n_k_wos_desc.GetStrides(), a_g_n_k_wos_strides); + copy(wei_g_k_c_xs_desc.GetLengths(), b_g_k_c_xs_lengths); + copy(wei_g_k_c_xs_desc.GetStrides(), b_g_k_c_xs_strides); + copy(bias_g_n_c_wis_desc.GetLengths(), d0_g_n_c_wis_lengths); + copy(bias_g_n_c_wis_desc.GetStrides(), d0_g_n_c_wis_strides); + copy(in_g_n_c_wis_desc.GetLengths(), e_g_n_c_wis_lengths); + copy(in_g_n_c_wis_desc.GetStrides(), e_g_n_c_wis_strides); + copy(conv_param.conv_filter_strides_, conv_filter_strides); + copy(conv_param.conv_filter_dilations_, conv_filter_dilations); + copy(conv_param.input_left_pads_, input_left_pads); + copy(conv_param.input_right_pads_, input_right_pads); + + // do conv + auto conv = DeviceInstance{}; + auto invoker = conv.MakeInvoker(); + auto argument = conv.MakeArgument( + out_device_buf.GetDeviceBuffer(), + wei_device_buf.GetDeviceBuffer(), + std::array{bias_device_buf.GetDeviceBuffer()}, + in_device_buf.GetDeviceBuffer(), + a_g_n_k_wos_lengths, + a_g_n_k_wos_strides, + b_g_k_c_xs_lengths, + b_g_k_c_xs_strides, + std::array, 1>{d0_g_n_c_wis_lengths}, + std::array, 1>{d0_g_n_c_wis_strides}, + e_g_n_c_wis_lengths, + e_g_n_c_wis_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + out_element_op, + wei_element_op, + in_element_op); + + if(!conv.IsSupportedArgument(argument)) + { + printf("wrong! device_conv with the specified compilation parameters does " + "not support this Conv problem\n"); + + return 1; + } + + float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); + + std::size_t flop = conv_param.GetFlops(); + std::size_t num_btype = conv_param.GetByte(); + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s" + << std::endl; + + if(do_verification) + { + using PassThrough = ck::tensor_operation::element_wise::PassThrough; + + // c doesn't physically exist, any layout is fine + Tensor c_host(in_g_n_c_wis_desc); + + auto ref_conv = ck::tensor_operation::host::ReferenceConvBwdData(); + + auto ref_invoker = ref_conv.MakeInvoker(); + + auto ref_argument = ref_conv.MakeArgument(c_host, + wei, + out, + conv_param.conv_filter_strides_, + conv_param.conv_filter_dilations_, + conv_param.input_left_pads_, + conv_param.input_right_pads_, + PassThrough{}, + wei_element_op, + out_element_op); + + ref_invoker.Run(ref_argument); + + // TODO: implement elementwise operation for host + in_host.ForEach( + [&](auto&, auto idx) { in_element_op(in_host(idx), c_host(idx), bias(idx)); }); + + in_device_buf.FromDevice(in_device.mData.data()); + + return ck::utils::check_err(in_device.mData, in_host.mData) ? 0 : 1; + } + + return 0; +} diff --git a/example/38_grouped_conv_bwd_data_bias_relu/grouped_conv_bwd_data_bias_relu_fp16.cpp b/example/38_grouped_conv_bwd_data_bias_relu/grouped_conv_bwd_data_bias_relu_fp16.cpp new file mode 100644 index 0000000000..c1091a67ae --- /dev/null +++ b/example/38_grouped_conv_bwd_data_bias_relu/grouped_conv_bwd_data_bias_relu_fp16.cpp @@ -0,0 +1,174 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "grouped_conv_bwd_data_bias_relu_common.hpp" + +#include "ck/tensor_operation/gpu/device/device_grouped_conv_bwd_data_multiple_d.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp" + +template +using S = ck::Sequence; + +using OutDataType = ck::half_t; +using WeiDataType = ck::half_t; +using AccDataType = float; +using CShuffleDataType = ck::half_t; +using BiasDataType = ck::half_t; // bias +using InDataType = ck::half_t; + +using OutLayout = ck::tensor_layout::convolution::GNHWK; +using WeiLayout = ck::tensor_layout::convolution::GKYXC; +using BiasLayout = ck::tensor_layout::convolution::G_C; +using InLayout = ck::tensor_layout::convolution::GNHWC; + +using OutElementOp = ck::tensor_operation::element_wise::PassThrough; +using WeiElementOp = ck::tensor_operation::element_wise::PassThrough; +using CBiasInElementOp = ck::tensor_operation::element_wise::AddRelu; + +static constexpr auto ConvBwdDataDefault = + ck::tensor_operation::device::ConvolutionBackwardDataSpecialization::Default; + +template +using DeviceConvNdBwdDataInstance = + ck::tensor_operation::device::DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< + NDimSpatial, + OutLayout, + WeiLayout, + ck::Tuple, + InLayout, + OutDataType, + WeiDataType, + AccDataType, + CShuffleDataType, + ck::Tuple, + InDataType, + OutElementOp, + WeiElementOp, + CBiasInElementOp, + ConvBwdDataDefault, + true, // DoPadGemmM + true, // DoPadGemmN + 1, + 256, + 128, + 256, + 32, + 8, + 2, + 32, + 32, + 2, + 4, + S<4, 64, 1>, + S<1, 0, 2>, + S<1, 0, 2>, + 2, + 8, + 8, + 1, + S<4, 64, 1>, + S<0, 2, 1>, + S<0, 2, 1>, + 1, + 4, + 2, + 0, + 1, + 1, + S<1, 32, 1, 8>, + 8>; + +int main(int argc, char* argv[]) +{ + namespace ctc = ck::tensor_layout::convolution; + + print_helper_msg(); + + bool do_verification = true; + int init_method = 1; + bool time_kernel = false; + + ck::utils::conv::ConvParam conv_param{ + 2, 2, 128, 256, 256, {3, 3}, {14, 14}, {2, 2}, {1, 1}, {1, 1}, {1, 1}}; + + if(argc == 1) + { + // use default + } + else if(argc == 4) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + } + else + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + const ck::index_t num_dim_spatial = std::stoi(argv[4]); + + conv_param = ck::utils::conv::parse_conv_param(num_dim_spatial, 5, argv); + } + + const auto in_element_op = CBiasInElementOp{}; + const auto wei_element_op = WeiElementOp{}; + const auto out_element_op = OutElementOp{}; + + if(conv_param.num_dim_spatial_ == 2) + { + // output image: GNHWK + const auto out_g_n_k_wos_desc = + ck::utils::conv::make_output_host_tensor_descriptor_g_n_k_wos_packed( + conv_param); + + // weight: GKYXC + const auto wei_g_k_c_xs_desc = + ck::utils::conv::make_weight_host_tensor_descriptor_g_k_c_xs_packed( + conv_param); + + // input image bias: G_C + const auto bias_g_n_c_wis_desc = + HostTensorDescriptor({conv_param.G_, + conv_param.N_, + conv_param.C_, + conv_param.input_spatial_lengths_[0], + conv_param.input_spatial_lengths_[1]}, + { + conv_param.C_, // g + 0, // n + 1, // c + 0, // hi + 0 // wi + }); + + // input image: GNHWC + const auto in_g_n_c_wis_desc = + ck::utils::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed( + conv_param); + + using DeviceInstance = DeviceConvNdBwdDataInstance<2>; + + run_conv_bwd_data_bias_relu<2, + OutDataType, + WeiDataType, + BiasDataType, + InDataType, + OutElementOp, + WeiElementOp, + CBiasInElementOp, + DeviceInstance>(do_verification, + init_method, + time_kernel, + conv_param, + out_g_n_k_wos_desc, + wei_g_k_c_xs_desc, + bias_g_n_c_wis_desc, + in_g_n_c_wis_desc, + wei_element_op, + out_element_op, + in_element_op); + } + + return 0; +} diff --git a/example/CMakeLists.txt b/example/CMakeLists.txt index 4a11997d87..32207ef3a6 100644 --- a/example/CMakeLists.txt +++ b/example/CMakeLists.txt @@ -21,36 +21,10 @@ function(add_example_executable_no_testing EXAMPLE_NAME FILE_NAME) add_dependencies(examples ${EXAMPLE_NAME}) endfunction(add_example_executable_no_testing EXAMPLE_NAME) -add_subdirectory(01_gemm) -add_subdirectory(02_gemm_bilinear) -add_subdirectory(03_gemm_bias_relu) -add_subdirectory(04_gemm_add_add_fastgelu) -add_subdirectory(09_convnd_fwd) -add_subdirectory(10_convnd_fwd_multiple_d_multiple_reduce) -add_subdirectory(12_reduce) -add_subdirectory(13_pool2d_fwd) -add_subdirectory(14_gemm_xdl_requant_relu_requant) -add_subdirectory(15_grouped_gemm) -add_subdirectory(16_gemm_multi_d_multi_reduces) -add_subdirectory(17_convnd_bwd_data) -add_subdirectory(18_batched_gemm_reduce) -add_subdirectory(19_binary_elementwise) -add_subdirectory(20_convnd_bwd_weight) -add_subdirectory(21_gemm_layernorm) -add_subdirectory(22_cgemm) -add_subdirectory(23_softmax) -add_subdirectory(24_batched_gemm) -add_subdirectory(25_gemm_bias_e_permute) -add_subdirectory(26_contraction) -add_subdirectory(27_layernorm) -add_subdirectory(28_grouped_gemm_bias_e_permute) -add_subdirectory(29_batched_gemm_bias_e_permute) -add_subdirectory(30_grouped_convnd_fwd_bias_relu_add) -add_subdirectory(31_batched_gemm_gemm) -add_subdirectory(32_batched_gemm_scale_softmax_gemm) -add_subdirectory(33_multiple_reduce) -add_subdirectory(34_batchnorm) -add_subdirectory(35_splitK_gemm) -add_subdirectory(36_sparse_embedding) -add_subdirectory(37_batched_gemm_add_add_relu_gemm_add) -add_subdirectory(41_grouped_conv_conv_fwd) +# add all example subdir +file(GLOB dir_list LIST_DIRECTORIES true *) +FOREACH(subdir ${dir_list}) + IF(IS_DIRECTORY "${subdir}") + add_subdirectory(${subdir}) + ENDIF() +ENDFOREACH() diff --git a/include/ck/tensor_operation/gpu/device/device_batched_contraction_multiple_d_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/device_batched_contraction_multiple_d_xdl_cshuffle.hpp index e0c4a408ed..9152e8d85a 100644 --- a/include/ck/tensor_operation/gpu/device/device_batched_contraction_multiple_d_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/device_batched_contraction_multiple_d_xdl_cshuffle.hpp @@ -549,10 +549,6 @@ struct DeviceBatchedContractionMultipleD_Xdl_CShuffle BElementwiseOperation, CDEElementwiseOperation, InMemoryDataOperationEnum::Set, - AGridDesc_M_K, - BGridDesc_N_K, - DsGridDesc_M_N, - EGridDesc_M_N, NumGemmKPrefetchStage, BlockSize, MPerBlock, @@ -586,12 +582,19 @@ struct DeviceBatchedContractionMultipleD_Xdl_CShuffle CDEBlockTransferScalarPerVector_NPerBlock, LoopSched>; - using AGridDesc_AK0_M_AK1 = remove_cvref_t; - using BGridDesc_BK0_N_BK1 = remove_cvref_t; + using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t; + using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t; - using Block2ETileMap = typename GridwiseGemm::DefaultBlock2ETileMap; + // block-to-e-tile map + using Block2ETileMap = + remove_cvref_t; // Argument struct Argument : public BaseArgument @@ -719,10 +722,9 @@ struct DeviceBatchedContractionMultipleD_Xdl_CShuffle // tensor descriptors for block/thread-wise copy AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_; BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_; - typename GridwiseGemm::DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock ds_grid_desc_mblock_mperblock_nblock_nperblock_; - typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock - e_grid_desc_mblock_mperblock_nblock_nperblock_; + EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock e_grid_desc_mblock_mperblock_nblock_nperblock_; // block-to-e-tile map Block2ETileMap block_2_etile_map_; @@ -786,10 +788,10 @@ struct DeviceBatchedContractionMultipleD_Xdl_CShuffle CDEElementwiseOperation, DeviceOp::AGridDesc_AK0_M_AK1, DeviceOp::BGridDesc_BK0_N_BK1, - typename GridwiseGemm::DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, - typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, + DeviceOp::DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, + DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, ComputePtrOffsetOfStridedBatch, - typename GridwiseGemm::DefaultBlock2ETileMap, + DeviceOp::Block2ETileMap, has_main_loop>; return launch_and_time_kernel(stream_config, diff --git a/include/ck/tensor_operation/gpu/device/device_batched_gemm_e_permute_xdl.hpp b/include/ck/tensor_operation/gpu/device/device_batched_gemm_e_permute_xdl.hpp deleted file mode 100644 index 8c5dc7de1f..0000000000 --- a/include/ck/tensor_operation/gpu/device/device_batched_gemm_e_permute_xdl.hpp +++ /dev/null @@ -1,682 +0,0 @@ -#pragma once - -#include -#include - -#include "ck/utility/common_header.hpp" -#include "ck/tensor_description/tensor_descriptor.hpp" -#include "ck/tensor_description/tensor_descriptor_helper.hpp" -#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" -#include "ck/tensor_operation/gpu/device/device_batched_gemm_e_permute.hpp" -#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" -#include "ck/tensor_operation/gpu/device/matrix_padder.hpp" -#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp" -#include "ck/host_utility/device_prop.hpp" -#include "ck/host_utility/kernel_launch.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { - -/* - * \brief Wrapper function of GridwiseGemm::Run to realize BatchedGEMM. - * - * \tparam ComputePtrOffsetOfBatch Class that computes the base pointer offsets of A, B, C matrix - * given the batch. For example, ComputePtrOffsetOfStridedBatch() computes the offsets of evenly - * strided batched, but we can easily extend to other layouts. The returned offset can be either \p - * index_t or \p long_index_t. If it returns \p long_index_t, we are not subject to the 2GB -#include "ck/tensor_operation/gpu/device/matrix_padder.hpp" - * limitations. - * - * \tparam Block2ETileMap Block2ETileMap::CalculateBottomIndex() takes in id of a workgroup and - * returns the 2D index of the tile that it computes. \see - * GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3::Run(). - * \note Using \p ComputePtrOffsetOfBatch gives us the flexibility that 2 workgroups can compute 2 - * tiles from different matrices. Keep in mind that these 2 matrices can share the same grid - * descriptor (like in BatchedGEMM), or use their own grid descriptors (in GroupedGemm). \link - * device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp kernel_gemm_xdlops_v2r3_for_conv3d \endlink for \link - * DeviceConv3d \endlink uses the same concept, but currently does NOT encapsulate the computing of - * pointer offset into \p ComputePtrOffsetOfStridedBatch. - * - * \note \p Block2ETileMap allows customized mapping between a workgroup and the C-tile it computes. - * Together with \p ComputePtrOffsetOfBatch, we can reuse GridwiseGemm (and GridwiseGemm fusion ) to - * realize BatchedGemmCPermute and GroupedGemm (and the corresponding GEMM fusion). - * - */ -template -__global__ void -#if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) -#endif - kernel_batched_gemm_e_permute_xdl(const ABDataType* __restrict__ p_a_grid, - const ABDataType* __restrict__ p_b_grid, - EDataType* __restrict__ p_e_grid, - const index_t batch_count, - const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, - const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, - const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock - e_grid_desc_mblock_mperblock_nblock_nperblock, - const AElementwiseOperation a_element_op, - const BElementwiseOperation b_element_op, - const CDEElementwiseOperation cde_element_op, - const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch, - const Block2ETileMap block_2_etile_map) -{ -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__)) - const index_t num_blocks_per_batch = - __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); - const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); - - const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane( - static_cast(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx))); - const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane( - static_cast(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx))); - const long_index_t e_batch_offset = __builtin_amdgcn_readfirstlane( - static_cast(compute_ptr_offset_of_batch.GetCPtrOffset(g_idx))); - - __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - - GridwiseGemm::template Run(p_a_grid + a_batch_offset, - p_b_grid + b_batch_offset, - ck::Tuple<>{}, - p_e_grid + e_batch_offset, - p_shared, - a_element_op, - b_element_op, - cde_element_op, - a_grid_desc_ak0_m_ak1, - b_grid_desc_bk0_n_bk1, - ck::Tuple<>{}, - e_grid_desc_mblock_mperblock_nblock_nperblock, - block_2_etile_map); -#else - ignore = p_a_grid; - ignore = p_b_grid; - ignore = p_e_grid; - ignore = batch_count; - ignore = a_grid_desc_ak0_m_ak1; - ignore = b_grid_desc_bk0_n_bk1; - ignore = e_grid_desc_mblock_mperblock_nblock_nperblock; - ignore = a_element_op; - ignore = b_element_op; - ignore = cde_element_op; - ignore = compute_ptr_offset_of_batch; - ignore = block_2_etile_map; -#endif -} - -template -struct DeviceBatchedGemmEPermuteXdl : public DeviceBatchedGemmEPermute -{ - using DeviceOp = DeviceBatchedGemmEPermuteXdl; - - static constexpr auto I0 = Number<0>{}; - static constexpr auto I1 = Number<1>{}; - static constexpr auto I2 = Number<2>{}; - - static constexpr auto matrix_padder = - MatrixPadder{MPerBlock, NPerBlock, KPerBlock}; - - static auto MakeAGridDescriptor_M_K(index_t MRaw, index_t KRaw, index_t StrideA) - { - const auto a_grid_desc_mraw_kraw = [&]() { - if constexpr(is_same_v) - { - return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw), - make_tuple(StrideA, I1)); - } - else if constexpr(is_same_v) - { - return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw), - make_tuple(I1, StrideA)); - } - }(); - - return matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw); - } - - static auto MakeBGridDescriptor_N_K(index_t KRaw, index_t NRaw, index_t StrideB) - { - const auto b_grid_desc_nraw_kraw = [&]() { - if constexpr(is_same::value) - { - return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), - make_tuple(I1, StrideB)); - } - else if constexpr(is_same::value) - { - return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), - make_tuple(StrideB, I1)); - } - }(); - - return matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw); - } - - static auto - MakeEGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t stride_M, index_t stride_N) - { - const auto e_grid_desc_mraw_nraw = - make_naive_tensor_descriptor(make_tuple(MRaw, NRaw), make_tuple(stride_M, stride_N)); - - return matrix_padder.PadCDescriptor_M_N(e_grid_desc_mraw_nraw); - } - - static auto MakeEGridDescriptor_G0_G1_M_N(index_t G0, - index_t G1, - index_t MRaw, - index_t NRaw, - index_t stride_G0, - index_t stride_G1, - index_t stride_M, - index_t stride_N) - { - const auto e_grid_desc_g0_g1_mraw_nraw = [&]() { - return make_naive_tensor_descriptor( - make_tuple(G0, G1, MRaw, NRaw), - make_tuple(stride_G0, stride_G1, stride_M, stride_N)); - }(); - - const auto M = math::integer_divide_ceil(MRaw, MPerBlock) * MPerBlock; - const auto N = math::integer_divide_ceil(NRaw, NPerBlock) * NPerBlock; - - const auto MPad = M - MRaw; - const auto NPad = N - NRaw; - - if constexpr(GemmSpec == GemmSpecialization::MNPadding || - GemmSpec == GemmSpecialization::MNKPadding) - { - // pad M and N - return transform_tensor_descriptor( - e_grid_desc_g0_g1_mraw_nraw, - make_tuple(make_pass_through_transform(G0), - make_pass_through_transform(G1), - make_right_pad_transform(MRaw, MPad), - make_right_pad_transform(NRaw, NPad)), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); - } - else if constexpr(GemmSpec == GemmSpecialization::MPadding || - GemmSpec == GemmSpecialization::MKPadding) - { - // pad M, but not N - return transform_tensor_descriptor( - e_grid_desc_g0_g1_mraw_nraw, - make_tuple(make_pass_through_transform(G0), - make_pass_through_transform(G1), - make_right_pad_transform(MRaw, MPad), - make_pass_through_transform(NRaw)), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); - } - else if constexpr(GemmSpec == GemmSpecialization::NPadding || - GemmSpec == GemmSpecialization::NKPadding) - { - // pad N, but not M - return transform_tensor_descriptor( - e_grid_desc_g0_g1_mraw_nraw, - make_tuple(make_pass_through_transform(G0), - make_pass_through_transform(G1), - make_pass_through_transform(MRaw), - make_right_pad_transform(NRaw, NPad)), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); - } - else - { - // not pad M or N - return e_grid_desc_g0_g1_mraw_nraw; - } - } - - using AGridDesc_M_K = decltype(MakeAGridDescriptor_M_K(1, 1, 1)); - using BGridDesc_N_K = decltype(MakeBGridDescriptor_N_K(1, 1, 1)); - using EGridDesc_M_N = decltype(MakeEGridDescriptor_M_N(1, 1, 1, 1)); - using EGridDesc_G0_G1_M_N = decltype(MakeEGridDescriptor_G0_G1_M_N(1, 1, 1, 1, 1, 1, 1, 1)); - - struct ComputePtrOffsetOfStridedBatch - { - ComputePtrOffsetOfStridedBatch(index_t Batchstride_A, - index_t Batchstride_B, - EGridDesc_G0_G1_M_N e_grid_desc_g0_g1_m_n) - : Batchstride_A_(Batchstride_A), - Batchstride_B_(Batchstride_B), - e_grid_desc_g0_g1_m_n_(e_grid_desc_g0_g1_m_n) - { - } - - __host__ __device__ constexpr long_index_t GetAPtrOffset(index_t g_idx) const - { - return g_idx * static_cast(Batchstride_A_); - } - - __host__ __device__ constexpr long_index_t GetBPtrOffset(index_t g_idx) const - { - return g_idx * static_cast(Batchstride_B_); - } - - __host__ __device__ constexpr long_index_t GetCPtrOffset(index_t g_idx) const - { - const index_t G1 = e_grid_desc_g0_g1_m_n_.GetLength(I1); - index_t b0 = g_idx / G1; - index_t b1 = g_idx - b0 * G1; // g_idx % G1 - return e_grid_desc_g0_g1_m_n_.CalculateOffset(make_multi_index(b0, b1, 0, 0)); - } - - private: - index_t Batchstride_A_; - index_t Batchstride_B_; - EGridDesc_G0_G1_M_N e_grid_desc_g0_g1_m_n_; - }; - - using GridwiseGemm = GridwiseGemmMultipleD_xdl_cshuffle< - ADataType, // TODO: distinguish A/B datatype - AccDataType, - CShuffleDataType, - ck::Tuple<>, // DsDataType, - EDataType, // EDataType, - AElementwiseOperation, - BElementwiseOperation, - CDEElementwiseOperation, - InMemoryDataOperationEnum::Set, - AGridDesc_M_K, - BGridDesc_N_K, - Tuple<>, - EGridDesc_M_N, - NumPrefetch, - BlockSize, - MPerBlock, - NPerBlock, - KPerBlock, - AK1, - BK1, - MPerXDL, - NPerXDL, - MXdlPerWave, - NXdlPerWave, - ABlockTransferThreadClusterLengths_K0_M_K1, - ABlockTransferThreadClusterArrangeOrder, - ABlockTransferSrcAccessOrder, - ABlockTransferSrcVectorDim, - ABlockTransferSrcScalarPerVector, - ABlockTransferDstScalarPerVector_K1, - false, // AThreadTransferSrcResetCoordinateAfterRun, - ABlockLdsExtraM, - BBlockTransferThreadClusterLengths_K0_N_K1, - BBlockTransferThreadClusterArrangeOrder, - BBlockTransferSrcAccessOrder, - BBlockTransferSrcVectorDim, - BBlockTransferSrcScalarPerVector, - BBlockTransferDstScalarPerVector_K1, - false, // BThreadTransferSrcResetCoordinateAfterRun, - BBlockLdsExtraN, - CShuffleMXdlPerWavePerShuffle, - CShuffleNXdlPerWavePerShuffle, - CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, - CDEBlockTransferScalarPerVector_NPerBlock, - LoopSched>; - - using AGridDesc_AK0_M_AK1 = remove_cvref_t; - using BGridDesc_BK0_N_BK1 = remove_cvref_t; - - using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = decltype( - GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(EGridDesc_M_N{})); - using Block2ETileMap = typename GridwiseGemm::DefaultBlock2ETileMap; - - // Argument - struct Argument : public BaseArgument - { - Argument(const ADataType* p_a_grid, - const BDataType* p_b_grid, - EDataType* p_e_grid, - index_t M, - index_t N, - index_t K, - index_t stride_A, - index_t stride_B, - index_t batch_stride_A, - index_t batch_stride_B, - BatchedGemmEPermuteDesc batched_gemm_e_permute_desc, - index_t BatchCount, - AElementwiseOperation a_element_op, - BElementwiseOperation b_element_op, - CDEElementwiseOperation cde_element_op) - : p_a_grid_{p_a_grid}, - p_b_grid_{p_b_grid}, - p_e_grid_{p_e_grid}, - BatchCount_(BatchCount), - a_grid_desc_m_k_{DeviceOp::MakeAGridDescriptor_M_K(M, K, stride_A)}, - b_grid_desc_n_k_{DeviceOp::MakeBGridDescriptor_N_K(K, N, stride_B)}, - e_grid_desc_m_n_{ - DeviceOp::MakeEGridDescriptor_M_N(batched_gemm_e_permute_desc.M_, - batched_gemm_e_permute_desc.N_, - batched_gemm_e_permute_desc.stride_M_, - batched_gemm_e_permute_desc.stride_N_)}, - a_grid_desc_ak0_m_ak1_{ - GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(a_grid_desc_m_k_)}, - b_grid_desc_bk0_n_bk1_{ - GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(b_grid_desc_n_k_)}, - e_grid_desc_mblock_mperblock_nblock_nperblock{}, - e_grid_desc_g0_g1_m_n_{ - DeviceOp::MakeEGridDescriptor_G0_G1_M_N(batched_gemm_e_permute_desc.G0_, - batched_gemm_e_permute_desc.G1_, - batched_gemm_e_permute_desc.M_, - batched_gemm_e_permute_desc.N_, - batched_gemm_e_permute_desc.stride_G0_, - batched_gemm_e_permute_desc.stride_G1_, - batched_gemm_e_permute_desc.stride_M_, - batched_gemm_e_permute_desc.stride_N_)}, - compute_ptr_offset_of_batch_{batch_stride_A, batch_stride_B, e_grid_desc_g0_g1_m_n_}, - block_2_etile_map_{GridwiseGemm::MakeDefaultBlock2ETileMap(e_grid_desc_m_n_)}, - a_element_op_{a_element_op}, - b_element_op_{b_element_op}, - cde_element_op_{cde_element_op} - { - if(GridwiseGemm::CheckValidity(a_grid_desc_m_k_, - b_grid_desc_n_k_, - ck::Tuple<>{}, - e_grid_desc_m_n_, - block_2_etile_map_)) - { - e_grid_desc_mblock_mperblock_nblock_nperblock = - GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( - e_grid_desc_m_n_); - } - } - - void Print() const - { - std::cout << "A[M, K]: " << a_grid_desc_m_k_ << std::endl; - std::cout << "B[N, K]: " << b_grid_desc_n_k_ << std::endl; - std::cout << "C[M, N]: " << e_grid_desc_m_n_ << std::endl; - } - - // private: - // pointers - const ADataType* p_a_grid_; - const BDataType* p_b_grid_; - EDataType* p_e_grid_; - - // batch count - index_t BatchCount_; - - // tensor descriptors for problem definiton - AGridDesc_M_K a_grid_desc_m_k_; - BGridDesc_N_K b_grid_desc_n_k_; - EGridDesc_M_N e_grid_desc_m_n_; - - // tensor descriptors for block/thread-wise copy - AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_; - BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_; - EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock e_grid_desc_mblock_mperblock_nblock_nperblock; - EGridDesc_G0_G1_M_N e_grid_desc_g0_g1_m_n_; - - // for calculating Batch offset - ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch_; - - // block-to-e-tile map - Block2ETileMap block_2_etile_map_; - - // element-wise op - AElementwiseOperation a_element_op_; - BElementwiseOperation b_element_op_; - CDEElementwiseOperation cde_element_op_; - }; - - // Invoker - struct Invoker : public BaseInvoker - { - using Argument = DeviceOp::Argument; - - float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) - { - if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_m_k_, - arg.b_grid_desc_n_k_, - ck::Tuple<>{}, - arg.e_grid_desc_m_n_, - arg.block_2_etile_map_)) - { - throw std::runtime_error( - "wrong! GridwiseBatchedGemmCPermute_km_kn_m0m1n0n1_xdlops_v2r3 has invalid " - "setting"); - } - - const index_t grid_size = - arg.block_2_etile_map_.CalculateGridSize(arg.e_grid_desc_m_n_) * arg.BatchCount_; - - const auto K = - arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_.GetLength(I2); - - auto launch_kernel = [&](auto has_main_k_block_loop_) { - const auto kernel = kernel_batched_gemm_e_permute_xdl< - GridwiseGemm, - ADataType, // TODO: distiguish A/B datatype - EDataType, - remove_reference_t, - remove_reference_t, - typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, - AElementwiseOperation, - BElementwiseOperation, - CDEElementwiseOperation, - ComputePtrOffsetOfStridedBatch, - remove_reference_t, - has_main_k_block_loop_>; - - return launch_and_time_kernel(stream_config, - kernel, - dim3(grid_size), - dim3(BlockSize), - 0, - arg.p_a_grid_, - arg.p_b_grid_, - arg.p_e_grid_, - arg.BatchCount_, - arg.a_grid_desc_ak0_m_ak1_, - arg.b_grid_desc_bk0_n_bk1_, - arg.e_grid_desc_mblock_mperblock_nblock_nperblock, - arg.a_element_op_, - arg.b_element_op_, - arg.cde_element_op_, - arg.compute_ptr_offset_of_batch_, - arg.block_2_etile_map_); - }; - - if(GridwiseGemm::CalculateHasMainKBlockLoop(K)) - { - return launch_kernel(integral_constant{}); - } - else - { - return launch_kernel(integral_constant{}); - } - } - - // polymorphic - float Run(const BaseArgument* p_arg, - const StreamConfig& stream_config = StreamConfig{}) override - { - return Run(*dynamic_cast(p_arg), stream_config); - } - }; - - static constexpr bool IsValidCompilationParameter() - { - // TODO: properly implement this check - return true; - } - - static bool IsSupportedArgument(const Argument& arg) - { - return GridwiseGemm::CheckValidity(arg.a_grid_desc_m_k_, - arg.b_grid_desc_n_k_, - ck::Tuple<>{}, - arg.e_grid_desc_m_n_, - arg.block_2_etile_map_); - } - - // polymorphic - bool IsSupportedArgument(const BaseArgument* p_arg) override - { - return IsSupportedArgument(*dynamic_cast(p_arg)); - } - - static auto MakeArgument(const ADataType* p_a, - const BDataType* p_b, - EDataType* p_e, - index_t M, - index_t N, - index_t K, - index_t stride_A, - index_t stride_B, - index_t batch_stride_A, - index_t batch_stride_B, - BatchedGemmEPermuteDesc batched_gemm_e_permute_desc, - index_t BatchCount, - AElementwiseOperation a_element_op, - BElementwiseOperation b_element_op, - CDEElementwiseOperation cde_element_op) - { - return Argument{p_a, - p_b, - p_e, - M, - N, - K, - stride_A, - stride_B, - batch_stride_A, - batch_stride_B, - batched_gemm_e_permute_desc, - BatchCount, - a_element_op, - b_element_op, - cde_element_op}; - } - - static auto MakeInvoker() { return Invoker{}; } - - // polymorphic - std::unique_ptr - MakeArgumentPointer(const void* p_a, - const void* p_b, - void* p_e, - index_t M, - index_t N, - index_t K, - index_t stride_A, - index_t stride_B, - index_t batch_stride_A, - index_t batch_stride_B, - BatchedGemmEPermuteDesc batched_gemm_e_permute_desc, - index_t BatchCount, - AElementwiseOperation a_element_op, - BElementwiseOperation b_element_op, - CDEElementwiseOperation cde_element_op) override - { - return std::make_unique(static_cast(p_a), - static_cast(p_b), - static_cast(p_e), - M, - N, - K, - stride_A, - stride_B, - batch_stride_A, - batch_stride_B, - batched_gemm_e_permute_desc, - BatchCount, - a_element_op, - b_element_op, - cde_element_op); - } - - // polymorphic - std::unique_ptr MakeInvokerPointer() override - { - return std::make_unique(Invoker{}); - } - - // polymorphic - std::string GetTypeString() const override - { - auto str = std::stringstream(); - - // clang-format off - str << "DeviceBatchedGemmEPermuteXdl" - << "<" - << BlockSize << ", " - << MPerBlock << ", " - << NPerBlock << ", " - << KPerBlock - << ">"; - // clang-format on - - return str.str(); - } -}; - -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/device_batched_gemm_multi_d_xdl.hpp b/include/ck/tensor_operation/gpu/device/device_batched_gemm_multi_d_xdl.hpp index 2b9520145d..af5b880654 100644 --- a/include/ck/tensor_operation/gpu/device/device_batched_gemm_multi_d_xdl.hpp +++ b/include/ck/tensor_operation/gpu/device/device_batched_gemm_multi_d_xdl.hpp @@ -333,10 +333,6 @@ struct DeviceBatchedGemmMultiD_Xdl : public DeviceBatchedGemmMultiD; - using AGridDesc_AK0_M_AK1 = remove_cvref_t; - using BGridDesc_BK0_N_BK1 = remove_cvref_t; + using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t; + using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t; - using Block2ETileMap = typename GridwiseGemm::DefaultBlock2ETileMap; + // block-to-e-tile map + using Block2ETileMap = + remove_cvref_t; // Argument struct Argument : public BaseArgument @@ -478,10 +481,9 @@ struct DeviceBatchedGemmMultiD_Xdl : public DeviceBatchedGemmMultiD; + const auto kernel = + kernel_batched_gemm_xdl; return launch_and_time_kernel(stream_config, kernel, diff --git a/include/ck/tensor_operation/gpu/device/device_contraction_multiple_d_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/device_contraction_multiple_d_xdl_cshuffle.hpp index b1c2545ff0..72c6d0b6f7 100644 --- a/include/ck/tensor_operation/gpu/device/device_contraction_multiple_d_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/device_contraction_multiple_d_xdl_cshuffle.hpp @@ -320,10 +320,6 @@ struct DeviceContractionMultipleD_Xdl_CShuffle BElementwiseOperation, CDEElementwiseOperation, InMemoryDataOperationEnum::Set, - AGridDesc_M_K, - BGridDesc_N_K, - DsGridDesc_M_N, - EGridDesc_M_N, NumGemmKPrefetchStage, BlockSize, MPerBlock, @@ -357,12 +353,19 @@ struct DeviceContractionMultipleD_Xdl_CShuffle CDEBlockTransferScalarPerVector_NPerBlock, LoopSched>; - using AGridDesc_AK0_M_AK1 = remove_cvref_t; - using BGridDesc_BK0_N_BK1 = remove_cvref_t; + using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t; + using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t; - using Block2ETileMap = typename GridwiseGemm::DefaultBlock2ETileMap; + // block-to-e-tile map + using Block2ETileMap = + remove_cvref_t; // Argument struct Argument : public BaseArgument @@ -475,10 +478,9 @@ struct DeviceContractionMultipleD_Xdl_CShuffle // tensor descriptors for block/thread-wise copy AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_; BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_; - typename GridwiseGemm::DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock ds_grid_desc_mblock_mperblock_nblock_nperblock_; - typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock - e_grid_desc_mblock_mperblock_nblock_nperblock_; + EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock e_grid_desc_mblock_mperblock_nblock_nperblock_; // block-to-e-tile map Block2ETileMap block_2_etile_map_; @@ -535,9 +537,9 @@ struct DeviceContractionMultipleD_Xdl_CShuffle CDEElementwiseOperation, DeviceOp::AGridDesc_AK0_M_AK1, DeviceOp::BGridDesc_BK0_N_BK1, - typename GridwiseGemm::DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, - typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, - typename GridwiseGemm::DefaultBlock2ETileMap, + DeviceOp::DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, + DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, + DeviceOp::Block2ETileMap, has_main_loop>; return launch_and_time_kernel(stream_config, diff --git a/include/ck/tensor_operation/gpu/device/device_gemm_bias_e_permute_xdl.hpp b/include/ck/tensor_operation/gpu/device/device_gemm_bias_e_permute_xdl.hpp index ffdb8d5894..1914068828 100644 --- a/include/ck/tensor_operation/gpu/device/device_gemm_bias_e_permute_xdl.hpp +++ b/include/ck/tensor_operation/gpu/device/device_gemm_bias_e_permute_xdl.hpp @@ -237,10 +237,6 @@ struct DeviceGemmBiasEPermute_Xdl : public DeviceGemmBiasCPermute{}); } + // desc for problem definition using AGridDesc_M_K = decltype(MakeAGridDescriptor_M_K(1, 1, 1)); using BGridDesc_N_K = decltype(MakeBGridDescriptor_N_K(1, 1, 1)); using DsGridDesc_M_N = remove_cvref_t; @@ -250,10 +251,6 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD; - using AGridDesc_AK0_M_AK1 = remove_cvref_t; - using BGridDesc_BK0_N_BK1 = remove_cvref_t; + using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t; + using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t; + + // block-to-e-tile map + using Block2ETileMap = + remove_cvref_t; // Argument struct Argument : public BaseArgument @@ -383,13 +389,12 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD; return launch_and_time_kernel(stream_config, diff --git a/include/ck/tensor_operation/gpu/device/device_grouped_contraction_multiple_d_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/device_grouped_contraction_multiple_d_xdl_cshuffle.hpp index 2dcd558273..03d9e26a46 100644 --- a/include/ck/tensor_operation/gpu/device/device_grouped_contraction_multiple_d_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/device_grouped_contraction_multiple_d_xdl_cshuffle.hpp @@ -365,10 +365,6 @@ struct DeviceGroupedContractionMultipleD_Xdl_CShuffle BElementwiseOperation, CDEElementwiseOperation, InMemoryDataOperationEnum::Set, - AGridDesc_M_K, - BGridDesc_N_K, - DsGridDesc_M_N, - EGridDesc_M_N, NumGemmKPrefetchStage, BlockSize, MPerBlock, @@ -402,17 +398,21 @@ struct DeviceGroupedContractionMultipleD_Xdl_CShuffle CDEBlockTransferScalarPerVector_NPerBlock, LoopSched>; - using AGridDesc_AK0_M_AK1 = remove_cvref_t; - using BGridDesc_BK0_N_BK1 = remove_cvref_t; + using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t; + using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t; struct GroupedContractionBlock2ETileMap { - static_assert( - std::is_same::value, - "Wrong! Should be the same type name"); + // block-to-e-tile map + using Block2ETileMap = + remove_cvref_t; GroupedContractionBlock2ETileMap(const EGridDesc_M_N& e_grid_desc_m_n, ck::index_t BlockStart) @@ -441,7 +441,7 @@ struct DeviceGroupedContractionMultipleD_Xdl_CShuffle return default_block_2_etile_map_.CheckValidity(e_grid_desc_m_n); } - typename GridwiseGemm::DefaultBlock2ETileMap default_block_2_etile_map_; + Block2ETileMap default_block_2_etile_map_; ck::index_t block_start_; }; @@ -456,10 +456,9 @@ struct DeviceGroupedContractionMultipleD_Xdl_CShuffle // tensor descriptors for block/thread-wise copy AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_; BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_; - typename GridwiseGemm::DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock ds_grid_desc_mblock_mperblock_nblock_nperblock_; - typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock - e_grid_desc_mblock_mperblock_nblock_nperblock_; + EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock e_grid_desc_mblock_mperblock_nblock_nperblock_; // lock-to-e-tile map GroupedContractionBlock2ETileMap block_2_etile_map_; diff --git a/include/ck/tensor_operation/gpu/device/device_grouped_conv_bwd_data_multiple_d.hpp b/include/ck/tensor_operation/gpu/device/device_grouped_conv_bwd_data_multiple_d.hpp new file mode 100644 index 0000000000..fa73188174 --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/device_grouped_conv_bwd_data_multiple_d.hpp @@ -0,0 +1,67 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +#include "ck/tensor_operation/gpu/device/device_base.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +// Conv backward data multiple D: +// input : output image A[G, N, K, Ho, Wo] +// input : weight B[G, K, C, Y, X], +// input : D0[G, N, K, Ho, Wo], D1[G, N, K, Ho, Wo], ... +// output : input image E[G, N, C, Hi, Wi], +// C = a_op(A) * b_op(B) +// E = cde_op(C, D0, D1, ...) +template +struct DeviceGroupedConvBwdDataMultipleD : public BaseOperator +{ + static constexpr index_t NumDTensor = DsDataType::Size(); + + static_assert(NumDTensor == DsLayout::Size(), "wrong! Inconsistent NumDTensor"); + + virtual std::unique_ptr MakeArgumentPointer( + const void* p_a, // output image + const void* p_b, // weight + const std::array& p_ds, // bias + void* p_e, // input image + const std::array& a_g_n_k_wos_lengths, // output image + const std::array& a_g_n_k_wos_strides, // output image + const std::array& b_g_k_c_xs_lengths, // weight + const std::array& b_g_k_c_xs_strides, // weight + const std::array, NumDTensor>& + ds_g_n_k_wos_lengths, // bias + const std::array, NumDTensor>& + ds_g_n_k_wos_strides, // bias + const std::array& e_g_n_c_wis_lengths, // input image + const std::array& e_g_n_c_wis_strides, // input image + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads, + const AElementwiseOperation& a_element_op, + const BElementwiseOperation& b_element_op, + const CDEElementwiseOperation& cde_element_op) = 0; + + virtual std::unique_ptr MakeInvokerPointer() = 0; +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_d.hpp b/include/ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_d.hpp index 00e0614475..1e2f81915d 100644 --- a/include/ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_d.hpp +++ b/include/ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_d.hpp @@ -34,11 +34,13 @@ struct DeviceGroupedConvFwdMultipleD : public BaseOperator { static constexpr index_t NumDTensor = DsDataType::Size(); + static_assert(NumDTensor == DsLayout::Size(), "wrong! Inconsistent NumDTensor"); + virtual std::unique_ptr MakeArgumentPointer( - const void* p_a, - const void* p_b, + const void* p_a, // input image + const void* p_b, // weight const std::array& p_ds, - void* p_e, + void* p_e, // output image const std::array& a_g_n_c_wis_lengths, const std::array& a_g_n_c_wis_strides, const std::array& b_g_k_c_xs_lengths, diff --git a/include/ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_d_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_d_xdl_cshuffle.hpp index 03c17e6e76..5ea757c27c 100644 --- a/include/ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_d_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_d_xdl_cshuffle.hpp @@ -117,7 +117,7 @@ __global__ void #if CK_USE_LAUNCH_BOUNDS __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #endif - kernel_batch_gemm_multiple_d_xdl_cshuffle( + kernel_grouped_conv_fwd_multiple_d_xdl_cshuffle( const ABDataType* __restrict__ p_a_grid, const ABDataType* __restrict__ p_b_grid, DsPointer p_ds_grid, @@ -136,8 +136,7 @@ __global__ void const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__)) - -#if 1 + // offset base pointer for each work-group const index_t num_blocks_per_batch = __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); @@ -174,24 +173,6 @@ __global__ void ds_grid_desc_mblock_mperblock_nblock_nperblock, e_grid_desc_mblock_mperblock_nblock_nperblock_, block_2_ctile_map); -#else - __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - - GridwiseGemm::template Run(p_a_grid, - p_b_grid, - p_ds_grid, - p_e_grid, - p_shared, - a_element_op, - b_element_op, - cde_element_op, - a_grid_desc_k0_m_k1, - b_grid_desc_k0_n_k1, - ds_grid_desc_mblock_mperblock_nblock_nperblock, - e_grid_desc_mblock_mperblock_nblock_nperblock_, - block_2_ctile_map); -#endif - #else ignore = p_a_grid; ignore = p_b_grid; @@ -378,6 +359,7 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle Number{}); } + // desc for problem definition using AGridDesc_M_K = remove_cvref_t({}, {}, {}, {}, {}, {}, {}, {}, {}, {}))>; using BGridDesc_N_K = remove_cvref_t({}, {}))>; @@ -395,10 +377,6 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle BElementwiseOperation, CDEElementwiseOperation, InMemoryDataOperationEnum::Set, - AGridDesc_M_K, - BGridDesc_N_K, - DsGridDesc_M_N, - EGridDesc_M_N, NumGemmKPrefetchStage, BlockSize, MPerBlock, @@ -432,12 +410,19 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle CDEBlockTransferScalarPerVector_NPerBlock, LoopSched>; - using AGridDesc_AK0_M_AK1 = remove_cvref_t; - using BGridDesc_BK0_N_BK1 = remove_cvref_t; + using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t; + using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t; - using Block2ETileMap = typename GridwiseGemm::DefaultBlock2ETileMap; + // block-to-e-tile map + using Block2ETileMap = + remove_cvref_t; // Argument struct Argument : public BaseArgument @@ -467,6 +452,7 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle p_b_grid_{static_cast(p_b)}, p_ds_grid_{}, p_e_grid_{static_cast(p_e)}, + num_group_{a_g_n_c_wis_lengths[0]}, a_grid_desc_m_k_{DeviceOp::MakeAGridDescriptor_M_K(a_g_n_c_wis_lengths, a_g_n_c_wis_strides, b_g_k_c_xs_lengths, @@ -561,6 +547,7 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle EDataType* p_e_grid_; // tensor descriptors for problem definiton + index_t num_group_; AGridDesc_M_K a_grid_desc_m_k_; BGridDesc_N_K b_grid_desc_n_k_; DsGridDesc_M_N ds_grid_desc_m_n_; @@ -569,14 +556,14 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle // tensor descriptors for block/thread-wise copy AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_; BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_; - typename GridwiseGemm::DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock ds_grid_desc_mblock_mperblock_nblock_nperblock_; - typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock - e_grid_desc_mblock_mperblock_nblock_nperblock_; + EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock e_grid_desc_mblock_mperblock_nblock_nperblock_; // block-to-e-tile map Block2ETileMap block_2_etile_map_; + // for computing batch offset ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch_; // element-wise op @@ -622,8 +609,7 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle } const index_t grid_size = - arg.block_2_etile_map_.CalculateGridSize(arg.e_grid_desc_m_n_) * - arg.a_g_n_c_wis_lengths_[0]; // Group count + arg.block_2_etile_map_.CalculateGridSize(arg.e_grid_desc_m_n_) * arg.num_group_; const auto K = arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_.GetLength(I2); @@ -631,7 +617,7 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle auto launch_kernel = [&](auto has_main_k_block_loop) { constexpr bool has_main_loop = has_main_k_block_loop.value; - const auto kernel = kernel_batch_gemm_multiple_d_xdl_cshuffle< + const auto kernel = kernel_grouped_conv_fwd_multiple_d_xdl_cshuffle< GridwiseGemm, ADataType, // TODO: distiguish A/B datatype typename GridwiseGemm::DsGridPointer, @@ -641,8 +627,8 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle CDEElementwiseOperation, DeviceOp::AGridDesc_AK0_M_AK1, DeviceOp::BGridDesc_BK0_N_BK1, - typename GridwiseGemm::DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, - typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, + DeviceOp::DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, + DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, Block2ETileMap, ComputePtrOffsetOfStridedBatch, has_main_loop>; @@ -798,7 +784,8 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle is_same_v || is_same_v || is_same_v || is_same_v || is_same_v || is_same_v || - is_same_v) + is_same_v || is_same_v || + is_same_v) { const index_t K = arg.ds_g_n_k_wos_lengths_[i][2]; diff --git a/include/ck/tensor_operation/gpu/device/device_grouped_gemm_xdl.hpp b/include/ck/tensor_operation/gpu/device/device_grouped_gemm_xdl.hpp index abdfd078cf..06e15c1eeb 100644 --- a/include/ck/tensor_operation/gpu/device/device_grouped_gemm_xdl.hpp +++ b/include/ck/tensor_operation/gpu/device/device_grouped_gemm_xdl.hpp @@ -238,10 +238,6 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm; - using AGridDesc_AK0_M_AK1 = remove_cvref_t; - using BGridDesc_BK0_N_BK1 = remove_cvref_t; + using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t; + using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t; struct GroupedGemmBlock2ETileMap { - using UnderlyingBlock2ETileMap = typename GridwiseGemm::DefaultBlock2ETileMap; - - static_assert( - std::is_same::value, - "Wrong! Should be the same type name"); + using Block2ETileMap = + remove_cvref_t; GroupedGemmBlock2ETileMap() { @@ -321,7 +317,7 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm{}([&](auto j) { diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp new file mode 100644 index 0000000000..9efcc5d8c8 --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp @@ -0,0 +1,1014 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_conv_bwd_data_multiple_d.hpp" +#include "ck/tensor_operation/gpu/device/convolution_backward_data_specialization.hpp" +#include "ck/tensor_operation/operator_transform/transform_conv_bwd_data_to_gemm_v1.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +namespace { + +template +struct ComputePtrOffsetOfStridedBatch +{ + ComputePtrOffsetOfStridedBatch() = default; + + ComputePtrOffsetOfStridedBatch(index_t BatchStrideA, + index_t BatchStrideB, + Array BatchStrideDs, + index_t BatchStrideE) + : BatchStrideA_(BatchStrideA), + BatchStrideB_(BatchStrideB), + BatchStrideDs_(BatchStrideDs), + BatchStrideE_(BatchStrideE) + { + } + + __host__ __device__ constexpr long_index_t GetAPtrOffset(index_t g_idx) const + { + return g_idx * static_cast(BatchStrideA_); + } + + __host__ __device__ constexpr long_index_t GetBPtrOffset(index_t g_idx) const + { + return g_idx * static_cast(BatchStrideB_); + } + + __host__ __device__ constexpr auto GetDsPtrOffset(index_t g_idx) const + { + Array ds_offset; + static_for<0, NumDTensor, 1>{}( + [&](auto i) { ds_offset(i) = g_idx * static_cast(BatchStrideDs_[i]); }); + return ds_offset; + } + + __host__ __device__ constexpr long_index_t GetEPtrOffset(index_t g_idx) const + { + return g_idx * static_cast(BatchStrideE_); + } + + index_t BatchStrideA_; + index_t BatchStrideB_; + Array BatchStrideDs_; + index_t BatchStrideE_; +}; + +/* + * \brief Wrapper function of GridwiseGemm::Run to realize BatchedGEMM. + * + * \tparam ComputePtrOffsetOfBatch Class that computes the base pointer offsets of A, B, C matrix + * given the batch. For example, ComputePtrOffsetOfStridedBatch() computes the offsets of evenly + * strided batched, but we can easily extend to other layouts. The returned offset can be either \p + * index_t or \p long_index_t. If it returns \p long_index_t, we are not subject to the 2GB + * limitations. + * + * \tparam Block2ETileMap Block2ETileMap::CalculateBottomIndex() takes in id of a workgroup and + * returns the 2D index of the tile that it computes. \see + * GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3::Run(). + * + * \note Using \p ComputePtrOffsetOfBatch gives us the flexibility that 2 workgroups can compute 2 + * tiles from different matrices. Keep in mind that these 2 matrices can share the same grid + * descriptor (like in BatchedGEMM), or use their own grid descriptors (in GroupedGemm). \link + * device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp kernel_gemm_xdlops_v2r3_for_conv3d \endlink for \link + * DeviceConv3d \endlink uses the same concept, but currently does NOT encapsulate the computing of + * pointer offset into \p ComputePtrOffsetOfStridedBatch. + * + * \note \p Block2ETileMap allows customized mapping between a workgroup and the C-tile it computes. + * Together with \p ComputePtrOffsetOfBatch, we can reuse GridwiseGemm (and GridwiseGemm fusion ) to + * realize BatchedGemm and GroupedGemm (and the corresponding GEMM fusion). + * + */ +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_grouped_conv_bwd_data_multiple_d_xdl_cshuffle( + const ABDataType* __restrict__ p_a_grid, + const ABDataType* __restrict__ p_b_grid, + DsPointer p_ds_grid, + EDataType* __restrict__ p_e_grid, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CDEElementwiseOperation cde_element_op, + const index_t batch_count, + const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, + const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, + const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + ds_grid_desc_mblock_mperblock_nblock_nperblock, + const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock + e_grid_desc_mblock_mperblock_nblock_nperblock_, + const Block2ETileMap block_2_ctile_map, + const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__)) + // offset base pointer for each work-group + const index_t num_blocks_per_batch = + __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); + const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); + + const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx))); + const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx))); + const long_index_t e_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx))); + + const auto ds_batch_offset = compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx); + + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + DsPointer p_ds_grid_grp; + + static constexpr index_t NumDTensor = + DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock::Size(); + + static_for<0, NumDTensor, 1>{}( + [&](auto i) { p_ds_grid_grp(i) = p_ds_grid[i] + ds_batch_offset[i]; }); + + GridwiseGemm::template Run(p_a_grid + a_batch_offset, + p_b_grid + b_batch_offset, + p_ds_grid_grp, + p_e_grid + e_batch_offset, + p_shared, + a_element_op, + b_element_op, + cde_element_op, + a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + ds_grid_desc_mblock_mperblock_nblock_nperblock, + e_grid_desc_mblock_mperblock_nblock_nperblock_, + block_2_ctile_map); +#else + ignore = p_a_grid; + ignore = p_b_grid; + ignore = p_ds_grid; + ignore = p_e_grid; + ignore = batch_count; + ignore = a_grid_desc_ak0_m_ak1; + ignore = b_grid_desc_bk0_n_bk1; + ignore = ds_grid_desc_mblock_mperblock_nblock_nperblock; + ignore = e_grid_desc_mblock_mperblock_nblock_nperblock_; + ignore = a_element_op; + ignore = b_element_op; + ignore = cde_element_op; + ignore = compute_ptr_offset_of_batch; + ignore = block_2_ctile_map; +#endif +} + +} // namespace + +// Conv backward data multiple D: +// input : output image A: [G, N, K, Ho, Wo] +// input : weight B: [G, K, C, Y, X], +// input : D0, D1, ... : [G, N, K, Ho, Wo] +// output : input image E: [G, N, C, Hi, Wi] +// C = a_op(A) * b_op(B) +// E = cde_op(C, D0, D1, ...) +template +struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 + : public DeviceGroupedConvBwdDataMultipleD +{ + // FIXME + static_assert(NDimSpatial == 2, "wrong! only implemented for 2D now"); + + using DeviceOp = DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1; + + static constexpr index_t NumDTensor = DsDataType::Size(); + + // TODO make A/B datatype different + using ABDataType = ADataType; + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + + static constexpr auto transform_conv_to_gemm = + TransformConvBwdDataToGemm_v1{}; + + static auto GetDummyABDsEGridDescriptor() + { + const std::array dummy_tensor_lengths = {1}; + const std::array dummy_tensor_strides = {1}; + const std::array dummy_spatial_lengths = {1}; + + const auto a_grid_desc_ak0_m_ak1 = + transform_conv_to_gemm.template MakeADescriptor_AK0_M_AK1( + dummy_tensor_lengths, + dummy_tensor_strides, + dummy_tensor_lengths, + dummy_tensor_strides, + dummy_tensor_lengths, + dummy_tensor_strides, + dummy_spatial_lengths, + dummy_spatial_lengths, + dummy_spatial_lengths, + dummy_spatial_lengths, + dummy_spatial_lengths); + + const auto b_grid_desc_bk0_n_bk1 = + transform_conv_to_gemm.template MakeBDescriptor_BK0_N_BK1( + dummy_tensor_lengths, + dummy_tensor_strides, + dummy_tensor_lengths, + dummy_tensor_strides, + dummy_tensor_lengths, + dummy_tensor_strides, + dummy_spatial_lengths, + dummy_spatial_lengths, + dummy_spatial_lengths, + dummy_spatial_lengths, + dummy_spatial_lengths); + + const auto ds_grid_desc_m_n = generate_tuple( + [&](auto i) { + using DLayout = remove_cvref_t>; + + return transform_conv_to_gemm.template MakeCDescriptor_M_N( + dummy_tensor_lengths, + dummy_tensor_strides, + dummy_tensor_lengths, + dummy_tensor_strides, + dummy_tensor_lengths, + dummy_tensor_strides, + dummy_spatial_lengths, + dummy_spatial_lengths, + dummy_spatial_lengths, + dummy_spatial_lengths, + dummy_spatial_lengths); + }, + Number{}); + + const auto e_grid_desc_m_n = + transform_conv_to_gemm.template MakeCDescriptor_M_N(dummy_tensor_lengths, + dummy_tensor_strides, + dummy_tensor_lengths, + dummy_tensor_strides, + dummy_tensor_lengths, + dummy_tensor_strides, + dummy_spatial_lengths, + dummy_spatial_lengths, + dummy_spatial_lengths, + dummy_spatial_lengths, + dummy_spatial_lengths); + + return make_tuple( + a_grid_desc_ak0_m_ak1, b_grid_desc_bk0_n_bk1, ds_grid_desc_m_n, e_grid_desc_m_n); + } + + // GridwiseGemm + using GridwiseGemm = GridwiseGemmMultipleD_xdl_cshuffle< + ABDataType, // TODO: distinguish A/B datatype + AccDataType, + CShuffleDataType, + DsDataType, + EDataType, + AElementwiseOp, + BElementwiseOp, + CDEElementwiseOp, + InMemoryDataOperationEnum::Set, + NumGemmKPrefetchStage, + BlockSize, + MPerBlock, + NPerBlock, + KPerBlock, + AK1, + BK1, + MPerXDL, + NPerXDL, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + false, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + false, + BBlockLdsExtraN, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CDEBlockTransferScalarPerVector_NPerBlock, + LoopSched>; + + template + static auto transform_k0_m_k1_to_m_k(const Desc_K0_M_K1& desc_k0_m_k1) + { + const auto grid_desc_m_k = transform_tensor_descriptor( + desc_k0_m_k1, + make_tuple(make_pass_through_transform(desc_k0_m_k1.GetLength(I1)), + make_merge_transform( + make_tuple(desc_k0_m_k1.GetLength(I0), desc_k0_m_k1.GetLength(I2)))), + make_tuple(Sequence<1>{}, Sequence<0, 2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return grid_desc_m_k; + } + + // desc + using ABDsEGridDesc = decltype(GetDummyABDsEGridDescriptor()); + + using AGridDesc_AK0_M_AK1 = remove_cvref_t>; + using BGridDesc_BK0_N_BK1 = remove_cvref_t>; + using DsGridDesc_M_N = remove_cvref_t>; + using EGridDesc_M_N = remove_cvref_t>; + + using AGridDesc_M_K = decltype(transform_k0_m_k1_to_m_k(AGridDesc_AK0_M_AK1{})); + using BGridDesc_N_K = decltype(transform_k0_m_k1_to_m_k(BGridDesc_BK0_N_BK1{})); + + using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = decltype( + GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(DsGridDesc_M_N{})); + using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = decltype( + GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(EGridDesc_M_N{})); + + // block-to-e-tile map + using Block2ETileMap = + remove_cvref_t; + + // Argument + struct Argument : public BaseArgument + { + Argument(const void* p_a, // output image + const void* p_b, // weight + const std::array& p_ds, // bias + void* p_e, // input image + const std::array& a_g_n_k_wos_lengths, + const std::array& a_g_n_k_wos_strides, + const std::array& b_g_k_c_xs_lengths, + const std::array& b_g_k_c_xs_strides, + const std::array, NumDTensor>& + ds_g_n_c_wis_lengths, + const std::array, NumDTensor>& + ds_g_n_c_wis_strides, + const std::array& e_g_n_c_wis_lengths, + const std::array& e_g_n_c_wis_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads, + const AElementwiseOp& a_element_op, + const BElementwiseOp& b_element_op, + const CDEElementwiseOp& cde_element_op) + : p_a_grid_{static_cast(p_a)}, + p_b_grid_{static_cast(p_b)}, + p_ds_grid_{}, + p_e_grid_{static_cast(p_e)}, + num_group_{a_g_n_k_wos_lengths[0]}, + num_gemm_{}, + a_element_op_{a_element_op}, + b_element_op_{b_element_op}, + cde_element_op_{cde_element_op}, + a_g_n_k_wos_lengths_{a_g_n_k_wos_lengths}, + a_g_n_k_wos_strides_{a_g_n_k_wos_strides}, + b_g_k_c_xs_lengths_{b_g_k_c_xs_lengths}, + b_g_k_c_xs_strides_{b_g_k_c_xs_strides}, + ds_g_n_c_wis_lengths_{ds_g_n_c_wis_lengths}, + ds_g_n_c_wis_strides_{ds_g_n_c_wis_strides}, + e_g_n_c_wis_lengths_{e_g_n_c_wis_lengths}, + e_g_n_c_wis_strides_{e_g_n_c_wis_strides}, + conv_filter_strides_{conv_filter_strides}, + conv_filter_dilations_{conv_filter_dilations}, + input_left_pads_{input_left_pads}, + input_right_pads_{input_right_pads} + { + // populate Ds pointer + static_for<0, NumDTensor, 1>{}([&](auto i) { + using DDataType = remove_cvref_t>; + + p_ds_grid_(i) = static_cast(p_ds[i]); + }); + + // A/B/Ds/E Batch Stride + compute_ptr_offset_of_batch_.BatchStrideA_ = a_g_n_k_wos_strides[0]; + compute_ptr_offset_of_batch_.BatchStrideB_ = b_g_k_c_xs_strides[0]; + compute_ptr_offset_of_batch_.BatchStrideE_ = e_g_n_c_wis_strides[0]; + + static_for<0, NumDTensor, 1>{}([&](auto i) { + compute_ptr_offset_of_batch_.BatchStrideDs_(i) = ds_g_n_c_wis_strides[i][0]; + }); + + // problem definition + const index_t Y = b_g_k_c_xs_lengths[3]; + const index_t X = b_g_k_c_xs_lengths[4]; + + const index_t ConvStrideH = conv_filter_strides_[0]; + const index_t ConvStrideW = conv_filter_strides_[1]; + + const index_t ConvDilationH = conv_filter_dilations_[0]; + const index_t ConvDilationW = conv_filter_dilations_[1]; + + const auto GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH); + const auto GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW); + + const auto YTilde = ConvStrideH / GcdStrideDilationH; + const auto XTilde = ConvStrideW / GcdStrideDilationW; + + // number of GEMM + num_gemm_ = YTilde * XTilde; + + for(index_t i_ytilde = 0; i_ytilde < YTilde; ++i_ytilde) + { + for(index_t i_xtilde = 0; i_xtilde < XTilde; ++i_xtilde) + { + // check slice is valid + const auto YDotSlice = math::integer_divide_ceil(Y - i_ytilde, YTilde); + const auto XDotSlice = math::integer_divide_ceil(X - i_xtilde, XTilde); + + if(YDotSlice * XDotSlice <= 0) + { + continue; + } + + const auto a_grid_desc_ak0_m_ak1 = + transform_conv_to_gemm.template MakeADescriptor_AK0_M_AK1( + a_g_n_k_wos_lengths, + a_g_n_k_wos_strides, + b_g_k_c_xs_lengths, + b_g_k_c_xs_strides, + e_g_n_c_wis_lengths, + e_g_n_c_wis_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + {i_ytilde, i_xtilde}); + + const auto b_grid_desc_bk0_n_bk1 = + transform_conv_to_gemm.template MakeBDescriptor_BK0_N_BK1( + a_g_n_k_wos_lengths, + a_g_n_k_wos_strides, + b_g_k_c_xs_lengths, + b_g_k_c_xs_strides, + e_g_n_c_wis_lengths, + e_g_n_c_wis_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + {i_ytilde, i_xtilde}); + + DsGridDesc_M_N ds_grid_desc_m_n; + + // populate Ds desc + static_for<0, NumDTensor, 1>{}([&](auto i) { + using DLayout = remove_cvref_t>; + + ds_grid_desc_m_n(i) = + transform_conv_to_gemm.template MakeCDescriptor_M_N( + a_g_n_k_wos_lengths, + a_g_n_k_wos_strides, + b_g_k_c_xs_lengths, + b_g_k_c_xs_strides, + ds_g_n_c_wis_lengths[i], + ds_g_n_c_wis_strides[i], + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + {i_ytilde, i_xtilde}); + }); + + const auto e_grid_desc_m_n = + transform_conv_to_gemm.template MakeCDescriptor_M_N( + a_g_n_k_wos_lengths, + a_g_n_k_wos_strides, + b_g_k_c_xs_lengths, + b_g_k_c_xs_strides, + e_g_n_c_wis_lengths, + e_g_n_c_wis_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + {i_ytilde, i_xtilde}); + + // desc for problem definition + const auto a_grid_desc_m_k = transform_k0_m_k1_to_m_k(a_grid_desc_ak0_m_ak1); + const auto b_grid_desc_n_k = transform_k0_m_k1_to_m_k(b_grid_desc_bk0_n_bk1); + + a_grid_desc_m_k_container_.push_back(a_grid_desc_m_k); + b_grid_desc_n_k_container_.push_back(b_grid_desc_n_k); + ds_grid_desc_m_n_container_.push_back(ds_grid_desc_m_n); + e_grid_desc_m_n_container_.push_back(e_grid_desc_m_n); + + // desc for blockwise copy + a_grid_desc_ak0_m_ak1_container_.push_back(a_grid_desc_ak0_m_ak1); + b_grid_desc_bk0_n_bk1_container_.push_back(b_grid_desc_bk0_n_bk1); + + // block-to-e-tile-map + auto block_2_etile_map = + GridwiseGemm::MakeDefaultBlock2ETileMap(e_grid_desc_m_n); + + block_2_etile_map_container_.push_back(block_2_etile_map); + + if(GridwiseGemm::CheckValidity(a_grid_desc_m_k, + b_grid_desc_n_k, + ds_grid_desc_m_n, + e_grid_desc_m_n, + block_2_etile_map)) + { + ds_grid_desc_mblock_mperblock_nblock_nperblock_container_.push_back( + GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + ds_grid_desc_m_n)); + + e_grid_desc_mblock_mperblock_nblock_nperblock_container_.push_back( + GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + e_grid_desc_m_n)); + } + } + } + } + + void Print() const + { + for(index_t i = 0; i < num_gemm_; i++) + { + std::cout << "a_grid_desc_ak0_m_ak1_container_" + << a_grid_desc_ak0_m_ak1_container_[i] << std::endl; + + std::cout << "b_grid_desc_bk0_n_bk1_container_" + << b_grid_desc_bk0_n_bk1_container_[i] << std::endl; + + static_for<0, NumDTensor, 1>{}([&](auto j) { + std::cout << "ds_grid_desc_mblock_mperblock_nblock_nperblock_container_" + << ds_grid_desc_mblock_mperblock_nblock_nperblock_container_[i][j] + << std::endl; + }); + + std::cout << "e_grid_desc_mblock_mperblock_nblock_nperblock_container_" + << e_grid_desc_mblock_mperblock_nblock_nperblock_container_[i] + << std::endl; + } + } + + // pointers + const ADataType* p_a_grid_; + const BDataType* p_b_grid_; + typename GridwiseGemm::DsGridPointer p_ds_grid_; + EDataType* p_e_grid_; + + // tensor descriptor for problem definition + index_t num_group_; + index_t num_gemm_; + std::vector a_grid_desc_m_k_container_; + std::vector b_grid_desc_n_k_container_; + std::vector ds_grid_desc_m_n_container_; + std::vector e_grid_desc_m_n_container_; + + // tensor descriptor for block-wise copy + std::vector a_grid_desc_ak0_m_ak1_container_; + std::vector b_grid_desc_bk0_n_bk1_container_; + std::vector + ds_grid_desc_mblock_mperblock_nblock_nperblock_container_; + std::vector + e_grid_desc_mblock_mperblock_nblock_nperblock_container_; + + // block-to-e-tile map + std::vector block_2_etile_map_container_; + + // for computing batch offset + ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch_; + + // element-wise op + AElementwiseOp a_element_op_; + BElementwiseOp b_element_op_; + CDEElementwiseOp cde_element_op_; + + // for checking IsSupportedArgument() + std::array a_g_n_k_wos_lengths_; + std::array a_g_n_k_wos_strides_; + std::array b_g_k_c_xs_lengths_; + std::array b_g_k_c_xs_strides_; + std::array, NumDTensor> ds_g_n_c_wis_lengths_; + std::array, NumDTensor> ds_g_n_c_wis_strides_; + std::array e_g_n_c_wis_lengths_; + std::array e_g_n_c_wis_strides_; + std::array conv_filter_strides_; + std::array conv_filter_dilations_; + std::array input_left_pads_; + std::array input_right_pads_; + }; + + // Invoker + struct Invoker : public BaseInvoker + { + using Argument = DeviceOp::Argument; + + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + if(stream_config.log_level_ > 0) + { + arg.Print(); + } + + float ave_time = 0; + + for(index_t i = 0; i < arg.num_gemm_; i++) + { + if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_m_k_container_[i], + arg.b_grid_desc_n_k_container_[i], + arg.ds_grid_desc_m_n_container_[i], + arg.e_grid_desc_m_n_container_[i], + arg.block_2_etile_map_container_[i])) + { + throw std::runtime_error("wrong! device_op has invalid setting"); + } + + const index_t grid_size = arg.block_2_etile_map_container_[i].CalculateGridSize( + arg.e_grid_desc_m_n_container_[i]) * + arg.num_group_; + + const auto GemmK = arg.a_grid_desc_m_k_container_[i].GetLength(I1); + + auto launch_kernel = [&](auto has_main_k_block_loop) { + constexpr bool has_main_loop = has_main_k_block_loop.value; + + const auto kernel = kernel_grouped_conv_bwd_data_multiple_d_xdl_cshuffle< + GridwiseGemm, + ADataType, // TODO: distiguish A/B datatype + typename GridwiseGemm::DsGridPointer, + EDataType, + AElementwiseOp, + BElementwiseOp, + CDEElementwiseOp, + DeviceOp::AGridDesc_AK0_M_AK1, + DeviceOp::BGridDesc_BK0_N_BK1, + DeviceOp::DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, + DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, + Block2ETileMap, + ComputePtrOffsetOfStridedBatch, + has_main_loop>; + + return launch_and_time_kernel( + stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_ds_grid_, + arg.p_e_grid_, + arg.a_element_op_, + arg.b_element_op_, + arg.cde_element_op_, + arg.a_g_n_k_wos_lengths_[0], // Group count + arg.a_grid_desc_ak0_m_ak1_container_[i], + arg.b_grid_desc_bk0_n_bk1_container_[i], + arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_container_[i], + arg.e_grid_desc_mblock_mperblock_nblock_nperblock_container_[i], + arg.block_2_etile_map_container_[i], + arg.compute_ptr_offset_of_batch_); + }; + + if(GridwiseGemm::CalculateHasMainKBlockLoop(GemmK)) + { + ave_time += launch_kernel(integral_constant{}); + } + else + { + ave_time += launch_kernel(integral_constant{}); + } + } + + return ave_time; + } + + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static bool IsSupportedArgument(const Argument& arg) + { + const index_t ConvK = arg.b_g_k_c_xs_lengths_[1]; + const index_t ConvC = arg.b_g_k_c_xs_lengths_[2]; + + // Specifialization + if constexpr(ConvBackwardDataSpecialization == + ConvolutionBackwardDataSpecialization::Filter1x1Stride1Pad0) + { + // check if it's 1x1, stride=1 pad = 0 conv + for(int i = 0; i < NDimSpatial; i++) + { + if(!(arg.filter_spatial_lengths_[i] == 1 && arg.conv_filter_strides_[i] == 1 && + arg.input_left_pads_[i] == 0 && arg.input_right_pads_[i] == 0)) + { + return false; + } + } + } + + // vector load for A matrix from global memory to LDS + if constexpr(is_same_v) + { + if(!(ABlockTransferSrcVectorDim == 2 && ConvK % ABlockTransferSrcScalarPerVector == 0)) + { + return false; + } + } + else + { + return false; + } + + // vector load for B matrix from global memory to LDS + if constexpr(is_same_v) + { + if(!(BBlockTransferSrcVectorDim == 1 && ConvC % BBlockTransferSrcScalarPerVector == 0)) + { + return false; + } + } + else + { + return false; + } + + // vector store for Ds + bool ds_valid = true; + + static_for<0, NumDTensor, 1>{}([&](auto i) { + using DLayout = remove_cvref_t>; + + if constexpr(is_same_v || + is_same_v || + is_same_v || + is_same_v || + is_same_v) + { + // vector load D matrix from global memory + if(!(ConvC % CDEBlockTransferScalarPerVector_NPerBlock == 0)) + { + ds_valid = false; + } + } + else + { + ds_valid = false; + } + }); + + if(!ds_valid) + { + return false; + } + + // vector store for E + if constexpr(is_same_v) + { + // vector store C matrix into global memory + if(!(ConvC % CDEBlockTransferScalarPerVector_NPerBlock == 0)) + { + return false; + } + } + else + { + return false; + } + + // Gridwise GEMM size + for(std::size_t i = 0; i < arg.a_grid_desc_ak0_m_ak1_container_.size(); i++) + { + if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_m_k_container_[i], + arg.b_grid_desc_n_k_container_[i], + arg.ds_grid_desc_m_n_container_[i], + arg.e_grid_desc_m_n_container_[i], + arg.block_2_etile_map_container_[i])) + { + return false; + } + } + + return true; + } + + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto + MakeArgument(const void* p_a, // output image + const void* p_b, // weight + const std::array& p_ds, // bias + void* p_e, // input image + const std::array& a_g_n_k_wos_lengths, // output image + const std::array& a_g_n_k_wos_strides, // output image + const std::array& b_g_k_c_xs_lengths, // weight + const std::array& b_g_k_c_xs_strides, // weight + const std::array, NumDTensor>& + ds_g_n_c_wis_lengths, // bias + const std::array, NumDTensor>& + ds_g_n_c_wis_strides, // bias + const std::array& e_g_n_c_wis_lengths, // input image + const std::array& e_g_n_c_wis_strides, // input image + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads, + const AElementwiseOp& a_element_op, + const BElementwiseOp& b_element_op, + const CDEElementwiseOp& cde_element_op) + { + return Argument{p_a, + p_b, + p_ds, + p_e, + a_g_n_k_wos_lengths, + a_g_n_k_wos_strides, + b_g_k_c_xs_lengths, + b_g_k_c_xs_strides, + ds_g_n_c_wis_lengths, + ds_g_n_c_wis_strides, + e_g_n_c_wis_lengths, + e_g_n_c_wis_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + a_element_op, + b_element_op, + cde_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + std::unique_ptr MakeArgumentPointer( + const void* p_a, // output image + const void* p_b, // weight + const std::array& p_ds, // bias + void* p_e, // input image + const std::array& a_g_n_k_wos_lengths, // output image + const std::array& a_g_n_k_wos_strides, // output image + const std::array& b_g_k_c_xs_lengths, // weight + const std::array& b_g_k_c_xs_strides, // weight + const std::array, NumDTensor>& + ds_g_n_c_wis_lengths, // bias + const std::array, NumDTensor>& + ds_g_n_c_wis_strides, // bias + const std::array& e_g_n_c_wis_lengths, // input image + const std::array& e_g_n_c_wis_strides, // input image + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads, + const AElementwiseOp& a_element_op, + const BElementwiseOp& b_element_op, + const CDEElementwiseOp& cde_element_op) override + { + return std::make_unique(p_a, + p_b, + p_ds, + p_e, + a_g_n_k_wos_lengths, + a_g_n_k_wos_strides, + b_g_k_c_xs_lengths, + b_g_k_c_xs_strides, + ds_g_n_c_wis_lengths, + ds_g_n_c_wis_strides, + e_g_n_c_wis_lengths, + e_g_n_c_wis_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + a_element_op, + b_element_op, + cde_element_op); + } + + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1" + << "<" + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << KPerBlock << ", " + << AK1 << ", " + << BK1 << ", " + << getConvBackwardDataSpecializationString(ConvBackwardDataSpecialization) + << ">"; + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/tensor_layout.hpp b/include/ck/tensor_operation/gpu/device/tensor_layout.hpp index a06a567c96..b44427411f 100644 --- a/include/ck/tensor_operation/gpu/device/tensor_layout.hpp +++ b/include/ck/tensor_operation/gpu/device/tensor_layout.hpp @@ -92,6 +92,12 @@ struct GNDHWC : public BaseTensorLayout static constexpr const char* name = "GNDHWC"; }; +// for input bias +struct GC : public BaseTensorLayout +{ + static constexpr const char* name = "GC"; +}; + // input tensor // packed NWGC/NHWGC/NDHWGC struct NWGC : public BaseTensorLayout @@ -126,6 +132,12 @@ struct G_NDHW_C : public BaseTensorLayout static constexpr const char* name = "G_NDHW_C"; }; +// for input bias +struct G_C : public BaseTensorLayout +{ + static constexpr const char* name = "G_C"; +}; + // weight tensor // packed KCX/KCYX/KCZYX struct KCX : public BaseTensorLayout @@ -296,6 +308,12 @@ struct GNDHWK : public BaseTensorLayout static constexpr const char* name = "GNDHWK"; }; +// for output bias +struct GK : public BaseTensorLayout +{ + static constexpr const char* name = "GK"; +}; + // output tensor // packed NWGK/NHWGK/NDHWGK struct NWGK : public BaseTensorLayout @@ -330,6 +348,12 @@ struct G_NDHW_K : public BaseTensorLayout static constexpr const char* name = "G_NDHW_K"; }; +// for output bias +struct G_K : public BaseTensorLayout +{ + static constexpr const char* name = "G_K"; +}; + // K-reduced output tensor (packed) struct GNW : public BaseTensorLayout { diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp index 4656ed439d..ade8b204a7 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp @@ -35,10 +35,6 @@ template __host__ __device__ static constexpr auto MakeDefaultAGridDescriptor_AK0_M_AK1(const AGridDesc_M_K& a_grid_desc_m_k) { @@ -182,6 +179,7 @@ struct GridwiseGemmMultipleD_xdl_cshuffle } // B desc for source in blockwise copy + template __host__ __device__ static constexpr auto MakeDefaultBGridDescriptor_BK0_N_BK1(const BGridDesc_N_K& b_grid_desc_n_k) { @@ -198,9 +196,9 @@ struct GridwiseGemmMultipleD_xdl_cshuffle } // E desc for destination in blockwise copy - template - __host__ __device__ static constexpr auto MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( - const EGridDescriptor_M_N& e_grid_desc_m_n) + template + __host__ __device__ static constexpr auto + MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const EGridDesc_M_N& e_grid_desc_m_n) { const auto M = e_grid_desc_m_n.GetLength(I0); const auto N = e_grid_desc_m_n.GetLength(I1); @@ -219,10 +217,9 @@ struct GridwiseGemmMultipleD_xdl_cshuffle } // Ds desc for source in blockwise copy - template + template __host__ __device__ static constexpr auto - MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( - const DsGridDescriptor_M_N& ds_grid_desc_m_n) + MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const DsGridDesc_M_N& ds_grid_desc_m_n) { return generate_tuple( [&](auto i) { @@ -232,6 +229,7 @@ struct GridwiseGemmMultipleD_xdl_cshuffle } // return block_id to E matrix tile idx (m0, n0) mapping + template __host__ __device__ static constexpr auto MakeDefaultBlock2ETileMap(const EGridDesc_M_N& e_grid_desc_m_n) { @@ -240,7 +238,11 @@ struct GridwiseGemmMultipleD_xdl_cshuffle } // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} - template + template __host__ __device__ static constexpr bool CheckValidity(const AGridDesc_M_K& a_grid_desc_m_k, const BGridDesc_N_K& b_grid_desc_n_k, const DsGridDesc_M_N& ds_grid_desc_m_n, @@ -314,23 +316,13 @@ struct GridwiseGemmMultipleD_xdl_cshuffle return GridwiseGemmPipe::CalculateHasMainLoop(num_loop); } - using DefaultAGridDesc_AK0_M_AK1 = - remove_cvref_t; - using DefaultBGridDesc_BK0_N_BK1 = - remove_cvref_t; - using EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t; - using DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t; - - using DefaultBlock2ETileMap = - remove_cvref_t; - using DsGridPointer = decltype(MakeDsGridPointer()); template __device__ static void Run(const ABDataType* __restrict__ p_a_grid, const ABDataType* __restrict__ p_b_grid, @@ -342,9 +334,9 @@ struct GridwiseGemmMultipleD_xdl_cshuffle const CDEElementwiseOperation& cde_element_op, const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1, const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1, - const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock& + const DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock& ds_grid_desc_mblock_mperblock_nblock_nperblock, - const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock& + const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock& e_grid_desc_mblock_mperblock_nblock_nperblock, const Block2ETileMap& block_2_etile_map) { diff --git a/include/ck/tensor_operation/operator_transform/transform_conv_bwd_data_to_gemm_v1.hpp b/include/ck/tensor_operation/operator_transform/transform_conv_bwd_data_to_gemm_v1.hpp new file mode 100644 index 0000000000..13d0a28cfe --- /dev/null +++ b/include/ck/tensor_operation/operator_transform/transform_conv_bwd_data_to_gemm_v1.hpp @@ -0,0 +1,583 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/convolution_backward_data_specialization.hpp" +#include "ck/tensor_operation/gpu/device/matrix_padder.hpp" + +namespace ck { +namespace tensor_operation { + +template < + index_t NDimSpatial, + ck::tensor_operation::device::ConvolutionBackwardDataSpecialization ConvBwdDataSpecialization, + index_t AK1, + index_t BK1, + index_t GemmMPerBlock, + index_t GemmNPerBlock, + bool DoPadGemmM, + bool DoPadGemmN> +struct TransformConvBwdDataToGemm_v1 +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + + template , + bool>::type = false> + static auto MakeADescriptor_AK0_M_AK1( + const std::array& out_g_n_k_wos_lengths, + const std::array& /* out_g_n_k_wos_strides */, + const std::array& wei_g_k_c_xs_lengths, + const std::array& /* wei_g_k_c_xs_strides */, + const std::array& in_g_n_c_wis_lengths, + const std::array& /* in_g_n_c_wis_strides */, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& /* input_right_pads */, + const std::array& tildes) + { + index_t i_ytilde = tildes[0]; + index_t i_xtilde = tildes[1]; + + const index_t N = in_g_n_c_wis_lengths[1]; + const index_t K = wei_g_k_c_xs_lengths[1]; + + const index_t Hi = in_g_n_c_wis_lengths[3]; + const index_t Wi = in_g_n_c_wis_lengths[4]; + + const index_t Ho = out_g_n_k_wos_lengths[3]; + const index_t Wo = out_g_n_k_wos_lengths[4]; + + const index_t Y = wei_g_k_c_xs_lengths[3]; + const index_t X = wei_g_k_c_xs_lengths[4]; + + const index_t InLeftPadH = input_left_pads[0]; + const index_t InLeftPadW = input_left_pads[1]; + + const index_t ConvStrideH = conv_filter_strides[0]; + const index_t ConvStrideW = conv_filter_strides[1]; + + const index_t ConvDilationH = conv_filter_dilations[0]; + const index_t ConvDilationW = conv_filter_dilations[1]; + + const index_t AK0 = K / AK1; + + // assume packed + const auto out_n_ho_wo_k_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N, Ho, Wo, K)); + + if constexpr(ConvBwdDataSpecialization == + ck::tensor_operation::device::ConvolutionBackwardDataSpecialization:: + Filter1x1Stride1Pad0) + { + // A: output tensor + const auto out_gemmak0_gemmmraw_gemmak1_grid_desc = transform_tensor_descriptor( + make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K)), + make_tuple(make_pass_through_transform(N * Ho * Wo), + make_unmerge_transform(make_tuple(AK0, AK1))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<1>{}, Sequence<0, 2>{})); + + const auto out_gemmak0_gemmm_gemmak1_grid_desc = + ck::tensor_operation::device::PadTensorDescriptor( + out_gemmak0_gemmmraw_gemmak1_grid_desc, + make_tuple(AK0, GemmMPerBlock, AK1), + Sequence{}); + + return out_gemmak0_gemmm_gemmak1_grid_desc; + } + else + { + const auto GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH); + const auto GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW); + + const auto YTilde = ConvStrideH / GcdStrideDilationH; + const auto XTilde = ConvStrideW / GcdStrideDilationW; + + const auto YDot = math::integer_divide_ceil(Y, YTilde); + const auto XDot = math::integer_divide_ceil(X, XTilde); + + const auto HTilde = + Ho + math::integer_divide_ceil(ConvDilationH * (Y - I1), ConvStrideH); + const auto WTilde = + Wo + math::integer_divide_ceil(ConvDilationW * (X - I1), ConvStrideW); + + // only work on HTilde and WTilde that contribute to non-padding area of input tensor + const auto IHTildeSliceBegin = math::integer_divide_floor( + math::max(I0, InLeftPadH - ConvDilationH * (YTilde - I1)), ConvStrideH); + const auto IWTildeSliceBegin = math::integer_divide_floor( + math::max(I0, InLeftPadW - ConvDilationW * (XTilde - I1)), ConvStrideW); + + const auto IHTildeSliceEnd = math::min( + HTilde, math::integer_divide_ceil(InLeftPadH + Hi - I1, ConvStrideH) + I1); + const auto IWTildeSliceEnd = math::min( + WTilde, math::integer_divide_ceil(InLeftPadW + Wi - I1, ConvStrideW) + I1); + + const auto HTildeSlice = IHTildeSliceEnd - IHTildeSliceBegin; + const auto WTildeSlice = IWTildeSliceEnd - IWTildeSliceBegin; + + // GemmK is different for each GEMM + const auto YDotSlice = math::integer_divide_ceil(Y - i_ytilde, YTilde); + const auto XDotSlice = math::integer_divide_ceil(X - i_xtilde, XTilde); + + // A: output tensor + const auto out_n_hop_wop_k_grid_desc = transform_tensor_descriptor( + out_n_ho_wo_k_grid_desc, + make_tuple(make_pass_through_transform(N), + make_pad_transform(Ho, I0, I0), + make_pad_transform(Wo, I0, I0), + make_pass_through_transform(K)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + const auto out_n_ydot_htilde_xdot_wtilde_k_grid_desc = transform_tensor_descriptor( + out_n_hop_wop_k_grid_desc, + make_tuple( + make_pass_through_transform(N), + make_embed_transform(make_tuple(YDot, HTilde), + make_tuple(-ConvDilationH / GcdStrideDilationH, I1)), + make_embed_transform(make_tuple(XDot, WTilde), + make_tuple(-ConvDilationW / GcdStrideDilationW, I1)), + make_pass_through_transform(K)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); + + const auto out_n_ydotslice_htildeslice_xdotslice_wtildeslice_ak0_ak1_grid_desc = + transform_tensor_descriptor( + out_n_ydot_htilde_xdot_wtilde_k_grid_desc, + make_tuple(make_pass_through_transform(N), + make_slice_transform(YDot, I0, YDotSlice), + make_slice_transform(HTilde, IHTildeSliceBegin, HTildeSlice), + make_slice_transform(XDot, I0, XDotSlice), + make_slice_transform(WTilde, IWTildeSliceBegin, WTildeSlice), + make_unmerge_transform(make_tuple(AK0, AK1))), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5, 6>{})); + + const auto out_gemmak0_gemmmraw_gemmak1_grid_desc = transform_tensor_descriptor( + out_n_ydotslice_htildeslice_xdotslice_wtildeslice_ak0_ak1_grid_desc, + make_tuple(make_merge_transform(make_tuple(YDotSlice, XDotSlice, AK0)), + make_merge_transform(make_tuple(N, HTildeSlice, WTildeSlice)), + make_pass_through_transform(AK1)), + make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}, Sequence<6>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + const auto out_gemmak0_gemmm_gemmak1_grid_desc = + ck::tensor_operation::device::PadTensorDescriptor( + out_gemmak0_gemmmraw_gemmak1_grid_desc, + make_tuple(AK0, GemmMPerBlock, AK1), + Sequence{}); + + return out_gemmak0_gemmm_gemmak1_grid_desc; + } + } + + template , + bool>::type = false> + static auto MakeBDescriptor_BK0_N_BK1( + const std::array& out_g_n_k_wos_lengths, + const std::array& /* out_g_n_k_wos_strides */, + const std::array& wei_g_k_c_xs_lengths, + const std::array& /* wei_g_k_c_xs_strides */, + const std::array& in_g_n_c_wis_lengths, + const std::array& /* in_g_n_c_wis_strides */, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& /* input_left_pads */, + const std::array& /* input_right_pads */, + const std::array& tildes) + { + index_t i_ytilde = tildes[0]; + index_t i_xtilde = tildes[1]; + + const index_t N = in_g_n_c_wis_lengths[1]; + const index_t K = wei_g_k_c_xs_lengths[1]; + const index_t C = wei_g_k_c_xs_lengths[2]; + + const index_t Ho = out_g_n_k_wos_lengths[3]; + const index_t Wo = out_g_n_k_wos_lengths[4]; + + const index_t Y = wei_g_k_c_xs_lengths[3]; + const index_t X = wei_g_k_c_xs_lengths[4]; + + const index_t ConvStrideH = conv_filter_strides[0]; + const index_t ConvStrideW = conv_filter_strides[1]; + + const index_t ConvDilationH = conv_filter_dilations[0]; + const index_t ConvDilationW = conv_filter_dilations[1]; + + const index_t BK0 = K / BK1; + + // assume packed + const auto wei_k_y_x_c_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(K, Y, X, C)); + + if constexpr(ConvBwdDataSpecialization == + ck::tensor_operation::device::ConvolutionBackwardDataSpecialization:: + Filter1x1Stride1Pad0) + { + // B: weight tensor + const auto wei_gemmbk0_gemmnraw_gemmbk1_grid_desc = + transform_tensor_descriptor(make_naive_tensor_descriptor_packed(make_tuple(K, C)), + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + make_naive_tensor_descriptor(make_tuple(N * Ho * Wo, C), make_tuple(I0, I1)); + + const auto wei_gemmbk0_gemmn_gemmbk1_grid_desc = + ck::tensor_operation::device::PadTensorDescriptor( + wei_gemmbk0_gemmnraw_gemmbk1_grid_desc, + make_tuple(BK0, GemmNPerBlock, BK1), + Sequence{}); + + return wei_gemmbk0_gemmn_gemmbk1_grid_desc; + } + else + { + const auto GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH); + const auto GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW); + + const auto YTilde = ConvStrideH / GcdStrideDilationH; + const auto XTilde = ConvStrideW / GcdStrideDilationW; + + const auto YDot = math::integer_divide_ceil(Y, YTilde); + const auto XDot = math::integer_divide_ceil(X, XTilde); + + // GemmK is different for each GEMM + const auto YDotSlice = math::integer_divide_ceil(Y - i_ytilde, YTilde); + const auto XDotSlice = math::integer_divide_ceil(X - i_xtilde, XTilde); + + // B weight tensor + const auto wei_k_ydot_ytilde_xdot_xtilde_c_grid_desc = transform_tensor_descriptor( + wei_k_y_x_c_grid_desc, + make_tuple(make_pass_through_transform(K), + make_embed_transform(make_tuple(YDot, YTilde), + make_tuple(ConvStrideH / GcdStrideDilationH, I1)), + make_embed_transform(make_tuple(XDot, XTilde), + make_tuple(ConvStrideW / GcdStrideDilationW, I1)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); + + const auto wei_bk0_bk1_ydotslice_xdotslice_c_grid_desc = + transform_tensor_descriptor(wei_k_ydot_ytilde_xdot_xtilde_c_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)), + make_slice_transform(YDot, I0, YDotSlice), + make_slice_transform(XDot, I0, XDotSlice), + make_freeze_transform(i_ytilde), + make_freeze_transform(i_xtilde), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<3>{}, + Sequence<2>{}, + Sequence<4>{}, + Sequence<5>{}), + make_tuple(Sequence<0, 1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<>{}, + Sequence<>{}, + Sequence<4>{})); + + const auto wei_gemmbk0_gemmnraw_gemmbk1_grid_desc = transform_tensor_descriptor( + wei_bk0_bk1_ydotslice_xdotslice_c_grid_desc, + make_tuple(make_merge_transform(make_tuple(YDotSlice, XDotSlice, BK0)), + make_pass_through_transform(C), + make_pass_through_transform(BK1)), + make_tuple(Sequence<2, 3, 0>{}, Sequence<4>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + const auto wei_gemmbk0_gemmn_gemmbk1_grid_desc = + ck::tensor_operation::device::PadTensorDescriptor( + wei_gemmbk0_gemmnraw_gemmbk1_grid_desc, + make_tuple( + wei_gemmbk0_gemmnraw_gemmbk1_grid_desc.GetLength(I0), GemmNPerBlock, BK1), + Sequence{}); + + return wei_gemmbk0_gemmn_gemmbk1_grid_desc; + } + } + + template || + is_same_v || + is_same_v), + bool>::type = false> + static auto + MakeCDescriptor_M_N(const std::array& out_g_n_k_wos_lengths, + const std::array& /* out_g_n_k_wos_strides */, + const std::array& wei_g_k_c_xs_lengths, + const std::array& /* wei_g_k_c_xs_strides */, + const std::array& in_g_n_c_wis_lengths, + const std::array& in_g_n_c_wis_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads, + const std::array& tildes) + { + index_t i_ytilde = tildes[0]; + index_t i_xtilde = tildes[1]; + + const index_t N = in_g_n_c_wis_lengths[1]; + const index_t C = wei_g_k_c_xs_lengths[2]; + + const index_t Hi = in_g_n_c_wis_lengths[3]; + const index_t Wi = in_g_n_c_wis_lengths[4]; + + const index_t Ho = out_g_n_k_wos_lengths[3]; + const index_t Wo = out_g_n_k_wos_lengths[4]; + + const index_t Y = wei_g_k_c_xs_lengths[3]; + const index_t X = wei_g_k_c_xs_lengths[4]; + + const index_t InLeftPadH = input_left_pads[0]; + const index_t InLeftPadW = input_left_pads[1]; + + const index_t InRightPadH = input_right_pads[0]; + const index_t InRightPadW = input_right_pads[1]; + + const index_t ConvStrideH = conv_filter_strides[0]; + const index_t ConvStrideW = conv_filter_strides[1]; + + const index_t ConvDilationH = conv_filter_dilations[0]; + const index_t ConvDilationW = conv_filter_dilations[1]; + + // assume strided + const auto in_n_hi_wi_c_grid_desc = + make_naive_tensor_descriptor(make_tuple(N, Hi, Wi, C), + make_tuple(in_g_n_c_wis_strides[1], + in_g_n_c_wis_strides[3], + in_g_n_c_wis_strides[4], + in_g_n_c_wis_strides[2])); + + if constexpr(ConvBwdDataSpecialization == + ck::tensor_operation::device::ConvolutionBackwardDataSpecialization:: + Filter1x1Stride1Pad0) + { + // C: input tensor + const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor( + in_n_hi_wi_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_embed_transform(make_tuple(I1, Ho), make_tuple(I1, ConvStrideH)), + make_embed_transform(make_tuple(I1, Wo), make_tuple(I1, ConvStrideW)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); + + const auto in_gemmmraw_gemmnraw_grid_desc = transform_tensor_descriptor( + in_n_y_ho_x_wo_c_grid_desc, + make_tuple(make_freeze_transform(I0), + make_freeze_transform(I0), + make_merge_transform(make_tuple(N, Ho, Wo)), + make_pass_through_transform(C)), + make_tuple(Sequence<1>{}, Sequence<3>{}, Sequence<0, 2, 4>{}, Sequence<5>{}), + make_tuple(Sequence<>{}, Sequence<>{}, Sequence<0>{}, Sequence<1>{})); + + const auto in_gemmm_gemmn_grid_desc = ck::tensor_operation::device::PadTensorDescriptor( + in_gemmmraw_gemmnraw_grid_desc, + make_tuple(GemmMPerBlock, GemmNPerBlock), + Sequence{}); + + return in_gemmm_gemmn_grid_desc; + } + else + { + const auto GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH); + const auto GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW); + + const auto YTilde = ConvStrideH / GcdStrideDilationH; + const auto XTilde = ConvStrideW / GcdStrideDilationW; + + const auto HTilde = + Ho + math::integer_divide_ceil(ConvDilationH * (Y - I1), ConvStrideH); + const auto WTilde = + Wo + math::integer_divide_ceil(ConvDilationW * (X - I1), ConvStrideW); + + // only work on HTilde and WTilde that contribute to non-padding area of input tensor + const auto IHTildeSliceBegin = math::integer_divide_floor( + math::max(I0, InLeftPadH - ConvDilationH * (YTilde - I1)), ConvStrideH); + const auto IWTildeSliceBegin = math::integer_divide_floor( + math::max(I0, InLeftPadW - ConvDilationW * (XTilde - I1)), ConvStrideW); + + const auto IHTildeSliceEnd = math::min( + HTilde, math::integer_divide_ceil(InLeftPadH + Hi - I1, ConvStrideH) + I1); + const auto IWTildeSliceEnd = math::min( + WTilde, math::integer_divide_ceil(InLeftPadW + Wi - I1, ConvStrideW) + I1); + + const auto HTildeSlice = IHTildeSliceEnd - IHTildeSliceBegin; + const auto WTildeSlice = IWTildeSliceEnd - IWTildeSliceBegin; + + // C: input tensor + const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor( + in_n_hi_wi_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_pad_transform(Hi, InLeftPadH, InRightPadH), + make_pad_transform(Wi, InLeftPadW, InRightPadW), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + const auto in_n_ytilde_htilde_xtilde_wtilde_c_grid_desc = transform_tensor_descriptor( + in_n_hip_wip_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_embed_transform(make_tuple(YTilde, HTilde), + make_tuple(ConvDilationH, ConvStrideH)), + make_embed_transform(make_tuple(XTilde, WTilde), + make_tuple(ConvDilationW, ConvStrideW)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); + + const auto in_n_htildeslice_wtildeslice_c_grid_desc = transform_tensor_descriptor( + in_n_ytilde_htilde_xtilde_wtilde_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_freeze_transform(i_ytilde), + make_slice_transform(HTilde, IHTildeSliceBegin, HTildeSlice), + make_freeze_transform(i_xtilde), + make_slice_transform(WTilde, IWTildeSliceBegin, WTildeSlice), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}), + make_tuple(Sequence<0>{}, + Sequence<>{}, + Sequence<1>{}, + Sequence<>{}, + Sequence<2>{}, + Sequence<3>{})); + + const auto in_gemmmraw_gemmnraw_grid_desc = transform_tensor_descriptor( + in_n_htildeslice_wtildeslice_c_grid_desc, + make_tuple(make_merge_transform(make_tuple(N, HTildeSlice, WTildeSlice)), + make_pass_through_transform(C)), + make_tuple(Sequence<0, 1, 2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto in_gemmm_gemmn_grid_desc = ck::tensor_operation::device::PadTensorDescriptor( + in_gemmmraw_gemmnraw_grid_desc, + make_tuple(GemmMPerBlock, GemmNPerBlock), + Sequence{}); + + return in_gemmm_gemmn_grid_desc; + } + } + + // for input bias + template || + is_same_v), + bool>::type = false> + static auto + MakeCDescriptor_M_N(const std::array& out_g_n_k_wos_lengths, + const std::array& /* out_g_n_k_wos_strides */, + const std::array& wei_g_k_c_xs_lengths, + const std::array& /* wei_g_k_c_xs_strides */, + const std::array& in_g_n_c_wis_lengths, + const std::array& /* in_g_n_c_wis_strides */, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& /* input_right_pads */, + const std::array& /* tildes */) + { + const index_t N = in_g_n_c_wis_lengths[1]; + const index_t C = wei_g_k_c_xs_lengths[2]; + + const index_t Hi = in_g_n_c_wis_lengths[3]; + const index_t Wi = in_g_n_c_wis_lengths[4]; + + const index_t Ho = out_g_n_k_wos_lengths[3]; + const index_t Wo = out_g_n_k_wos_lengths[4]; + + const index_t Y = wei_g_k_c_xs_lengths[3]; + const index_t X = wei_g_k_c_xs_lengths[4]; + + const index_t InLeftPadH = input_left_pads[0]; + const index_t InLeftPadW = input_left_pads[1]; + + const index_t ConvStrideH = conv_filter_strides[0]; + const index_t ConvStrideW = conv_filter_strides[1]; + + const index_t ConvDilationH = conv_filter_dilations[0]; + const index_t ConvDilationW = conv_filter_dilations[1]; + + if constexpr(ConvBwdDataSpecialization == + ck::tensor_operation::device::ConvolutionBackwardDataSpecialization:: + Filter1x1Stride1Pad0) + { + const auto in_gemmm_gemmn_grid_desc = + make_naive_tensor_descriptor(make_tuple(N * Ho * Wo, C), make_tuple(I0, I1)); + + return in_gemmm_gemmn_grid_desc; + } + else + { + const auto GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH); + const auto GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW); + + const auto YTilde = ConvStrideH / GcdStrideDilationH; + const auto XTilde = ConvStrideW / GcdStrideDilationW; + + const auto HTilde = + Ho + math::integer_divide_ceil(ConvDilationH * (Y - I1), ConvStrideH); + const auto WTilde = + Wo + math::integer_divide_ceil(ConvDilationW * (X - I1), ConvStrideW); + + // only work on HTilde and WTilde that contribute to non-padding area of input tensor + const auto IHTildeSliceBegin = math::integer_divide_floor( + math::max(I0, InLeftPadH - ConvDilationH * (YTilde - I1)), ConvStrideH); + const auto IWTildeSliceBegin = math::integer_divide_floor( + math::max(I0, InLeftPadW - ConvDilationW * (XTilde - I1)), ConvStrideW); + + const auto IHTildeSliceEnd = math::min( + HTilde, math::integer_divide_ceil(InLeftPadH + Hi - I1, ConvStrideH) + I1); + const auto IWTildeSliceEnd = math::min( + WTilde, math::integer_divide_ceil(InLeftPadW + Wi - I1, ConvStrideW) + I1); + + const auto HTildeSlice = IHTildeSliceEnd - IHTildeSliceBegin; + const auto WTildeSlice = IWTildeSliceEnd - IWTildeSliceBegin; + + // bias tensor + const auto in_gemmmraw_gemmnraw_grid_desc = make_naive_tensor_descriptor( + make_tuple(N * HTildeSlice * WTildeSlice, C), make_tuple(I0, I1)); + + const auto in_gemmm_gemmn_grid_desc = ck::tensor_operation::device::PadTensorDescriptor( + in_gemmmraw_gemmnraw_grid_desc, + make_tuple(GemmMPerBlock, GemmNPerBlock), + Sequence{}); + + return in_gemmm_gemmn_grid_desc; + } + } +}; + +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp b/include/ck/tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp index 37a6e362c4..80934f7803 100644 --- a/include/ck/tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp +++ b/include/ck/tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp @@ -16,6 +16,7 @@ namespace tensor_operation { template struct TransformConvFwdToGemm { + static constexpr auto I0 = Number<0>{}; static constexpr auto I1 = Number<1>{}; template || + is_same_v, + bool>::type = false> + static auto + MakeCDescriptor_M_N(const std::array& c_g_n_k_wos_lengths, + const std::array& /* c_g_n_k_wos_strides */) + { + const index_t N = c_g_n_k_wos_lengths[1]; + const index_t K = c_g_n_k_wos_lengths[2]; + + const index_t NHoWo = N * std::accumulate(c_g_n_k_wos_lengths.begin() + 3, + c_g_n_k_wos_lengths.begin() + 3 + NDimSpatial, + index_t{1}, + std::multiplies()); + + const auto out_gemmm_gemmn_desc = + make_naive_tensor_descriptor(make_tuple(NHoWo, K), make_tuple(I0, I1)); + + return out_gemmm_gemmn_desc; + } }; } // namespace tensor_operation diff --git a/include/ck/utility/ignore.hpp b/include/ck/utility/ignore.hpp index 0172458741..ac33cbf9a5 100644 --- a/include/ck/utility/ignore.hpp +++ b/include/ck/utility/ignore.hpp @@ -1,8 +1,7 @@ // SPDX-License-Identifier: MIT // Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. -#ifndef CK_IGNORE_HPP -#define CK_IGNORE_HPP +#pragma once // https://en.cppreference.com/w/cpp/utility/tuple/ignore @@ -21,4 +20,3 @@ struct ignore_t inline constexpr detail::ignore_t ignore; } // namespace ck -#endif