From 94dadbf4ed1bbdad41a79bae70450051ca8edcd6 Mon Sep 17 00:00:00 2001 From: zjing14 Date: Tue, 22 Mar 2022 18:18:18 -0500 Subject: [PATCH] Grouped GEMM for fp16 (#126) * init of grouped_gemm * 2 gemm test * perf test * clean * wrap desc into a struct * test cast static_arr to pointer * add ptr to GemmDesc * add grouped gemm profiler * fixed mem issue with unique_ptr * clean * clean * finished ckprofiler * Update README.md * readme * fixed readme * add example * improve code * fixed comments: reserve, seperate ptr and gemm_shapes * merge group and non-group * fixed comments: replace push_back with emplace_back to avoid copy constructor * fixed comments: unified blk2ctile; add test * ci fix * fixed ci * fixed ci * fixed ci [ROCm/composable_kernel commit: 716f1c7fb172733d7ec9330b75aece8bad10a423] --- example/15_grouped_gemm/CMakeLists.txt | 1 + example/15_grouped_gemm/README.md | 58 ++ .../15_grouped_gemm/grouped_gemm_xdl_fp16.cpp | 234 ++++++++ example/CMakeLists.txt | 1 + .../gpu/device/device_gemm.hpp | 29 + .../gpu/device/device_grouped_gemm_xdl.hpp | 562 ++++++++++++++++++ .../gpu/grid/gridwise_gemm_xdlops_v2r3.hpp | 74 +++ .../gpu/CMakeLists.txt | 1 + .../gpu/grouped_gemm/CMakeLists.txt | 15 + ...gemm_xdl_f16_f16_f16_km_kn_mn_instance.cpp | 53 ++ ...gemm_xdl_f16_f16_f16_km_nk_mn_instance.cpp | 53 ++ ...gemm_xdl_f16_f16_f16_mk_kn_mn_instance.cpp | 62 ++ ...gemm_xdl_f16_f16_f16_mk_nk_mn_instance.cpp | 73 +++ profiler/CMakeLists.txt | 3 + .../include/profile_grouped_gemm_impl.hpp | 314 ++++++++++ profiler/src/profile_grouped_gemm.cpp | 157 +++++ profiler/src/profiler.cpp | 10 + test/CMakeLists.txt | 1 + test/grouped_gemm/CMakeLists.txt | 3 + test/grouped_gemm/grouped_gemm_fp16.cpp | 213 +++++++ 20 files changed, 1917 insertions(+) create mode 100644 example/15_grouped_gemm/CMakeLists.txt create mode 100644 example/15_grouped_gemm/README.md create mode 100644 example/15_grouped_gemm/grouped_gemm_xdl_fp16.cpp create mode 100644 include/ck/tensor_operation/gpu/device/device_grouped_gemm_xdl.hpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_gemm/CMakeLists.txt create mode 100644 library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_f16_f16_f16_km_kn_mn_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_f16_f16_f16_km_nk_mn_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_f16_f16_f16_mk_kn_mn_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_f16_f16_f16_mk_nk_mn_instance.cpp create mode 100644 profiler/include/profile_grouped_gemm_impl.hpp create mode 100644 profiler/src/profile_grouped_gemm.cpp create mode 100644 test/grouped_gemm/CMakeLists.txt create mode 100644 test/grouped_gemm/grouped_gemm_fp16.cpp diff --git a/example/15_grouped_gemm/CMakeLists.txt b/example/15_grouped_gemm/CMakeLists.txt new file mode 100644 index 0000000000..a8cac06930 --- /dev/null +++ b/example/15_grouped_gemm/CMakeLists.txt @@ -0,0 +1 @@ +add_example_executable(example_grouped_gemm_xdl_fp16 grouped_gemm_xdl_fp16.cpp) diff --git a/example/15_grouped_gemm/README.md b/example/15_grouped_gemm/README.md new file mode 100644 index 0000000000..b8245dc05a --- /dev/null +++ b/example/15_grouped_gemm/README.md @@ -0,0 +1,58 @@ +# Instructions for ```grouped_gemm_xdl``` Example + +## Docker script +```bash +docker run \ +-it \ +--rm \ +--privileged \ +--group-add sudo \ +-w /root/workspace \ +-v ${PATH_TO_LOCAL_WORKSPACE}:/root/workspace \ +rocm/tensorflow:rocm4.3.1-tf2.6-dev \ +/bin/bash +``` + +## Build ```grouped_gemm_xdl``` +```bash +mkdir build && cd build +``` + +```bash +# Need to specify target ID, example below is gfx908 +cmake \ +-D BUILD_DEV=OFF \ +-D CMAKE_BUILD_TYPE=Release \ +-D CMAKE_CXX_FLAGS="-DCK_AMD_GPU_GFX908 --amdgpu-target=gfx908 -O3 " \ +-D CMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc \ +-D CMAKE_PREFIX_PATH=/opt/rocm \ +.. +``` + +```bash + make -j example_grouped_gemm_xdl_fp16 +``` + +## Run ```grouped_gemm_xdl``` +```bash +#arg1: verification (0=no, 1=yes) +#arg2: initialization (0=no init, 1=integer value, 2=decimal value) +#arg3: run kernel # of times (>1) +./bin/example_grouped_gemm_xdl_fp16 0 1 5 +``` + +Result (MI100 @ 1087Mhz, 133.5TFlops peak FP16) +``` +gemm[0] a_m_k: dim 2, lengths {256, 64}, strides {64, 1} b_k_n: dim 2, lengths {64, 128}, strides {1, 64} c_m_n: dim 2, lengths {256, 128}, strides {128, 1} +gemm[1] a_m_k: dim 2, lengths {512, 128}, strides {128, 1} b_k_n: dim 2, lengths {128, 256}, strides {1, 128} c_m_n: dim 2, lengths {512, 256}, strides {256, 1} +gemm[2] a_m_k: dim 2, lengths {768, 192}, strides {192, 1} b_k_n: dim 2, lengths {192, 384}, strides {1, 192} c_m_n: dim 2, lengths {768, 384}, strides {384, 1} +gemm[3] a_m_k: dim 2, lengths {1024, 256}, strides {256, 1} b_k_n: dim 2, lengths {256, 512}, strides {1, 256} c_m_n: dim 2, lengths {1024, 512}, strides {512, 1} +group: 0 arg.a_grid_desc_k0_m_k1_{8, 256, 8}, arg.b_grid_desc_k0_n_k1_{8, 128, 8}, arg.c_grid_desc_m_n_{ 256, 128} +group: 1 arg.a_grid_desc_k0_m_k1_{16, 512, 8}, arg.b_grid_desc_k0_n_k1_{16, 256, 8}, arg.c_grid_desc_m_n_{ 512, 256} +group: 2 arg.a_grid_desc_k0_m_k1_{24, 768, 8}, arg.b_grid_desc_k0_n_k1_{24, 384, 8}, arg.c_grid_desc_m_n_{ 768, 384} +group: 3 arg.a_grid_desc_k0_m_k1_{32, 1024, 8}, arg.b_grid_desc_k0_n_k1_{32, 512, 8}, arg.c_grid_desc_m_n_{ 1024, 512} +launch_and_time_kernel: grid_dim {30, 1, 1}, block_dim {256, 1, 1} +Warm up +Start running 5 times... +Perf: 0.037887 ms, 11.0706 TFlops, 90.8132 GB/s, DeviceGroupedGemmXdl<256, 256, 128, 4, 8, 32, 32, 4, 2> +``` diff --git a/example/15_grouped_gemm/grouped_gemm_xdl_fp16.cpp b/example/15_grouped_gemm/grouped_gemm_xdl_fp16.cpp new file mode 100644 index 0000000000..03afb7c44c --- /dev/null +++ b/example/15_grouped_gemm/grouped_gemm_xdl_fp16.cpp @@ -0,0 +1,234 @@ +#include +#include +#include +#include +#include +#include +#include "config.hpp" +#include "print.hpp" +#include "device.hpp" +#include "host_tensor.hpp" +#include "host_tensor_generator.hpp" +#include "host_gemm.hpp" +#include "device_tensor.hpp" +#include "device_grouped_gemm_xdl.hpp" +#include "element_wise_operation.hpp" +#include "reference_gemm.hpp" +#include "gemm_specialization.hpp" + +template +using S = ck::Sequence; + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using ADataType = ck::half_t; +using BDataType = ck::half_t; +using CDataType = ck::half_t; +using AccDataType = float; + +using ALayout = ck::tensor_layout::gemm::RowMajor; +using BLayout = ck::tensor_layout::gemm::ColumnMajor; +using CLayout = ck::tensor_layout::gemm::RowMajor; + +using AElementOp = ck::tensor_operation::element_wise::PassThrough; +using BElementOp = ck::tensor_operation::element_wise::PassThrough; +using CElementOp = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization_t::Default; +// static constexpr auto GemmMNPadding = +// ck::tensor_operation::device::GemmSpecialization_t::MNPadding; + +// clang-format off +using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedGemmXdl +//######| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| Num| +//######| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Spacialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| Prefetch| +//######| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| | +//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + < F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1>; +// clang-format on + +using ReferenceGemmInstance = ck::tensor_operation::host:: + ReferenceGemm; + +int main(int argc, char* argv[]) +{ + bool do_verification = 0; + int init_method = 0; + int nrepeat = 5; + + if(argc == 4) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + nrepeat = std::stoi(argv[3]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: run kernel # of times (>1)\n"); + exit(0); + } + + int group_count = 4; + + // GEMM shape + std::vector gemm_shapes; + std::vector p_a, p_b; + std::vector p_c; + + gemm_shapes.reserve(group_count); + + for(int i = 0; i < group_count; i++) + { + int M = 256 + 256 * i; + int N = 128 + 128 * i; + int K = 64 + 64 * i; + + gemm_shapes.push_back({M, N, K, K, K, N}); + } + + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + if(std::is_same::value) + { + return HostTensorDescriptor(std::vector({row, col}), + std::vector({stride, 1})); + } + else + { + return HostTensorDescriptor(std::vector({row, col}), + std::vector({1, stride})); + } + }; + + std::vector> a_tensors; + ; + std::vector> b_tensors; + std::vector> c_host_tensors; + std::vector> c_device_tensors; + + a_tensors.reserve(group_count); + b_tensors.reserve(group_count); + c_host_tensors.reserve(group_count); + c_device_tensors.reserve(group_count); + + using DeviceMemPtr = std::unique_ptr; + + std::vector a_tensors_device, b_tensors_device, c_tensors_device; + + a_tensors_device.reserve(group_count); + b_tensors_device.reserve(group_count); + c_tensors_device.reserve(group_count); + + std::size_t flop = 0, num_btype = 0; + + for(int i = 0; i < gemm_shapes.size(); i++) + { + a_tensors.push_back(Tensor(f_host_tensor_descriptor( + gemm_shapes[i].M, gemm_shapes[i].K, gemm_shapes[i].StrideA, ALayout{}))); + b_tensors.push_back(Tensor(f_host_tensor_descriptor( + gemm_shapes[i].K, gemm_shapes[i].N, gemm_shapes[i].StrideB, BLayout{}))); + c_host_tensors.push_back(Tensor(f_host_tensor_descriptor( + gemm_shapes[i].M, gemm_shapes[i].N, gemm_shapes[i].StrideC, CLayout{}))); + c_device_tensors.push_back(Tensor(f_host_tensor_descriptor( + gemm_shapes[i].M, gemm_shapes[i].N, gemm_shapes[i].StrideC, CLayout{}))); + + std::cout << "gemm[" << i << "] a_m_k: " << a_tensors[i].mDesc + << " b_k_n: " << b_tensors[i].mDesc << " c_m_n: " << c_device_tensors[i].mDesc + << std::endl; + + flop += std::size_t(2) * gemm_shapes[i].M * gemm_shapes[i].K * gemm_shapes[i].N; + num_btype += sizeof(ADataType) * a_tensors[i].mDesc.GetElementSize() + + sizeof(BDataType) * b_tensors[i].mDesc.GetElementSize() + + sizeof(CDataType) * c_device_tensors[i].mDesc.GetElementSize(); + + switch(init_method) + { + case 0: break; + case 1: + a_tensors[i].GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b_tensors[i].GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + case 2: + a_tensors[i].GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b_tensors[i].GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + break; + default: + a_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<0>{}); + b_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<1>{}); + } + } + + for(int i = 0; i < gemm_shapes.size(); i++) + { + a_tensors_device.emplace_back( + std::make_unique(sizeof(ADataType) * a_tensors[i].mDesc.GetElementSize())); + b_tensors_device.emplace_back( + std::make_unique(sizeof(BDataType) * b_tensors[i].mDesc.GetElementSize())); + c_tensors_device.emplace_back(std::make_unique( + sizeof(CDataType) * c_device_tensors[i].mDesc.GetElementSize())); + + a_tensors_device[i]->ToDevice(a_tensors[i].mData.data()); + b_tensors_device[i]->ToDevice(b_tensors[i].mData.data()); + + p_a.push_back(a_tensors_device[i]->GetDeviceBuffer()); + p_b.push_back(b_tensors_device[i]->GetDeviceBuffer()); + p_c.push_back(c_tensors_device[i]->GetDeviceBuffer()); + } + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto c_element_op = CElementOp{}; + + // do GEMM + auto gemm = DeviceGemmInstance{}; + auto invoker = gemm.MakeInvoker(); + auto argument = + gemm.MakeArgument(p_a, p_b, p_c, gemm_shapes, a_element_op, b_element_op, c_element_op); + + if(!gemm.IsSupportedArgument(argument)) + { + throw std::runtime_error( + "wrong! device_gemm with the specified compilation parameters does " + "not support this GEMM problem"); + } + + float ave_time = invoker.Run(argument, nrepeat); + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " + << gemm.GetTypeString() << std::endl; + + if(do_verification) + { + for(int i = 0; i < gemm_shapes.size(); i++) + { + c_tensors_device[i]->FromDevice(c_device_tensors[i].mData.data()); + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = ref_gemm.MakeArgument(a_tensors[i], + b_tensors[i], + c_host_tensors[i], + a_element_op, + b_element_op, + c_element_op); + + ref_invoker.Run(ref_argument); + + check_error(c_host_tensors[i], c_device_tensors[i]); + } + } + + return 0; +} diff --git a/example/CMakeLists.txt b/example/CMakeLists.txt index b9fa9040e1..0be312ea33 100644 --- a/example/CMakeLists.txt +++ b/example/CMakeLists.txt @@ -39,3 +39,4 @@ add_subdirectory(11_conv2d_bwd_wgt) add_subdirectory(12_reduce) add_subdirectory(13_pool2d_fwd) add_subdirectory(14_gemm_xdl_requant_relu_requant) +add_subdirectory(15_grouped_gemm) diff --git a/include/ck/tensor_operation/gpu/device/device_gemm.hpp b/include/ck/tensor_operation/gpu/device/device_gemm.hpp index 72b79e8531..72aa780c52 100644 --- a/include/ck/tensor_operation/gpu/device/device_gemm.hpp +++ b/include/ck/tensor_operation/gpu/device/device_gemm.hpp @@ -8,6 +8,12 @@ namespace ck { namespace tensor_operation { namespace device { +struct GemmShape +{ + ck::index_t M, N, K; + ck::index_t StrideA, StrideB, StrideC; +}; + template @@ -65,6 +71,29 @@ template >; +template +struct DeviceGroupedGemm : public BaseOperator +{ + virtual std::unique_ptr MakeArgumentPointer(std::vector& p_a, + std::vector& p_b, + std::vector& p_c, + std::vector& gemm_shapes, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op, + ck::index_t KBatch = 1) = 0; + + virtual std::unique_ptr MakeInvokerPointer() = 0; +}; + +template +using DeviceGroupedGemmPtr = std::unique_ptr< + DeviceGroupedGemm>; + } // namespace device } // namespace tensor_operation } // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/device_grouped_gemm_xdl.hpp b/include/ck/tensor_operation/gpu/device/device_grouped_gemm_xdl.hpp new file mode 100644 index 0000000000..0c74f569c0 --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/device_grouped_gemm_xdl.hpp @@ -0,0 +1,562 @@ +#ifndef DEVICE_GROUPED_GEMM_XDL_HPP +#define DEVICE_GROUPED_GEMM_XDL_HPP + +#include +#include +#include "device.hpp" +#include "device_base.hpp" +#include "device_gemm.hpp" +#include "common_header.hpp" +#include "tensor_layout.hpp" +#include "tensor_descriptor.hpp" +#include "tensor_descriptor_helper.hpp" +#include "gridwise_gemm_xdlops_v2r3.hpp" +#include "gemm_specialization.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct DeviceGroupedGemmXdl + : public DeviceGroupedGemm +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + + static constexpr auto K1Number = Number{}; + + static auto MakeAGridDescriptor_K0_M_K1(index_t M, index_t K, index_t StrideA) + { + assert(K % K1 == 0); + + const index_t K0 = K / K1; + + const auto a_grid_desc_m_k = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA)); + } + }(); + + if constexpr(GemmSpecialization == GemmSpecialization_t::MNPadding) + { + const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock; + + return transform_tensor_descriptor( + a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)), + make_right_pad_transform(M, PadM)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + else + { + return transform_tensor_descriptor( + a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)), + make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + } + + static auto MakeBGridDescriptor_K0_N_K1(index_t K, index_t N, index_t StrideB) + { + assert(K % K1 == 0); + + const index_t K0 = K / K1; + + const auto b_grid_desc_k_n = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(StrideB, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(I1, StrideB)); + } + }(); + + if constexpr(GemmSpecialization == GemmSpecialization_t::MNPadding) + { + const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock; + + return transform_tensor_descriptor( + b_grid_desc_k_n, + make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)), + make_right_pad_transform(N, PadN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + else + { + return transform_tensor_descriptor( + b_grid_desc_k_n, + make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)), + make_pass_through_transform(N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + } + + static auto MakeCGridDescriptor_M_N(index_t M, index_t N, index_t StrideC) + { + const auto c_grid_desc_m_n = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC)); + } + }(); + + if constexpr(GemmSpecialization == GemmSpecialization_t::MNPadding) + { + const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock; + const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock; + + return transform_tensor_descriptor( + c_grid_desc_m_n, + make_tuple(make_right_pad_transform(M, PadM), make_right_pad_transform(N, PadN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else + { + + return transform_tensor_descriptor( + c_grid_desc_m_n, + make_tuple(make_pass_through_transform(M), make_pass_through_transform(N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + } + + using AGridDesc_K0_M_K1 = decltype(MakeAGridDescriptor_K0_M_K1(1, 1, 1)); + using BGridDesc_K0_N_K1 = decltype(MakeBGridDescriptor_K0_N_K1(1, 1, 1)); + using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1)); + + // GridwiseGemm + using GridwiseGemm = GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3< + BlockSize, + ADataType, // TODO: distinguish A/B datatype + AccDataType, + CDataType, + InMemoryDataOperationEnum_t::Set, + AGridDesc_K0_M_K1, + BGridDesc_K0_N_K1, + CGridDesc_M_N, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + MPerBlock, + NPerBlock, + K0PerBlock, + MPerXDL, + NPerXDL, + K1, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_K0_M_K1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_K1, + false, // AThreadTransferSrcResetCoordinateAfterRun, + ABlockLdsAddExtraM, + BBlockTransferThreadClusterLengths_K0_N_K1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_K1, + false, // BThreadTransferSrcResetCoordinateAfterRun, + BBlockLdsAddExtraN, + Sequence<0, 2, 4, 5, 6, 1, 3, 7>, // CThreadTransferSrcDstAccessOrder, + CThreadTransferSrcDstVectorDim, + CThreadTransferDstScalarPerVector, + NumPrefetch>; + + struct GroupedGemmBlock2CTileMap + { + GroupedGemmBlock2CTileMap() + { + block_2_ctile_map_ = GridwiseGemm::MakeDefaultBlock2CTileMap(CGridDesc_M_N{}, 1, 1); + BlockStart_ = -1; + } + + GroupedGemmBlock2CTileMap(const CGridDesc_M_N& c_grid_desc_m_n, + index_t M01, + index_t N01, + ck::index_t BlockStart) + { + block_2_ctile_map_ = GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n, M01, N01); + BlockStart_ = BlockStart; + } + + template + __host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx& idx_top) const + { + return block_2_ctile_map_.CalculateBottomIndex( + make_multi_index(idx_top[I0] - BlockStart_)); + } + + private: + typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_; + ck::index_t BlockStart_; + }; + + struct GemmDescKernelArg + { + AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1_; + BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_; + CGridDesc_M_N c_grid_desc_m_n_; + + typename GridwiseGemm::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 + c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_; + + GroupedGemmBlock2CTileMap grouped_gemm_block_2_ctile_map_; + + const ADataType* a_ptr; + const BDataType* b_ptr; + CDataType* c_ptr; + + ck::index_t BlockStart_, BlockEnd_; + }; + + // Argument + struct Argument : public BaseArgument + { + Argument(std::vector& p_a, + std::vector& p_b, + std::vector& p_c, + std::vector& gemm_shapes, + index_t M01, + index_t N01, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) + : M01_{M01}, + N01_{N01}, + a_element_op_{a_element_op}, + b_element_op_{b_element_op}, + c_element_op_{c_element_op} + { + grid_size_ = 0; + + group_count_ = static_cast(gemm_shapes.size()); + + if(!(group_count_ == p_a.size() && group_count_ == p_b.size() && + group_count_ == p_c.size())) + { + throw std::runtime_error("wrong! group_count_ != P_a/b/c.size"); + } + + gemm_desc_kernel_arg_.reserve(group_count_); + + for(index_t i = 0; i < gemm_shapes.size(); i++) + { + const index_t M = gemm_shapes[i].M; + const index_t N = gemm_shapes[i].N; + const index_t K = gemm_shapes[i].K; + + const index_t StrideA = gemm_shapes[i].StrideA; + const index_t StrideB = gemm_shapes[i].StrideB; + const index_t StrideC = gemm_shapes[i].StrideC; + + const auto a_grid_desc_k0_m_k1_ = + DeviceGroupedGemmXdl::MakeAGridDescriptor_K0_M_K1(M, K, StrideA); + const auto b_grid_desc_k0_n_k1_ = + DeviceGroupedGemmXdl::MakeBGridDescriptor_K0_N_K1(K, N, StrideB); + const auto c_grid_desc_m_n_ = + DeviceGroupedGemmXdl::MakeCGridDescriptor_M_N(M, N, StrideC); + + const index_t grid_size_grp = GridwiseGemm::CalculateGridSize(c_grid_desc_m_n_); + + const index_t BlockStart = grid_size_; + const index_t BlockEnd = grid_size_ + grid_size_grp; + + grid_size_ += grid_size_grp; + + if(GridwiseGemm::CheckValidity( + a_grid_desc_k0_m_k1_, b_grid_desc_k0_n_k1_, c_grid_desc_m_n_, M01_, N01_)) + { + const auto c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_ = + GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_grid_desc_m_n_); + + const auto grouped_gemm_block_2_ctile_map_ = + GroupedGemmBlock2CTileMap(c_grid_desc_m_n_, M01, N01, BlockStart); + + gemm_desc_kernel_arg_.push_back( + GemmDescKernelArg{a_grid_desc_k0_m_k1_, + b_grid_desc_k0_n_k1_, + c_grid_desc_m_n_, + c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_, + grouped_gemm_block_2_ctile_map_, + static_cast(p_a[i]), + static_cast(p_b[i]), + static_cast(p_c[i]), + BlockStart, + BlockEnd}); + } + } + } + + // private: + index_t M01_; + index_t N01_; + index_t group_count_; + AElementwiseOperation a_element_op_; + BElementwiseOperation b_element_op_; + CElementwiseOperation c_element_op_; + + std::vector gemm_desc_kernel_arg_; + + index_t grid_size_; + }; + + // Invoker + struct Invoker : public BaseInvoker + { + using Argument = DeviceGroupedGemmXdl::Argument; + + float Run(const Argument& arg, int nrepeat = 1) + { + StaticallyIndexedArray gemm_desc_kernel_arg_arg; + + bool has_main_k0_block_loop = true; + + static_for<0, MaxGroupCount, 1>{}([&](auto i) { + if(i < arg.gemm_desc_kernel_arg_.size()) + { + gemm_desc_kernel_arg_arg(i) = arg.gemm_desc_kernel_arg_[i]; + + std::cout << "group: " << i << " arg.a_grid_desc_k0_m_k1_{" + << gemm_desc_kernel_arg_arg[i].a_grid_desc_k0_m_k1_.GetLength(I0) + << ", " + << gemm_desc_kernel_arg_arg[i].a_grid_desc_k0_m_k1_.GetLength(I1) + << ", " + << gemm_desc_kernel_arg_arg[i].a_grid_desc_k0_m_k1_.GetLength(I2) + << "}"; + + std::cout << ", arg.b_grid_desc_k0_n_k1_{" + << gemm_desc_kernel_arg_arg[i].b_grid_desc_k0_n_k1_.GetLength(I0) + << ", " + << gemm_desc_kernel_arg_arg[i].b_grid_desc_k0_n_k1_.GetLength(I1) + << ", " + << gemm_desc_kernel_arg_arg[i].b_grid_desc_k0_n_k1_.GetLength(I2) + << "}"; + + std::cout << ", arg.c_grid_desc_m_n_{ " + << gemm_desc_kernel_arg_arg[i].c_grid_desc_m_n_.GetLength(I0) << ", " + << gemm_desc_kernel_arg_arg[i].c_grid_desc_m_n_.GetLength(I1) << "}" + << std::endl; + + if(!GridwiseGemm::CheckValidity( + gemm_desc_kernel_arg_arg[i].a_grid_desc_k0_m_k1_, + gemm_desc_kernel_arg_arg[i].b_grid_desc_k0_n_k1_, + gemm_desc_kernel_arg_arg[i].c_grid_desc_m_n_, + arg.M01_, + arg.N01_)) + { + throw std::runtime_error( + "wrong! GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 has invalid setting"); + } + + const auto K0 = gemm_desc_kernel_arg_arg[i].a_grid_desc_k0_m_k1_.GetLength(I0); + + if(GridwiseGemm::CalculateHasMainK0BlockLoop(K0) != has_main_k0_block_loop) + { + throw std::runtime_error("wrong! not all gemm has_main_k0_block_loop"); + } + } + }); + + float ave_time = 0; + + if(has_main_k0_block_loop) + { + const auto kernel = + kernel_grouped_gemm_xdlops_v2r3, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + true, + MaxGroupCount>; + + ave_time = launch_and_time_kernel(kernel, + nrepeat, + dim3(arg.grid_size_), + dim3(BlockSize), + 0, + gemm_desc_kernel_arg_arg, + arg.gemm_desc_kernel_arg_.size(), + arg.a_element_op_, + arg.b_element_op_, + arg.c_element_op_); + } + else + { + const auto kernel = + kernel_grouped_gemm_xdlops_v2r3, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + false, + MaxGroupCount>; + + ave_time = launch_and_time_kernel(kernel, + nrepeat, + dim3(arg.grid_size_), + dim3(BlockSize), + 0, + gemm_desc_kernel_arg_arg, + arg.gemm_desc_kernel_arg_.size(), + arg.a_element_op_, + arg.b_element_op_, + arg.c_element_op_); + } + + return ave_time; + } + + // polymorphic + float Run(const BaseArgument* p_arg, int nrepeat = 1) override + { + return Run(*dynamic_cast(p_arg), nrepeat); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + static bool IsSupportedArgument(const Argument& arg) + { + if(arg.gemm_desc_kernel_arg_.size() != arg.group_count_) + return false; + else + return true; + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument(std::vector& p_a, + std::vector& p_b, + std::vector& p_c, + std::vector gemm_shapes, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) + { + return Argument{p_a, p_b, p_c, gemm_shapes, 1, 1, a_element_op, b_element_op, c_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + // polymorphic + std::unique_ptr MakeArgumentPointer(std::vector& p_a, + std::vector& p_b, + std::vector& p_c, + std::vector& gemm_shapes, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op, + index_t /* KBatch */ = 1) override + { + return std::make_unique( + p_a, p_b, p_c, gemm_shapes, 1, 1, a_element_op, b_element_op, c_element_op); + } + + // polymorphic + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + // polymorphic + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceGroupedGemmXdl" + << "<" + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << K0PerBlock << ", " + << K1 << ", " + << MPerXDL << ", " + << NPerXDL << ", " + << MXdlPerWave << ", " + << NXdlPerWave + << ">"; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck +#endif diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp index 47622ad148..9ce5b3dae6 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp @@ -54,6 +54,80 @@ __global__ void block_2_ctile_map); } +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_grouped_gemm_xdlops_v2r3( + const StaticallyIndexedArray gemm_desc_, + const index_t group_count, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CElementwiseOperation c_element_op) +{ + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + const index_t block_id = get_block_1d_id(); + +#if 1 + static_for<0, MaxGroupCount, 1>{}([&](auto i) { + if(block_id >= gemm_desc_[i].BlockStart_ && block_id < gemm_desc_[i].BlockEnd_ && + i < group_count) + { + auto group_id = i; + + GridwiseGemm::template Run( + gemm_desc_[group_id].a_ptr, + gemm_desc_[group_id].b_ptr, + gemm_desc_[group_id].c_ptr, + p_shared, + gemm_desc_[group_id].a_grid_desc_k0_m_k1_, + gemm_desc_[group_id].b_grid_desc_k0_n_k1_, + gemm_desc_[group_id].c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_, + a_element_op, + b_element_op, + c_element_op, + gemm_desc_[group_id].grouped_gemm_block_2_ctile_map_); + } + }); +#else + const auto gemm_desc_ptr = reinterpret_cast(&gemm_desc_); + + index_t group_id = 0; + static_for<0, MaxGroupCount, 1>{}([&](auto i) { + group_id = (block_id >= gemm_desc_[i].BlockStart && block_id < gemm_desc_[i].BlockEnd && + i < group_count) + ? i + : group_id; + }); + + const index_t block_id_grp = block_id - gemm_desc_ptr[group_id].BlockStart; + + GridwiseGemm::template Run( + gemm_desc_ptr[group_id].a_ptr, + gemm_desc_ptr[group_id].b_ptr, + gemm_desc_ptr[group_id].c_ptr, + p_shared, + gemm_desc_ptr[group_id].a_grid_desc_k0_m_k1_, + gemm_desc_ptr[group_id].b_grid_desc_k0_n_k1_, + gemm_desc_ptr[group_id].c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_, + a_element_op, + b_element_op, + c_element_op, + gemm_desc_ptr[group_id].block_2_ctile_map_, + block_id_grp); +#endif +} + template +#include "config.hpp" +#include "device_grouped_gemm_xdl.hpp" +#include "element_wise_operation.hpp" +#include "device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_grouped_gemm_instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization_t::Default; + +// Compilation parameters for a[k, m] * b[k, n] = c[m, n] +using device_grouped_gemm_xdl_f16_f16_f16_km_kn_mn_instances = std::tuple< + // clang-format off + //#################| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //#################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Spacialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //#################| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //#################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedGemmXdl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceGroupedGemmXdl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceGroupedGemmXdl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceGroupedGemmXdl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceGroupedGemmXdl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceGroupedGemmXdl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceGroupedGemmXdl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1>, + DeviceGroupedGemmXdl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1> + // clang-format on + >; + +void add_device_grouped_gemm_xdl_f16_f16_f16_km_kn_mn_instances( + std::vector>& instances) +{ + add_device_operation_instances(instances, + device_grouped_gemm_xdl_f16_f16_f16_km_kn_mn_instances{}); +} + +} // namespace device_grouped_gemm_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_f16_f16_f16_km_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_f16_f16_f16_km_nk_mn_instance.cpp new file mode 100644 index 0000000000..20c970cebe --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_f16_f16_f16_km_nk_mn_instance.cpp @@ -0,0 +1,53 @@ +#include +#include "config.hpp" +#include "device_grouped_gemm_xdl.hpp" +#include "element_wise_operation.hpp" +#include "device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_grouped_gemm_instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization_t::Default; + +// Compilation parameters for a[k, m] * b[n, k] = c[m, n] +using device_grouped_gemm_xdl_f16_f16_f16_km_nk_mn_instances = std::tuple< + // clang-format off + //#################| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //#################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Spacialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //#################| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //#################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedGemmXdl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceGroupedGemmXdl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceGroupedGemmXdl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceGroupedGemmXdl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceGroupedGemmXdl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceGroupedGemmXdl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceGroupedGemmXdl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceGroupedGemmXdl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1> + // clang-format on + >; + +void add_device_grouped_gemm_xdl_f16_f16_f16_km_nk_mn_instances( + std::vector>& instances) +{ + add_device_operation_instances(instances, + device_grouped_gemm_xdl_f16_f16_f16_km_nk_mn_instances{}); +} + +} // namespace device_grouped_gemm_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_f16_f16_f16_mk_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_f16_f16_f16_mk_kn_mn_instance.cpp new file mode 100644 index 0000000000..b16d2b84c9 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_f16_f16_f16_mk_kn_mn_instance.cpp @@ -0,0 +1,62 @@ +#include +#include "config.hpp" +#include "device_grouped_gemm_xdl.hpp" +#include "element_wise_operation.hpp" +#include "device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_grouped_gemm_instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization_t::Default; + +// Compilation parameters for a[m, k] * b[k, n] = c[m, n] +using device_grouped_gemm_xdl_f16_f16_f16_mk_kn_mn_instances = std::tuple< + // clang-format off + //#################| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //#################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Spacialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //#################| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //#################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceGroupedGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceGroupedGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceGroupedGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceGroupedGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceGroupedGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceGroupedGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1>, + DeviceGroupedGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceGroupedGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 32, 256, 4, 8, 32, 32, 1, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, true, 7, 1>, + DeviceGroupedGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceGroupedGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 32, 64, 4, 8, 32, 32, 1, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceGroupedGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 32, 32, 4, 8, 32, 32, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceGroupedGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 16, 256, 4, 8, 16, 16, 1, 8, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, true, 7, 1>, + DeviceGroupedGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 16, 128, 4, 8, 16, 16, 1, 4, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, + DeviceGroupedGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 16, 64, 4, 8, 16, 16, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, + DeviceGroupedGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 16, 32, 4, 8, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1>, + DeviceGroupedGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 16, 16, 4, 8, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1> + // clang-format on + >; + +void add_device_grouped_gemm_xdl_f16_f16_f16_mk_kn_mn_instances( + std::vector>& instances) +{ + add_device_operation_instances(instances, + device_grouped_gemm_xdl_f16_f16_f16_mk_kn_mn_instances{}); +} + +} // namespace device_grouped_gemm_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_f16_f16_f16_mk_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_f16_f16_f16_mk_nk_mn_instance.cpp new file mode 100644 index 0000000000..5a6f64b9da --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_f16_f16_f16_mk_nk_mn_instance.cpp @@ -0,0 +1,73 @@ +#include +#include "config.hpp" +#include "device_grouped_gemm_xdl.hpp" +#include "element_wise_operation.hpp" +#include "device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_grouped_gemm_instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization_t::Default; +static constexpr auto GemmMNPadding = ck::tensor_operation::device::GemmSpecialization_t::MNPadding; + +// Compilation parameters for a[m, k] * b[n, k] = c[m, n] +using device_grouped_gemm_xdl_f16_f16_f16_mk_nk_mn_instances = std::tuple< + // clang-format off + //##################| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //##################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Spacialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //##################| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //##################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceGroupedGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceGroupedGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceGroupedGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceGroupedGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceGroupedGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceGroupedGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceGroupedGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceGroupedGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceGroupedGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceGroupedGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceGroupedGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceGroupedGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1> + // clang-format on + >; + +// irregular tile size +using device_grouped_gemm_xdl_f16_f16_f16_mk_nk_mn_irregular_tile_instances = std::tuple< + // clang-format off + //##################| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //##################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Spacialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //##################| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //##################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 128, 144, 8, 8, 16, 16, 2, 9, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 8, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, 7, 1>, + DeviceGroupedGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 128, 144, 4, 8, 16, 16, 2, 9, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, 7, 1> + // clang-format on + >; + +void add_device_grouped_gemm_xdl_f16_f16_f16_mk_nk_mn_instances( + std::vector>& instances) +{ + add_device_operation_instances(instances, + device_grouped_gemm_xdl_f16_f16_f16_mk_nk_mn_instances{}); + add_device_operation_instances( + instances, device_grouped_gemm_xdl_f16_f16_f16_mk_nk_mn_irregular_tile_instances{}); +} + +} // namespace device_grouped_gemm_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/profiler/CMakeLists.txt b/profiler/CMakeLists.txt index 5e7156a399..74970b9aac 100644 --- a/profiler/CMakeLists.txt +++ b/profiler/CMakeLists.txt @@ -33,6 +33,7 @@ set(PROFILER_SOURCE src/profile_conv_fwd_bias_relu_atomic_add.cpp src/profile_conv_bwd_data.cpp src/profile_reduce.cpp + src/profile_grouped_gemm.cpp ) add_executable(ckProfiler ${PROFILER_SOURCE}) @@ -49,3 +50,5 @@ target_link_libraries(ckProfiler PRIVATE device_conv2d_fwd_bias_relu_add_instanc target_link_libraries(ckProfiler PRIVATE device_conv2d_fwd_bias_relu_atomic_add_instance) target_link_libraries(ckProfiler PRIVATE device_conv2d_bwd_data_instance) target_link_libraries(ckProfiler PRIVATE device_reduce_instance) +target_link_libraries(ckProfiler PRIVATE device_reduce_instance) +target_link_libraries(ckProfiler PRIVATE device_grouped_gemm_instance) diff --git a/profiler/include/profile_grouped_gemm_impl.hpp b/profiler/include/profile_grouped_gemm_impl.hpp new file mode 100644 index 0000000000..2d99e93cfd --- /dev/null +++ b/profiler/include/profile_grouped_gemm_impl.hpp @@ -0,0 +1,314 @@ +#pragma once +#include +#include "config.hpp" +#include "device.hpp" +#include "host_tensor.hpp" +#include "host_tensor_generator.hpp" +#include "host_conv.hpp" +#include "tensor_layout.hpp" +#include "device_tensor.hpp" +#include "element_wise_operation.hpp" +#include "device_gemm.hpp" +#include "reference_gemm.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_grouped_gemm_instance { + +using DeviceGroupedGemmNoOpPtr = ck::tensor_operation::device::DeviceGroupedGemmPtr< + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::PassThrough>; + +void add_device_grouped_gemm_xdl_f16_f16_f16_mk_kn_mn_instances( + std::vector&); +void add_device_grouped_gemm_xdl_f16_f16_f16_mk_nk_mn_instances( + std::vector&); +void add_device_grouped_gemm_xdl_f16_f16_f16_km_kn_mn_instances( + std::vector&); +void add_device_grouped_gemm_xdl_f16_f16_f16_km_nk_mn_instances( + std::vector&); + +} // namespace device_grouped_gemm_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck + +namespace ck { +namespace profiler { + +template +void profile_grouped_gemm_impl(int do_verification, + int init_method, + bool do_log, + int nrepeat, + std::vector Ms, + std::vector Ns, + std::vector Ks, + std::vector StrideAs, + std::vector StrideBs, + std::vector StrideCs) +{ + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + if(is_same::value) + { + return HostTensorDescriptor(std::vector({row, col}), + std::vector({stride, 1})); + } + else + { + return HostTensorDescriptor(std::vector({row, col}), + std::vector({1, stride})); + } + }; + + int group_count = Ms.size(); + + if(!(group_count == Ns.size() && group_count == Ks.size() && group_count == StrideAs.size() && + group_count == StrideBs.size() && group_count == StrideCs.size())) + { + throw std::runtime_error("wrong! inconsistent M/N/Ks, StrideA/B/Cs size\n"); + } + + std::vector> a_m_k; + std::vector> b_k_n; + std::vector> c_m_n_device_results; + + for(int i = 0; i < Ms.size(); i++) + { + a_m_k.push_back( + Tensor(f_host_tensor_descriptor(Ms[i], Ks[i], StrideAs[i], ALayout{}))); + b_k_n.push_back( + Tensor(f_host_tensor_descriptor(Ks[i], Ns[i], StrideBs[i], BLayout{}))); + + c_m_n_device_results.push_back( + Tensor(f_host_tensor_descriptor(Ms[i], Ns[i], StrideCs[i], CLayout{}))); + + std::cout << "group: " << i << " a_m_k[" << i << "]:" << a_m_k[i].mDesc << ", b_k_n[" << i + << "]:" << b_k_n[i].mDesc << ", c_m_n_device_results[" << i + << "]:" << c_m_n_device_results[i].mDesc << std::endl; + + std::size_t num_thread = std::thread::hardware_concurrency(); + switch(init_method) + { + case 0: break; + case 1: + a_m_k[i].GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); + b_k_n[i].GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); + break; + default: + a_m_k[i].GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}, num_thread); + b_k_n[i].GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}, num_thread); + } + + c_m_n_device_results[i].GenerateTensorValue(GeneratorTensor_0{}, num_thread); + } + + using AElementOp = ck::tensor_operation::element_wise::PassThrough; + using BElementOp = ck::tensor_operation::element_wise::PassThrough; + using CElementOp = ck::tensor_operation::element_wise::PassThrough; + + const auto a_element_op = AElementOp{}; + const auto b_element_op = BElementOp{}; + const auto c_element_op = CElementOp{}; + + // if(do_verification) + // { + + // } + + using DeviceMemPtr = std::unique_ptr; + std::vector a_device_buf, b_device_buf, c_device_buf; + + a_device_buf.reserve(group_count); + b_device_buf.reserve(group_count); + c_device_buf.reserve(group_count); + + std::vector p_a, p_b; + std::vector p_c; + + p_a.reserve(group_count); + p_b.reserve(group_count); + p_c.reserve(group_count); + + std::vector gemm_shapes; + + gemm_shapes.reserve(group_count); + + for(int i = 0; i < group_count; i++) + { + a_device_buf.emplace_back( + std::make_unique(sizeof(ADataType) * a_m_k[i].mDesc.GetElementSize())); + b_device_buf.emplace_back( + std::make_unique(sizeof(BDataType) * b_k_n[i].mDesc.GetElementSize())); + + c_device_buf.emplace_back(std::make_unique( + sizeof(CDataType) * c_m_n_device_results[i].mDesc.GetElementSize())); + + a_device_buf[i]->ToDevice(a_m_k[i].mData.data()); + b_device_buf[i]->ToDevice(b_k_n[i].mData.data()); + c_device_buf[i]->ToDevice(c_m_n_device_results[i].mData.data()); + + gemm_shapes.push_back({Ms[i], Ns[i], Ks[i], StrideAs[i], StrideBs[i], StrideCs[i]}); + + p_a.push_back(a_device_buf[i]->GetDeviceBuffer()); + p_b.push_back(b_device_buf[i]->GetDeviceBuffer()); + p_c.push_back(c_device_buf[i]->GetDeviceBuffer()); + } + + // add device GEMM instances + std::vector< + ck::tensor_operation::device::device_grouped_gemm_instance::DeviceGroupedGemmNoOpPtr> + gemm_ptrs; + + if constexpr(is_same::value && is_same::value && + is_same::value) + { + if constexpr(is_same::value && + is_same::value && + is_same::value) + { + ck::tensor_operation::device::device_grouped_gemm_instance:: + add_device_grouped_gemm_xdl_f16_f16_f16_mk_kn_mn_instances(gemm_ptrs); + } + else if constexpr(is_same::value && + is_same::value && + is_same::value) + { + ck::tensor_operation::device::device_grouped_gemm_instance:: + add_device_grouped_gemm_xdl_f16_f16_f16_mk_nk_mn_instances(gemm_ptrs); + } + else if constexpr(is_same::value && + is_same::value && + is_same::value) + { + ck::tensor_operation::device::device_grouped_gemm_instance:: + add_device_grouped_gemm_xdl_f16_f16_f16_km_kn_mn_instances(gemm_ptrs); + } + else if constexpr(is_same::value && + is_same::value && + is_same::value) + { + ck::tensor_operation::device::device_grouped_gemm_instance:: + add_device_grouped_gemm_xdl_f16_f16_f16_km_nk_mn_instances(gemm_ptrs); + } + } + + if(gemm_ptrs.size() <= 0) + { + throw std::runtime_error("wrong! no device GEMM instance found"); + } + + std::string best_gemm_name; + float best_ave_time = 0; + float best_tflops = 0; + float best_gb_per_sec = 0; + + // profile device GEMM instances + for(auto& gemm_ptr : gemm_ptrs) + { + auto argument_ptr = + gemm_ptr->MakeArgumentPointer(p_a, + p_b, + p_c, + gemm_shapes, + ck::tensor_operation::element_wise::PassThrough{}, + ck::tensor_operation::element_wise::PassThrough{}, + ck::tensor_operation::element_wise::PassThrough{}); + + auto invoker_ptr = gemm_ptr->MakeInvokerPointer(); + + if(gemm_ptr->IsSupportedArgument(argument_ptr.get())) + { + std::string gemm_name = gemm_ptr->GetTypeString(); + + float ave_time = invoker_ptr->Run(argument_ptr.get(), nrepeat); + + std::size_t flop = 0, num_btype = 0; + for(int i = 0; i < gemm_shapes.size(); i++) + { + flop += std::size_t(2) * Ms[i] * Ns[i] * Ks[i]; + + num_btype += sizeof(ADataType) * Ms[i] * Ks[i] + sizeof(BDataType) * Ks[i] * Ns[i] + + sizeof(CDataType) * Ms[i] * Ns[i]; + } + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops << " TFlops, " + << gb_per_sec << " GB/s, " << gemm_name << std::endl; + + if(tflops > best_tflops) + { + best_gemm_name = gemm_name; + best_tflops = tflops; + best_ave_time = ave_time; + best_gb_per_sec = gb_per_sec; + } + + if(do_verification) + { + for(int i = 0; i < gemm_shapes.size(); i++) + { + + c_device_buf[i]->FromDevice(c_m_n_device_results[i].mData.data()); + + Tensor c_m_n_host_result( + f_host_tensor_descriptor(Ms[i], Ns[i], StrideCs[i], CLayout{})); + + using ReferenceGemmInstance = + ck::tensor_operation::host::ReferenceGemm; + + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = ref_gemm.MakeArgument(a_m_k[i], + b_k_n[i], + c_m_n_host_result, + a_element_op, + b_element_op, + c_element_op); + + ref_invoker.Run(ref_argument); + check_error(c_m_n_host_result, c_m_n_device_results[i]); + + if(do_log) + { + LogRangeAsType(std::cout << "a : ", a_m_k[i].mData, ",") + << std::endl; + LogRangeAsType(std::cout << "b: ", b_k_n[i].mData, ",") << std::endl; + LogRangeAsType( + std::cout << "c_device: ", c_m_n_device_results[i].mData, ",") + << std::endl; + LogRangeAsType( + std::cout << "c_host : ", c_m_n_host_result.mData, ",") + << std::endl; + } + } + } + } + else + { + std::cout << "does not support this GEMM problem" << std::endl; + } + } + + std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, " + << best_gb_per_sec << " GB/s, " << best_gemm_name << std::endl; +} // namespace profiler + +} // namespace profiler +} // namespace ck diff --git a/profiler/src/profile_grouped_gemm.cpp b/profiler/src/profile_grouped_gemm.cpp new file mode 100644 index 0000000000..99ddb838ac --- /dev/null +++ b/profiler/src/profile_grouped_gemm.cpp @@ -0,0 +1,157 @@ +#include +#include +#include +#include +#include +#include +#include "profile_grouped_gemm_impl.hpp" + +enum GemmMatrixLayout +{ + MK_KN_MN, // 0 + MK_NK_MN, // 1 + KM_KN_MN, // 2 + KM_NK_MN, // 3 + MK_KN_NM, // 4 + MK_NK_NM, // 5 + KM_KN_NM, // 6 + KM_NK_NM, // 7 +}; + +enum GemmDataType +{ + F32_F32_F32, // 0 + F16_F16_F16, // 1 + BF16_BF16_BF16, // 2 + INT8_INT8_INT8, // 3 +}; + +std::vector argToIntArray(char* input) +{ + std::vector out; + + std::istringstream in(input); + + std::string item; + + while(std::getline(in, item, ',')) + { + out.push_back(std::stoi(item)); + } + + return out; +} + +int profile_grouped_gemm(int argc, char* argv[]) +{ + if(!(argc == 14)) + { + printf("arg1: tensor operation (grouped_gemm: Grouped GEMM)\n"); + printf("arg2: data type (0: fp32; 1: fp16; 2: bf16; 3: int8)\n"); + printf("arg3: matrix layout (0: A[m, k] * B[k, n] = C[m, n];\n"); + printf(" 1: A[m, k] * B[n, k] = C[m, n];\n"); + printf(" 2: A[k, m] * B[k, n] = C[m, n];\n"); + printf(" 3: A[k, m] * B[n, k] = C[m, n])\n"); + printf("arg4: verification (0: no; 1: yes)\n"); + printf("arg5: initialization (0: no init; 1: integer value; 2: decimal value)\n"); + printf("arg8: print tensor value (0: no; 1: yes)\n"); + printf("arg7: run kernel # of times (>1)\n"); + printf("arg8 to 13: Ms, Ns, Ks, StrideAs, StrideBs, StrideCs (e.g., 256,256 128,128 64,64 " + "64,64 64,64 128,128)\n"); + exit(1); + } + + const int data_type = static_cast(std::stoi(argv[2])); + const int layout = static_cast(std::stoi(argv[3])); + const bool do_verification = std::stoi(argv[4]); + const int init_method = std::stoi(argv[5]); + const bool do_log = std::stoi(argv[6]); + const int nrepeat = std::stoi(argv[7]); + + const auto Ms = argToIntArray(argv[8]); + const auto Ns = argToIntArray(argv[9]); + const auto Ks = argToIntArray(argv[10]); + + const auto StrideAs = argToIntArray(argv[11]); + const auto StrideBs = argToIntArray(argv[12]); + const auto StrideCs = argToIntArray(argv[13]); + + if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_KN_MN) + { + ck::profiler::profile_grouped_gemm_impl(do_verification, + init_method, + do_log, + nrepeat, + Ms, + Ns, + Ks, + StrideAs, + StrideBs, + StrideCs); + } + else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_NK_MN) + { + ck::profiler::profile_grouped_gemm_impl(do_verification, + init_method, + do_log, + nrepeat, + Ms, + Ns, + Ks, + StrideAs, + StrideBs, + StrideCs); + } + else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::KM_KN_MN) + { + ck::profiler::profile_grouped_gemm_impl(do_verification, + init_method, + do_log, + nrepeat, + Ms, + Ns, + Ks, + StrideAs, + StrideBs, + StrideCs); + } + else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::KM_NK_MN) + { + ck::profiler::profile_grouped_gemm_impl(do_verification, + init_method, + do_log, + nrepeat, + Ms, + Ns, + Ks, + StrideAs, + StrideBs, + StrideCs); + } + else + { + throw std::runtime_error("wrong! this GEMM data_type & layout is not implemented"); + } + + return 1; +} diff --git a/profiler/src/profiler.cpp b/profiler/src/profiler.cpp index 80ce1f8324..eb5ba53571 100644 --- a/profiler/src/profiler.cpp +++ b/profiler/src/profiler.cpp @@ -15,9 +15,11 @@ int profile_conv_fwd_bias_relu_add(int, char*[]); int profile_conv_fwd_bias_relu_atomic_add(int, char*[]); int profile_conv_bwd_data(int, char*[]); int profile_reduce(int, char*[]); +int profile_grouped_gemm(int, char*[]); int main(int argc, char* argv[]) { +#if 0 if(strcmp(argv[1], "gemm") == 0) { return profile_gemm(argc, argv); @@ -62,6 +64,10 @@ int main(int argc, char* argv[]) { return profile_reduce(argc, argv); } + else if(strcmp(argv[1], "grouped_gemm") == 0) + { + return profile_grouped_gemm(argc, argv); + } else { // clang-format off @@ -74,9 +80,13 @@ int main(int argc, char* argv[]) " conv_fwd_bias_relu_add: ForwardConvolution+Bias+ReLU+Add\n" " conv_fwd_bias_relu_atomic_add: ForwardConvolution+Bias+ReLU+AtomicAdd\n" " conv_bwd: BackwardConvolution\n" + " grouped_gemm: Grouped Gemm\n" " reduce: REDUCE\n"); // clang-format on return 0; } +#else + profile_grouped_gemm(argc, argv); +#endif } diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 13289443fa..8e74fb9d7d 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -35,6 +35,7 @@ add_subdirectory(space_filling_curve) add_subdirectory(conv_util) add_subdirectory(reference_conv_fwd) add_subdirectory(gemm) +add_subdirectory(grouped_gemm) add_subdirectory(gemm_split_k) add_subdirectory(conv2d_fwd) add_subdirectory(convnd_fwd) diff --git a/test/grouped_gemm/CMakeLists.txt b/test/grouped_gemm/CMakeLists.txt new file mode 100644 index 0000000000..f04ee77062 --- /dev/null +++ b/test/grouped_gemm/CMakeLists.txt @@ -0,0 +1,3 @@ +add_test_executable(test_grouped_gemm_fp16 grouped_gemm_fp16.cpp) +target_link_libraries(test_grouped_gemm_fp16 PRIVATE host_tensor) +target_link_libraries(test_grouped_gemm_fp16 PRIVATE device_grouped_gemm_instance) diff --git a/test/grouped_gemm/grouped_gemm_fp16.cpp b/test/grouped_gemm/grouped_gemm_fp16.cpp new file mode 100644 index 0000000000..9b3d2901ee --- /dev/null +++ b/test/grouped_gemm/grouped_gemm_fp16.cpp @@ -0,0 +1,213 @@ +#include +#include +#include +#include +#include +#include +#include "config.hpp" +#include "print.hpp" +#include "device.hpp" +#include "host_tensor.hpp" +#include "host_tensor_generator.hpp" +#include "host_gemm.hpp" +#include "device_tensor.hpp" +#include "device_grouped_gemm_xdl.hpp" +#include "element_wise_operation.hpp" +#include "reference_gemm.hpp" +#include "gemm_specialization.hpp" +#include "test_util.hpp" + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using DeviceGroupedGemmPtr_ = ck::tensor_operation::device::DeviceGroupedGemmPtr< + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::PassThrough>; + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_grouped_gemm_instance { +void add_device_grouped_gemm_xdl_f16_f16_f16_mk_nk_mn_instances( + std::vector&); +} +} // namespace device +} // namespace tensor_operation +} // namespace ck + +namespace { + +using ADataType = ck::half_t; +using BDataType = ck::half_t; +using CDataType = ck::half_t; +using AccDataType = float; + +using ALayout = ck::tensor_layout::gemm::RowMajor; +using BLayout = ck::tensor_layout::gemm::ColumnMajor; +using CLayout = ck::tensor_layout::gemm::RowMajor; + +template +static bool check_err(const Tensor& ref, const Tensor& result) +{ + float max_diff = 1e-2; + + for(int i = 0; i < ref.mData.size(); ++i) + { + float diff = std::abs(double(ref.mData[i]) - double(result.mData[i])); + if(max_diff < diff) + { + std::cout << double(ref.mData[i]) << "," << double(result.mData[i]) << std::endl; + return false; + } + } + + return true; +} + +bool TestGroupedGemm(DeviceGroupedGemmPtr_& groupedGemmPtr) +{ + int group_count = 4; + + // GEMM shape + std::vector gemm_shapes; + std::vector p_a, p_b; + std::vector p_c; + + gemm_shapes.reserve(group_count); + + for(int i = 0; i < group_count; i++) + { + int M = 256 + 256 * i; + int N = 128 + 128 * i; + int K = 128 + 64 * i; + + int AStride = std::is_same::value ? K : M; + int BStride = std::is_same::value ? N : K; + int CStride = std::is_same::value ? N : M; + + gemm_shapes.push_back({M, N, K, AStride, BStride, CStride}); + } + + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + if(std::is_same::value) + { + return HostTensorDescriptor(std::vector({row, col}), + std::vector({stride, 1})); + } + else + { + return HostTensorDescriptor(std::vector({row, col}), + std::vector({1, stride})); + } + }; + + std::vector> a_tensors; + ; + std::vector> b_tensors; + std::vector> c_host_tensors; + std::vector> c_device_tensors; + + a_tensors.reserve(group_count); + b_tensors.reserve(group_count); + c_host_tensors.reserve(group_count); + c_device_tensors.reserve(group_count); + + using DeviceMemPtr = std::unique_ptr; + + std::vector a_tensors_device, b_tensors_device, c_tensors_device; + + a_tensors_device.reserve(group_count); + b_tensors_device.reserve(group_count); + c_tensors_device.reserve(group_count); + + for(int i = 0; i < gemm_shapes.size(); i++) + { + a_tensors.emplace_back(Tensor(f_host_tensor_descriptor( + gemm_shapes[i].M, gemm_shapes[i].K, gemm_shapes[i].StrideA, ALayout{}))); + b_tensors.emplace_back(Tensor(f_host_tensor_descriptor( + gemm_shapes[i].K, gemm_shapes[i].N, gemm_shapes[i].StrideB, BLayout{}))); + c_host_tensors.emplace_back(Tensor(f_host_tensor_descriptor( + gemm_shapes[i].M, gemm_shapes[i].N, gemm_shapes[i].StrideC, CLayout{}))); + c_device_tensors.emplace_back(Tensor(f_host_tensor_descriptor( + gemm_shapes[i].M, gemm_shapes[i].N, gemm_shapes[i].StrideC, CLayout{}))); + + a_tensors[i].GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b_tensors[i].GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + } + + for(int i = 0; i < gemm_shapes.size(); i++) + { + a_tensors_device.emplace_back( + std::make_unique(sizeof(ADataType) * a_tensors[i].mDesc.GetElementSize())); + b_tensors_device.emplace_back( + std::make_unique(sizeof(BDataType) * b_tensors[i].mDesc.GetElementSize())); + c_tensors_device.emplace_back(std::make_unique( + sizeof(CDataType) * c_device_tensors[i].mDesc.GetElementSize())); + + a_tensors_device[i]->ToDevice(a_tensors[i].mData.data()); + b_tensors_device[i]->ToDevice(b_tensors[i].mData.data()); + + p_a.push_back(a_tensors_device[i]->GetDeviceBuffer()); + p_b.push_back(b_tensors_device[i]->GetDeviceBuffer()); + p_c.push_back(c_tensors_device[i]->GetDeviceBuffer()); + } + + auto a_element_op = PassThrough{}; + auto b_element_op = PassThrough{}; + auto c_element_op = PassThrough{}; + + // do GEMM + auto invoker_ptr = groupedGemmPtr->MakeInvokerPointer(); + auto argument_ptr = groupedGemmPtr->MakeArgumentPointer( + p_a, p_b, p_c, gemm_shapes, a_element_op, b_element_op, c_element_op); + + invoker_ptr->Run(argument_ptr.get()); + + for(int i = 0; i < gemm_shapes.size(); i++) + { + c_tensors_device[i]->FromDevice(c_device_tensors[i].mData.data()); + + using ReferenceGemmInstance = ck::tensor_operation::host:: + ReferenceGemm; + + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = ref_gemm.MakeArgument(a_tensors[i], + b_tensors[i], + c_host_tensors[i], + a_element_op, + b_element_op, + c_element_op); + + ref_invoker.Run(ref_argument); + + bool res = check_err(c_device_tensors[i], c_host_tensors[i]); + + std::cout << "group_id: " << i << (res ? " SUCCESS" : " FAILURE") << std::endl; + + if(!res) + return false; + } + + return true; +} + +} // anonymous namespace + +int main() +{ + std::vector groupedGemmPtrs; + ck::tensor_operation::device::device_grouped_gemm_instance:: + add_device_grouped_gemm_xdl_f16_f16_f16_mk_nk_mn_instances(groupedGemmPtrs); + + bool res = true; + + for(auto& gemmPtr : groupedGemmPtrs) + { + res &= TestGroupedGemm(gemmPtr); + } + + std::cout << "TestGroupedGemm ..... " << (res ? "SUCCESS" : "FAILURE") << std::endl; +}