From 9052e8501c809045837cad0339c03b06b5cd4c44 Mon Sep 17 00:00:00 2001 From: rocking5566 Date: Thu, 3 Nov 2022 03:56:26 +0800 Subject: [PATCH] Conv perlayer int8 quantization (#471) * Add conv2d requant example * Fix bash error * Rename example * 1. Rename gemm quantization 2. shares the requantization lambda function with conv * Refine declare type * Add conv bias relu quantization exmaple * clang format * Fix compile error due to merge develop * Fix CI error * Extract quantization post operation into another file * Support quantization for non piecewise linear function * Add instance for conv quantization * Add convolution quantization factory * Add convolution quantization client example * Add more instances with different template parameters * clang format * Sync the naming with the develop [ROCm/composable_kernel commit: 226bc02b73468a9078374459dfe188d4051d1f7c] --- client_example/09_quantization/CMakeLists.txt | 5 + ...2d_fwd_bias_relu_perlayer_quantization.cpp | 198 +++++++++++ .../conv2d_fwd_perlayer_quantization.cpp | 192 +++++++++++ .../14_gemm_xdl_quantization/CMakeLists.txt | 1 + .../gemm_xdl_relu_quantization_int8.cpp} | 40 +-- .../CMakeLists.txt | 1 - example/44_conv2d_fwd_quant/CMakeLists.txt | 2 + ...l_bias_relu_perlayer_quantization_int8.cpp | 317 ++++++++++++++++++ ...v2d_fwd_xdl_perlayer_quantization_int8.cpp | 277 +++++++++++++++ .../gpu/element/element_wise_operation.hpp | 1 + .../gpu/element/quantization_operation.hpp | 86 +++++ .../element/unary_element_wise_operation.hpp | 1 + .../device_operation_instance_factory.hpp | 14 + ...ion_bias_forward_perlayer_quantization.hpp | 114 +++++++ ...volution_forward_perlayer_quantization.hpp | 110 ++++++ .../gpu/quantization/CMakeLists.txt | 4 + ...ce_conv2d_xdl_bias_quant_int8_instance.cpp | 112 +++++++ .../device_conv2d_xdl_quant_int8_instance.cpp | 109 ++++++ script/cmake-ck-dev.sh | 2 +- script/cmake-ck-release.sh | 2 +- 20 files changed, 1553 insertions(+), 35 deletions(-) create mode 100644 client_example/09_quantization/CMakeLists.txt create mode 100644 client_example/09_quantization/conv2d_fwd_bias_relu_perlayer_quantization.cpp create mode 100644 client_example/09_quantization/conv2d_fwd_perlayer_quantization.cpp create mode 100644 example/14_gemm_xdl_quantization/CMakeLists.txt rename example/{14_gemm_xdl_requant_relu_requant/gemm_xdl_requant_relu_requant_int8.cpp => 14_gemm_xdl_quantization/gemm_xdl_relu_quantization_int8.cpp} (86%) delete mode 100644 example/14_gemm_xdl_requant_relu_requant/CMakeLists.txt create mode 100644 example/44_conv2d_fwd_quant/CMakeLists.txt create mode 100644 example/44_conv2d_fwd_quant/conv2d_fwd_xdl_bias_relu_perlayer_quantization_int8.cpp create mode 100644 example/44_conv2d_fwd_quant/conv2d_fwd_xdl_perlayer_quantization_int8.cpp create mode 100644 include/ck/tensor_operation/gpu/element/quantization_operation.hpp create mode 100644 library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_bias_forward_perlayer_quantization.hpp create mode 100644 library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_perlayer_quantization.hpp create mode 100644 library/src/tensor_operation_instance/gpu/quantization/CMakeLists.txt create mode 100644 library/src/tensor_operation_instance/gpu/quantization/device_conv2d_xdl_bias_quant_int8_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/quantization/device_conv2d_xdl_quant_int8_instance.cpp diff --git a/client_example/09_quantization/CMakeLists.txt b/client_example/09_quantization/CMakeLists.txt new file mode 100644 index 0000000000..eceaa84174 --- /dev/null +++ b/client_example/09_quantization/CMakeLists.txt @@ -0,0 +1,5 @@ +add_executable(client_conv2d_fwd_bias_relu_perlayer_quantization conv2d_fwd_bias_relu_perlayer_quantization.cpp) +target_link_libraries(client_conv2d_fwd_bias_relu_perlayer_quantization PRIVATE composable_kernel::device_operations) + +add_executable(client_conv2d_fwd_perlayer_quantization conv2d_fwd_perlayer_quantization.cpp) +target_link_libraries(client_conv2d_fwd_perlayer_quantization PRIVATE composable_kernel::device_operations) diff --git a/client_example/09_quantization/conv2d_fwd_bias_relu_perlayer_quantization.cpp b/client_example/09_quantization/conv2d_fwd_bias_relu_perlayer_quantization.cpp new file mode 100644 index 0000000000..7416e12620 --- /dev/null +++ b/client_example/09_quantization/conv2d_fwd_bias_relu_perlayer_quantization.cpp @@ -0,0 +1,198 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_convolution_bias_forward_perlayer_quantization.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_conv_fwd.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +using InDataType = int8_t; +using WeiDataType = int8_t; +using BiasDataType = int32_t; +using OutDataType = int8_t; + +using InLayout = ck::tensor_layout::convolution::GNHWC; +using WeiLayout = ck::tensor_layout::convolution::GKYXC; +using BiasLayout = ck::tensor_layout::convolution::G_K; +using OutLayout = ck::tensor_layout::convolution::GNHWK; +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using ActivationOp = ck::tensor_operation::element_wise::Relu; +using OutElementOp = ck::tensor_operation::element_wise::Add_Activation_Mul_Clamp; + +static constexpr ck::index_t NumDimSpatial = 2; +static constexpr ck::index_t G = 1; +static constexpr ck::index_t N = 4; +static constexpr ck::index_t K = 64; +static constexpr ck::index_t C = 32; +static constexpr ck::index_t Y = 3; +static constexpr ck::index_t X = 3; +static constexpr ck::index_t Hi = 71; +static constexpr ck::index_t Wi = 71; +static constexpr ck::index_t Ho = 36; +static constexpr ck::index_t Wo = 36; + +struct SimpleDeviceMem +{ + SimpleDeviceMem() = delete; + + SimpleDeviceMem(std::size_t mem_size) : p_mem_{} + { + (void)hipMalloc(static_cast(&p_mem_), mem_size); + } + + void* GetDeviceBuffer() { return p_mem_; } + + ~SimpleDeviceMem() { (void)hipFree(p_mem_); } + + void* p_mem_; +}; + +int main(int argc, char* argv[]) +{ + std::array in_lengths{G, N, C, Hi, Wi}; + std::array in_strides{N * Hi * Wi * C, Hi * Wi * C, 1, Wi * C, C}; + std::array weight_lengths{G, K, C, Y, X}; + std::array weight_strides{K * Y * X * C, Y * X * C, 1, X * C, C}; + std::array bias_lengths{G, N, K, Ho, Wo}; + std::array bias_strides{K, 0, 1, 0, 0}; + std::array out_lengths{G, N, C, Ho, Wo}; + std::array out_strides{N * Ho * Wo * C, Ho * Wo * C, 1, Wo * C, C}; + std::array in_left_pad{1, 1}; + std::array in_right_pad{1, 1}; + std::array conv_strides{2, 2}; + std::array conv_dilations{1, 1}; + + SimpleDeviceMem in(sizeof(InDataType) * N * Hi * Wi * C); + SimpleDeviceMem wei(sizeof(WeiDataType) * K * Y * X * C); + SimpleDeviceMem bias(sizeof(BiasDataType) * K * Y * X * C); + SimpleDeviceMem out(sizeof(OutDataType) * N * Ho * Wo * K); + + using DeviceOp = + ck::tensor_operation::device::DeviceGroupedConvFwdMultipleD, + OutLayout, + InDataType, + WeiDataType, + ck::Tuple, + OutDataType, + PassThrough, + PassThrough, + OutElementOp>; + // get device op instances + const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + std::cout << "found " << op_ptrs.size() << " instances" << std::endl; + + std::string best_op_name; + int best_op_id = -1; + float best_avg_time = std::numeric_limits::max(); + float best_gb_per_sec = 0; + float best_tflops = 0; + + // profile device operation instances + std::cout << "Run all instances and do timing" << std::endl; + + for(int i = 0; i < op_ptrs.size(); ++i) + { + auto& op_ptr = op_ptrs[i]; + auto argument_ptr = op_ptr->MakeArgumentPointer(in.GetDeviceBuffer(), + wei.GetDeviceBuffer(), + {bias.GetDeviceBuffer()}, + out.GetDeviceBuffer(), + in_lengths, + in_strides, + weight_lengths, + weight_strides, + {bias_lengths}, + {bias_strides}, + out_lengths, + out_strides, + conv_strides, + conv_dilations, + in_left_pad, + in_right_pad, + PassThrough{}, + PassThrough{}, + OutElementOp{0.5f, ActivationOp{}}); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + std::string op_name = op_ptr->GetTypeString(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + float avg_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true}); + + std::size_t flop = G * 2 * N * K * C * Ho * Wo * Y * X; + std::size_t num_bytes = G * sizeof(InDataType) * N * Hi * Wi * C + + G * sizeof(WeiDataType) * K * Y * X * C + + G * sizeof(OutDataType) * N * Ho * Wo * K; + + float tflops = static_cast(flop) / 1.E9 / avg_time; + float gb_per_sec = num_bytes / 1.E6 / avg_time; + + std::cout << "Perf: " << std::setw(10) << avg_time << " ms, " << tflops << " TFlops, " + << gb_per_sec << " GB/s, " << op_name << std::endl; + + if(tflops > best_tflops) + { + best_op_id = i; + best_op_name = op_name; + best_avg_time = avg_time; + best_gb_per_sec = gb_per_sec; + best_tflops = tflops; + } + } + else + { + std::cout << op_name << " does not support this problem" << std::endl; + } + } + + std::cout << "Best Perf: " << std::setw(10) << best_avg_time << " ms, " << best_tflops + << " TFlops, " << best_gb_per_sec << " GB/s, " << best_op_name << std::endl; + + // run the best intance + { + auto& op_ptr = op_ptrs[best_op_id]; + std::cout << "Run the best instance without timing: " << op_ptr->GetTypeString() + << std::endl; + auto argument_ptr = op_ptr->MakeArgumentPointer(in.GetDeviceBuffer(), + wei.GetDeviceBuffer(), + {bias.GetDeviceBuffer()}, + out.GetDeviceBuffer(), + in_lengths, + in_strides, + weight_lengths, + weight_strides, + {bias_lengths}, + {bias_strides}, + out_lengths, + out_strides, + conv_strides, + conv_dilations, + in_left_pad, + in_right_pad, + PassThrough{}, + PassThrough{}, + OutElementOp{0.5f, ActivationOp{}}); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false}); + } + + std::cout << "Done" << std::endl; + } + + return 0; +} \ No newline at end of file diff --git a/client_example/09_quantization/conv2d_fwd_perlayer_quantization.cpp b/client_example/09_quantization/conv2d_fwd_perlayer_quantization.cpp new file mode 100644 index 0000000000..81176fd2e3 --- /dev/null +++ b/client_example/09_quantization/conv2d_fwd_perlayer_quantization.cpp @@ -0,0 +1,192 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_perlayer_quantization.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_conv_fwd.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +using InDataType = int8_t; +using WeiDataType = int8_t; +using OutDataType = int8_t; + +using InLayout = ck::tensor_layout::convolution::GNHWC; +using WeiLayout = ck::tensor_layout::convolution::GKYXC; +using OutLayout = ck::tensor_layout::convolution::GNHWK; +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using ActivationOp = PassThrough; +using OutElementOp = ck::tensor_operation::element_wise::Activation_Mul_Clamp; + +static constexpr ck::index_t NumDimSpatial = 2; +static constexpr ck::index_t G = 1; +static constexpr ck::index_t N = 4; +static constexpr ck::index_t K = 64; +static constexpr ck::index_t C = 32; +static constexpr ck::index_t Y = 3; +static constexpr ck::index_t X = 3; +static constexpr ck::index_t Hi = 71; +static constexpr ck::index_t Wi = 71; +static constexpr ck::index_t Ho = 36; +static constexpr ck::index_t Wo = 36; + +struct SimpleDeviceMem +{ + SimpleDeviceMem() = delete; + + SimpleDeviceMem(std::size_t mem_size) : p_mem_{} + { + (void)hipMalloc(static_cast(&p_mem_), mem_size); + } + + void* GetDeviceBuffer() { return p_mem_; } + + ~SimpleDeviceMem() { (void)hipFree(p_mem_); } + + void* p_mem_; +}; + +int main(int argc, char* argv[]) +{ + std::array in_lengths{G, N, C, Hi, Wi}; + std::array in_strides{N * Hi * Wi * C, Hi * Wi * C, 1, Wi * C, C}; + std::array weight_lengths{G, K, C, Y, X}; + std::array weight_strides{K * Y * X * C, Y * X * C, 1, X * C, C}; + std::array out_lengths{G, N, C, Ho, Wo}; + std::array out_strides{N * Ho * Wo * C, Ho * Wo * C, 1, Wo * C, C}; + std::array in_left_pad{1, 1}; + std::array in_right_pad{1, 1}; + std::array conv_strides{2, 2}; + std::array conv_dilations{1, 1}; + + SimpleDeviceMem in(sizeof(InDataType) * N * Hi * Wi * C); + SimpleDeviceMem wei(sizeof(WeiDataType) * K * Y * X * C); + SimpleDeviceMem out(sizeof(OutDataType) * N * Ho * Wo * K); + + using DeviceOp = ck::tensor_operation::device::DeviceGroupedConvFwdMultipleD, + OutLayout, + InDataType, + WeiDataType, + ck::Tuple<>, + OutDataType, + PassThrough, + PassThrough, + OutElementOp>; + // get device op instances + const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + std::cout << "found " << op_ptrs.size() << " instances" << std::endl; + + std::string best_op_name; + int best_op_id = -1; + float best_avg_time = std::numeric_limits::max(); + float best_gb_per_sec = 0; + float best_tflops = 0; + + // profile device operation instances + std::cout << "Run all instances and do timing" << std::endl; + + for(int i = 0; i < op_ptrs.size(); ++i) + { + auto& op_ptr = op_ptrs[i]; + auto argument_ptr = op_ptr->MakeArgumentPointer(in.GetDeviceBuffer(), + wei.GetDeviceBuffer(), + {}, + out.GetDeviceBuffer(), + in_lengths, + in_strides, + weight_lengths, + weight_strides, + {}, + {}, + out_lengths, + out_strides, + conv_strides, + conv_dilations, + in_left_pad, + in_right_pad, + PassThrough{}, + PassThrough{}, + OutElementOp{0.5f, ActivationOp{}}); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + std::string op_name = op_ptr->GetTypeString(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + float avg_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true}); + + std::size_t flop = G * 2 * N * K * C * Ho * Wo * Y * X; + std::size_t num_bytes = G * sizeof(InDataType) * N * Hi * Wi * C + + G * sizeof(WeiDataType) * K * Y * X * C + + G * sizeof(OutDataType) * N * Ho * Wo * K; + + float tflops = static_cast(flop) / 1.E9 / avg_time; + float gb_per_sec = num_bytes / 1.E6 / avg_time; + + std::cout << "Perf: " << std::setw(10) << avg_time << " ms, " << tflops << " TFlops, " + << gb_per_sec << " GB/s, " << op_name << std::endl; + + if(tflops > best_tflops) + { + best_op_id = i; + best_op_name = op_name; + best_avg_time = avg_time; + best_gb_per_sec = gb_per_sec; + best_tflops = tflops; + } + } + else + { + std::cout << op_name << " does not support this problem" << std::endl; + } + } + + std::cout << "Best Perf: " << std::setw(10) << best_avg_time << " ms, " << best_tflops + << " TFlops, " << best_gb_per_sec << " GB/s, " << best_op_name << std::endl; + + // run the best intance + { + auto& op_ptr = op_ptrs[best_op_id]; + std::cout << "Run the best instance without timing: " << op_ptr->GetTypeString() + << std::endl; + auto argument_ptr = op_ptr->MakeArgumentPointer(in.GetDeviceBuffer(), + wei.GetDeviceBuffer(), + {}, + out.GetDeviceBuffer(), + in_lengths, + in_strides, + weight_lengths, + weight_strides, + {}, + {}, + out_lengths, + out_strides, + conv_strides, + conv_dilations, + in_left_pad, + in_right_pad, + PassThrough{}, + PassThrough{}, + OutElementOp{0.5f, ActivationOp{}}); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false}); + } + + std::cout << "Done" << std::endl; + } + + return 0; +} \ No newline at end of file diff --git a/example/14_gemm_xdl_quantization/CMakeLists.txt b/example/14_gemm_xdl_quantization/CMakeLists.txt new file mode 100644 index 0000000000..9674aba2a4 --- /dev/null +++ b/example/14_gemm_xdl_quantization/CMakeLists.txt @@ -0,0 +1 @@ +add_example_executable(example_gemm_xdl_relu_quantization_int8 gemm_xdl_relu_quantization_int8.cpp) \ No newline at end of file diff --git a/example/14_gemm_xdl_requant_relu_requant/gemm_xdl_requant_relu_requant_int8.cpp b/example/14_gemm_xdl_quantization/gemm_xdl_relu_quantization_int8.cpp similarity index 86% rename from example/14_gemm_xdl_requant_relu_requant/gemm_xdl_requant_relu_requant_int8.cpp rename to example/14_gemm_xdl_quantization/gemm_xdl_relu_quantization_int8.cpp index 79838d1b2f..d2c9e66d31 100644 --- a/example/14_gemm_xdl_requant_relu_requant/gemm_xdl_requant_relu_requant_int8.cpp +++ b/example/14_gemm_xdl_quantization/gemm_xdl_relu_quantization_int8.cpp @@ -18,30 +18,12 @@ #include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" #include "ck/library/utility/check_err.hpp" -struct RequantReluRequant -{ - // FIXME: We just need one scale for Relu / Leaky Relu / PRelu - RequantReluRequant(float scaleGemm, float scaleRelu) - : scaleGemm_(scaleGemm), scaleRelu_(scaleRelu) - { - } - - __host__ __device__ constexpr void operator()(float& y, const float& x) const - { - float gemm_requant = scaleGemm_ * x; - float relu = gemm_requant > 0 ? gemm_requant : 0; - float relu_requant = scaleRelu_ * relu; - y = relu_requant > 127 ? 127 : relu_requant < -128 ? -128 : relu_requant; - } - - float scaleGemm_; - float scaleRelu_; -}; - template using S = ck::Sequence; -using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using ActivationOp = ck::tensor_operation::element_wise::Relu; +using CElementOp = ck::tensor_operation::element_wise::Activation_Mul_Clamp; using ADataType = int8_t; using BDataType = int8_t; @@ -67,7 +49,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle CShuffleDataType, // typename CShuffleDataType, PassThrough, // typename AElementwiseOperation, PassThrough, // typename BElementwiseOperation, - RequantReluRequant, // typename CElementwiseOperation, + CElementOp, // typename CElementwiseOperation, GemmDefault, // GemmSpecialization GemmSpec, 1, // index_t NumGemmKPrefetchStage, 256, // index_t BlockSize, @@ -100,13 +82,8 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle 16>; // index_t CShuffleBlockTransferScalarPerVector_NPerBlock> // clang-format on -using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; +using ReferenceGemmInstance = ck::tensor_operation::host:: + ReferenceGemm; int main(int argc, char* argv[]) { @@ -123,8 +100,7 @@ int main(int argc, char* argv[]) ck::index_t StrideB = 4096; ck::index_t StrideC = 4096; - float scale_gemm = 0.03; - float scale_relu = 1; + float quant_multiplier = 0.03; if(argc == 4) { @@ -199,7 +175,7 @@ int main(int argc, char* argv[]) auto a_element_op = PassThrough{}; auto b_element_op = PassThrough{}; - auto c_element_op = RequantReluRequant{scale_gemm, scale_relu}; + auto c_element_op = CElementOp{quant_multiplier, ActivationOp{}}; // do GEMM auto gemm = DeviceGemmInstance{}; diff --git a/example/14_gemm_xdl_requant_relu_requant/CMakeLists.txt b/example/14_gemm_xdl_requant_relu_requant/CMakeLists.txt deleted file mode 100644 index 0f5b8e1bc7..0000000000 --- a/example/14_gemm_xdl_requant_relu_requant/CMakeLists.txt +++ /dev/null @@ -1 +0,0 @@ -add_example_executable(example_gemm_xdl_requant_relu_requant_int8 gemm_xdl_requant_relu_requant_int8.cpp) \ No newline at end of file diff --git a/example/44_conv2d_fwd_quant/CMakeLists.txt b/example/44_conv2d_fwd_quant/CMakeLists.txt new file mode 100644 index 0000000000..1ecf89ccb8 --- /dev/null +++ b/example/44_conv2d_fwd_quant/CMakeLists.txt @@ -0,0 +1,2 @@ +add_example_executable(example_conv2d_fwd_xdl_perlayer_quantization_int8 conv2d_fwd_xdl_perlayer_quantization_int8.cpp) +add_example_executable(example_conv2d_fwd_xdl_bias_relu_perlayer_quantization_int8 conv2d_fwd_xdl_bias_relu_perlayer_quantization_int8.cpp) diff --git a/example/44_conv2d_fwd_quant/conv2d_fwd_xdl_bias_relu_perlayer_quantization_int8.cpp b/example/44_conv2d_fwd_quant/conv2d_fwd_xdl_bias_relu_perlayer_quantization_int8.cpp new file mode 100644 index 0000000000..613f607d8b --- /dev/null +++ b/example/44_conv2d_fwd_quant/conv2d_fwd_xdl_bias_relu_perlayer_quantization_int8.cpp @@ -0,0 +1,317 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/convolution_parameter.hpp" +#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp" + +using InDataType = int8_t; +using WeiDataType = int8_t; +using BiasDataType = int32_t; +using AccDataType = int32_t; +using CShuffleDataType = int32_t; +using OutDataType = int8_t; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using InElementOp = PassThrough; +using WeiElementOp = PassThrough; +using ActivationOp = ck::tensor_operation::element_wise::Relu; +using OutElementOp = ck::tensor_operation::element_wise::Add_Activation_Mul_Clamp; + +static constexpr auto ConvSpec = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +template +using DeviceGroupedConvNDFwdInstance = + ck::tensor_operation::device::DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< + NDimSpatial, + InLayout, + WeiLayout, + ck::Tuple, + OutLayout, + InDataType, + WeiDataType, + AccDataType, + CShuffleDataType, + ck::Tuple, + OutDataType, + InElementOp, + WeiElementOp, + OutElementOp, + ConvSpec, // ConvForwardSpecialization + GemmSpec, // GemmSpecialization + 1, // + 256, // BlockSize + 128, // MPerBlock + 256, // NPerBlock + 64, // KPerBlock + 16, // AK1 + 16, // BK1 + 32, // MPerXdl + 32, // NPerXdl + 2, // MXdlPerWave + 4, // NXdlPerWave + S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1 + S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // ABlockTransferSrcAccessOrder + 2, // ABlockTransferSrcVectorDim + 16, // ABlockTransferSrcScalarPerVector + 16, // ABlockTransferDstScalarPerVector_AK1 + 1, // ABlockLdsExtraM + S<4, 64, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1 + S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // BBlockTransferSrcAccessOrder + 2, // BBlockTransferSrcVectorDim + 16, // BBlockTransferSrcScalarPerVector + 16, // BBlockTransferDstScalarPerVector_BK1 + 1, // BBlockLdsExtraN + 1, + 1, + S<1, 64, 1, 4>, + 8>; + +template +bool run_grouped_conv_fwd(bool do_verification, + bool time_kernel, + const ck::utils::conv::ConvParam& conv_param, + const HostTensorDescriptor& in_g_n_c_wis_desc, + const HostTensorDescriptor& wei_g_k_c_xs_desc, + const HostTensorDescriptor& bias_g_k_desc, + const HostTensorDescriptor& out_g_n_k_wos_desc, + const InElementOp& in_element_op, + const WeiElementOp& wei_element_op, + const OutElementOp& out_element_op) +{ + Tensor in(in_g_n_c_wis_desc); + Tensor wei(wei_g_k_c_xs_desc); + Tensor bias(bias_g_k_desc); + Tensor out_host(out_g_n_k_wos_desc); + Tensor out_device(out_g_n_k_wos_desc); + + std::cout << "in: " << in.mDesc << std::endl; + std::cout << "wei: " << wei.mDesc << std::endl; + std::cout << "bias: " << bias.mDesc << std::endl; + std::cout << "out: " << out_host.mDesc << std::endl; + + in.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + wei.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + bias.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + + DeviceMem in_device_buf(sizeof(InDataType) * in.mDesc.GetElementSpaceSize()); + DeviceMem wei_device_buf(sizeof(WeiDataType) * wei.mDesc.GetElementSpaceSize()); + DeviceMem bias_device_buf(sizeof(BiasDataType) * bias.mDesc.GetElementSpaceSize()); + DeviceMem out_device_buf(sizeof(OutDataType) * out_device.mDesc.GetElementSpaceSize()); + + in_device_buf.ToDevice(in.mData.data()); + wei_device_buf.ToDevice(wei.mData.data()); + bias_device_buf.ToDevice(bias.mData.data()); + + std::array a_g_n_c_wis_lengths{}; + std::array a_g_n_c_wis_strides{}; + std::array b_g_k_c_xs_lengths{}; + std::array b_g_k_c_xs_strides{}; + std::array d0_g_n_k_wos_lengths{}; + std::array d0_g_n_k_wos_strides{}; + std::array e_g_n_k_wos_lengths{}; + std::array e_g_n_k_wos_strides{}; + std::array conv_filter_strides{}; + std::array conv_filter_dilations{}; + std::array input_left_pads{}; + std::array input_right_pads{}; + + auto copy = [](auto& x, auto& y) { std::copy(x.begin(), x.end(), y.begin()); }; + + copy(in_g_n_c_wis_desc.GetLengths(), a_g_n_c_wis_lengths); + copy(in_g_n_c_wis_desc.GetStrides(), a_g_n_c_wis_strides); + copy(wei_g_k_c_xs_desc.GetLengths(), b_g_k_c_xs_lengths); + copy(wei_g_k_c_xs_desc.GetStrides(), b_g_k_c_xs_strides); + copy(bias_g_k_desc.GetLengths(), d0_g_n_k_wos_lengths); + copy(bias_g_k_desc.GetStrides(), d0_g_n_k_wos_strides); + copy(out_g_n_k_wos_desc.GetLengths(), e_g_n_k_wos_lengths); + copy(out_g_n_k_wos_desc.GetStrides(), e_g_n_k_wos_strides); + copy(conv_param.conv_filter_strides_, conv_filter_strides); + copy(conv_param.conv_filter_dilations_, conv_filter_dilations); + copy(conv_param.input_left_pads_, input_left_pads); + copy(conv_param.input_right_pads_, input_right_pads); + + // do Conv + auto conv = DeviceConvNDFwdInstance{}; + auto invoker = conv.MakeInvoker(); + auto argument = conv.MakeArgument( + in_device_buf.GetDeviceBuffer(), + wei_device_buf.GetDeviceBuffer(), + std::array{bias_device_buf.GetDeviceBuffer()}, + out_device_buf.GetDeviceBuffer(), + a_g_n_c_wis_lengths, + a_g_n_c_wis_strides, + b_g_k_c_xs_lengths, + b_g_k_c_xs_strides, + std::array, 1>{{d0_g_n_k_wos_lengths}}, + std::array, 1>{{d0_g_n_k_wos_strides}}, + e_g_n_k_wos_lengths, + e_g_n_k_wos_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + in_element_op, + wei_element_op, + out_element_op); + + if(!conv.IsSupportedArgument(argument)) + { + throw std::runtime_error( + "wrong! device_conv with the specified compilation parameters does " + "not support this Conv problem"); + } + + float avg_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); + + std::size_t flop = conv_param.GetFlops(); + std::size_t num_btype = conv_param.GetByte(); + + float tflops = static_cast(flop) / 1.E9 / avg_time; + float gb_per_sec = num_btype / 1.E6 / avg_time; + std::cout << "Perf: " << avg_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " + << conv.GetTypeString() << std::endl; + + bool pass = true; + + if(do_verification) + { + Tensor c_host(out_g_n_k_wos_desc); + + auto ref_conv = ck::tensor_operation::host::ReferenceConvFwd(); + + auto ref_invoker = ref_conv.MakeInvoker(); + auto ref_argument = ref_conv.MakeArgument(in, + wei, + c_host, + conv_param.conv_filter_strides_, + conv_param.conv_filter_dilations_, + conv_param.input_left_pads_, + conv_param.input_right_pads_, + in_element_op, + wei_element_op, + PassThrough{}); + + ref_invoker.Run(ref_argument); + + // TODO: implement elementwise operation for host + out_host.ForEach( + [&](auto&, auto idx) { out_element_op(out_host(idx), c_host(idx), bias(idx)); }); + + out_device_buf.FromDevice(out_device.mData.data()); + + pass &= ck::utils::check_err( + out_device.mData, out_host.mData, "Error: incorrect results!", 1e-5f, 1e-4f); + } + + return (pass ? 0 : 1); +} + +int main() +{ + bool do_verification = true; + bool time_kernel = true; + const ck::index_t ndim_spatial = 2; + + ck::utils::conv::ConvParam conv_param{ + ndim_spatial, // n_dim + 1, // group + 4, // batch + 64, // output channels + 32, // input chanels + {3, 3}, // weight HW + {71, 71}, // x HW + {2, 2}, // strides + {1, 1}, // dilations + {1, 1}, // left_pads + {1, 1} // right_pads + }; + + const auto in_element_op = InElementOp{}; + const auto wei_element_op = WeiElementOp{}; + const auto out_element_op = OutElementOp{0.5f, ActivationOp{}}; + + using InLayout = ck::tensor_layout::convolution::GNHWC; + using WeiLayout = ck::tensor_layout::convolution::GKYXC; + using BiasLayout = ck::tensor_layout::convolution::G_K; + using OutLayout = ck::tensor_layout::convolution::GNHWK; + + const auto in_g_n_c_wis_desc = + ck::utils::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed(conv_param); + + const auto wei_g_k_c_xs_desc = + ck::utils::conv::make_weight_host_tensor_descriptor_g_k_c_xs_packed(conv_param); + + // TODO - make_bias_host_tensor_descriptor_g_n_k_wos_packed() + const auto bias_g_k_desc = HostTensorDescriptor({conv_param.G_, + conv_param.N_, + conv_param.K_, + conv_param.output_spatial_lengths_[0], + conv_param.output_spatial_lengths_[1]}, + { + conv_param.K_, // g + 0, // n + 1, // k + 0, // ho + 0 // wo + }); + + const auto out_g_n_k_wos_desc = + ck::utils::conv::make_output_host_tensor_descriptor_g_n_k_wos_packed(conv_param); + + std::cout << out_g_n_k_wos_desc << std::endl; + + return run_grouped_conv_fwd< + ndim_spatial, + InDataType, + WeiDataType, + OutDataType, + InElementOp, + WeiElementOp, + OutElementOp, + DeviceGroupedConvNDFwdInstance>( + do_verification, + time_kernel, + conv_param, + in_g_n_c_wis_desc, + wei_g_k_c_xs_desc, + bias_g_k_desc, + out_g_n_k_wos_desc, + in_element_op, + wei_element_op, + out_element_op); +} diff --git a/example/44_conv2d_fwd_quant/conv2d_fwd_xdl_perlayer_quantization_int8.cpp b/example/44_conv2d_fwd_quant/conv2d_fwd_xdl_perlayer_quantization_int8.cpp new file mode 100644 index 0000000000..71472440c9 --- /dev/null +++ b/example/44_conv2d_fwd_quant/conv2d_fwd_xdl_perlayer_quantization_int8.cpp @@ -0,0 +1,277 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/convolution_parameter.hpp" +#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp" + +using InDataType = int8_t; +using WeiDataType = int8_t; +using AccDataType = int32_t; +using CShuffleDataType = int32_t; +using OutDataType = int8_t; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using InElementOp = PassThrough; +using WeiElementOp = PassThrough; +using ActivationOp = PassThrough; +using OutElementOp = ck::tensor_operation::element_wise::Activation_Mul_Clamp; + +static constexpr auto ConvSpec = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +template +using DeviceGroupedConvNDFwdInstance = + ck::tensor_operation::device::DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< + NDimSpatial, + InLayout, + WeiLayout, + ck::Tuple<>, + OutLayout, + InDataType, + WeiDataType, + AccDataType, + CShuffleDataType, + ck::Tuple<>, + OutDataType, + InElementOp, + WeiElementOp, + OutElementOp, + ConvSpec, // ConvForwardSpecialization + GemmSpec, // GemmSpecialization + 1, // + 256, // BlockSize + 128, // MPerBlock + 256, // NPerBlock + 64, // KPerBlock + 16, // AK1 + 16, // BK1 + 32, // MPerXdl + 32, // NPerXdl + 2, // MXdlPerWave + 4, // NXdlPerWave + S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1 + S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // ABlockTransferSrcAccessOrder + 2, // ABlockTransferSrcVectorDim + 16, // ABlockTransferSrcScalarPerVector + 16, // ABlockTransferDstScalarPerVector_AK1 + 1, // ABlockLdsExtraM + S<4, 64, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1 + S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // BBlockTransferSrcAccessOrder + 2, // BBlockTransferSrcVectorDim + 16, // BBlockTransferSrcScalarPerVector + 16, // BBlockTransferDstScalarPerVector_BK1 + 1, // BBlockLdsExtraN + 1, + 1, + S<1, 64, 1, 4>, + 16>; + +template +bool run_grouped_conv_fwd(bool do_verification, + bool time_kernel, + const ck::utils::conv::ConvParam& conv_param, + const HostTensorDescriptor& in_g_n_c_wis_desc, + const HostTensorDescriptor& wei_g_k_c_xs_desc, + const HostTensorDescriptor& out_g_n_k_wos_desc, + const InElementOp& in_element_op, + const WeiElementOp& wei_element_op, + const OutElementOp& out_element_op) +{ + Tensor in(in_g_n_c_wis_desc); + Tensor wei(wei_g_k_c_xs_desc); + Tensor out_host(out_g_n_k_wos_desc); + Tensor out_device(out_g_n_k_wos_desc); + + std::cout << "in: " << in.mDesc << std::endl; + std::cout << "wei: " << wei.mDesc << std::endl; + std::cout << "out: " << out_host.mDesc << std::endl; + + in.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + wei.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + + DeviceMem in_device_buf(sizeof(InDataType) * in.mDesc.GetElementSpaceSize()); + DeviceMem wei_device_buf(sizeof(WeiDataType) * wei.mDesc.GetElementSpaceSize()); + DeviceMem out_device_buf(sizeof(OutDataType) * out_device.mDesc.GetElementSpaceSize()); + + in_device_buf.ToDevice(in.mData.data()); + wei_device_buf.ToDevice(wei.mData.data()); + + std::array a_g_n_c_wis_lengths{}; + std::array a_g_n_c_wis_strides{}; + std::array b_g_k_c_xs_lengths{}; + std::array b_g_k_c_xs_strides{}; + std::array e_g_n_k_wos_lengths{}; + std::array e_g_n_k_wos_strides{}; + std::array conv_filter_strides{}; + std::array conv_filter_dilations{}; + std::array input_left_pads{}; + std::array input_right_pads{}; + + auto copy = [](auto& x, auto& y) { std::copy(x.begin(), x.end(), y.begin()); }; + + copy(in_g_n_c_wis_desc.GetLengths(), a_g_n_c_wis_lengths); + copy(in_g_n_c_wis_desc.GetStrides(), a_g_n_c_wis_strides); + copy(wei_g_k_c_xs_desc.GetLengths(), b_g_k_c_xs_lengths); + copy(wei_g_k_c_xs_desc.GetStrides(), b_g_k_c_xs_strides); + copy(out_g_n_k_wos_desc.GetLengths(), e_g_n_k_wos_lengths); + copy(out_g_n_k_wos_desc.GetStrides(), e_g_n_k_wos_strides); + copy(conv_param.conv_filter_strides_, conv_filter_strides); + copy(conv_param.conv_filter_dilations_, conv_filter_dilations); + copy(conv_param.input_left_pads_, input_left_pads); + copy(conv_param.input_right_pads_, input_right_pads); + + // do Conv + auto conv = DeviceConvNDFwdInstance{}; + auto invoker = conv.MakeInvoker(); + auto argument = conv.MakeArgument(in_device_buf.GetDeviceBuffer(), + wei_device_buf.GetDeviceBuffer(), + std::array{}, + out_device_buf.GetDeviceBuffer(), + a_g_n_c_wis_lengths, + a_g_n_c_wis_strides, + b_g_k_c_xs_lengths, + b_g_k_c_xs_strides, + std::array, 0>{{}}, + std::array, 0>{{}}, + e_g_n_k_wos_lengths, + e_g_n_k_wos_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + in_element_op, + wei_element_op, + out_element_op); + + if(!conv.IsSupportedArgument(argument)) + { + throw std::runtime_error( + "wrong! device_conv with the specified compilation parameters does " + "not support this Conv problem"); + } + + float avg_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); + + std::size_t flop = conv_param.GetFlops(); + std::size_t num_btype = conv_param.GetByte(); + + float tflops = static_cast(flop) / 1.E9 / avg_time; + float gb_per_sec = num_btype / 1.E6 / avg_time; + std::cout << "Perf: " << avg_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " + << conv.GetTypeString() << std::endl; + + bool pass = true; + + if(do_verification) + { + auto ref_conv = ck::tensor_operation::host::ReferenceConvFwd(); + + auto ref_invoker = ref_conv.MakeInvoker(); + auto ref_argument = ref_conv.MakeArgument(in, + wei, + out_host, + conv_param.conv_filter_strides_, + conv_param.conv_filter_dilations_, + conv_param.input_left_pads_, + conv_param.input_right_pads_, + in_element_op, + wei_element_op, + out_element_op); + + ref_invoker.Run(ref_argument); + + out_device_buf.FromDevice(out_device.mData.data()); + + pass &= ck::utils::check_err( + out_device.mData, out_host.mData, "Error: incorrect results!", 1e-5f, 1e-4f); + } + + return (pass ? 0 : 1); +} + +int main() +{ + bool do_verification = true; + bool time_kernel = true; + const ck::index_t ndim_spatial = 2; + + ck::utils::conv::ConvParam conv_param{ + ndim_spatial, // n_dim + 1, // group + 4, // batch + 64, // output channels + 32, // input chanels + {3, 3}, // weight HW + {71, 71}, // x HW + {2, 2}, // strides + {1, 1}, // dilations + {1, 1}, // left_pads + {1, 1} // right_pads + }; + + const auto in_element_op = InElementOp{}; + const auto wei_element_op = WeiElementOp{}; + const auto out_element_op = OutElementOp{0.5f, ActivationOp{}}; + + using InLayout = ck::tensor_layout::convolution::GNHWC; + using WeiLayout = ck::tensor_layout::convolution::GKYXC; + using OutLayout = ck::tensor_layout::convolution::GNHWK; + + const auto in_g_n_c_wis_desc = + ck::utils::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed(conv_param); + + const auto wei_g_k_c_xs_desc = + ck::utils::conv::make_weight_host_tensor_descriptor_g_k_c_xs_packed(conv_param); + + const auto out_g_n_k_wos_desc = + ck::utils::conv::make_output_host_tensor_descriptor_g_n_k_wos_packed(conv_param); + + return run_grouped_conv_fwd< + ndim_spatial, + InDataType, + WeiDataType, + OutDataType, + InElementOp, + WeiElementOp, + OutElementOp, + DeviceGroupedConvNDFwdInstance>( + do_verification, + time_kernel, + conv_param, + in_g_n_c_wis_desc, + wei_g_k_c_xs_desc, + out_g_n_k_wos_desc, + in_element_op, + wei_element_op, + out_element_op); +} diff --git a/include/ck/tensor_operation/gpu/element/element_wise_operation.hpp b/include/ck/tensor_operation/gpu/element/element_wise_operation.hpp index 47d018095d..b66107a525 100644 --- a/include/ck/tensor_operation/gpu/element/element_wise_operation.hpp +++ b/include/ck/tensor_operation/gpu/element/element_wise_operation.hpp @@ -7,6 +7,7 @@ #include "ck/utility/math_v2.hpp" #include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/element/quantization_operation.hpp" namespace ck { namespace tensor_operation { diff --git a/include/ck/tensor_operation/gpu/element/quantization_operation.hpp b/include/ck/tensor_operation/gpu/element/quantization_operation.hpp new file mode 100644 index 0000000000..f27b61ba53 --- /dev/null +++ b/include/ck/tensor_operation/gpu/element/quantization_operation.hpp @@ -0,0 +1,86 @@ +#pragma once + +#include "ck/utility/data_type.hpp" + +namespace ck { +namespace tensor_operation { +namespace element_wise { + +// For Activation function which is piecewise linear function, such as relu, leaky relu ...etc +template +struct Activation_Mul_Clamp +{ + Activation_Mul_Clamp(float multiplier, Activation activationOp) + : multiplier_(multiplier), activationOp_(activationOp) + { + } + + __host__ __device__ constexpr void operator()(int8_t& y, const int32_t& x) const + { + float x_fp32 = ck::type_convert(x); + activationOp_(x_fp32, x_fp32); + float y_fp32 = math::clamp(multiplier_ * x_fp32, -128.f, 127.f); + y = ck::type_convert(y_fp32); + } + + __host__ __device__ constexpr void operator()(float& y, const int32_t& x) const + { + // We might type_convert to int8 after lambda in someplace + float x_fp32 = ck::type_convert(x); + activationOp_(x_fp32, x_fp32); + y = math::clamp(multiplier_ * x_fp32, -128.f, 127.f); + } + + float multiplier_; + Activation activationOp_; +}; + +// For Activation function which is piecewise linear function, such as relu, leaky relu ...etc +template +struct Add_Activation_Mul_Clamp +{ + Add_Activation_Mul_Clamp(float multiplier, Activation activationOp) + : multiplier_(multiplier), activationOp_(activationOp) + { + } + + __host__ __device__ constexpr void + operator()(int8_t& y, const int32_t& x1, const int32_t& x2) const + { + float y_fp32 = ck::type_convert(x1 + x2); + activationOp_(y_fp32, y_fp32); + y_fp32 = math::clamp(multiplier_ * y_fp32, -128.f, 127.f); + y = ck::type_convert(y_fp32); + } + + float multiplier_; + Activation activationOp_; +}; + +// For Activation function which is non piecewise linear function, such as TanH, Sigmoid ...etc +template +struct Add_Mul_Activation_Mul_Clamp +{ + Add_Mul_Activation_Mul_Clamp(float multiplier1, float multiplier2, Activation activationOp) + : multiplier1_(multiplier1), multiplier2_(multiplier2), activationOp_(activationOp) + { + } + + __host__ __device__ constexpr void + operator()(int8_t& y, const int32_t& x1, const int32_t& x2) const + { + float y_fp32 = ck::type_convert(x1 + x2); + y_fp32 = multiplier1_ * y_fp32; + activationOp_(y_fp32, y_fp32); + y_fp32 = math::clamp(multiplier2_ * y_fp32, -128.f, 127.f); + y = ck::type_convert(y_fp32); + } + + float multiplier1_; + float multiplier2_; + Activation activationOp_; +}; + +} // namespace element_wise +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp b/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp index 699b05fe3c..dcf80b9439 100644 --- a/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp +++ b/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp @@ -4,6 +4,7 @@ #pragma once #include "ck/utility/data_type.hpp" +#include "ck/utility/math.hpp" #include "ck/utility/math_v2.hpp" namespace ck { diff --git a/library/include/ck/library/tensor_operation_instance/device_operation_instance_factory.hpp b/library/include/ck/library/tensor_operation_instance/device_operation_instance_factory.hpp index fcc662e249..49ba995a46 100644 --- a/library/include/ck/library/tensor_operation_instance/device_operation_instance_factory.hpp +++ b/library/include/ck/library/tensor_operation_instance/device_operation_instance_factory.hpp @@ -28,6 +28,8 @@ using F16_F16_Tuple = ck::Tuple; using F32_Tuple = ck::Tuple; +using I32_Tuple = ck::Tuple; + // GEMM layout using Row = ck::tensor_layout::gemm::RowMajor; using Col = ck::tensor_layout::gemm::ColumnMajor; @@ -75,12 +77,24 @@ using NWGK = ck::tensor_layout::convolution::NWGK; using NHWGK = ck::tensor_layout::convolution::NHWGK; using NDHWGK = ck::tensor_layout::convolution::NDHWGK; +// +using GK = ck::tensor_layout::convolution::G_K; +using GK_TUPLE = ck::Tuple; + // pointwise functor using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using Relu = ck::tensor_operation::element_wise::Relu; using Scale = ck::tensor_operation::element_wise::Scale; using Bilinear = ck::tensor_operation::element_wise::Bilinear; using AddAddFastGelu = ck::tensor_operation::element_wise::AddAddFastGelu; +template +using Activation_Mul_Clamp = ck::tensor_operation::element_wise::Activation_Mul_Clamp; + +template +using Add_Activation_Mul_Clamp = + ck::tensor_operation::element_wise::Add_Activation_Mul_Clamp; + template struct DeviceOperationInstanceFactory; diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_bias_forward_perlayer_quantization.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_bias_forward_perlayer_quantization.hpp new file mode 100644 index 0000000000..9d441d14d1 --- /dev/null +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_bias_forward_perlayer_quantization.hpp @@ -0,0 +1,114 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_d.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// grouped conv2d forward, GNHWC/GKYXC/GNHWK +void add_device_conv2d_bias_perlayer_quantization_int8_instances( + std::vector< + std::unique_ptr>>>& + instances); + +void add_device_conv2d_bias_relu_perlayer_quantization_int8_instances( + std::vector>>>& + instances); + +template +struct DeviceOperationInstanceFactory>> +{ + using DeviceOp = DeviceGroupedConvFwdMultipleD>; + + static auto GetInstances() + { + std::vector> op_ptrs; + + if constexpr(NumDimSpatial == 2 && is_same_v && + is_same_v && is_same_v && + is_same_v) + { + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v) + { + if constexpr(is_same_v) + add_device_conv2d_bias_perlayer_quantization_int8_instances(op_ptrs); + else if constexpr(is_same_v) + add_device_conv2d_bias_relu_perlayer_quantization_int8_instances(op_ptrs); + } + } + + return op_ptrs; + } +}; + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_perlayer_quantization.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_perlayer_quantization.hpp new file mode 100644 index 0000000000..410be4a57b --- /dev/null +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_perlayer_quantization.hpp @@ -0,0 +1,110 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_d.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// grouped conv2d forward, GNHWC/GKYXC/GNHWK +void add_device_conv2d_perlayer_quantization_int8_instances( + std::vector>>>& + instances); + +void add_device_conv2d_relu_perlayer_quantization_int8_instances( + std::vector>>>& + instances); + +template +struct DeviceOperationInstanceFactory>> +{ + using DeviceOp = DeviceGroupedConvFwdMultipleD>; + + static auto GetInstances() + { + std::vector> op_ptrs; + + if constexpr(NumDimSpatial == 2 && is_same_v && + is_same_v && is_same_v) + { + if constexpr(is_same_v && is_same_v && + is_same_v) + { + if constexpr(is_same_v) + add_device_conv2d_perlayer_quantization_int8_instances(op_ptrs); + else if constexpr(is_same_v) + add_device_conv2d_relu_perlayer_quantization_int8_instances(op_ptrs); + } + } + + return op_ptrs; + } +}; + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/quantization/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/quantization/CMakeLists.txt new file mode 100644 index 0000000000..8b2149aefb --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/quantization/CMakeLists.txt @@ -0,0 +1,4 @@ +add_instance_library(device_quantization_instance + device_conv2d_xdl_bias_quant_int8_instance.cpp + device_conv2d_xdl_quant_int8_instance.cpp +) diff --git a/library/src/tensor_operation_instance/gpu/quantization/device_conv2d_xdl_bias_quant_int8_instance.cpp b/library/src/tensor_operation_instance/gpu/quantization/device_conv2d_xdl_bias_quant_int8_instance.cpp new file mode 100644 index 0000000000..774758fb69 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/quantization/device_conv2d_xdl_bias_quant_int8_instance.cpp @@ -0,0 +1,112 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +template +using S = ck::Sequence; + +using GNHWC = ck::tensor_layout::convolution::GNHWC; +using GKYXC = ck::tensor_layout::convolution::GKYXC; +using GNHWK = ck::tensor_layout::convolution::GNHWK; +using GK = ck::tensor_layout::convolution::G_K; +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using Relu = ck::tensor_operation::element_wise::Relu; + +using GK_Tuple = ck::Tuple; +using I32_Tuple = ck::Tuple; + +using Add_Mul_Clamp = ck::tensor_operation::element_wise::Add_Activation_Mul_Clamp; +using Add_Relu_Mul_Clamp = ck::tensor_operation::element_wise::Add_Activation_Mul_Clamp; + +static constexpr ck::index_t NDimSpatial = 2; +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding; +static constexpr auto ConvFwdDefault = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; +static constexpr auto ConvFwd1x1P0 = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0; +static constexpr auto ConvFwd1x1S1P0 = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0; + +// TODO - Add more instances +template +// clang-format off +using device_conv2d_int8_instances = + std::tuple < + //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, GK_Tuple, GNHWK, int8_t, int8_t, int32_t, int32_t, I32_Tuple, int8_t, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmSpec, 1, 256, 256, 128, 64, 16, 16, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, GK_Tuple, GNHWK, int8_t, int8_t, int32_t, int32_t, I32_Tuple, int8_t, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmSpec, 1, 256, 128, 256, 64, 16, 16, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, GK_Tuple, GNHWK, int8_t, int8_t, int32_t, int32_t, I32_Tuple, int8_t, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmSpec, 1, 128, 128, 128, 64, 16, 16, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, GK_Tuple, GNHWK, int8_t, int8_t, int32_t, int32_t, I32_Tuple, int8_t, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmSpec, 1, 256, 128, 128, 64, 16, 16, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, GK_Tuple, GNHWK, int8_t, int8_t, int32_t, int32_t, I32_Tuple, int8_t, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmSpec, 1, 128, 128, 64, 64, 16, 16, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 2>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, GK_Tuple, GNHWK, int8_t, int8_t, int32_t, int32_t, I32_Tuple, int8_t, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmSpec, 1, 128, 64, 128, 64, 16, 16, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, GK_Tuple, GNHWK, int8_t, int8_t, int32_t, int32_t, I32_Tuple, int8_t, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmSpec, 1, 64, 64, 64, 64, 16, 16, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 2>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, GK_Tuple, GNHWK, int8_t, int8_t, int32_t, int32_t, I32_Tuple, int8_t, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmSpec, 1, 256, 128, 64, 64, 16, 16, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, GK_Tuple, GNHWK, int8_t, int8_t, int32_t, int32_t, I32_Tuple, int8_t, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmSpec, 1, 256, 64, 128, 64, 16, 16, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, GK_Tuple, GNHWK, int8_t, int8_t, int32_t, int32_t, I32_Tuple, int8_t, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmSpec, 1, 128, 128, 32, 64, 16, 16, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 2>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, GK_Tuple, GNHWK, int8_t, int8_t, int32_t, int32_t, I32_Tuple, int8_t, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmSpec, 1, 128, 32, 128, 64, 16, 16, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, GK_Tuple, GNHWK, int8_t, int8_t, int32_t, int32_t, I32_Tuple, int8_t, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmSpec, 1, 64, 64, 32, 64, 16, 16, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 2>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, GK_Tuple, GNHWK, int8_t, int8_t, int32_t, int32_t, I32_Tuple, int8_t, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmSpec, 1, 64, 32, 64, 64, 16, 16, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 2>, 8> + >; +// clang-format on + +void add_device_conv2d_bias_perlayer_quantization_int8_instances( + std::vector, + GNHWK, + int8_t, + int8_t, + ck::Tuple, + int8_t, + PassThrough, + PassThrough, + Add_Mul_Clamp>>>& instances) +{ + add_device_operation_instances(instances, + device_conv2d_int8_instances{}); + add_device_operation_instances(instances, + device_conv2d_int8_instances{}); + add_device_operation_instances(instances, + device_conv2d_int8_instances{}); +} + +void add_device_conv2d_bias_relu_perlayer_quantization_int8_instances( + std::vector, + GNHWK, + int8_t, + int8_t, + ck::Tuple, + int8_t, + PassThrough, + PassThrough, + Add_Relu_Mul_Clamp>>>& instances) +{ + add_device_operation_instances( + instances, device_conv2d_int8_instances{}); + add_device_operation_instances( + instances, device_conv2d_int8_instances{}); + add_device_operation_instances( + instances, device_conv2d_int8_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/quantization/device_conv2d_xdl_quant_int8_instance.cpp b/library/src/tensor_operation_instance/gpu/quantization/device_conv2d_xdl_quant_int8_instance.cpp new file mode 100644 index 0000000000..eba5954c55 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/quantization/device_conv2d_xdl_quant_int8_instance.cpp @@ -0,0 +1,109 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using Empty_Tuple = ck::Tuple<>; +template +using S = ck::Sequence; + +using GNHWC = ck::tensor_layout::convolution::GNHWC; +using GKYXC = ck::tensor_layout::convolution::GKYXC; +using GNHWK = ck::tensor_layout::convolution::GNHWK; +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using Relu = ck::tensor_operation::element_wise::Relu; + +using Mul_Clamp = ck::tensor_operation::element_wise::Activation_Mul_Clamp; +using Relu_Mul_Clamp = ck::tensor_operation::element_wise::Activation_Mul_Clamp; + +static constexpr ck::index_t NDimSpatial = 2; +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding; +static constexpr auto ConvFwdDefault = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; +static constexpr auto ConvFwd1x1P0 = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0; +static constexpr auto ConvFwd1x1S1P0 = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0; + +// TODO - Add more instances +template +// clang-format off +using device_conv2d_int8_instances = + std::tuple < + //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, int8_t, int8_t, int32_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmSpec, 1, 256, 256, 128, 64, 16, 16, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 4>, 16>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, int8_t, int8_t, int32_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmSpec, 1, 256, 128, 256, 64, 16, 16, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 4>, 16>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, int8_t, int8_t, int32_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmSpec, 1, 128, 128, 128, 64, 16, 16, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 4>, 16>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, int8_t, int8_t, int32_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmSpec, 1, 256, 128, 128, 64, 16, 16, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 4>, 16>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, int8_t, int8_t, int32_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmSpec, 1, 128, 128, 64, 64, 16, 16, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 2>, 16>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, int8_t, int8_t, int32_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmSpec, 1, 128, 64, 128, 64, 16, 16, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 4>, 16>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, int8_t, int8_t, int32_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmSpec, 1, 64, 64, 64, 64, 16, 16, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 2>, 16>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, int8_t, int8_t, int32_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmSpec, 1, 256, 128, 64, 64, 16, 16, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 4>, 16>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, int8_t, int8_t, int32_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmSpec, 1, 256, 64, 128, 64, 16, 16, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 4>, 16>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, int8_t, int8_t, int32_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmSpec, 1, 128, 128, 32, 64, 16, 16, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 2>, 16>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, int8_t, int8_t, int32_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmSpec, 1, 128, 32, 128, 64, 16, 16, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 4>, 16>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, int8_t, int8_t, int32_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmSpec, 1, 64, 64, 32, 64, 16, 16, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 2>, 16>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, int8_t, int8_t, int32_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmSpec, 1, 64, 32, 64, 64, 16, 16, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 2>, 16> + >; +// clang-format on + +void add_device_conv2d_perlayer_quantization_int8_instances( + std::vector>>& instances) +{ + add_device_operation_instances(instances, + device_conv2d_int8_instances{}); + add_device_operation_instances(instances, + device_conv2d_int8_instances{}); + add_device_operation_instances(instances, + device_conv2d_int8_instances{}); +} + +void add_device_conv2d_relu_perlayer_quantization_int8_instances( + std::vector>>& instances) +{ + add_device_operation_instances(instances, + device_conv2d_int8_instances{}); + add_device_operation_instances(instances, + device_conv2d_int8_instances{}); + add_device_operation_instances(instances, + device_conv2d_int8_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/script/cmake-ck-dev.sh b/script/cmake-ck-dev.sh index f5a08204c4..2e605ce8de 100755 --- a/script/cmake-ck-dev.sh +++ b/script/cmake-ck-dev.sh @@ -11,7 +11,7 @@ cmake -D CMAKE_CXX_FLAGS="-O3 -ftemplate-backtrace-limit=0 -gline-tables-only -save-temps=$PWD" \ -D CMAKE_BUILD_TYPE=Release \ -D BUILD_DEV=ON \ --D GPU_TARGETS=gfx908;gfx90a \ +-D GPU_TARGETS="gfx908;gfx90a" \ -D CMAKE_VERBOSE_MAKEFILE:BOOL=ON \ -D USE_BITINT_EXTENSION_INT4=OFF \ ${MY_PROJECT_SOURCE} diff --git a/script/cmake-ck-release.sh b/script/cmake-ck-release.sh index a583cc35ed..268b1ebf9b 100755 --- a/script/cmake-ck-release.sh +++ b/script/cmake-ck-release.sh @@ -11,7 +11,7 @@ cmake -D CMAKE_CXX_FLAGS="-O3" \ -D CMAKE_BUILD_TYPE=Release \ -D BUILD_DEV=OFF \ --D GPU_TARGETS=gfx908;gfx90a \ +-D GPU_TARGETS="gfx908;gfx90a" \ -D CMAKE_VERBOSE_MAKEFILE:BOOL=ON \ -D USE_BITINT_EXTENSION_INT4=OFF \ ${MY_PROJECT_SOURCE}