From c16cff1498c505f8c2f6903d004af2bd07f9dfa1 Mon Sep 17 00:00:00 2001 From: Rostyslav Geyyer <46627076+geyyer@users.noreply.github.com> Date: Tue, 21 May 2024 09:52:41 -0500 Subject: [PATCH] Move grouped conv fwd client examples (#1299) * Move grouped conv fwd client examples * Update existing examples * Format [ROCm/composable_kernel commit: 204da9c522cebec5220bba52cd3542ebcaf99e7a] --- .../07_grouped_convnd_fwd/CMakeLists.txt | 20 +- .../07_grouped_convnd_fwd/common.hpp | 304 ++++++++++++++++++ .../grouped_conv1d_fwd.cpp | 212 +----------- .../grouped_conv2d_fwd.cpp | 180 +---------- .../grouped_conv3d_fwd_bf8.cpp} | 0 .../grouped_conv3d_fwd_bf8_fp8.cpp} | 0 .../grouped_conv3d_fwd_fp8.cpp} | 0 .../grouped_conv3d_fwd_fp8_bf8.cpp} | 0 client_example/16_convnd_fwd/CMakeLists.txt | 16 - 9 files changed, 345 insertions(+), 387 deletions(-) create mode 100644 client_example/07_grouped_convnd_fwd/common.hpp rename client_example/{16_convnd_fwd/conv3d_fwd_bf8.cpp => 07_grouped_convnd_fwd/grouped_conv3d_fwd_bf8.cpp} (100%) rename client_example/{16_convnd_fwd/conv3d_fwd_bf8_fp8.cpp => 07_grouped_convnd_fwd/grouped_conv3d_fwd_bf8_fp8.cpp} (100%) rename client_example/{16_convnd_fwd/conv3d_fwd_fp8.cpp => 07_grouped_convnd_fwd/grouped_conv3d_fwd_fp8.cpp} (100%) rename client_example/{16_convnd_fwd/conv3d_fwd_fp8_bf8.cpp => 07_grouped_convnd_fwd/grouped_conv3d_fwd_fp8_bf8.cpp} (100%) diff --git a/client_example/07_grouped_convnd_fwd/CMakeLists.txt b/client_example/07_grouped_convnd_fwd/CMakeLists.txt index 710eca9f49..e8c046ff44 100644 --- a/client_example/07_grouped_convnd_fwd/CMakeLists.txt +++ b/client_example/07_grouped_convnd_fwd/CMakeLists.txt @@ -4,4 +4,22 @@ if(GPU_TARGETS MATCHES "gfx9") add_executable(client_grouped_conv1d_fwd grouped_conv1d_fwd.cpp) target_link_libraries(client_grouped_conv1d_fwd PRIVATE composable_kernel::device_conv_operations) -endif() \ No newline at end of file + + if((DTYPES MATCHES "fp8") OR NOT DEFINED DTYPES) + add_executable(client_grouped_conv3d_fwd_fp8 grouped_conv3d_fwd_fp8.cpp) + target_link_libraries(client_grouped_conv3d_fwd_fp8 PRIVATE composable_kernel::device_conv_operations) + endif() + + if((DTYPES MATCHES "bf8") OR NOT DEFINED DTYPES) + add_executable(client_grouped_conv3d_fwd_bf8 grouped_conv3d_fwd_bf8.cpp) + target_link_libraries(client_grouped_conv3d_fwd_bf8 PRIVATE composable_kernel::device_conv_operations) + endif() + + if((DTYPES MATCHES "fp8" AND DTYPES MATCHES "bf8") OR NOT DEFINED DTYPES) + add_executable(client_grouped_conv3d_fwd_fp8_bf8 grouped_conv3d_fwd_fp8_bf8.cpp) + target_link_libraries(client_grouped_conv3d_fwd_fp8_bf8 PRIVATE composable_kernel::device_conv_operations) + + add_executable(client_grouped_conv3d_fwd_bf8_fp8 grouped_conv3d_fwd_bf8_fp8.cpp) + target_link_libraries(client_grouped_conv3d_fwd_bf8_fp8 PRIVATE composable_kernel::device_conv_operations) + endif() +endif() diff --git a/client_example/07_grouped_convnd_fwd/common.hpp b/client_example/07_grouped_convnd_fwd/common.hpp new file mode 100644 index 0000000000..729af0b88b --- /dev/null +++ b/client_example/07_grouped_convnd_fwd/common.hpp @@ -0,0 +1,304 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_abd.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +struct SimpleDeviceMem +{ + SimpleDeviceMem() = delete; + + SimpleDeviceMem(std::size_t mem_size) : p_mem_{} + { + (void)hipMalloc(static_cast(&p_mem_), mem_size); + } + + void* GetDeviceBuffer() { return p_mem_; } + + ~SimpleDeviceMem() { (void)hipFree(p_mem_); } + + void* p_mem_; +}; + +template +std::size_t +GetFlops(const std::array& output_lengths, + const std::array& weights_lengths) +{ + // 2 * G * N * K * C * * + ck::index_t G = weights_lengths[0]; + ck::index_t N = output_lengths[1]; + ck::index_t K = weights_lengths[1]; + ck::index_t C = weights_lengths[2]; + + return static_cast(2) * G * N * K * C * + std::accumulate(std::next(std::begin(output_lengths), NumNonSpatialDim), + std::end(output_lengths), + static_cast(1), + std::multiplies<>()) * + std::accumulate(std::next(std::begin(weights_lengths), NumNonSpatialDim), + std::end(weights_lengths), + static_cast(1), + std::multiplies<>()); +} + +template +std::size_t +GetInputByte(const std::array& input_lengths) +{ + // sizeof(InDataType) * (G * N * C * ) + + return sizeof(InDataType) * std::accumulate(std::begin(input_lengths), + std::end(input_lengths), + static_cast(1), + std::multiplies<>()); +} + +template +std::size_t +GetWeightByte(const std::array& weights_lengths) +{ + // sizeof(WeiDataType) * (G * K * C * ) + + return sizeof(WeiDataType) * std::accumulate(std::begin(weights_lengths), + std::end(weights_lengths), + static_cast(1), + std::multiplies<>()); +} + +template +std::size_t +GetOutputByte(const std::array& output_lengths) +{ + // sizeof(OutDataType) * (G * N * K * ); + return sizeof(OutDataType) * std::accumulate(std::begin(output_lengths), + std::end(output_lengths), + static_cast(1), + std::multiplies()); +} + +template +bool run_grouped_conv_fwd(std::array in_lengths, + std::array wei_lengths, + std::array out_lengths) +{ + std::size_t in_mem_size = GetInputByte(in_lengths); + std::size_t wei_mem_size = GetWeightByte(wei_lengths); + std::size_t out_mem_size = GetOutputByte(out_lengths); + + SimpleDeviceMem in(in_mem_size); + SimpleDeviceMem wei(wei_mem_size); + SimpleDeviceMem out(out_mem_size); + + std::array in_strides; + std::array wei_strides; + std::array out_strides; + in_strides.fill(0); + wei_strides.fill(0); + out_strides.fill(0); + in_strides.back() = 1; + wei_strides.back() = 1; + out_strides.back() = 1; + + std::partial_sum(rbegin(in_lengths), + std::prev(rend(in_lengths)), + std::next(rbegin(in_strides)), + std::multiplies<>{}); + std::partial_sum(rbegin(wei_lengths), + std::prev(rend(wei_lengths)), + std::next(rbegin(wei_strides)), + std::multiplies<>{}); + std::partial_sum(rbegin(out_lengths), + std::prev(rend(out_lengths)), + std::next(rbegin(out_strides)), + std::multiplies<>{}); + + // transpose NDHWGC/KZYXGC/NDHWGK to GNDHWC/GKZYXC/GNDHWK to GNCDHW/GKCZYX/GNKDHW + std::rotate(std::next(rbegin(in_lengths)), std::next(rbegin(in_lengths), 2), rend(in_lengths)); + std::rotate(rbegin(in_lengths), + std::next(rbegin(in_lengths)), + std::next(rbegin(in_lengths), NumDimSpatial + 1)); + + std::rotate(std::next(rbegin(in_strides)), std::next(rbegin(in_strides), 2), rend(in_strides)); + std::rotate(rbegin(in_strides), + std::next(rbegin(in_strides)), + std::next(rbegin(in_strides), NumDimSpatial + 1)); + + std::rotate(rbegin(wei_lengths), + std::next(rbegin(wei_lengths)), + std::next(rbegin(wei_lengths), NumDimSpatial + 1)); + + std::rotate(rbegin(wei_strides), + std::next(rbegin(wei_strides)), + std::next(rbegin(wei_strides), NumDimSpatial + 1)); + + std::rotate( + std::next(rbegin(out_lengths)), std::next(rbegin(out_lengths), 2), rend(out_lengths)); + std::rotate(rbegin(out_lengths), + std::next(rbegin(out_lengths)), + std::next(rbegin(out_lengths), NumDimSpatial + 1)); + + std::rotate( + std::next(rbegin(out_strides)), std::next(rbegin(out_strides), 2), rend(out_strides)); + std::rotate(rbegin(out_strides), + std::next(rbegin(out_strides)), + std::next(rbegin(out_strides), NumDimSpatial + 1)); + + std::array conv_filter_strides; + std::array conv_filter_dilations; + std::array input_left_pads; + std::array input_right_pads; + conv_filter_strides.fill(1); + conv_filter_dilations.fill(1); + input_left_pads.fill(1); + input_right_pads.fill(1); + + std::size_t flop = GetFlops(out_lengths, wei_lengths); + std::size_t num_bytes = in_mem_size + wei_mem_size + out_mem_size; + + using DeviceOp = ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD, + OutLayout, + InDataType, + WeiDataType, + ck::Tuple<>, + OutDataType, + PassThrough, + PassThrough, + PassThrough, + AComputeType, + BComputeType>; + // get device op instances + const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + std::cout << "found " << op_ptrs.size() << " instances" << std::endl; + + std::string best_op_name; + int best_op_id = -1; + float best_avg_time = std::numeric_limits::max(); + float best_gb_per_sec = 0; + float best_tflops = 0; + + // profile device operation instances + std::cout << "Run all instances and do timing" << std::endl; + + for(int i = 0; i < op_ptrs.size(); ++i) + { + auto& op_ptr = op_ptrs[i]; + auto argument_ptr = op_ptr->MakeArgumentPointer( + in.GetDeviceBuffer(), + wei.GetDeviceBuffer(), + std::array{}, + out.GetDeviceBuffer(), + in_lengths, + in_strides, + wei_lengths, + wei_strides, + std::array, 0>{{}}, + std::array, 0>{{}}, + out_lengths, + out_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + PassThrough{}, + PassThrough{}, + PassThrough{}); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + std::string op_name = op_ptr->GetTypeString(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + float avg_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true}); + + float tflops = static_cast(flop) / 1.E9 / avg_time; + float gb_per_sec = num_bytes / 1.E6 / avg_time; + + std::cout << "Perf: " << std::setw(10) << avg_time << " ms, " << tflops << " TFlops, " + << gb_per_sec << " GB/s, " << op_name << std::endl; + + if(tflops > best_tflops) + { + best_op_id = i; + best_op_name = op_name; + best_avg_time = avg_time; + best_gb_per_sec = gb_per_sec; + best_tflops = tflops; + } + } + else + { + std::cerr << op_name << " does not support this problem" << std::endl; + } + } + + if(best_op_id < 0) + { + std::cerr << "no suitable instance" << std::endl; + return false; + } + + std::cout << "Best Perf: " << std::setw(10) << best_avg_time << " ms, " << best_tflops + << " TFlops, " << best_gb_per_sec << " GB/s, " << best_op_name << std::endl; + + // run the best intance + { + auto& op_ptr = op_ptrs[best_op_id]; + std::cout << "Run the best instance without timing: " << op_ptr->GetTypeString() + << std::endl; + auto argument_ptr = op_ptr->MakeArgumentPointer( + in.GetDeviceBuffer(), + wei.GetDeviceBuffer(), + std::array{}, + out.GetDeviceBuffer(), + in_lengths, + in_strides, + wei_lengths, + wei_strides, + std::array, 0>{{}}, + std::array, 0>{{}}, + out_lengths, + out_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + PassThrough{}, + PassThrough{}, + PassThrough{}); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false}); + } + + std::cout << "Done" << std::endl; + } + return true; +} diff --git a/client_example/07_grouped_convnd_fwd/grouped_conv1d_fwd.cpp b/client_example/07_grouped_convnd_fwd/grouped_conv1d_fwd.cpp index 4983ac33c3..d3a3111e94 100644 --- a/client_example/07_grouped_convnd_fwd/grouped_conv1d_fwd.cpp +++ b/client_example/07_grouped_convnd_fwd/grouped_conv1d_fwd.cpp @@ -1,17 +1,10 @@ // SPDX-License-Identifier: MIT // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. -#include -#include -#include -#include -#include -#include +#include "common.hpp" #include "ck/ck.hpp" -#include "ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" -#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" using InDataType = ck::half_t; using WeiDataType = ck::half_t; @@ -31,199 +24,16 @@ static constexpr ck::index_t X = 3; static constexpr ck::index_t Wi = 28; static constexpr ck::index_t Wo = 28; -struct SimpleDeviceMem -{ - SimpleDeviceMem() = delete; - - SimpleDeviceMem(std::size_t mem_size) : p_mem_{} - { - (void)hipMalloc(static_cast(&p_mem_), mem_size); - } - - void* GetDeviceBuffer() { return p_mem_; } - - ~SimpleDeviceMem() { (void)hipFree(p_mem_); } - - void* p_mem_; -}; - int main() { - std::array in_lengths{G, N, Wi, C}; - std::array in_strides{0, 0, 0, 1}; - - std::array wei_lengths{G, K, X, C}; - std::array wei_strides{0, 0, 0, 1}; - - std::array out_lengths{G, N, Wo, K}; - std::array out_strides{0, 0, 0, 1}; - - std::partial_sum(rbegin(in_lengths), - std::prev(rend(in_lengths)), - std::next(rbegin(in_strides)), - std::multiplies<>{}); - std::partial_sum(rbegin(wei_lengths), - std::prev(rend(wei_lengths)), - std::next(rbegin(wei_strides)), - std::multiplies<>{}); - std::partial_sum(rbegin(out_lengths), - std::prev(rend(out_lengths)), - std::next(rbegin(out_strides)), - std::multiplies<>{}); - - // transpose GNWC/GKXC/GNWK to GNCW/GKCX/GNCW - std::rotate(rbegin(in_lengths), - std::next(rbegin(in_lengths)), - std::next(rbegin(in_lengths), NumDimSpatial + 1)); - std::rotate(rbegin(in_strides), - std::next(rbegin(in_strides)), - std::next(rbegin(in_strides), NumDimSpatial + 1)); - std::rotate(rbegin(wei_lengths), - std::next(rbegin(wei_lengths)), - std::next(rbegin(wei_lengths), NumDimSpatial + 1)); - std::rotate(rbegin(wei_strides), - std::next(rbegin(wei_strides)), - std::next(rbegin(wei_strides), NumDimSpatial + 1)); - std::rotate(rbegin(out_lengths), - std::next(rbegin(out_lengths)), - std::next(rbegin(out_lengths), NumDimSpatial + 1)); - std::rotate(rbegin(out_strides), - std::next(rbegin(out_strides)), - std::next(rbegin(out_strides), NumDimSpatial + 1)); - - std::array filter_strides{1}; - std::array filter_dilations{1}; - std::array input_left_pads{1}; - std::array input_right_pads{1}; - - SimpleDeviceMem in(sizeof(InDataType) * G * N * Wi * C); - SimpleDeviceMem wei(sizeof(WeiDataType) * G * K * X * C); - SimpleDeviceMem out(sizeof(OutDataType) * G * N * Wo * K); - - using DeviceOp = ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD, - OutLayout, - InDataType, - WeiDataType, - ck::Tuple<>, - OutDataType, - PassThrough, - PassThrough, - PassThrough>; - - // get device op instances - const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< - DeviceOp>::GetInstances(); - - std::cout << "found " << op_ptrs.size() << " instances" << std::endl; - - std::string best_op_name; - int best_op_id = -1; - float best_avg_time = std::numeric_limits::max(); - float best_gb_per_sec = 0; - float best_tflops = 0; - - // profile device operation instances - std::cout << "Run all instances and do timing" << std::endl; - - for(int i = 0; i < op_ptrs.size(); ++i) - { - auto& op_ptr = op_ptrs[i]; - auto argument_ptr = op_ptr->MakeArgumentPointer(in.GetDeviceBuffer(), - wei.GetDeviceBuffer(), - {}, - out.GetDeviceBuffer(), - in_lengths, - in_strides, - wei_lengths, - wei_strides, - {}, - {}, - out_lengths, - out_strides, - filter_strides, - filter_dilations, - input_left_pads, - input_right_pads, - PassThrough{}, - PassThrough{}, - PassThrough{}); - auto invoker_ptr = op_ptr->MakeInvokerPointer(); - std::string op_name = op_ptr->GetTypeString(); - - if(op_ptr->IsSupportedArgument(argument_ptr.get())) - { - float avg_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true}); - - std::size_t flop = std::size_t(2) * G * N * K * C * Wo * X; - std::size_t num_bytes = sizeof(InDataType) * G * N * Wi * C + - sizeof(WeiDataType) * G * K * X * C + - sizeof(OutDataType) * G * N * Wo * K; - - float tflops = static_cast(flop) / 1.E9 / avg_time; - float gb_per_sec = num_bytes / 1.E6 / avg_time; - - std::cout << "Perf: " << std::setw(10) << avg_time << " ms, " << tflops << " TFlops, " - << gb_per_sec << " GB/s, " << op_name << std::endl; - - if(tflops > best_tflops) - { - best_op_id = i; - best_op_name = op_name; - best_avg_time = avg_time; - best_gb_per_sec = gb_per_sec; - best_tflops = tflops; - } - } - else - { - std::cerr << op_name << " does not support this problem" << std::endl; - } - } - - if(best_op_id < 0) - { - std::cerr << "no suitable instance" << std::endl; - return EXIT_FAILURE; - } - - std::cout << "Best Perf: " << std::setw(10) << best_avg_time << " ms, " << best_tflops - << " TFlops, " << best_gb_per_sec << " GB/s, " << best_op_name << std::endl; - - // run the best intance - { - auto& op_ptr = op_ptrs[best_op_id]; - std::cout << "Run the best instance without timing: " << op_ptr->GetTypeString() - << std::endl; - auto argument_ptr = op_ptr->MakeArgumentPointer(in.GetDeviceBuffer(), - wei.GetDeviceBuffer(), - {}, - out.GetDeviceBuffer(), - in_lengths, - in_strides, - wei_lengths, - wei_strides, - {}, - {}, - out_lengths, - out_strides, - filter_strides, - filter_dilations, - input_left_pads, - input_right_pads, - PassThrough{}, - PassThrough{}, - PassThrough{}); - - auto invoker_ptr = op_ptr->MakeInvokerPointer(); - - if(op_ptr->IsSupportedArgument(argument_ptr.get())) - { - invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false}); - } - - std::cout << "Done" << std::endl; - } + return run_grouped_conv_fwd({N, Wi, G, C}, {G, K, X, C}, {N, Wo, G, K}) + ? EXIT_SUCCESS + : EXIT_FAILURE; } diff --git a/client_example/07_grouped_convnd_fwd/grouped_conv2d_fwd.cpp b/client_example/07_grouped_convnd_fwd/grouped_conv2d_fwd.cpp index 9383350629..fb8a410ab3 100644 --- a/client_example/07_grouped_convnd_fwd/grouped_conv2d_fwd.cpp +++ b/client_example/07_grouped_convnd_fwd/grouped_conv2d_fwd.cpp @@ -1,17 +1,10 @@ // SPDX-License-Identifier: MIT // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. -#include -#include -#include -#include -#include -#include +#include "common.hpp" #include "ck/ck.hpp" -#include "ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" -#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" using InDataType = ck::half_t; using WeiDataType = ck::half_t; @@ -34,167 +27,16 @@ static constexpr ck::index_t Wi = 28; // input W static constexpr ck::index_t Ho = 28; // output H static constexpr ck::index_t Wo = 28; // output W -struct SimpleDeviceMem -{ - SimpleDeviceMem() = delete; - - SimpleDeviceMem(std::size_t mem_size) : p_mem_{} - { - (void)hipMalloc(static_cast(&p_mem_), mem_size); - } - - void* GetDeviceBuffer() { return p_mem_; } - - ~SimpleDeviceMem() { (void)hipFree(p_mem_); } - - void* p_mem_; -}; - int main() { - // We have NHWGC/GKYXC/NHWGK (x, weight, y) in memory space - // However, CK's API only accept length and stride with order of GNCHW/GKCYX/GNCHW - // Hence, we need to adjust the order of stride - std::array in_lengths{G, N, C, Hi, Wi}; - std::array in_strides{C, Hi * Wi * G * C, 1, Wi * G * C, G * C}; - std::array wei_lengths{G, K, C, Y, X}; - std::array wei_strides{K * Y * X * C, Y * X * C, 1, X * C, C}; - std::array out_lengths{G, N, K, Ho, Wo}; - std::array out_strides{C, Ho * Wo * G * C, 1, Wo * G * C, G * C}; - - std::array filter_strides{1, 1}; - std::array filter_dilations{1, 1}; - std::array input_left_pads{1, 1}; - std::array input_right_pads{1, 1}; - - SimpleDeviceMem in(sizeof(InDataType) * N * Hi * Wi * G * C); - SimpleDeviceMem wei(sizeof(WeiDataType) * G * K * Y * X * C); - SimpleDeviceMem out(sizeof(OutDataType) * N * Ho * Wo * G * K); - - using DeviceOp = ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD, - OutLayout, - InDataType, - WeiDataType, - ck::Tuple<>, - OutDataType, - PassThrough, - PassThrough, - PassThrough>; - - // get device op instances - const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< - DeviceOp>::GetInstances(); - - std::cout << "found " << op_ptrs.size() << " instances" << std::endl; - - std::string best_op_name; - int best_op_id = -1; - float best_avg_time = std::numeric_limits::max(); - float best_gb_per_sec = 0; - float best_tflops = 0; - - // profile device operation instances - std::cout << "Run all instances and do timing" << std::endl; - - for(int i = 0; i < op_ptrs.size(); ++i) - { - auto& op_ptr = op_ptrs[i]; - auto argument_ptr = op_ptr->MakeArgumentPointer(in.GetDeviceBuffer(), - wei.GetDeviceBuffer(), - {}, - out.GetDeviceBuffer(), - in_lengths, - in_strides, - wei_lengths, - wei_strides, - {}, - {}, - out_lengths, - out_strides, - filter_strides, - filter_dilations, - input_left_pads, - input_right_pads, - PassThrough{}, - PassThrough{}, - PassThrough{}); - auto invoker_ptr = op_ptr->MakeInvokerPointer(); - std::string op_name = op_ptr->GetTypeString(); - - if(op_ptr->IsSupportedArgument(argument_ptr.get())) - { - float avg_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true}); - - std::size_t flop = std::size_t(2) * G * N * K * C * Ho * Wo * Y * X; - std::size_t num_bytes = sizeof(InDataType) * N * Hi * Wi * G * C + - sizeof(WeiDataType) * G * K * Y * X * C + - sizeof(OutDataType) * N * Ho * Wo * G * K; - - float tflops = static_cast(flop) / 1.E9 / avg_time; - float gb_per_sec = num_bytes / 1.E6 / avg_time; - - std::cout << "Perf: " << std::setw(10) << avg_time << " ms, " << tflops << " TFlops, " - << gb_per_sec << " GB/s, " << op_name << std::endl; - - if(tflops > best_tflops) - { - best_op_id = i; - best_op_name = op_name; - best_avg_time = avg_time; - best_gb_per_sec = gb_per_sec; - best_tflops = tflops; - } - } - else - { - std::cerr << op_name << " does not support this problem" << std::endl; - } - } - - if(best_op_id < 0) - { - std::cerr << "no suitable instance" << std::endl; - return EXIT_FAILURE; - } - - std::cout << "Best Perf: " << std::setw(10) << best_avg_time << " ms, " << best_tflops - << " TFlops, " << best_gb_per_sec << " GB/s, " << best_op_name << std::endl; - - // run the best intance - { - auto& op_ptr = op_ptrs[best_op_id]; - std::cout << "Run the best instance without timing: " << op_ptr->GetTypeString() - << std::endl; - auto argument_ptr = op_ptr->MakeArgumentPointer(in.GetDeviceBuffer(), - wei.GetDeviceBuffer(), - {}, - out.GetDeviceBuffer(), - in_lengths, - in_strides, - wei_lengths, - wei_strides, - {}, - {}, - out_lengths, - out_strides, - filter_strides, - filter_dilations, - input_left_pads, - input_right_pads, - PassThrough{}, - PassThrough{}, - PassThrough{}); - - auto invoker_ptr = op_ptr->MakeInvokerPointer(); - - if(op_ptr->IsSupportedArgument(argument_ptr.get())) - { - invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false}); - } - - std::cout << "Done" << std::endl; - } + return run_grouped_conv_fwd({N, Hi, Wi, G, C}, {G, K, Y, X, C}, {N, Ho, Wo, G, K}) + ? EXIT_SUCCESS + : EXIT_FAILURE; } diff --git a/client_example/16_convnd_fwd/conv3d_fwd_bf8.cpp b/client_example/07_grouped_convnd_fwd/grouped_conv3d_fwd_bf8.cpp similarity index 100% rename from client_example/16_convnd_fwd/conv3d_fwd_bf8.cpp rename to client_example/07_grouped_convnd_fwd/grouped_conv3d_fwd_bf8.cpp diff --git a/client_example/16_convnd_fwd/conv3d_fwd_bf8_fp8.cpp b/client_example/07_grouped_convnd_fwd/grouped_conv3d_fwd_bf8_fp8.cpp similarity index 100% rename from client_example/16_convnd_fwd/conv3d_fwd_bf8_fp8.cpp rename to client_example/07_grouped_convnd_fwd/grouped_conv3d_fwd_bf8_fp8.cpp diff --git a/client_example/16_convnd_fwd/conv3d_fwd_fp8.cpp b/client_example/07_grouped_convnd_fwd/grouped_conv3d_fwd_fp8.cpp similarity index 100% rename from client_example/16_convnd_fwd/conv3d_fwd_fp8.cpp rename to client_example/07_grouped_convnd_fwd/grouped_conv3d_fwd_fp8.cpp diff --git a/client_example/16_convnd_fwd/conv3d_fwd_fp8_bf8.cpp b/client_example/07_grouped_convnd_fwd/grouped_conv3d_fwd_fp8_bf8.cpp similarity index 100% rename from client_example/16_convnd_fwd/conv3d_fwd_fp8_bf8.cpp rename to client_example/07_grouped_convnd_fwd/grouped_conv3d_fwd_fp8_bf8.cpp diff --git a/client_example/16_convnd_fwd/CMakeLists.txt b/client_example/16_convnd_fwd/CMakeLists.txt index 23311b4024..5279e3dfcf 100644 --- a/client_example/16_convnd_fwd/CMakeLists.txt +++ b/client_example/16_convnd_fwd/CMakeLists.txt @@ -7,22 +7,6 @@ endif() if((DTYPES MATCHES "fp8") OR NOT DEFINED DTYPES) add_executable(client_conv3d_fwd_fp16_comp_fp8 conv3d_fwd_fp16_comp_fp8.cpp) target_link_libraries(client_conv3d_fwd_fp16_comp_fp8 PRIVATE composable_kernel::device_conv_operations) - - add_executable(client_conv3d_fwd_fp8 conv3d_fwd_fp8.cpp) - target_link_libraries(client_conv3d_fwd_fp8 PRIVATE composable_kernel::device_conv_operations) -endif() - -if((DTYPES MATCHES "bf8") OR NOT DEFINED DTYPES) - add_executable(client_conv3d_fwd_bf8 conv3d_fwd_bf8.cpp) - target_link_libraries(client_conv3d_fwd_bf8 PRIVATE composable_kernel::device_conv_operations) -endif() - -if((DTYPES MATCHES "fp8" AND DTYPES MATCHES "bf8") OR NOT DEFINED DTYPES) - add_executable(client_conv3d_fwd_fp8_bf8 conv3d_fwd_fp8_bf8.cpp) - target_link_libraries(client_conv3d_fwd_fp8_bf8 PRIVATE composable_kernel::device_conv_operations) - - add_executable(client_conv3d_fwd_bf8_fp8 conv3d_fwd_bf8_fp8.cpp) - target_link_libraries(client_conv3d_fwd_bf8_fp8 PRIVATE composable_kernel::device_conv_operations) endif() if((DTYPES MATCHES "fp32") OR NOT DEFINED DTYPES)