From 4b153bd9744d135b3e429b23133461df6fa7f4d4 Mon Sep 17 00:00:00 2001 From: "Po-Yen, Chen" Date: Thu, 18 Aug 2022 23:03:36 -0400 Subject: [PATCH] Add GEMM examples for int4 Currently the source files are just copied from int8 examples --- example/01_gemm/CMakeLists.txt | 6 + example/01_gemm/gemm_dl_int4.cpp | 207 ++++++++++++++++++++++++++ example/01_gemm/gemm_xdl_int4.cpp | 233 ++++++++++++++++++++++++++++++ 3 files changed, 446 insertions(+) create mode 100644 example/01_gemm/gemm_dl_int4.cpp create mode 100644 example/01_gemm/gemm_xdl_int4.cpp diff --git a/example/01_gemm/CMakeLists.txt b/example/01_gemm/CMakeLists.txt index fc22088ad4..ebcd01b57b 100644 --- a/example/01_gemm/CMakeLists.txt +++ b/example/01_gemm/CMakeLists.txt @@ -1,9 +1,15 @@ add_example_executable(example_gemm_dl_fp32 gemm_dl_fp32.cpp) add_example_executable(example_gemm_dl_fp16 gemm_dl_fp16.cpp) add_example_executable(example_gemm_dl_int8 gemm_dl_int8.cpp) +if (USE_BITINT_EXTENSION_INT4) + add_example_executable(example_gemm_dl_int4 gemm_dl_int4.cpp) +endif(USE_BITINT_EXTENSION_INT4) add_example_executable(example_gemm_xdl_fp16 gemm_xdl_fp16.cpp) add_example_executable(example_gemm_xdl_bf16 gemm_xdl_bf16.cpp) add_example_executable(example_gemm_xdl_int8 gemm_xdl_int8.cpp) +if (USE_BITINT_EXTENSION_INT4) + add_example_executable(example_gemm_xdl_int4 gemm_xdl_int4.cpp) +endif(USE_BITINT_EXTENSION_INT4) add_example_executable(example_gemm_xdl_skip_b_lds_fp16 gemm_xdl_skip_b_lds_fp16.cpp) # FIXME: re-enable this exampe as test when SWDEV-335738 is fixed add_example_executable_no_testing(example_gemm_xdl_fp64 gemm_xdl_fp64.cpp) diff --git a/example/01_gemm/gemm_dl_int4.cpp b/example/01_gemm/gemm_dl_int4.cpp new file mode 100644 index 0000000000..7787110580 --- /dev/null +++ b/example/01_gemm/gemm_dl_int4.cpp @@ -0,0 +1,207 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_dl.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" + +template +using S = ck::Sequence; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using ADataType = int8_t; +using BDataType = int8_t; +using CDataType = int8_t; +using AccDataType = int32_t; + +using ALayout = Col; +using BLayout = Row; +using CLayout = Row; + +using AElementOp = ck::tensor_operation::element_wise::PassThrough; +using BElementOp = ck::tensor_operation::element_wise::PassThrough; +using CElementOp = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// clang-format off +using DeviceGemmInstance = ck::tensor_operation::device:: + // #########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| + // #########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| + // #########| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | | + // #########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 16, 4, 4, 4, 1, S<8, 2>, S<8, 2>, S<2, 1, 4, 4>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<2, 1, 4, 4>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>; +// clang-format on + +using ReferenceGemmInstance = ck::tensor_operation::host:: + ReferenceGemm; + +int main(int argc, char* argv[]) +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = false; + + // GEMM shape + ck::index_t M = 3840; + ck::index_t N = 4096; + ck::index_t K = 4096; + + ck::index_t StrideA = 4096; + ck::index_t StrideB = 4096; + ck::index_t StrideC = 4096; + + if(argc == 1) + { + // do nothing + } + else if(argc == 4) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + } + else if(argc == 10) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + + M = std::stoi(argv[4]); + N = std::stoi(argv[5]); + K = std::stoi(argv[6]); + + StrideA = std::stoi(argv[7]); + StrideB = std::stoi(argv[8]); + StrideC = std::stoi(argv[9]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: time kernel (0=n0, 1=yes)\n"); + printf("arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC\n"); + exit(1); + } + + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + if(std::is_same::value) + { + return HostTensorDescriptor(std::vector({row, col}), + std::vector({stride, 1})); + } + else + { + return HostTensorDescriptor(std::vector({row, col}), + std::vector({1, stride})); + } + }; + + Tensor a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); + Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); + Tensor c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + Tensor c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + + std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; + std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; + std::cout << "c_m_n: " << c_m_n_host_result.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + a_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b_k_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + case 2: + a_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b_k_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + break; + default: + a_m_k.GenerateTensorValue(GeneratorTensor_Sequential<0>{}); + b_k_n.GenerateTensorValue(GeneratorTensor_Sequential<1>{}); + } + + DeviceMem a_m_k_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize()); + DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize()); + DeviceMem c_m_n_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpaceSize()); + + a_m_k_device_buf.ToDevice(a_m_k.mData.data()); + b_k_n_device_buf.ToDevice(b_k_n.mData.data()); + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto c_element_op = CElementOp{}; + + // do GEMM + auto gemm = DeviceGemmInstance{}; + auto invoker = gemm.MakeInvoker(); + auto argument = gemm.MakeArgument(static_cast(a_m_k_device_buf.GetDeviceBuffer()), + static_cast(b_k_n_device_buf.GetDeviceBuffer()), + static_cast(c_m_n_device_buf.GetDeviceBuffer()), + M, + N, + K, + StrideA, + StrideB, + StrideC, + a_element_op, + b_element_op, + c_element_op); + + if(!gemm.IsSupportedArgument(argument)) + { + std::cout << gemm.GetTypeString() << " does not support this problem" << std::endl; + + return 0; + } + + float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); + + std::size_t flop = std::size_t(2) * M * N * K; + std::size_t num_btype = + sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(CDataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " + << gemm.GetTypeString() << std::endl; + + c_m_n_device_buf.FromDevice(c_m_n_device_result.mData.data()); + + bool pass = true; + + if(do_verification) + { + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = ref_gemm.MakeArgument( + a_m_k, b_k_n, c_m_n_host_result, a_element_op, b_element_op, c_element_op); + + ref_invoker.Run(ref_argument); + + pass = ck::utils::check_err(c_m_n_device_result.mData, c_m_n_host_result.mData); + } + + return pass ? 0 : 1; +} diff --git a/example/01_gemm/gemm_xdl_int4.cpp b/example/01_gemm/gemm_xdl_int4.cpp new file mode 100644 index 0000000000..ca5c66f8af --- /dev/null +++ b/example/01_gemm/gemm_xdl_int4.cpp @@ -0,0 +1,233 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using ADataType = int8_t; +using BDataType = int8_t; +using CDataType = int8_t; +using AccDataType = int32_t; +using CShuffleDataType = int8_t; + +using ALayout = ck::tensor_layout::gemm::RowMajor; +using BLayout = ck::tensor_layout::gemm::ColumnMajor; +using CLayout = ck::tensor_layout::gemm::RowMajor; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// clang-format off +using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle< + ALayout, // typename ALayout + BLayout, // typename BLayout + CLayout, // typename CLayout + ADataType, // typename ADataType + BDataType, // typename BDataType + CDataType, // typename CDataType + AccDataType, // typename GemmAccDataType + CShuffleDataType, // typename CShuffleDataType + PassThrough, // typename AElementwiseOperation + PassThrough, // typename BElementwiseOperation + PassThrough, // typename CElementwiseOperation + GemmDefault, // GemmSpecialization GemmSpec + 1, // index_t NumGemmKPrefetchStage + 256, // index_t BlockSize + 256, // index_t MPerBlock + 128, // index_t NPerBlock + 64, // index_t KPerBlock + 16, // index_t AK1 + 16, // index_t BK1 + 32, // index_t MPerXDL + 32, // index_t NPerXDL + 4, // index_t MXdlPerWave + 2, // index_t NXdlPerWave + S<4, 64, 1>, // typename ABlockTransferThreadClusterLengths_AK0_M_AK1 + S<1, 0, 2>, // typename ABlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // typename ABlockTransferSrcAccessOrder + 2, // index_t ABlockTransferSrcVectorDim + 16, // index_t ABlockTransferSrcScalarPerVector + 16, // index_t ABlockTransferDstScalarPerVector_AK1 + 1, // index_t ABlockLdsExtraM + S<4, 64, 1>, // typename BBlockTransferThreadClusterLengths_BK0_N_BK1 + S<1, 0, 2>, // typename BBlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // typename BBlockTransferSrcAccessOrder + 2, // index_t BBlockTransferSrcVectorDim + 8, // index_t BBlockTransferSrcScalarPerVector + 8, // index_t BBlockTransferDstScalarPerVector_BK1 + 1, // index_t BBlockLdsExtraN + 1, // index_t CShuffleMXdlPerWavePerShuffle + 1, // index_t CShuffleNXdlPerWavePerShuffle + S<1, 64, 1, 4>, // typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock + 16>; // index_t CShuffleBlockTransferScalarPerVector_NPerBlock +// clang-format on + +using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; + +int main(int argc, char* argv[]) +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = false; + + // GEMM shape + ck::index_t M = 3840; + ck::index_t N = 4096; + ck::index_t K = 4096; + + ck::index_t StrideA = 4096; + ck::index_t StrideB = 4096; + ck::index_t StrideC = 4096; + + if(argc == 4) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + } + else if(argc == 10) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + + M = std::stoi(argv[4]); + N = std::stoi(argv[5]); + K = std::stoi(argv[6]); + + StrideA = std::stoi(argv[7]); + StrideB = std::stoi(argv[8]); + StrideC = std::stoi(argv[9]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: time kernel (0=n0, 1=yes)\n"); + printf("arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC\n"); + exit(0); + } + + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + if(std::is_same::value) + { + return HostTensorDescriptor(std::vector({row, col}), + std::vector({stride, 1})); + } + else + { + return HostTensorDescriptor(std::vector({row, col}), + std::vector({1, stride})); + } + }; + + Tensor a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); + Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); + Tensor c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + Tensor c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + + std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; + std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; + std::cout << "c_m_n: " << c_m_n_host_result.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + a_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b_k_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + default: + a_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b_k_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + } + + DeviceMem a_m_k_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize()); + DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize()); + DeviceMem c_m_n_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpaceSize()); + + a_m_k_device_buf.ToDevice(a_m_k.mData.data()); + b_k_n_device_buf.ToDevice(b_k_n.mData.data()); + + auto a_element_op = PassThrough{}; + auto b_element_op = PassThrough{}; + auto c_element_op = PassThrough{}; + + // do GEMM + auto gemm = DeviceGemmInstance{}; + auto invoker = gemm.MakeInvoker(); + auto argument = gemm.MakeArgument(static_cast(a_m_k_device_buf.GetDeviceBuffer()), + static_cast(b_k_n_device_buf.GetDeviceBuffer()), + static_cast(c_m_n_device_buf.GetDeviceBuffer()), + M, + N, + K, + StrideA, + StrideB, + StrideC, + a_element_op, + b_element_op, + c_element_op); + + if(!gemm.IsSupportedArgument(argument)) + { + std::cout << gemm.GetTypeString() << " does not support this problem" << std::endl; + + return 0; + } + + float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); + + std::size_t flop = std::size_t(2) * M * N * K; + std::size_t num_btype = + sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(CDataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " + << gemm.GetTypeString() << std::endl; + + c_m_n_device_buf.FromDevice(c_m_n_device_result.mData.data()); + + if(do_verification) + { + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = ref_gemm.MakeArgument( + a_m_k, b_k_n, c_m_n_host_result, a_element_op, b_element_op, c_element_op); + + ref_invoker.Run(ref_argument); + + return ck::utils::check_err(c_m_n_device_result.mData, c_m_n_host_result.mData) ? 0 : 1; + } + + return 0; +}