From d832dc4354c332acefb75d5565e2b968a0953e95 Mon Sep 17 00:00:00 2001 From: Po Yen Chen Date: Fri, 12 Aug 2022 06:31:28 +0800 Subject: [PATCH] Add examples for GEMM + AddAddFastGelu (data type: int8, bf16, fp32) (#340) * Add always_false<> util to delay symbol resolution * Use always_false<> to prevent trying instantiate unwanted method * Add new specializations of AddAddFastGelu::operator() method * Add GEMM + AddAddFastGelu examples for data types: int8, bf16, fp32 * Use floating point literal to simplify code * Remove unnecessary capture in lambda expressions * Extract fast GeLU calculation as standalone method * Mark methods as 'constexpr' * Add constraint for HostTensorDescriptor templated ctors * Simplify HostTensorDescriptor ctor calls * Add C++23 std::size_t literal suffix * Use _uz suffix to shorten example code * Remove unnecessary conversion to std::array<> * Re-order include directives * Remove C-style casting by literal suffix * Remove unnecessary statements in main() * Remove unused type parameter of always_false<> * Remove unused include directive * Exit main() by returning meaningful value * Use 'if constexpr' to switch example flow * Use std::is_same_v<> to shorten example code * Add 'inline' specifier to literal functions * Unify output methods in example * Move common codes into .inc file * Add type check in type_convert<>() * Add type_convert() before computation * Merge AddAddFastGelu method specializations * Remove always_false<> * Add constraint to AddAddFastGelu::operator() parameter types [ROCm/composable_kernel commit: 68b61504a38e26ad5ad4c8be26e9653cfe62c6ed] --- .../04_gemm_add_add_fastgelu/CMakeLists.txt | 3 + .../gemm_add_add_fastgelu_xdl_bf16.cpp | 67 ++++++ .../gemm_add_add_fastgelu_xdl_fp16.cpp | 198 +---------------- .../gemm_add_add_fastgelu_xdl_fp32.cpp | 67 ++++++ .../gemm_add_add_fastgelu_xdl_int8.cpp | 67 ++++++ .../run_gemm_add_add_fastgelu_example.inc | 201 ++++++++++++++++++ .../gpu/element/element_wise_operation.hpp | 43 ++-- include/ck/utility/data_type.hpp | 2 + .../ck/library/utility/host_tensor.hpp | 30 ++- .../include/ck/library/utility/literals.hpp | 16 ++ 10 files changed, 469 insertions(+), 225 deletions(-) create mode 100644 example/04_gemm_add_add_fastgelu/gemm_add_add_fastgelu_xdl_bf16.cpp create mode 100644 example/04_gemm_add_add_fastgelu/gemm_add_add_fastgelu_xdl_fp32.cpp create mode 100644 example/04_gemm_add_add_fastgelu/gemm_add_add_fastgelu_xdl_int8.cpp create mode 100644 example/04_gemm_add_add_fastgelu/run_gemm_add_add_fastgelu_example.inc create mode 100644 library/include/ck/library/utility/literals.hpp diff --git a/example/04_gemm_add_add_fastgelu/CMakeLists.txt b/example/04_gemm_add_add_fastgelu/CMakeLists.txt index 754de47c2b..0285a53f28 100644 --- a/example/04_gemm_add_add_fastgelu/CMakeLists.txt +++ b/example/04_gemm_add_add_fastgelu/CMakeLists.txt @@ -1 +1,4 @@ +add_example_executable(example_gemm_add_add_fastgelu_xdl_bf16 gemm_add_add_fastgelu_xdl_bf16.cpp) add_example_executable(example_gemm_add_add_fastgelu_xdl_fp16 gemm_add_add_fastgelu_xdl_fp16.cpp) +add_example_executable(example_gemm_add_add_fastgelu_xdl_fp32 gemm_add_add_fastgelu_xdl_fp32.cpp) +add_example_executable(example_gemm_add_add_fastgelu_xdl_int8 gemm_add_add_fastgelu_xdl_int8.cpp) diff --git a/example/04_gemm_add_add_fastgelu/gemm_add_add_fastgelu_xdl_bf16.cpp b/example/04_gemm_add_add_fastgelu/gemm_add_add_fastgelu_xdl_bf16.cpp new file mode 100644 index 0000000000..2f7a4fd862 --- /dev/null +++ b/example/04_gemm_add_add_fastgelu/gemm_add_add_fastgelu_xdl_bf16.cpp @@ -0,0 +1,67 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_multiple_d_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" + +template +using S = ck::Sequence; + +using BF16 = ck::bhalf_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using AddAddFastGelu = ck::tensor_operation::element_wise::AddAddFastGelu; + +using ADataType = BF16; +using BDataType = BF16; +using AccDataType = F32; +using CShuffleDataType = F32; +using D0DataType = BF16; +using D1DataType = BF16; +using DsDataType = ck::Tuple; +using EDataType = BF16; + +using ALayout = Row; +using BLayout = Col; +using D0Layout = Row; +using D1Layout = Row; +using DsLayout = ck::Tuple; +using ELayout = Row; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CDEElementOp = AddAddFastGelu; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// clang-format off +using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleD_Xdl_CShuffle +//######| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| +//######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| +//######| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| +//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + < ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>; +// clang-format on + +#include "run_gemm_add_add_fastgelu_example.inc" + +int main(int argc, char* argv[]) { return !run_gemm_add_add_fastgelu_example(argc, argv); } diff --git a/example/04_gemm_add_add_fastgelu/gemm_add_add_fastgelu_xdl_fp16.cpp b/example/04_gemm_add_add_fastgelu/gemm_add_add_fastgelu_xdl_fp16.cpp index c440297ec6..149cef6f81 100644 --- a/example/04_gemm_add_add_fastgelu/gemm_add_add_fastgelu_xdl_fp16.cpp +++ b/example/04_gemm_add_add_fastgelu/gemm_add_add_fastgelu_xdl_fp16.cpp @@ -1,10 +1,10 @@ // SPDX-License-Identifier: MIT // Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +#include #include -#include -#include -#include +#include +#include #include "ck/ck.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" @@ -12,11 +12,12 @@ #include "ck/tensor_operation/gpu/device/device_gemm_multiple_d_xdl_cshuffle.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" +#include "ck/library/utility/check_err.hpp" #include "ck/library/utility/device_memory.hpp" #include "ck/library/utility/host_tensor.hpp" #include "ck/library/utility/host_tensor_generator.hpp" -#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" -#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/literals.hpp" template using S = ck::Sequence; @@ -61,189 +62,6 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleD_Xdl_C < ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>; // clang-format on -int main(int argc, char* argv[]) -{ - bool do_verification = true; - int init_method = 1; - bool time_kernel = false; +#include "run_gemm_add_add_fastgelu_example.inc" - // GEMM shape - ck::index_t M = 3840; - ck::index_t N = 4096; - ck::index_t K = 4096; - - ck::index_t StrideA = 4096; - ck::index_t StrideB = 4096; - ck::index_t StrideD0 = 0; - ck::index_t StrideD1 = 4096; - ck::index_t StrideE = 4096; - - if(argc == 1) - { - // use default case - } - else if(argc == 4) - { - do_verification = std::stoi(argv[1]); - init_method = std::stoi(argv[2]); - time_kernel = std::stoi(argv[3]); - } - else if(argc == 12) - { - do_verification = std::stoi(argv[1]); - init_method = std::stoi(argv[2]); - time_kernel = std::stoi(argv[3]); - - M = std::stoi(argv[4]); - N = std::stoi(argv[5]); - K = std::stoi(argv[6]); - - StrideA = std::stoi(argv[7]); - StrideB = std::stoi(argv[8]); - StrideD0 = std::stoi(argv[9]); - StrideD1 = std::stoi(argv[10]); - StrideE = std::stoi(argv[11]); - } - else - { - printf("arg1: verification (0=no, 1=yes)\n"); - printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); - printf("arg3: time kernel (0=no, 1=yes)\n"); - printf("arg4 to 10: M (256x), N(128x), K(32x), StrideA, StrideB, StrideD0, StrideD1, " - "StrideE\n"); - exit(0); - } - - auto f_host_tensor_descriptor = - [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { - if(std::is_same::value) - { - return HostTensorDescriptor(std::vector({row, col}), - std::vector({stride, 1})); - } - else - { - return HostTensorDescriptor(std::vector({row, col}), - std::vector({1, stride})); - } - }; - - Tensor a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); - Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); - Tensor d0_m_n(f_host_tensor_descriptor(M, N, StrideD0, D0Layout{})); - Tensor d1_m_n(f_host_tensor_descriptor(M, N, StrideD1, D1Layout{})); - Tensor e_m_n_host_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); - Tensor e_m_n_device_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); - - std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; - std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; - std::cout << "d0_m_n: " << d0_m_n.mDesc << std::endl; - std::cout << "d1_m_n: " << d1_m_n.mDesc << std::endl; - std::cout << "e_m_n: " << e_m_n_host_result.mDesc << std::endl; - - switch(init_method) - { - case 0: break; - case 1: - a_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 5}); - b_k_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); - d0_m_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); - d1_m_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); - break; - default: - a_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); - b_k_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); - d0_m_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); - d1_m_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); - } - - DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize()); - DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize()); - DeviceMem d0_device_buf(sizeof(D0DataType) * d0_m_n.mDesc.GetElementSpaceSize()); - DeviceMem d1_device_buf(sizeof(D1DataType) * d1_m_n.mDesc.GetElementSpaceSize()); - DeviceMem e_device_buf(sizeof(EDataType) * e_m_n_device_result.mDesc.GetElementSpaceSize()); - - a_device_buf.ToDevice(a_m_k.mData.data()); - b_device_buf.ToDevice(b_k_n.mData.data()); - d0_device_buf.ToDevice(d0_m_n.mData.data()); - d1_device_buf.ToDevice(d1_m_n.mData.data()); - - auto a_element_op = AElementOp{}; - auto b_element_op = BElementOp{}; - auto cde_element_op = CDEElementOp{}; - - // do GEMM - auto device_op = DeviceOpInstance{}; - auto invoker = device_op.MakeInvoker(); - auto argument = - device_op.MakeArgument(a_device_buf.GetDeviceBuffer(), - b_device_buf.GetDeviceBuffer(), - std::array{d0_device_buf.GetDeviceBuffer(), - d1_device_buf.GetDeviceBuffer()}, - e_device_buf.GetDeviceBuffer(), - M, - N, - K, - StrideA, - StrideB, - std::array{StrideD0, StrideD1}, - StrideE, - a_element_op, - b_element_op, - cde_element_op); - - if(!device_op.IsSupportedArgument(argument)) - { - throw std::runtime_error("wrong! this device_op instance does not support this problem"); - } - - float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); - - std::size_t flop = std::size_t(2) * M * N * K; - std::size_t num_btype = sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + - sizeof(D0DataType) * N + sizeof(D1DataType) * M * N + - sizeof(EDataType) * M * N; - - float tflops = static_cast(flop) / 1.E9 / ave_time; - - float gb_per_sec = num_btype / 1.E6 / ave_time; - - std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " - << device_op.GetTypeString() << std::endl; - - if(do_verification) - { - Tensor c_m_n(HostTensorDescriptor( - std::vector{static_cast(M), static_cast(N)})); - - using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; - - auto ref_gemm = ReferenceGemmInstance{}; - auto ref_invoker = ref_gemm.MakeInvoker(); - - auto ref_argument = - ref_gemm.MakeArgument(a_m_k, b_k_n, c_m_n, a_element_op, b_element_op, PassThrough{}); - - ref_invoker.Run(ref_argument); - - for(int m = 0; m < M; ++m) - { - for(int n = 0; n < N; ++n) - { - cde_element_op(e_m_n_host_result(m, n), c_m_n(m, n), d0_m_n(m, n), d1_m_n(m, n)); - } - } - - e_device_buf.FromDevice(e_m_n_device_result.mData.data()); - - return ck::utils::check_err(e_m_n_device_result.mData, e_m_n_host_result.mData) ? 0 : 1; - } - - return 0; -} +int main(int argc, char* argv[]) { return !run_gemm_add_add_fastgelu_example(argc, argv); } diff --git a/example/04_gemm_add_add_fastgelu/gemm_add_add_fastgelu_xdl_fp32.cpp b/example/04_gemm_add_add_fastgelu/gemm_add_add_fastgelu_xdl_fp32.cpp new file mode 100644 index 0000000000..dfef81fa0c --- /dev/null +++ b/example/04_gemm_add_add_fastgelu/gemm_add_add_fastgelu_xdl_fp32.cpp @@ -0,0 +1,67 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_multiple_d_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" + +template +using S = ck::Sequence; + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using AddAddFastGelu = ck::tensor_operation::element_wise::AddAddFastGelu; + +using ADataType = F32; +using BDataType = F32; +using AccDataType = F32; +using CShuffleDataType = F32; +using D0DataType = F32; +using D1DataType = F32; +using DsDataType = ck::Tuple; +using EDataType = F32; + +using ALayout = Row; +using BLayout = Col; +using D0Layout = Row; +using D1Layout = Row; +using DsLayout = ck::Tuple; +using ELayout = Row; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CDEElementOp = AddAddFastGelu; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// clang-format off +using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleD_Xdl_CShuffle +//######| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| +//######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| +//######| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| +//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + < ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 32, 1, 8>, 4>; +// clang-format on + +#include "run_gemm_add_add_fastgelu_example.inc" + +int main(int argc, char* argv[]) { return !run_gemm_add_add_fastgelu_example(argc, argv); } diff --git a/example/04_gemm_add_add_fastgelu/gemm_add_add_fastgelu_xdl_int8.cpp b/example/04_gemm_add_add_fastgelu/gemm_add_add_fastgelu_xdl_int8.cpp new file mode 100644 index 0000000000..c00339f7b8 --- /dev/null +++ b/example/04_gemm_add_add_fastgelu/gemm_add_add_fastgelu_xdl_int8.cpp @@ -0,0 +1,67 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_multiple_d_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" + +template +using S = ck::Sequence; + +using I8 = int8_t; +using I32 = int32_t; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using AddAddFastGelu = ck::tensor_operation::element_wise::AddAddFastGelu; + +using ADataType = I8; +using BDataType = I8; +using AccDataType = I32; +using CShuffleDataType = I32; +using D0DataType = I8; +using D1DataType = I8; +using DsDataType = ck::Tuple; +using EDataType = I8; + +using ALayout = Row; +using BLayout = Col; +using D0Layout = Row; +using D1Layout = Row; +using DsLayout = ck::Tuple; +using ELayout = Row; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CDEElementOp = AddAddFastGelu; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// clang-format off +using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleD_Xdl_CShuffle +//######| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| +//######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| +//######| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| +//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + < ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 256, 256, 128, 64, 16, 16, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 16>; +// clang-format on + +#include "run_gemm_add_add_fastgelu_example.inc" + +int main(int argc, char* argv[]) { return !run_gemm_add_add_fastgelu_example(argc, argv); } diff --git a/example/04_gemm_add_add_fastgelu/run_gemm_add_add_fastgelu_example.inc b/example/04_gemm_add_add_fastgelu/run_gemm_add_add_fastgelu_example.inc new file mode 100644 index 0000000000..a860a780e7 --- /dev/null +++ b/example/04_gemm_add_add_fastgelu/run_gemm_add_add_fastgelu_example.inc @@ -0,0 +1,201 @@ +#pragma once + +struct ProblemSize final +{ + ck::index_t M = 3840; + ck::index_t N = 4096; + ck::index_t K = 4096; + + ck::index_t StrideA = 4096; + ck::index_t StrideB = 4096; + ck::index_t StrideD0 = 0; + ck::index_t StrideD1 = 4096; + ck::index_t StrideE = 4096; +}; + +struct ExecutionConfig final +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = false; +}; + +bool run_gemm_add_add_fastgelu(const ProblemSize& problem_size, const ExecutionConfig& config) +{ + auto& [M, N, K, StrideA, StrideB, StrideD0, StrideD1, StrideE] = problem_size; + + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + if constexpr(std::is_same_v) + { + return HostTensorDescriptor({row, col}, {stride, 1_uz}); + } + else + { + return HostTensorDescriptor({row, col}, {1_uz, stride}); + } + }; + + Tensor a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); + Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); + Tensor d0_m_n(f_host_tensor_descriptor(M, N, StrideD0, D0Layout{})); + Tensor d1_m_n(f_host_tensor_descriptor(M, N, StrideD1, D1Layout{})); + Tensor e_m_n_host_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); + Tensor e_m_n_device_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); + + std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; + std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; + std::cout << "d0_m_n: " << d0_m_n.mDesc << std::endl; + std::cout << "d1_m_n: " << d1_m_n.mDesc << std::endl; + std::cout << "e_m_n: " << e_m_n_host_result.mDesc << std::endl; + + switch(config.init_method) + { + case 0: break; + case 1: + a_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b_k_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + d0_m_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + d1_m_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + default: + a_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b_k_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + d0_m_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + d1_m_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + } + + DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize()); + DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize()); + DeviceMem d0_device_buf(sizeof(D0DataType) * d0_m_n.mDesc.GetElementSpaceSize()); + DeviceMem d1_device_buf(sizeof(D1DataType) * d1_m_n.mDesc.GetElementSpaceSize()); + DeviceMem e_device_buf(sizeof(EDataType) * e_m_n_device_result.mDesc.GetElementSpaceSize()); + + a_device_buf.ToDevice(a_m_k.mData.data()); + b_device_buf.ToDevice(b_k_n.mData.data()); + d0_device_buf.ToDevice(d0_m_n.mData.data()); + d1_device_buf.ToDevice(d1_m_n.mData.data()); + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto cde_element_op = CDEElementOp{}; + + // do GEMM + auto device_op = DeviceOpInstance{}; + auto invoker = device_op.MakeInvoker(); + auto argument = + device_op.MakeArgument(a_device_buf.GetDeviceBuffer(), + b_device_buf.GetDeviceBuffer(), + {d0_device_buf.GetDeviceBuffer(), d1_device_buf.GetDeviceBuffer()}, + e_device_buf.GetDeviceBuffer(), + M, + N, + K, + StrideA, + StrideB, + {StrideD0, StrideD1}, + StrideE, + a_element_op, + b_element_op, + cde_element_op); + + if(!device_op.IsSupportedArgument(argument)) + { + throw std::runtime_error("wrong! this device_op instance does not support this problem"); + } + + float ave_time = invoker.Run(argument, StreamConfig{nullptr, config.time_kernel}); + + std::size_t flop = 2_uz * M * N * K; + std::size_t num_btype = sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + + sizeof(D0DataType) * N + sizeof(D1DataType) * M * N + + sizeof(EDataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " + << device_op.GetTypeString() << std::endl; + + if(config.do_verification) + { + Tensor c_m_n(HostTensorDescriptor{M, N}); + + using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; + + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = + ref_gemm.MakeArgument(a_m_k, b_k_n, c_m_n, a_element_op, b_element_op, PassThrough{}); + + ref_invoker.Run(ref_argument); + + for(int m = 0; m < M; ++m) + { + for(int n = 0; n < N; ++n) + { + cde_element_op(e_m_n_host_result(m, n), c_m_n(m, n), d0_m_n(m, n), d1_m_n(m, n)); + } + } + + e_device_buf.FromDevice(e_m_n_device_result.mData.data()); + + return ck::utils::check_err(e_m_n_device_result.mData, e_m_n_host_result.mData); + } + + return true; +} + +bool run_gemm_add_add_fastgelu_example(int argc, char* argv[]) +{ + ProblemSize problem_size; + ExecutionConfig config; + + if(argc == 1) + { + // use default case + } + else if(argc == 4) + { + config.do_verification = std::stoi(argv[1]); + config.init_method = std::stoi(argv[2]); + config.time_kernel = std::stoi(argv[3]); + } + else if(argc == 12) + { + config.do_verification = std::stoi(argv[1]); + config.init_method = std::stoi(argv[2]); + config.time_kernel = std::stoi(argv[3]); + + problem_size.M = std::stoi(argv[4]); + problem_size.N = std::stoi(argv[5]); + problem_size.K = std::stoi(argv[6]); + + problem_size.StrideA = std::stoi(argv[7]); + problem_size.StrideB = std::stoi(argv[8]); + problem_size.StrideD0 = std::stoi(argv[9]); + problem_size.StrideD1 = std::stoi(argv[10]); + problem_size.StrideE = std::stoi(argv[11]); + } + else + { + std::cerr << "arg1: verification (0=no, 1=yes)" << std::endl + << "arg2: initialization (0=no init, 1=integer value, 2=decimal value)" + << std::endl + << "arg3: time kernel (0=no, 1=yes)" << std::endl + << "arg4 to 10: M (256x), N(128x), K(32x), StrideA, StrideB, StrideD0, StrideD1, " + "StrideE" + << std::endl; + return true; + } + + return run_gemm_add_add_fastgelu(problem_size, config); +} diff --git a/include/ck/tensor_operation/gpu/element/element_wise_operation.hpp b/include/ck/tensor_operation/gpu/element/element_wise_operation.hpp index 9c273e750b..20b40d9fce 100644 --- a/include/ck/tensor_operation/gpu/element/element_wise_operation.hpp +++ b/include/ck/tensor_operation/gpu/element/element_wise_operation.hpp @@ -114,28 +114,33 @@ struct AddHardswishAdd // E = FastGelu(C + D0 + D1) struct AddAddFastGelu { - template - __host__ __device__ void operator()(E&, const C&, const D0&, const D1&) const; - - template <> - __host__ __device__ void operator()(half_t& e, - const float& c, - const half_t& d0, - const half_t& d1) const + // Fast GeLU + // https://paperswithcode.com/method/gelu + // y = 0.5*x*(1+tanh(sqrt(2/pi)*(x+0.044715*x^3))) + __host__ __device__ static constexpr float GetFastGeLU(float x) { - // Fast GeLU - // https://paperswithcode.com/method/gelu - // y = 0.5*x*(1+tanh(sqrt(2/pi)*(x+0.044715*x^3))) - const auto fast_gelu = [&](float x) { - const float u = float(2) * x * (float(0.035677) * x * x + float(0.797885)); - const float emu = exp(-u); - const float cdf = float(0.5) + float(0.5) * (float(2) / (float(1) + emu) - float(1)); - return x * cdf; - }; + const float u = 2.f * x * (0.035677f * x * x + 0.797885f); + const float emu = exp(-u); + const float cdf = 0.5f + 0.5f * (2.f / (1.f + emu) - 1.f); + return x * cdf; + } - const float y = fast_gelu(c + float(d0) + float(d1)); + template + static inline constexpr bool is_valid_param_type_v = + std::is_same_v || std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v; - e = type_convert(y); + template + __host__ __device__ constexpr void + operator()(E& e, const C& c, const D0& d0, const D1& d1) const + { + static_assert(is_valid_param_type_v && is_valid_param_type_v && + is_valid_param_type_v && is_valid_param_type_v); + + const float y = + GetFastGeLU(type_convert(c) + type_convert(d0) + type_convert(d1)); + + e = type_convert(y); } }; diff --git a/include/ck/utility/data_type.hpp b/include/ck/utility/data_type.hpp index 0e0d71a586..4b578bf149 100644 --- a/include/ck/utility/data_type.hpp +++ b/include/ck/utility/data_type.hpp @@ -934,6 +934,8 @@ using int8x64_t = typename vector_type::type; template __host__ __device__ constexpr Y type_convert(X x) { + static_assert(!std::is_reference_v && !std::is_reference_v); + return static_cast(x); } diff --git a/library/include/ck/library/utility/host_tensor.hpp b/library/include/ck/library/utility/host_tensor.hpp index 23596d553c..d6c033b2f4 100644 --- a/library/include/ck/library/utility/host_tensor.hpp +++ b/library/include/ck/library/utility/host_tensor.hpp @@ -77,38 +77,36 @@ struct HostTensorDescriptor void CalculateStrides(); - template + template >> HostTensorDescriptor(const std::initializer_list& lens) : mLens(lens.begin(), lens.end()) { this->CalculateStrides(); } - template - HostTensorDescriptor(const std::vector& lens) : mLens(lens.begin(), lens.end()) - { - this->CalculateStrides(); - } - - template + template ())), std::size_t>>> HostTensorDescriptor(const Range& lens) : mLens(lens.begin(), lens.end()) { this->CalculateStrides(); } - template + template && + std::is_convertible_v>> HostTensorDescriptor(const std::initializer_list& lens, const std::initializer_list& strides) : mLens(lens.begin(), lens.end()), mStrides(strides.begin(), strides.end()) { } - template - HostTensorDescriptor(const std::vector& lens, const std::vector& strides) - : mLens(lens.begin(), lens.end()), mStrides(strides.begin(), strides.end()) - { - } - - template + template < + typename Range1, + typename Range2, + typename = std::enable_if_t< + std::is_convertible_v())), std::size_t> && + std::is_convertible_v())), std::size_t>>> HostTensorDescriptor(const Range1& lens, const Range2& strides) : mLens(lens.begin(), lens.end()), mStrides(strides.begin(), strides.end()) { diff --git a/library/include/ck/library/utility/literals.hpp b/library/include/ck/library/utility/literals.hpp new file mode 100644 index 0000000000..a421a81190 --- /dev/null +++ b/library/include/ck/library/utility/literals.hpp @@ -0,0 +1,16 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +// [P0330] Literal Suffix for (signed) size_t (C++23) +// ref: https://wg21.link/p0330r8 +inline constexpr std::size_t operator""_uz(unsigned long long size) +{ + return static_cast(size); +} + +inline constexpr std::size_t operator""_zu(unsigned long long size) +{ + return static_cast(size); +}