From 05d18a052b76b571de105ee4cc05420cd40684d3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bart=C5=82omiej=20Kocot?= Date: Fri, 10 Nov 2023 15:54:44 +0100 Subject: [PATCH] Support multi AB for grouped conv fwd xdl (#1027) * Support multi AB for grouped conv fwd xdl * Add instances * Add client example * Add example * Add interface test * Minor fixes Minor fixes Minor fixes * Comment fixes * Fixes * Reference fix * Test xdl fixes * Improve multi_ab interface test [ROCm/composable_kernel commit: 49e52bb35714bcb81c2799994adbb6b23f6a4a29] --- ...rouped_conv_fwd_scaleadd_scaleadd_relu.inc | 2 +- .../CMakeLists.txt | 11 + .../grouped_conv_fwd_scaleadd_ab.inc | 221 ++++++++ .../grouped_conv_fwd_scaleadd_ab_bf16.cpp | 13 + .../grouped_conv_fwd_scaleadd_ab_fp16.cpp | 13 + .../grouped_conv_fwd_scaleadd_ab_fp32.cpp | 13 + .../grouped_conv_fwd_scaleadd_ab_int8.cpp | 13 + example/62_conv_fwd_activ/CMakeLists.txt | 9 + ...nd_fwd_xdl_scaleadd_scaleadd_relu_fp16.cpp | 21 +- .../conv_fwd_xdl_scaleadd_ab_bf16.cpp | 26 + .../conv_fwd_xdl_scaleadd_ab_fp16.cpp | 26 + .../conv_fwd_xdl_scaleadd_ab_fp32.cpp | 26 + .../conv_fwd_xdl_scaleadd_ab_int8.cpp | 26 + .../convnd_fwd_activ_multi_ab_common.hpp | 266 ++++++++++ .../device_grouped_conv_fwd_multiple_d.hpp | 88 +++- .../impl/device_column_to_image_impl.hpp | 29 +- ..._contraction_multiple_abd_xdl_cshuffle.hpp | 14 +- .../device_gemm_multiple_abd_xdl_cshuffle.hpp | 14 +- ...conv_bwd_data_multiple_d_wmma_cshuffle.hpp | 4 +- ...nv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp | 4 +- .../device_grouped_conv_bwd_weight_dl.hpp | 4 +- ..._grouped_conv_bwd_weight_wmma_cshuffle.hpp | 4 +- ...e_grouped_conv_bwd_weight_xdl_cshuffle.hpp | 4 +- ..._conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp | 4 +- ...uped_conv_fwd_multiple_d_wmma_cshuffle.hpp | 4 +- ...ouped_conv_fwd_multiple_d_xdl_cshuffle.hpp | 480 ++++++++++++------ .../device/impl/device_grouped_conv_utils.hpp | 129 ++++- .../impl/device_image_to_column_impl.hpp | 29 +- .../element/binary_element_wise_operation.hpp | 7 +- ...ridwise_gemm_multiple_abd_xdl_cshuffle.hpp | 24 +- .../cpu/reference_conv_fwd.hpp | 318 +++++++----- ...uped_conv_fwd_xdl_scaleadd_ab_instance.hpp | 124 +++++ ...rouped_convolution_forward_scaleadd_ab.hpp | 179 +++++++ .../CMakeLists.txt | 7 + ..._ab_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp | 52 ++ ...d_ab_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp | 52 ++ ...d_ab_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp | 52 ++ ..._ab_ndhwgc_gkzyxc_ndhwgk_int8_instance.cpp | 51 ++ test/grouped_convnd_fwd/CMakeLists.txt | 2 + ..._grouped_convnd_fwd_multi_ab_interface.cpp | 235 +++++++++ 40 files changed, 2235 insertions(+), 365 deletions(-) create mode 100644 client_example/24_grouped_convnd_fwd_scaleadd_ab/CMakeLists.txt create mode 100644 client_example/24_grouped_convnd_fwd_scaleadd_ab/grouped_conv_fwd_scaleadd_ab.inc create mode 100644 client_example/24_grouped_convnd_fwd_scaleadd_ab/grouped_conv_fwd_scaleadd_ab_bf16.cpp create mode 100644 client_example/24_grouped_convnd_fwd_scaleadd_ab/grouped_conv_fwd_scaleadd_ab_fp16.cpp create mode 100644 client_example/24_grouped_convnd_fwd_scaleadd_ab/grouped_conv_fwd_scaleadd_ab_fp32.cpp create mode 100644 client_example/24_grouped_convnd_fwd_scaleadd_ab/grouped_conv_fwd_scaleadd_ab_int8.cpp create mode 100644 example/62_conv_fwd_activ/multi_AB/conv_fwd_xdl_scaleadd_ab_bf16.cpp create mode 100644 example/62_conv_fwd_activ/multi_AB/conv_fwd_xdl_scaleadd_ab_fp16.cpp create mode 100644 example/62_conv_fwd_activ/multi_AB/conv_fwd_xdl_scaleadd_ab_fp32.cpp create mode 100644 example/62_conv_fwd_activ/multi_AB/conv_fwd_xdl_scaleadd_ab_int8.cpp create mode 100644 example/62_conv_fwd_activ/multi_AB/convnd_fwd_activ_multi_ab_common.hpp create mode 100644 library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_scaleadd_ab_instance.hpp create mode 100644 library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_scaleadd_ab.hpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scaleadd_ab/CMakeLists.txt create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scaleadd_ab/xdl/device_grouped_conv3d_fwd_xdl_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scaleadd_ab/xdl/device_grouped_conv3d_fwd_xdl_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scaleadd_ab/xdl/device_grouped_conv3d_fwd_xdl_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scaleadd_ab/xdl/device_grouped_conv3d_fwd_xdl_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_int8_instance.cpp create mode 100644 test/grouped_convnd_fwd/test_grouped_convnd_fwd_multi_ab_interface.cpp diff --git a/client_example/23_grouped_convnd_fwd_scaleadd_scaleadd_relu/grouped_conv_fwd_scaleadd_scaleadd_relu.inc b/client_example/23_grouped_convnd_fwd_scaleadd_scaleadd_relu/grouped_conv_fwd_scaleadd_scaleadd_relu.inc index 0f316d5953..1c110cd8fa 100644 --- a/client_example/23_grouped_convnd_fwd_scaleadd_scaleadd_relu/grouped_conv_fwd_scaleadd_scaleadd_relu.inc +++ b/client_example/23_grouped_convnd_fwd_scaleadd_scaleadd_relu/grouped_conv_fwd_scaleadd_scaleadd_relu.inc @@ -63,7 +63,7 @@ int execute_conv_fwd_scaleadd_scaleadd_relu() K * Z * Y * X * C, Z * Y * X * C, 1, Y * X * C, X * C, C}; std::array out_lengths{G, N, K, Do, Ho, Wo}; std::array out_strides{ - C, Do * Ho * Wo * G * C, 1, Ho * Wo * G * C, Wo * G * C, G * C}; + K, Do * Ho * Wo * G * K, 1, Ho * Wo * G * K, Wo * G * K, G * K}; std::array filter_strides{1, 1, 1}; std::array filter_dilations{1, 1, 1}; diff --git a/client_example/24_grouped_convnd_fwd_scaleadd_ab/CMakeLists.txt b/client_example/24_grouped_convnd_fwd_scaleadd_ab/CMakeLists.txt new file mode 100644 index 0000000000..94a5ad0685 --- /dev/null +++ b/client_example/24_grouped_convnd_fwd_scaleadd_ab/CMakeLists.txt @@ -0,0 +1,11 @@ +add_executable(client_grouped_convnd_fwd_scaleadd_ab_fp32 grouped_conv_fwd_scaleadd_ab_fp32.cpp) +target_link_libraries(client_grouped_convnd_fwd_scaleadd_ab_fp32 PRIVATE composable_kernel::device_operations) + +add_executable(client_grouped_convnd_fwd_scaleadd_ab_fp16 grouped_conv_fwd_scaleadd_ab_fp16.cpp) +target_link_libraries(client_grouped_convnd_fwd_scaleadd_ab_fp16 PRIVATE composable_kernel::device_operations) + +add_executable(client_grouped_convnd_fwd_scaleadd_ab_bf16 grouped_conv_fwd_scaleadd_ab_bf16.cpp) +target_link_libraries(client_grouped_convnd_fwd_scaleadd_ab_bf16 PRIVATE composable_kernel::device_operations) + +add_executable(client_grouped_convnd_fwd_scaleadd_ab_int8 grouped_conv_fwd_scaleadd_ab_int8.cpp) +target_link_libraries(client_grouped_convnd_fwd_scaleadd_ab_int8 PRIVATE composable_kernel::device_operations) diff --git a/client_example/24_grouped_convnd_fwd_scaleadd_ab/grouped_conv_fwd_scaleadd_ab.inc b/client_example/24_grouped_convnd_fwd_scaleadd_ab/grouped_conv_fwd_scaleadd_ab.inc new file mode 100644 index 0000000000..54f24f8554 --- /dev/null +++ b/client_example/24_grouped_convnd_fwd_scaleadd_ab/grouped_conv_fwd_scaleadd_ab.inc @@ -0,0 +1,221 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_scaleadd_ab.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +using InLayout = ck::tensor_layout::convolution::NDHWGC; +using WeiLayout = ck::tensor_layout::convolution::GKZYXC; +using OutLayout = ck::tensor_layout::convolution::NDHWGK; +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using ScaleAdd = ck::tensor_operation::element_wise::ScaleAdd; + +static constexpr ck::index_t NumDimSpatial = 3; +static constexpr ck::index_t G = 32; +static constexpr ck::index_t N = 64; // batch size +static constexpr ck::index_t K = 64; // output channel +static constexpr ck::index_t C = 32; // input channel (per group) +static constexpr ck::index_t Z = 3; // filter D +static constexpr ck::index_t Y = 3; // filter H +static constexpr ck::index_t X = 3; // filter W +static constexpr ck::index_t Di = 14; // input D +static constexpr ck::index_t Hi = 14; // input H +static constexpr ck::index_t Wi = 14; // input W +static constexpr ck::index_t Do = 14; // output D +static constexpr ck::index_t Ho = 14; // output H +static constexpr ck::index_t Wo = 14; // output W + +struct SimpleDeviceMem +{ + SimpleDeviceMem() = delete; + + SimpleDeviceMem(std::size_t mem_size) : p_mem_{} + { + (void)hipMalloc(static_cast(&p_mem_), mem_size); + } + + void* GetDeviceBuffer() { return p_mem_; } + + ~SimpleDeviceMem() { (void)hipFree(p_mem_); } + + void* p_mem_; +}; + +int execute_conv_fwd_scaleadd_ab() +{ + constexpr ck::index_t NumAs = 2; + constexpr ck::index_t NumBs = 2; + + constexpr float scale = 1.5f; + + // We have NHWGC/GKYXC/NHWGK (x, weight, y) in memory space. + // However, CK's API only accepts lengths and strides with order of GNCDHW/GKCZYX/GNKDHW. + // Hence, we need to adjust the order of strides. + std::array in_lengths{G, N, C, Di, Hi, Wi}; + std::array in_strides{ + C, Di * Hi * Wi * G * C, 1, Hi * Wi * G * C, Wi * G * C, G * C}; + std::array wei_lengths{G, K, C, Z, Y, X}; + std::array wei_strides{ + K * Z * Y * X * C, Z * Y * X * C, 1, Y * X * C, X * C, C}; + std::array out_lengths{G, N, K, Do, Ho, Wo}; + std::array out_strides{ + K, Do * Ho * Wo * G * K, 1, Ho * Wo * G * K, Wo * G * K, G * K}; + + std::array filter_strides{1, 1, 1}; + std::array filter_dilations{1, 1, 1}; + std::array input_left_pads{1, 1, 1}; + std::array input_right_pads{1, 1, 1}; + + using InputDtype = ck::tuple_element_t<0, InDataType>; + using InputBiasDtype = ck::tuple_element_t<1, InDataType>; + using WeightDtype = ck::tuple_element_t<0, WeiDataType>; + using WeightBiasDtype = ck::tuple_element_t<1, WeiDataType>; + + SimpleDeviceMem in(sizeof(InputDtype) * N * Di * Hi * Wi * G * C); + SimpleDeviceMem in_bias(sizeof(InputBiasDtype) * N * Di * Hi * Wi * G * C); + SimpleDeviceMem wei(sizeof(WeightDtype) * G * K * Z * Y * X * C); + SimpleDeviceMem wei_bias(sizeof(WeightBiasDtype) * G * K * Z * Y * X * C); + SimpleDeviceMem out(sizeof(OutDataType) * N * Do * Ho * Wo * G * K); + + using DeviceOp = ck::tensor_operation::device::DeviceGroupedConvFwdMultipleD, + OutLayout, + InDataType, + WeiDataType, + ck::Tuple<>, + OutDataType, + ScaleAdd, + ScaleAdd, + PassThrough>; + + // get device op instances + const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + std::cout << "found " << op_ptrs.size() << " instances" << std::endl; + + std::string best_op_name; + int best_op_id = -1; + float best_avg_time = std::numeric_limits::max(); + float best_gb_per_sec = 0; + float best_tflops = 0; + + // profile device operation instances + std::cout << "Run all instances and do timing" << std::endl; + + std::array as = {in.GetDeviceBuffer(), in_bias.GetDeviceBuffer()}; + std::array bs = {wei.GetDeviceBuffer(), wei_bias.GetDeviceBuffer()}; + std::array ds{}; + + for(int i = 0; i < op_ptrs.size(); ++i) + { + auto& op_ptr = op_ptrs[i]; + auto argument_ptr = op_ptr->MakeArgumentPointer(as, + bs, + ds, + out.GetDeviceBuffer(), + in_lengths, + in_strides, + wei_lengths, + wei_strides, + {}, + {}, + out_lengths, + out_strides, + filter_strides, + filter_dilations, + input_left_pads, + input_right_pads, + ScaleAdd{scale}, + ScaleAdd{scale}, + PassThrough{}); + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + std::string op_name = op_ptr->GetTypeString(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + float avg_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true}); + + std::size_t flop = std::size_t(2) * G * N * K * C * Do * Ho * Wo * Z * Y * X + + N * Di * Hi * Wi * G * C + G * K * Z * Y * X * C; + std::size_t num_bytes = 2 * sizeof(InDataType) * N * Di * Hi * Wi * G * C + + 2 * sizeof(WeiDataType) * G * K * Z * Y * X * C + + sizeof(OutDataType) * N * Do * Ho * Wo * G * K; + + float tflops = static_cast(flop) / 1.E9 / avg_time; + float gb_per_sec = num_bytes / 1.E6 / avg_time; + + std::cout << "Perf: " << std::setw(10) << avg_time << " ms, " << tflops << " TFlops, " + << gb_per_sec << " GB/s, " << op_name << std::endl; + + if(tflops > best_tflops) + { + best_op_id = i; + best_op_name = op_name; + best_avg_time = avg_time; + best_gb_per_sec = gb_per_sec; + best_tflops = tflops; + } + } + else + { + std::cerr << op_name << " does not support this problem" << std::endl; + } + } + + if(best_op_id < 0) + { + std::cerr << "no suitable instance" << std::endl; + return EXIT_FAILURE; + } + + std::cout << "Best Perf: " << std::setw(10) << best_avg_time << " ms, " << best_tflops + << " TFlops, " << best_gb_per_sec << " GB/s, " << best_op_name << std::endl; + + // run the best intance + { + auto& op_ptr = op_ptrs[best_op_id]; + std::cout << "Run the best instance without timing: " << op_ptr->GetTypeString() + << std::endl; + auto argument_ptr = op_ptr->MakeArgumentPointer(as, + bs, + ds, + out.GetDeviceBuffer(), + in_lengths, + in_strides, + wei_lengths, + wei_strides, + {}, + {}, + out_lengths, + out_strides, + filter_strides, + filter_dilations, + input_left_pads, + input_right_pads, + ScaleAdd{scale}, + ScaleAdd{scale}, + PassThrough{}); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false}); + } + + std::cout << "Done" << std::endl; + } + return 0; +} diff --git a/client_example/24_grouped_convnd_fwd_scaleadd_ab/grouped_conv_fwd_scaleadd_ab_bf16.cpp b/client_example/24_grouped_convnd_fwd_scaleadd_ab/grouped_conv_fwd_scaleadd_ab_bf16.cpp new file mode 100644 index 0000000000..f384d854df --- /dev/null +++ b/client_example/24_grouped_convnd_fwd_scaleadd_ab/grouped_conv_fwd_scaleadd_ab_bf16.cpp @@ -0,0 +1,13 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/utility/data_type.hpp" +#include "ck/utility/tuple.hpp" + +using InDataType = ck::Tuple; +using WeiDataType = ck::Tuple; +using OutDataType = ck::bhalf_t; + +#include "grouped_conv_fwd_scaleadd_ab.inc" + +int main() { return execute_conv_fwd_scaleadd_ab(); } diff --git a/client_example/24_grouped_convnd_fwd_scaleadd_ab/grouped_conv_fwd_scaleadd_ab_fp16.cpp b/client_example/24_grouped_convnd_fwd_scaleadd_ab/grouped_conv_fwd_scaleadd_ab_fp16.cpp new file mode 100644 index 0000000000..fd61ef1e15 --- /dev/null +++ b/client_example/24_grouped_convnd_fwd_scaleadd_ab/grouped_conv_fwd_scaleadd_ab_fp16.cpp @@ -0,0 +1,13 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/utility/data_type.hpp" +#include "ck/utility/tuple.hpp" + +using InDataType = ck::Tuple; +using WeiDataType = ck::Tuple; +using OutDataType = ck::half_t; + +#include "grouped_conv_fwd_scaleadd_ab.inc" + +int main() { return execute_conv_fwd_scaleadd_ab(); } diff --git a/client_example/24_grouped_convnd_fwd_scaleadd_ab/grouped_conv_fwd_scaleadd_ab_fp32.cpp b/client_example/24_grouped_convnd_fwd_scaleadd_ab/grouped_conv_fwd_scaleadd_ab_fp32.cpp new file mode 100644 index 0000000000..387369c667 --- /dev/null +++ b/client_example/24_grouped_convnd_fwd_scaleadd_ab/grouped_conv_fwd_scaleadd_ab_fp32.cpp @@ -0,0 +1,13 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/utility/data_type.hpp" +#include "ck/utility/tuple.hpp" + +using InDataType = ck::Tuple; +using WeiDataType = ck::Tuple; +using OutDataType = float; + +#include "grouped_conv_fwd_scaleadd_ab.inc" + +int main() { return execute_conv_fwd_scaleadd_ab(); } diff --git a/client_example/24_grouped_convnd_fwd_scaleadd_ab/grouped_conv_fwd_scaleadd_ab_int8.cpp b/client_example/24_grouped_convnd_fwd_scaleadd_ab/grouped_conv_fwd_scaleadd_ab_int8.cpp new file mode 100644 index 0000000000..20654c7180 --- /dev/null +++ b/client_example/24_grouped_convnd_fwd_scaleadd_ab/grouped_conv_fwd_scaleadd_ab_int8.cpp @@ -0,0 +1,13 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/utility/data_type.hpp" +#include "ck/utility/tuple.hpp" + +using InDataType = ck::Tuple; +using WeiDataType = ck::Tuple; +using OutDataType = int8_t; + +#include "grouped_conv_fwd_scaleadd_ab.inc" + +int main() { return execute_conv_fwd_scaleadd_ab(); } diff --git a/example/62_conv_fwd_activ/CMakeLists.txt b/example/62_conv_fwd_activ/CMakeLists.txt index 3cc69a6e87..bb95602416 100644 --- a/example/62_conv_fwd_activ/CMakeLists.txt +++ b/example/62_conv_fwd_activ/CMakeLists.txt @@ -30,6 +30,15 @@ foreach(gpu IN LISTS GPU_TARGETS) # Elu add_example_executable(example_convnd_fwd_xdl_elu_fp16 convnd_fwd_xdl_elu_fp16.cpp) add_example_dependencies(example_convnd_fwd_activ_xdl example_convnd_fwd_xdl_elu_fp16) + # ScaleAdd on A and B + add_example_executable(example_conv_fwd_xdl_scaleadd_ab_fp16 multi_AB/conv_fwd_xdl_scaleadd_ab_fp16.cpp) + add_example_dependencies(example_convnd_fwd_activ_xdl example_conv_fwd_xdl_scaleadd_ab_fp16) + add_example_executable(example_conv_fwd_xdl_scaleadd_ab_fp32 multi_AB/conv_fwd_xdl_scaleadd_ab_fp32.cpp) + add_example_dependencies(example_convnd_fwd_activ_xdl example_conv_fwd_xdl_scaleadd_ab_fp32) + add_example_executable(example_conv_fwd_xdl_scaleadd_ab_bf16 multi_AB/conv_fwd_xdl_scaleadd_ab_bf16.cpp) + add_example_dependencies(example_convnd_fwd_activ_xdl example_conv_fwd_xdl_scaleadd_ab_bf16) + add_example_executable(example_conv_fwd_xdl_scaleadd_ab_int8 multi_AB/conv_fwd_xdl_scaleadd_ab_int8.cpp) + add_example_dependencies(example_convnd_fwd_activ_xdl example_conv_fwd_xdl_scaleadd_ab_int8) # ScaleAdd ScaleAdd Relu add_example_executable(example_convnd_fwd_xdl_scaleadd_scaleadd_relu_fp16 convnd_fwd_xdl_scaleadd_scaleadd_relu_fp16.cpp) add_example_dependencies(example_convnd_fwd_activ_xdl example_convnd_fwd_xdl_scaleadd_scaleadd_relu_fp16) diff --git a/example/62_conv_fwd_activ/convnd_fwd_xdl_scaleadd_scaleadd_relu_fp16.cpp b/example/62_conv_fwd_activ/convnd_fwd_xdl_scaleadd_scaleadd_relu_fp16.cpp index e716a85010..126d95c176 100644 --- a/example/62_conv_fwd_activ/convnd_fwd_xdl_scaleadd_scaleadd_relu_fp16.cpp +++ b/example/62_conv_fwd_activ/convnd_fwd_xdl_scaleadd_scaleadd_relu_fp16.cpp @@ -226,14 +226,17 @@ bool run_grouped_conv_fwd(bool do_verification, if(do_verification) { - auto ref_conv = ck::tensor_operation::host::ReferenceConvFwd(); + auto ref_conv = + ck::tensor_operation::host::ReferenceConvFwd(); auto ref_invoker = ref_conv.MakeInvoker(); auto ref_argument = ref_conv.MakeArgument(in, @@ -246,6 +249,8 @@ bool run_grouped_conv_fwd(bool do_verification, in_element_op, wei_element_op, out_element_op, + {}, + {}, d_tensors); ref_invoker.Run(ref_argument); diff --git a/example/62_conv_fwd_activ/multi_AB/conv_fwd_xdl_scaleadd_ab_bf16.cpp b/example/62_conv_fwd_activ/multi_AB/conv_fwd_xdl_scaleadd_ab_bf16.cpp new file mode 100644 index 0000000000..7993552210 --- /dev/null +++ b/example/62_conv_fwd_activ/multi_AB/conv_fwd_xdl_scaleadd_ab_bf16.cpp @@ -0,0 +1,26 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "convnd_fwd_activ_multi_ab_common.hpp" + +using DataType = ck::bhalf_t; +using AccDataType = float; +using InDataType = DataType; +using WeiDataType = DataType; +using OutDataType = DataType; +using ADataTypes = ck::Tuple; +using BDataTypes = ck::Tuple; + +using InElementOp = ck::tensor_operation::element_wise::ScaleAdd; +using WeiElementOp = ck::tensor_operation::element_wise::ScaleAdd; + +using DeviceGroupedConvNDFwdActivInstance = DeviceGroupedConvNDMultiABFwdInstance; + +#include "../run_convnd_fwd_activ_example.inc" + +int main(int argc, char* argv[]) { return !run_convnd_fwd_example(argc, argv); } diff --git a/example/62_conv_fwd_activ/multi_AB/conv_fwd_xdl_scaleadd_ab_fp16.cpp b/example/62_conv_fwd_activ/multi_AB/conv_fwd_xdl_scaleadd_ab_fp16.cpp new file mode 100644 index 0000000000..696bc0c3fe --- /dev/null +++ b/example/62_conv_fwd_activ/multi_AB/conv_fwd_xdl_scaleadd_ab_fp16.cpp @@ -0,0 +1,26 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "convnd_fwd_activ_multi_ab_common.hpp" + +using DataType = ck::half_t; +using AccDataType = float; +using InDataType = DataType; +using WeiDataType = DataType; +using OutDataType = DataType; +using ADataTypes = ck::Tuple; +using BDataTypes = ck::Tuple; + +using InElementOp = ck::tensor_operation::element_wise::ScaleAdd; +using WeiElementOp = ck::tensor_operation::element_wise::ScaleAdd; + +using DeviceGroupedConvNDFwdActivInstance = DeviceGroupedConvNDMultiABFwdInstance; + +#include "../run_convnd_fwd_activ_example.inc" + +int main(int argc, char* argv[]) { return !run_convnd_fwd_example(argc, argv); } diff --git a/example/62_conv_fwd_activ/multi_AB/conv_fwd_xdl_scaleadd_ab_fp32.cpp b/example/62_conv_fwd_activ/multi_AB/conv_fwd_xdl_scaleadd_ab_fp32.cpp new file mode 100644 index 0000000000..a95f5e1347 --- /dev/null +++ b/example/62_conv_fwd_activ/multi_AB/conv_fwd_xdl_scaleadd_ab_fp32.cpp @@ -0,0 +1,26 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "convnd_fwd_activ_multi_ab_common.hpp" + +using DataType = float; +using AccDataType = float; +using InDataType = DataType; +using WeiDataType = DataType; +using OutDataType = DataType; +using ADataTypes = ck::Tuple; +using BDataTypes = ck::Tuple; + +using InElementOp = ck::tensor_operation::element_wise::ScaleAdd; +using WeiElementOp = ck::tensor_operation::element_wise::ScaleAdd; + +using DeviceGroupedConvNDFwdActivInstance = DeviceGroupedConvNDMultiABFwdInstance; + +#include "../run_convnd_fwd_activ_example.inc" + +int main(int argc, char* argv[]) { return !run_convnd_fwd_example(argc, argv); } diff --git a/example/62_conv_fwd_activ/multi_AB/conv_fwd_xdl_scaleadd_ab_int8.cpp b/example/62_conv_fwd_activ/multi_AB/conv_fwd_xdl_scaleadd_ab_int8.cpp new file mode 100644 index 0000000000..4fde3a722d --- /dev/null +++ b/example/62_conv_fwd_activ/multi_AB/conv_fwd_xdl_scaleadd_ab_int8.cpp @@ -0,0 +1,26 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "convnd_fwd_activ_multi_ab_common.hpp" + +using DataType = int8_t; +using AccDataType = int32_t; +using InDataType = DataType; +using WeiDataType = DataType; +using OutDataType = DataType; +using ADataTypes = ck::Tuple; +using BDataTypes = ck::Tuple; + +using InElementOp = ck::tensor_operation::element_wise::ScaleAdd; +using WeiElementOp = ck::tensor_operation::element_wise::ScaleAdd; + +using DeviceGroupedConvNDFwdActivInstance = DeviceGroupedConvNDMultiABFwdInstance; + +#include "../run_convnd_fwd_activ_example.inc" + +int main(int argc, char* argv[]) { return !run_convnd_fwd_example(argc, argv); } diff --git a/example/62_conv_fwd_activ/multi_AB/convnd_fwd_activ_multi_ab_common.hpp b/example/62_conv_fwd_activ/multi_AB/convnd_fwd_activ_multi_ab_common.hpp new file mode 100644 index 0000000000..6159b49805 --- /dev/null +++ b/example/62_conv_fwd_activ/multi_AB/convnd_fwd_activ_multi_ab_common.hpp @@ -0,0 +1,266 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_cshuffle.hpp" + +#include "ck/library/utility/algorithm.hpp" +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/convolution_parameter.hpp" +#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp" +#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp" + +constexpr ck::index_t NDimSpatial = 3; + +template +using S = ck::Sequence; + +using InLayout = ck::tensor_layout::convolution::GNDHWC; +using WeiLayout = ck::tensor_layout::convolution::GKZYXC; +using OutLayout = ck::tensor_layout::convolution::GNDHWK; + +using OutElementOp = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto ConvSpec = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +template +using DeviceGroupedConvNDMultiABFwdInstance = + ck::tensor_operation::device::DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< + NDimSpatial, + InLayout, + WeiLayout, + ck::Tuple<>, + OutLayout, + InDataTypes, + WeiDataTypes, + AccDataType, + DataType, + ck::Tuple<>, + DataType, + InElementOp, + WeiElementOp, + OutElementOp, + ConvSpec, // ConvForwardSpecialization + GemmSpec, // GemmSpecialization + 1, // + 256, // BlockSize + 128, // MPerBlock + 256, // NPerBlock + 32, // KPerBlock + 8, // AK1 + 8, // BK1 + 32, // MPerXdl + 32, // NPerXdl + 2, // MXdlPerWave + 4, // NXdlPerWave + S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1 + S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // ABlockTransferSrcAccessOrder + 2, // ABlockTransferSrcVectorDim + 8, // ABlockTransferSrcScalarPerVector + 8, // ABlockTransferDstScalarPerVector_AK1 + 1, // ABlockLdsExtraM + S<4, 64, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1 + S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // BBlockTransferSrcAccessOrder + 2, // BBlockTransferSrcVectorDim + 8, // BBlockTransferSrcScalarPerVector + 8, // BBlockTransferDstScalarPerVector_BK1 + 1, // BBlockLdsExtraN + 1, + 1, + S<1, 32, 1, 8>, + 8>; + +namespace { +template +bool run_grouped_conv_fwd(bool do_verification, + int init_method, + bool time_kernel, + const ck::utils::conv::ConvParam& conv_param, + const HostTensorDescriptor& in_g_n_c_wis_desc, + const HostTensorDescriptor& wei_g_k_c_xs_desc, + const HostTensorDescriptor& out_g_n_k_wos_desc, + const InElementOp& in_element_op, + const WeiElementOp& wei_element_op, + const OutElementOp& out_element_op) +{ + constexpr ck::index_t NumAs = 2; + constexpr ck::index_t NumBs = 2; + Tensor in(in_g_n_c_wis_desc); + Tensor in_bias(in_g_n_c_wis_desc); + Tensor wei(wei_g_k_c_xs_desc); + Tensor wei_bias(wei_g_k_c_xs_desc); + Tensor out_host(out_g_n_k_wos_desc); + Tensor out_device(out_g_n_k_wos_desc); + + std::cout << "in: " << in.mDesc << std::endl; + std::cout << "wei: " << wei.mDesc << std::endl; + std::cout << "out: " << out_host.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + in.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + in_bias.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + wei.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + wei_bias.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + break; + default: + in.GenerateTensorValue(GeneratorTensor_3{-1.0, 1.0}); + in_bias.GenerateTensorValue(GeneratorTensor_3{-1.0, 1.0}); + wei.GenerateTensorValue(GeneratorTensor_3{-0.05, 0.05}); + wei_bias.GenerateTensorValue(GeneratorTensor_3{-1.0, 1.0}); + } + + DeviceMem in_device_buf(sizeof(InDataType) * in.mDesc.GetElementSpaceSize()); + DeviceMem in_bias_device_buf(sizeof(InDataType) * in_bias.mDesc.GetElementSpaceSize()); + DeviceMem wei_device_buf(sizeof(WeiDataType) * wei.mDesc.GetElementSpaceSize()); + DeviceMem wei_bias_device_buf(sizeof(WeiDataType) * wei_bias.mDesc.GetElementSpaceSize()); + DeviceMem out_device_buf(sizeof(OutDataType) * out_device.mDesc.GetElementSpaceSize()); + + in_device_buf.ToDevice(in.mData.data()); + in_bias_device_buf.ToDevice(in_bias.mData.data()); + wei_device_buf.ToDevice(wei.mData.data()); + wei_bias_device_buf.ToDevice(wei_bias.mData.data()); + + std::array a_g_n_c_wis_lengths{}; + std::array a_g_n_c_wis_strides{}; + std::array b_g_k_c_xs_lengths{}; + std::array b_g_k_c_xs_strides{}; + std::array e_g_n_k_wos_lengths{}; + std::array e_g_n_k_wos_strides{}; + std::array conv_filter_strides{}; + std::array conv_filter_dilations{}; + std::array input_left_pads{}; + std::array input_right_pads{}; + + auto copy = [](const auto& x, auto& y) { ck::ranges::copy(x, y.begin()); }; + + copy(in_g_n_c_wis_desc.GetLengths(), a_g_n_c_wis_lengths); + copy(in_g_n_c_wis_desc.GetStrides(), a_g_n_c_wis_strides); + copy(wei_g_k_c_xs_desc.GetLengths(), b_g_k_c_xs_lengths); + copy(wei_g_k_c_xs_desc.GetStrides(), b_g_k_c_xs_strides); + copy(out_g_n_k_wos_desc.GetLengths(), e_g_n_k_wos_lengths); + copy(out_g_n_k_wos_desc.GetStrides(), e_g_n_k_wos_strides); + copy(conv_param.conv_filter_strides_, conv_filter_strides); + copy(conv_param.conv_filter_dilations_, conv_filter_dilations); + copy(conv_param.input_left_pads_, input_left_pads); + copy(conv_param.input_right_pads_, input_right_pads); + + std::array as{in_device_buf.GetDeviceBuffer(), + in_bias_device_buf.GetDeviceBuffer()}; + std::array bs{wei_device_buf.GetDeviceBuffer(), + wei_bias_device_buf.GetDeviceBuffer()}; + std::array ds{}; + + // do Conv + auto conv = DeviceConvNDFwdInstance{}; + auto invoker = conv.MakeInvoker(); + auto argument = conv.MakeArgument(as, + bs, + ds, + out_device_buf.GetDeviceBuffer(), + a_g_n_c_wis_lengths, + a_g_n_c_wis_strides, + b_g_k_c_xs_lengths, + b_g_k_c_xs_strides, + {}, + {}, + e_g_n_k_wos_lengths, + e_g_n_k_wos_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + in_element_op, + wei_element_op, + out_element_op); + + if(!conv.IsSupportedArgument(argument)) + { + throw std::runtime_error( + "wrong! device_conv with the specified compilation parameters does " + "not support this Conv problem"); + } + + float avg_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); + + std::size_t flop = conv_param.GetFlops() + + 2 * conv_param.GetOutputByte() / sizeof(InDataType) + + 2 * conv_param.GetOutputByte() / sizeof(WeiDataType); + std::size_t num_btype = conv_param.GetByte() + + conv_param.GetInputByte() + + conv_param.GetWeightByte(); + + float tflops = static_cast(flop) / 1.E9 / avg_time; + float gb_per_sec = num_btype / 1.E6 / avg_time; + std::cout << "Perf: " << avg_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " + << conv.GetTypeString() << std::endl; + + if(do_verification) + { + const std::array, NumAs - 1> elementwise_a_tensors = {in_bias}; + const std::array, NumBs - 1> elementwise_b_tensors = {wei_bias}; + auto ref_conv = ck::tensor_operation::host::ReferenceConvFwd(); + + auto ref_invoker = ref_conv.MakeInvoker(); + auto ref_argument = ref_conv.MakeArgument(in, + wei, + out_host, + conv_param.conv_filter_strides_, + conv_param.conv_filter_dilations_, + conv_param.input_left_pads_, + conv_param.input_right_pads_, + in_element_op, + wei_element_op, + out_element_op, + elementwise_a_tensors, + elementwise_b_tensors); + + ref_invoker.Run(ref_argument); + + out_device_buf.FromDevice(out_device.mData.data()); + + return ck::utils::check_err(out_device, out_host, "Error: incorrect results!"); + } + + return true; +} + +} // namespace diff --git a/include/ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_d.hpp b/include/ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_d.hpp index 2ca82dc6da..db1f50435a 100644 --- a/include/ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_d.hpp +++ b/include/ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_d.hpp @@ -6,18 +6,42 @@ #include #include "ck/tensor_operation/gpu/device/device_base.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp" +#include "ck/utility/is_detected.hpp" namespace ck { namespace tensor_operation { namespace device { -// Convolution Forward: -// input : input image A[G, N, C, Hi, Wi], -// input : weight B[G, K, C, Y, X], -// input : D0[G, N, K, Ho, Wo], D1[G, N, K, Ho, Wo], ... -// output : output image E[G, N, K, Ho, Wo] -// C = a_op(A) * b_op(B) -// E = cde_op(C, D0, D1, ...) +template +using is_tuple = decltype(std::declval().IsTuple()); + +/** + * \brief Grouped Convolution Forward + * + * \details + * input : input image A[G, N, C, Hi, Wi], A1[G, N, C, Hi, Wi]... + * input : weight B[G, K, C, Y, X], B1[G, K, C, Y, X]... + * input : D0[G, N, K, Ho, Wo], D1[G, N, K, Ho, Wo], ... + * output : output image E[G, N, K, Ho, Wo] + * + * C = a_op(A, A1...) * b_op(B, B1...) + * E = cde_op(C, D0, D1, ...) + * + * \tparam NDimSpatial Number of spatial dimensions. + * \tparam ALayout Input layout (also for a1, a2...). + * \tparam BLayout Weight layout (also for b1, b2...). + * \tparam DsLayout Ds layouts. + * \tparam ELayout Output layout. + * \tparam ADataType Input data type. Pass tuple if there is multiple A. + * \tparam BDataType Weight data type. Pass tuple if there is multiple B. + * \tparam DsDataType D data types. + * \tparam EDataType Output data type. + * \tparam AElementwiseOperation A elementwise operation. + * \tparam BElementwiseOperation B elementwise operation. + * \tparam CDEElementwiseOperation CDE elementwise operation. + * \tparam ComputeType Compute data type (default: ADataType, first if tuple passed). + */ template + typename ComputeType = + decltype(UnpackDataType::value, + Number<0>, + ADataType>())> // ComputeType is InputType by default (first + // in tuple for MultiAB), unpack if tuple was + // passed struct DeviceGroupedConvFwdMultipleD : public BaseOperator { + static constexpr bool isMultiA = is_detected::value; + static constexpr bool isMultiB = is_detected::value; + + static constexpr index_t NumATensor = GetNumABTensors(); + static constexpr index_t NumBTensor = GetNumABTensors(); static constexpr index_t NumDTensor = DsDataType::Size(); static_assert(NumDTensor == DsLayout::Size(), "wrong! Inconsistent NumDTensor"); + // If DataType is tuple, user has to pass std::array with pointers. + using APointers = + std::conditional_t&, const void*>; + using BPointers = + std::conditional_t&, const void*>; + + /** + * \brief Make argument pointer for grouped conv fwd. + * + * \param p_a A pointer to the input (std::array with + pointers for multiple A). + * \param p_b A pointer to the weight (std::array with + pointers for multiple B). + * \param p_ds A pointers to the Ds. + * \param p_e A pointers to the output. + * \param a_g_n_c_wis_lengths Input lengths [G, N, C, Spatial...] (for 3d). + * \param a_g_n_c_wis_strides Input strides [G, N, C, Spatial...] (for 3d). + * \param b_g_k_c_xs_lengths Weight lengths [G, K, C, Spatial...] (for 3d). + * \param b_g_k_c_xs_strides Weight strides [G, K, C, Spatial...] (for 3d). + * \param ds_g_n_k_wos_lengths Ds lengths [G, N, K, Spatial...] (for 3d). + * \param ds_g_n_k_wos_strides Ds strides [G, N, K, Spatial...] (for 3d). + * \param e_g_n_k_wos_lengths Output lengths [G, N, K, Spatial...] (for 3d). + * \param e_g_n_k_wos_strides Output strides [G, N, K, Spatial...] (for 3d). + * \param conv_filter_strides Convolution filter strides. + * \param conv_filter_dilations Convolution filter dilations. + * \param input_left_pads Input left paddings. + * \param input_right_pads Input right paddings. + * \param a_element_op A elementwise operation object. + * \param b_element_op B elementwise operation object. + * \param cde_element_op CDE elementwise operation object. + * \return Pointer to the argument. + */ virtual std::unique_ptr MakeArgumentPointer( - const void* p_a, // input image - const void* p_b, // weight + APointers p_a, + BPointers p_b, const std::array& p_ds, - void* p_e, // output image + void* p_e, const std::array& a_g_n_c_wis_lengths, const std::array& a_g_n_c_wis_strides, const std::array& b_g_k_c_xs_lengths, diff --git a/include/ck/tensor_operation/gpu/device/impl/device_column_to_image_impl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_column_to_image_impl.hpp index 567be5f364..4c6546239b 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_column_to_image_impl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_column_to_image_impl.hpp @@ -263,19 +263,18 @@ struct DeviceColumnToImageImpl decltype(BlockToCTileMap_M00_N0_M01Adapt( InputGridDesc{}))>; - using GridwiseTensorRearrangeKernel = - GridwiseTensorRearrange>; + using GridwiseTensorRearrangeKernel = GridwiseTensorRearrange>; struct Argument : public BaseArgument { @@ -453,7 +452,7 @@ struct DeviceColumnToImageImpl std::vector p_in_container_; std::vector p_out_container_; - ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch_; + ComputePtrOffsetOfStridedBatch<> compute_ptr_offset_of_batch_; }; struct Invoker : public BaseInvoker @@ -471,7 +470,7 @@ struct DeviceColumnToImageImpl OutputGridDesc, OutputDataType, Block2ETileMap, - ComputePtrOffsetOfStridedBatch, + ComputePtrOffsetOfStridedBatch<>, GridwiseTensorRearrangeKernel>; // Execute each set of independent filters diff --git a/include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_abd_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_abd_xdl_cshuffle.hpp index 309a32bf2f..29d7a2b949 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_abd_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_abd_xdl_cshuffle.hpp @@ -385,9 +385,11 @@ struct DeviceContractionMultipleABD_Xdl_CShuffle // desc for blockwise copy using AsGridDesc_AK0_M_AK1 = - remove_cvref_t; + remove_cvref_t; using BsGridDesc_BK0_N_BK1 = - remove_cvref_t; + remove_cvref_t; using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t< decltype(GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( DsGridDesc_M_N{}))>; @@ -397,7 +399,7 @@ struct DeviceContractionMultipleABD_Xdl_CShuffle // block-to-e-tile map using Block2ETileMap = - remove_cvref_t; + remove_cvref_t; // Argument struct Argument : public BaseArgument @@ -429,7 +431,7 @@ struct DeviceContractionMultipleABD_Xdl_CShuffle bs_grid_desc_bk0_n_bk1_{}, ds_grid_desc_mblock_mperblock_nblock_nperblock_{}, e_grid_desc_mblock_mperblock_nblock_nperblock_{}, - block_2_etile_map_{GridwiseGemm::MakeBlock2ETileMap(e_grid_desc_m_n_)}, + block_2_etile_map_{GridwiseGemm::MakeDefaultBlock2ETileMap(e_grid_desc_m_n_)}, a_element_op_{a_element_op}, b_element_op_{b_element_op}, cde_element_op_{cde_element_op} @@ -481,10 +483,10 @@ struct DeviceContractionMultipleABD_Xdl_CShuffle block_2_etile_map_)) { as_grid_desc_ak0_m_ak1_ = - GridwiseGemm::MakeAsGridDescriptor_AK0_M_AK1(as_grid_desc_m_k_); + GridwiseGemm::MakeDefaultAsGridDescriptor_AK0_M_AK1(as_grid_desc_m_k_); bs_grid_desc_bk0_n_bk1_ = - GridwiseGemm::MakeBsGridDescriptor_BK0_N_BK1(bs_grid_desc_n_k_); + GridwiseGemm::MakeDefaultBsGridDescriptor_BK0_N_BK1(bs_grid_desc_n_k_); ds_grid_desc_mblock_mperblock_nblock_nperblock_ = GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_abd_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_abd_xdl_cshuffle.hpp index 827a341a50..3d17734b32 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_abd_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_abd_xdl_cshuffle.hpp @@ -305,9 +305,11 @@ struct DeviceGemmMultipleABD_Xdl_CShuffle : public DeviceGemmMultipleABD; + remove_cvref_t; using BsGridDesc_BK0_N_BK1 = - remove_cvref_t; + remove_cvref_t; using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t< decltype(GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( DsGridDesc_M_N{}))>; @@ -317,7 +319,7 @@ struct DeviceGemmMultipleABD_Xdl_CShuffle : public DeviceGemmMultipleABD; + remove_cvref_t; // Argument struct Argument : public BaseArgument @@ -349,7 +351,7 @@ struct DeviceGemmMultipleABD_Xdl_CShuffle : public DeviceGemmMultipleABD block_2_ctile_map_container_; // for computing batch offset - ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch_; + ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch_; // element-wise op AElementwiseOp a_element_op_; @@ -579,7 +579,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle typename GridwiseGemm::DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, remove_reference_t, - ComputePtrOffsetOfStridedBatch, + ComputePtrOffsetOfStridedBatch, has_main_loop>; return launch_and_time_kernel( diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp index 88b859efbc..a157d16181 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp @@ -677,7 +677,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 std::vector block_2_etile_map_container_; // for computing batch offset - ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch_; + ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch_; // element-wise op AElementwiseOp a_element_op_; @@ -746,7 +746,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 DeviceOp::DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, Block2ETileMap, - ComputePtrOffsetOfStridedBatch, + ComputePtrOffsetOfStridedBatch, has_main_loop>; return launch_and_time_kernel( diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_dl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_dl.hpp index e2c9dca904..a5f34f0b24 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_dl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_dl.hpp @@ -927,7 +927,7 @@ struct DeviceGroupedConvBwdWeight_Dl : public DeviceGroupedConvBwdWeight compute_ptr_offset_of_batch_; + ComputePtrOffsetOfStridedBatch<> compute_ptr_offset_of_batch_; // element-wise op OutElementwiseOperation a_element_op_; @@ -999,7 +999,7 @@ struct DeviceGroupedConvBwdWeight_Dl : public DeviceGroupedConvBwdWeight, remove_reference_t, remove_reference_t, - ComputePtrOffsetOfStridedBatch, + ComputePtrOffsetOfStridedBatch<>, has_main_loop, has_double_loop>; diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle.hpp index e8a721eb36..dd591fb781 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle.hpp @@ -565,7 +565,7 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffle Block2CTileMap block_2_ctile_map_; // for computing batch offset - ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch_; + ComputePtrOffsetOfStridedBatch<> compute_ptr_offset_of_batch_; OutElementwiseOperation a_element_op_; InElementwiseOperation b_element_op_; @@ -647,7 +647,7 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffle DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, remove_reference_t, - ComputePtrOffsetOfStridedBatch, + ComputePtrOffsetOfStridedBatch<>, has_main_loop>; using EmptyTuple = Tuple<>; diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp index cbb8643627..468c92348e 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp @@ -1197,7 +1197,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle Block2CTileMap block_2_ctile_map_; // for computing batch offset - ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch_; + ComputePtrOffsetOfStridedBatch<> compute_ptr_offset_of_batch_; index_t M01_; index_t N01_; @@ -1276,7 +1276,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle remove_reference_t, remove_reference_t, remove_reference_t, - ComputePtrOffsetOfStridedBatch, + ComputePtrOffsetOfStridedBatch<>, has_main_loop>; return launch_and_time_kernel(stream_config, diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp index 8cfdd04e55..484f1e729d 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp @@ -537,7 +537,7 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK DefaultBlock2CTileMap block_2_ctile_map_; // for computing batch offset - ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch_; + ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch_; // element-wise op AElementwiseOperation a_element_op_; @@ -601,7 +601,7 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK DeviceOp::DsGridDesc_M0_M10_M11_N0_N10_N11, DeviceOp::CGridDesc_M0_M10_M11_N0_N10_N11, DefaultBlock2CTileMap, - ComputePtrOffsetOfStridedBatch, + ComputePtrOffsetOfStridedBatch, has_main_loop, has_double_loop>; diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp index 3e303f5c5e..4c9178d6b2 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp @@ -428,7 +428,7 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle typename GridwiseOp::DefaultBlock2CTileMap block_2_etile_map_; // for computing batch offset - ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch_; + ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch_; // element-wise op AElementwiseOperation a_element_op_; @@ -485,7 +485,7 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle typename GridwiseOp::DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, typename GridwiseOp::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, remove_reference_t, - ComputePtrOffsetOfStridedBatch, + ComputePtrOffsetOfStridedBatch, has_main_loop>; return launch_and_time_kernel(stream_config, diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_cshuffle.hpp index f4b8d66ecf..e94f58c575 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_cshuffle.hpp @@ -19,6 +19,7 @@ #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/matrix_padder.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_abd_xdl_cshuffle.hpp" #include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp" #include "ck/host_utility/device_prop.hpp" #include "ck/host_utility/kernel_launch.hpp" @@ -56,7 +57,8 @@ namespace { * */ template + bool HasMainKBlockLoop, + bool isMultiA, + bool isMultiB> __global__ void #if CK_USE_LAUNCH_BOUNDS __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #endif kernel_grouped_conv_fwd_multiple_d_xdl_cshuffle( - const ABDataType* __restrict__ p_a_grid, - const ABDataType* __restrict__ p_b_grid, + AsPointer p_as_grid, + BsPointer p_bs_grid, DsPointer p_ds_grid, EDataType* __restrict__ p_e_grid, const AElementwiseOperation a_element_op, @@ -98,14 +102,9 @@ __global__ void __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); - const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane( - static_cast(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx))); - const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane( - static_cast(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx))); const long_index_t e_batch_offset = __builtin_amdgcn_readfirstlane( static_cast(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx))); - - const auto ds_batch_offset = compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx); + const auto& ds_batch_offset = compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx); __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; @@ -117,22 +116,63 @@ __global__ void static_for<0, NumDTensor, 1>{}( [&](auto i) { p_ds_grid_grp(i) = p_ds_grid[i] + ds_batch_offset[i]; }); - GridwiseGemm::template Run(p_a_grid + a_batch_offset, - p_b_grid + b_batch_offset, - p_ds_grid_grp, - p_e_grid + e_batch_offset, - p_shared, - a_element_op, - b_element_op, - cde_element_op, - a_grid_desc_k0_m_k1, - b_grid_desc_k0_n_k1, - ds_grid_desc_mblock_mperblock_nblock_nperblock, - e_grid_desc_mblock_mperblock_nblock_nperblock_, - block_2_ctile_map); + if constexpr(isMultiA || isMultiB) + { + AsPointer p_as_grid_grp; + BsPointer p_bs_grid_grp; + + const auto& as_batch_offset = compute_ptr_offset_of_batch.GetAsPtrOffset(g_idx); + + static constexpr index_t NumATensor = AGridDesc_AK0_M_AK1::Size(); + static_for<0, NumATensor, 1>{}( + [&](auto i) { p_as_grid_grp(i) = p_as_grid[i] + as_batch_offset[i]; }); + + const auto& bs_batch_offset = compute_ptr_offset_of_batch.GetBsPtrOffset(g_idx); + + static constexpr index_t NumBTensor = BGridDesc_BK0_N_BK1::Size(); + static_for<0, NumBTensor, 1>{}( + [&](auto i) { p_bs_grid_grp(i) = p_bs_grid[i] + bs_batch_offset[i]; }); + + GridwiseGemm::template Run( + p_as_grid_grp, + p_bs_grid_grp, + p_ds_grid_grp, + p_e_grid + e_batch_offset, + p_shared, + a_element_op, + b_element_op, + cde_element_op, + a_grid_desc_k0_m_k1, + b_grid_desc_k0_n_k1, + ds_grid_desc_mblock_mperblock_nblock_nperblock, + e_grid_desc_mblock_mperblock_nblock_nperblock_, + block_2_ctile_map); + } + else + { + const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx))); + const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx))); + + GridwiseGemm::template Run( + p_as_grid + a_batch_offset, + p_bs_grid + b_batch_offset, + p_ds_grid_grp, + p_e_grid + e_batch_offset, + p_shared, + a_element_op, + b_element_op, + cde_element_op, + a_grid_desc_k0_m_k1, + b_grid_desc_k0_n_k1, + ds_grid_desc_mblock_mperblock_nblock_nperblock, + e_grid_desc_mblock_mperblock_nblock_nperblock_, + block_2_ctile_map); + } #else - ignore = p_a_grid; - ignore = p_b_grid; + ignore = p_as_grid; + ignore = p_bs_grid; ignore = p_ds_grid; ignore = p_e_grid; ignore = batch_count; @@ -150,6 +190,9 @@ __global__ void } // namespace +template +using is_tuple = decltype(std::declval().IsTuple()); + // // @brief Device Convolution operation. // @@ -211,8 +254,13 @@ template + typename ComputeDataType = + decltype(UnpackDataType::value, + Number<0>, + ADataType>()), // ComputeType is InputType by default (first + // in tuple for MultiAB), unpack if tuple was + // passed + LoopScheduler LoopSched = make_default_loop_scheduler()> struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle : public DeviceGroupedConvFwdMultipleD::value; + static constexpr bool isMultiB = is_detected::value; + + static constexpr index_t NumATensor = GetNumABTensors(); + static constexpr index_t NumBTensor = GetNumABTensors(); static constexpr index_t NumDTensor = DsDataType::Size(); static constexpr auto I0 = Number<0>{}; @@ -325,51 +378,43 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle using DsGridDesc_M_N = remove_cvref_t; using EGridDesc_M_N = remove_cvref_t({}, {}))>; - // GridwiseGemm - using GridwiseGemm = GridwiseGemmMultipleD_xdl_cshuffle< - ADataType, // TODO: distinguish A/B datatype - BDataType, - ComputeDataType, - AccDataType, - CShuffleDataType, - DsDataType, - EDataType, - AElementwiseOperation, - BElementwiseOperation, - CDEElementwiseOperation, - InMemoryDataOperationEnum::Set, - NumGemmKPrefetchStage, - BlockSize, - MPerBlock, - NPerBlock, - KPerBlock, - AK1, - BK1, - MPerXDL, - NPerXDL, - MXdlPerWave, - NXdlPerWave, - ABlockTransferThreadClusterLengths_AK0_M_AK1, - ABlockTransferThreadClusterArrangeOrder, - ABlockTransferSrcAccessOrder, - ABlockTransferSrcVectorDim, - ABlockTransferSrcScalarPerVector, - ABlockTransferDstScalarPerVector_AK1, - false, - ABlockLdsExtraM, - BBlockTransferThreadClusterLengths_BK0_N_BK1, - BBlockTransferThreadClusterArrangeOrder, - BBlockTransferSrcAccessOrder, - BBlockTransferSrcVectorDim, - BBlockTransferSrcScalarPerVector, - BBlockTransferDstScalarPerVector_BK1, - false, - BBlockLdsExtraN, - CShuffleMXdlPerWavePerShuffle, - CShuffleNXdlPerWavePerShuffle, - CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, - CDEBlockTransferScalarPerVector_NPerBlock, - LoopSched>; + // If we are using multiAB and one of the template datatype parameters is not a tuple, convert + // it to it + using GemmADataType = std::conditional_t, ADataType>; + using GemmBDataType = std::conditional_t, BDataType>; + +#define GridwiseGemmTemplateParameters \ + GemmADataType, GemmBDataType, ComputeDataType, AccDataType, CShuffleDataType, DsDataType, \ + EDataType, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, \ + InMemoryDataOperationEnum::Set, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, \ + KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave, \ + ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, \ + ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, \ + ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, \ + ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, \ + BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, \ + BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, \ + BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, \ + CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, \ + CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, \ + CDEBlockTransferScalarPerVector_NPerBlock, LoopSched + // Use appropriate gridwise gemm + using GridwiseGemm = + std::conditional_t, + GridwiseGemmMultipleD_xdl_cshuffle>; + + // If ADataTypes or BDataTypes is tuple, user has to pass std::array with pointers. + using APointers = + std::conditional_t&, const void*>; + using BPointers = + std::conditional_t&, const void*>; + // Use Tuple for the both cases for GridPointer to initialize it in Argument constructor (not + // in initializer list what is required for single const pointer). + using AGridPointer = remove_cvref_t< + decltype(GetAGridPointer < isMultiA || isMultiB, GridwiseGemm, ADataType > ())>; + using BGridPointer = remove_cvref_t< + decltype(GetBGridPointer < isMultiA || isMultiB, GridwiseGemm, BDataType > ())>; // desc for blockwise copy using AGridDesc_AK0_M_AK1 = @@ -392,8 +437,8 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle // Argument struct Argument : public BaseArgument { - Argument(const void* p_a, - const void* p_b, + Argument(APointers p_as, + BPointers p_bs, const std::array& p_ds, void* p_e, const std::array& a_g_n_c_wis_lengths, @@ -413,8 +458,8 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle const AElementwiseOperation& a_element_op, const BElementwiseOperation& b_element_op, const CDEElementwiseOperation& cde_element_op) - : p_a_grid_{static_cast(p_a)}, - p_b_grid_{static_cast(p_b)}, + : p_as_grid_{}, + p_bs_grid_{}, p_ds_grid_{}, p_e_grid_{static_cast(p_e)}, num_group_{a_g_n_c_wis_lengths[0]}, @@ -458,9 +503,58 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle input_right_pads_{input_right_pads} { // A/B/E Batch Stride - compute_ptr_offset_of_batch_.BatchStrideA_ = a_g_n_c_wis_strides[0]; - compute_ptr_offset_of_batch_.BatchStrideB_ = b_g_k_c_xs_strides[0]; - compute_ptr_offset_of_batch_.BatchStrideE_ = e_g_n_k_wos_strides[0]; + if constexpr(isMultiA || isMultiB) + { + static_for<0, NumATensor, 1>{}([&](auto i) { + // Init compute_ptr_offset_of_batch_ for multiple AB + compute_ptr_offset_of_batch_.BatchStrideA_(i) = a_g_n_c_wis_strides[0]; + + // Use GemmADataType/GemmBDataType to iterate over tuple (even if passed data + // type is not tuple) + using DataType = remove_cvref_t>; + // It is possible that one of the AB is a pointer and one is a tuple. + // Then also use multiAB but we have to cast single pointer instead of tuple of + // pointer. + if constexpr(isMultiA) + { + // p_as is tuple + p_as_grid_(i) = static_cast(p_as[i.value]); + } + else + { + // if MultiB and not MultiA then p_as is single pointer + p_as_grid_(i) = static_cast(p_as); + } + }); + static_for<0, NumBTensor, 1>{}([&](auto i) { + // Init compute_ptr_offset_of_batch_ for multiple AB + compute_ptr_offset_of_batch_.BatchStrideB_(i) = b_g_k_c_xs_strides[0]; + + using DataType = remove_cvref_t>; + // It is possible that one of the AB is a pointer and one is a tuple. + // Then also use multiAB but we have to cast single pointer instead of tuple of + // pointer. + if constexpr(isMultiB) + { + // p_bs is tuple + p_bs_grid_(i) = static_cast(p_bs[i.value]); + } + else + { + // if MultiA and not MultiB then p_bs is single pointer + p_bs_grid_(i) = static_cast(p_bs); + } + }); + } + else + { + compute_ptr_offset_of_batch_.BatchStrideA_ = a_g_n_c_wis_strides[0]; + compute_ptr_offset_of_batch_.BatchStrideB_ = b_g_k_c_xs_strides[0]; + + // p_as and p_bs are pointers + p_as_grid_(I0) = static_cast(p_as); + p_bs_grid_(I0) = static_cast(p_bs); + } // populate pointer, batch stride, desc for Ds static_for<0, NumDTensor, 1>{}([&](auto i) { @@ -477,21 +571,47 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle ds_grid_desc_m_n_(i) = DeviceOp::MakeEGridDescriptor_M_N( ds_g_n_k_wos_lengths[i], ds_g_n_k_wos_strides[i]); }); + compute_ptr_offset_of_batch_.BatchStrideE_ = e_g_n_k_wos_strides[0]; // populate desc for Ds/E - if(GridwiseGemm::CheckValidity(a_grid_desc_m_k_, - b_grid_desc_n_k_, - ds_grid_desc_m_n_, - e_grid_desc_m_n_, - block_2_etile_map_)) + if constexpr(isMultiA || isMultiB) { - e_grid_desc_mblock_mperblock_nblock_nperblock_ = - GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( - e_grid_desc_m_n_); + const auto as_grid_desc_ak0_m_ak1 = + generate_tuple([&](auto) { return a_grid_desc_m_k_; }, Number{}); + const auto bs_grid_desc_bk0_n_bk1 = + generate_tuple([&](auto) { return b_grid_desc_n_k_; }, Number{}); - ds_grid_desc_mblock_mperblock_nblock_nperblock_ = - GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( - ds_grid_desc_m_n_); + if(GridwiseGemm::CheckValidity(as_grid_desc_ak0_m_ak1, + bs_grid_desc_bk0_n_bk1, + ds_grid_desc_m_n_, + e_grid_desc_m_n_, + block_2_etile_map_)) + { + e_grid_desc_mblock_mperblock_nblock_nperblock_ = + GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + e_grid_desc_m_n_); + + ds_grid_desc_mblock_mperblock_nblock_nperblock_ = + GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + ds_grid_desc_m_n_); + } + } + else + { + if(GridwiseGemm::CheckValidity(a_grid_desc_m_k_, + b_grid_desc_n_k_, + ds_grid_desc_m_n_, + e_grid_desc_m_n_, + block_2_etile_map_)) + { + e_grid_desc_mblock_mperblock_nblock_nperblock_ = + GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + e_grid_desc_m_n_); + + ds_grid_desc_mblock_mperblock_nblock_nperblock_ = + GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + ds_grid_desc_m_n_); + } } } @@ -505,9 +625,9 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle } // private: - // pointers - const ADataType* p_a_grid_; - const BDataType* p_b_grid_; + // pointers (tuple if multi AB, pointer if no) + AGridPointer p_as_grid_; + BGridPointer p_bs_grid_; typename GridwiseGemm::DsGridPointer p_ds_grid_; EDataType* p_e_grid_; @@ -529,7 +649,8 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle Block2ETileMap block_2_etile_map_; // for computing batch offset - ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch_; + ComputePtrOffsetOfStridedBatch + compute_ptr_offset_of_batch_; // element-wise op AElementwiseOperation a_element_op_; @@ -563,16 +684,6 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle arg.Print(); } - if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_m_k_, - arg.b_grid_desc_n_k_, - arg.ds_grid_desc_m_n_, - arg.e_grid_desc_m_n_, - arg.block_2_etile_map_)) - { - throw std::runtime_error( - "wrong! GridwiseGemmMultipleD_xdl_cshuffle has invalid setting"); - } - const index_t grid_size = arg.block_2_etile_map_.CalculateGridSize(arg.e_grid_desc_m_n_) * arg.num_group_; @@ -582,41 +693,96 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle auto launch_kernel = [&](auto has_main_k_block_loop) { constexpr bool has_main_loop = has_main_k_block_loop.value; - const auto kernel = kernel_grouped_conv_fwd_multiple_d_xdl_cshuffle< - GridwiseGemm, - ADataType, // TODO: distiguish A/B datatype - typename GridwiseGemm::DsGridPointer, - EDataType, - AElementwiseOperation, - BElementwiseOperation, - CDEElementwiseOperation, - DeviceOp::AGridDesc_AK0_M_AK1, - DeviceOp::BGridDesc_BK0_N_BK1, - DeviceOp::DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, - DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, - Block2ETileMap, - ComputePtrOffsetOfStridedBatch, - has_main_loop>; + if constexpr(isMultiA || isMultiB) + { + // Generate tuples with grid descriptors for each A and B + const auto as_grid_desc_ak0_m_ak1 = generate_tuple( + [&](auto) { return arg.a_grid_desc_ak0_m_ak1_; }, Number{}); + const auto bs_grid_desc_bk0_n_bk1 = generate_tuple( + [&](auto) { return arg.b_grid_desc_bk0_n_bk1_; }, Number{}); - return launch_and_time_kernel(stream_config, - kernel, - dim3(grid_size), - dim3(BlockSize), - 0, - arg.p_a_grid_, - arg.p_b_grid_, - arg.p_ds_grid_, - arg.p_e_grid_, - arg.a_element_op_, - arg.b_element_op_, - arg.cde_element_op_, - arg.a_g_n_c_wis_lengths_[0], // Group count - arg.a_grid_desc_ak0_m_ak1_, - arg.b_grid_desc_bk0_n_bk1_, - arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_, - arg.e_grid_desc_mblock_mperblock_nblock_nperblock_, - arg.block_2_etile_map_, - arg.compute_ptr_offset_of_batch_); + const auto kernel = kernel_grouped_conv_fwd_multiple_d_xdl_cshuffle< + GridwiseGemm, + AGridPointer, + BGridPointer, + typename GridwiseGemm::DsGridPointer, + EDataType, + AElementwiseOperation, + BElementwiseOperation, + CDEElementwiseOperation, + decltype(as_grid_desc_ak0_m_ak1), + decltype(bs_grid_desc_bk0_n_bk1), + DeviceOp::DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, + DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, + Block2ETileMap, + ComputePtrOffsetOfStridedBatch, + has_main_loop, + isMultiA, + isMultiB>; + + return launch_and_time_kernel( + stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_as_grid_, + arg.p_bs_grid_, + arg.p_ds_grid_, + arg.p_e_grid_, + arg.a_element_op_, + arg.b_element_op_, + arg.cde_element_op_, + arg.a_g_n_c_wis_lengths_[0], // Group count + as_grid_desc_ak0_m_ak1, + bs_grid_desc_bk0_n_bk1, + arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.e_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.block_2_etile_map_, + arg.compute_ptr_offset_of_batch_); + } + else + { + const auto kernel = kernel_grouped_conv_fwd_multiple_d_xdl_cshuffle< + GridwiseGemm, + const ADataType*, + const BDataType*, + typename GridwiseGemm::DsGridPointer, + EDataType, + AElementwiseOperation, + BElementwiseOperation, + CDEElementwiseOperation, + DeviceOp::AGridDesc_AK0_M_AK1, + DeviceOp::BGridDesc_BK0_N_BK1, + DeviceOp::DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, + DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, + Block2ETileMap, + ComputePtrOffsetOfStridedBatch, + has_main_loop, + isMultiA, + isMultiB>; + + return launch_and_time_kernel( + stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_as_grid_.At(I0), // Pass just A descriptor instead of tuple + arg.p_bs_grid_.At(I0), // Pass just B descriptor instead of tuple + arg.p_ds_grid_, + arg.p_e_grid_, + arg.a_element_op_, + arg.b_element_op_, + arg.cde_element_op_, + arg.a_g_n_c_wis_lengths_[0], // Group count + arg.a_grid_desc_ak0_m_ak1_, + arg.b_grid_desc_bk0_n_bk1_, + arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.e_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.block_2_etile_map_, + arg.compute_ptr_offset_of_batch_); + } }; if(GridwiseGemm::CalculateHasMainKBlockLoop(K)) @@ -791,11 +957,27 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle } // check Gridwise GEMM - return GridwiseGemm::CheckValidity(arg.a_grid_desc_m_k_, - arg.b_grid_desc_n_k_, - arg.ds_grid_desc_m_n_, - arg.e_grid_desc_m_n_, - arg.block_2_etile_map_); + if constexpr(isMultiA || isMultiB) + { + // Genarate tuples with the same descriptors + const auto as_grid_desc_ak0_m_ak1 = + generate_tuple([&](auto) { return arg.a_grid_desc_m_k_; }, Number{}); + const auto bs_grid_desc_bk0_n_bk1 = + generate_tuple([&](auto) { return arg.b_grid_desc_n_k_; }, Number{}); + return GridwiseGemm::CheckValidity(as_grid_desc_ak0_m_ak1, + bs_grid_desc_bk0_n_bk1, + arg.ds_grid_desc_m_n_, + arg.e_grid_desc_m_n_, + arg.block_2_etile_map_); + } + else + { + return GridwiseGemm::CheckValidity(arg.a_grid_desc_m_k_, + arg.b_grid_desc_n_k_, + arg.ds_grid_desc_m_n_, + arg.e_grid_desc_m_n_, + arg.block_2_etile_map_); + } } bool IsSupportedArgument(const BaseArgument* p_arg) override @@ -804,8 +986,8 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle } static auto MakeArgument( - const void* p_a, - const void* p_b, + APointers p_as, + BPointers p_bs, const std::array& p_ds, void* p_e, const std::array& a_g_n_c_wis_lengths, @@ -824,8 +1006,8 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle const BElementwiseOperation& b_element_op, const CDEElementwiseOperation& cde_element_op) { - return Argument{p_a, - p_b, + return Argument{p_as, + p_bs, p_ds, p_e, a_g_n_c_wis_lengths, @@ -848,8 +1030,8 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle static auto MakeInvoker() { return Invoker{}; } std::unique_ptr MakeArgumentPointer( - const void* p_a, - const void* p_b, + APointers p_a, + BPointers p_b, const std::array& p_ds, void* p_e, const std::array& a_g_n_c_wis_lengths, diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp index e19a00299e..35f4393e36 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp @@ -9,8 +9,77 @@ namespace ck { namespace tensor_operation { namespace device { -template +template struct ComputePtrOffsetOfStridedBatch +{ +}; + +template +struct ComputePtrOffsetOfStridedBatch 1 || NumBTensor > 1)>> +{ + ComputePtrOffsetOfStridedBatch() = default; + + ComputePtrOffsetOfStridedBatch(Array& BatchStrideAs, + Array& BatchStrideBs, + Array& BatchStrideDs, + index_t BatchStrideE) + : BatchStrideA_(BatchStrideAs), + BatchStrideB_(BatchStrideBs), + BatchStrideDs_(BatchStrideDs), + BatchStrideE_(BatchStrideE) + { + } + + __host__ __device__ constexpr auto GetAsPtrOffset(index_t g_idx) const + { + Array as_offset; + static_for<0, NumATensor, 1>{}( + [&](auto i) { as_offset(i) = g_idx * static_cast(BatchStrideA_[i]); }); + return as_offset; + } + + __host__ __device__ constexpr auto GetBsPtrOffset(index_t g_idx) const + { + Array bs_offset; + static_for<0, NumBTensor, 1>{}( + [&](auto i) { bs_offset(i) = g_idx * static_cast(BatchStrideB_[i]); }); + return bs_offset; + } + + __host__ __device__ constexpr auto GetDsPtrOffset(index_t g_idx) const + { + Array ds_offset; + static_for<0, NumDTensor, 1>{}( + [&](auto i) { ds_offset(i) = g_idx * static_cast(BatchStrideDs_[i]); }); + return ds_offset; + } + + [[maybe_unused]] __host__ __device__ constexpr long_index_t GetEPtrOffset(index_t g_idx) const + { + return g_idx * static_cast(BatchStrideE_); + } + + // alias for kernels without multiple D + [[maybe_unused]] __host__ __device__ constexpr long_index_t GetCPtrOffset(index_t g_idx) const + { + return g_idx * static_cast(BatchStrideE_); + } + + Array BatchStrideA_; + Array BatchStrideB_; + Array BatchStrideDs_; + index_t BatchStrideE_; + index_t& BatchStrideC_ = BatchStrideE_; // alias for kernels without multiple D +}; + +template +struct ComputePtrOffsetOfStridedBatch> { ComputePtrOffsetOfStridedBatch() = default; @@ -54,13 +123,67 @@ struct ComputePtrOffsetOfStridedBatch return g_idx * static_cast(BatchStrideE_); } - index_t BatchStrideA_; - index_t BatchStrideB_; + ck::index_t BatchStrideA_; + ck::index_t BatchStrideB_; Array BatchStrideDs_; index_t BatchStrideE_; index_t& BatchStrideC_ = BatchStrideE_; // alias for kernels without multiple D }; +template +constexpr static auto GetNumABTensors() +{ + if constexpr(isTuple) + { + return Number{}; + } + else + { + return Number<1>{}; + } +} + +template +constexpr static auto GetAGridPointer() +{ + if constexpr(isTuple) + { + return typename GridwiseGemm::AsGridPointer{}; + } + else + { + return Tuple{}; + } +} + +template +constexpr static auto GetBGridPointer() +{ + if constexpr(isTuple) + { + return typename GridwiseGemm::BsGridPointer{}; + } + else + { + return Tuple{}; + } +} + +template +constexpr static auto UnpackDataType() +{ + if constexpr(isTuple) + { + // unpack if tuple + return tuple_element_t{}; + } + else + { + // if no, return Type + return Type{}; + } +} + } // namespace device } // namespace tensor_operation } // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/impl/device_image_to_column_impl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_image_to_column_impl.hpp index c83ffdcd26..52aeefa3a4 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_image_to_column_impl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_image_to_column_impl.hpp @@ -142,19 +142,18 @@ struct DeviceImageToColumnImpl decltype(BlockToCTileMap_M00_N0_M01Adapt( OutputGridDesc{}))>; - using GridwiseTensorRearrangeKernel = - GridwiseTensorRearrange>; + using GridwiseTensorRearrangeKernel = GridwiseTensorRearrange>; struct Argument : public BaseArgument { @@ -224,7 +223,7 @@ struct DeviceImageToColumnImpl InputGridDesc in_grid_desc_m_k_; OutputGridDesc out_grid_desc_m_k_; - ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch_; + ComputePtrOffsetOfStridedBatch<> compute_ptr_offset_of_batch_; }; struct Invoker : public BaseInvoker @@ -246,7 +245,7 @@ struct DeviceImageToColumnImpl OutputGridDesc, OutputDataType, Block2ETileMap, - ComputePtrOffsetOfStridedBatch, + ComputePtrOffsetOfStridedBatch<>, GridwiseTensorRearrangeKernel>; float elapsed_time = launch_and_time_kernel(stream_config, diff --git a/include/ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp b/include/ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp index 1b9e37b74a..f0f3b0f167 100644 --- a/include/ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp +++ b/include/ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp @@ -85,10 +85,13 @@ struct Add struct ScaleAdd { - __host__ __device__ ScaleAdd(float scale) : scale_(scale) {} + __host__ __device__ ScaleAdd(float scale = 1.f) : scale_(scale) {} template - __host__ __device__ constexpr void operator()(Y& y, const X0& x0, const X1& x1) const; + __host__ __device__ constexpr void operator()(Y& y, const X0& x0, const X1& x1) const + { + y = ck::type_convert(scale_ * ck::type_convert(x0) + ck::type_convert(x1)); + } template <> __host__ __device__ void diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_abd_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_abd_xdl_cshuffle.hpp index 89424468c6..4b7cc56796 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_abd_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_abd_xdl_cshuffle.hpp @@ -203,7 +203,7 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle // A desc for source in blockwise copy template __host__ __device__ static constexpr auto - MakeAGridDescriptor_AK0_M_AK1(const AGridDesc_M_K& a_grid_desc_m_k) + MakeDefaultAGridDescriptor_AK0_M_AK1(const AGridDesc_M_K& a_grid_desc_m_k) { const auto M = a_grid_desc_m_k.GetLength(I0); const auto K = a_grid_desc_m_k.GetLength(I1); @@ -219,17 +219,17 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle template __host__ __device__ static constexpr auto - MakeAsGridDescriptor_AK0_M_AK1(const AsGridDesc_M_K& as_grid_desc_m_k) + MakeDefaultAsGridDescriptor_AK0_M_AK1(const AsGridDesc_M_K& as_grid_desc_m_k) { return generate_tuple( - [&](auto i) { return MakeAGridDescriptor_AK0_M_AK1(as_grid_desc_m_k[i]); }, + [&](auto i) { return MakeDefaultAGridDescriptor_AK0_M_AK1(as_grid_desc_m_k[i]); }, Number{}); } // B desc for source in blockwise copy template __host__ __device__ static constexpr auto - MakeBGridDescriptor_BK0_N_BK1(const BGridDesc_N_K& b_grid_desc_n_k) + MakeDefaultBGridDescriptor_BK0_N_BK1(const BGridDesc_N_K& b_grid_desc_n_k) { const auto N = b_grid_desc_n_k.GetLength(I0); const auto K = b_grid_desc_n_k.GetLength(I1); @@ -245,10 +245,10 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle template __host__ __device__ static constexpr auto - MakeBsGridDescriptor_BK0_N_BK1(const BsGridDesc_N_K& bs_grid_desc_n_k) + MakeDefaultBsGridDescriptor_BK0_N_BK1(const BsGridDesc_N_K& bs_grid_desc_n_k) { return generate_tuple( - [&](auto i) { return MakeBGridDescriptor_BK0_N_BK1(bs_grid_desc_n_k[i]); }, + [&](auto i) { return MakeDefaultBGridDescriptor_BK0_N_BK1(bs_grid_desc_n_k[i]); }, Number{}); } @@ -288,7 +288,7 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle // return block_id to E matrix tile idx (m0, n0) mapping template __host__ __device__ static constexpr auto - MakeBlock2ETileMap(const EGridDesc_M_N& e_grid_desc_m_n) + MakeDefaultBlock2ETileMap(const EGridDesc_M_N& e_grid_desc_m_n) { return BlockToCTileMap_M00_N0_M01Adapt( e_grid_desc_m_n); @@ -591,6 +591,9 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle generate_tuple([&](auto) { return make_multi_index(0, m_block_data_idx_on_grid, 0); }, Number{}); + static_assert(ABlockTransferSrcScalarPerVector == ABlockTransferDstScalarPerVector_AK1, + "Src and Dst ScalarPerVector must be the same"); + auto a_blockwise_copy = ThreadGroupTensorSliceTransfer_v7r2< ThisThreadBlock, AsDataType, @@ -619,6 +622,9 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle generate_tuple([&](auto) { return make_multi_index(0, n_block_data_idx_on_grid, 0); }, Number{}); + static_assert(BBlockTransferSrcScalarPerVector == BBlockTransferDstScalarPerVector_BK1, + "Src and Dst ScalarPerVector must be the same"); + auto b_blockwise_copy = ThreadGroupTensorSliceTransfer_v7r2< ThisThreadBlock, BsDataType, @@ -1005,9 +1011,9 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle const auto e_grid_desc_m_n = MakeEGridDescriptor_M_N(M, N, StrideE); // tensor descriptors for block/thread-wise copy - const auto as_grid_desc_ak0_m_ak1 = MakeAsGridDescriptor_AK0_M_AK1(as_grid_desc_m_k); + const auto as_grid_desc_ak0_m_ak1 = MakeDefaultAsGridDescriptor_AK0_M_AK1(as_grid_desc_m_k); - const auto bs_grid_desc_bk0_n_bk1 = MakeBsGridDescriptor_BK0_N_BK1(bs_grid_desc_n_k); + const auto bs_grid_desc_bk0_n_bk1 = MakeDefaultBsGridDescriptor_BK0_N_BK1(bs_grid_desc_n_k); const auto ds_grid_desc_mblock_mperblock_nblock_nperblock = MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(ds_grid_desc_m_n); diff --git a/library/include/ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp b/library/include/ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp index 2e01400b84..ffc9470df2 100644 --- a/library/include/ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp +++ b/library/include/ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp @@ -3,12 +3,23 @@ #pragma once -#include +#include +#include +#include #include -#include +#include +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/device_base.hpp" + +#include "ck/library/utility/algorithm.hpp" +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/fill.hpp" #include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/convolution_parameter.hpp" +#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp" namespace ck { namespace tensor_operation { @@ -22,6 +33,7 @@ namespace host { // Supports both GNCHW/NGCHW as well as GNHWC/NHWGC physical layout // as long as dimensions in tensor descriptor is in GNCHW order // +// @tparam NDimSpatial Number of spatial dimensions. // @tparam InDataType Input tensor data type. // @tparam WeiDataType Weights tensor data type. // @tparam OutDataType Output tensor data type. @@ -29,7 +41,9 @@ namespace host { // operation. // @tparam WeiElementwiseOperation Functor for weights tensor elementwise // operation. -// @tparam NDimSpatial Number of spatial dimensions. +// @tparam NumAElementwiseTensor Number of A elementwise tensors. +// @tparam NumBElementwiseTensor Number of B elementwise tensors. +// @tparam NumDElementwiseTensor Number of D elementwise tensors. // // input descriptor in [G, N, C, Do, Ho, Wo] order // weight descriptor in [G, K, C, Z, Y, X] order @@ -42,28 +56,35 @@ template = 1 && NDimSpatial <= 3, bool>::type = false> struct ReferenceConvFwd : public device::BaseOperator { // Argument struct Argument : public device::BaseArgument { - Argument(const Tensor& input, - const Tensor& weight, - Tensor& output, - std::vector conv_filter_strides, - std::vector conv_filter_dilations, - std::vector input_left_pads, - std::vector input_right_pads, - InElementwiseOperation in_element_op, - WeiElementwiseOperation wei_element_op, - OutElementwiseOperation out_element_op, - const std::array, NumDTensor>& d_tensors) + Argument( + const Tensor& input, + const Tensor& weight, + Tensor& output, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads, + InElementwiseOperation in_element_op, + WeiElementwiseOperation wei_element_op, + OutElementwiseOperation out_element_op, + const std::array, NumAElementwiseTensor>& elementwise_a_tensors, + const std::array, NumBElementwiseTensor>& elementwise_b_tensors, + const std::array, NumDElementwiseTensor>& elementwise_d_tensors) : input_{input}, weight_{weight}, output_{output}, - d_tensors_{d_tensors}, + elementwise_a_tensors_{elementwise_a_tensors}, + elementwise_b_tensors_{elementwise_b_tensors}, + elementwise_d_tensors_{elementwise_d_tensors}, conv_strides_{conv_filter_strides}, conv_dilations_{conv_filter_dilations}, in_left_pads_{input_left_pads}, @@ -78,7 +99,9 @@ struct ReferenceConvFwd : public device::BaseOperator const Tensor& weight_; Tensor& output_; - const std::array, NumDTensor>& d_tensors_; + const std::array, NumAElementwiseTensor>& elementwise_a_tensors_; + const std::array, NumBElementwiseTensor>& elementwise_b_tensors_; + const std::array, NumDElementwiseTensor>& elementwise_d_tensors_; std::vector conv_strides_; std::vector conv_dilations_; @@ -119,42 +142,43 @@ struct ReferenceConvFwd : public device::BaseOperator if(wi >= 0 && ck::type_convert(wi) < arg.input_.GetLengths()[3]) { - float v_in; - float v_wei; + InDataType v_in; + WeiDataType v_wei; - arg.in_element_op_( - v_in, ck::type_convert(arg.input_(g, n, c, wi))); - - arg.wei_element_op_( - v_wei, ck::type_convert(arg.weight_(g, k, c, x))); - - v_acc += v_in * v_wei; + ExecuteElementwiseOp(arg.in_element_op_, + arg.elementwise_a_tensors_, + Number{}, + v_in, + arg.input_(g, n, c, wi), + g, + n, + c, + wi); + ExecuteElementwiseOp(arg.wei_element_op_, + arg.elementwise_b_tensors_, + Number{}, + v_wei, + arg.weight_(g, k, c, x), + g, + k, + c, + x); + v_acc += + ck::type_convert(v_in) * ck::type_convert(v_wei); } } } - - OutDataType v_out; OutDataType v_acc_converted = ck::type_convert(v_acc); - if constexpr(NumDTensor == 0) - { - arg.out_element_op_(v_out, v_acc_converted); - } - else if constexpr(NumDTensor == 1) - { - arg.out_element_op_(v_out, v_acc_converted, arg.d_tensors_[0](g, n, k, wo)); - } - else if constexpr(NumDTensor == 2) - { - arg.out_element_op_(v_out, - v_acc_converted, - arg.d_tensors_[0](g, n, k, wo), - arg.d_tensors_[1](g, n, k, wo)); - } - else - { - throw std::runtime_error("Output ElementOp not supported in reference."); - } - arg.output_(g, n, k, wo) = v_out; + OutDataType& v_out = arg.output_(g, n, k, wo); + ExecuteElementwiseOp(arg.out_element_op_, + arg.elementwise_d_tensors_, + Number{}, + v_out, + v_acc_converted, + g, + n, + k, + wo); }; make_ParallelTensorFunctor(func, @@ -191,44 +215,47 @@ struct ReferenceConvFwd : public device::BaseOperator wi >= 0 && ck::type_convert(wi) < arg.input_.GetLengths()[4]) { - float v_in; - float v_wei; + InDataType v_in; + WeiDataType v_wei; - arg.in_element_op_( - v_in, ck::type_convert(arg.input_(g, n, c, hi, wi))); - - arg.wei_element_op_( - v_wei, ck::type_convert(arg.weight_(g, k, c, y, x))); - - v_acc += v_in * v_wei; + ExecuteElementwiseOp(arg.in_element_op_, + arg.elementwise_a_tensors_, + Number{}, + v_in, + arg.input_(g, n, c, hi, wi), + g, + n, + c, + hi, + wi); + ExecuteElementwiseOp(arg.wei_element_op_, + arg.elementwise_b_tensors_, + Number{}, + v_wei, + arg.weight_(g, k, c, y, x), + g, + k, + c, + y, + x); + v_acc += ck::type_convert(v_in) * + ck::type_convert(v_wei); } } } } - - OutDataType v_out; OutDataType v_acc_converted = ck::type_convert(v_acc); - if constexpr(NumDTensor == 0) - { - arg.out_element_op_(v_out, v_acc_converted); - } - else if constexpr(NumDTensor == 1) - { - arg.out_element_op_( - v_out, v_acc_converted, arg.d_tensors_[0](g, n, k, ho, wo)); - } - else if constexpr(NumDTensor == 2) - { - arg.out_element_op_(v_out, - v_acc_converted, - arg.d_tensors_[0](g, n, k, ho, wo), - arg.d_tensors_[1](g, n, k, ho, wo)); - } - else - { - throw std::runtime_error("Output ElementOp not supported in reference."); - } - arg.output_(g, n, k, ho, wo) = v_out; + OutDataType& v_out = arg.output_(g, n, k, ho, wo); + ExecuteElementwiseOp(arg.out_element_op_, + arg.elementwise_d_tensors_, + Number{}, + v_out, + v_acc_converted, + g, + n, + k, + ho, + wo); }; make_ParallelTensorFunctor(func, @@ -275,47 +302,51 @@ struct ReferenceConvFwd : public device::BaseOperator ck::type_convert(wi) < arg.input_.GetLengths()[5]) { - float v_in; - float v_wei; + InDataType v_in; + WeiDataType v_wei; - arg.in_element_op_(v_in, - ck::type_convert( - arg.input_(g, n, c, di, hi, wi))); - - arg.wei_element_op_( - v_wei, - ck::type_convert(arg.weight_(g, k, c, z, y, x))); - - v_acc += v_in * v_wei; + ExecuteElementwiseOp(arg.in_element_op_, + arg.elementwise_a_tensors_, + Number{}, + v_in, + arg.input_(g, n, c, di, hi, wi), + g, + n, + c, + di, + hi, + wi); + ExecuteElementwiseOp(arg.wei_element_op_, + arg.elementwise_b_tensors_, + Number{}, + v_wei, + arg.weight_(g, k, c, z, y, x), + g, + k, + c, + z, + y, + x); + v_acc += ck::type_convert(v_in) * + ck::type_convert(v_wei); } } } } } - - OutDataType v_out; OutDataType v_acc_converted = ck::type_convert(v_acc); - if constexpr(NumDTensor == 0) - { - arg.out_element_op_(v_out, v_acc_converted); - } - else if constexpr(NumDTensor == 1) - { - arg.out_element_op_( - v_out, v_acc_converted, arg.d_tensors_[0](g, n, k, d_o, ho, wo)); - } - else if constexpr(NumDTensor == 2) - { - arg.out_element_op_(v_out, - v_acc_converted, - arg.d_tensors_[0](g, n, k, d_o, ho, wo), - arg.d_tensors_[1](g, n, k, d_o, ho, wo)); - } - else - { - throw std::runtime_error("Output ElementOp not supported in reference."); - } - arg.output_(g, n, k, d_o, ho, wo) = v_out; + OutDataType& v_out = arg.output_(g, n, k, d_o, ho, wo); + ExecuteElementwiseOp(arg.out_element_op_, + arg.elementwise_d_tensors_, + Number{}, + v_out, + v_acc_converted, + g, + n, + k, + d_o, + ho, + wo); }; make_ParallelTensorFunctor(func, @@ -338,6 +369,36 @@ struct ReferenceConvFwd : public device::BaseOperator } }; + template + static void ExecuteElementwiseOp(ElementwiseOp& elementwise_op, + ElementwiseTensor& elementwise_tensors, + NumTensor, + T& y, + const T& x, + Args... dims) + { + if constexpr(NumTensor::value == 0) + { + elementwise_op(y, x); + } + else if constexpr(NumTensor::value == 1) + { + elementwise_op(y, x, elementwise_tensors[0](dims...)); + } + else if constexpr(NumTensor::value == 2) + { + elementwise_op(y, x, elementwise_tensors[0](dims...), elementwise_tensors[1](dims...)); + } + else + { + throw std::runtime_error("ElementOp not supported in reference."); + } + } + static constexpr bool IsValidCompilationParameter() { // TODO: properly implement this check @@ -349,17 +410,20 @@ struct ReferenceConvFwd : public device::BaseOperator return NDimSpatial >= 1 && NDimSpatial <= 3; } - static auto MakeArgument(const Tensor& input, - const Tensor& weight, - Tensor& output, - std::vector conv_filter_strides, - std::vector conv_filter_dilations, - std::vector input_left_pads, - std::vector input_right_pads, - InElementwiseOperation in_element_op, - WeiElementwiseOperation wei_element_op, - OutElementwiseOperation out_element_op, - const std::array, NumDTensor>& d_tensors = {}) + static auto MakeArgument( + const Tensor& input, + const Tensor& weight, + Tensor& output, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads, + InElementwiseOperation in_element_op, + WeiElementwiseOperation wei_element_op, + OutElementwiseOperation out_element_op, + const std::array, NumAElementwiseTensor>& elementwise_a_tensors = {}, + const std::array, NumBElementwiseTensor>& elementwise_b_tensors = {}, + const std::array, NumDElementwiseTensor>& elementwise_d_tensors = {}) { return Argument{input, weight, @@ -371,7 +435,9 @@ struct ReferenceConvFwd : public device::BaseOperator in_element_op, wei_element_op, out_element_op, - d_tensors}; + elementwise_a_tensors, + elementwise_b_tensors, + elementwise_d_tensors}; } static auto MakeInvoker() { return Invoker{}; } diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_scaleadd_ab_instance.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_scaleadd_ab_instance.hpp new file mode 100644 index 0000000000..3ab1d73d01 --- /dev/null +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_scaleadd_ab_instance.hpp @@ -0,0 +1,124 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using BF16 = ck::bhalf_t; +using F16 = ck::half_t; +using F32 = float; + +template +using S = ck::Sequence; + +using namespace ck::tensor_layout::convolution; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using ScaleAdd = ck::tensor_operation::element_wise::ScaleAdd; + +static constexpr auto ConvFwdDefault = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; +static constexpr auto ConvFwd1x1P0 = ConvolutionForwardSpecialization::Filter1x1Pad0; +static constexpr auto ConvFwd1x1S1P0 = ConvolutionForwardSpecialization::Filter1x1Stride1Pad0; +static constexpr auto ConvFwdOddC = + ck::tensor_operation::device::ConvolutionForwardSpecialization::OddC; + +static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding; + +template +using device_grouped_conv_fwd_xdl_scaleadd_ab_bf16_instances = std::tuple< + // clang-format off + //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // generic instance + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle,ELayout, ck::Tuple, ck::Tuple, F32, BF16, ck::Tuple<>, BF16, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 16, 1, 4>, 1>, + // instances for small conv.K and conv.C + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle,ELayout, ck::Tuple, ck::Tuple, F32, BF16, ck::Tuple<>, BF16, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 1>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle,ELayout, ck::Tuple, ck::Tuple, F32, BF16, ck::Tuple<>, BF16, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 32, 1, 8>, 8>, + + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle,ELayout, ck::Tuple, ck::Tuple, F32, BF16, ck::Tuple<>, BF16, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8> + // clang-format on + >; + +template +using device_grouped_conv_fwd_xdl_scaleadd_ab_f16_instances = std::tuple< + // clang-format off + //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // generic instance + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle,ELayout, ck::Tuple, ck::Tuple, F32, F16, ck::Tuple<>, F16, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 16, 1, 4>, 1>, + // instances for small conv.K and conv.C + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle,ELayout, ck::Tuple, ck::Tuple, F32, F16, ck::Tuple<>, F16, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 1>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle,ELayout, ck::Tuple, ck::Tuple, F32, F16, ck::Tuple<>, F16, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 32, 1, 8>, 8>, + + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle,ELayout, ck::Tuple, ck::Tuple, F32, F16, ck::Tuple<>, F16, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8> + // clang-format on + >; + +template +using device_grouped_conv_fwd_xdl_scaleadd_ab_f32_instances = std::tuple< + // clang-format off + //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // generic instance + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle,ELayout, ck::Tuple, ck::Tuple, F32, F32, ck::Tuple<>, F32, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 1, 64, 64, 64, 16, 4, 4, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 8, 1, 8>, 1>, + // instances for small conv.K and conv.C + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle,ELayout, ck::Tuple, ck::Tuple, F32, F32, ck::Tuple<>, F32, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 1, 64, 64, 32, 16, 4, 4, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 1>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle,ELayout, ck::Tuple, ck::Tuple, F32, F32, ck::Tuple<>, F32, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 1, 256, 128, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 16, 1, 16>, 4>, + + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle,ELayout, ck::Tuple, ck::Tuple, F32, F32, ck::Tuple<>, F32, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 1, 256, 256, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4> + // clang-format on + >; + +template +using device_grouped_conv_fwd_xdl_scaleadd_ab_int8_instances = std::tuple< + // clang-format off + //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // generic instance + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle,ELayout, ck::Tuple, ck::Tuple, int32_t, int8_t, ck::Tuple<>, int8_t, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 16, 1, 4>, 1>, + // instances for small conv.K and conv.C + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle,ELayout, ck::Tuple, ck::Tuple, int32_t, int8_t, ck::Tuple<>, int8_t, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 1>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle,ELayout, ck::Tuple, ck::Tuple, int32_t, int8_t, ck::Tuple<>, int8_t, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 32, 1, 8>, 8>, + + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle,ELayout, ck::Tuple, ck::Tuple, int32_t, int8_t, ck::Tuple<>, int8_t, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8> + // clang-format on + >; + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_scaleadd_ab.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_scaleadd_ab.hpp new file mode 100644 index 0000000000..4bcf1f08a1 --- /dev/null +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_scaleadd_ab.hpp @@ -0,0 +1,179 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_d.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using ScaleAdd = ck::tensor_operation::element_wise::ScaleAdd; + +#ifdef CK_ENABLE_BF16 +// grouped conv3d forward multi AB scaleadd, NDHWGC/GKZYXC/NDHWGK +void add_device_grouped_conv3d_fwd_xdl_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_bf16_instances( + std::vector, + NDHWGK, + ck::Tuple, + ck::Tuple, + ck::Tuple<>, + BF16, + ScaleAdd, + ScaleAdd, + PassThrough>>>& instances); +#endif + +#ifdef CK_ENABLE_FP16 +void add_device_grouped_conv3d_fwd_xdl_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_f16_instances( + std::vector, + NDHWGK, + ck::Tuple, + ck::Tuple, + ck::Tuple<>, + F16, + ScaleAdd, + ScaleAdd, + PassThrough>>>& instances); +#endif + +#ifdef CK_ENABLE_FP32 +void add_device_grouped_conv3d_fwd_xdl_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_f32_instances( + std::vector, + NDHWGK, + ck::Tuple, + ck::Tuple, + ck::Tuple<>, + F32, + ScaleAdd, + ScaleAdd, + PassThrough>>>& instances); +#endif + +#ifdef CK_ENABLE_INT8 +void add_device_grouped_conv3d_fwd_xdl_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_int8_instances( + std::vector, + NDHWGK, + ck::Tuple, + ck::Tuple, + ck::Tuple<>, + int8_t, + ScaleAdd, + ScaleAdd, + PassThrough>>>& instances); +#endif + +template +struct DeviceOperationInstanceFactory> +{ + using DeviceOp = DeviceGroupedConvFwdMultipleD; + + static auto GetInstances() + { + std::vector> op_ptrs; + if constexpr(NumDimSpatial == 3 && is_same_v && + is_same_v && is_same_v) + { +#ifdef CK_ENABLE_FP32 + if constexpr(is_same_v> && + is_same_v> && + is_same_v && is_same_v) + { + add_device_grouped_conv3d_fwd_xdl_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_f32_instances( + op_ptrs); + } +#endif +#ifdef CK_ENABLE_FP16 + if constexpr(is_same_v> && + is_same_v> && + is_same_v && is_same_v) + { + add_device_grouped_conv3d_fwd_xdl_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_f16_instances( + op_ptrs); + } +#endif +#ifdef CK_ENABLE_BF16 + if constexpr(is_same_v> && + is_same_v> && + is_same_v && is_same_v) + { + add_device_grouped_conv3d_fwd_xdl_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_bf16_instances( + op_ptrs); + } +#endif +#ifdef CK_ENABLE_INT8 + if constexpr(is_same_v> && + is_same_v> && + is_same_v && is_same_v) + { + add_device_grouped_conv3d_fwd_xdl_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_int8_instances( + op_ptrs); + } +#endif + } + + return op_ptrs; + } +}; + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scaleadd_ab/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scaleadd_ab/CMakeLists.txt new file mode 100644 index 0000000000..08fb23afc9 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scaleadd_ab/CMakeLists.txt @@ -0,0 +1,7 @@ +set(GROUPED_CONV3D_FWD_SCALEADD_AB + xdl/device_grouped_conv3d_fwd_xdl_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp + xdl/device_grouped_conv3d_fwd_xdl_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp + xdl/device_grouped_conv3d_fwd_xdl_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp + xdl/device_grouped_conv3d_fwd_xdl_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_int8_instance.cpp) + +add_instance_library(device_grouped_conv3d_fwd_scaleadd_ab_instance ${GROUPED_CONV3D_FWD_SCALEADD_AB}) diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scaleadd_ab/xdl/device_grouped_conv3d_fwd_xdl_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scaleadd_ab/xdl/device_grouped_conv3d_fwd_xdl_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp new file mode 100644 index 0000000000..797ac452d6 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scaleadd_ab/xdl/device_grouped_conv3d_fwd_xdl_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp @@ -0,0 +1,52 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_scaleadd_ab_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_conv3d_fwd_xdl_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_bf16_instances( + std::vector, + NDHWGK, + ck::Tuple, + ck::Tuple, + ck::Tuple<>, + BF16, + ScaleAdd, + ScaleAdd, + PassThrough>>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_scaleadd_ab_bf16_instances<3, + NDHWGC, + GKZYXC, + NDHWGK, + ConvFwdDefault>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_scaleadd_ab_bf16_instances<3, + NDHWGC, + GKZYXC, + NDHWGK, + ConvFwd1x1P0>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_scaleadd_ab_bf16_instances<3, + NDHWGC, + GKZYXC, + NDHWGK, + ConvFwd1x1S1P0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scaleadd_ab/xdl/device_grouped_conv3d_fwd_xdl_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scaleadd_ab/xdl/device_grouped_conv3d_fwd_xdl_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp new file mode 100644 index 0000000000..27942b5fe6 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scaleadd_ab/xdl/device_grouped_conv3d_fwd_xdl_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp @@ -0,0 +1,52 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_scaleadd_ab_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_conv3d_fwd_xdl_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_f16_instances( + std::vector, + NDHWGK, + ck::Tuple, + ck::Tuple, + ck::Tuple<>, + F16, + ScaleAdd, + ScaleAdd, + PassThrough>>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_scaleadd_ab_f16_instances<3, + NDHWGC, + GKZYXC, + NDHWGK, + ConvFwdDefault>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_scaleadd_ab_f16_instances<3, + NDHWGC, + GKZYXC, + NDHWGK, + ConvFwd1x1P0>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_scaleadd_ab_f16_instances<3, + NDHWGC, + GKZYXC, + NDHWGK, + ConvFwd1x1S1P0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scaleadd_ab/xdl/device_grouped_conv3d_fwd_xdl_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scaleadd_ab/xdl/device_grouped_conv3d_fwd_xdl_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp new file mode 100644 index 0000000000..4527f57fb4 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scaleadd_ab/xdl/device_grouped_conv3d_fwd_xdl_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp @@ -0,0 +1,52 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_scaleadd_ab_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_conv3d_fwd_xdl_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_f32_instances( + std::vector, + NDHWGK, + ck::Tuple, + ck::Tuple, + ck::Tuple<>, + F32, + ScaleAdd, + ScaleAdd, + PassThrough>>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_scaleadd_ab_f32_instances<3, + NDHWGC, + GKZYXC, + NDHWGK, + ConvFwdDefault>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_scaleadd_ab_f32_instances<3, + NDHWGC, + GKZYXC, + NDHWGK, + ConvFwd1x1P0>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_scaleadd_ab_f32_instances<3, + NDHWGC, + GKZYXC, + NDHWGK, + ConvFwd1x1S1P0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scaleadd_ab/xdl/device_grouped_conv3d_fwd_xdl_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_int8_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scaleadd_ab/xdl/device_grouped_conv3d_fwd_xdl_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_int8_instance.cpp new file mode 100644 index 0000000000..3a0f041b16 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scaleadd_ab/xdl/device_grouped_conv3d_fwd_xdl_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_int8_instance.cpp @@ -0,0 +1,51 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_scaleadd_ab_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +void add_device_grouped_conv3d_fwd_xdl_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_int8_instances( + std::vector, + NDHWGK, + ck::Tuple, + ck::Tuple, + ck::Tuple<>, + int8_t, + ScaleAdd, + ScaleAdd, + PassThrough>>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_scaleadd_ab_int8_instances<3, + NDHWGC, + GKZYXC, + NDHWGK, + ConvFwdDefault>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_scaleadd_ab_int8_instances<3, + NDHWGC, + GKZYXC, + NDHWGK, + ConvFwd1x1P0>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_scaleadd_ab_int8_instances<3, + NDHWGC, + GKZYXC, + NDHWGK, + ConvFwd1x1S1P0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/test/grouped_convnd_fwd/CMakeLists.txt b/test/grouped_convnd_fwd/CMakeLists.txt index 869bd77d3d..0b8356555b 100644 --- a/test/grouped_convnd_fwd/CMakeLists.txt +++ b/test/grouped_convnd_fwd/CMakeLists.txt @@ -1,3 +1,5 @@ add_gtest_executable(test_grouped_convnd_fwd test_grouped_convnd_fwd.cpp) target_link_libraries(test_grouped_convnd_fwd PRIVATE utility device_grouped_conv1d_fwd_instance device_grouped_conv2d_fwd_instance device_grouped_conv3d_fwd_instance) +add_gtest_executable(test_grouped_convnd_fwd_multi_ab_interface test_grouped_convnd_fwd_multi_ab_interface.cpp) +target_link_libraries(test_grouped_convnd_fwd_multi_ab_interface PRIVATE utility) diff --git a/test/grouped_convnd_fwd/test_grouped_convnd_fwd_multi_ab_interface.cpp b/test/grouped_convnd_fwd/test_grouped_convnd_fwd_multi_ab_interface.cpp new file mode 100644 index 0000000000..de151ea644 --- /dev/null +++ b/test/grouped_convnd_fwd/test_grouped_convnd_fwd_multi_ab_interface.cpp @@ -0,0 +1,235 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_cshuffle.hpp" + +#include "ck/host_utility/device_prop.hpp" + +#include "ck/library/utility/convolution_parameter.hpp" +#include "ck/library/utility/algorithm.hpp" +#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp" + +#include + +template +using S = ck::Sequence; + +using ScaleAdd = ck::tensor_operation::element_wise::ScaleAdd; +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +template +class TestGroupedConvndFwdMultiABInterfaceBase : public ::testing::Test +{ + protected: + static constexpr ck::index_t NDimSpatial = 3; + static constexpr ck::index_t NumAs = 2; + static constexpr ck::index_t NumBs = 2; + static constexpr auto ConvSpec = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; + static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + using InLayout = ck::tensor_layout::convolution::GNDHWC; + using WeiLayout = ck::tensor_layout::convolution::GKZYXC; + using OutLayout = ck::tensor_layout::convolution::GNDHWK; + using OutElementOp = PassThrough; + + using DeviceGroupedConvNDMultiABFwdInstance = + ck::tensor_operation::device::DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< + NDimSpatial, + InLayout, + WeiLayout, + ck::Tuple<>, + OutLayout, + InDataTypes, + WeiDataTypes, + DataType, + DataType, + ck::Tuple<>, + DataType, + InElementOp, + WeiElementOp, + OutElementOp, + ConvSpec, // ConvForwardSpecialization + GemmSpec, // GemmSpecialization + 1, // + 256, // BlockSize + 128, // MPerBlock + 256, // NPerBlock + 32, // KPerBlock + 8, // AK1 + 8, // BK1 + 32, // MPerXdl + 32, // NPerXdl + 2, // MXdlPerWave + 4, // NXdlPerWave + S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1 + S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // ABlockTransferSrcAccessOrder + 2, // ABlockTransferSrcVectorDim + 8, // ABlockTransferSrcScalarPerVector + 8, // ABlockTransferDstScalarPerVector_AK1 + 1, // ABlockLdsExtraM + S<4, 64, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1 + S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // BBlockTransferSrcAccessOrder + 2, // BBlockTransferSrcVectorDim + 8, // BBlockTransferSrcScalarPerVector + 8, // BBlockTransferDstScalarPerVector_BK1 + 1, // BBlockLdsExtraN + 1, + 1, + S<1, 32, 1, 8>, + 8>; + + const ck::utils::conv::ConvParam conv_param{ + 3, 1, 16, 16, 8, {3, 3, 3}, {17, 17, 17}, {2, 2, 2}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}}; + + void SetUp() override + { + if(!ck::is_xdl_supported()) + { + GTEST_SKIP(); + } + } + + template + bool Run(ADataType as, BDataType bs) + { + const auto in_g_n_c_wis_desc = + ck::utils::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed( + conv_param); + const auto wei_g_k_c_xs_desc = + ck::utils::conv::make_weight_host_tensor_descriptor_g_k_c_xs_packed( + conv_param); + const auto out_g_n_k_wos_desc = + ck::utils::conv::make_output_host_tensor_descriptor_g_n_k_wos_packed( + conv_param); + + std::array a_g_n_c_wis_lengths{}; + std::array a_g_n_c_wis_strides{}; + std::array b_g_k_c_xs_lengths{}; + std::array b_g_k_c_xs_strides{}; + std::array e_g_n_k_wos_lengths{}; + std::array e_g_n_k_wos_strides{}; + std::array conv_filter_strides{}; + std::array conv_filter_dilations{}; + std::array input_left_pads{}; + std::array input_right_pads{}; + + auto copy = [](const auto& x, auto& y) { ck::ranges::copy(x, y.begin()); }; + + copy(in_g_n_c_wis_desc.GetLengths(), a_g_n_c_wis_lengths); + copy(in_g_n_c_wis_desc.GetStrides(), a_g_n_c_wis_strides); + copy(wei_g_k_c_xs_desc.GetLengths(), b_g_k_c_xs_lengths); + copy(wei_g_k_c_xs_desc.GetStrides(), b_g_k_c_xs_strides); + copy(out_g_n_k_wos_desc.GetLengths(), e_g_n_k_wos_lengths); + copy(out_g_n_k_wos_desc.GetStrides(), e_g_n_k_wos_strides); + copy(conv_param.conv_filter_strides_, conv_filter_strides); + copy(conv_param.conv_filter_dilations_, conv_filter_dilations); + copy(conv_param.input_left_pads_, input_left_pads); + copy(conv_param.input_right_pads_, input_right_pads); + + std::array ds{}; + + // do Conv + auto conv = DeviceGroupedConvNDMultiABFwdInstance{}; + auto invoker = conv.MakeInvoker(); + auto argument = conv.MakeArgument(as, + bs, + ds, + nullptr, + a_g_n_c_wis_lengths, + a_g_n_c_wis_strides, + b_g_k_c_xs_lengths, + b_g_k_c_xs_strides, + {}, + {}, + e_g_n_k_wos_lengths, + e_g_n_k_wos_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + InElementOp{}, + WeiElementOp{}, + OutElementOp{}); + + return conv.IsSupportedArgument(argument); + } +}; + +class TestGroupedConvndFwdMultiAInterface + : public TestGroupedConvndFwdMultiABInterfaceBase, + float, + ScaleAdd, + PassThrough> +{ +}; + +class TestGroupedConvndFwdMultiBInterface + : public TestGroupedConvndFwdMultiABInterfaceBase, + PassThrough, + ScaleAdd> +{ +}; + +class TestGroupedConvndFwdMultiABInterface + : public TestGroupedConvndFwdMultiABInterfaceBase, + ck::Tuple, + ScaleAdd, + ScaleAdd> +{ +}; + +class TestGroupedConvndFwdInterface + : public TestGroupedConvndFwdMultiABInterfaceBase +{ +}; + +TEST_F(TestGroupedConvndFwdMultiAInterface, MultiA) +{ + std::array as{nullptr, nullptr}; + const void* b = nullptr; + + EXPECT_TRUE(this->template Run(as, b)); +} + +TEST_F(TestGroupedConvndFwdMultiBInterface, MultiB) +{ + const void* a = nullptr; + std::array bs{nullptr, nullptr}; + + EXPECT_TRUE(this->template Run(a, bs)); +} + +TEST_F(TestGroupedConvndFwdMultiABInterface, MultiAB) +{ + std::array as{nullptr, nullptr}; + std::array bs{nullptr, nullptr}; + + EXPECT_TRUE(this->template Run(as, bs)); +} + +TEST_F(TestGroupedConvndFwdInterface, SingleAB) +{ + const void* a = nullptr; + const void* b = nullptr; + + EXPECT_TRUE(this->template Run(a, b)); +}