From 7566771bddb32b0b5ce010adb39ca73bf3abd944 Mon Sep 17 00:00:00 2001 From: zjing14 Date: Tue, 23 Aug 2022 14:41:56 -0500 Subject: [PATCH] Add examples of batched/grouped/SplitK Gemm for int8/bfp16/fp16/fp32 (#361) * add examples into grouped/batched_gemm * adding splitK examples * fixed splitK * add bfp16 int8 example into splitK * formatting * use static_cast * added common for batched_gemm * add commons for examples of splitK/batched/grouped_gemm * return true * adjust splitK check tol * update example Co-authored-by: Chao Liu [ROCm/composable_kernel commit: 6091458300996a1b4a4f30ff25a828e8a40df7f2] --- example/15_grouped_gemm/CMakeLists.txt | 3 + .../grouped_gemm_xdl_bfp16.cpp | 61 +++++ .../15_grouped_gemm/grouped_gemm_xdl_fp16.cpp | 195 +-------------- .../15_grouped_gemm/grouped_gemm_xdl_fp32.cpp | 61 +++++ .../15_grouped_gemm/grouped_gemm_xdl_int8.cpp | 58 +++++ .../run_grouped_gemm_example.inc | 233 ++++++++++++++++++ example/24_batched_gemm/CMakeLists.txt | 4 + .../batched_gemm_xdl_bfp16.cpp | 59 +++++ .../24_batched_gemm/batched_gemm_xdl_fp16.cpp | 59 +++++ .../24_batched_gemm/batched_gemm_xdl_fp32.cpp | 58 +++++ .../24_batched_gemm/batched_gemm_xdl_int8.cpp | 56 +++++ .../run_batched_gemm_example.inc | 194 +++++++++++++++ .../24_batched_gemm_e_permute/CMakeLists.txt | 2 - example/35_splitK_gemm/CMakeLists.txt | 4 + .../run_splitK_gemm_example.inc | 196 +++++++++++++++ .../35_splitK_gemm/splitK_gemm_xdl_bfp16.cpp | 58 +++++ .../35_splitK_gemm/splitK_gemm_xdl_fp16.cpp | 58 +++++ .../35_splitK_gemm/splitK_gemm_xdl_fp32.cpp | 58 +++++ .../35_splitK_gemm/splitK_gemm_xdl_int8.cpp | 55 +++++ example/CMakeLists.txt | 4 +- .../device_gemm_xdl_splitk_c_shuffle.hpp | 2 +- .../gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp | 13 +- 22 files changed, 1284 insertions(+), 207 deletions(-) create mode 100644 example/15_grouped_gemm/grouped_gemm_xdl_bfp16.cpp create mode 100644 example/15_grouped_gemm/grouped_gemm_xdl_fp32.cpp create mode 100644 example/15_grouped_gemm/grouped_gemm_xdl_int8.cpp create mode 100644 example/15_grouped_gemm/run_grouped_gemm_example.inc create mode 100644 example/24_batched_gemm/CMakeLists.txt create mode 100644 example/24_batched_gemm/batched_gemm_xdl_bfp16.cpp create mode 100644 example/24_batched_gemm/batched_gemm_xdl_fp16.cpp create mode 100644 example/24_batched_gemm/batched_gemm_xdl_fp32.cpp create mode 100644 example/24_batched_gemm/batched_gemm_xdl_int8.cpp create mode 100644 example/24_batched_gemm/run_batched_gemm_example.inc delete mode 100644 example/24_batched_gemm_e_permute/CMakeLists.txt create mode 100644 example/35_splitK_gemm/CMakeLists.txt create mode 100644 example/35_splitK_gemm/run_splitK_gemm_example.inc create mode 100644 example/35_splitK_gemm/splitK_gemm_xdl_bfp16.cpp create mode 100644 example/35_splitK_gemm/splitK_gemm_xdl_fp16.cpp create mode 100644 example/35_splitK_gemm/splitK_gemm_xdl_fp32.cpp create mode 100644 example/35_splitK_gemm/splitK_gemm_xdl_int8.cpp diff --git a/example/15_grouped_gemm/CMakeLists.txt b/example/15_grouped_gemm/CMakeLists.txt index a8cac06930..2c9d2d78cd 100644 --- a/example/15_grouped_gemm/CMakeLists.txt +++ b/example/15_grouped_gemm/CMakeLists.txt @@ -1 +1,4 @@ +add_example_executable(example_grouped_gemm_xdl_fp32 grouped_gemm_xdl_fp32.cpp) add_example_executable(example_grouped_gemm_xdl_fp16 grouped_gemm_xdl_fp16.cpp) +add_example_executable(example_grouped_gemm_xdl_bfp16 grouped_gemm_xdl_bfp16.cpp) +add_example_executable(example_grouped_gemm_xdl_int8 grouped_gemm_xdl_int8.cpp) diff --git a/example/15_grouped_gemm/grouped_gemm_xdl_bfp16.cpp b/example/15_grouped_gemm/grouped_gemm_xdl_bfp16.cpp new file mode 100644 index 0000000000..427e82b40a --- /dev/null +++ b/example/15_grouped_gemm/grouped_gemm_xdl_bfp16.cpp @@ -0,0 +1,61 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_gemm_xdl.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" + +template +using S = ck::Sequence; + +using BF16 = ck::bhalf_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using ADataType = BF16; +using BDataType = BF16; +using AccDataType = F32; +using CShuffleDataType = BF16; +using DsDataType = ck::Tuple<>; +using EDataType = BF16; + +using ALayout = Row; +using BLayout = Col; +using DsLayout = ck::Tuple<>; +using ELayout = Row; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CDEElementOp = PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedGemm_Xdl + // clang-format off +//######| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| +//######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| +//######| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| +//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + < ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>; +// clang-format on + +#include "run_grouped_gemm_example.inc" + +int main(int argc, char* argv[]) { return !run_grouped_gemm_example(argc, argv); } diff --git a/example/15_grouped_gemm/grouped_gemm_xdl_fp16.cpp b/example/15_grouped_gemm/grouped_gemm_xdl_fp16.cpp index a107b6b8c8..13bb1c5405 100644 --- a/example/15_grouped_gemm/grouped_gemm_xdl_fp16.cpp +++ b/example/15_grouped_gemm/grouped_gemm_xdl_fp16.cpp @@ -56,197 +56,6 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedGemm_Xdl < ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>; // clang-format on -using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; +#include "run_grouped_gemm_example.inc" -int main(int argc, char* argv[]) -{ - bool do_verification = true; - int init_method = 1; - bool time_kernel = false; - - if(argc == 4) - { - do_verification = std::stoi(argv[1]); - init_method = std::stoi(argv[2]); - time_kernel = std::stoi(argv[3]); - } - else - { - printf("arg1: verification (0=no, 1=yes)\n"); - printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); - printf("arg3: time kernel (0=n0, 1=yes)\n"); - exit(0); - } - - int group_count = rand() % 16 + 1; - - // GEMM shape - std::vector gemm_descs; - std::vector p_a, p_b; - std::vector p_c; - - gemm_descs.reserve(group_count); - - for(int i = 0; i < group_count; i++) - { - int M = 256 + 256 * i; - int N = 128 + 128 * i; - int K = 64 + 64 * i; - - int stride_A = K; - int stride_B = K; - int stride_C = N; - - gemm_descs.push_back({M, N, K, stride_A, stride_B, stride_C, {}}); - } - - auto f_host_tensor_descriptor = - [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { - if(std::is_same::value) - { - return HostTensorDescriptor(std::vector({row, col}), - std::vector({stride, 1})); - } - else - { - return HostTensorDescriptor(std::vector({row, col}), - std::vector({1, stride})); - } - }; - - std::vector> a_tensors; - std::vector> b_tensors; - std::vector> c_host_tensors; - std::vector> c_device_tensors; - - a_tensors.reserve(group_count); - b_tensors.reserve(group_count); - c_host_tensors.reserve(group_count); - c_device_tensors.reserve(group_count); - - using DeviceMemPtr = std::unique_ptr; - - std::vector a_tensors_device, b_tensors_device, c_tensors_device; - - a_tensors_device.reserve(group_count); - b_tensors_device.reserve(group_count); - c_tensors_device.reserve(group_count); - - std::size_t flop = 0, num_btype = 0; - - for(std::size_t i = 0; i < gemm_descs.size(); i++) - { - a_tensors.push_back(Tensor(f_host_tensor_descriptor( - gemm_descs[i].M_, gemm_descs[i].K_, gemm_descs[i].stride_A_, ALayout{}))); - b_tensors.push_back(Tensor(f_host_tensor_descriptor( - gemm_descs[i].K_, gemm_descs[i].N_, gemm_descs[i].stride_B_, BLayout{}))); - c_host_tensors.push_back(Tensor(f_host_tensor_descriptor( - gemm_descs[i].M_, gemm_descs[i].N_, gemm_descs[i].stride_C_, ELayout{}))); - c_device_tensors.push_back(Tensor(f_host_tensor_descriptor( - gemm_descs[i].M_, gemm_descs[i].N_, gemm_descs[i].stride_C_, ELayout{}))); - - std::cout << "gemm[" << i << "] a_m_k: " << a_tensors[i].mDesc - << " b_k_n: " << b_tensors[i].mDesc << " c_m_n: " << c_device_tensors[i].mDesc - << std::endl; - - flop += std::size_t(2) * gemm_descs[i].M_ * gemm_descs[i].K_ * gemm_descs[i].N_; - num_btype += sizeof(ADataType) * a_tensors[i].mDesc.GetElementSize() + - sizeof(BDataType) * b_tensors[i].mDesc.GetElementSize() + - sizeof(EDataType) * c_device_tensors[i].mDesc.GetElementSize(); - - switch(init_method) - { - case 0: break; - case 1: - a_tensors[i].GenerateTensorValue(GeneratorTensor_2{-5, 5}); - b_tensors[i].GenerateTensorValue(GeneratorTensor_2{-5, 5}); - break; - case 2: - a_tensors[i].GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); - b_tensors[i].GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); - break; - default: - a_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<0>{}); - b_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<1>{}); - } - } - - for(std::size_t i = 0; i < gemm_descs.size(); i++) - { - a_tensors_device.emplace_back(std::make_unique( - sizeof(ADataType) * a_tensors[i].mDesc.GetElementSpaceSize())); - b_tensors_device.emplace_back(std::make_unique( - sizeof(BDataType) * b_tensors[i].mDesc.GetElementSpaceSize())); - c_tensors_device.emplace_back(std::make_unique( - sizeof(EDataType) * c_device_tensors[i].mDesc.GetElementSpaceSize())); - - a_tensors_device[i]->ToDevice(a_tensors[i].mData.data()); - b_tensors_device[i]->ToDevice(b_tensors[i].mData.data()); - - p_a.push_back(a_tensors_device[i]->GetDeviceBuffer()); - p_b.push_back(b_tensors_device[i]->GetDeviceBuffer()); - p_c.push_back(c_tensors_device[i]->GetDeviceBuffer()); - } - - auto a_element_op = AElementOp{}; - auto b_element_op = BElementOp{}; - auto c_element_op = CDEElementOp{}; - - auto gemm = DeviceGemmInstance{}; - auto invoker = gemm.MakeInvoker(); - - std::vector> p_Ds = {}; - - // do GEMM - auto argument = gemm.MakeArgument( - p_a, p_b, p_Ds, p_c, gemm_descs, a_element_op, b_element_op, c_element_op); - - DeviceMem gemm_desc_workspace(gemm.GetWorkSpaceSize(&argument)); - - gemm.SetWorkSpacePointer(&argument, gemm_desc_workspace.GetDeviceBuffer()); - - if(!gemm.IsSupportedArgument(argument)) - { - throw std::runtime_error( - "wrong! device_gemm with the specified compilation parameters does " - "not support this GEMM problem"); - } - - float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); - - float tflops = static_cast(flop) / 1.E9 / ave_time; - - float gb_per_sec = num_btype / 1.E6 / ave_time; - - std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " - << gemm.GetTypeString() << std::endl; - - bool pass = true; - if(do_verification) - { - for(std::size_t i = 0; i < gemm_descs.size(); i++) - { - c_tensors_device[i]->FromDevice(c_device_tensors[i].mData.data()); - auto ref_gemm = ReferenceGemmInstance{}; - auto ref_invoker = ref_gemm.MakeInvoker(); - - auto ref_argument = ref_gemm.MakeArgument(a_tensors[i], - b_tensors[i], - c_host_tensors[i], - a_element_op, - b_element_op, - c_element_op); - - ref_invoker.Run(ref_argument); - pass &= ck::utils::check_err(c_device_tensors[i].mData, c_host_tensors[i].mData); - } - } - - return pass ? 0 : 1; -} +int main(int argc, char* argv[]) { return !run_grouped_gemm_example(argc, argv); } diff --git a/example/15_grouped_gemm/grouped_gemm_xdl_fp32.cpp b/example/15_grouped_gemm/grouped_gemm_xdl_fp32.cpp new file mode 100644 index 0000000000..7d1a102d14 --- /dev/null +++ b/example/15_grouped_gemm/grouped_gemm_xdl_fp32.cpp @@ -0,0 +1,61 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_gemm_xdl.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" + +template +using S = ck::Sequence; + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using ADataType = F32; +using BDataType = F32; +using AccDataType = F32; +using CShuffleDataType = F32; +using DsDataType = ck::Tuple<>; +using EDataType = F32; + +using ALayout = Row; +using BLayout = Col; +using DsLayout = ck::Tuple<>; +using ELayout = Row; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CDEElementOp = PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedGemm_Xdl + // clang-format off +//######| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| +//######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| +//######| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| +//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + < ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 256, 256, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 32, 1, 8>, 4>; +// clang-format on + +#include "run_grouped_gemm_example.inc" + +int main(int argc, char* argv[]) { return !run_grouped_gemm_example(argc, argv); } diff --git a/example/15_grouped_gemm/grouped_gemm_xdl_int8.cpp b/example/15_grouped_gemm/grouped_gemm_xdl_int8.cpp new file mode 100644 index 0000000000..c96ff76bf3 --- /dev/null +++ b/example/15_grouped_gemm/grouped_gemm_xdl_int8.cpp @@ -0,0 +1,58 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_gemm_xdl.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" + +template +using S = ck::Sequence; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using ADataType = int8_t; +using BDataType = int8_t; +using AccDataType = int32_t; +using CShuffleDataType = int8_t; +using DsDataType = ck::Tuple<>; +using EDataType = int8_t; + +using ALayout = Row; +using BLayout = Col; +using DsLayout = ck::Tuple<>; +using ELayout = Row; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CDEElementOp = PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedGemm_Xdl + // clang-format off +//######| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| +//######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| +//######| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| +//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + < ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 256, 256, 128, 64, 16, 16, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 4>, 16>; +// clang-format on + +#include "run_grouped_gemm_example.inc" + +int main(int argc, char* argv[]) { return !run_grouped_gemm_example(argc, argv); } diff --git a/example/15_grouped_gemm/run_grouped_gemm_example.inc b/example/15_grouped_gemm/run_grouped_gemm_example.inc new file mode 100644 index 0000000000..e1a4134846 --- /dev/null +++ b/example/15_grouped_gemm/run_grouped_gemm_example.inc @@ -0,0 +1,233 @@ +#pragma once + +struct ProblemSize final +{ + std::vector Ms; + std::vector Ns; + std::vector Ks; + + std::vector stride_As; + std::vector stride_Bs; + std::vector stride_Cs; + + ck::index_t group_count; +}; + +struct ExecutionConfig final +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = false; +}; + +bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& config) +{ + int group_count = problem_size.group_count; + + // GEMM shape + std::vector gemm_descs; + std::vector p_a, p_b; + std::vector p_c; + + gemm_descs.reserve(group_count); + + for(int i = 0; i < group_count; i++) + { + int M = problem_size.Ms[i]; + int N = problem_size.Ns[i]; + int K = problem_size.Ks[i]; + + int stride_A = problem_size.stride_As[i]; + int stride_B = problem_size.stride_Bs[i]; + int stride_C = problem_size.stride_Cs[i]; + + gemm_descs.push_back({M, N, K, stride_A, stride_B, stride_C, {}}); + } + + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + if(std::is_same::value) + { + return HostTensorDescriptor(std::vector({row, col}), + std::vector({stride, 1})); + } + else + { + return HostTensorDescriptor(std::vector({row, col}), + std::vector({1, stride})); + } + }; + + std::vector> a_tensors; + std::vector> b_tensors; + std::vector> c_host_tensors; + std::vector> c_device_tensors; + + a_tensors.reserve(group_count); + b_tensors.reserve(group_count); + c_host_tensors.reserve(group_count); + c_device_tensors.reserve(group_count); + + using DeviceMemPtr = std::unique_ptr; + + std::vector a_tensors_device, b_tensors_device, c_tensors_device; + + a_tensors_device.reserve(group_count); + b_tensors_device.reserve(group_count); + c_tensors_device.reserve(group_count); + + std::size_t flop = 0, num_btype = 0; + + for(std::size_t i = 0; i < gemm_descs.size(); i++) + { + a_tensors.push_back(Tensor(f_host_tensor_descriptor( + gemm_descs[i].M_, gemm_descs[i].K_, gemm_descs[i].stride_A_, ALayout{}))); + b_tensors.push_back(Tensor(f_host_tensor_descriptor( + gemm_descs[i].K_, gemm_descs[i].N_, gemm_descs[i].stride_B_, BLayout{}))); + c_host_tensors.push_back(Tensor(f_host_tensor_descriptor( + gemm_descs[i].M_, gemm_descs[i].N_, gemm_descs[i].stride_C_, ELayout{}))); + c_device_tensors.push_back(Tensor(f_host_tensor_descriptor( + gemm_descs[i].M_, gemm_descs[i].N_, gemm_descs[i].stride_C_, ELayout{}))); + + std::cout << "gemm[" << i << "] a_m_k: " << a_tensors[i].mDesc + << " b_k_n: " << b_tensors[i].mDesc << " c_m_n: " << c_device_tensors[i].mDesc + << std::endl; + + flop += std::size_t(2) * gemm_descs[i].M_ * gemm_descs[i].K_ * gemm_descs[i].N_; + num_btype += sizeof(ADataType) * a_tensors[i].mDesc.GetElementSize() + + sizeof(BDataType) * b_tensors[i].mDesc.GetElementSize() + + sizeof(EDataType) * c_device_tensors[i].mDesc.GetElementSize(); + + switch(config.init_method) + { + case 0: break; + case 1: + a_tensors[i].GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b_tensors[i].GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + case 2: + a_tensors[i].GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b_tensors[i].GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + break; + default: + a_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<0>{}); + b_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<1>{}); + } + } + + for(std::size_t i = 0; i < gemm_descs.size(); i++) + { + a_tensors_device.emplace_back(std::make_unique( + sizeof(ADataType) * a_tensors[i].mDesc.GetElementSpaceSize())); + b_tensors_device.emplace_back(std::make_unique( + sizeof(BDataType) * b_tensors[i].mDesc.GetElementSpaceSize())); + c_tensors_device.emplace_back(std::make_unique( + sizeof(EDataType) * c_device_tensors[i].mDesc.GetElementSpaceSize())); + + a_tensors_device[i]->ToDevice(a_tensors[i].mData.data()); + b_tensors_device[i]->ToDevice(b_tensors[i].mData.data()); + + p_a.push_back(a_tensors_device[i]->GetDeviceBuffer()); + p_b.push_back(b_tensors_device[i]->GetDeviceBuffer()); + p_c.push_back(c_tensors_device[i]->GetDeviceBuffer()); + } + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto c_element_op = CDEElementOp{}; + + auto gemm = DeviceGemmInstance{}; + auto invoker = gemm.MakeInvoker(); + + std::vector> p_Ds = {}; + + // do GEMM + auto argument = gemm.MakeArgument( + p_a, p_b, p_Ds, p_c, gemm_descs, a_element_op, b_element_op, c_element_op); + + DeviceMem gemm_desc_workspace(gemm.GetWorkSpaceSize(&argument)); + + gemm.SetWorkSpacePointer(&argument, gemm_desc_workspace.GetDeviceBuffer()); + + if(!gemm.IsSupportedArgument(argument)) + { + throw std::runtime_error( + "wrong! device_gemm with the specified compilation parameters does " + "not support this GEMM problem"); + } + + float ave_time = invoker.Run(argument, StreamConfig{nullptr, config.time_kernel}); + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " + << gemm.GetTypeString() << std::endl; + + bool pass = true; + if(config.do_verification) + { + using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; + + for(std::size_t i = 0; i < gemm_descs.size(); i++) + { + c_tensors_device[i]->FromDevice(c_device_tensors[i].mData.data()); + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = ref_gemm.MakeArgument(a_tensors[i], + b_tensors[i], + c_host_tensors[i], + a_element_op, + b_element_op, + c_element_op); + + ref_invoker.Run(ref_argument); + pass &= ck::utils::check_err(c_device_tensors[i].mData, c_host_tensors[i].mData); + } + } + + return pass ? 0 : 1; +} + +bool run_grouped_gemm_example(int argc, char* argv[]) +{ + ProblemSize problem_size; + ExecutionConfig config; + + problem_size.group_count = 16; + + for(int i = 0; i < problem_size.group_count; i++) + { + problem_size.Ms.push_back(256 + 256 * i); + problem_size.Ns.push_back(128 + 128 * i); + problem_size.Ks.push_back(64 + 64 * i); + + problem_size.stride_As.push_back(problem_size.Ks[i]); + problem_size.stride_Bs.push_back(problem_size.Ks[i]); + problem_size.stride_Cs.push_back(problem_size.Ns[i]); + } + + if(argc == 4) + { + config.do_verification = std::stoi(argv[1]); + config.init_method = std::stoi(argv[2]); + config.time_kernel = std::stoi(argv[3]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: time kernel (0=n0, 1=yes)\n"); + exit(0); + } + + return run_grouped_gemm(problem_size, config); +} diff --git a/example/24_batched_gemm/CMakeLists.txt b/example/24_batched_gemm/CMakeLists.txt new file mode 100644 index 0000000000..8ca5e55dcb --- /dev/null +++ b/example/24_batched_gemm/CMakeLists.txt @@ -0,0 +1,4 @@ +add_example_executable(example_batched_gemm_xdl_fp32 batched_gemm_xdl_fp32.cpp) +add_example_executable(example_batched_gemm_xdl_fp16 batched_gemm_xdl_fp16.cpp) +add_example_executable(example_batched_gemm_xdl_bfp16 batched_gemm_xdl_bfp16.cpp) +add_example_executable(example_batched_gemm_xdl_int8 batched_gemm_xdl_int8.cpp) diff --git a/example/24_batched_gemm/batched_gemm_xdl_bfp16.cpp b/example/24_batched_gemm/batched_gemm_xdl_bfp16.cpp new file mode 100644 index 0000000000..42beb0e92c --- /dev/null +++ b/example/24_batched_gemm/batched_gemm_xdl_bfp16.cpp @@ -0,0 +1,59 @@ +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_batched_gemm_multi_d_xdl.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp" +#include "ck/library/utility/literals.hpp" + +template +using S = ck::Sequence; + +using BF16 = ck::bhalf_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using ADataType = BF16; +using BDataType = BF16; +using AccDataType = F32; +using CShuffleDataType = BF16; +using DsDataType = ck::Tuple<>; +using EDataType = BF16; + +using ALayout = Row; +using BLayout = Col; +using DsLayout = ck::Tuple<>; +using ELayout = Row; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CDEElementOp = PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// clang-format off +using DeviceGemmInstance = ck::tensor_operation::device::DeviceBatchedGemmMultiD_Xdl +//######| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| +//######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| +//######| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| +//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + < ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>; +// clang-format on + +#include "run_batched_gemm_example.inc" + +int main(int argc, char* argv[]) { return !run_batched_gemm_example(argc, argv); } diff --git a/example/24_batched_gemm/batched_gemm_xdl_fp16.cpp b/example/24_batched_gemm/batched_gemm_xdl_fp16.cpp new file mode 100644 index 0000000000..f9dc581087 --- /dev/null +++ b/example/24_batched_gemm/batched_gemm_xdl_fp16.cpp @@ -0,0 +1,59 @@ +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_batched_gemm_multi_d_xdl.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp" +#include "ck/library/utility/literals.hpp" + +template +using S = ck::Sequence; + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using ADataType = F16; +using BDataType = F16; +using AccDataType = F32; +using CShuffleDataType = F16; +using DsDataType = ck::Tuple<>; +using EDataType = F16; + +using ALayout = Row; +using BLayout = Col; +using DsLayout = ck::Tuple<>; +using ELayout = Row; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CDEElementOp = PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// clang-format off +using DeviceGemmInstance = ck::tensor_operation::device::DeviceBatchedGemmMultiD_Xdl +//######| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| +//######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| +//######| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| +//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + < ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>; +// clang-format on + +#include "run_batched_gemm_example.inc" + +int main(int argc, char* argv[]) { return !run_batched_gemm_example(argc, argv); } diff --git a/example/24_batched_gemm/batched_gemm_xdl_fp32.cpp b/example/24_batched_gemm/batched_gemm_xdl_fp32.cpp new file mode 100644 index 0000000000..304cd14dbf --- /dev/null +++ b/example/24_batched_gemm/batched_gemm_xdl_fp32.cpp @@ -0,0 +1,58 @@ +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_batched_gemm_multi_d_xdl.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp" +#include "ck/library/utility/literals.hpp" + +template +using S = ck::Sequence; + +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using ADataType = F32; +using BDataType = F32; +using AccDataType = F32; +using CShuffleDataType = F32; +using DsDataType = ck::Tuple<>; +using EDataType = F32; + +using ALayout = Row; +using BLayout = Col; +using DsLayout = ck::Tuple<>; +using ELayout = Row; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CDEElementOp = PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// clang-format off +using DeviceGemmInstance = ck::tensor_operation::device::DeviceBatchedGemmMultiD_Xdl +//######| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| +//######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| +//######| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| +//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + < ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 256, 256, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 32, 1, 8>, 4>; +// clang-format on + +#include "run_batched_gemm_example.inc" + +int main(int argc, char* argv[]) { return !run_batched_gemm_example(argc, argv); } diff --git a/example/24_batched_gemm/batched_gemm_xdl_int8.cpp b/example/24_batched_gemm/batched_gemm_xdl_int8.cpp new file mode 100644 index 0000000000..cc48355073 --- /dev/null +++ b/example/24_batched_gemm/batched_gemm_xdl_int8.cpp @@ -0,0 +1,56 @@ +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_batched_gemm_multi_d_xdl.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp" +#include "ck/library/utility/literals.hpp" + +template +using S = ck::Sequence; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using ADataType = int8_t; +using BDataType = int8_t; +using AccDataType = int32_t; +using CShuffleDataType = int8_t; +using DsDataType = ck::Tuple<>; +using EDataType = int8_t; + +using ALayout = Row; +using BLayout = Col; +using DsLayout = ck::Tuple<>; +using ELayout = Row; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CDEElementOp = PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// clang-format off +using DeviceGemmInstance = ck::tensor_operation::device::DeviceBatchedGemmMultiD_Xdl +//######| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| +//######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| +//######| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| +//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + < ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 256, 256, 128, 64, 16, 16, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 4>, 16>; +// clang-format on + +#include "run_batched_gemm_example.inc" + +int main(int argc, char* argv[]) { return !run_batched_gemm_example(argc, argv); } diff --git a/example/24_batched_gemm/run_batched_gemm_example.inc b/example/24_batched_gemm/run_batched_gemm_example.inc new file mode 100644 index 0000000000..2db6ab76be --- /dev/null +++ b/example/24_batched_gemm/run_batched_gemm_example.inc @@ -0,0 +1,194 @@ +#pragma once + +struct ProblemSize final +{ + ck::index_t M = 3840; + ck::index_t N = 4096; + ck::index_t K = 4096; + + ck::index_t stride_A = K; + ck::index_t stride_B = K; + ck::index_t stride_C = N; + + ck::index_t batch_stride_A = M * K; + ck::index_t batch_stride_B = K * N; + ck::index_t batch_stride_C = M * N; + + ck::index_t batch_count = 16; +}; + +struct ExecutionConfig final +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = false; +}; + +bool run_batched_gemm(const ProblemSize& problem_size, const ExecutionConfig& config) +{ + using namespace ck::literals; + + auto& [M, N, K, stride_A, stride_B, stride_C, batch_stride_A, batch_stride_B, batch_stride_C, batch_count] = problem_size; + + // GEMM shape + auto f_host_tensor_descriptor = [](std::size_t batch_count_, + std::size_t row, + std::size_t col, + std::size_t stride, + std::size_t batch_stride, + auto layout) { + if(std::is_same::value) + { + return HostTensorDescriptor(std::vector({batch_count_, row, col}), + std::vector({batch_stride, stride, 1})); + } + else + { + return HostTensorDescriptor(std::vector({batch_count_, row, col}), + std::vector({batch_stride, 1, stride})); + } + }; + + Tensor a_g_m_k( + f_host_tensor_descriptor(batch_count, M, K, stride_A, batch_stride_A, ALayout{})); + Tensor b_g_k_n( + f_host_tensor_descriptor(batch_count, K, N, stride_B, batch_stride_B, BLayout{})); + + Tensor e_g_m_n_device_result( + f_host_tensor_descriptor(batch_count, M, N, stride_C, batch_stride_C, ELayout{})); + + std::cout << "a_g_m_k: " << a_g_m_k.mDesc << std::endl; + std::cout << "b_g_k_n: " << b_g_k_n.mDesc << std::endl; + std::cout << "e_g_m_n: " << e_g_m_n_device_result.mDesc << std::endl; + + switch(config.init_method) + { + case 0: break; + case 1: + a_g_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b_g_k_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + default: + a_g_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b_g_k_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + break; + } + + DeviceMem a_device_buf(sizeof(ADataType) * a_g_m_k.mDesc.GetElementSpaceSize()); + DeviceMem b_device_buf(sizeof(BDataType) * b_g_k_n.mDesc.GetElementSpaceSize()); + DeviceMem c_device_buf(sizeof(EDataType) * e_g_m_n_device_result.mDesc.GetElementSpaceSize()); + + a_device_buf.ToDevice(a_g_m_k.mData.data()); + b_device_buf.ToDevice(b_g_k_n.mData.data()); + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto cde_element_op = CDEElementOp{}; + + auto gemm = DeviceGemmInstance{}; + auto invoker = gemm.MakeInvoker(); + + // do GEMM + auto argument = gemm.MakeArgument(a_device_buf.GetDeviceBuffer(), + b_device_buf.GetDeviceBuffer(), + {}, + c_device_buf.GetDeviceBuffer(), + M, + N, + K, + batch_count, + stride_A, + stride_B, + {}, + stride_C, + batch_stride_A, + batch_stride_B, + {}, + batch_stride_C, + a_element_op, + b_element_op, + cde_element_op); + + if(!gemm.IsSupportedArgument(argument)) + { + throw std::runtime_error( + "wrong! device_gemm with the specified compilation parameters does " + "not support this GEMM problem"); + } + + float ave_time = invoker.Run(argument, StreamConfig{nullptr, config.time_kernel}); + + std::size_t flop = std::size_t(2) * batch_count * M * N * K; + std::size_t num_btype = sizeof(ADataType) * batch_count * M * K + + sizeof(BDataType) * batch_count * K * N + + sizeof(EDataType) * batch_count * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " + << gemm.GetTypeString() << std::endl; + + bool pass = true; + + if(config.do_verification) + { + c_device_buf.FromDevice(e_g_m_n_device_result.mData.data()); + + using ReferenceBatchedGemmInstance = ck::tensor_operation::host:: + ReferenceBatchedGemm; + + auto ref_batched_gemm = ReferenceBatchedGemmInstance{}; + auto ref_invoker = ref_batched_gemm.MakeInvoker(); + + Tensor e_g_m_n_host_result( + f_host_tensor_descriptor(batch_count, M, N, stride_C, batch_stride_C, ELayout{})); + + auto ref_argument = ref_batched_gemm.MakeArgument( + a_g_m_k, b_g_k_n, e_g_m_n_host_result, a_element_op, b_element_op, cde_element_op); + + ref_invoker.Run(ref_argument); + + pass = ck::utils::check_err( + e_g_m_n_host_result.mData, e_g_m_n_device_result.mData, "Error: Incorrect results c"); + } + + return pass ? 0 : 1; +} + +bool run_batched_gemm_example(int argc, char* argv[]) +{ + ProblemSize problem_size; + ExecutionConfig config; + + problem_size.M = 256 * (rand() % 16 + 1); + problem_size.N = 128 * (rand() % 16 + 1); + problem_size.K = 64 * (rand() % 16 + 1); + + problem_size.stride_A = problem_size.K; + problem_size.stride_B = problem_size.K; + problem_size.stride_C = problem_size.N; + + problem_size.batch_stride_A = problem_size.M * problem_size.K; + problem_size.batch_stride_B = problem_size.K * problem_size.N; + problem_size.batch_stride_C = problem_size.M * problem_size.N; + + problem_size.batch_count = 16; + + if(argc == 4) + { + config.do_verification = std::stoi(argv[1]); + config.init_method = std::stoi(argv[2]); + config.time_kernel = std::stoi(argv[3]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: time kernel (0=n0, 1=yes)\n"); + exit(0); + } + + return run_batched_gemm(problem_size, config); +} diff --git a/example/24_batched_gemm_e_permute/CMakeLists.txt b/example/24_batched_gemm_e_permute/CMakeLists.txt deleted file mode 100644 index 3c5d39784b..0000000000 --- a/example/24_batched_gemm_e_permute/CMakeLists.txt +++ /dev/null @@ -1,2 +0,0 @@ -add_example_executable(example_batched_gemm_e_permute_xdl_fp16 batched_gemm_e_permute_xdl_fp16.cpp) - diff --git a/example/35_splitK_gemm/CMakeLists.txt b/example/35_splitK_gemm/CMakeLists.txt new file mode 100644 index 0000000000..ceb20921f3 --- /dev/null +++ b/example/35_splitK_gemm/CMakeLists.txt @@ -0,0 +1,4 @@ +add_example_executable(example_splitK_gemm_xdl_fp32 splitK_gemm_xdl_fp32.cpp) +add_example_executable(example_splitK_gemm_xdl_fp16 splitK_gemm_xdl_fp16.cpp) +add_example_executable(example_splitK_gemm_xdl_bfp16 splitK_gemm_xdl_bfp16.cpp) +add_example_executable(example_splitK_gemm_xdl_int8 splitK_gemm_xdl_int8.cpp) diff --git a/example/35_splitK_gemm/run_splitK_gemm_example.inc b/example/35_splitK_gemm/run_splitK_gemm_example.inc new file mode 100644 index 0000000000..cbd43869dd --- /dev/null +++ b/example/35_splitK_gemm/run_splitK_gemm_example.inc @@ -0,0 +1,196 @@ +#pragma once + +struct ProblemSize final +{ + ck::index_t M = 3840; + ck::index_t N = 4096; + ck::index_t K = 4096; + + ck::index_t stride_A = K; + ck::index_t stride_B = K; + ck::index_t stride_C = N; + + ck::index_t k_batch = 4; +}; + +struct ExecutionConfig final +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = false; +}; + +bool run_splitK_gemm(const ProblemSize& problem_size, const ExecutionConfig& config) +{ + using namespace ck::literals; + + auto& [M, N, K, StrideA, StrideB, StrideC, KBatch] = problem_size; + + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + if(std::is_same::value) + { + return HostTensorDescriptor(std::vector({row, col}), + std::vector({stride, 1})); + } + else + { + return HostTensorDescriptor(std::vector({row, col}), + std::vector({1, stride})); + } + }; + + Tensor a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); + Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); + Tensor c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + Tensor c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + + std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; + std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; + std::cout << "c_m_n: " << c_m_n_host_result.mDesc << std::endl; + + switch(config.init_method) + { + case 0: break; + case 1: + a_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b_k_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + case 2: + a_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b_k_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + break; + default: + a_m_k.GenerateTensorValue(GeneratorTensor_Sequential<0>{}); + b_k_n.GenerateTensorValue(GeneratorTensor_Sequential<1>{}); + } + + DeviceMem a_m_k_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize()); + DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize()); + DeviceMem c_m_n_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpaceSize()); + + a_m_k_device_buf.ToDevice(a_m_k.mData.data()); + b_k_n_device_buf.ToDevice(b_k_n.mData.data()); + c_m_n_device_buf.SetZero(); + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto c_element_op = CElementOp{}; + + // do GEMM + auto gemm = DeviceGemmInstance{}; + auto invoker = gemm.MakeInvoker(); + auto argument = gemm.MakeArgument(static_cast(a_m_k_device_buf.GetDeviceBuffer()), + static_cast(b_k_n_device_buf.GetDeviceBuffer()), + static_cast(c_m_n_device_buf.GetDeviceBuffer()), + M, + N, + K, + StrideA, + StrideB, + StrideC, + a_element_op, + b_element_op, + c_element_op, + KBatch); + + if(!gemm.IsSupportedArgument(argument)) + { + std::cout << gemm.GetTypeString() << " does not support this problem" << std::endl; + + return 0; + } + + float ave_time = invoker.Run(argument, StreamConfig{nullptr, config.time_kernel}); + + std::size_t flop = std::size_t(2) * M * N * K; + std::size_t num_btype = + sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(CDataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " + << gemm.GetTypeString() << std::endl; + + c_m_n_device_buf.FromDevice(c_m_n_device_result.mData.data()); + + if(config.do_verification) + { + using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; + + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = ref_gemm.MakeArgument( + a_m_k, b_k_n, c_m_n_host_result, a_element_op, b_element_op, c_element_op); + + ref_invoker.Run(ref_argument); + + if(std::is_same::value) + { + return ck::utils::check_err(c_m_n_device_result.mData, + c_m_n_host_result.mData, + "fp16 incorrect result", + 3e-3, + 1e-3); + } + else + { + return ck::utils::check_err(c_m_n_device_result.mData, c_m_n_host_result.mData); + } + } + + return true; +} + +bool run_splitK_gemm_example(int argc, char* argv[]) +{ + ProblemSize problem_size; + ExecutionConfig config; + + if(argc == 1) + { + // use default case + } + else if(argc == 5) + { + config.do_verification = std::stoi(argv[1]); + config.init_method = std::stoi(argv[2]); + config.time_kernel = std::stoi(argv[3]); + problem_size.k_batch = std::stoi(argv[4]); + } + else if(argc == 11) + { + config.do_verification = std::stoi(argv[1]); + config.init_method = std::stoi(argv[2]); + config.time_kernel = std::stoi(argv[3]); + problem_size.k_batch = std::stoi(argv[4]); + + problem_size.M = std::stoi(argv[5]); + problem_size.N = std::stoi(argv[6]); + problem_size.K = std::stoi(argv[7]); + + problem_size.stride_A = std::stoi(argv[8]); + problem_size.stride_B = std::stoi(argv[9]); + problem_size.stride_C = std::stoi(argv[10]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: time kernel (0=no, 1=yes)\n"); + printf("arg4: KBatch\n"); + printf("arg5 to 11: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC\n"); + exit(0); + } + + return run_splitK_gemm(problem_size, config); +} diff --git a/example/35_splitK_gemm/splitK_gemm_xdl_bfp16.cpp b/example/35_splitK_gemm/splitK_gemm_xdl_bfp16.cpp new file mode 100644 index 0000000000..484a4494bd --- /dev/null +++ b/example/35_splitK_gemm/splitK_gemm_xdl_bfp16.cpp @@ -0,0 +1,58 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_xdl_splitk_c_shuffle.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" +#include "ck/library/utility/literals.hpp" + +template +using S = ck::Sequence; + +using BF16 = ck::bhalf_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using ADataType = BF16; +using BDataType = BF16; +using AccDataType = F32; +using CDataType = F32; + +using ALayout = Row; +using BLayout = Col; +using CLayout = Row; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CElementOp = PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdlSplitKCShuffle + // clang-format off +//######| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| KPer| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| +//######| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| +//######| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| +//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + < ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, GemmDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 4>; +// clang-format on + +#include "run_splitK_gemm_example.inc" + +int main(int argc, char* argv[]) { return !run_splitK_gemm_example(argc, argv); } diff --git a/example/35_splitK_gemm/splitK_gemm_xdl_fp16.cpp b/example/35_splitK_gemm/splitK_gemm_xdl_fp16.cpp new file mode 100644 index 0000000000..a1c43d0389 --- /dev/null +++ b/example/35_splitK_gemm/splitK_gemm_xdl_fp16.cpp @@ -0,0 +1,58 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_xdl_splitk_c_shuffle.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" +#include "ck/library/utility/literals.hpp" + +template +using S = ck::Sequence; + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using ADataType = F16; +using BDataType = F16; +using AccDataType = F32; +using CDataType = F16; + +using ALayout = Row; +using BLayout = Col; +using CLayout = Row; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CElementOp = PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdlSplitKCShuffle + // clang-format off +//######| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| KPer| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| +//######| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| +//######| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| +//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + < ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, GemmDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8>; +// clang-format on + +#include "run_splitK_gemm_example.inc" + +int main(int argc, char* argv[]) { return !run_splitK_gemm_example(argc, argv); } diff --git a/example/35_splitK_gemm/splitK_gemm_xdl_fp32.cpp b/example/35_splitK_gemm/splitK_gemm_xdl_fp32.cpp new file mode 100644 index 0000000000..01093461c3 --- /dev/null +++ b/example/35_splitK_gemm/splitK_gemm_xdl_fp32.cpp @@ -0,0 +1,58 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_xdl_splitk_c_shuffle.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" +#include "ck/library/utility/literals.hpp" + +template +using S = ck::Sequence; + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using ADataType = F32; +using BDataType = F32; +using AccDataType = F32; +using CDataType = F32; + +using ALayout = Row; +using BLayout = Col; +using CLayout = Row; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CElementOp = PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdlSplitKCShuffle + // clang-format off +//######| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| KPer| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| +//######| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| +//######| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| +//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + < ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, GemmDefault, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 4, 4, true, 1, 1, S<1, 32, 1, 8>, 4>; +// clang-format on + +#include "run_splitK_gemm_example.inc" + +int main(int argc, char* argv[]) { return !run_splitK_gemm_example(argc, argv); } diff --git a/example/35_splitK_gemm/splitK_gemm_xdl_int8.cpp b/example/35_splitK_gemm/splitK_gemm_xdl_int8.cpp new file mode 100644 index 0000000000..d2f51db2ce --- /dev/null +++ b/example/35_splitK_gemm/splitK_gemm_xdl_int8.cpp @@ -0,0 +1,55 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_xdl_splitk_c_shuffle.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" +#include "ck/library/utility/literals.hpp" + +template +using S = ck::Sequence; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using ADataType = int8_t; +using BDataType = int8_t; +using AccDataType = int32_t; +using CDataType = int32_t; + +using ALayout = Row; +using BLayout = Col; +using CLayout = Row; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CElementOp = PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdlSplitKCShuffle + // clang-format off +//######| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| KPer| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| +//######| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| +//######| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| +//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + < ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, GemmDefault, 256, 256, 128, 4, 16, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 16, 16, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 16, 16, true, 1, 1, S<1, 32, 1, 8>, 4>; +// clang-format on + +#include "run_splitK_gemm_example.inc" + +int main(int argc, char* argv[]) { return !run_splitK_gemm_example(argc, argv); } diff --git a/example/CMakeLists.txt b/example/CMakeLists.txt index 1845d46c05..4324c92e10 100644 --- a/example/CMakeLists.txt +++ b/example/CMakeLists.txt @@ -38,7 +38,7 @@ add_subdirectory(20_convnd_bwd_weight) add_subdirectory(21_gemm_layernorm) add_subdirectory(22_cgemm) add_subdirectory(23_softmax) -add_subdirectory(24_batched_gemm_e_permute) +add_subdirectory(24_batched_gemm) add_subdirectory(25_gemm_bias_e_permute) add_subdirectory(26_contraction) add_subdirectory(27_layernorm) @@ -49,4 +49,4 @@ add_subdirectory(31_batched_gemm_gemm) add_subdirectory(32_batched_gemm_scale_softmax_gemm) add_subdirectory(33_multiple_reduce) add_subdirectory(34_batchnorm) - +add_subdirectory(35_splitK_gemm) diff --git a/include/ck/tensor_operation/gpu/device/device_gemm_xdl_splitk_c_shuffle.hpp b/include/ck/tensor_operation/gpu/device/device_gemm_xdl_splitk_c_shuffle.hpp index eb2e521bdb..50515189fa 100644 --- a/include/ck/tensor_operation/gpu/device/device_gemm_xdl_splitk_c_shuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/device_gemm_xdl_splitk_c_shuffle.hpp @@ -95,7 +95,7 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{})); diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp index 84e1af0a35..190194f1eb 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp @@ -53,7 +53,7 @@ __global__ void GridwiseGemm::template Run(p_a_grid, p_b_grid, p_c_grid, - p_shared_block, + static_cast(p_shared_block), a_b_k0_m_k1_grid_desc, b_b_k0_n_k1_grid_desc, c_grid_desc_mblock_mperblock_nblock_nperblock, @@ -270,7 +270,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 __device__ static void Run(const FloatAB* __restrict__ p_a_grid, const FloatAB* __restrict__ p_b_grid, FloatC* __restrict__ p_c_grid, - FloatAB* __restrict__ p_shared_block, + void* __restrict__ p_shared_block, const AGridDesc_B_K0_M_K1& a_b_k0_m_k1_grid_desc, const BGridDesc_B_K0_N_K1& b_b_k0_n_k1_grid_desc, const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock& @@ -463,8 +463,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 constexpr auto a_block_space_size = math::integer_least_multiple(a_k0_m_k1_block_desc.GetElementSpaceSize(), max_lds_align); - FloatAB* p_a_block = p_shared_block; - FloatAB* p_b_block = p_shared_block + a_block_space_size; + FloatAB* p_a_block = static_cast(p_shared_block); + FloatAB* p_b_block = static_cast(p_shared_block) + a_block_space_size; constexpr auto a_block_slice_copy_step = make_multi_index(0, K0PerBlock, 0, 0); constexpr auto b_block_slice_copy_step = make_multi_index(0, K0PerBlock, 0, 0); @@ -547,11 +547,6 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 static_cast(p_shared_block), c_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); - static_assert(M1 == MWave, ""); - static_assert(N1 == NWave, ""); - static_assert(M2 * M3 * M4 == MPerXDL, ""); - static_assert(N2 == NPerXDL, ""); - constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor( c_block_desc_mblock_mperblock_nblock_nperblock, make_tuple(