From aef327296e5ca53e0ee08a383aaf10c8f03349cd Mon Sep 17 00:00:00 2001 From: Illia Silin <98187287+illsilin@users.noreply.github.com> Date: Tue, 3 Feb 2026 09:52:14 -0800 Subject: [PATCH] Revert "Implement device grouped gemm fixed nk multi abd for rdna4 (#3619)" (#3705) This reverts commit 372a284890dc19cfd3c241c3e9a6076d35e843a5. [ROCm/composable_kernel commit: 569640dc70bb9175fb1b6664b9a2e0970b7dec78] --- ...grouped_gemm_bias_fastgelu_xdl_bf16_i8.cpp | 2 - .../59_grouped_gemm_multi_ABD/CMakeLists.txt | 8 - ...m_multi_abd_wmma_fixed_nk_bias_bf16_i8.cpp | 400 -------- ...gemm_multi_abd_wmma_fixed_nk_bias_fp16.cpp | 396 -------- ...mm_multi_abd_xdl_fixed_nk_bias_bf16_i8.cpp | 27 +- ..._gemm_multi_abd_xdl_fixed_nk_bias_fp16.cpp | 33 +- ...e_grouped_gemm_multi_abd_wmma_fixed_nk.hpp | 899 ------------------ ...ce_grouped_gemm_multi_abd_xdl_fixed_nk.hpp | 27 +- .../impl/device_grouped_gemm_xdl_fixed_nk.hpp | 2 +- .../grid/gridwise_gemm_wmma_cshuffle_v3.hpp | 1 - include/ck/utility/tuple_helper.hpp | 9 - .../cpu/reference_gemm_multi_abd.hpp | 194 ---- .../gpu/grouped_gemm_multi_abd_fixed_nk.hpp | 295 +----- .../CMakeLists.txt | 6 +- ...as_gelu_bf16_i8_bf16_km_kn_mn_instance.cpp | 144 --- ...as_gelu_bf16_i8_bf16_mk_kn_mn_instance.cpp | 144 --- ...as_gelu_bf16_i8_bf16_mk_nk_mn_instance.cpp | 144 --- ..._fixed_nk_bf16_i8_bf16_km_kn_mn_common.hpp | 10 +- ..._fixed_nk_bf16_i8_bf16_mk_kn_mn_common.hpp | 10 +- ..._fixed_nk_bf16_i8_bf16_mk_nk_mn_common.hpp | 2 - .../profiler/profile_gemm_multi_abd_impl.hpp | 88 +- ...e_grouped_gemm_multi_abd_fixed_nk_impl.hpp | 534 ----------- test/grouped_gemm/CMakeLists.txt | 6 - .../test_grouped_gemm_multi_abd_fixed_nk.cpp | 256 ----- 24 files changed, 120 insertions(+), 3517 deletions(-) delete mode 100644 example/59_grouped_gemm_multi_ABD/grouped_gemm_multi_abd_wmma_fixed_nk_bias_bf16_i8.cpp delete mode 100644 example/59_grouped_gemm_multi_ABD/grouped_gemm_multi_abd_wmma_fixed_nk_bias_fp16.cpp delete mode 100644 include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_wmma_fixed_nk.hpp delete mode 100644 library/include/ck/library/reference_tensor_operation/cpu/reference_gemm_multi_abd.hpp delete mode 100644 library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk_multi_abd/device_grouped_gemm_wmma_fixed_nk_bias_gelu_bf16_i8_bf16_km_kn_mn_instance.cpp delete mode 100644 library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk_multi_abd/device_grouped_gemm_wmma_fixed_nk_bias_gelu_bf16_i8_bf16_mk_kn_mn_instance.cpp delete mode 100644 library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk_multi_abd/device_grouped_gemm_wmma_fixed_nk_bias_gelu_bf16_i8_bf16_mk_nk_mn_instance.cpp delete mode 100644 profiler/include/profiler/profile_grouped_gemm_multi_abd_fixed_nk_impl.hpp delete mode 100644 test/grouped_gemm/test_grouped_gemm_multi_abd_fixed_nk.cpp diff --git a/client_example/31_grouped_gemm_bf16Aint8B/grouped_gemm_bias_fastgelu_xdl_bf16_i8.cpp b/client_example/31_grouped_gemm_bf16Aint8B/grouped_gemm_bias_fastgelu_xdl_bf16_i8.cpp index e6e2137bea..0766373465 100644 --- a/client_example/31_grouped_gemm_bf16Aint8B/grouped_gemm_bias_fastgelu_xdl_bf16_i8.cpp +++ b/client_example/31_grouped_gemm_bf16Aint8B/grouped_gemm_bias_fastgelu_xdl_bf16_i8.cpp @@ -15,8 +15,6 @@ #include "ck/library/tensor_operation_instance/gpu/grouped_gemm_multi_abd_fixed_nk.hpp" -#include "ck/host_utility/hip_check_error.hpp" - using ::ck::hip_check_error; template diff --git a/example/59_grouped_gemm_multi_ABD/CMakeLists.txt b/example/59_grouped_gemm_multi_ABD/CMakeLists.txt index d7ff58705c..4155e0a344 100644 --- a/example/59_grouped_gemm_multi_ABD/CMakeLists.txt +++ b/example/59_grouped_gemm_multi_ABD/CMakeLists.txt @@ -8,11 +8,3 @@ add_example_dependencies(example_grouped_gemm_xdl_multi_abd example_grouped_gemm add_example_executable(example_grouped_gemm_multi_abd_xdl_fixed_nk_bias_bf16_i8 grouped_gemm_multi_abd_xdl_fixed_nk_bias_bf16_i8.cpp) add_example_dependencies(example_grouped_gemm_xdl_multi_abd example_grouped_gemm_multi_abd_xdl_fixed_nk_bias_bf16_i8) - -add_custom_target(example_grouped_gemm_wmma_multi_abd) - -add_example_executable(example_grouped_gemm_multi_abd_wmma_fixed_nk_bias_fp16 grouped_gemm_multi_abd_wmma_fixed_nk_bias_fp16.cpp) -add_example_dependencies(example_grouped_gemm_wmma_multi_abd example_grouped_gemm_multi_abd_wmma_fixed_nk_bias_fp16) - -add_example_executable(example_grouped_gemm_multi_abd_wmma_fixed_nk_bias_bf16_i8 grouped_gemm_multi_abd_wmma_fixed_nk_bias_bf16_i8.cpp) -add_example_dependencies(example_grouped_gemm_wmma_multi_abd example_grouped_gemm_multi_abd_wmma_fixed_nk_bias_bf16_i8) \ No newline at end of file diff --git a/example/59_grouped_gemm_multi_ABD/grouped_gemm_multi_abd_wmma_fixed_nk_bias_bf16_i8.cpp b/example/59_grouped_gemm_multi_ABD/grouped_gemm_multi_abd_wmma_fixed_nk_bias_bf16_i8.cpp deleted file mode 100644 index 4eab6cfce2..0000000000 --- a/example/59_grouped_gemm_multi_ABD/grouped_gemm_multi_abd_wmma_fixed_nk_bias_bf16_i8.cpp +++ /dev/null @@ -1,400 +0,0 @@ -// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. -// SPDX-License-Identifier: MIT - -#include -#include -#include -#include - -#include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" -#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_wmma_fixed_nk.hpp" -#include "ck/tensor_operation/gpu/device/device_grouped_gemm_multi_abd.hpp" -#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" - -#include "ck/library/utility/check_err.hpp" -#include "ck/library/utility/device_memory.hpp" -#include "ck/library/utility/host_tensor.hpp" -#include "ck/library/utility/host_tensor_generator.hpp" -#include "ck/library/utility/literals.hpp" -#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" - -#include "ck/host_utility/hip_check_error.hpp" - -using ::ck::DeviceMem; -using ::ck::hip_check_error; -using ::ck::HostTensorDescriptor; -using ::ck::Tensor; - -template -using S = ck::Sequence; - -using BF16 = ck::bhalf_t; -using I8 = int8_t; -using F32 = float; - -using Row = ck::tensor_layout::gemm::RowMajor; -using Col = ck::tensor_layout::gemm::ColumnMajor; -using Bypass = ck::tensor_layout::BypassLayoutVerification; - -using PassThrough = ck::tensor_operation::element_wise::PassThrough; -using Add = ck::tensor_operation::element_wise::Add; - -using A0DataType = BF16; -using AsDataType = ck::Tuple; -using B0DataType = I8; -using B1DataType = BF16; -using BsDataType = ck::Tuple; -using AccDataType = F32; -using CShuffleDataType = BF16; -using D0DataType = BF16; -using DsDataType = ck::Tuple; -using EDataType = BF16; - -using A0Layout = Row; -using AsLayout = ck::Tuple; -using B0Layout = Col; -using B1Layout = B0Layout; -using BsLayout = ck::Tuple; -using DsLayout = ck::Tuple; -using ELayout = Row; - -using Multiply = ck::tensor_operation::element_wise::Multiply; -using PassThrough = ck::tensor_operation::element_wise::PassThrough; -using AddFastGelu = ck::tensor_operation::element_wise::AddFastGelu; - -using AElementOp = PassThrough; -using BElementOp = Multiply; -using CDEElementOp = AddFastGelu; - -static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::MNKPadding; - -using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedGemm_Wmma_Multi_ABD_Fixed_NK - // clang-format off -///######| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| -///######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| -///######| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| -///######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - < AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 128, 32, 128, 32, 8, 8, 16, 16, 1, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 16, 1, 8>, 1>; - -// clang-format on - -struct ProblemSize final -{ - std::vector Ms; - std::vector Ns; - std::vector Ks; - - std::vector stride_As; - std::vector stride_Bs; - std::vector stride_Cs; - - ck::index_t group_count; -}; - -struct ExecutionConfig final -{ - bool do_verification = true; - int init_method = 1; - bool time_kernel = false; - int k_batch = 1; -}; - -bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& config) -{ - auto group_count = problem_size.group_count; - - // GEMM shape - std::vector gemm_descs; - - gemm_descs.reserve(group_count); - - int sum_of_m = 0; - - auto f_host_tensor_descriptor = - [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { - using namespace ck::literals; - - if(std::is_same::value) - { - return HostTensorDescriptor({row, col}, {stride, 1_uz}, Bypass{}); - } - else - { - return HostTensorDescriptor({row, col}, {1_uz, stride}, Bypass{}); - } - }; - - std::vector> a0_tensors; - std::vector> b_tensors; - std::vector> b0_tensors; - std::vector> b1_tensors; - std::vector> d0_tensors; - std::vector> c_host_tensors; - std::vector> c_device_tensors; - - a0_tensors.reserve(group_count); - b_tensors.reserve(group_count); - b0_tensors.reserve(group_count); - b1_tensors.reserve(group_count); - d0_tensors.reserve(group_count); - c_host_tensors.reserve(group_count); - c_device_tensors.reserve(group_count); - - using DeviceMemPtr = std::unique_ptr; - - std::vector a0_tensors_device, b0_tensors_device, b1_tensors_device, - d0_tensors_device, c_tensors_device; - - a0_tensors_device.reserve(group_count); - b0_tensors_device.reserve(group_count); - b1_tensors_device.reserve(group_count); - d0_tensors_device.reserve(group_count); - c_tensors_device.reserve(group_count); - - std::size_t flop = 0, num_btype = 0; - - for(int i = 0; i < group_count; i++) - { - sum_of_m += problem_size.Ms[i]; - - a0_tensors.push_back(Tensor(f_host_tensor_descriptor( - problem_size.Ms[i], problem_size.Ks[i], problem_size.stride_As[i], A0Layout{}))); - - b_tensors.push_back(Tensor(f_host_tensor_descriptor( - problem_size.Ks[i], problem_size.Ns[i], problem_size.stride_Bs[i], B0Layout{}))); - b0_tensors.push_back(Tensor(f_host_tensor_descriptor( - problem_size.Ks[i], problem_size.Ns[i], problem_size.stride_Bs[i], B0Layout{}))); - b1_tensors.push_back(Tensor( - f_host_tensor_descriptor(problem_size.Ks[i], problem_size.Ns[i], 0, B1Layout{}))); - - d0_tensors.push_back(Tensor( - f_host_tensor_descriptor(problem_size.Ms[i], problem_size.Ns[i], 0, ELayout{}))); - - c_host_tensors.push_back(Tensor(f_host_tensor_descriptor( - problem_size.Ms[i], problem_size.Ns[i], problem_size.stride_Cs[i], ELayout{}))); - c_device_tensors.push_back(Tensor(f_host_tensor_descriptor( - problem_size.Ms[i], problem_size.Ns[i], problem_size.stride_Cs[i], ELayout{}))); - - std::cout << "gemm[" << i << "] a_m_k: " << a0_tensors[i].mDesc - << " b_k_n: " << b0_tensors[i].mDesc << " d_m_n: " << d0_tensors[i].mDesc - << " c_m_n: " << c_device_tensors[i].mDesc << std::endl; - - flop += std::size_t(2) * problem_size.Ms[i] * problem_size.Ks[i] * problem_size.Ns[i]; - num_btype += sizeof(A0DataType) * a0_tensors[i].mDesc.GetElementSize() + - sizeof(B0DataType) * b0_tensors[i].mDesc.GetElementSize() + - sizeof(B1DataType) * b1_tensors[i].mDesc.GetElementSize() + - sizeof(D0DataType) * d0_tensors[i].mDesc.GetElementSize() + - sizeof(EDataType) * c_device_tensors[i].mDesc.GetElementSize(); - - switch(config.init_method) - { - case 0: break; - case 1: - a0_tensors[i].GenerateTensorValue(GeneratorTensor_2{-5, 5}); - b0_tensors[i].GenerateTensorValue(GeneratorTensor_2{-5, 5}); - b1_tensors[i].GenerateTensorValue(GeneratorTensor_2{0, 5}); - break; - case 2: - a0_tensors[i].GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); - b0_tensors[i].GenerateTensorValue(GeneratorTensor_3{-5, 5}); - b1_tensors[i].GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); - break; - default: - a0_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential{}); - b0_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential{}); - b1_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential{}); - } - - d0_tensors[i].GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); - } - - constexpr ck::index_t NumATensor = 1; - constexpr ck::index_t NumBTensor = 2; - constexpr ck::index_t NumDTensor = 1; - - using GroupedGemmKernelArgument = ck::tensor_operation::device:: - GroupedGemmMultiABDKernelArgument; - - std::vector grouped_gemm_kernel_args_; - grouped_gemm_kernel_args_.reserve(group_count); - - for(int i = 0; i < group_count; i++) - { - a0_tensors_device.emplace_back(std::make_unique( - sizeof(A0DataType) * problem_size.Ms[i] * problem_size.Ks[i])); - - b0_tensors_device.emplace_back(std::make_unique( - sizeof(B0DataType) * problem_size.Ns[i] * problem_size.Ks[i])); - - b1_tensors_device.emplace_back(std::make_unique( - sizeof(B1DataType) * problem_size.Ns[i] * problem_size.Ks[i])); - - d0_tensors_device.emplace_back( - std::make_unique(sizeof(D0DataType) * problem_size.Ns[i])); - - c_tensors_device.emplace_back(std::make_unique( - sizeof(EDataType) * problem_size.Ms[i] * problem_size.Ns[i])); - - a0_tensors_device[i]->ToDevice(a0_tensors[i].mData.data()); - b0_tensors_device[i]->ToDevice(b0_tensors[i].mData.data()); - b1_tensors_device[i]->ToDevice(b1_tensors[i].mData.data()); - d0_tensors_device[i]->ToDevice(d0_tensors[i].mData.data()); - c_tensors_device[i]->SetZero(); - - gemm_descs.push_back( - {sum_of_m, problem_size.Ns[i], problem_size.Ks[i], {1}, {1, 1}, {0}, 1}); - - grouped_gemm_kernel_args_.push_back( - {std::array{a0_tensors_device[i]->GetDeviceBuffer()}, - std::array{b0_tensors_device[i]->GetDeviceBuffer(), - b1_tensors_device[i]->GetDeviceBuffer()}, - std::array{d0_tensors_device[i]->GetDeviceBuffer()}, - c_tensors_device[i]->GetDeviceBuffer(), - problem_size.Ms[i], - problem_size.Ns[i], - problem_size.Ks[i], - std::array{problem_size.stride_As[i]}, - std::array{problem_size.stride_Bs[i], 0}, - std::array{0}, - problem_size.stride_Cs[i]}); - } - - auto a_element_op = AElementOp{}; - auto b_element_op = BElementOp{}; - auto cde_element_op = CDEElementOp{}; - - auto gemm = DeviceGemmInstance{}; - auto invoker = gemm.MakeInvoker(); - - std::vector> p_As = {}; - std::vector> p_Bs = {}; - std::vector> p_Ds = {}; - std::vector p_Cs = {}; - - // do GEMM - auto argument = gemm.MakeArgument(p_As, p_Bs, p_Ds, p_Cs, gemm_descs); - - if(!gemm.IsSupportedArgument(argument)) - { - throw std::runtime_error( - "wrong! device_gemm with the specified compilation parameters does " - "not support this GEMM problem"); - } - - DeviceMem gemm_workspace_dev(gemm.GetWorkSpaceSize(&argument)); - gemm.SetWorkSpacePointer(&argument, gemm_workspace_dev.GetDeviceBuffer()); - - DeviceMem gemm_kernel_args_dev(gemm.GetDeviceKernelArgSize(&argument)); - hip_check_error(hipMemcpy(gemm_kernel_args_dev.GetDeviceBuffer(), - grouped_gemm_kernel_args_.data(), - gemm.GetDeviceKernelArgSize(&argument), - hipMemcpyHostToDevice)); - - gemm.SetDeviceKernelArgs(argument, gemm_kernel_args_dev.GetDeviceBuffer()); - gemm.SetKBatch(argument, config.k_batch); - - gemm.SetElementwiseOps(argument, a_element_op, b_element_op, cde_element_op); - - invoker.Run(&argument, StreamConfig{nullptr, false}); - - if(config.time_kernel) - { - float ave_time = invoker.Run(&argument, StreamConfig{nullptr, config.time_kernel}); - float tflops = static_cast(flop) / 1.E9 / ave_time; - float gb_per_sec = num_btype / 1.E6 / ave_time; - - std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec - << " GB/s, " << gemm.GetTypeString() << std::endl; - } - - bool pass = true; - if(config.do_verification) - { - using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; - - for(std::size_t i = 0; i < gemm_descs.size(); i++) - { - for(int n = 0; n < problem_size.Ns[i]; ++n) - { - for(int k = 0; k < problem_size.Ks[i]; ++k) - { - b_element_op(b_tensors[i](k, n), b0_tensors[i](k, n), b1_tensors[i](k, n)); - } - } - - c_tensors_device[i]->FromDevice(c_device_tensors[i].mData.data(), - c_device_tensors[i].mDesc.GetElementSize() * - sizeof(EDataType)); - - auto ref_gemm = ReferenceGemmInstance{}; - auto ref_invoker = ref_gemm.MakeInvoker(); - - auto ref_argument = ref_gemm.MakeArgument(a0_tensors[i], - b_tensors[i], - c_host_tensors[i], - PassThrough{}, - PassThrough{}, - PassThrough{}); - - ref_invoker.Run(ref_argument); - - for(int m = 0; m < problem_size.Ms[i]; ++m) - { - for(int n = 0; n < problem_size.Ns[i]; ++n) - { - cde_element_op( - c_host_tensors[i](m, n), c_host_tensors[i](m, n), d0_tensors[i](m, n)); - } - } - - pass &= ck::utils::check_err(c_device_tensors[i], c_host_tensors[i]); - } - } - - return pass; -} - -int main(int argc, char* argv[]) -{ - ProblemSize problem_size; - ExecutionConfig config; - - problem_size.group_count = 16; - - for(int i = 0; i < problem_size.group_count; i++) - { - problem_size.Ms.push_back(32 + rand() % 32); - problem_size.Ns.push_back(1024); - problem_size.Ks.push_back(512); - - problem_size.stride_As.push_back(problem_size.Ks[i]); - problem_size.stride_Bs.push_back(problem_size.Ks[i]); - problem_size.stride_Cs.push_back(problem_size.Ns[i]); - } - - if(argc == 5) - { - config.do_verification = std::stoi(argv[1]); - config.init_method = std::stoi(argv[2]); - config.time_kernel = std::stoi(argv[3]); - config.k_batch = std::stoi(argv[4]); - } - else - { - printf("arg1: verification (0=no, 1=yes)\n"); - printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); - printf("arg3: time kernel (0=no, 1=yes)\n"); - printf("arg4: k_batch (>0)\n"); - exit(0); - } - - return !run_grouped_gemm(problem_size, config); -} diff --git a/example/59_grouped_gemm_multi_ABD/grouped_gemm_multi_abd_wmma_fixed_nk_bias_fp16.cpp b/example/59_grouped_gemm_multi_ABD/grouped_gemm_multi_abd_wmma_fixed_nk_bias_fp16.cpp deleted file mode 100644 index c494e45bfb..0000000000 --- a/example/59_grouped_gemm_multi_ABD/grouped_gemm_multi_abd_wmma_fixed_nk_bias_fp16.cpp +++ /dev/null @@ -1,396 +0,0 @@ -// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. -// SPDX-License-Identifier: MIT - -#include -#include -#include -#include - -#include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" -#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_wmma_fixed_nk.hpp" -#include "ck/tensor_operation/gpu/device/device_grouped_gemm_multi_abd.hpp" -#include "ck/tensor_operation/gpu/element/combined_element_wise_operation.hpp" - -#include "ck/utility/scheduler_enum.hpp" -#include "ck/library/utility/check_err.hpp" -#include "ck/library/utility/device_memory.hpp" -#include "ck/library/utility/host_tensor.hpp" -#include "ck/library/utility/host_tensor_generator.hpp" -#include "ck/library/utility/literals.hpp" -#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" - -#include "ck/host_utility/hip_check_error.hpp" - -using ::ck::DeviceMem; -using ::ck::hip_check_error; -using ::ck::HostTensorDescriptor; -using ::ck::Tensor; - -template -using S = ck::Sequence; - -using F16 = ck::half_t; -using F32 = float; - -using Row = ck::tensor_layout::gemm::RowMajor; -using Col = ck::tensor_layout::gemm::ColumnMajor; -using Bypass = ck::tensor_layout::BypassLayoutVerification; - -using PassThrough = ck::tensor_operation::element_wise::PassThrough; -using Add = ck::tensor_operation::element_wise::Add; -using Scale = ck::tensor_operation::element_wise::Scale; -using AddScale = ck::tensor_operation::element_wise::BinaryWithUnaryCombinedOp; - -using A0DataType = F16; -using A1DataType = F32; -using AsDataType = ck::Tuple; -using B0DataType = F16; -using BsDataType = ck::Tuple; -using AccDataType = F32; -using CShuffleDataType = F32; -using D0DataType = F16; -using DsDataType = ck::Tuple; -using EDataType = F16; - -using A0Layout = Row; -using A1Layout = Row; -using AsLayout = ck::Tuple; -using B0Layout = Col; -using BsLayout = ck::Tuple; -using D0Layout = Row; -using DsLayout = ck::Tuple; -using ELayout = Row; - -using AElementOp = AddScale; -using BElementOp = PassThrough; -using CDEElementOp = Add; - -static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::MNKPadding; - -using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedGemm_Wmma_Multi_ABD_Fixed_NK - // clang-format off -///######| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| -///######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| -///######| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| -///######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - < AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 128, 32, 128, 32, 8, 8, 16, 16, 1, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 16, 1, 8>, 1>; -// clang-format on - -struct ProblemSize final -{ - std::vector Ms; - std::vector Ns; - std::vector Ks; - - std::vector stride_As; - std::vector stride_Bs; - std::vector stride_Cs; - - ck::index_t group_count; -}; - -struct ExecutionConfig final -{ - bool do_verification = true; - int init_method = 1; - bool time_kernel = false; - int k_batch = 1; -}; - -bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& config) -{ - auto group_count = problem_size.group_count; - - // GEMM shape - std::vector gemm_descs; - - gemm_descs.reserve(group_count); - - int sum_of_m = 0; - - auto f_host_tensor_descriptor = - [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { - using namespace ck::literals; - - if(std::is_same::value) - { - return HostTensorDescriptor({row, col}, {stride, 1_uz}, Bypass{}); - } - else - { - return HostTensorDescriptor({row, col}, {1_uz, stride}, Bypass{}); - } - }; - - std::vector> a0_tensors; - std::vector> a1_tensors; - std::vector> b_tensors; - std::vector> d0_tensors; - std::vector> e_host_tensors; - std::vector> e_device_tensors; - - a0_tensors.reserve(group_count); - a1_tensors.reserve(group_count); - b_tensors.reserve(group_count); - d0_tensors.reserve(group_count); - e_host_tensors.reserve(group_count); - e_device_tensors.reserve(group_count); - - using DeviceMemPtr = std::unique_ptr; - - std::vector a0_tensors_device, a1_tensors_device, b_tensors_device, - d0_tensors_device, c_tensors_device; - - a0_tensors_device.reserve(group_count); - a1_tensors_device.reserve(group_count); - b_tensors_device.reserve(group_count); - d0_tensors_device.reserve(group_count); - c_tensors_device.reserve(group_count); - - std::size_t flop = 0, num_btype = 0; - - for(int i = 0; i < group_count; i++) - { - sum_of_m += problem_size.Ms[i]; - a0_tensors.push_back(Tensor(f_host_tensor_descriptor( - problem_size.Ms[i], problem_size.Ks[i], problem_size.stride_As[i], A0Layout{}))); - a1_tensors.push_back(Tensor(f_host_tensor_descriptor( - problem_size.Ms[i], problem_size.Ks[i], problem_size.stride_As[i], A1Layout{}))); - b_tensors.push_back(Tensor(f_host_tensor_descriptor( - problem_size.Ks[i], problem_size.Ns[i], problem_size.stride_Bs[i], B0Layout{}))); - d0_tensors.push_back(Tensor( - f_host_tensor_descriptor(problem_size.Ms[i], problem_size.Ns[i], 0, ELayout{}))); - e_host_tensors.push_back(Tensor(f_host_tensor_descriptor( - problem_size.Ms[i], problem_size.Ns[i], problem_size.stride_Cs[i], ELayout{}))); - e_device_tensors.push_back(Tensor(f_host_tensor_descriptor( - problem_size.Ms[i], problem_size.Ns[i], problem_size.stride_Cs[i], ELayout{}))); - std::cout << "gemm[" << i << "] a_m_k: " << a0_tensors[i].mDesc - << " b_k_n: " << b_tensors[i].mDesc << " d_m_n: " << d0_tensors[i].mDesc - << " c_m_n: " << e_device_tensors[i].mDesc << std::endl; - - flop += std::size_t(2) * problem_size.Ms[i] * problem_size.Ks[i] * problem_size.Ns[i]; - num_btype += sizeof(A0DataType) * a0_tensors[i].mDesc.GetElementSize() + - sizeof(A1DataType) * a1_tensors[i].mDesc.GetElementSize() + - sizeof(B0DataType) * b_tensors[i].mDesc.GetElementSize() + - sizeof(D0DataType) * d0_tensors[i].mDesc.GetElementSize() + - sizeof(EDataType) * e_device_tensors[i].mDesc.GetElementSize(); - - switch(config.init_method) - { - case 0: break; - case 1: - a0_tensors[i].GenerateTensorValue(GeneratorTensor_2{-5, 5}); - a1_tensors[i].GenerateTensorValue(GeneratorTensor_2{-5, 5}); - b_tensors[i].GenerateTensorValue(GeneratorTensor_2{-5, 5}); - break; - case 2: - a0_tensors[i].GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); - a1_tensors[i].GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); - b_tensors[i].GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); - break; - default: - a0_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential{}); - a1_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential{}); - b_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential{}); - } - - d0_tensors[i].GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); - } - - constexpr ck::index_t NumATensor = 2; - constexpr ck::index_t NumBTensor = 1; - constexpr ck::index_t NumDTensor = 1; - - using GroupedGemmKernelArgument = ck::tensor_operation::device:: - GroupedGemmMultiABDKernelArgument; - - std::vector grouped_gemm_kernel_args_; - grouped_gemm_kernel_args_.reserve(group_count); - - for(int i = 0; i < group_count; i++) - { - a0_tensors_device.emplace_back(std::make_unique( - sizeof(A0DataType) * problem_size.Ms[i] * problem_size.Ks[i])); - - a1_tensors_device.emplace_back(std::make_unique( - sizeof(A1DataType) * problem_size.Ms[i] * problem_size.Ks[i])); - - b_tensors_device.emplace_back(std::make_unique( - sizeof(B0DataType) * problem_size.Ns[i] * problem_size.Ks[i])); - - d0_tensors_device.emplace_back( - std::make_unique(sizeof(D0DataType) * problem_size.Ns[i])); - - c_tensors_device.emplace_back(std::make_unique( - sizeof(EDataType) * problem_size.Ms[i] * problem_size.Ns[i])); - - a0_tensors_device[i]->ToDevice(a0_tensors[i].mData.data()); - a1_tensors_device[i]->ToDevice(a1_tensors[i].mData.data()); - b_tensors_device[i]->ToDevice(b_tensors[i].mData.data()); - d0_tensors_device[i]->ToDevice(d0_tensors[i].mData.data()); - c_tensors_device[i]->SetZero(); - - gemm_descs.push_back({sum_of_m, - problem_size.Ns[i], - problem_size.Ks[i], - {1, 1}, - {problem_size.stride_Bs[i]}, - {0}, - 1}); - - grouped_gemm_kernel_args_.push_back( - {std::array{a0_tensors_device[i]->GetDeviceBuffer(), - a1_tensors_device[i]->GetDeviceBuffer()}, - std::array{b_tensors_device[i]->GetDeviceBuffer()}, - std::array{d0_tensors_device[i]->GetDeviceBuffer()}, - c_tensors_device[i]->GetDeviceBuffer(), - problem_size.Ms[i], - problem_size.Ns[i], - problem_size.Ks[i], - std::array{problem_size.stride_As[i], - problem_size.stride_As[i]}, - std::array{problem_size.stride_Bs[i]}, - std::array{0}, - problem_size.stride_Cs[i]}); - } - - constexpr float scale = 1.f; - auto a_element_op = AElementOp{Add{}, Scale{scale}, Scale{scale}}; - auto b_element_op = BElementOp{}; - auto cde_element_op = CDEElementOp{}; - - auto gemm = DeviceGemmInstance{}; - auto invoker = gemm.MakeInvoker(); - - std::vector> p_As = {}; - std::vector> p_Bs = {}; - std::vector> p_Ds = {}; - std::vector p_Cs = {}; - - // do GEMM - auto argument = gemm.MakeArgument(p_As, p_Bs, p_Ds, p_Cs, gemm_descs); - - if(!gemm.IsSupportedArgument(argument)) - { - throw std::runtime_error( - "wrong! device_gemm with the specified compilation parameters does " - "not support this GEMM problem"); - } - - DeviceMem gemm_workspace_dev(gemm.GetWorkSpaceSize(&argument)); - gemm.SetWorkSpacePointer(&argument, gemm_workspace_dev.GetDeviceBuffer()); - - DeviceMem gemm_kernel_args_dev(gemm.GetDeviceKernelArgSize(&argument)); - hip_check_error(hipMemcpy(gemm_kernel_args_dev.GetDeviceBuffer(), - grouped_gemm_kernel_args_.data(), - gemm.GetDeviceKernelArgSize(&argument), - hipMemcpyHostToDevice)); - - gemm.SetDeviceKernelArgs(argument, gemm_kernel_args_dev.GetDeviceBuffer()); - gemm.SetKBatch(argument, config.k_batch); - - gemm.SetElementwiseOps(argument, a_element_op, b_element_op, cde_element_op); - - invoker.Run(&argument, StreamConfig{nullptr, false}); - - if(config.time_kernel) - { - float ave_time = invoker.Run(&argument, StreamConfig{nullptr, config.time_kernel}); - float tflops = static_cast(flop) / 1.E9 / ave_time; - float gb_per_sec = num_btype / 1.E6 / ave_time; - - std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec - << " GB/s, " << gemm.GetTypeString() << std::endl; - } - - bool pass = true; - if(config.do_verification) - { - using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; - - for(std::size_t i = 0; i < gemm_descs.size(); i++) - { - for(int m = 0; m < problem_size.Ms[i]; ++m) - { - for(int k = 0; k < problem_size.Ks[i]; ++k) - { - a_element_op(a0_tensors[i](m, k), a0_tensors[i](m, k), a1_tensors[i](m, k)); - } - } - - c_tensors_device[i]->FromDevice(e_device_tensors[i].mData.data(), - e_device_tensors[i].mDesc.GetElementSize() * - sizeof(EDataType)); - - auto ref_gemm = ReferenceGemmInstance{}; - auto ref_invoker = ref_gemm.MakeInvoker(); - - auto ref_argument = ref_gemm.MakeArgument(a0_tensors[i], - b_tensors[i], - e_host_tensors[i], - PassThrough{}, - b_element_op, - PassThrough{}); - - ref_invoker.Run(ref_argument); - - for(int m = 0; m < problem_size.Ms[i]; ++m) - { - for(int n = 0; n < problem_size.Ns[i]; ++n) - { - cde_element_op( - e_host_tensors[i](m, n), e_host_tensors[i](m, n), d0_tensors[i](m, n)); - } - } - - pass &= ck::utils::check_err(e_device_tensors[i], e_host_tensors[i]); - } - } - - return pass; -} - -int main(int argc, char* argv[]) -{ - ProblemSize problem_size; - ExecutionConfig config; - - problem_size.group_count = 16; - - for(int i = 0; i < problem_size.group_count; i++) - { - problem_size.Ms.push_back(32 + rand() % 32); - problem_size.Ns.push_back(64); - problem_size.Ks.push_back(64); - - problem_size.stride_As.push_back(problem_size.Ks[i]); - problem_size.stride_Bs.push_back(problem_size.Ks[i]); - problem_size.stride_Cs.push_back(problem_size.Ns[i]); - } - - if(argc == 5) - { - config.do_verification = std::stoi(argv[1]); - config.init_method = std::stoi(argv[2]); - config.time_kernel = std::stoi(argv[3]); - config.k_batch = std::stoi(argv[4]); - } - else - { - printf("arg1: verification (0=no, 1=yes)\n"); - printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); - printf("arg3: time kernel (0=no, 1=yes)\n"); - printf("arg4: k_batch (>0)\n"); - exit(0); - } - - return !run_grouped_gemm(problem_size, config); -} diff --git a/example/59_grouped_gemm_multi_ABD/grouped_gemm_multi_abd_xdl_fixed_nk_bias_bf16_i8.cpp b/example/59_grouped_gemm_multi_ABD/grouped_gemm_multi_abd_xdl_fixed_nk_bias_bf16_i8.cpp index dfb20777bc..28b3fa9213 100644 --- a/example/59_grouped_gemm_multi_ABD/grouped_gemm_multi_abd_xdl_fixed_nk_bias_bf16_i8.cpp +++ b/example/59_grouped_gemm_multi_ABD/grouped_gemm_multi_abd_xdl_fixed_nk_bias_bf16_i8.cpp @@ -20,8 +20,6 @@ #include "ck/library/utility/literals.hpp" #include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" -#include "ck/host_utility/hip_check_error.hpp" - using ::ck::DeviceMem; using ::ck::hip_check_error; using ::ck::HostTensorDescriptor; @@ -222,8 +220,8 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co for(int i = 0; i < group_count; i++) { - a0_tensors_device.emplace_back(std::make_unique( - sizeof(A0DataType) * problem_size.Ms[i] * problem_size.Ks[i])); + a0_tensors_device.emplace_back( + std::make_unique(sizeof(A0DataType) * sum_of_m * problem_size.Ks[i])); b0_tensors_device.emplace_back(std::make_unique( sizeof(B0DataType) * problem_size.Ns[i] * problem_size.Ks[i])); @@ -234,12 +232,21 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co d0_tensors_device.emplace_back( std::make_unique(sizeof(D0DataType) * problem_size.Ns[i])); - c_tensors_device.emplace_back(std::make_unique( - sizeof(EDataType) * problem_size.Ms[i] * problem_size.Ns[i])); + c_tensors_device.emplace_back( + std::make_unique(sizeof(EDataType) * sum_of_m * problem_size.Ns[i])); + + a0_tensors_device[i]->ToDevice(a0_tensors[i].mData.data(), + a0_tensors[i].mDesc.GetElementSpaceSize() * + sizeof(A0DataType)); + + b0_tensors_device[i]->ToDevice(b0_tensors[i].mData.data(), + b0_tensors[i].mDesc.GetElementSpaceSize() * + sizeof(B0DataType)); + + b1_tensors_device[i]->ToDevice(b1_tensors[i].mData.data(), + b1_tensors[i].mDesc.GetElementSpaceSize() * + sizeof(B1DataType)); - a0_tensors_device[i]->ToDevice(a0_tensors[i].mData.data()); - b0_tensors_device[i]->ToDevice(b0_tensors[i].mData.data()); - b1_tensors_device[i]->ToDevice(b1_tensors[i].mData.data()); d0_tensors_device[i]->ToDevice(d0_tensors[i].mData.data()); c_tensors_device[i]->SetZero(); @@ -391,7 +398,7 @@ int main(int argc, char* argv[]) { printf("arg1: verification (0=no, 1=yes)\n"); printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); - printf("arg3: time kernel (0=no, 1=yes)\n"); + printf("arg3: time kernel (0=n0, 1=yes)\n"); printf("arg4: k_batch (>0)\n"); exit(0); } diff --git a/example/59_grouped_gemm_multi_ABD/grouped_gemm_multi_abd_xdl_fixed_nk_bias_fp16.cpp b/example/59_grouped_gemm_multi_ABD/grouped_gemm_multi_abd_xdl_fixed_nk_bias_fp16.cpp index 82c2e17308..032842b9eb 100644 --- a/example/59_grouped_gemm_multi_ABD/grouped_gemm_multi_abd_xdl_fixed_nk_bias_fp16.cpp +++ b/example/59_grouped_gemm_multi_ABD/grouped_gemm_multi_abd_xdl_fixed_nk_bias_fp16.cpp @@ -20,8 +20,6 @@ #include "ck/library/utility/literals.hpp" #include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" -#include "ck/host_utility/hip_check_error.hpp" - using ::ck::DeviceMem; using ::ck::hip_check_error; using ::ck::HostTensorDescriptor; @@ -49,9 +47,9 @@ using B0DataType = F16; using BsDataType = ck::Tuple; using AccDataType = F32; using CShuffleDataType = F32; -using D0DataType = F16; +using D0DataType = F32; using DsDataType = ck::Tuple; -using EDataType = F16; +using EDataType = F32; using A0Layout = Row; using A1Layout = Row; @@ -212,11 +210,11 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co for(int i = 0; i < group_count; i++) { - a0_tensors_device.emplace_back(std::make_unique( - sizeof(A0DataType) * problem_size.Ms[i] * problem_size.Ks[i])); + a0_tensors_device.emplace_back( + std::make_unique(sizeof(A0DataType) * sum_of_m * problem_size.Ks[i])); - a1_tensors_device.emplace_back(std::make_unique( - sizeof(A1DataType) * problem_size.Ms[i] * problem_size.Ks[i])); + a1_tensors_device.emplace_back( + std::make_unique(sizeof(A1DataType) * sum_of_m * problem_size.Ks[i])); b_tensors_device.emplace_back(std::make_unique( sizeof(B0DataType) * problem_size.Ns[i] * problem_size.Ks[i])); @@ -224,12 +222,19 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co d0_tensors_device.emplace_back( std::make_unique(sizeof(D0DataType) * problem_size.Ns[i])); - c_tensors_device.emplace_back(std::make_unique( - sizeof(EDataType) * problem_size.Ms[i] * problem_size.Ns[i])); + c_tensors_device.emplace_back( + std::make_unique(sizeof(EDataType) * sum_of_m * problem_size.Ns[i])); - a0_tensors_device[i]->ToDevice(a0_tensors[i].mData.data()); - a1_tensors_device[i]->ToDevice(a1_tensors[i].mData.data()); - b_tensors_device[i]->ToDevice(b_tensors[i].mData.data()); + a0_tensors_device[i]->ToDevice(a0_tensors[i].mData.data(), + a0_tensors[i].mDesc.GetElementSpaceSize() * + sizeof(A0DataType)); + + a1_tensors_device[i]->ToDevice(a1_tensors[i].mData.data(), + a1_tensors[i].mDesc.GetElementSpaceSize() * + sizeof(A1DataType)); + b_tensors_device[i]->ToDevice(b_tensors[i].mData.data(), + b_tensors[i].mDesc.GetElementSpaceSize() * + sizeof(B0DataType)); d0_tensors_device[i]->ToDevice(d0_tensors[i].mData.data()); c_tensors_device[i]->SetZero(); @@ -389,7 +394,7 @@ int main(int argc, char* argv[]) { printf("arg1: verification (0=no, 1=yes)\n"); printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); - printf("arg3: time kernel (0=no, 1=yes)\n"); + printf("arg3: time kernel (0=n0, 1=yes)\n"); printf("arg4: k_batch (>0)\n"); exit(0); } diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_wmma_fixed_nk.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_wmma_fixed_nk.hpp deleted file mode 100644 index 10e604de60..0000000000 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_wmma_fixed_nk.hpp +++ /dev/null @@ -1,899 +0,0 @@ -// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. -// SPDX-License-Identifier: MIT - -#pragma once - -#include -#include - -#include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" -#include "ck/utility/env.hpp" -#include "ck/host_utility/hip_check_error.hpp" -#include "ck/utility/common_header.hpp" -#include "ck/utility/tuple.hpp" -#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" - -#include "ck/utility/common_header.hpp" -#include "ck/tensor_description/tensor_descriptor.hpp" -#include "ck/tensor_description/tensor_descriptor_helper.hpp" -#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" -#include "ck/tensor_operation/gpu/device/device_grouped_gemm_multi_abd_fixed_nk.hpp" -#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_abd_wmma_cshuffle_v3.hpp" -#include "ck/host_utility/device_prop.hpp" -#include "ck/host_utility/kernel_launch.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { - -template -__global__ void -#if CK_USE_LAUNCH_BOUNDS -__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) -#endif - kernel_grouped_gemm_wmma_fixed_nk(const void CK_CONSTANT_ADDRESS_SPACE* gemm_descs_const, - const index_t group_count, - const index_t grid_size_grp, - const AElementwiseOperation a_element_op, - const BElementwiseOperation b_element_op, - const CDEElementwiseOperation cde_element_op) -{ -#if defined(__gfx11__) || defined(__gfx12__) - __shared__ char p_shared[GridwiseGemm::template GetSharedMemoryNumberOfByte< - typename GridwiseGemm::EpilogueCShuffle>()]; - - const index_t KBatch = 1; - - const index_t block_id = get_block_1d_id(); - - const auto gemm_desc_ptr = - reinterpret_cast(cast_pointer_to_generic_address_space(gemm_descs_const)); - - const index_t group_id = block_id / grid_size_grp; - - if(group_id >= group_count) - return; - - auto karg = gemm_desc_ptr[group_id]; - - if(karg.M == 0 || karg.N == 0 || karg.K == 0) - return; - -#if defined(__gfx11__) - // gfx11 does not support *_atomic_pk_add_f16/bf16 instructions - if constexpr(!(EGlobalMemoryDataOperation == InMemoryDataOperationEnum::AtomicAdd && - (std::is_same_v || - std::is_same_v))) -#endif - { - - typename GridwiseGemm::Problem problem(karg.M, - karg.N, - karg.K, - karg.StrideAs, - karg.StrideBs, - karg.StrideDs, - karg.StrideE, - KBatch); - - const auto e_grid_desc_m_n = GridwiseGemm::template MakeDEGridDescriptor_M_N( - problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideE); - - const index_t BlockStart = group_id * grid_size_grp; - - const auto local_b2e_tile_map = Block2ETileMap{e_grid_desc_m_n, KBatch}; - - const auto local_grid_size = local_b2e_tile_map.CalculateGridSize(e_grid_desc_m_n); - - constexpr auto NumATensor = GridwiseGemm::AsGridPointer::Size(); - constexpr auto NumBTensor = GridwiseGemm::BsGridPointer::Size(); - constexpr auto NumDTensor = GridwiseGemm::DsGridPointer::Size(); - - typename GridwiseGemm::AsGridPointer p_as_grid_; - typename GridwiseGemm::BsGridPointer p_bs_grid_; - typename GridwiseGemm::DsGridPointer p_ds_grid_; - - static_for<0, NumATensor, 1>{}([&](auto i) { - using ADataType = remove_cvref_t; - p_as_grid_(i) = static_cast(karg.p_as_grid[i]); - }); - - static_for<0, NumBTensor, 1>{}([&](auto i) { - using BDataType = remove_cvref_t; - p_bs_grid_(i) = static_cast(karg.p_bs_grid[i]); - }); - - static_for<0, NumDTensor, 1>{}([&](auto i) { - using DDataType = remove_cvref_t; - p_ds_grid_(i) = static_cast(karg.p_ds_grid[i]); - }); - - index_t id_off = 0; - index_t id_local = get_block_1d_id() - BlockStart; - - while(id_local < local_grid_size) - { - const auto block_2_etile_map = - GroupedGemmBlock2ETileMap(local_b2e_tile_map, BlockStart, id_off); - - auto epilogue_args = typename GridwiseGemm::EpilogueCShuffle{}; - - GridwiseGemm::template Run( - p_as_grid_, - p_bs_grid_, - p_ds_grid_, - static_cast(karg.p_e_grid), - p_shared, - problem, - block_2_etile_map, - a_element_op, - b_element_op, - cde_element_op, - epilogue_args); - - id_off += grid_size_grp; - id_local += grid_size_grp; - } - } -#else - ignore = gemm_descs_const; - ignore = group_count; - ignore = grid_size_grp; - ignore = a_element_op; - ignore = b_element_op; - ignore = cde_element_op; -#endif -} - -template -struct DeviceGroupedGemm_Wmma_Multi_ABD_Fixed_NK - : public DeviceGroupedGemmMultiABDFixedNK -{ - using DeviceOp = DeviceGroupedGemm_Wmma_Multi_ABD_Fixed_NK; - - static constexpr index_t NumATensor = AsDataType::Size(); - static constexpr index_t NumBTensor = BsDataType::Size(); - static constexpr index_t NumDTensor = DsDataType::Size(); - - static constexpr auto I0 = Number<0>{}; - static constexpr auto I1 = Number<1>{}; - static constexpr auto I2 = Number<2>{}; - - // Note: Pass multiple layout but then using only the first one - // This is to replicate xdl functionality but it should be extended - using ALayout = remove_cvref_t>; - using BLayout = remove_cvref_t>; - - using GridwiseGemm = GridwiseGemm_wmma_cshuffle_v3< - ALayout, - BLayout, - DsLayout, - ELayout, - AsDataType, - BsDataType, - AccDataType, - CShuffleDataType, - DsDataType, - EDataType, - AElementwiseOperation, - BElementwiseOperation, - CDEElementwiseOperation, - GemmSpec, - BlockSize, - MPerBlock, - NPerBlock, - KPerBlock, - AK1, - BK1, - MPerWmma, - NPerWmma, - MRepeat, - NRepeat, - ABlockTransferThreadClusterLengths_K0_M_K1, - ABlockTransferThreadClusterArrangeOrder, - ABlockTransferSrcAccessOrder, - ABlockTransferSrcVectorDim, - ABlockTransferSrcScalarPerVector, - ABlockTransferDstScalarPerVector_K1, - false, - ABlockLdsExtraM, - BBlockTransferThreadClusterLengths_K0_N_K1, - BBlockTransferThreadClusterArrangeOrder, - BBlockTransferSrcAccessOrder, - BBlockTransferSrcVectorDim, - BBlockTransferSrcScalarPerVector, - BBlockTransferDstScalarPerVector_K1, - false, - BBlockLdsExtraN, - CShuffleMRepeatPerShuffle, - CShuffleNRepeatPerShuffle, - CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, - typename uniform_sequence_gen::type, - BlkGemmPipeSched, - BlkGemmPipelineVer, - ComputeTypeA, - ComputeTypeB, - false, - false>; - - // TODO: Block to tile mappings could potentially moved out to avoid code duplications between - // different device implementations. - - template - struct OffsettedBlockToCTileMapMLoops - { - using underlying_type = UnderlyingBlockToCTileMap; - - __host__ __device__ OffsettedBlockToCTileMapMLoops( - UnderlyingBlockToCTileMap block_to_ctile_map, index_t block_start, index_t id_off = 0) - { - block_to_ctile_map_ = block_to_ctile_map; - block_start_ = block_start; - id_off_ = id_off; - } - - template - __host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx& idx_top) const - { - auto idx_bot = block_to_ctile_map_.CalculateBottomIndex( - make_multi_index(idx_top[Number<0>{}] - block_start_ + id_off_)); - - return make_tuple(idx_bot[Number<0>{}], idx_bot[Number<1>{}], idx_bot[Number<2>{}]); - } - - template - __host__ __device__ bool ValidCTileIndex(const CTileIdx& c_tile_idx, - const CTileDim& c_tile_dim) const - { - return block_to_ctile_map_.ValidCTileIndex(c_tile_idx, c_tile_dim); - } - - template - __host__ bool CheckValidity(const CGridDesc_M_N& c_grid_desc_m_n) const - { - return block_to_ctile_map_.CheckValidity(c_grid_desc_m_n); - } - - template - __host__ constexpr index_t CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n) const - { - return block_to_ctile_map_.CalculateGridSize(c_grid_desc_m_n); - } - - UnderlyingBlockToCTileMap block_to_ctile_map_; - index_t block_start_; - index_t id_off_; - }; - - template - struct BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops - { - __host__ __device__ BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops() = default; - - __host__ __device__ BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops( - const BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops&) = default; - __host__ __device__ BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops( - BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops&&) = default; - __host__ __device__ BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops& - operator=(const BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops&) = default; - __host__ __device__ BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops& - operator=(BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops&&) = default; - - __host__ __device__ BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops(index_t M, - index_t N, - index_t KBatch, - index_t M01 = 8) - : M_(M), N_(N), KBatch_(KBatch), M01_(M01) - { - } - - template - __host__ __device__ BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops( - const CGridDesc_M_N& c_grid_desc_m_n, index_t KBatch, index_t M01 = 8) - : BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops( - c_grid_desc_m_n.GetLength(I0), c_grid_desc_m_n.GetLength(I1), KBatch, M01) - { - } - - __host__ __device__ constexpr index_t CalculateGridSize(index_t M, index_t N) const - { - const auto M0 = math::integer_divide_ceil(M, MPerBlock); - const auto N0 = math::integer_divide_ceil(N, NPerBlock); - - return M0 * N0 * KBatch_; - } - - template - __host__ __device__ constexpr index_t - CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n) const - { - return CalculateGridSize(c_grid_desc_m_n.GetLength(I0), c_grid_desc_m_n.GetLength(I1)); - } - - template - __host__ bool CheckValidity(const CGridDesc_M_N& /* c_grid_desc_m_n */) const - { - return true; - } - - template - __host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx& idx_top) const - { - auto block_1d_id = idx_top[I0]; - - const auto M0 = math::integer_divide_ceil(M_, MPerBlock_); - const auto N0 = math::integer_divide_ceil(N_, NPerBlock_); - - block_1d_id = block_1d_id % (M0 * N0 * KBatch_); // hide groups - - const index_t idx_ksplit = block_1d_id / (M0 * N0); - block_1d_id = block_1d_id % (M0 * N0); - - index_t idx_N0 = block_1d_id % N0; - index_t idx_M0 = block_1d_id / N0; - - const auto M01_adapt = (idx_M0 < M0 - M0 % M01_) ? M01_ : M0 % M01_; - - index_t idx_M00 = idx_M0 / M01_; - index_t idx_M01 = idx_M0 % M01_; - index_t idx_N0_M01_local = idx_N0 + idx_M01 * N0; - - return make_tuple(idx_ksplit, - idx_N0_M01_local % M01_adapt + idx_M00 * M01_, - idx_N0_M01_local / M01_adapt); - } - - template - __host__ __device__ bool ValidCTileIndex(const CTileIdx& /* c_tile_idx */, - const CTileDim& /* c_tile_dim */) const - { - return true; // always valid provided that user gets grid size from CalculateGridSize() - } - - private: - index_t M_; - index_t N_; - index_t KBatch_; - index_t M01_; - }; - - using Block2ETileMap = BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops; - using GroupedGemmBlock2ETileMap = OffsettedBlockToCTileMapMLoops; - - static constexpr index_t DefaultKBatch = 1; // implementation only supports KBatch == 1 - using KernelArgument = typename GridwiseGemm::Argument; - - using GemmTransKernelArg = - GroupedGemmMultiABDKernelArgument; - - static constexpr bool CalculateHasMainKBlockLoop(const GemmTransKernelArg& karg, - index_t k_batch) - { - index_t k_grain = k_batch * KPerBlock; - index_t K_split = (karg.K + k_grain - 1) / k_batch; - return GridwiseGemm::CalculateHasMainKBlockLoop(K_split); - } - - // Argument - struct Argument : public BaseArgument - { - - Argument(std::vector>& p_As, - std::vector>& p_Bs, - std::vector>& p_Ds, - std::vector& p_Es, - std::vector& gemm_descs, - AElementwiseOperation a_element_op, - BElementwiseOperation b_element_op, - CDEElementwiseOperation c_element_op) - : Argument(p_As, - p_Bs, - p_Ds, - p_Es, - gemm_descs, - a_element_op, - b_element_op, - c_element_op, - DefaultKBatch) - { - // TODO: use occupancy api to calculate appropriate batch size. - } - - // Client is expected to manually copy the kernel arguments to the device therefore there is - // no point in setting tensor device pointers for the argument structure. - Argument(std::vector>&, - std::vector>&, - std::vector>&, - std::vector&, - std::vector& gemm_descs, - AElementwiseOperation a_element_op, - BElementwiseOperation b_element_op, - CDEElementwiseOperation c_element_op, - index_t kbatch) - : group_count_{ck::type_convert(gemm_descs.size())}, - a_element_op_{a_element_op}, - b_element_op_{b_element_op}, - c_element_op_{c_element_op}, - grouped_gemm_kernel_args_dev{nullptr}, - gemm_kernel_host_args_{nullptr}, - grid_size_{0}, - k_batch_{kbatch} - { - gemm_desc_kernel_arg_.reserve(group_count_); - - index_t group_id = 0; - - sum_of_m = gemm_descs[0].M_; - const index_t AverM = math::integer_divide_ceil(sum_of_m, group_count_); - const index_t fixed_N = gemm_descs[0].N_; - const index_t fixed_K = gemm_descs[0].K_; - - for(std::size_t g = 0; g < gemm_descs.size(); g++) - { - const index_t M = gemm_descs[g].M_; - const index_t N = gemm_descs[g].N_; - const index_t K = gemm_descs[g].K_; - - if(M != sum_of_m || N != fixed_N || K != fixed_K) - { - throw std::runtime_error("wrong! M/N/K is not identical"); - } - - a_mtx_mraw_kraw_.emplace_back(sum_of_m, K); - b_mtx_nraw_kraw_.emplace_back(N, K); - - // pointer - std::array p_as_grid; - std::array p_bs_grid; - std::array p_ds_grid; - - static_for<0, NumATensor, 1>{}([&](auto i) { p_as_grid[i] = nullptr; }); - static_for<0, NumBTensor, 1>{}([&](auto i) { p_bs_grid[i] = nullptr; }); - static_for<0, NumDTensor, 1>{}([&](auto i) { p_ds_grid[i] = nullptr; }); - - std::array StrideAs; - std::array StrideBs; - std::array StrideDs; - - const index_t StrideE = gemm_descs[g].stride_C_; - - if(gemm_descs[g].stride_As_.size() != NumATensor) - { - throw std::runtime_error( - "wrong! gemm_descs[i].stride_As_.size() does not match NumATensor"); - } - - static_for<0, NumATensor, 1>{}( - [&](auto j) { StrideAs[j] = gemm_descs[g].stride_As_[j]; }); - - if(gemm_descs[g].stride_Bs_.size() != NumBTensor) - { - throw std::runtime_error( - "wrong! gemm_descs[i].stride_Bs_.size() does not match NumBTensor"); - } - - static_for<0, NumBTensor, 1>{}( - [&](auto j) { StrideBs[j] = gemm_descs[g].stride_Bs_[j]; }); - - if(gemm_descs[g].stride_Ds_.size() != NumDTensor) - { - throw std::runtime_error( - "wrong! gemm_descs[i].stride_Ds_.size() does not match NumDTensor"); - } - - static_for<0, NumDTensor, 1>{}( - [&](auto j) { StrideDs[j] = gemm_descs[g].stride_Ds_[j]; }); - - const auto e_grid_desc_m_n = - GridwiseGemm::template MakeDEGridDescriptor_M_N( - AverM, AverM, N, N, StrideE); - - // block-to-e-tile map - const auto local_b2c_tile_map = Block2ETileMap{e_grid_desc_m_n, k_batch_}; - - grid_size_grp_ = local_b2c_tile_map.CalculateGridSize(e_grid_desc_m_n); - - if(group_id * grid_size_grp_ != grid_size_) - { - throw std::runtime_error("wrong! grid_size_grp_ is not identical!"); - } - - const index_t block_start = grid_size_; - - grid_size_ += grid_size_grp_; - - if(!local_b2c_tile_map.CheckValidity(e_grid_desc_m_n)) - { - throw std::runtime_error("wrong! block_2_etile_map validation failed"); - } - - auto grouped_block_2_ctile_map = - GroupedGemmBlock2ETileMap(local_b2c_tile_map, block_start); - - auto karg = GemmTransKernelArg({p_as_grid, - p_bs_grid, - p_ds_grid, - nullptr, - AverM, - N, - K, - StrideAs, - StrideBs, - StrideDs, - StrideE}); - - gemm_desc_kernel_arg_.emplace_back(std::move(karg)); - - group_id++; - } - } - - void UpdateKBatch(index_t) {} - - index_t group_count_; - - AElementwiseOperation a_element_op_; - BElementwiseOperation b_element_op_; - CDEElementwiseOperation c_element_op_; - - std::vector gemm_desc_kernel_arg_; - std::vector> a_mtx_mraw_kraw_; - std::vector> b_mtx_nraw_kraw_; - - const void* grouped_gemm_kernel_args_dev; - void* gemm_kernel_host_args_; - index_t grid_size_; - index_t grid_size_grp_; - index_t sum_of_m; - - index_t k_batch_; - }; - - // Invoker - struct Invoker : public BaseInvoker - { - using Argument = DeviceOp::Argument; - - float RunImp(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) - { - if(arg.grouped_gemm_kernel_args_dev == nullptr) - { - throw std::runtime_error("wrong! grouped_gemm_kernel_args_dev is nullptr"); - } - - if(arg.k_batch_ != 1) - { - throw std::runtime_error("Split K functionality is not supported for wmma multi " - "abd fixed nk implementation."); - } - - float ave_time = 0; - - auto launch_kernel = [&](auto e_global_memory_operation_) { - const auto kernel = kernel_grouped_gemm_wmma_fixed_nk; - - return launch_and_time_kernel( - stream_config, - kernel, - dim3(arg.grid_size_), - dim3(BlockSize), - 0, - cast_pointer_to_constant_address_space(arg.grouped_gemm_kernel_args_dev), - arg.gemm_desc_kernel_arg_.size(), - arg.grid_size_grp_, - arg.a_element_op_, - arg.b_element_op_, - arg.c_element_op_); - }; - - constexpr auto Set = InMemoryDataOperationEnum::Set; - ave_time = launch_kernel(integral_constant{}); - - return ave_time; - } - - // polymorphic - float Run(const BaseArgument* p_arg, - const StreamConfig& stream_config = StreamConfig{}) override - { - return RunImp(*dynamic_cast(p_arg), stream_config); - } - }; - - static bool IsSupportedArgument(const Argument& arg) - { - if(!ck::is_gfx11_supported() && !ck::is_gfx12_supported()) - { - return false; - } - - if(ck::type_convert(arg.gemm_desc_kernel_arg_.size()) != arg.group_count_) - { - return false; - } - - bool supported = true; - - // If we use padding we do not support vector loads for dimensions not divisible by - // vector load size. - if constexpr(GemmSpec != GemmSpecialization::Default) - { - // [A|B]BlockTransferSrcVectorDim value define dimension in the block {K0,M,K1} layout, - // thus we have to adapt it to the {M,K} or {N,K} layout. - const auto a_raw_vector_dim = ABlockTransferSrcVectorDim != 1 ? 1 : 0; - const auto b_raw_vector_dim = BBlockTransferSrcVectorDim != 1 ? 1 : 0; - - for(index_t i = 0; i < arg.group_count_; ++i) - { - const auto a_vector_dim = arg.a_mtx_mraw_kraw_[i].At(Number{}); - const auto b_vector_dim = arg.b_mtx_nraw_kraw_[i].At(Number{}); - - supported = supported & (a_vector_dim % ABlockTransferSrcScalarPerVector == 0); - supported = supported & (b_vector_dim % BBlockTransferSrcScalarPerVector == 0); - } - } - - for(index_t i = 0; i < arg.group_count_; i++) - { - if(CalculateHasMainKBlockLoop(arg.gemm_desc_kernel_arg_[i], arg.k_batch_) != true) - { - supported = false; - } - } - - return supported; - } - - // polymorphic - bool IsSupportedArgument(const BaseArgument* p_arg) override - { - return IsSupportedArgument(*dynamic_cast(p_arg)); - } - - static auto MakeArgument(std::vector>& p_As, - std::vector>& p_Bs, - std::vector>& p_Ds, - std::vector& p_Es, - std::vector gemm_descs, - AElementwiseOperation a_element_op = AElementwiseOperation{}, - BElementwiseOperation b_element_op = BElementwiseOperation{}, - CDEElementwiseOperation c_element_op = CDEElementwiseOperation{}) - { - return Argument{ - p_As, p_Bs, p_Ds, p_Es, gemm_descs, a_element_op, b_element_op, c_element_op}; - } - - static auto MakeInvoker() { return Invoker{}; } - - // polymorphic - std::unique_ptr - MakeArgumentPointer(std::vector>& p_As, - std::vector>& p_Bs, - std::vector>& p_Ds, - std::vector& p_Es, - std::vector& gemm_descs, - AElementwiseOperation a_element_op = AElementwiseOperation{}, - BElementwiseOperation b_element_op = BElementwiseOperation{}, - CDEElementwiseOperation c_element_op = CDEElementwiseOperation{}) override - { - return std::make_unique( - p_As, p_Bs, p_Ds, p_Es, gemm_descs, a_element_op, b_element_op, c_element_op); - } - - // polymorphic - std::unique_ptr MakeInvokerPointer() override - { - return std::make_unique(Invoker{}); - } - - // polymorphic - std::string GetTypeString() const override - { - auto str = std::stringstream(); - - // clang-format off - str << "DeviceGroupedGemm_Wmma_Fixed_Nk" - << "<" - << BlockSize << ", " - << MPerBlock << ", " - << NPerBlock << ", " - << KPerBlock << ", " - << AK1 << ", " - << BK1 << ", " - << MPerWmma << ", " - << NPerWmma << ", " - << ABlockTransferSrcScalarPerVector << ", " - << BBlockTransferSrcScalarPerVector << ", " - << CShuffleMRepeatPerShuffle << ", " - << CShuffleNRepeatPerShuffle << ", " - << getGemmSpecializationString(GemmSpec) - << ">"; - // clang-format on - - return str.str(); - } - - static void SetElementwiseOps(Argument& arg, - AElementwiseOperation a_element_op, - BElementwiseOperation b_element_op, - CDEElementwiseOperation c_element_op) - { - arg.a_element_op_ = a_element_op; - arg.b_element_op_ = b_element_op; - arg.c_element_op_ = c_element_op; - } - - // polymorphic - void SetElementwiseOps(BaseArgument* p_arg, - AElementwiseOperation a_element_op, - BElementwiseOperation b_element_op, - CDEElementwiseOperation c_element_op) const override - { - - SetElementwiseOps( - *dynamic_cast(p_arg), a_element_op, b_element_op, c_element_op); - } - - static void SetDeviceKernelArgs(Argument& arg, const void* kernel_args) - { - arg.grouped_gemm_kernel_args_dev = kernel_args; - } - - // polymorphic - void SetDeviceKernelArgs(BaseArgument* p_arg, const void* kernel_args) const override - { - return SetDeviceKernelArgs(*dynamic_cast(p_arg), kernel_args); - } - - size_t GetDeviceKernelArgSize(const BaseArgument* p_arg) const override - { - auto arg = *dynamic_cast(p_arg); - - return arg.group_count_ * - sizeof(GroupedGemmMultiABDKernelArgument); - } - - size_t GetWorkSpaceSize(const BaseArgument* p_arg) const override - { - auto p_arg_ = dynamic_cast(p_arg); - if(p_arg_) - { - return p_arg_->gemm_desc_kernel_arg_.size() * sizeof(GemmTransKernelArg); - } - else - throw std::runtime_error( - "The argument pointer is not an object of " - "DeviceGroupedGemm_Wmma_Multi_ABD_Fixed_NK::Argument structure!"); - } - - void SetWorkSpacePointer(BaseArgument* p_arg, - void* p_workspace, - const StreamConfig& stream_config = StreamConfig{}) const override - { - auto p_arg_ = dynamic_cast(p_arg); - p_arg_->p_workspace_ = p_workspace; - - hip_check_error( - hipMemsetAsync(p_workspace, 0, GetWorkSpaceSize(p_arg), stream_config.stream_id_)); - } - - static void SetKBatch(Argument& arg, index_t k_batch) { arg.UpdateKBatch(k_batch); } - - // polymorphic - void SetKBatch(BaseArgument* p_arg, index_t k_batch) const override - { - return SetKBatch(*dynamic_cast(p_arg), k_batch); - } - - void SetHostKernelArgsPointer(BaseArgument* p_arg, void* p_host_kernel_args) const - { - Argument* pArg_ = dynamic_cast(p_arg); - if(!pArg_) - { - throw std::runtime_error("Failed to cast argument pointer!"); - } - - pArg_->gemm_kernel_host_args_ = p_host_kernel_args; - std::copy(pArg_->gemm_desc_kernel_arg_.begin(), - pArg_->gemm_desc_kernel_arg_.end(), - static_cast(pArg_->gemm_kernel_host_args_)); - } -}; - -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp index 36e66017c6..fb4e01b961 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp @@ -605,7 +605,7 @@ struct DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK if(arg.grouped_gemm_kernel_args_dev == nullptr) { - throw std::runtime_error("wrong! grouped_gemm_kernel_args_dev is nullptr"); + throw std::runtime_error("wrong! grouped_gemm_kernel_args_dev is nullpr"); } float ave_time = 0; @@ -688,11 +688,6 @@ struct DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK static bool IsSupportedArgument(const Argument& arg) { - if(!ck::is_xdl_wmma_supported()) - { - return false; - } - // Split-K autodeduction is not supported if(arg.k_batch_ < 1) { @@ -725,26 +720,6 @@ struct DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK } } - for(index_t i = 0; i < arg.group_count_; i++) - { - if(get_warp_size() == 64) - { - if(GridwiseGemm64::CalculateHasMainKBlockLoop(arg.gemm_desc_kernel_arg_[i].K_) != - true) - { - supported = false; - } - } - else - { - if(GridwiseGemm32::CalculateHasMainKBlockLoop(arg.gemm_desc_kernel_arg_[i].K_) != - true) - { - supported = false; - } - } - } - return supported; } diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_fixed_nk.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_fixed_nk.hpp index 311a1c0bf4..7653724b21 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_fixed_nk.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_fixed_nk.hpp @@ -696,7 +696,7 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK& ty); } -template -auto concat_tuple_of_reference(ck::Tuple& tx, ck::Tuple& ty) -{ - return ck::unpack2( - [&](auto&&... zs) { return ck::Tuple{ck::forward(zs)...}; }, - tx, - ty); -} - template __host__ __device__ constexpr auto concat_tuple(const Tuple& tx, const Tuple& ty) { diff --git a/library/include/ck/library/reference_tensor_operation/cpu/reference_gemm_multi_abd.hpp b/library/include/ck/library/reference_tensor_operation/cpu/reference_gemm_multi_abd.hpp deleted file mode 100644 index 2d766e621b..0000000000 --- a/library/include/ck/library/reference_tensor_operation/cpu/reference_gemm_multi_abd.hpp +++ /dev/null @@ -1,194 +0,0 @@ -// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. -// SPDX-License-Identifier: MIT - -#pragma once - -#include -#include - -#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" -#include "ck/tensor_operation/gpu/device/device_base.hpp" -#include "ck/library/utility/host_tensor.hpp" -#include "ck/utility/functional4.hpp" -#include "ck/utility/tuple_helper.hpp" - -#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" - -namespace ck { -namespace tensor_operation { -namespace host { - -template -struct ReferenceGemmMultiABD : public device::BaseOperator -{ - // Argument - struct Argument : public device::BaseArgument - { - Argument(const AsTensorTuple& as_m_k, - const BsTensorTuple& bs_k_n, - const DsTensorTuple& ds_m_n, - Tensor& e_m_n, - AElementwiseOperation a_element_op, - BElementwiseOperation b_element_op, - CDEElementwiseOperation cde_element_op) - : as_m_k_{as_m_k}, - bs_k_n_{bs_k_n}, - ds_m_n_{ds_m_n}, - e_m_n_{e_m_n}, - a_element_op_{a_element_op}, - b_element_op_{b_element_op}, - cde_element_op_{cde_element_op} - { - } - - const AsTensorTuple& as_m_k_; - const BsTensorTuple& bs_k_n_; - const DsTensorTuple& ds_m_n_; - Tensor& e_m_n_; - - AElementwiseOperation a_element_op_; - BElementwiseOperation b_element_op_; - CDEElementwiseOperation cde_element_op_; - }; - - // Invoker - struct Invoker : public device::BaseInvoker - { - using Argument = ReferenceGemmMultiABD::Argument; - - float Run(const Argument& arg) - { - static constexpr index_t NumATensor = AsTensorTuple::Size(); - static constexpr index_t NumBTensor = BsTensorTuple::Size(); - static constexpr index_t NumDTensor = DsTensorTuple::Size(); - - const int M = arg.as_m_k_[Number<0>{}].mDesc.GetLengths()[0]; - const int K = arg.as_m_k_[Number<0>{}].mDesc.GetLengths()[1]; - const int N = arg.bs_k_n_[Number<0>{}].mDesc.GetLengths()[1]; - - Tensor a_m_k({M, K}); - for(int m = 0; m < M; ++m) - { - for(int k = 0; k < K; ++k) - { - // result - auto data_refs1 = ck::tie(a_m_k(m, k)); - // inputs - auto data_refs2 = generate_tie( - [&](auto i) -> auto& { return arg.as_m_k_[Number{}](m, k); }, - Number{}); - auto data_refs = concat_tuple_of_reference(data_refs1, data_refs2); - unpack(arg.a_element_op_, data_refs); - } - } - - Tensor b_k_n({K, N}); - for(int k = 0; k < K; ++k) - { - for(int n = 0; n < N; ++n) - { - // result - auto data_refs1 = ck::tie(b_k_n(k, n)); - // inputs - auto data_refs2 = generate_tie( - [&](auto i) -> auto& { return arg.bs_k_n_[Number{}](k, n); }, - Number{}); - auto data_refs = concat_tuple_of_reference(data_refs1, data_refs2); - unpack(arg.b_element_op_, data_refs); - } - } - - using PassThrough = ck::tensor_operation::element_wise::PassThrough; - Tensor c_m_n({M, N}); - - using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; - auto ref_gemm = ReferenceGemmInstance{}; - auto ref_invoker = ref_gemm.MakeInvoker(); - - auto ref_argument = ref_gemm.MakeArgument( - a_m_k, b_k_n, c_m_n, PassThrough{}, PassThrough{}, PassThrough{}); - - ref_invoker.Run(ref_argument); - - for(int m = 0; m < M; ++m) - { - for(int n = 0; n < N; ++n) - { - // compulsory - auto data_refs1 = ck::tie(arg.e_m_n_(m, n), c_m_n(m, n)); - // optional (if multiple Ds) - auto data_refs2 = generate_tie( - [&](auto i) -> auto& { return arg.ds_m_n_[Number{}](m, n); }, - Number{}); - auto data_refs = concat_tuple_of_reference(data_refs1, data_refs2); - unpack(arg.cde_element_op_, data_refs); - } - } - - return 0; - } - - float Run(const device::BaseArgument* p_arg, - const StreamConfig& /* stream_config */ = StreamConfig{}) override - { - return Run(*dynamic_cast(p_arg)); - } - }; - - static constexpr bool IsValidCompilationParameter() - { - // TODO: properly implement this check - return true; - } - - bool IsSupportedArgument(const device::BaseArgument*) override { return true; } - - static auto MakeArgument(const AsTensorTuple& as_m_k, - const BsTensorTuple& bs_k_n, - const DsTensorTuple& ds_m_n, - Tensor& e_m_n, - AElementwiseOperation a_element_op, - BElementwiseOperation b_element_op, - CDEElementwiseOperation cde_element_op) - { - return Argument{as_m_k, bs_k_n, ds_m_n, e_m_n, a_element_op, b_element_op, cde_element_op}; - } - - static auto MakeInvoker() { return Invoker{}; } - - virtual std::unique_ptr MakeInvokerPointer() - { - return std::make_unique(Invoker{}); - } - - std::string GetTypeString() const override - { - auto str = std::stringstream(); - - // clang-format off - str << "ReferenceGemmMultiABD" - << std::endl; - // clang-format on - - return str.str(); - } -}; - -} // namespace host -} // namespace tensor_operation -} // namespace ck diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_gemm_multi_abd_fixed_nk.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_gemm_multi_abd_fixed_nk.hpp index 0879bea4ea..6d97ec3a05 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_gemm_multi_abd_fixed_nk.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_gemm_multi_abd_fixed_nk.hpp @@ -10,6 +10,7 @@ #include "ck/ck.hpp" #include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" #include "ck/tensor_operation/gpu/device/device_grouped_gemm_multi_abd_fixed_nk.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp" namespace ck { namespace tensor_operation { @@ -20,7 +21,6 @@ using Multiply = ck::tensor_operation::element_wise::Multiply; using PassThrough = ck::tensor_operation::element_wise::PassThrough; using AddFastGelu = ck::tensor_operation::element_wise::AddFastGelu; -#if defined(CK_USE_XDL) // RRR void add_device_grouped_gemm_xdl_fixed_nk_multi_abd_bf16_i8_bf16_mk_kn_mn_bias_gelu_instances( std::vector, @@ -179,167 +179,6 @@ void add_device_grouped_gemm_xdl_fixed_nk_multi_abd_bf16_i8_bf16_km_kn_mn_instan PassThrough, Multiply, PassThrough>>>& instances); -#endif - -#if defined(CK_USE_WMMA) -void add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_kn_mn_bias_gelu_instances( - std::vector, - ck::Tuple, - ck::Tuple, - Row, - ck::Tuple, - ck::Tuple, - ck::Tuple, - BF16, - PassThrough, - Multiply, - AddFastGelu>>>& instances); - -void add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_kn_mn_bias_instances( - std::vector, - ck::Tuple, - ck::Tuple, - Row, - ck::Tuple, - ck::Tuple, - ck::Tuple, - BF16, - PassThrough, - Multiply, - Add>>>& instances); - -void add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_kn_mn_gelu_instances( - std::vector, - ck::Tuple, - ck::Tuple<>, - Row, - ck::Tuple, - ck::Tuple, - ck::Tuple<>, - BF16, - PassThrough, - Multiply, - FastGelu>>>& instances); - -void add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_kn_mn_instances( - std::vector, - ck::Tuple, - ck::Tuple<>, - Row, - ck::Tuple, - ck::Tuple, - ck::Tuple<>, - BF16, - PassThrough, - Multiply, - PassThrough>>>& instances); - -// RCR -void add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_nk_mn_bias_gelu_instances( - std::vector, - ck::Tuple, - ck::Tuple, - Row, - ck::Tuple, - ck::Tuple, - ck::Tuple, - BF16, - PassThrough, - Multiply, - AddFastGelu>>>& instances); - -void add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_nk_mn_bias_instances( - std::vector, - ck::Tuple, - ck::Tuple, - Row, - ck::Tuple, - ck::Tuple, - ck::Tuple, - BF16, - PassThrough, - Multiply, - Add>>>& instances); - -void add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_nk_mn_gelu_instances( - std::vector, - ck::Tuple, - ck::Tuple<>, - Row, - ck::Tuple, - ck::Tuple, - ck::Tuple<>, - BF16, - PassThrough, - Multiply, - FastGelu>>>& instances); - -void add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_nk_mn_instances( - std::vector, - ck::Tuple, - ck::Tuple<>, - Row, - ck::Tuple, - ck::Tuple, - ck::Tuple<>, - BF16, - PassThrough, - Multiply, - PassThrough>>>& instances); - -// CRR -void add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_km_kn_mn_bias_gelu_instances( - std::vector, - ck::Tuple, - ck::Tuple, - Row, - ck::Tuple, - ck::Tuple, - ck::Tuple, - BF16, - PassThrough, - Multiply, - AddFastGelu>>>& instances); - -void add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_km_kn_mn_bias_instances( - std::vector, - ck::Tuple, - ck::Tuple, - Row, - ck::Tuple, - ck::Tuple, - ck::Tuple, - BF16, - PassThrough, - Multiply, - Add>>>& instances); - -void add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_km_kn_mn_gelu_instances( - std::vector, - ck::Tuple, - ck::Tuple<>, - Row, - ck::Tuple, - ck::Tuple, - ck::Tuple<>, - BF16, - PassThrough, - Multiply, - FastGelu>>>& instances); - -void add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_km_kn_mn_instances( - std::vector, - ck::Tuple, - ck::Tuple<>, - Row, - ck::Tuple, - ck::Tuple, - ck::Tuple<>, - BF16, - PassThrough, - Multiply, - PassThrough>>>& instances); -#endif // CK_USE // GEMM + Add + Gelu template > op_ptrs; -#if defined(CK_USE_XDL) if constexpr(is_same_v> && is_same_v> && is_same_v> && is_same_v) @@ -408,38 +246,6 @@ struct DeviceOperationInstanceFactory< op_ptrs); } } -#endif // CK_USE_XDL - -#if defined(CK_USE_WMMA) - if constexpr(is_same_v> && - is_same_v> && - is_same_v> && is_same_v) - { - if constexpr(is_same_v> && - is_same_v> && - is_same_v> && is_same_v) - { - add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_kn_mn_bias_gelu_instances( - op_ptrs); - } - - if constexpr(is_same_v> && - is_same_v> && - is_same_v> && is_same_v) - { - add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_km_kn_mn_bias_gelu_instances( - op_ptrs); - } - - if constexpr(is_same_v> && - is_same_v> && - is_same_v> && is_same_v) - { - add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_nk_mn_bias_gelu_instances( - op_ptrs); - } - } -#endif // CK_USE_WMMA return op_ptrs; } @@ -483,7 +289,6 @@ struct DeviceOperationInstanceFactory< { std::vector> op_ptrs; -#if defined(CK_USE_XDL) if constexpr(is_same_v> && is_same_v> && is_same_v> && is_same_v) @@ -512,38 +317,6 @@ struct DeviceOperationInstanceFactory< op_ptrs); } } -#endif // CK_USE_XDL - -#if defined(CK_USE_WMMA) - if constexpr(is_same_v> && - is_same_v> && - is_same_v> && is_same_v) - { - if constexpr(is_same_v> && - is_same_v> && - is_same_v> && is_same_v) - { - add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_kn_mn_bias_instances( - op_ptrs); - } - - if constexpr(is_same_v> && - is_same_v> && - is_same_v> && is_same_v) - { - add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_km_kn_mn_bias_instances( - op_ptrs); - } - - if constexpr(is_same_v> && - is_same_v> && - is_same_v> && is_same_v) - { - add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_nk_mn_bias_instances( - op_ptrs); - } - } -#endif // CK_USE_WMMA return op_ptrs; } @@ -587,7 +360,6 @@ struct DeviceOperationInstanceFactory< { std::vector> op_ptrs; -#if defined(CK_USE_XDL) if constexpr(is_same_v> && is_same_v> && is_same_v> && is_same_v) @@ -616,38 +388,6 @@ struct DeviceOperationInstanceFactory< op_ptrs); } } -#endif // CK_USE_XDL - -#if defined(CK_USE_WMMA) - if constexpr(is_same_v> && - is_same_v> && - is_same_v> && is_same_v) - { - if constexpr(is_same_v> && - is_same_v> && - is_same_v> && is_same_v) - { - add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_kn_mn_gelu_instances( - op_ptrs); - } - - if constexpr(is_same_v> && - is_same_v> && - is_same_v> && is_same_v) - { - add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_km_kn_mn_gelu_instances( - op_ptrs); - } - - if constexpr(is_same_v> && - is_same_v> && - is_same_v> && is_same_v) - { - add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_nk_mn_gelu_instances( - op_ptrs); - } - } -#endif // CK_USE_WMMA return op_ptrs; } @@ -691,7 +431,6 @@ struct DeviceOperationInstanceFactory< { std::vector> op_ptrs; -#if defined(CK_USE_XDL) if constexpr(is_same_v> && is_same_v> && is_same_v> && is_same_v) @@ -720,38 +459,6 @@ struct DeviceOperationInstanceFactory< op_ptrs); } } -#endif // CK_USE_XDL - -#if defined(CK_USE_WMMA) - if constexpr(is_same_v> && - is_same_v> && - is_same_v> && is_same_v) - { - if constexpr(is_same_v> && - is_same_v> && - is_same_v> && is_same_v) - { - add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_kn_mn_instances( - op_ptrs); - } - - if constexpr(is_same_v> && - is_same_v> && - is_same_v> && is_same_v) - { - add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_km_kn_mn_instances( - op_ptrs); - } - - if constexpr(is_same_v> && - is_same_v> && - is_same_v> && is_same_v) - { - add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_nk_mn_instances( - op_ptrs); - } - } -#endif // CK_USE_WMMA return op_ptrs; } diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk_multi_abd/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk_multi_abd/CMakeLists.txt index fc60f48727..9d9a0e691c 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk_multi_abd/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk_multi_abd/CMakeLists.txt @@ -1,17 +1,13 @@ # Copyright (c) Advanced Micro Devices, Inc., or its affiliates. # SPDX-License-Identifier: MIT -# ONLY XDL_AND_WMMA_KERNELS +# ONLY XDL_KERNELS set(GROUPED_GEMM_FIXED_NK_MULTI_ABD_INSTANCES) list(APPEND GROUPED_GEMM_FIXED_NK_MULTI_ABD_INSTANCES device_grouped_gemm_xdl_fixed_nk_bias_gelu_bf16_i8_bf16_mk_kn_mn_instance.cpp device_grouped_gemm_xdl_fixed_nk_bias_gelu_bf16_i8_bf16_mk_nk_mn_instance.cpp device_grouped_gemm_xdl_fixed_nk_bias_gelu_bf16_i8_bf16_km_kn_mn_instance.cpp - - device_grouped_gemm_wmma_fixed_nk_bias_gelu_bf16_i8_bf16_mk_kn_mn_instance.cpp - device_grouped_gemm_wmma_fixed_nk_bias_gelu_bf16_i8_bf16_mk_nk_mn_instance.cpp - device_grouped_gemm_wmma_fixed_nk_bias_gelu_bf16_i8_bf16_km_kn_mn_instance.cpp ) add_instance_library(device_grouped_gemm_fixed_nk_multi_abd_instance ${GROUPED_GEMM_FIXED_NK_MULTI_ABD_INSTANCES}) diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk_multi_abd/device_grouped_gemm_wmma_fixed_nk_bias_gelu_bf16_i8_bf16_km_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk_multi_abd/device_grouped_gemm_wmma_fixed_nk_bias_gelu_bf16_i8_bf16_km_kn_mn_instance.cpp deleted file mode 100644 index a29f8513d8..0000000000 --- a/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk_multi_abd/device_grouped_gemm_wmma_fixed_nk_bias_gelu_bf16_i8_bf16_km_kn_mn_instance.cpp +++ /dev/null @@ -1,144 +0,0 @@ -// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. -// SPDX-License-Identifier: MIT - -#include - -#include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" -#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" -#include "ck/tensor_operation/gpu/device/device_grouped_gemm_multi_abd.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_wmma_fixed_nk.hpp" -#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" -#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { - -template -using S = Sequence; - -using BF16 = bhalf_t; -using I8 = int8_t; -using F32 = float; - -using Row = tensor_layout::gemm::RowMajor; -using Col = tensor_layout::gemm::ColumnMajor; - -using Multiply = element_wise::Multiply; -using PassThrough = element_wise::PassThrough; -using AddFastGelu = element_wise::AddFastGelu; -using Add = element_wise::Add; -using FastGelu = element_wise::FastGelu; - -static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding; - -template -using device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_km_kn_mn_instances = std::tuple< - // clang-format off - //#######################################| AsLayout| BsLayout| DsLayout| ELayout| AsData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //#######################################| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| _MBlock_MPerBlock_NBlock_NPerBlock| ScalarPerVector| - //#######################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| | _NWaveNPerXdl| - //#######################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGroupedGemm_Wmma_Multi_ABD_Fixed_NK< Tuple, Tuple, DsLayout, Row, Tuple, Tuple, F32, BF16, DsDataType, BF16, PassThrough, Multiply, CDEElementOp, GemmSpec, 256, 256, 256, 32, 8, 8, 16, 16, 8, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 0, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 32, 1, 8>, 4>, - DeviceGroupedGemm_Wmma_Multi_ABD_Fixed_NK< Tuple, Tuple, DsLayout, Row, Tuple, Tuple, F32, BF16, DsDataType, BF16, PassThrough, Multiply, CDEElementOp, GemmSpec, 256, 128, 256, 32, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 0, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 32, 1, 8>, 4>, - DeviceGroupedGemm_Wmma_Multi_ABD_Fixed_NK< Tuple, Tuple, DsLayout, Row, Tuple, Tuple, F32, BF16, DsDataType, BF16, PassThrough, Multiply, CDEElementOp, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 2, 4, S<4, 32, 2>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 0, S<4, 32, 2>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 32, 1, 8>, 4> - // clang-format on - >; - -void add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_km_kn_mn_bias_gelu_instances( - std::vector, - Tuple, - Tuple, - Row, - Tuple, - Tuple, - Tuple, - BF16, - PassThrough, - Multiply, - AddFastGelu>>>& instances) -{ - add_device_operation_instances( - instances, - device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_km_kn_mn_instances< - Tuple, - Tuple, - AddFastGelu, - GemmMNKPadding>{}); -} - -void add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_km_kn_mn_bias_instances( - std::vector, - Tuple, - Tuple, - Row, - Tuple, - Tuple, - Tuple, - BF16, - PassThrough, - Multiply, - Add>>>& instances) -{ - add_device_operation_instances( - instances, - device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_km_kn_mn_instances< - Tuple, - Tuple, - Add, - GemmMNKPadding>{}); -} - -void add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_km_kn_mn_instances( - std::vector, - Tuple, - Tuple<>, - Row, - Tuple, - Tuple, - Tuple<>, - BF16, - PassThrough, - Multiply, - PassThrough>>>& instances) -{ - add_device_operation_instances( - instances, - device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_km_kn_mn_instances< - Tuple<>, - Tuple<>, - PassThrough, - GemmMNKPadding>{}); -} - -void add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_km_kn_mn_gelu_instances( - std::vector, - Tuple, - Tuple<>, - Row, - Tuple, - Tuple, - Tuple<>, - BF16, - PassThrough, - Multiply, - FastGelu>>>& instances) -{ - add_device_operation_instances( - instances, - device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_km_kn_mn_instances< - Tuple<>, - Tuple<>, - FastGelu, - GemmMNKPadding>{}); -} - -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk_multi_abd/device_grouped_gemm_wmma_fixed_nk_bias_gelu_bf16_i8_bf16_mk_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk_multi_abd/device_grouped_gemm_wmma_fixed_nk_bias_gelu_bf16_i8_bf16_mk_kn_mn_instance.cpp deleted file mode 100644 index 2eaaaf009a..0000000000 --- a/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk_multi_abd/device_grouped_gemm_wmma_fixed_nk_bias_gelu_bf16_i8_bf16_mk_kn_mn_instance.cpp +++ /dev/null @@ -1,144 +0,0 @@ -// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. -// SPDX-License-Identifier: MIT - -#include - -#include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" -#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" -#include "ck/tensor_operation/gpu/device/device_grouped_gemm_multi_abd.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_wmma_fixed_nk.hpp" -#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" -#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { - -template -using S = Sequence; - -using BF16 = bhalf_t; -using I8 = int8_t; -using F32 = float; - -using Row = tensor_layout::gemm::RowMajor; -using Col = tensor_layout::gemm::ColumnMajor; - -using Multiply = element_wise::Multiply; -using PassThrough = element_wise::PassThrough; -using AddFastGelu = element_wise::AddFastGelu; -using Add = element_wise::Add; -using FastGelu = element_wise::FastGelu; - -static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding; - -template -using device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_kn_mn_instances = std::tuple< - // clang-format off - //#######################################| AsLayout| BsLayout| DsLayout| ELayout| AsData| BsData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //#######################################| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| _MBlock_MPerBlock_NBlock_NPerBlock| ScalarPerVector| - //#######################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| | | - //#######################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGroupedGemm_Wmma_Multi_ABD_Fixed_NK< Tuple, Tuple, DsLayout, Row, Tuple, Tuple, F32, BF16, DsDataType, BF16, PassThrough, Multiply, CDEElementOp, GemmSpec, 256, 256, 256, 32, 8, 8, 16, 16, 8, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 32, 1, 8>, 4>, - DeviceGroupedGemm_Wmma_Multi_ABD_Fixed_NK< Tuple, Tuple, DsLayout, Row, Tuple, Tuple, F32, BF16, DsDataType, BF16, PassThrough, Multiply, CDEElementOp, GemmSpec, 256, 128, 256, 32, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 32, 1, 8>, 4>, - DeviceGroupedGemm_Wmma_Multi_ABD_Fixed_NK< Tuple, Tuple, DsLayout, Row, Tuple, Tuple, F32, BF16, DsDataType, BF16, PassThrough, Multiply, CDEElementOp, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 2>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 4> - // clang-format on - >; - -void add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_kn_mn_bias_gelu_instances( - std::vector, - Tuple, - Tuple, - Row, - Tuple, - Tuple, - Tuple, - BF16, - PassThrough, - Multiply, - AddFastGelu>>>& instances) -{ - add_device_operation_instances( - instances, - device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_kn_mn_instances< - Tuple, - Tuple, - AddFastGelu, - GemmMNKPadding>{}); -} - -void add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_kn_mn_bias_instances( - std::vector, - Tuple, - Tuple, - Row, - Tuple, - Tuple, - Tuple, - BF16, - PassThrough, - Multiply, - Add>>>& instances) -{ - add_device_operation_instances( - instances, - device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_kn_mn_instances< - Tuple, - Tuple, - Add, - GemmMNKPadding>{}); -} - -void add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_kn_mn_instances( - std::vector, - Tuple, - Tuple<>, - Row, - Tuple, - Tuple, - Tuple<>, - BF16, - PassThrough, - Multiply, - PassThrough>>>& instances) -{ - add_device_operation_instances( - instances, - device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_kn_mn_instances< - Tuple<>, - Tuple<>, - PassThrough, - GemmMNKPadding>{}); -} - -void add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_kn_mn_gelu_instances( - std::vector, - Tuple, - Tuple<>, - Row, - Tuple, - Tuple, - Tuple<>, - BF16, - PassThrough, - Multiply, - FastGelu>>>& instances) -{ - add_device_operation_instances( - instances, - device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_kn_mn_instances< - Tuple<>, - Tuple<>, - FastGelu, - GemmMNKPadding>{}); -} - -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk_multi_abd/device_grouped_gemm_wmma_fixed_nk_bias_gelu_bf16_i8_bf16_mk_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk_multi_abd/device_grouped_gemm_wmma_fixed_nk_bias_gelu_bf16_i8_bf16_mk_nk_mn_instance.cpp deleted file mode 100644 index 3320b4afa6..0000000000 --- a/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk_multi_abd/device_grouped_gemm_wmma_fixed_nk_bias_gelu_bf16_i8_bf16_mk_nk_mn_instance.cpp +++ /dev/null @@ -1,144 +0,0 @@ -// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. -// SPDX-License-Identifier: MIT - -#include - -#include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" -#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" -#include "ck/tensor_operation/gpu/device/device_grouped_gemm_multi_abd.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_wmma_fixed_nk.hpp" -#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" -#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { - -template -using S = Sequence; - -using BF16 = bhalf_t; -using I8 = int8_t; -using F32 = float; - -using Row = tensor_layout::gemm::RowMajor; -using Col = tensor_layout::gemm::ColumnMajor; - -using Multiply = element_wise::Multiply; -using PassThrough = element_wise::PassThrough; -using AddFastGelu = element_wise::AddFastGelu; -using Add = element_wise::Add; -using FastGelu = element_wise::FastGelu; - -static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding; - -template -using device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_nk_mn_instances = std::tuple< - // clang-format off - //######################################| AsLayout| BsLayout| DsLayout| ELayout| AsData| BsData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //######################################| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise|Spacialization| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| _MBlock_MPerBlock_NBlock_NPerBlock| ScalarPerVector| - //######################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| | | - //######################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGroupedGemm_Wmma_Multi_ABD_Fixed_NK< Tuple, Tuple, DsLayout, Row, Tuple, Tuple, F32, BF16, DsDataType, BF16, PassThrough, Multiply, CDEElementOp, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 2, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8>, - DeviceGroupedGemm_Wmma_Multi_ABD_Fixed_NK< Tuple, Tuple, DsLayout, Row, Tuple, Tuple, F32, BF16, DsDataType, BF16, PassThrough, Multiply, CDEElementOp, GemmSpec, 256, 128, 128, 64, 2, 2, 16, 16, 2, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, 1, 1, S<1, 64, 1, 4>, 8>, - DeviceGroupedGemm_Wmma_Multi_ABD_Fixed_NK< Tuple, Tuple, DsLayout, Row, Tuple, Tuple, F32, BF16, DsDataType, BF16, PassThrough, Multiply, CDEElementOp, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 2, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8> - // clang-format on - >; - -void add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_nk_mn_bias_gelu_instances( - std::vector, - Tuple, - Tuple, - Row, - Tuple, - Tuple, - Tuple, - BF16, - PassThrough, - Multiply, - AddFastGelu>>>& instances) -{ - add_device_operation_instances( - instances, - device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_nk_mn_instances< - Tuple, - Tuple, - AddFastGelu, - GemmMNKPadding>{}); -} - -void add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_nk_mn_bias_instances( - std::vector, - Tuple, - Tuple, - Row, - Tuple, - Tuple, - Tuple, - BF16, - PassThrough, - Multiply, - Add>>>& instances) -{ - add_device_operation_instances( - instances, - device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_nk_mn_instances< - Tuple, - Tuple, - Add, - GemmMNKPadding>{}); -} - -void add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_nk_mn_instances( - std::vector, - Tuple, - Tuple<>, - Row, - Tuple, - Tuple, - Tuple<>, - BF16, - PassThrough, - Multiply, - PassThrough>>>& instances) -{ - add_device_operation_instances( - instances, - device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_nk_mn_instances< - Tuple<>, - Tuple<>, - PassThrough, - GemmMNKPadding>{}); -} - -void add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_nk_mn_gelu_instances( - std::vector, - Tuple, - Tuple<>, - Row, - Tuple, - Tuple, - Tuple<>, - BF16, - PassThrough, - Multiply, - FastGelu>>>& instances) -{ - add_device_operation_instances( - instances, - device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_nk_mn_instances< - Tuple<>, - Tuple<>, - FastGelu, - GemmMNKPadding>{}); -} - -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk_multi_abd/device_grouped_gemm_xdl_fixed_nk_bf16_i8_bf16_km_kn_mn_common.hpp b/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk_multi_abd/device_grouped_gemm_xdl_fixed_nk_bf16_i8_bf16_km_kn_mn_common.hpp index 6e72d379d0..23e3b7f511 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk_multi_abd/device_grouped_gemm_xdl_fixed_nk_bf16_i8_bf16_km_kn_mn_common.hpp +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk_multi_abd/device_grouped_gemm_xdl_fixed_nk_bf16_i8_bf16_km_kn_mn_common.hpp @@ -61,8 +61,6 @@ static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecial static constexpr auto GemmMNPadding = ck::tensor_operation::device::GemmSpecialization::MNPadding; static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; -// NOTE: After adding unit tests for DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK it tuned out that -// portion of the instances are failing. As a workaround these have been commented out. template , S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - // DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 64, 32, 8, 2, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>, DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - // DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 64, 128, 32, 8, 2, 32, 32, 1, 2, S<4, 32, 2>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 64, 128, 32, 8, 2, 32, 32, 1, 2, S<4, 32, 2>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>, DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 2>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - // DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 128, 128, 64, 32, 8, 2, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 128, 128, 64, 32, 8, 2, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 4>, 8>, DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, - // DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 128, 64, 128, 32, 8, 2, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 128, 64, 128, 32, 8, 2, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 16, 1, 8>, 8>, DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8> // clang-format on >; diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk_multi_abd/device_grouped_gemm_xdl_fixed_nk_bf16_i8_bf16_mk_kn_mn_common.hpp b/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk_multi_abd/device_grouped_gemm_xdl_fixed_nk_bf16_i8_bf16_mk_kn_mn_common.hpp index 5eedb8b5ee..0560f159fc 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk_multi_abd/device_grouped_gemm_xdl_fixed_nk_bf16_i8_bf16_mk_kn_mn_common.hpp +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk_multi_abd/device_grouped_gemm_xdl_fixed_nk_bf16_i8_bf16_mk_kn_mn_common.hpp @@ -61,8 +61,6 @@ static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecial static constexpr auto GemmMNPadding = ck::tensor_operation::device::GemmSpecialization::MNPadding; static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; -// NOTE: After adding unit tests for DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK it tuned out that -// portion of the instances are failing. As a workaround these have been commented out. template , S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - // DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 64, 32, 8, 2, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>, DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - // DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 64, 128, 32, 8, 2, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 64, 128, 32, 8, 2, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>, DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - // DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 128, 128, 64, 32, 8, 2, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 128, 128, 64, 32, 8, 2, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 4>, 8>, DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, - // DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 128, 64, 128, 32, 8, 2, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 128, 64, 128, 32, 8, 2, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 16, 1, 8>, 8>, DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8> // clang-format on >; diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk_multi_abd/device_grouped_gemm_xdl_fixed_nk_bf16_i8_bf16_mk_nk_mn_common.hpp b/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk_multi_abd/device_grouped_gemm_xdl_fixed_nk_bf16_i8_bf16_mk_nk_mn_common.hpp index 7d1fcb5552..95365c82e7 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk_multi_abd/device_grouped_gemm_xdl_fixed_nk_bf16_i8_bf16_mk_nk_mn_common.hpp +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk_multi_abd/device_grouped_gemm_xdl_fixed_nk_bf16_i8_bf16_mk_nk_mn_common.hpp @@ -61,8 +61,6 @@ static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecial static constexpr auto GemmMNPadding = ck::tensor_operation::device::GemmSpecialization::MNPadding; static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; -// NOTE: After adding unit tests for DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK it tuned out that -// portion of the instances are failing. As a workaround these have been commented out. template +auto concat_tuple_of_refs(ck::Tuple& tx, ck::Tuple& ty) +{ + return ck::unpack2( + [&](auto&&... zs) { return ck::Tuple{ck::forward(zs)...}; }, + tx, + ty); +} + template c_m_n({M, N}); + using AComputeType = typename std::conditional<(NumATensor > 1), EDataType, remove_cvref_t>>::type; + Tensor a_m_k({M, K}); + for(int m = 0; m < M; ++m) + { + for(int k = 0; k < K; ++k) + { + // result + auto data_refs1 = ck::tie(a_m_k(m, k)); + // inputs + auto data_refs2 = + generate_tie([&](auto i) -> auto& { return as_m_k(Number{})(m, k); }, + Number{}); + auto data_refs = concat_tuple_of_refs(data_refs1, data_refs2); + unpack(a_element_op, data_refs); + } + } + using BComputeType = typename std::conditional<(NumBTensor > 1), EDataType, remove_cvref_t>>::type; - using ReferenceGemmInstance = - ck::tensor_operation::host::ReferenceGemmMultiABD; + Tensor b_k_n({K, N}); + for(int k = 0; k < K; ++k) + { + for(int n = 0; n < N; ++n) + { + // result + auto data_refs1 = ck::tie(b_k_n(k, n)); + // inputs + auto data_refs2 = + generate_tie([&](auto i) -> auto& { return bs_k_n(Number{})(k, n); }, + Number{}); + auto data_refs = concat_tuple_of_refs(data_refs1, data_refs2); + unpack(b_element_op, data_refs); + } + } - auto ref_gemm = ReferenceGemmInstance{}; - auto ref_invoker = ref_gemm.MakeInvoker(); + using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); - auto ref_argument = ref_gemm.MakeArgument( - as_m_k, bs_k_n, ds_m_n, e_m_n_host_result, a_element_op, b_element_op, cde_element_op); + auto ref_argument = + ref_gemm.MakeArgument(a_m_k, b_k_n, c_m_n, PassThrough{}, PassThrough{}, PassThrough{}); ref_invoker.Run(ref_argument); + + for(int m = 0; m < M; ++m) + { + for(int n = 0; n < N; ++n) + { + // compulsory + auto data_refs1 = ck::tie(e_m_n_host_result(m, n), c_m_n(m, n)); + // optional (if multiple Ds) + auto data_refs2 = + generate_tie([&](auto i) -> auto& { return ds_m_n(Number{})(m, n); }, + Number{}); + auto data_refs = concat_tuple_of_refs(data_refs1, data_refs2); + unpack(cde_element_op, data_refs); + } + } } std::array as_device_buf; diff --git a/profiler/include/profiler/profile_grouped_gemm_multi_abd_fixed_nk_impl.hpp b/profiler/include/profiler/profile_grouped_gemm_multi_abd_fixed_nk_impl.hpp deleted file mode 100644 index eea72b324d..0000000000 --- a/profiler/include/profiler/profile_grouped_gemm_multi_abd_fixed_nk_impl.hpp +++ /dev/null @@ -1,534 +0,0 @@ -// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. -// SPDX-License-Identifier: MIT - -#pragma once - -#include -#include -#include -#include - -#include "ck/ck.hpp" -#include "ck/utility/env.hpp" -#include "ck/utility/tuple.hpp" -#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" -#include "ck/tensor_operation/gpu/device/device_grouped_gemm_multi_abd_fixed_nk.hpp" -#include "ck/tensor_operation/gpu/device/device_grouped_gemm.hpp" -#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" - -#include "ck/library/reference_tensor_operation/cpu/reference_gemm_multi_abd.hpp" -#include "ck/library/tensor_operation_instance/gpu/grouped_gemm_multi_abd_fixed_nk.hpp" - -#include "ck/library/utility/check_err.hpp" -#include "ck/library/utility/convolution_parameter.hpp" -#include "ck/library/utility/device_memory.hpp" -#include "ck/library/utility/host_tensor.hpp" -#include "ck/library/utility/host_tensor_generator.hpp" -#include "ck/library/utility/literals.hpp" -#include "ck/library/utility/fill.hpp" - -namespace ck { -namespace profiler { - -template -auto reserveVector(std::size_t size) -{ - std::vector vec; - vec.reserve(size); - return vec; -} - -template -bool profile_grouped_gemm_multi_abd_fixed_nk_impl(int do_verification, - int init_method, - bool do_log, - bool time_kernel, - const std::vector& Ms, - const std::vector& Ns, - const std::vector& Ks, - const std::vector& StrideAs, - const std::vector& StrideBs, - const std::vector& StrideDs, - const std::vector& StrideE, - const std::vector& kbatch_list = {1}, - int n_warmup = 1, - int n_iter = 10) -{ - bool pass = true; - - auto f_host_tensor_descriptor = - [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { - using namespace ck::literals; - - if(is_same::value) - { - return HostTensorDescriptor({row, col}, {stride, 1_uz}); - } - else - { - return HostTensorDescriptor({row, col}, {1_uz, stride}); - } - }; - - const std::size_t group_count = Ms.size(); - const int sum_of_m = std::accumulate(Ms.begin(), Ms.end(), 0); - - static constexpr index_t NumATensor = AsDataType::Size(); - static constexpr index_t NumBTensor = BsDataType::Size(); - static constexpr index_t NumDTensor = DsDataType::Size(); - - if(group_count != Ns.size() || group_count != Ks.size() || group_count != StrideAs.size() || - group_count != StrideBs.size() || (NumDTensor > 0 && group_count != StrideDs.size())) - { - throw std::runtime_error("wrong! inconsistent M/N/Ks, StrideAs/Bs/Ds/E size\n"); - } - - auto generateInputTupleA = [&](std::size_t g) { - if constexpr(NumATensor == 0) - { - static_assert("Gemm problem should have at least 1 A tensor."); - } - else - { - using ALayout = remove_cvref_t{}, AsLayout>>; - return generate_tuple( - [&](auto i) { - using ADataType = remove_cvref_t>; - return Tensor( - f_host_tensor_descriptor(Ms[g], Ks[g], StrideAs[g], ALayout{})); - }, - Number{}); - } - }; - auto generateInputTupleB = [&](std::size_t g) { - if constexpr(NumBTensor == 0) - { - static_assert("Gemm problem should have at least 1 B tensor."); - } - else - { - using BLayout = remove_cvref_t{}, BsLayout>>; - return generate_tuple( - [&](auto i) { - using BDataType = remove_cvref_t>; - return Tensor( - f_host_tensor_descriptor(Ks[g], Ns[g], StrideBs[g], BLayout{})); - }, - Number{}); - } - }; - auto generateInputTupleD = [&](std::size_t g) { - if constexpr(NumDTensor == 0) - { - return ck::Tuple<>(); - } - else - { - using DLayout = remove_cvref_t{}, DsLayout>>; - return generate_tuple( - [&](auto i) { - using DDataType = remove_cvref_t>; - return Tensor( - f_host_tensor_descriptor(Ms[g], Ns[g], StrideDs[g], DLayout{})); - }, - Number{}); - } - }; - - using AsTensorTuple = decltype(generateInputTupleA(0)); - using BsTensorTuple = decltype(generateInputTupleB(0)); - using DsTensorTuple = decltype(generateInputTupleD(0)); - - auto g_as_m_k = reserveVector(group_count); - auto g_bs_k_n = reserveVector(group_count); - auto g_ds_m_n = reserveVector(group_count); - auto g_e_m_n_host_results = reserveVector>(group_count); - auto g_e_m_n_device_results = reserveVector>(group_count); - - for(std::size_t g = 0; g < group_count; g++) - { - auto& as_m_k = g_as_m_k.emplace_back(generateInputTupleA(g)); - auto& bs_k_n = g_bs_k_n.emplace_back(generateInputTupleB(g)); - auto& ds_m_n = g_ds_m_n.emplace_back(generateInputTupleD(g)); - - g_e_m_n_host_results.push_back( - Tensor(f_host_tensor_descriptor(Ms[g], Ns[g], StrideE[g], ELayout{}))); - g_e_m_n_device_results.push_back( - Tensor(f_host_tensor_descriptor(Ms[g], Ns[g], StrideE[g], ELayout{}))); - - if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) - { - std::cout << "group: " << g << std::endl; - static_for<0, NumATensor, 1>{}([&](auto i) { - std::cout << "a" << i.value << "_m_k: " << as_m_k(i).mDesc << std::endl; - }); - static_for<0, NumBTensor, 1>{}([&](auto i) { - std::cout << "b" << i.value << "_k_n: " << bs_k_n(i).mDesc << std::endl; - }); - static_for<0, NumDTensor, 1>{}([&](auto i) { - std::cout << "d" << i.value << "_m_n: " << ds_m_n(i).mDesc << std::endl; - }); - std::cout << "e_m_n: " << g_e_m_n_device_results[g].mDesc << std::endl; - } - - std::size_t num_thread = 1; - switch(init_method) - { - case 0: break; - case 1: - static_for<0, NumATensor, 1>{}([&](auto i) { - using ADataType = remove_cvref_t>; - as_m_k(i).GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); - }); - - static_for<0, NumBTensor, 1>{}([&](auto i) { - using BDataType = remove_cvref_t>; - bs_k_n(i).GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); - }); - - static_for<0, NumDTensor, 1>{}([&](auto i) { - using DDataType = remove_cvref_t>; - ds_m_n(i).GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); - }); - - break; - default: - static_for<0, NumATensor, 1>{}([&](auto i) { - using ADataType = remove_cvref_t>; - as_m_k(i).GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}, num_thread); - }); - - static_for<0, NumBTensor, 1>{}([&](auto i) { - using BDataType = remove_cvref_t>; - bs_k_n(i).GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}, num_thread); - }); - - static_for<0, NumDTensor, 1>{}([&](auto i) { - using DDataType = remove_cvref_t>; - ds_m_n(i).GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}, num_thread); - }); - } - } - - const auto a_element_op = AElementOp{}; - const auto b_element_op = BElementOp{}; - const auto cde_element_op = CDEElementOp{}; - - using DeviceMemPtr = std::unique_ptr; - std::vector> g_as_device_buf(group_count); - std::vector> g_bs_device_buf(group_count); - std::vector> g_ds_device_buf(group_count); - std::vector g_e_device_buf(group_count); - - std::vector> g_as_device_view(group_count); - std::vector> g_bs_device_view(group_count); - std::vector> g_ds_device_view(group_count); - std::vector g_e_device_view(group_count); - - auto g_gemm_descs = reserveVector(group_count); - - auto grouped_gemm_kernel_args_host = - reserveVector>( - group_count); - - for(std::size_t g = 0; g < group_count; g++) - { - std::array as_stride; - std::array bs_stride; - std::array ds_stride; - - auto& as_m_k = g_as_m_k[g]; - auto& as_device_buf = g_as_device_buf[g]; - auto& as_device_view = g_as_device_view[g]; - - static_for<0, NumATensor, 1>{}([&](auto i) { - using ADataType = remove_cvref_t>; - as_device_buf[i] = std::make_unique(sizeof(ADataType) * Ms[g] * Ks[g]); - as_device_buf[i]->ToDevice(as_m_k[i].mData.data()); - as_device_view[i] = as_device_buf[i]->GetDeviceBuffer(); - as_stride[i] = StrideAs[g]; - }); - - auto& bs_k_n = g_bs_k_n[g]; - auto& bs_device_buf = g_bs_device_buf[g]; - auto& bs_device_view = g_bs_device_view[g]; - - static_for<0, NumBTensor, 1>{}([&](auto i) { - using BDataType = remove_cvref_t>; - bs_device_buf[i] = std::make_unique(sizeof(BDataType) * Ks[g] * Ns[g]); - bs_device_buf[i]->ToDevice(bs_k_n[i].mData.data()); - bs_device_view[i] = bs_device_buf[i]->GetDeviceBuffer(); - bs_stride[i] = StrideBs[g]; - }); - - auto& ds_m_n = g_ds_m_n[g]; - auto& ds_device_buf = g_ds_device_buf[g]; - auto& ds_device_view = g_ds_device_view[g]; - - static_for<0, NumDTensor, 1>{}([&](auto i) { - using DDataType = remove_cvref_t>; - ds_device_buf[i] = std::make_unique(sizeof(DDataType) * Ms[g] * Ns[g]); - ds_device_buf[i]->ToDevice(ds_m_n[i].mData.data()); - ds_device_view[i] = ds_device_buf[i]->GetDeviceBuffer(); - ds_stride[i] = StrideDs[g]; - }); - - g_e_device_buf[g] = std::make_unique(sizeof(EDataType) * Ms[g] * Ns[g]); - g_e_device_view[g] = g_e_device_buf[g]->GetDeviceBuffer(); - - g_gemm_descs.push_back(tensor_operation::device::GemmMultiABDDesc{ - sum_of_m, - Ns[g], - Ks[g], - std::vector(as_stride.begin(), as_stride.end()), - std::vector(bs_stride.begin(), bs_stride.end()), - std::vector(ds_stride.begin(), ds_stride.end()), - StrideE[g]}); - - tensor_operation::device:: - GroupedGemmMultiABDKernelArgument - kernelArg{as_device_view, - bs_device_view, - ds_device_view, - g_e_device_view[g], - Ms[g], - Ns[g], - Ks[g], - as_stride, - bs_stride, - ds_stride, - StrideE[g]}; - - grouped_gemm_kernel_args_host.push_back(std::move(kernelArg)); - } - - using DeviceOp = tensor_operation::device::DeviceGroupedGemmMultiABDFixedNK; - - const auto op_ptrs = tensor_operation::device::instance::DeviceOperationInstanceFactory< - DeviceOp>::GetInstances(); - - if(op_ptrs.size() <= 0) - { - throw std::runtime_error("wrong! no device GEMM instance found"); - } - - std::string best_gemm_name; - float best_ave_time = 0; - float best_tflops = 0; - float best_gb_per_sec = 0; - float best_kbatch = 0; - - if(do_verification) - { - using AComputeType = - typename std::conditional<(NumATensor > 1), - EDataType, - remove_cvref_t>>::type; - - using BComputeType = - typename std::conditional<(NumBTensor > 1), - EDataType, - remove_cvref_t>>::type; - - using ReferenceGemmInstance = - ck::tensor_operation::host::ReferenceGemmMultiABD; - - auto ref_gemm = ReferenceGemmInstance{}; - auto ref_invoker = ref_gemm.MakeInvoker(); - - for(std::size_t i = 0; i < group_count; i++) - { - auto ref_argument = ref_gemm.MakeArgument(g_as_m_k[i], - g_bs_k_n[i], - g_ds_m_n[i], - g_e_m_n_host_results[i], - a_element_op, - b_element_op, - cde_element_op); - - ref_invoker.Run(ref_argument); - } - } - - // profile device GEMM instances - for(auto& gemm_ptr : op_ptrs) - { - auto argument_ptr = gemm_ptr->MakeArgumentPointer( - g_as_device_view, g_bs_device_view, g_ds_device_view, g_e_device_view, g_gemm_descs); - - if(!gemm_ptr->IsSupportedArgument(argument_ptr.get())) - { - if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) - { - std::cout << "Gemm incompatible with runtime set parameters. Skipping..." - << std::endl; - } - - continue; - } - - DeviceMem gemm_workspace_dev(gemm_ptr->GetWorkSpaceSize(argument_ptr.get())); - gemm_ptr->SetWorkSpacePointer(argument_ptr.get(), gemm_workspace_dev.GetDeviceBuffer()); - - DeviceMem grouped_gemm_kernel_args_dev( - gemm_ptr->GetDeviceKernelArgSize(argument_ptr.get())); - hipGetErrorString(hipMemcpy(grouped_gemm_kernel_args_dev.GetDeviceBuffer(), - grouped_gemm_kernel_args_host.data(), - gemm_ptr->GetDeviceKernelArgSize(argument_ptr.get()), - hipMemcpyHostToDevice)); - - gemm_ptr->SetDeviceKernelArgs(argument_ptr.get(), - grouped_gemm_kernel_args_dev.GetDeviceBuffer()); - gemm_ptr->SetElementwiseOps(argument_ptr.get(), a_element_op, b_element_op, cde_element_op); - - auto invoker_ptr = gemm_ptr->MakeInvokerPointer(); - - std::string gemm_name = gemm_ptr->GetTypeString(); - - for(const auto kbatch_curr : kbatch_list) - { - gemm_ptr->SetKBatch(argument_ptr.get(), kbatch_curr); - - if(gemm_ptr->IsSupportedArgument(argument_ptr.get())) - { - for(std::size_t g = 0; g < group_count; g++) - { - g_e_device_buf[g]->SetZero(); - } - - float ave_time = invoker_ptr->Run( - argument_ptr.get(), StreamConfig{nullptr, time_kernel, 0, n_warmup, n_iter}); - - if(do_verification) - { - bool instance_pass = true; - for(std::size_t g = 0; g < group_count; g++) - { - g_e_device_buf[g]->FromDevice( - g_e_m_n_device_results[g].mData.data(), - g_e_m_n_device_results[g].mDesc.GetElementSize() * sizeof(EDataType)); - - instance_pass = - instance_pass && ck::utils::check_err(g_e_m_n_device_results[g], - g_e_m_n_host_results[g]); - - if(do_log) - { - static_for<0, NumATensor, 1>{}([&](auto i) { - LogRangeAsType( - std::cout << "a[" << g << "]: ", g_as_m_k[g](i).mData, ",") - << std::endl; - }); - static_for<0, NumBTensor, 1>{}([&](auto i) { - LogRangeAsType( - std::cout << "b[" << g << "]: ", g_bs_k_n[g](i).mData, ",") - << std::endl; - }); - static_for<0, NumDTensor, 1>{}([&](auto i) { - LogRangeAsType( - std::cout << "d[" << g << "]: ", g_ds_m_n[g](i).mData, ",") - << std::endl; - }); - LogRangeAsType( - std::cout << "e_device: ", g_e_m_n_device_results[g].mData, ",") - << std::endl; - LogRangeAsType( - std::cout << "e_host : ", g_e_m_n_host_results[g].mData, ",") - << std::endl; - } - } - - std::cout << "Instance: " << gemm_name << " verification " - << (instance_pass ? "SUCCEED" : "FAILED") << std::endl; - - pass = pass && instance_pass; - } - - if(time_kernel) - { - std::size_t flop = 0, num_btype = 0; - for(std::size_t g = 0; g < group_count; g++) - { - flop += std::size_t(2) * Ms[g] * Ns[g] * Ks[g]; - - static_for<0, NumATensor, 1>{}([&](auto i) { - using ADataType = remove_cvref_t>; - num_btype += sizeof(ADataType) * Ms[g] * Ks[g]; - }); - static_for<0, NumBTensor, 1>{}([&](auto i) { - using BDataType = remove_cvref_t>; - num_btype += sizeof(BDataType) * Ks[g] * Ns[g]; - }); - static_for<0, NumDTensor, 1>{}([&](auto i) { - using DDataType = remove_cvref_t>; - num_btype += sizeof(DDataType) * Ms[g] * Ns[g]; - }); - } - - float tflops = static_cast(flop) / 1.E9 / ave_time; - - float gb_per_sec = num_btype / 1.E6 / ave_time; - std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops - << " TFlops, " << gb_per_sec << " GB/s, " << gemm_name << ", KBatch " - << kbatch_curr << std::endl; - - if(tflops > best_tflops) - { - best_gemm_name = gemm_name; - best_tflops = tflops; - best_ave_time = ave_time; - best_gb_per_sec = gb_per_sec; - best_kbatch = kbatch_curr; - } - } - } - else - { - std::cout << "Instance: " << gemm_name << ", does not support this GEMM problem" - << std::endl; - } - } - } - - if(time_kernel) - { - std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, " - << best_gb_per_sec << " GB/s, " << best_gemm_name << ", KBatch = " << best_kbatch - << std::endl; - } - return pass; -} - -} // namespace profiler -} // namespace ck diff --git a/test/grouped_gemm/CMakeLists.txt b/test/grouped_gemm/CMakeLists.txt index bc79c85e59..450950cbd6 100644 --- a/test/grouped_gemm/CMakeLists.txt +++ b/test/grouped_gemm/CMakeLists.txt @@ -18,12 +18,6 @@ if (CK_USE_XDL OR CK_USE_WMMA) target_link_libraries(test_grouped_gemm_fastgelu PRIVATE utility device_grouped_gemm_fastgelu_instance) add_dependencies(test_grouped_gemm test_grouped_gemm_fastgelu) endif() - - add_gtest_executable(test_grouped_gemm_multi_abd_fixed_nk test_grouped_gemm_multi_abd_fixed_nk.cpp) - if(result EQUAL 0) - target_link_libraries(test_grouped_gemm_multi_abd_fixed_nk PRIVATE utility device_grouped_gemm_fixed_nk_multi_abd_instance) - add_dependencies(test_grouped_gemm test_grouped_gemm_multi_abd_fixed_nk) - endif() endif() add_gtest_executable(test_grouped_gemm_interface test_grouped_gemm_interface_xdl.cpp) diff --git a/test/grouped_gemm/test_grouped_gemm_multi_abd_fixed_nk.cpp b/test/grouped_gemm/test_grouped_gemm_multi_abd_fixed_nk.cpp deleted file mode 100644 index 610e7f2b77..0000000000 --- a/test/grouped_gemm/test_grouped_gemm_multi_abd_fixed_nk.cpp +++ /dev/null @@ -1,256 +0,0 @@ -// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. -// SPDX-License-Identifier: MIT - -#include -#include -#include - -#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" -#include "ck/utility/data_type.hpp" - -#include "ck/ck.hpp" -#include "ck/utility/type.hpp" -#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" -#include "profiler/profile_grouped_gemm_multi_abd_fixed_nk_impl.hpp" - -#include "gtest/gtest.h" - -static ck::index_t param_mask = 0xffffff; -static ck::index_t instance_index = -1; - -using FP32 = float; -using FP16 = ck::half_t; -using BF16 = ck::bhalf_t; -using I8 = int8_t; -using Row = ck::tensor_layout::gemm::RowMajor; -using Col = ck::tensor_layout::gemm::ColumnMajor; - -using PassThrough = ck::tensor_operation::element_wise::PassThrough; -using AddFastGelu = ck::tensor_operation::element_wise::AddFastGelu; -using Add = ck::tensor_operation::element_wise::Add; -using Multiply = ck::tensor_operation::element_wise::Multiply; -using FastGelu = ck::tensor_operation::element_wise::FastGelu; - -// clang-format off -using KernelTypes = ::testing::Types< - std::tuple, ck::Tuple, ck::Tuple, BF16, ck::Tuple, ck::Tuple, ck::Tuple, Row, AddFastGelu>, - std::tuple, ck::Tuple, ck::Tuple, BF16, ck::Tuple, ck::Tuple, ck::Tuple, Row, AddFastGelu>, - std::tuple, ck::Tuple, ck::Tuple, BF16, ck::Tuple, ck::Tuple, ck::Tuple, Row, AddFastGelu>, - std::tuple, ck::Tuple, ck::Tuple, BF16, ck::Tuple, ck::Tuple, ck::Tuple, Row, Add>, - std::tuple, ck::Tuple, ck::Tuple, BF16, ck::Tuple, ck::Tuple, ck::Tuple, Row, Add>, - std::tuple, ck::Tuple, ck::Tuple, BF16, ck::Tuple, ck::Tuple, ck::Tuple, Row, Add>, - std::tuple, ck::Tuple, ck::Tuple<>, BF16, ck::Tuple, ck::Tuple, ck::Tuple<>, Row, PassThrough>, - std::tuple, ck::Tuple, ck::Tuple<>, BF16, ck::Tuple, ck::Tuple, ck::Tuple<>, Row, PassThrough>, - std::tuple, ck::Tuple, ck::Tuple<>, BF16, ck::Tuple, ck::Tuple, ck::Tuple<>, Row, PassThrough>, - std::tuple, ck::Tuple, ck::Tuple<>, BF16, ck::Tuple, ck::Tuple, ck::Tuple<>, Row, FastGelu>, - std::tuple, ck::Tuple, ck::Tuple<>, BF16, ck::Tuple, ck::Tuple, ck::Tuple<>, Row, FastGelu>, - std::tuple, ck::Tuple, ck::Tuple<>, BF16, ck::Tuple, ck::Tuple, ck::Tuple<>, Row, FastGelu> ->; -// clang-format on - -template -class TestGroupedGemmMultiABDFixedNK : public testing::Test -{ - protected: - using AsDataType = std::tuple_element_t<0, Tuple>; - using BsDataType = std::tuple_element_t<1, Tuple>; - using DsDataType = std::tuple_element_t<2, Tuple>; - using EDataType = std::tuple_element_t<3, Tuple>; - using AccDataType = float; - using AsLayout = std::tuple_element_t<4, Tuple>; - using BsLayout = std::tuple_element_t<5, Tuple>; - using DsLayout = std::tuple_element_t<6, Tuple>; - using ELayout = std::tuple_element_t<7, Tuple>; - using AElementOp = PassThrough; - using BElementOp = Multiply; - using CDEElementOp = std::tuple_element_t<8, Tuple>; - - using Row = ck::tensor_layout::gemm::RowMajor; - using Col = ck::tensor_layout::gemm::ColumnMajor; - - public: - static constexpr bool verify_ = true; - static constexpr int init_method_ = 1; // integer value initialization - static constexpr bool log_ = false; - static constexpr bool bench_ = false; // measure kernel performance - static constexpr int n_warmup_ = 0; - static constexpr int n_iter_ = 1; - - std::vector k_batches_ = {1}; - - private: - template - void SetStrides(std::vector& strides, - const std::vector& rows, - const std::vector& cols) const - { - if(std::is_same_v) - { - for(const auto c : cols) - { - strides.emplace_back(c); - } - } - else if(std::is_same_v) - { - for(const auto r : rows) - { - strides.emplace_back(r); - } - } - } - - template - void SetTupleStrides(std::vector& strides, - const std::vector& rows, - const std::vector& cols) const - { - if constexpr(Layouts::Size() > 0) - { - // As of now multi ABD implementation supports only tensors with matching layouts. - using Layout = ck::remove_cvref_t{}, Layouts>>; - SetStrides(strides, rows, cols); - } - } - - public: - void Run(const std::vector& Ms, - const std::vector& Ns, - const std::vector& Ks, - const std::vector& StrideAs = {}, - const std::vector& StrideBs = {}, - const std::vector& StrideDs = {}, - const std::vector& StrideE = {}) - { - std::vector stride_as = StrideAs; - std::vector stride_bs = StrideBs; - std::vector stride_ds = StrideDs; - std::vector stride_e = StrideE; - - if(stride_as.empty()) - { - SetTupleStrides(stride_as, Ms, Ks); - } - if(stride_bs.empty()) - { - SetTupleStrides(stride_bs, Ks, Ns); - } - if(stride_ds.empty()) - { - SetTupleStrides(stride_ds, Ms, Ns); - } - if(stride_e.empty()) - { - SetStrides(stride_e, Ms, Ns); - } - - RunSingle(Ms, Ns, Ks, stride_as, stride_bs, stride_ds, stride_e); - } - - void RunSingle(const std::vector& Ms, - const std::vector& Ns, - const std::vector& Ks, - const std::vector& StrideAs, - const std::vector& StrideBs, - const std::vector& StrideDs, - const std::vector& StrideE) - { - bool pass = - ck::profiler::profile_grouped_gemm_multi_abd_fixed_nk_impl(verify_, - init_method_, - log_, - bench_, - Ms, - Ns, - Ks, - StrideAs, - StrideBs, - StrideDs, - StrideE, - k_batches_, - n_warmup_, - n_iter_); - EXPECT_TRUE(pass); - } -}; - -TYPED_TEST_SUITE(TestGroupedGemmMultiABDFixedNK, KernelTypes); - -TYPED_TEST(TestGroupedGemmMultiABDFixedNK, TinyCases) -{ - const std::vector Ms{3, 4}; - constexpr int N = 8; - constexpr int K = 64; - - const std::vector Ns(Ms.size(), N); - const std::vector Ks(Ms.size(), K); - - this->Run(Ms, Ns, Ks); -} - -TYPED_TEST(TestGroupedGemmMultiABDFixedNK, SmallCases) -{ - const std::vector Ms{3, 5, 16, 7, 8}; - constexpr int N = 768; - constexpr int K = 544; - - const std::vector Ns(Ms.size(), N); - const std::vector Ks(Ms.size(), K); - - this->Run(Ms, Ns, Ks); -} - -TYPED_TEST(TestGroupedGemmMultiABDFixedNK, MidCases) -{ - const std::vector Ms{167, 183, 177, 153, 139, 204}; - constexpr int N = 768; - constexpr int K = 544; - - const std::vector Ns(Ms.size(), N); - const std::vector Ks(Ms.size(), K); - - this->Run(Ms, Ns, Ks); -} - -TYPED_TEST(TestGroupedGemmMultiABDFixedNK, Regular) -{ - const std::vector Ms{64, 128, 256}; - constexpr int N = 768; - constexpr int K = 320; - - const std::vector Ns(Ms.size(), N); - const std::vector Ks(Ms.size(), K); - - this->Run(Ms, Ns, Ks); -} - -int main(int argc, char** argv) -{ - testing::InitGoogleTest(&argc, argv); - if(argc == 1) - { - // Run with default arguments. - } - else if(argc == 3) - { - param_mask = strtol(argv[1], nullptr, 0); - instance_index = atoi(argv[2]); - } - else - { - std::cout << "Usage of " << argv[0] << std::endl; - std::cout << "Arg1,2: param_mask instance_index(-1 means all)" << std::endl; - } - return RUN_ALL_TESTS(); -}