From 8e868bf880189be00b202104713f56d900d2e6bc Mon Sep 17 00:00:00 2001 From: rocking5566 Date: Thu, 1 Dec 2022 04:13:04 +0800 Subject: [PATCH] gemm, conv perchannel quantization (#503) * Use gemm_multiple_D instead * Add gemm bias relu quantization example * Add pure gemm quantization example * Add quantization of perchannel conv + bias + relu example * Refine the code * Rename multiplier to requant_scale * Rename the folder * Remove redundant comment * Rename the file. Prepare to add perchannel * Add conv perchannel instance * Move to quantization folder * Add conv perchannel client example * Apply Rangify constructor of HostTensorDescriptor & Tensor<> * Fix merge error [ROCm/composable_kernel commit: ad541ad6b9de9b0579d5254f82e9d5b86103d309] --- client_example/09_quantization/CMakeLists.txt | 6 + ..._fwd_bias_relu_perchannel_quantization.cpp | 205 +++++++++++ ...2d_fwd_bias_relu_perlayer_quantization.cpp | 2 +- .../conv2d_fwd_perchannel_quantization.cpp | 198 ++++++++++ .../conv2d_fwd_perlayer_quantization.cpp | 2 +- example/14_gemm_quantization/CMakeLists.txt | 2 + .../gemm_xdl_bias_relu_quantization_int8.cpp | 235 ++++++++++++ .../gemm_xdl_quantization_int8.cpp | 207 +++++++++++ .../14_gemm_xdl_quantization/CMakeLists.txt | 1 - .../gemm_xdl_relu_quantization_int8.cpp | 233 ------------ .../CMakeLists.txt | 1 + ...bias_relu_perchannel_quantization_int8.cpp | 342 ++++++++++++++++++ ...l_bias_relu_perlayer_quantization_int8.cpp | 44 +-- ...v2d_fwd_xdl_perlayer_quantization_int8.cpp | 11 +- .../gpu/element/quantization_operation.hpp | 76 +++- .../device_operation_instance_factory.hpp | 18 +- ...n_bias_forward_perchannel_quantization.hpp | 114 ++++++ ...ion_bias_forward_perlayer_quantization.hpp | 6 +- ...lution_forward_perchannel_quantization.hpp | 113 ++++++ ...volution_forward_perlayer_quantization.hpp | 0 .../gpu/quantization/CMakeLists.txt | 6 +- ..._perchannel_quantization_int8_instance.cpp | 74 ++++ ...as_perlayer_quantization_int8_instance.cpp | 68 ++++ ...ce_conv2d_xdl_bias_quant_int8_instance.cpp | 112 ------ .../device_conv2d_xdl_int8_instance.hpp | 111 ++++++ ..._perchannel_quantization_int8_instance.cpp | 62 ++++ ...dl_perlayer_quantization_int8_instance.cpp | 62 ++++ .../device_conv2d_xdl_quant_int8_instance.cpp | 109 ------ 28 files changed, 1907 insertions(+), 513 deletions(-) create mode 100644 client_example/09_quantization/conv2d_fwd_bias_relu_perchannel_quantization.cpp create mode 100644 client_example/09_quantization/conv2d_fwd_perchannel_quantization.cpp create mode 100644 example/14_gemm_quantization/CMakeLists.txt create mode 100644 example/14_gemm_quantization/gemm_xdl_bias_relu_quantization_int8.cpp create mode 100644 example/14_gemm_quantization/gemm_xdl_quantization_int8.cpp delete mode 100644 example/14_gemm_xdl_quantization/CMakeLists.txt delete mode 100644 example/14_gemm_xdl_quantization/gemm_xdl_relu_quantization_int8.cpp rename example/{44_conv2d_fwd_quant => 44_conv2d_fwd_quantization}/CMakeLists.txt (65%) create mode 100644 example/44_conv2d_fwd_quantization/conv2d_fwd_xdl_bias_relu_perchannel_quantization_int8.cpp rename example/{44_conv2d_fwd_quant => 44_conv2d_fwd_quantization}/conv2d_fwd_xdl_bias_relu_perlayer_quantization_int8.cpp (90%) rename example/{44_conv2d_fwd_quant => 44_conv2d_fwd_quantization}/conv2d_fwd_xdl_perlayer_quantization_int8.cpp (96%) create mode 100644 library/include/ck/library/tensor_operation_instance/gpu/quantization/grouped_convolution_bias_forward_perchannel_quantization.hpp rename library/include/ck/library/tensor_operation_instance/gpu/{ => quantization}/grouped_convolution_bias_forward_perlayer_quantization.hpp (98%) create mode 100644 library/include/ck/library/tensor_operation_instance/gpu/quantization/grouped_convolution_forward_perchannel_quantization.hpp rename library/include/ck/library/tensor_operation_instance/gpu/{ => quantization}/grouped_convolution_forward_perlayer_quantization.hpp (100%) create mode 100644 library/src/tensor_operation_instance/gpu/quantization/device_conv2d_xdl_bias_perchannel_quantization_int8_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/quantization/device_conv2d_xdl_bias_perlayer_quantization_int8_instance.cpp delete mode 100644 library/src/tensor_operation_instance/gpu/quantization/device_conv2d_xdl_bias_quant_int8_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/quantization/device_conv2d_xdl_int8_instance.hpp create mode 100644 library/src/tensor_operation_instance/gpu/quantization/device_conv2d_xdl_perchannel_quantization_int8_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/quantization/device_conv2d_xdl_perlayer_quantization_int8_instance.cpp delete mode 100644 library/src/tensor_operation_instance/gpu/quantization/device_conv2d_xdl_quant_int8_instance.cpp diff --git a/client_example/09_quantization/CMakeLists.txt b/client_example/09_quantization/CMakeLists.txt index eceaa84174..7dc9b860c0 100644 --- a/client_example/09_quantization/CMakeLists.txt +++ b/client_example/09_quantization/CMakeLists.txt @@ -1,5 +1,11 @@ +add_executable(client_conv2d_fwd_bias_relu_perchannel_quantization conv2d_fwd_bias_relu_perchannel_quantization.cpp) +target_link_libraries(client_conv2d_fwd_bias_relu_perchannel_quantization PRIVATE composable_kernel::device_operations) + add_executable(client_conv2d_fwd_bias_relu_perlayer_quantization conv2d_fwd_bias_relu_perlayer_quantization.cpp) target_link_libraries(client_conv2d_fwd_bias_relu_perlayer_quantization PRIVATE composable_kernel::device_operations) +add_executable(client_conv2d_fwd_perchannel_quantization conv2d_fwd_perchannel_quantization.cpp) +target_link_libraries(client_conv2d_fwd_perchannel_quantization PRIVATE composable_kernel::device_operations) + add_executable(client_conv2d_fwd_perlayer_quantization conv2d_fwd_perlayer_quantization.cpp) target_link_libraries(client_conv2d_fwd_perlayer_quantization PRIVATE composable_kernel::device_operations) diff --git a/client_example/09_quantization/conv2d_fwd_bias_relu_perchannel_quantization.cpp b/client_example/09_quantization/conv2d_fwd_bias_relu_perchannel_quantization.cpp new file mode 100644 index 0000000000..bcb0cefa71 --- /dev/null +++ b/client_example/09_quantization/conv2d_fwd_bias_relu_perchannel_quantization.cpp @@ -0,0 +1,205 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/library/tensor_operation_instance/gpu/quantization/grouped_convolution_bias_forward_perchannel_quantization.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_conv_fwd.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +using InDataType = int8_t; +using WeiDataType = int8_t; +using BiasDataType = int32_t; +using RequantScaleDataType = float; +using OutDataType = int8_t; + +using InLayout = ck::tensor_layout::convolution::GNHWC; +using WeiLayout = ck::tensor_layout::convolution::GKYXC; +using BiasLayout = ck::tensor_layout::convolution::G_K; +using RequantScaleLayout = ck::tensor_layout::convolution::G_K; +using OutLayout = ck::tensor_layout::convolution::GNHWK; +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using ActivationOp = ck::tensor_operation::element_wise::Relu; +using OutElementOp = ck::tensor_operation::element_wise::Add_Activation_Mul2_Clamp; + +static constexpr ck::index_t NumDimSpatial = 2; +static constexpr ck::index_t G = 1; +static constexpr ck::index_t N = 4; +static constexpr ck::index_t K = 64; +static constexpr ck::index_t C = 32; +static constexpr ck::index_t Y = 3; +static constexpr ck::index_t X = 3; +static constexpr ck::index_t Hi = 71; +static constexpr ck::index_t Wi = 71; +static constexpr ck::index_t Ho = 36; +static constexpr ck::index_t Wo = 36; + +struct SimpleDeviceMem +{ + SimpleDeviceMem() = delete; + + SimpleDeviceMem(std::size_t mem_size) : p_mem_{} + { + (void)hipMalloc(static_cast(&p_mem_), mem_size); + } + + void* GetDeviceBuffer() { return p_mem_; } + + ~SimpleDeviceMem() { (void)hipFree(p_mem_); } + + void* p_mem_; +}; + +int main(int argc, char* argv[]) +{ + std::array in_lengths{G, N, C, Hi, Wi}; + std::array in_strides{N * Hi * Wi * C, Hi * Wi * C, 1, Wi * C, C}; + std::array weight_lengths{G, K, C, Y, X}; + std::array weight_strides{K * Y * X * C, Y * X * C, 1, X * C, C}; + std::array bias_lengths{G, N, K, Ho, Wo}; + std::array bias_strides{K, 0, 1, 0, 0}; + std::array requant_scale_lengths{G, N, K, Ho, Wo}; + std::array requant_scale_strides{K, 0, 1, 0, 0}; + std::array out_lengths{G, N, C, Ho, Wo}; + std::array out_strides{N * Ho * Wo * C, Ho * Wo * C, 1, Wo * C, C}; + std::array in_left_pad{1, 1}; + std::array in_right_pad{1, 1}; + std::array conv_strides{2, 2}; + std::array conv_dilations{1, 1}; + + SimpleDeviceMem in(sizeof(InDataType) * N * Hi * Wi * C); + SimpleDeviceMem wei(sizeof(WeiDataType) * K * Y * X * C); + SimpleDeviceMem bias(sizeof(BiasDataType) * K * Y * X * C); + SimpleDeviceMem requant_scale(sizeof(RequantScaleDataType) * K * Y * X * C); + SimpleDeviceMem out(sizeof(OutDataType) * N * Ho * Wo * K); + + using DeviceOp = ck::tensor_operation::device::DeviceGroupedConvFwdMultipleD< + NumDimSpatial, + InLayout, + WeiLayout, + ck::Tuple, + OutLayout, + InDataType, + WeiDataType, + ck::Tuple, + OutDataType, + PassThrough, + PassThrough, + OutElementOp>; + // get device op instances + const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + std::cout << "found " << op_ptrs.size() << " instances" << std::endl; + + std::string best_op_name; + int best_op_id = -1; + float best_avg_time = std::numeric_limits::max(); + float best_gb_per_sec = 0; + float best_tflops = 0; + + // profile device operation instances + std::cout << "Run all instances and do timing" << std::endl; + + for(int i = 0; i < op_ptrs.size(); ++i) + { + auto& op_ptr = op_ptrs[i]; + auto argument_ptr = + op_ptr->MakeArgumentPointer(in.GetDeviceBuffer(), + wei.GetDeviceBuffer(), + {bias.GetDeviceBuffer(), requant_scale.GetDeviceBuffer()}, + out.GetDeviceBuffer(), + in_lengths, + in_strides, + weight_lengths, + weight_strides, + {bias_lengths, requant_scale_lengths}, + {bias_strides, requant_scale_strides}, + out_lengths, + out_strides, + conv_strides, + conv_dilations, + in_left_pad, + in_right_pad, + PassThrough{}, + PassThrough{}, + OutElementOp{ActivationOp{}}); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + std::string op_name = op_ptr->GetTypeString(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + float avg_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true}); + + std::size_t flop = G * 2 * N * K * C * Ho * Wo * Y * X; + std::size_t num_bytes = G * sizeof(InDataType) * N * Hi * Wi * C + + G * sizeof(WeiDataType) * K * Y * X * C + + G * sizeof(OutDataType) * N * Ho * Wo * K; + + float tflops = static_cast(flop) / 1.E9 / avg_time; + float gb_per_sec = num_bytes / 1.E6 / avg_time; + + std::cout << "Perf: " << std::setw(10) << avg_time << " ms, " << tflops << " TFlops, " + << gb_per_sec << " GB/s, " << op_name << std::endl; + + if(tflops > best_tflops) + { + best_op_id = i; + best_op_name = op_name; + best_avg_time = avg_time; + best_gb_per_sec = gb_per_sec; + best_tflops = tflops; + } + } + else + { + std::cout << op_name << " does not support this problem" << std::endl; + } + } + + std::cout << "Best Perf: " << std::setw(10) << best_avg_time << " ms, " << best_tflops + << " TFlops, " << best_gb_per_sec << " GB/s, " << best_op_name << std::endl; + + // run the best intance + { + auto& op_ptr = op_ptrs[best_op_id]; + std::cout << "Run the best instance without timing: " << op_ptr->GetTypeString() + << std::endl; + auto argument_ptr = + op_ptr->MakeArgumentPointer(in.GetDeviceBuffer(), + wei.GetDeviceBuffer(), + {bias.GetDeviceBuffer(), requant_scale.GetDeviceBuffer()}, + out.GetDeviceBuffer(), + in_lengths, + in_strides, + weight_lengths, + weight_strides, + {bias_lengths, requant_scale_lengths}, + {bias_strides, requant_scale_strides}, + out_lengths, + out_strides, + conv_strides, + conv_dilations, + in_left_pad, + in_right_pad, + PassThrough{}, + PassThrough{}, + OutElementOp{ActivationOp{}}); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false}); + } + + std::cout << "Done" << std::endl; + } + + return 0; +} \ No newline at end of file diff --git a/client_example/09_quantization/conv2d_fwd_bias_relu_perlayer_quantization.cpp b/client_example/09_quantization/conv2d_fwd_bias_relu_perlayer_quantization.cpp index 7416e12620..26c7aa15e2 100644 --- a/client_example/09_quantization/conv2d_fwd_bias_relu_perlayer_quantization.cpp +++ b/client_example/09_quantization/conv2d_fwd_bias_relu_perlayer_quantization.cpp @@ -6,7 +6,7 @@ #include #include "ck/ck.hpp" -#include "ck/library/tensor_operation_instance/gpu/grouped_convolution_bias_forward_perlayer_quantization.hpp" +#include "ck/library/tensor_operation_instance/gpu/quantization/grouped_convolution_bias_forward_perlayer_quantization.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/device_conv_fwd.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" diff --git a/client_example/09_quantization/conv2d_fwd_perchannel_quantization.cpp b/client_example/09_quantization/conv2d_fwd_perchannel_quantization.cpp new file mode 100644 index 0000000000..475b2f03b4 --- /dev/null +++ b/client_example/09_quantization/conv2d_fwd_perchannel_quantization.cpp @@ -0,0 +1,198 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/library/tensor_operation_instance/gpu/quantization/grouped_convolution_forward_perchannel_quantization.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_conv_fwd.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +using InDataType = int8_t; +using WeiDataType = int8_t; +using RequantScaleDataType = float; +using OutDataType = int8_t; + +using InLayout = ck::tensor_layout::convolution::GNHWC; +using WeiLayout = ck::tensor_layout::convolution::GKYXC; +using RequantScaleLayout = ck::tensor_layout::convolution::G_K; +using OutLayout = ck::tensor_layout::convolution::GNHWK; +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using ActivationOp = PassThrough; +using OutElementOp = ck::tensor_operation::element_wise::Activation_Mul2_Clamp; + +static constexpr ck::index_t NumDimSpatial = 2; +static constexpr ck::index_t G = 1; +static constexpr ck::index_t N = 4; +static constexpr ck::index_t K = 64; +static constexpr ck::index_t C = 32; +static constexpr ck::index_t Y = 3; +static constexpr ck::index_t X = 3; +static constexpr ck::index_t Hi = 71; +static constexpr ck::index_t Wi = 71; +static constexpr ck::index_t Ho = 36; +static constexpr ck::index_t Wo = 36; + +struct SimpleDeviceMem +{ + SimpleDeviceMem() = delete; + + SimpleDeviceMem(std::size_t mem_size) : p_mem_{} + { + (void)hipMalloc(static_cast(&p_mem_), mem_size); + } + + void* GetDeviceBuffer() { return p_mem_; } + + ~SimpleDeviceMem() { (void)hipFree(p_mem_); } + + void* p_mem_; +}; + +int main(int argc, char* argv[]) +{ + std::array in_lengths{G, N, C, Hi, Wi}; + std::array in_strides{N * Hi * Wi * C, Hi * Wi * C, 1, Wi * C, C}; + std::array weight_lengths{G, K, C, Y, X}; + std::array weight_strides{K * Y * X * C, Y * X * C, 1, X * C, C}; + std::array requant_scale_lengths{G, N, K, Ho, Wo}; + std::array requant_scale_strides{K, 0, 1, 0, 0}; + std::array out_lengths{G, N, C, Ho, Wo}; + std::array out_strides{N * Ho * Wo * C, Ho * Wo * C, 1, Wo * C, C}; + std::array in_left_pad{1, 1}; + std::array in_right_pad{1, 1}; + std::array conv_strides{2, 2}; + std::array conv_dilations{1, 1}; + + SimpleDeviceMem in(sizeof(InDataType) * N * Hi * Wi * C); + SimpleDeviceMem wei(sizeof(WeiDataType) * K * Y * X * C); + SimpleDeviceMem requant_scale(sizeof(RequantScaleDataType) * K * Y * X * C); + SimpleDeviceMem out(sizeof(OutDataType) * N * Ho * Wo * K); + + using DeviceOp = + ck::tensor_operation::device::DeviceGroupedConvFwdMultipleD, + OutLayout, + InDataType, + WeiDataType, + ck::Tuple, + OutDataType, + PassThrough, + PassThrough, + OutElementOp>; + // get device op instances + const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + std::cout << "found " << op_ptrs.size() << " instances" << std::endl; + + std::string best_op_name; + int best_op_id = -1; + float best_avg_time = std::numeric_limits::max(); + float best_gb_per_sec = 0; + float best_tflops = 0; + + // profile device operation instances + std::cout << "Run all instances and do timing" << std::endl; + + for(int i = 0; i < op_ptrs.size(); ++i) + { + auto& op_ptr = op_ptrs[i]; + auto argument_ptr = op_ptr->MakeArgumentPointer(in.GetDeviceBuffer(), + wei.GetDeviceBuffer(), + {requant_scale.GetDeviceBuffer()}, + out.GetDeviceBuffer(), + in_lengths, + in_strides, + weight_lengths, + weight_strides, + {requant_scale_lengths}, + {requant_scale_strides}, + out_lengths, + out_strides, + conv_strides, + conv_dilations, + in_left_pad, + in_right_pad, + PassThrough{}, + PassThrough{}, + OutElementOp{ActivationOp{}}); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + std::string op_name = op_ptr->GetTypeString(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + float avg_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true}); + + std::size_t flop = G * 2 * N * K * C * Ho * Wo * Y * X; + std::size_t num_bytes = G * sizeof(InDataType) * N * Hi * Wi * C + + G * sizeof(WeiDataType) * K * Y * X * C + + G * sizeof(OutDataType) * N * Ho * Wo * K; + + float tflops = static_cast(flop) / 1.E9 / avg_time; + float gb_per_sec = num_bytes / 1.E6 / avg_time; + + std::cout << "Perf: " << std::setw(10) << avg_time << " ms, " << tflops << " TFlops, " + << gb_per_sec << " GB/s, " << op_name << std::endl; + + if(tflops > best_tflops) + { + best_op_id = i; + best_op_name = op_name; + best_avg_time = avg_time; + best_gb_per_sec = gb_per_sec; + best_tflops = tflops; + } + } + else + { + std::cout << op_name << " does not support this problem" << std::endl; + } + } + + std::cout << "Best Perf: " << std::setw(10) << best_avg_time << " ms, " << best_tflops + << " TFlops, " << best_gb_per_sec << " GB/s, " << best_op_name << std::endl; + + // run the best intance + { + auto& op_ptr = op_ptrs[best_op_id]; + std::cout << "Run the best instance without timing: " << op_ptr->GetTypeString() + << std::endl; + auto argument_ptr = op_ptr->MakeArgumentPointer(in.GetDeviceBuffer(), + wei.GetDeviceBuffer(), + {}, + out.GetDeviceBuffer(), + in_lengths, + in_strides, + weight_lengths, + weight_strides, + {}, + {}, + out_lengths, + out_strides, + conv_strides, + conv_dilations, + in_left_pad, + in_right_pad, + PassThrough{}, + PassThrough{}, + OutElementOp{ActivationOp{}}); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false}); + } + + std::cout << "Done" << std::endl; + } + + return 0; +} \ No newline at end of file diff --git a/client_example/09_quantization/conv2d_fwd_perlayer_quantization.cpp b/client_example/09_quantization/conv2d_fwd_perlayer_quantization.cpp index 81176fd2e3..da7b7e6abf 100644 --- a/client_example/09_quantization/conv2d_fwd_perlayer_quantization.cpp +++ b/client_example/09_quantization/conv2d_fwd_perlayer_quantization.cpp @@ -6,7 +6,7 @@ #include #include "ck/ck.hpp" -#include "ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_perlayer_quantization.hpp" +#include "ck/library/tensor_operation_instance/gpu/quantization/grouped_convolution_forward_perlayer_quantization.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/device_conv_fwd.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" diff --git a/example/14_gemm_quantization/CMakeLists.txt b/example/14_gemm_quantization/CMakeLists.txt new file mode 100644 index 0000000000..ca09c48c10 --- /dev/null +++ b/example/14_gemm_quantization/CMakeLists.txt @@ -0,0 +1,2 @@ +add_example_executable(example_gemm_xdl_bias_relu_quantization_int8 gemm_xdl_bias_relu_quantization_int8.cpp) +add_example_executable(example_gemm_xdl_quantization_int8 gemm_xdl_quantization_int8.cpp) \ No newline at end of file diff --git a/example/14_gemm_quantization/gemm_xdl_bias_relu_quantization_int8.cpp b/example/14_gemm_quantization/gemm_xdl_bias_relu_quantization_int8.cpp new file mode 100644 index 0000000000..d5f4e6f62c --- /dev/null +++ b/example/14_gemm_quantization/gemm_xdl_bias_relu_quantization_int8.cpp @@ -0,0 +1,235 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" +#include "ck/library/utility/check_err.hpp" + +template +using S = ck::Sequence; + +using I8 = int8_t; +using I32 = int32_t; +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using ActivationOp = ck::tensor_operation::element_wise::Relu; +using CDEElementOp = ck::tensor_operation::element_wise::Add_Activation_Mul_Clamp; + +using ADataType = I8; +using BDataType = I8; +using AccDataType = I32; +using CShuffleDataType = I32; +using BiasDataType = I32; +using DsDataType = ck::Tuple; +using EDataType = I8; + +using ALayout = Row; +using BLayout = Col; +using BiasLayout = Row; +using DsLayout = ck::Tuple; +using ELayout = Row; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// clang-format off +using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmMultipleD_Xdl_CShuffle< + ALayout, + BLayout, + DsLayout, + ELayout, + ADataType, + BDataType, + AccDataType, + CShuffleDataType, + DsDataType, + EDataType, + PassThrough, // AElementwiseOperation, + PassThrough, // BElementwiseOperation, + CDEElementOp, // CDEElementwiseOperation, + GemmDefault, // GemmSpecialization GemmSpec, + 1, // NumGemmKPrefetchStage, + 256, // BlockSize, + 256, // MPerBlock, + 128, // NPerBlock, + 64, // KPerBlock, + 16, // AK1, + 16, // BK1, + 32, // MPerXDL, + 32, // NPerXDL, + 4, // MXdlPerWave, + 2, // NXdlPerWave, + S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1, + S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder, + S<1, 0, 2>, // ABlockTransferSrcAccessOrder, + 2, // index_t ABlockTransferSrcVectorDim, + 16, // index_t ABlockTransferSrcScalarPerVector, + 16, // index_t ABlockTransferDstScalarPerVector_AK1, + 1, // bool ABlockLdsExtraM, + S<4, 64, 1>, // typename BBlockTransferThreadClusterLengths_BK0_N_BK1, + S<1, 0, 2>, // typename BBlockTransferThreadClusterArrangeOrder, + S<1, 0, 2>, // typename BBlockTransferSrcAccessOrder, + 2, // index_t BBlockTransferSrcVectorDim, + 8, // index_t BBlockTransferSrcScalarPerVector, + 8, // index_t BBlockTransferDstScalarPerVector_BK1, + 1, // bool BBlockLdsExtraN, + 1, // index_t CShuffleMXdlPerWavePerShuffle, + 1, // index_t CShuffleNXdlPerWavePerShuffle, + S<1, 64, 1, 4>, // typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + 8>; // index_t CShuffleBlockTransferScalarPerVector_NPerBlock> +// clang-format on + +using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; + +int main() +{ + bool do_verification = true; + bool time_kernel = false; + + // GEMM shape + ck::index_t M = 1024; + ck::index_t N = 1024; + ck::index_t K = 1024; + + ck::index_t StrideA = 1024; + ck::index_t StrideB = 1024; + ck::index_t StrideBias = 0; + ck::index_t StrideE = 1024; + + float requant_scale = 0.03; + + auto f_host_tensor_descriptor2d = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + using namespace ck::literals; + + if(std::is_same::value) + { + return HostTensorDescriptor(std::vector({row, col}), + std::vector({stride, 1_uz})); + } + else + { + return HostTensorDescriptor(std::vector({row, col}), + std::vector({1_uz, stride})); + } + }; + + auto f_host_tensor_descriptor1d = [](std::size_t len, std::size_t stride) { + return HostTensorDescriptor(std::vector({len}), + std::vector({stride})); + }; + + Tensor a_m_k(f_host_tensor_descriptor2d(M, K, StrideA, ALayout{})); + Tensor b_k_n(f_host_tensor_descriptor2d(K, N, StrideB, BLayout{})); + Tensor bias_n(f_host_tensor_descriptor1d(N, 1)); + Tensor e_m_n_host_result(f_host_tensor_descriptor2d(M, N, StrideE, ELayout{})); + Tensor e_m_n_device_result(f_host_tensor_descriptor2d(M, N, StrideE, ELayout{})); + + std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; + std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; + std::cout << "bias_n: " << bias_n.mDesc << std::endl; + std::cout << "e_m_n: " << e_m_n_host_result.mDesc << std::endl; + + a_m_k.GenerateTensorValue(GeneratorTensor_2{-128, 127}); + b_k_n.GenerateTensorValue(GeneratorTensor_2{-128, 127}); + bias_n.GenerateTensorValue(GeneratorTensor_2{-128, 127}); + + DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize()); + DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize()); + DeviceMem bias_device_buf(sizeof(BiasDataType) * bias_n.mDesc.GetElementSpaceSize()); + DeviceMem e_device_buf(sizeof(EDataType) * e_m_n_device_result.mDesc.GetElementSpaceSize()); + + a_device_buf.ToDevice(a_m_k.mData.data()); + b_device_buf.ToDevice(b_k_n.mData.data()); + bias_device_buf.ToDevice(bias_n.mData.data()); + + auto a_element_op = PassThrough{}; + auto b_element_op = PassThrough{}; + auto cde_element_op = CDEElementOp{requant_scale, ActivationOp{}}; + + // do GEMM + auto gemm = DeviceGemmInstance{}; + auto invoker = gemm.MakeInvoker(); + auto argument = gemm.MakeArgument(a_device_buf.GetDeviceBuffer(), + b_device_buf.GetDeviceBuffer(), + {bias_device_buf.GetDeviceBuffer()}, + e_device_buf.GetDeviceBuffer(), + M, + N, + K, + StrideA, + StrideB, + {StrideBias}, + StrideE, + a_element_op, + b_element_op, + cde_element_op); + + if(!gemm.IsSupportedArgument(argument)) + { + throw std::runtime_error( + "wrong! device_gemm with the specified compilation parameters does " + "not support this GEMM problem"); + } + + float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); + + std::size_t flop = std::size_t(2) * M * N * K; + std::size_t num_btype = + sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(EDataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " + << gemm.GetTypeString() << std::endl; + + e_device_buf.FromDevice(e_m_n_device_result.mData.data()); + + if(do_verification) + { + Tensor c_m_n(HostTensorDescriptor{M, N}); + + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = + ref_gemm.MakeArgument(a_m_k, b_k_n, c_m_n, a_element_op, b_element_op, PassThrough{}); + + ref_invoker.Run(ref_argument); + + for(int m = 0; m < M; ++m) + { + for(int n = 0; n < N; ++n) + { + cde_element_op(e_m_n_host_result(m, n), c_m_n(m, n), bias_n(n)); + } + } + + return ck::utils::check_err(e_m_n_device_result, e_m_n_host_result) ? 0 : 1; + } + + return 0; +} diff --git a/example/14_gemm_quantization/gemm_xdl_quantization_int8.cpp b/example/14_gemm_quantization/gemm_xdl_quantization_int8.cpp new file mode 100644 index 0000000000..2371737382 --- /dev/null +++ b/example/14_gemm_quantization/gemm_xdl_quantization_int8.cpp @@ -0,0 +1,207 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" +#include "ck/library/utility/check_err.hpp" + +template +using S = ck::Sequence; + +using I8 = int8_t; +using I32 = int32_t; +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using ActivationOp = PassThrough; +using CDEElementOp = ck::tensor_operation::element_wise::Activation_Mul_Clamp; + +using ADataType = I8; +using BDataType = I8; +using AccDataType = I32; +using CShuffleDataType = I32; +using DsDataType = ck::Tuple<>; +using EDataType = I8; + +using ALayout = Row; +using BLayout = Col; +using DsLayout = ck::Tuple<>; +using ELayout = Row; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// clang-format off +using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmMultipleD_Xdl_CShuffle< + ALayout, + BLayout, + DsLayout, + ELayout, + ADataType, + BDataType, + AccDataType, + CShuffleDataType, + DsDataType, + EDataType, + PassThrough, // AElementwiseOperation, + PassThrough, // BElementwiseOperation, + CDEElementOp, // CDEElementwiseOperation, + GemmDefault, // GemmSpecialization GemmSpec, + 1, // NumGemmKPrefetchStage, + 256, // BlockSize, + 256, // MPerBlock, + 128, // NPerBlock, + 64, // KPerBlock, + 16, // AK1, + 16, // BK1, + 32, // MPerXDL, + 32, // NPerXDL, + 4, // MXdlPerWave, + 2, // NXdlPerWave, + S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1, + S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder, + S<1, 0, 2>, // ABlockTransferSrcAccessOrder, + 2, // index_t ABlockTransferSrcVectorDim, + 16, // index_t ABlockTransferSrcScalarPerVector, + 16, // index_t ABlockTransferDstScalarPerVector_AK1, + 1, // bool ABlockLdsExtraM, + S<4, 64, 1>, // typename BBlockTransferThreadClusterLengths_BK0_N_BK1, + S<1, 0, 2>, // typename BBlockTransferThreadClusterArrangeOrder, + S<1, 0, 2>, // typename BBlockTransferSrcAccessOrder, + 2, // index_t BBlockTransferSrcVectorDim, + 8, // index_t BBlockTransferSrcScalarPerVector, + 8, // index_t BBlockTransferDstScalarPerVector_BK1, + 1, // bool BBlockLdsExtraN, + 1, // index_t CShuffleMXdlPerWavePerShuffle, + 1, // index_t CShuffleNXdlPerWavePerShuffle, + S<1, 64, 1, 4>, // typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + 16>; // index_t CShuffleBlockTransferScalarPerVector_NPerBlock> +// clang-format on + +using ReferenceGemmInstance = ck::tensor_operation::host:: + ReferenceGemm; + +int main() +{ + bool do_verification = true; + bool time_kernel = false; + + // GEMM shape + ck::index_t M = 1024; + ck::index_t N = 1024; + ck::index_t K = 1024; + + ck::index_t StrideA = 1024; + ck::index_t StrideB = 1024; + ck::index_t StrideE = 1024; + + float requant_scale = 0.03; + + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + using namespace ck::literals; + + if(std::is_same::value) + { + return HostTensorDescriptor(std::vector({row, col}), + std::vector({stride, 1_uz})); + } + else + { + return HostTensorDescriptor(std::vector({row, col}), + std::vector({1_uz, stride})); + } + }; + + Tensor a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); + Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); + Tensor e_m_n_host_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); + Tensor e_m_n_device_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); + + std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; + std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; + std::cout << "e_m_n: " << e_m_n_host_result.mDesc << std::endl; + + a_m_k.GenerateTensorValue(GeneratorTensor_2{-128, 127}); + b_k_n.GenerateTensorValue(GeneratorTensor_2{-128, 127}); + + DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize()); + DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize()); + DeviceMem e_device_buf(sizeof(EDataType) * e_m_n_device_result.mDesc.GetElementSpaceSize()); + + a_device_buf.ToDevice(a_m_k.mData.data()); + b_device_buf.ToDevice(b_k_n.mData.data()); + + auto a_element_op = PassThrough{}; + auto b_element_op = PassThrough{}; + auto cde_element_op = CDEElementOp{requant_scale, ActivationOp{}}; + + // do GEMM + auto gemm = DeviceGemmInstance{}; + auto invoker = gemm.MakeInvoker(); + auto argument = gemm.MakeArgument(a_device_buf.GetDeviceBuffer(), + b_device_buf.GetDeviceBuffer(), + {}, + e_device_buf.GetDeviceBuffer(), + M, + N, + K, + StrideA, + StrideB, + {}, + StrideE, + a_element_op, + b_element_op, + cde_element_op); + + if(!gemm.IsSupportedArgument(argument)) + { + throw std::runtime_error( + "wrong! device_gemm with the specified compilation parameters does " + "not support this GEMM problem"); + } + + float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); + + std::size_t flop = std::size_t(2) * M * N * K; + std::size_t num_btype = + sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(EDataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " + << gemm.GetTypeString() << std::endl; + + e_device_buf.FromDevice(e_m_n_device_result.mData.data()); + + if(do_verification) + { + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = ref_gemm.MakeArgument( + a_m_k, b_k_n, e_m_n_host_result, a_element_op, b_element_op, cde_element_op); + + ref_invoker.Run(ref_argument); + + return ck::utils::check_err(e_m_n_device_result, e_m_n_host_result) ? 0 : 1; + } + + return 0; +} diff --git a/example/14_gemm_xdl_quantization/CMakeLists.txt b/example/14_gemm_xdl_quantization/CMakeLists.txt deleted file mode 100644 index 9674aba2a4..0000000000 --- a/example/14_gemm_xdl_quantization/CMakeLists.txt +++ /dev/null @@ -1 +0,0 @@ -add_example_executable(example_gemm_xdl_relu_quantization_int8 gemm_xdl_relu_quantization_int8.cpp) \ No newline at end of file diff --git a/example/14_gemm_xdl_quantization/gemm_xdl_relu_quantization_int8.cpp b/example/14_gemm_xdl_quantization/gemm_xdl_relu_quantization_int8.cpp deleted file mode 100644 index bb50a90804..0000000000 --- a/example/14_gemm_xdl_quantization/gemm_xdl_relu_quantization_int8.cpp +++ /dev/null @@ -1,233 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. - -#include -#include -#include -#include - -#include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" -#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp" -#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" - -#include "ck/library/utility/device_memory.hpp" -#include "ck/library/utility/host_tensor.hpp" -#include "ck/library/utility/host_tensor_generator.hpp" -#include "ck/library/utility/literals.hpp" -#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" -#include "ck/library/utility/check_err.hpp" - -template -using S = ck::Sequence; - -using PassThrough = ck::tensor_operation::element_wise::PassThrough; -using ActivationOp = ck::tensor_operation::element_wise::Relu; -using CElementOp = ck::tensor_operation::element_wise::Activation_Mul_Clamp; - -using ADataType = int8_t; -using BDataType = int8_t; -using CDataType = int8_t; -using AccDataType = int32_t; -using CShuffleDataType = float; - -using ALayout = ck::tensor_layout::gemm::RowMajor; -using BLayout = ck::tensor_layout::gemm::ColumnMajor; -using CLayout = ck::tensor_layout::gemm::RowMajor; - -static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; - -// clang-format off -using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle< - ALayout, // typename ALayout, - BLayout, // typename BLayout, - CLayout, // typename CLayout, - ADataType, // typename ADataType, - BDataType, // typename BDataType, - CDataType, // typename CDataType, - AccDataType, // typename GemmAccDataType, - CShuffleDataType, // typename CShuffleDataType, - PassThrough, // typename AElementwiseOperation, - PassThrough, // typename BElementwiseOperation, - CElementOp, // typename CElementwiseOperation, - GemmDefault, // GemmSpecialization GemmSpec, - 1, // index_t NumGemmKPrefetchStage, - 256, // index_t BlockSize, - 256, // index_t MPerBlock, - 128, // index_t NPerBlock, - 64, // index_t KPerBlock, - 16, // index_t AK1, - 16, // index_t BK1, - 32, // index_t MPerXDL, - 32, // index_t NPerXDL, - 4, // index_t MXdlPerWave, - 2, // index_t NXdlPerWave, - S<4, 64, 1>, // typename ABlockTransferThreadClusterLengths_AK0_M_AK1, - S<1, 0, 2>, // typename ABlockTransferThreadClusterArrangeOrder, - S<1, 0, 2>, // typename ABlockTransferSrcAccessOrder, - 2, // index_t ABlockTransferSrcVectorDim, - 16, // index_t ABlockTransferSrcScalarPerVector, - 16, // index_t ABlockTransferDstScalarPerVector_AK1, - 1, // bool ABlockLdsExtraM, - S<4, 64, 1>, // typename BBlockTransferThreadClusterLengths_BK0_N_BK1, - S<1, 0, 2>, // typename BBlockTransferThreadClusterArrangeOrder, - S<1, 0, 2>, // typename BBlockTransferSrcAccessOrder, - 2, // index_t BBlockTransferSrcVectorDim, - 8, // index_t BBlockTransferSrcScalarPerVector, - 8, // index_t BBlockTransferDstScalarPerVector_BK1, - 1, // bool BBlockLdsExtraN, - 1, // index_t CShuffleMXdlPerWavePerShuffle, - 1, // index_t CShuffleNXdlPerWavePerShuffle, - S<1, 64, 1, 4>, // typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, - 16>; // index_t CShuffleBlockTransferScalarPerVector_NPerBlock> -// clang-format on - -using ReferenceGemmInstance = ck::tensor_operation::host:: - ReferenceGemm; - -int main(int argc, char* argv[]) -{ - bool do_verification = true; - int init_method = 1; - bool time_kernel = false; - - // GEMM shape - ck::index_t M = 3840; - ck::index_t N = 4096; - ck::index_t K = 4096; - - ck::index_t StrideA = 4096; - ck::index_t StrideB = 4096; - ck::index_t StrideC = 4096; - - float quant_multiplier = 0.03; - - if(argc == 4) - { - do_verification = std::stoi(argv[1]); - init_method = std::stoi(argv[2]); - time_kernel = std::stoi(argv[3]); - } - else if(argc == 10) - { - do_verification = std::stoi(argv[1]); - init_method = std::stoi(argv[2]); - time_kernel = std::stoi(argv[3]); - - M = std::stoi(argv[4]); - N = std::stoi(argv[5]); - K = std::stoi(argv[6]); - - StrideA = std::stoi(argv[7]); - StrideB = std::stoi(argv[8]); - StrideC = std::stoi(argv[9]); - } - else - { - printf("arg1: verification (0=no, 1=yes)\n"); - printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); - printf("arg3: time kernel (0=n0, 1=yes)\n"); - printf("arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC\n"); - exit(0); - } - - auto f_host_tensor_descriptor = - [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { - using namespace ck::literals; - - if(std::is_same::value) - { - return HostTensorDescriptor({row, col}, {stride, 1_uz}); - } - else - { - return HostTensorDescriptor({row, col}, {1_uz, stride}); - } - }; - - Tensor a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); - Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); - Tensor c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); - Tensor c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); - - std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; - std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; - std::cout << "c_m_n: " << c_m_n_host_result.mDesc << std::endl; - - switch(init_method) - { - case 0: break; - case 1: - a_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 5}); - b_k_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); - break; - default: - a_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); - b_k_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); - } - - DeviceMem a_m_k_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize()); - DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize()); - DeviceMem c_m_n_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpaceSize()); - - a_m_k_device_buf.ToDevice(a_m_k.mData.data()); - b_k_n_device_buf.ToDevice(b_k_n.mData.data()); - - auto a_element_op = PassThrough{}; - auto b_element_op = PassThrough{}; - auto c_element_op = CElementOp{quant_multiplier, ActivationOp{}}; - - // do GEMM - auto gemm = DeviceGemmInstance{}; - auto invoker = gemm.MakeInvoker(); - auto argument = gemm.MakeArgument(static_cast(a_m_k_device_buf.GetDeviceBuffer()), - static_cast(b_k_n_device_buf.GetDeviceBuffer()), - static_cast(c_m_n_device_buf.GetDeviceBuffer()), - M, - N, - K, - StrideA, - StrideB, - StrideC, - a_element_op, - b_element_op, - c_element_op); - - if(!gemm.IsSupportedArgument(argument)) - { - throw std::runtime_error( - "wrong! device_gemm with the specified compilation parameters does " - "not support this GEMM problem"); - } - - float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); - - std::size_t flop = std::size_t(2) * M * N * K; - std::size_t num_btype = - sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(CDataType) * M * N; - - float tflops = static_cast(flop) / 1.E9 / ave_time; - - float gb_per_sec = num_btype / 1.E6 / ave_time; - - std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " - << gemm.GetTypeString() << std::endl; - - c_m_n_device_buf.FromDevice(c_m_n_device_result.mData.data()); - - if(do_verification) - { - auto ref_gemm = ReferenceGemmInstance{}; - auto ref_invoker = ref_gemm.MakeInvoker(); - - auto ref_argument = ref_gemm.MakeArgument( - a_m_k, b_k_n, c_m_n_host_result, a_element_op, b_element_op, c_element_op); - - ref_invoker.Run(ref_argument); - - return ck::utils::check_err(c_m_n_device_result, c_m_n_host_result) ? 0 : 1; - } - - return 0; -} diff --git a/example/44_conv2d_fwd_quant/CMakeLists.txt b/example/44_conv2d_fwd_quantization/CMakeLists.txt similarity index 65% rename from example/44_conv2d_fwd_quant/CMakeLists.txt rename to example/44_conv2d_fwd_quantization/CMakeLists.txt index 1ecf89ccb8..f02e5110d0 100644 --- a/example/44_conv2d_fwd_quant/CMakeLists.txt +++ b/example/44_conv2d_fwd_quantization/CMakeLists.txt @@ -1,2 +1,3 @@ +add_example_executable(example_conv2d_fwd_xdl_perchannel_quantization_int8 conv2d_fwd_xdl_bias_relu_perchannel_quantization_int8.cpp) add_example_executable(example_conv2d_fwd_xdl_perlayer_quantization_int8 conv2d_fwd_xdl_perlayer_quantization_int8.cpp) add_example_executable(example_conv2d_fwd_xdl_bias_relu_perlayer_quantization_int8 conv2d_fwd_xdl_bias_relu_perlayer_quantization_int8.cpp) diff --git a/example/44_conv2d_fwd_quantization/conv2d_fwd_xdl_bias_relu_perchannel_quantization_int8.cpp b/example/44_conv2d_fwd_quantization/conv2d_fwd_xdl_bias_relu_perchannel_quantization_int8.cpp new file mode 100644 index 0000000000..832665edc0 --- /dev/null +++ b/example/44_conv2d_fwd_quantization/conv2d_fwd_xdl_bias_relu_perchannel_quantization_int8.cpp @@ -0,0 +1,342 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/utility/algorithm.hpp" +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/utility/convolution_parameter.hpp" +#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp" + +using InDataType = int8_t; +using WeiDataType = int8_t; +using BiasDataType = int32_t; +using RequantScaleDataType = float; +using AccDataType = int32_t; +using CShuffleDataType = int32_t; +using OutDataType = int8_t; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using InElementOp = PassThrough; +using WeiElementOp = PassThrough; +using ActivationOp = ck::tensor_operation::element_wise::Relu; +using OutElementOp = ck::tensor_operation::element_wise::Add_Activation_Mul2_Clamp; + +static constexpr auto ConvSpec = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +template +using DeviceGroupedConvNDFwdInstance = + ck::tensor_operation::device::DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< + NDimSpatial, + InLayout, + WeiLayout, + ck::Tuple, + OutLayout, + InDataType, + WeiDataType, + AccDataType, + CShuffleDataType, + ck::Tuple, + OutDataType, + InElementOp, + WeiElementOp, + OutElementOp, + ConvSpec, // ConvForwardSpecialization + GemmSpec, // GemmSpecialization + 1, // + 256, // BlockSize + 128, // MPerBlock + 256, // NPerBlock + 64, // KPerBlock + 16, // AK1 + 16, // BK1 + 32, // MPerXdl + 32, // NPerXdl + 2, // MXdlPerWave + 4, // NXdlPerWave + S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1 + S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // ABlockTransferSrcAccessOrder + 2, // ABlockTransferSrcVectorDim + 16, // ABlockTransferSrcScalarPerVector + 16, // ABlockTransferDstScalarPerVector_AK1 + 1, // ABlockLdsExtraM + S<4, 64, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1 + S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // BBlockTransferSrcAccessOrder + 2, // BBlockTransferSrcVectorDim + 16, // BBlockTransferSrcScalarPerVector + 16, // BBlockTransferDstScalarPerVector_BK1 + 1, // BBlockLdsExtraN + 1, + 1, + S<1, 64, 1, 4>, + 8>; + +template +bool run_grouped_conv_fwd(bool do_verification, + bool time_kernel, + const ck::utils::conv::ConvParam& conv_param, + const HostTensorDescriptor& in_g_n_c_wis_desc, + const HostTensorDescriptor& wei_g_k_c_xs_desc, + const HostTensorDescriptor& bias_g_k_desc, + const HostTensorDescriptor& requant_scale_g_k_desc, + const HostTensorDescriptor& out_g_n_k_wos_desc, + const InElementOp& in_element_op, + const WeiElementOp& wei_element_op, + const OutElementOp& out_element_op) +{ + Tensor in(in_g_n_c_wis_desc); + Tensor wei(wei_g_k_c_xs_desc); + Tensor bias(bias_g_k_desc); + Tensor requant_scale(requant_scale_g_k_desc); + Tensor out_host(out_g_n_k_wos_desc); + Tensor out_device(out_g_n_k_wos_desc); + + std::cout << "in: " << in.mDesc << std::endl; + std::cout << "wei: " << wei.mDesc << std::endl; + std::cout << "bias: " << bias.mDesc << std::endl; + std::cout << "requant_scale: " << requant_scale.mDesc << std::endl; + std::cout << "out: " << out_host.mDesc << std::endl; + + in.GenerateTensorValue(GeneratorTensor_2{-128, 127}); + wei.GenerateTensorValue(GeneratorTensor_2{-128, 127}); + bias.GenerateTensorValue(GeneratorTensor_2{-128, 127}); + requant_scale.GenerateTensorValue(GeneratorTensor_2{0, 1}); + + DeviceMem in_device_buf(sizeof(InDataType) * in.mDesc.GetElementSpaceSize()); + DeviceMem wei_device_buf(sizeof(WeiDataType) * wei.mDesc.GetElementSpaceSize()); + DeviceMem bias_device_buf(sizeof(BiasDataType) * bias.mDesc.GetElementSpaceSize()); + DeviceMem requant_scale_device_buf(sizeof(RequantScaleDataType) * + requant_scale.mDesc.GetElementSpaceSize()); + DeviceMem out_device_buf(sizeof(OutDataType) * out_device.mDesc.GetElementSpaceSize()); + + in_device_buf.ToDevice(in.mData.data()); + wei_device_buf.ToDevice(wei.mData.data()); + bias_device_buf.ToDevice(bias.mData.data()); + requant_scale_device_buf.ToDevice(requant_scale.mData.data()); + + std::array a_g_n_c_wis_lengths{}; + std::array a_g_n_c_wis_strides{}; + std::array b_g_k_c_xs_lengths{}; + std::array b_g_k_c_xs_strides{}; + std::array d0_g_n_k_wos_lengths{}; + std::array d0_g_n_k_wos_strides{}; + std::array d1_g_n_k_wos_lengths{}; + std::array d1_g_n_k_wos_strides{}; + std::array e_g_n_k_wos_lengths{}; + std::array e_g_n_k_wos_strides{}; + std::array conv_filter_strides{}; + std::array conv_filter_dilations{}; + std::array input_left_pads{}; + std::array input_right_pads{}; + + auto copy = [](const auto& x, auto& y) { ck::ranges::copy(x, y.begin()); }; + + copy(in_g_n_c_wis_desc.GetLengths(), a_g_n_c_wis_lengths); + copy(in_g_n_c_wis_desc.GetStrides(), a_g_n_c_wis_strides); + copy(wei_g_k_c_xs_desc.GetLengths(), b_g_k_c_xs_lengths); + copy(wei_g_k_c_xs_desc.GetStrides(), b_g_k_c_xs_strides); + copy(bias_g_k_desc.GetLengths(), d0_g_n_k_wos_lengths); + copy(bias_g_k_desc.GetStrides(), d0_g_n_k_wos_strides); + copy(requant_scale_g_k_desc.GetLengths(), d1_g_n_k_wos_lengths); + copy(requant_scale_g_k_desc.GetStrides(), d1_g_n_k_wos_strides); + copy(out_g_n_k_wos_desc.GetLengths(), e_g_n_k_wos_lengths); + copy(out_g_n_k_wos_desc.GetStrides(), e_g_n_k_wos_strides); + copy(conv_param.conv_filter_strides_, conv_filter_strides); + copy(conv_param.conv_filter_dilations_, conv_filter_dilations); + copy(conv_param.input_left_pads_, input_left_pads); + copy(conv_param.input_right_pads_, input_right_pads); + + // do Conv + auto conv = DeviceConvNDFwdInstance{}; + auto invoker = conv.MakeInvoker(); + auto argument = conv.MakeArgument( + in_device_buf.GetDeviceBuffer(), + wei_device_buf.GetDeviceBuffer(), + {bias_device_buf.GetDeviceBuffer(), requant_scale_device_buf.GetDeviceBuffer()}, + out_device_buf.GetDeviceBuffer(), + a_g_n_c_wis_lengths, + a_g_n_c_wis_strides, + b_g_k_c_xs_lengths, + b_g_k_c_xs_strides, + {d0_g_n_k_wos_lengths, d1_g_n_k_wos_lengths}, + {d0_g_n_k_wos_strides, d1_g_n_k_wos_strides}, + e_g_n_k_wos_lengths, + e_g_n_k_wos_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + in_element_op, + wei_element_op, + out_element_op); + + if(!conv.IsSupportedArgument(argument)) + { + throw std::runtime_error( + "wrong! device_conv with the specified compilation parameters does " + "not support this Conv problem"); + } + + float avg_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); + + std::size_t flop = conv_param.GetFlops(); + std::size_t num_btype = conv_param.GetByte(); + + float tflops = static_cast(flop) / 1.E9 / avg_time; + float gb_per_sec = num_btype / 1.E6 / avg_time; + std::cout << "Perf: " << avg_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " + << conv.GetTypeString() << std::endl; + + bool pass = true; + + if(do_verification) + { + Tensor c_host(out_g_n_k_wos_desc); + + auto ref_conv = ck::tensor_operation::host::ReferenceConvFwd(); + + auto ref_invoker = ref_conv.MakeInvoker(); + auto ref_argument = ref_conv.MakeArgument(in, + wei, + c_host, + conv_param.conv_filter_strides_, + conv_param.conv_filter_dilations_, + conv_param.input_left_pads_, + conv_param.input_right_pads_, + in_element_op, + wei_element_op, + PassThrough{}); + + ref_invoker.Run(ref_argument); + + // TODO: implement elementwise operation for host + out_host.ForEach([&](auto&, auto idx) { + out_element_op(out_host(idx), c_host(idx), bias(idx), requant_scale(idx)); + }); + + out_device_buf.FromDevice(out_device.mData.data()); + + pass &= + ck::utils::check_err(out_device, out_host, "Error: incorrect results!", 1e-5f, 1e-4f); + } + + return (pass ? 0 : 1); +} + +int main() +{ + bool do_verification = true; + bool time_kernel = true; + const ck::index_t ndim_spatial = 2; + + ck::utils::conv::ConvParam conv_param{ + ndim_spatial, // n_dim + 1, // group + 4, // batch + 64, // output channels + 32, // input chanels + {3, 3}, // weight HW + {71, 71}, // x HW + {2, 2}, // strides + {1, 1}, // dilations + {1, 1}, // left_pads + {1, 1} // right_pads + }; + + const auto in_element_op = InElementOp{}; + const auto wei_element_op = WeiElementOp{}; + const auto out_element_op = OutElementOp{ActivationOp{}}; + + using InLayout = ck::tensor_layout::convolution::GNHWC; + using WeiLayout = ck::tensor_layout::convolution::GKYXC; + using BiasLayout = ck::tensor_layout::convolution::G_K; + using RequantScaleLayout = ck::tensor_layout::convolution::G_K; + using OutLayout = ck::tensor_layout::convolution::GNHWK; + + const auto in_g_n_c_wis_desc = + ck::utils::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed(conv_param); + + const auto wei_g_k_c_xs_desc = + ck::utils::conv::make_weight_host_tensor_descriptor_g_k_c_xs_packed(conv_param); + + // TODO - make_bias_host_tensor_descriptor_g_n_k_wos_packed() + const auto bias_g_k_desc = HostTensorDescriptor({conv_param.G_, + conv_param.N_, + conv_param.K_, + conv_param.output_spatial_lengths_[0], + conv_param.output_spatial_lengths_[1]}, + { + conv_param.K_, // g + 0, // n + 1, // k + 0, // ho + 0 // wo + }); + + const auto requant_scale_g_k_desc = bias_g_k_desc; + + const auto out_g_n_k_wos_desc = + ck::utils::conv::make_output_host_tensor_descriptor_g_n_k_wos_packed(conv_param); + + std::cout << out_g_n_k_wos_desc << std::endl; + + using deviceOp = DeviceGroupedConvNDFwdInstance; + + return run_grouped_conv_fwd(do_verification, + time_kernel, + conv_param, + in_g_n_c_wis_desc, + wei_g_k_c_xs_desc, + bias_g_k_desc, + requant_scale_g_k_desc, + out_g_n_k_wos_desc, + in_element_op, + wei_element_op, + out_element_op); +} diff --git a/example/44_conv2d_fwd_quant/conv2d_fwd_xdl_bias_relu_perlayer_quantization_int8.cpp b/example/44_conv2d_fwd_quantization/conv2d_fwd_xdl_bias_relu_perlayer_quantization_int8.cpp similarity index 90% rename from example/44_conv2d_fwd_quant/conv2d_fwd_xdl_bias_relu_perlayer_quantization_int8.cpp rename to example/44_conv2d_fwd_quantization/conv2d_fwd_xdl_bias_relu_perlayer_quantization_int8.cpp index 51315de7ed..f540135035 100644 --- a/example/44_conv2d_fwd_quant/conv2d_fwd_xdl_bias_relu_perlayer_quantization_int8.cpp +++ b/example/44_conv2d_fwd_quantization/conv2d_fwd_xdl_bias_relu_perlayer_quantization_int8.cpp @@ -11,6 +11,7 @@ #include "ck/library/utility/device_memory.hpp" #include "ck/library/utility/host_tensor.hpp" #include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" #include "ck/library/utility/convolution_parameter.hpp" #include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp" #include "ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp" @@ -163,26 +164,25 @@ bool run_grouped_conv_fwd(bool do_verification, // do Conv auto conv = DeviceConvNDFwdInstance{}; auto invoker = conv.MakeInvoker(); - auto argument = conv.MakeArgument( - in_device_buf.GetDeviceBuffer(), - wei_device_buf.GetDeviceBuffer(), - std::array{bias_device_buf.GetDeviceBuffer()}, - out_device_buf.GetDeviceBuffer(), - a_g_n_c_wis_lengths, - a_g_n_c_wis_strides, - b_g_k_c_xs_lengths, - b_g_k_c_xs_strides, - std::array, 1>{{d0_g_n_k_wos_lengths}}, - std::array, 1>{{d0_g_n_k_wos_strides}}, - e_g_n_k_wos_lengths, - e_g_n_k_wos_strides, - conv_filter_strides, - conv_filter_dilations, - input_left_pads, - input_right_pads, - in_element_op, - wei_element_op, - out_element_op); + auto argument = conv.MakeArgument(in_device_buf.GetDeviceBuffer(), + wei_device_buf.GetDeviceBuffer(), + {bias_device_buf.GetDeviceBuffer()}, + out_device_buf.GetDeviceBuffer(), + a_g_n_c_wis_lengths, + a_g_n_c_wis_strides, + b_g_k_c_xs_lengths, + b_g_k_c_xs_strides, + {d0_g_n_k_wos_lengths}, + {d0_g_n_k_wos_strides}, + e_g_n_k_wos_lengths, + e_g_n_k_wos_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + in_element_op, + wei_element_op, + out_element_op); if(!conv.IsSupportedArgument(argument)) { @@ -235,8 +235,8 @@ bool run_grouped_conv_fwd(bool do_verification, out_device_buf.FromDevice(out_device.mData.data()); - pass &= ck::utils::check_err( - out_device.mData, out_host.mData, "Error: incorrect results!", 1e-5f, 1e-4f); + pass &= + ck::utils::check_err(out_device, out_host, "Error: incorrect results!", 1e-5f, 1e-4f); } return (pass ? 0 : 1); diff --git a/example/44_conv2d_fwd_quant/conv2d_fwd_xdl_perlayer_quantization_int8.cpp b/example/44_conv2d_fwd_quantization/conv2d_fwd_xdl_perlayer_quantization_int8.cpp similarity index 96% rename from example/44_conv2d_fwd_quant/conv2d_fwd_xdl_perlayer_quantization_int8.cpp rename to example/44_conv2d_fwd_quantization/conv2d_fwd_xdl_perlayer_quantization_int8.cpp index fa7f34cef0..2d46d86655 100644 --- a/example/44_conv2d_fwd_quant/conv2d_fwd_xdl_perlayer_quantization_int8.cpp +++ b/example/44_conv2d_fwd_quantization/conv2d_fwd_xdl_perlayer_quantization_int8.cpp @@ -11,6 +11,7 @@ #include "ck/library/utility/device_memory.hpp" #include "ck/library/utility/host_tensor.hpp" #include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" #include "ck/library/utility/convolution_parameter.hpp" #include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp" #include "ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp" @@ -150,14 +151,14 @@ bool run_grouped_conv_fwd(bool do_verification, auto invoker = conv.MakeInvoker(); auto argument = conv.MakeArgument(in_device_buf.GetDeviceBuffer(), wei_device_buf.GetDeviceBuffer(), - std::array{}, + {}, out_device_buf.GetDeviceBuffer(), a_g_n_c_wis_lengths, a_g_n_c_wis_strides, b_g_k_c_xs_lengths, b_g_k_c_xs_strides, - std::array, 0>{{}}, - std::array, 0>{{}}, + {}, + {}, e_g_n_k_wos_lengths, e_g_n_k_wos_strides, conv_filter_strides, @@ -213,8 +214,8 @@ bool run_grouped_conv_fwd(bool do_verification, out_device_buf.FromDevice(out_device.mData.data()); - pass &= ck::utils::check_err( - out_device.mData, out_host.mData, "Error: incorrect results!", 1e-5f, 1e-4f); + pass &= + ck::utils::check_err(out_device, out_host, "Error: incorrect results!", 1e-5f, 1e-4f); } return (pass ? 0 : 1); diff --git a/include/ck/tensor_operation/gpu/element/quantization_operation.hpp b/include/ck/tensor_operation/gpu/element/quantization_operation.hpp index f27b61ba53..3f2c2f8773 100644 --- a/include/ck/tensor_operation/gpu/element/quantization_operation.hpp +++ b/include/ck/tensor_operation/gpu/element/quantization_operation.hpp @@ -10,8 +10,8 @@ namespace element_wise { template struct Activation_Mul_Clamp { - Activation_Mul_Clamp(float multiplier, Activation activationOp) - : multiplier_(multiplier), activationOp_(activationOp) + Activation_Mul_Clamp(float requantScale, Activation activationOp) + : requantScale_(requantScale), activationOp_(activationOp) { } @@ -19,7 +19,7 @@ struct Activation_Mul_Clamp { float x_fp32 = ck::type_convert(x); activationOp_(x_fp32, x_fp32); - float y_fp32 = math::clamp(multiplier_ * x_fp32, -128.f, 127.f); + float y_fp32 = math::clamp(requantScale_ * x_fp32, -128.f, 127.f); y = ck::type_convert(y_fp32); } @@ -28,10 +28,29 @@ struct Activation_Mul_Clamp // We might type_convert to int8 after lambda in someplace float x_fp32 = ck::type_convert(x); activationOp_(x_fp32, x_fp32); - y = math::clamp(multiplier_ * x_fp32, -128.f, 127.f); + y = math::clamp(requantScale_ * x_fp32, -128.f, 127.f); + } + + float requantScale_; + Activation activationOp_; +}; + +// Conv Perchannel quantization + Activation function which is piecewise linear function, such as +// relu, leaky relu ...etc +template +struct Activation_Mul2_Clamp +{ + Activation_Mul2_Clamp(Activation activationOp) : activationOp_(activationOp) {} + + __host__ __device__ constexpr void + operator()(int8_t& y, const int32_t& x, const float& requantScale) const + { + float y_fp32 = ck::type_convert(x); + activationOp_(y_fp32, y_fp32); + y_fp32 = math::clamp(requantScale * y_fp32, -128.f, 127.f); + y = ck::type_convert(y_fp32); } - float multiplier_; Activation activationOp_; }; @@ -39,21 +58,40 @@ struct Activation_Mul_Clamp template struct Add_Activation_Mul_Clamp { - Add_Activation_Mul_Clamp(float multiplier, Activation activationOp) - : multiplier_(multiplier), activationOp_(activationOp) + Add_Activation_Mul_Clamp(float requantScale, Activation activationOp) + : requantScale_(requantScale), activationOp_(activationOp) { } __host__ __device__ constexpr void - operator()(int8_t& y, const int32_t& x1, const int32_t& x2) const + operator()(int8_t& y, const int32_t& x, const int32_t& bias) const { - float y_fp32 = ck::type_convert(x1 + x2); + float y_fp32 = ck::type_convert(x + bias); activationOp_(y_fp32, y_fp32); - y_fp32 = math::clamp(multiplier_ * y_fp32, -128.f, 127.f); + y_fp32 = math::clamp(requantScale_ * y_fp32, -128.f, 127.f); + y = ck::type_convert(y_fp32); + } + + float requantScale_; + Activation activationOp_; +}; + +// Conv Perchannel quantization + Activation function which is piecewise linear function, such as +// relu, leaky relu ...etc +template +struct Add_Activation_Mul2_Clamp +{ + Add_Activation_Mul2_Clamp(Activation activationOp) : activationOp_(activationOp) {} + + __host__ __device__ constexpr void + operator()(int8_t& y, const int32_t& x, const int32_t& bias, const float& requantScale) const + { + float y_fp32 = ck::type_convert(x + bias); + activationOp_(y_fp32, y_fp32); + y_fp32 = math::clamp(requantScale * y_fp32, -128.f, 127.f); y = ck::type_convert(y_fp32); } - float multiplier_; Activation activationOp_; }; @@ -61,23 +99,23 @@ struct Add_Activation_Mul_Clamp template struct Add_Mul_Activation_Mul_Clamp { - Add_Mul_Activation_Mul_Clamp(float multiplier1, float multiplier2, Activation activationOp) - : multiplier1_(multiplier1), multiplier2_(multiplier2), activationOp_(activationOp) + Add_Mul_Activation_Mul_Clamp(float requantScale1, float requantScale2, Activation activationOp) + : requantScale1_(requantScale1), requantScale2_(requantScale2), activationOp_(activationOp) { } __host__ __device__ constexpr void - operator()(int8_t& y, const int32_t& x1, const int32_t& x2) const + operator()(int8_t& y, const int32_t& x, const int32_t& bias) const { - float y_fp32 = ck::type_convert(x1 + x2); - y_fp32 = multiplier1_ * y_fp32; + float y_fp32 = ck::type_convert(x + bias); + y_fp32 = requantScale1_ * y_fp32; activationOp_(y_fp32, y_fp32); - y_fp32 = math::clamp(multiplier2_ * y_fp32, -128.f, 127.f); + y_fp32 = math::clamp(requantScale2_ * y_fp32, -128.f, 127.f); y = ck::type_convert(y_fp32); } - float multiplier1_; - float multiplier2_; + float requantScale1_; + float requantScale2_; Activation activationOp_; }; diff --git a/library/include/ck/library/tensor_operation_instance/device_operation_instance_factory.hpp b/library/include/ck/library/tensor_operation_instance/device_operation_instance_factory.hpp index 785d5510f3..91980a9a66 100644 --- a/library/include/ck/library/tensor_operation_instance/device_operation_instance_factory.hpp +++ b/library/include/ck/library/tensor_operation_instance/device_operation_instance_factory.hpp @@ -26,9 +26,9 @@ using Empty_Tuple = ck::Tuple<>; using F16_Tuple = ck::Tuple; using F16_F16_Tuple = ck::Tuple; -using F32_Tuple = ck::Tuple; - -using I32_Tuple = ck::Tuple; +using F32_Tuple = ck::Tuple; +using I32_Tuple = ck::Tuple; +using I32_F32_Tuple = ck::Tuple; // GEMM layout using Row = ck::tensor_layout::gemm::RowMajor; @@ -78,8 +78,9 @@ using NHWGK = ck::tensor_layout::convolution::NHWGK; using NDHWGK = ck::tensor_layout::convolution::NDHWGK; // -using GK = ck::tensor_layout::convolution::G_K; -using GK_TUPLE = ck::Tuple; +using GK = ck::tensor_layout::convolution::G_K; +using GK_Tuple = ck::Tuple; +using GK_GK_Tuple = ck::Tuple; // pointwise functor using PassThrough = ck::tensor_operation::element_wise::PassThrough; @@ -97,6 +98,13 @@ template using Add_Activation_Mul_Clamp = ck::tensor_operation::element_wise::Add_Activation_Mul_Clamp; +template +using Activation_Mul2_Clamp = ck::tensor_operation::element_wise::Activation_Mul2_Clamp; + +template +using Add_Activation_Mul2_Clamp = + ck::tensor_operation::element_wise::Add_Activation_Mul2_Clamp; + template struct DeviceOperationInstanceFactory; diff --git a/library/include/ck/library/tensor_operation_instance/gpu/quantization/grouped_convolution_bias_forward_perchannel_quantization.hpp b/library/include/ck/library/tensor_operation_instance/gpu/quantization/grouped_convolution_bias_forward_perchannel_quantization.hpp new file mode 100644 index 0000000000..eda81a233c --- /dev/null +++ b/library/include/ck/library/tensor_operation_instance/gpu/quantization/grouped_convolution_bias_forward_perchannel_quantization.hpp @@ -0,0 +1,114 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_d.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// grouped conv2d forward, GNHWC/GKYXC/GNHWK +void add_device_conv2d_bias_perchannel_quantization_int8_instances( + std::vector< + std::unique_ptr>>>& + instances); + +void add_device_conv2d_bias_relu_perchannel_quantization_int8_instances( + std::vector>>>& + instances); + +template +struct DeviceOperationInstanceFactory>> +{ + using DeviceOp = DeviceGroupedConvFwdMultipleD>; + + static auto GetInstances() + { + std::vector> op_ptrs; + + if constexpr(NumDimSpatial == 2 && is_same_v && + is_same_v && is_same_v && + is_same_v) + { + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v) + { + if constexpr(is_same_v) + add_device_conv2d_bias_perchannel_quantization_int8_instances(op_ptrs); + else if constexpr(is_same_v) + add_device_conv2d_bias_relu_perchannel_quantization_int8_instances(op_ptrs); + } + } + + return op_ptrs; + } +}; + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_bias_forward_perlayer_quantization.hpp b/library/include/ck/library/tensor_operation_instance/gpu/quantization/grouped_convolution_bias_forward_perlayer_quantization.hpp similarity index 98% rename from library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_bias_forward_perlayer_quantization.hpp rename to library/include/ck/library/tensor_operation_instance/gpu/quantization/grouped_convolution_bias_forward_perlayer_quantization.hpp index 9d441d14d1..1138402638 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_bias_forward_perlayer_quantization.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/quantization/grouped_convolution_bias_forward_perlayer_quantization.hpp @@ -23,7 +23,7 @@ void add_device_conv2d_bias_perlayer_quantization_int8_instances( std::unique_ptr> op_ptrs; if constexpr(NumDimSpatial == 2 && is_same_v && - is_same_v && is_same_v && + is_same_v && is_same_v && is_same_v) { if constexpr(is_same_v && is_same_v && diff --git a/library/include/ck/library/tensor_operation_instance/gpu/quantization/grouped_convolution_forward_perchannel_quantization.hpp b/library/include/ck/library/tensor_operation_instance/gpu/quantization/grouped_convolution_forward_perchannel_quantization.hpp new file mode 100644 index 0000000000..1a67ce5688 --- /dev/null +++ b/library/include/ck/library/tensor_operation_instance/gpu/quantization/grouped_convolution_forward_perchannel_quantization.hpp @@ -0,0 +1,113 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_d.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// grouped conv2d forward, GNHWC/GKYXC/GNHWK +void add_device_conv2d_perchannel_quantization_int8_instances( + std::vector>>>& + instances); + +void add_device_conv2d_relu_perchannel_quantization_int8_instances( + std::vector>>>& + instances); + +template +struct DeviceOperationInstanceFactory>> +{ + using DeviceOp = DeviceGroupedConvFwdMultipleD>; + + static auto GetInstances() + { + std::vector> op_ptrs; + + if constexpr(NumDimSpatial == 2 && is_same_v && + is_same_v && is_same_v && + is_same_v) + { + if constexpr(is_same_v && is_same_v && + is_same_v) + { + if constexpr(is_same_v) + add_device_conv2d_perchannel_quantization_int8_instances(op_ptrs); + else if constexpr(is_same_v) + add_device_conv2d_relu_perchannel_quantization_int8_instances(op_ptrs); + } + } + + return op_ptrs; + } +}; + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_perlayer_quantization.hpp b/library/include/ck/library/tensor_operation_instance/gpu/quantization/grouped_convolution_forward_perlayer_quantization.hpp similarity index 100% rename from library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_perlayer_quantization.hpp rename to library/include/ck/library/tensor_operation_instance/gpu/quantization/grouped_convolution_forward_perlayer_quantization.hpp diff --git a/library/src/tensor_operation_instance/gpu/quantization/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/quantization/CMakeLists.txt index 8b2149aefb..9f826afd68 100644 --- a/library/src/tensor_operation_instance/gpu/quantization/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/quantization/CMakeLists.txt @@ -1,4 +1,6 @@ add_instance_library(device_quantization_instance - device_conv2d_xdl_bias_quant_int8_instance.cpp - device_conv2d_xdl_quant_int8_instance.cpp + device_conv2d_xdl_bias_perchannel_quantization_int8_instance.cpp + device_conv2d_xdl_bias_perlayer_quantization_int8_instance.cpp + device_conv2d_xdl_perchannel_quantization_int8_instance.cpp + device_conv2d_xdl_perlayer_quantization_int8_instance.cpp ) diff --git a/library/src/tensor_operation_instance/gpu/quantization/device_conv2d_xdl_bias_perchannel_quantization_int8_instance.cpp b/library/src/tensor_operation_instance/gpu/quantization/device_conv2d_xdl_bias_perchannel_quantization_int8_instance.cpp new file mode 100644 index 0000000000..e87e987593 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/quantization/device_conv2d_xdl_bias_perchannel_quantization_int8_instance.cpp @@ -0,0 +1,74 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_conv2d_xdl_int8_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +void add_device_conv2d_bias_perchannel_quantization_int8_instances( + std::vector>>& instances) +{ + add_device_operation_instances(instances, + device_conv2d_int8_32Ds_instances{}); + add_device_operation_instances(instances, + device_conv2d_int8_32Ds_instances{}); + add_device_operation_instances(instances, + device_conv2d_int8_32Ds_instances{}); +} + +void add_device_conv2d_bias_relu_perchannel_quantization_int8_instances( + std::vector>>& instances) +{ + add_device_operation_instances(instances, + device_conv2d_int8_32Ds_instances{}); + add_device_operation_instances(instances, + device_conv2d_int8_32Ds_instances{}); + add_device_operation_instances(instances, + device_conv2d_int8_32Ds_instances{}); +} +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/quantization/device_conv2d_xdl_bias_perlayer_quantization_int8_instance.cpp b/library/src/tensor_operation_instance/gpu/quantization/device_conv2d_xdl_bias_perlayer_quantization_int8_instance.cpp new file mode 100644 index 0000000000..06eed76014 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/quantization/device_conv2d_xdl_bias_perlayer_quantization_int8_instance.cpp @@ -0,0 +1,68 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_conv2d_xdl_int8_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +void add_device_conv2d_bias_perlayer_quantization_int8_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_conv2d_int8_32Ds_instances{}); + add_device_operation_instances( + instances, + device_conv2d_int8_32Ds_instances{}); + add_device_operation_instances( + instances, + device_conv2d_int8_32Ds_instances{}); +} + +void add_device_conv2d_bias_relu_perlayer_quantization_int8_instances( + std::vector>>& instances) +{ + add_device_operation_instances(instances, + device_conv2d_int8_32Ds_instances{}); + + add_device_operation_instances( + instances, + device_conv2d_int8_32Ds_instances{}); + + add_device_operation_instances(instances, + device_conv2d_int8_32Ds_instances{}); +} +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/quantization/device_conv2d_xdl_bias_quant_int8_instance.cpp b/library/src/tensor_operation_instance/gpu/quantization/device_conv2d_xdl_bias_quant_int8_instance.cpp deleted file mode 100644 index 774758fb69..0000000000 --- a/library/src/tensor_operation_instance/gpu/quantization/device_conv2d_xdl_bias_quant_int8_instance.cpp +++ /dev/null @@ -1,112 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. - -#include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_cshuffle.hpp" -#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" -#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { - -template -using S = ck::Sequence; - -using GNHWC = ck::tensor_layout::convolution::GNHWC; -using GKYXC = ck::tensor_layout::convolution::GKYXC; -using GNHWK = ck::tensor_layout::convolution::GNHWK; -using GK = ck::tensor_layout::convolution::G_K; -using PassThrough = ck::tensor_operation::element_wise::PassThrough; -using Relu = ck::tensor_operation::element_wise::Relu; - -using GK_Tuple = ck::Tuple; -using I32_Tuple = ck::Tuple; - -using Add_Mul_Clamp = ck::tensor_operation::element_wise::Add_Activation_Mul_Clamp; -using Add_Relu_Mul_Clamp = ck::tensor_operation::element_wise::Add_Activation_Mul_Clamp; - -static constexpr ck::index_t NDimSpatial = 2; -static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding; -static constexpr auto ConvFwdDefault = - ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; -static constexpr auto ConvFwd1x1P0 = - ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0; -static constexpr auto ConvFwd1x1S1P0 = - ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0; - -// TODO - Add more instances -template -// clang-format off -using device_conv2d_int8_instances = - std::tuple < - //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| - //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| - //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, GK_Tuple, GNHWK, int8_t, int8_t, int32_t, int32_t, I32_Tuple, int8_t, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmSpec, 1, 256, 256, 128, 64, 16, 16, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, GK_Tuple, GNHWK, int8_t, int8_t, int32_t, int32_t, I32_Tuple, int8_t, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmSpec, 1, 256, 128, 256, 64, 16, 16, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, GK_Tuple, GNHWK, int8_t, int8_t, int32_t, int32_t, I32_Tuple, int8_t, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmSpec, 1, 128, 128, 128, 64, 16, 16, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, GK_Tuple, GNHWK, int8_t, int8_t, int32_t, int32_t, I32_Tuple, int8_t, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmSpec, 1, 256, 128, 128, 64, 16, 16, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, GK_Tuple, GNHWK, int8_t, int8_t, int32_t, int32_t, I32_Tuple, int8_t, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmSpec, 1, 128, 128, 64, 64, 16, 16, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 2>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, GK_Tuple, GNHWK, int8_t, int8_t, int32_t, int32_t, I32_Tuple, int8_t, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmSpec, 1, 128, 64, 128, 64, 16, 16, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, GK_Tuple, GNHWK, int8_t, int8_t, int32_t, int32_t, I32_Tuple, int8_t, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmSpec, 1, 64, 64, 64, 64, 16, 16, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 2>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, GK_Tuple, GNHWK, int8_t, int8_t, int32_t, int32_t, I32_Tuple, int8_t, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmSpec, 1, 256, 128, 64, 64, 16, 16, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, GK_Tuple, GNHWK, int8_t, int8_t, int32_t, int32_t, I32_Tuple, int8_t, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmSpec, 1, 256, 64, 128, 64, 16, 16, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, GK_Tuple, GNHWK, int8_t, int8_t, int32_t, int32_t, I32_Tuple, int8_t, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmSpec, 1, 128, 128, 32, 64, 16, 16, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 2>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, GK_Tuple, GNHWK, int8_t, int8_t, int32_t, int32_t, I32_Tuple, int8_t, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmSpec, 1, 128, 32, 128, 64, 16, 16, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, GK_Tuple, GNHWK, int8_t, int8_t, int32_t, int32_t, I32_Tuple, int8_t, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmSpec, 1, 64, 64, 32, 64, 16, 16, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 2>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, GK_Tuple, GNHWK, int8_t, int8_t, int32_t, int32_t, I32_Tuple, int8_t, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmSpec, 1, 64, 32, 64, 64, 16, 16, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 2>, 8> - >; -// clang-format on - -void add_device_conv2d_bias_perlayer_quantization_int8_instances( - std::vector, - GNHWK, - int8_t, - int8_t, - ck::Tuple, - int8_t, - PassThrough, - PassThrough, - Add_Mul_Clamp>>>& instances) -{ - add_device_operation_instances(instances, - device_conv2d_int8_instances{}); - add_device_operation_instances(instances, - device_conv2d_int8_instances{}); - add_device_operation_instances(instances, - device_conv2d_int8_instances{}); -} - -void add_device_conv2d_bias_relu_perlayer_quantization_int8_instances( - std::vector, - GNHWK, - int8_t, - int8_t, - ck::Tuple, - int8_t, - PassThrough, - PassThrough, - Add_Relu_Mul_Clamp>>>& instances) -{ - add_device_operation_instances( - instances, device_conv2d_int8_instances{}); - add_device_operation_instances( - instances, device_conv2d_int8_instances{}); - add_device_operation_instances( - instances, device_conv2d_int8_instances{}); -} - -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/quantization/device_conv2d_xdl_int8_instance.hpp b/library/src/tensor_operation_instance/gpu/quantization/device_conv2d_xdl_int8_instance.hpp new file mode 100644 index 0000000000..6904e269f9 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/quantization/device_conv2d_xdl_int8_instance.hpp @@ -0,0 +1,111 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using Empty_Tuple = ck::Tuple<>; +template +using S = ck::Sequence; + +using GNHWC = ck::tensor_layout::convolution::GNHWC; +using GKYXC = ck::tensor_layout::convolution::GKYXC; +using GNHWK = ck::tensor_layout::convolution::GNHWK; +using GK = ck::tensor_layout::convolution::G_K; +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using Relu = ck::tensor_operation::element_wise::Relu; + +using GK_Tuple = ck::Tuple; +using GK_GK_Tuple = ck::Tuple; +using I32_Tuple = ck::Tuple; +using F32_Tuple = ck::Tuple; +using I32_F32_Tuple = ck::Tuple; + +using Mul_Clamp = ck::tensor_operation::element_wise::Activation_Mul_Clamp; +using Relu_Mul_Clamp = ck::tensor_operation::element_wise::Activation_Mul_Clamp; + +using Add_Mul_Clamp = ck::tensor_operation::element_wise::Add_Activation_Mul_Clamp; +using Add_Relu_Mul_Clamp = ck::tensor_operation::element_wise::Add_Activation_Mul_Clamp; + +using Mul2_Clamp = ck::tensor_operation::element_wise::Activation_Mul2_Clamp; +using Relu_Mul2_Clamp = ck::tensor_operation::element_wise::Activation_Mul2_Clamp; + +using Add_Mul2_Clamp = ck::tensor_operation::element_wise::Add_Activation_Mul2_Clamp; +using Add_Relu_Mul2_Clamp = ck::tensor_operation::element_wise::Add_Activation_Mul2_Clamp; + +static constexpr ck::index_t NDimSpatial = 2; +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding; +static constexpr auto ConvFwdDefault = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; +static constexpr auto ConvFwd1x1P0 = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0; +static constexpr auto ConvFwd1x1S1P0 = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0; + +template +// clang-format off +using device_conv2d_int8_instances = + std::tuple < + //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, DsLayout, GNHWK, int8_t, int8_t, int32_t, int32_t, DsDatatype, int8_t, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmSpec, 1, 256, 256, 128, 64, 16, 16, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 4>, 16>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, DsLayout, GNHWK, int8_t, int8_t, int32_t, int32_t, DsDatatype, int8_t, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmSpec, 1, 256, 128, 256, 64, 16, 16, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 4>, 16>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, DsLayout, GNHWK, int8_t, int8_t, int32_t, int32_t, DsDatatype, int8_t, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmSpec, 1, 128, 128, 128, 64, 16, 16, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 4>, 16>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, DsLayout, GNHWK, int8_t, int8_t, int32_t, int32_t, DsDatatype, int8_t, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmSpec, 1, 256, 128, 128, 64, 16, 16, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 4>, 16>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, DsLayout, GNHWK, int8_t, int8_t, int32_t, int32_t, DsDatatype, int8_t, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmSpec, 1, 128, 128, 64, 64, 16, 16, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 2>, 16>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, DsLayout, GNHWK, int8_t, int8_t, int32_t, int32_t, DsDatatype, int8_t, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmSpec, 1, 128, 64, 128, 64, 16, 16, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 4>, 16>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, DsLayout, GNHWK, int8_t, int8_t, int32_t, int32_t, DsDatatype, int8_t, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmSpec, 1, 64, 64, 64, 64, 16, 16, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 2>, 16>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, DsLayout, GNHWK, int8_t, int8_t, int32_t, int32_t, DsDatatype, int8_t, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmSpec, 1, 256, 128, 64, 64, 16, 16, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 4>, 16>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, DsLayout, GNHWK, int8_t, int8_t, int32_t, int32_t, DsDatatype, int8_t, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmSpec, 1, 256, 64, 128, 64, 16, 16, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 4>, 16>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, DsLayout, GNHWK, int8_t, int8_t, int32_t, int32_t, DsDatatype, int8_t, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmSpec, 1, 128, 128, 32, 64, 16, 16, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 2>, 16>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, DsLayout, GNHWK, int8_t, int8_t, int32_t, int32_t, DsDatatype, int8_t, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmSpec, 1, 128, 32, 128, 64, 16, 16, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 4>, 16>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, DsLayout, GNHWK, int8_t, int8_t, int32_t, int32_t, DsDatatype, int8_t, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmSpec, 1, 64, 64, 32, 64, 16, 16, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 2>, 16>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, DsLayout, GNHWK, int8_t, int8_t, int32_t, int32_t, DsDatatype, int8_t, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmSpec, 1, 64, 32, 64, 64, 16, 16, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 2>, 16> + >; +// clang-format on + +// for conv + multiple of 32 bit Ds. bit of Ds will affect the ScalarPerVector of C +template +// clang-format off +using device_conv2d_int8_32Ds_instances = + std::tuple < + //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, DsLayout, GNHWK, int8_t, int8_t, int32_t, int32_t, DsDatatype, int8_t, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmSpec, 1, 256, 256, 128, 64, 16, 16, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, DsLayout, GNHWK, int8_t, int8_t, int32_t, int32_t, DsDatatype, int8_t, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmSpec, 1, 256, 128, 256, 64, 16, 16, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, DsLayout, GNHWK, int8_t, int8_t, int32_t, int32_t, DsDatatype, int8_t, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmSpec, 1, 128, 128, 128, 64, 16, 16, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, DsLayout, GNHWK, int8_t, int8_t, int32_t, int32_t, DsDatatype, int8_t, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmSpec, 1, 256, 128, 128, 64, 16, 16, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, DsLayout, GNHWK, int8_t, int8_t, int32_t, int32_t, DsDatatype, int8_t, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmSpec, 1, 128, 128, 64, 64, 16, 16, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 2>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, DsLayout, GNHWK, int8_t, int8_t, int32_t, int32_t, DsDatatype, int8_t, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmSpec, 1, 128, 64, 128, 64, 16, 16, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, DsLayout, GNHWK, int8_t, int8_t, int32_t, int32_t, DsDatatype, int8_t, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmSpec, 1, 64, 64, 64, 64, 16, 16, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 2>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, DsLayout, GNHWK, int8_t, int8_t, int32_t, int32_t, DsDatatype, int8_t, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmSpec, 1, 256, 128, 64, 64, 16, 16, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, DsLayout, GNHWK, int8_t, int8_t, int32_t, int32_t, DsDatatype, int8_t, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmSpec, 1, 256, 64, 128, 64, 16, 16, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, DsLayout, GNHWK, int8_t, int8_t, int32_t, int32_t, DsDatatype, int8_t, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmSpec, 1, 128, 128, 32, 64, 16, 16, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 2>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, DsLayout, GNHWK, int8_t, int8_t, int32_t, int32_t, DsDatatype, int8_t, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmSpec, 1, 128, 32, 128, 64, 16, 16, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, DsLayout, GNHWK, int8_t, int8_t, int32_t, int32_t, DsDatatype, int8_t, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmSpec, 1, 64, 64, 32, 64, 16, 16, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 2>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, DsLayout, GNHWK, int8_t, int8_t, int32_t, int32_t, DsDatatype, int8_t, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmSpec, 1, 64, 32, 64, 64, 16, 16, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 2>, 8> + >; +// clang-format on + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/quantization/device_conv2d_xdl_perchannel_quantization_int8_instance.cpp b/library/src/tensor_operation_instance/gpu/quantization/device_conv2d_xdl_perchannel_quantization_int8_instance.cpp new file mode 100644 index 0000000000..5f1aa0c5c7 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/quantization/device_conv2d_xdl_perchannel_quantization_int8_instance.cpp @@ -0,0 +1,62 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_conv2d_xdl_int8_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +void add_device_conv2d_perchannel_quantization_int8_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_conv2d_int8_32Ds_instances{}); + add_device_operation_instances( + instances, + device_conv2d_int8_32Ds_instances{}); + add_device_operation_instances( + instances, + device_conv2d_int8_32Ds_instances{}); +} + +void add_device_conv2d_relu_perchannel_quantization_int8_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_conv2d_int8_32Ds_instances{}); + add_device_operation_instances( + instances, + device_conv2d_int8_32Ds_instances{}); + add_device_operation_instances( + instances, + device_conv2d_int8_32Ds_instances{}); +} +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/quantization/device_conv2d_xdl_perlayer_quantization_int8_instance.cpp b/library/src/tensor_operation_instance/gpu/quantization/device_conv2d_xdl_perlayer_quantization_int8_instance.cpp new file mode 100644 index 0000000000..83435d8119 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/quantization/device_conv2d_xdl_perlayer_quantization_int8_instance.cpp @@ -0,0 +1,62 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_conv2d_xdl_int8_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +void add_device_conv2d_perlayer_quantization_int8_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_conv2d_int8_instances{}); + add_device_operation_instances( + instances, + device_conv2d_int8_instances{}); + add_device_operation_instances( + instances, + device_conv2d_int8_instances{}); +} + +void add_device_conv2d_relu_perlayer_quantization_int8_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_conv2d_int8_instances{}); + add_device_operation_instances( + instances, + device_conv2d_int8_instances{}); + add_device_operation_instances( + instances, + device_conv2d_int8_instances{}); +} +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/quantization/device_conv2d_xdl_quant_int8_instance.cpp b/library/src/tensor_operation_instance/gpu/quantization/device_conv2d_xdl_quant_int8_instance.cpp deleted file mode 100644 index eba5954c55..0000000000 --- a/library/src/tensor_operation_instance/gpu/quantization/device_conv2d_xdl_quant_int8_instance.cpp +++ /dev/null @@ -1,109 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. - -#include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_cshuffle.hpp" -#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" -#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { - -using Empty_Tuple = ck::Tuple<>; -template -using S = ck::Sequence; - -using GNHWC = ck::tensor_layout::convolution::GNHWC; -using GKYXC = ck::tensor_layout::convolution::GKYXC; -using GNHWK = ck::tensor_layout::convolution::GNHWK; -using PassThrough = ck::tensor_operation::element_wise::PassThrough; -using Relu = ck::tensor_operation::element_wise::Relu; - -using Mul_Clamp = ck::tensor_operation::element_wise::Activation_Mul_Clamp; -using Relu_Mul_Clamp = ck::tensor_operation::element_wise::Activation_Mul_Clamp; - -static constexpr ck::index_t NDimSpatial = 2; -static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding; -static constexpr auto ConvFwdDefault = - ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; -static constexpr auto ConvFwd1x1P0 = - ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0; -static constexpr auto ConvFwd1x1S1P0 = - ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0; - -// TODO - Add more instances -template -// clang-format off -using device_conv2d_int8_instances = - std::tuple < - //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| - //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| - //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, int8_t, int8_t, int32_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmSpec, 1, 256, 256, 128, 64, 16, 16, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 4>, 16>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, int8_t, int8_t, int32_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmSpec, 1, 256, 128, 256, 64, 16, 16, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 4>, 16>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, int8_t, int8_t, int32_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmSpec, 1, 128, 128, 128, 64, 16, 16, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 4>, 16>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, int8_t, int8_t, int32_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmSpec, 1, 256, 128, 128, 64, 16, 16, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 4>, 16>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, int8_t, int8_t, int32_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmSpec, 1, 128, 128, 64, 64, 16, 16, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 2>, 16>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, int8_t, int8_t, int32_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmSpec, 1, 128, 64, 128, 64, 16, 16, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 4>, 16>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, int8_t, int8_t, int32_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmSpec, 1, 64, 64, 64, 64, 16, 16, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 2>, 16>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, int8_t, int8_t, int32_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmSpec, 1, 256, 128, 64, 64, 16, 16, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 4>, 16>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, int8_t, int8_t, int32_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmSpec, 1, 256, 64, 128, 64, 16, 16, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 4>, 16>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, int8_t, int8_t, int32_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmSpec, 1, 128, 128, 32, 64, 16, 16, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 2>, 16>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, int8_t, int8_t, int32_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmSpec, 1, 128, 32, 128, 64, 16, 16, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 4>, 16>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, int8_t, int8_t, int32_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmSpec, 1, 64, 64, 32, 64, 16, 16, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 2>, 16>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, int8_t, int8_t, int32_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmSpec, 1, 64, 32, 64, 64, 16, 16, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 2>, 16> - >; -// clang-format on - -void add_device_conv2d_perlayer_quantization_int8_instances( - std::vector>>& instances) -{ - add_device_operation_instances(instances, - device_conv2d_int8_instances{}); - add_device_operation_instances(instances, - device_conv2d_int8_instances{}); - add_device_operation_instances(instances, - device_conv2d_int8_instances{}); -} - -void add_device_conv2d_relu_perlayer_quantization_int8_instances( - std::vector>>& instances) -{ - add_device_operation_instances(instances, - device_conv2d_int8_instances{}); - add_device_operation_instances(instances, - device_conv2d_int8_instances{}); - add_device_operation_instances(instances, - device_conv2d_int8_instances{}); -} - -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck