From 2b452ad13599cab90fd5b575a40f6cb4b796039d Mon Sep 17 00:00:00 2001 From: Adam Osewski <19374865+aosewski@users.noreply.github.com> Date: Thu, 25 Apr 2024 22:12:53 +0200 Subject: [PATCH] Grouped GEMM Multiple D tile loop. (#1247) * Overload output stream operator for LoopScheduler and PiplineVersion * Add Run overload accepting grid descriptors MK. * Add __device__ keyword for CalculateGridSize * Create device op GroupedGemmMultipleD * Add GroupedGemm MultipleD Tile Loop implementation. * Add an example for GroupedGemm MultipleD tile loop. * Device Op GroupedGEMMTileLoop. * Bunch of small changes in exmaple. * CkProfiler * Remove unused tparam. * Fix include statement. * Fix output stream overloads. * Do not make descriptors and check validity untill we find group. * Fix gemm desc initialization. * Revert device op * Fix compilation for DTYPES=FP16 * Validate tensor transfers paramters. * Validate on host only NK dims if M is not known. * Fix bug. * A convenient debug func for selecting threads. * Fix has main k block loop bug. * Make sure that b2c has up to date tile offset. * Output stream operator for Sequence type. * Cmake file formatting. [ROCm/composable_kernel commit: b4032629e57aa3e4aab01ed52a263886bb92acc1] --- example/15_grouped_gemm/CMakeLists.txt | 3 + .../grouped_gemm_multiple_d_xdl_fp16.cpp | 403 +++++++++ .../device/device_grouped_gemm_tile_loop.hpp | 128 +++ ...gemm_multiple_d_xdl_cshuffle_tile_loop.hpp | 787 ++++++++++++++++++ .../gpu/grid/block_to_ctile_map.hpp | 13 +- .../gridwise_gemm_multiple_d_xdl_cshuffle.hpp | 126 ++- .../grid/gridwise_gemm_pipeline_selector.hpp | 16 +- include/ck/utility/debug.hpp | 9 +- include/ck/utility/loop_scheduler.hpp | 14 +- include/ck/utility/sequence.hpp | 15 +- .../gpu/grouped_gemm_tile_loop.hpp | 108 +++ .../gpu/grouped_gemm_tile_loop/CMakeLists.txt | 9 + ...ile_loop_f16_f16_f16_mk_kn_mn_instance.cpp | 75 ++ ...ile_loop_f16_f16_f16_mk_nk_mn_instance.cpp | 77 ++ .../profile_grouped_gemm_fixed_nk_impl.hpp | 4 +- .../profile_grouped_gemm_tile_loop_impl.hpp | 319 +++++++ profiler/src/CMakeLists.txt | 2 + .../src/profile_grouped_gemm_tile_loop.cpp | 152 ++++ test/normalization_bwd_data/CMakeLists.txt | 13 +- .../CMakeLists.txt | 13 +- 20 files changed, 2264 insertions(+), 22 deletions(-) create mode 100644 example/15_grouped_gemm/grouped_gemm_multiple_d_xdl_fp16.cpp create mode 100644 include/ck/tensor_operation/gpu/device/device_grouped_gemm_tile_loop.hpp create mode 100644 include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_xdl_cshuffle_tile_loop.hpp create mode 100644 library/include/ck/library/tensor_operation_instance/gpu/grouped_gemm_tile_loop.hpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_gemm_tile_loop/CMakeLists.txt create mode 100644 library/src/tensor_operation_instance/gpu/grouped_gemm_tile_loop/device_grouped_gemm_xdl_tile_loop_f16_f16_f16_mk_kn_mn_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_gemm_tile_loop/device_grouped_gemm_xdl_tile_loop_f16_f16_f16_mk_nk_mn_instance.cpp create mode 100644 profiler/include/profiler/profile_grouped_gemm_tile_loop_impl.hpp create mode 100644 profiler/src/profile_grouped_gemm_tile_loop.cpp diff --git a/example/15_grouped_gemm/CMakeLists.txt b/example/15_grouped_gemm/CMakeLists.txt index 550dafb066..20cbc5fdca 100644 --- a/example/15_grouped_gemm/CMakeLists.txt +++ b/example/15_grouped_gemm/CMakeLists.txt @@ -26,6 +26,9 @@ add_example_dependencies(example_grouped_gemm_xdl example_grouped_gemm_xdl_int8) add_example_executable(example_grouped_gemm_xdl_fixed_nk_fp16_fp8 grouped_gemm_xdl_fixed_nk_fp16_fp8.cpp) add_example_dependencies(example_grouped_gemm_xdl example_grouped_gemm_xdl_fixed_nk_fp16_fp8) +add_example_executable(example_grouped_gemm_multiple_d_xdl_fp16 grouped_gemm_multiple_d_xdl_fp16.cpp) +add_example_dependencies(example_grouped_gemm_xdl example_grouped_gemm_multiple_d_xdl_fp16) + if(USE_BITINT_EXTENSION_INT4) add_example_executable(example_grouped_gemm_xdl_int4 grouped_gemm_xdl_int4.cpp) add_example_dependencies(example_grouped_gemm_xdl example_grouped_gemm_xdl_int4) diff --git a/example/15_grouped_gemm/grouped_gemm_multiple_d_xdl_fp16.cpp b/example/15_grouped_gemm/grouped_gemm_multiple_d_xdl_fp16.cpp new file mode 100644 index 0000000000..d80c163e3f --- /dev/null +++ b/example/15_grouped_gemm/grouped_gemm_multiple_d_xdl_fp16.cpp @@ -0,0 +1,403 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_xdl_cshuffle_tile_loop.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_gemm_tile_loop.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include +#include + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm_multiple_d.hpp" + +template +using S = ck::Sequence; + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using AddAdd = ck::tensor_operation::element_wise::AddAdd; + +using ADataType = F16; +using BDataType = F16; +using AccDataType = F32; +using CShuffleDataType = F32; +using DDataType = F16; +using DsDataType = ck::Tuple; +using EDataType = F16; + +using ALayout = Row; +using BLayout = Col; +using DLayout = Row; +using DsLayout = ck::Tuple; +using ELayout = Row; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CDEElementOp = AddAdd; + +static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; +static constexpr int NumDs = 2; + +using DeviceGemmInstance = + ck::tensor_operation::device::DeviceGroupedGemmMultipleDXdlCShuffleTileLoop + // clang-format off +//######| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| +//######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| +//######| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| +//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + < ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 4>; +// clang-format on + +struct ProblemSize final +{ + std::vector Ms; + std::vector Ns; + std::vector Ks; + + std::vector stride_As; + std::vector stride_Bs; + std::vector> stride_Ds; + std::vector stride_Cs; + + ck::index_t group_count; +}; + +struct ExecutionConfig final +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = true; +}; + +bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& config) +{ + auto group_count = problem_size.group_count; + + using KernelArguments = ck::tensor_operation::device::GroupedGemmTileLoopKernelArguments; + + // GEMM shape + std::vector gemm_descs; + std::vector ggemm_kargs; + std::vector p_Cs; + std::vector p_As; + std::vector p_Bs; + std::vector> p_Ds = {}; + + gemm_descs.reserve(group_count); + ggemm_kargs.reserve(group_count); + p_As.reserve(group_count); + p_Bs.reserve(group_count); + p_Ds.reserve(group_count); + + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + using namespace ck::literals; + + if(std::is_same::value) + { + return HostTensorDescriptor({row, col}, {stride, 1_uz}); + } + else + { + return HostTensorDescriptor({row, col}, {1_uz, stride}); + } + }; + + std::vector> a_tensors; + std::vector> b_tensors; + std::vector, NumDs>> d_tensors; + std::vector> c_host_tensors; + std::vector> c_device_result_tensors; + + a_tensors.reserve(group_count); + b_tensors.reserve(group_count); + d_tensors.reserve(group_count); + c_host_tensors.reserve(group_count); + c_device_result_tensors.reserve(group_count); + + using DeviceMemPtr = std::unique_ptr; + + std::vector a_tensors_device, b_tensors_device, c_tensors_device; + std::vector> d_tensors_device; + + a_tensors_device.reserve(group_count); + b_tensors_device.reserve(group_count); + d_tensors_device.reserve(group_count); + c_tensors_device.reserve(group_count); + + std::size_t flop = 0, num_btype = 0; + + for(int i = 0; i < group_count; i++) + { + a_tensors.push_back(Tensor(f_host_tensor_descriptor( + problem_size.Ms[i], problem_size.Ks[i], problem_size.stride_As[i], ALayout{}))); + b_tensors.push_back(Tensor(f_host_tensor_descriptor( + problem_size.Ks[i], problem_size.Ns[i], problem_size.stride_Bs[i], BLayout{}))); + + auto d0_tensor = Tensor(f_host_tensor_descriptor( + problem_size.Ms[i], problem_size.Ns[i], problem_size.stride_Cs[i], DLayout{})); + auto d1_tensor = Tensor(f_host_tensor_descriptor( + problem_size.Ms[i], problem_size.Ns[i], problem_size.stride_Cs[i], DLayout{})); + + std::array, NumDs> d_tens = {d0_tensor, d1_tensor}; + d_tensors.push_back(d_tens); + c_host_tensors.push_back(Tensor(f_host_tensor_descriptor( + problem_size.Ms[i], problem_size.Ns[i], problem_size.stride_Cs[i], ELayout{}))); + c_device_result_tensors.push_back(Tensor(f_host_tensor_descriptor( + problem_size.Ms[i], problem_size.Ns[i], problem_size.stride_Cs[i], ELayout{}))); + std::cout << "gemm[" << i << "] a_m_k: " << a_tensors[i].mDesc + << " b_k_n: " << b_tensors[i].mDesc + << " c_m_n: " << c_device_result_tensors[i].mDesc << std::endl; + + flop += std::size_t(2) * problem_size.Ms[i] * problem_size.Ks[i] * problem_size.Ns[i]; + num_btype += sizeof(ADataType) * a_tensors[i].GetElementSize() + + sizeof(BDataType) * b_tensors[i].GetElementSize() + + sizeof(DDataType) * d_tensors[i][0].GetElementSize() * NumDs + + sizeof(EDataType) * c_device_result_tensors[i].GetElementSize(); + + switch(config.init_method) + { + case 0: break; + case 1: + a_tensors[i].GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b_tensors[i].GenerateTensorValue(GeneratorTensor_2{-5, 5}); + for(int j = 0; j < NumDs; ++j) + { + d_tensors[i][j].GenerateTensorValue(GeneratorTensor_2{-5, 5}); + } + break; + case 2: + a_tensors[i].GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b_tensors[i].GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + for(int j = 0; j < NumDs; ++j) + { + d_tensors[i][j].GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + } + break; + default: + a_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<0>{}); + b_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<1>{}); + for(int j = 0; j < NumDs; ++j) + { + d_tensors[i][j].GenerateTensorValue(GeneratorTensor_Sequential<0>{}); + } + } + } + + for(int i = 0; i < group_count; i++) + { + a_tensors_device.emplace_back( + std::make_unique(a_tensors[i].GetElementSpaceSize() * sizeof(ADataType))); + b_tensors_device.emplace_back( + std::make_unique(b_tensors[i].GetElementSpaceSize() * sizeof(BDataType))); + c_tensors_device.emplace_back(std::make_unique( + c_device_result_tensors[i].GetElementSpaceSize() * sizeof(EDataType))); + + for(int j = 0; j < NumDs; ++j) + { + d_tensors_device[i].emplace_back(std::make_unique( + d_tensors[i][j].GetElementSpaceSize() * sizeof(DDataType))); + } + + a_tensors_device[i]->ToDevice(a_tensors[i].mData.data()); + b_tensors_device[i]->ToDevice(b_tensors[i].mData.data()); + for(int j = 0; j < NumDs; ++j) + { + d_tensors_device[i][j]->ToDevice(d_tensors[i][j].mData.data()); + } + c_tensors_device[i]->SetZero(); + + p_As.push_back(a_tensors_device[i]->GetDeviceBuffer()); + p_Bs.push_back(b_tensors_device[i]->GetDeviceBuffer()); + p_Ds.push_back( + {d_tensors_device[i][0]->GetDeviceBuffer(), d_tensors_device[i][1]->GetDeviceBuffer()}); + p_Cs.push_back(c_tensors_device[i]->GetDeviceBuffer()); + + // The device op does not have to know M problem size at lunch time. + gemm_descs.push_back({0, + problem_size.Ns[i], + problem_size.Ks[i], + problem_size.stride_As[i], + problem_size.stride_Bs[i], + problem_size.stride_Cs[i], + {problem_size.stride_Cs[i], problem_size.stride_Cs[i]}}); + ggemm_kargs.push_back( + {a_tensors_device[i]->GetDeviceBuffer(), + b_tensors_device[i]->GetDeviceBuffer(), + {d_tensors_device[i][0]->GetDeviceBuffer(), d_tensors_device[i][1]->GetDeviceBuffer()}, + c_tensors_device[i]->GetDeviceBuffer(), + problem_size.Ms[i], + problem_size.Ns[i], + problem_size.Ks[i], + problem_size.stride_As[i], + problem_size.stride_Bs[i], + {problem_size.stride_Cs[i], problem_size.stride_Cs[i]}, + problem_size.stride_Cs[i]}); + } + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto cde_element_op = CDEElementOp{}; + + auto gemm = DeviceGemmInstance{}; + auto invoker = gemm.MakeInvoker(); + + // do GEMM + auto argument = gemm.MakeArgument( + p_As, p_Bs, p_Ds, p_Cs, gemm_descs, a_element_op, b_element_op, cde_element_op); + if(!gemm.IsSupportedArgument(argument)) + { + throw std::runtime_error( + "wrong! device_gemm with the specified compilation parameters does " + "not support this GEMM problem"); + } + + DeviceMem gemm_arg_dev_mem(gemm.GetDeviceKernelArgSize(&argument)); + hip_check_error(hipMemcpy(gemm_arg_dev_mem.GetDeviceBuffer(), + ggemm_kargs.data(), + gemm.GetDeviceKernelArgSize(&argument), + hipMemcpyHostToDevice)); + gemm.SetDeviceKernelArgs(argument, gemm_arg_dev_mem.GetDeviceBuffer()); + + invoker.Run(argument, StreamConfig{nullptr, false, 1}); + + bool pass = true; + if(config.do_verification) + { + using ReferenceGemmInstance = + ck::tensor_operation::host::ReferenceGemmMultipleD; + + for(std::size_t i = 0; i < gemm_descs.size(); i++) + { + auto karg = ggemm_kargs[i]; + auto dev_res_tensor = + Tensor(f_host_tensor_descriptor(karg.M, karg.N, karg.StrideE, ELayout{})); + c_tensors_device[i]->FromDevice(c_device_result_tensors[i].mData.data()); + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = ref_gemm.MakeArgument(a_tensors[i], + b_tensors[i], + d_tensors[i], + c_host_tensors[i], + a_element_op, + b_element_op, + cde_element_op); + + ref_invoker.Run(ref_argument); + pass &= ck::utils::check_err(c_device_result_tensors[i], c_host_tensors[i]); + } + + std::cout << "Verification: " << (pass ? "SUCCESS" : "FAILURE") << "!" << std::endl; + } + + if(config.time_kernel) + { + float ave_time = invoker.Run(argument, StreamConfig{nullptr, config.time_kernel}); + float tflops = static_cast(flop) / 1.E9 / ave_time; + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec + << " GB/s, " << gemm.GetTypeString() << std::endl; + } + + return pass; +} + +std::vector argToIntArray(char* input) +{ + std::vector out; + std::istringstream in(input); + std::string item; + + while(std::getline(in, item, ',')) + { + out.push_back(std::stoi(item)); + } + return out; +} + +int main(int argc, char* argv[]) +{ + ProblemSize problem_size; + ExecutionConfig config; + + if(argc < 10) + { + std::vector Ms{64, 127, 255, 129, 260, 190, 77}; + problem_size.group_count = Ms.size(); + + for(int i = 0; i < problem_size.group_count; i++) + { + problem_size.Ms.push_back(Ms[i]); + problem_size.Ns.push_back(252); + problem_size.Ks.push_back(4608); + + problem_size.stride_As.push_back(problem_size.Ks[i]); + problem_size.stride_Bs.push_back(problem_size.Ks[i]); + problem_size.stride_Cs.push_back(problem_size.Ns[i]); + + problem_size.stride_Ds.push_back({}); + for(int j = 0; j < NumDs; ++j) + { + problem_size.stride_Ds[i].push_back(problem_size.Ns[i]); + } + } + + std::cout + << "Usage:\n" + << "arg1: verification (0=no, 1=yes)\n" + << "arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n" + << "arg3: time kernel (0=n0, 1=yes)\n" + << "arg4 to 9: Ms, Ns, Ks, StrideAs, StrideBs, StrideCs (e.g., 256,256 128,128 64,64 " + "64,64 64,64 128,128)\n" + << "... setting default values." << std::endl; + } + else + { + config.do_verification = std::stoi(argv[1]); + config.init_method = std::stoi(argv[2]); + config.time_kernel = std::stoi(argv[3]); + + problem_size.Ms = argToIntArray(argv[4]); + problem_size.Ns = argToIntArray(argv[5]); + problem_size.Ks = argToIntArray(argv[6]); + + problem_size.stride_As = argToIntArray(argv[7]); + problem_size.stride_Bs = argToIntArray(argv[8]); + problem_size.stride_Cs = argToIntArray(argv[9]); + + for(int j = 0; j < NumDs; ++j) + { + problem_size.stride_Ds.push_back(problem_size.stride_Cs); + } + + problem_size.group_count = problem_size.Ms.size(); + } + + return !run_grouped_gemm(problem_size, config); +} diff --git a/include/ck/tensor_operation/gpu/device/device_grouped_gemm_tile_loop.hpp b/include/ck/tensor_operation/gpu/device/device_grouped_gemm_tile_loop.hpp new file mode 100644 index 0000000000..c1030f31cc --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/device_grouped_gemm_tile_loop.hpp @@ -0,0 +1,128 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include +#include + +#include "device_grouped_gemm.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +/// +/// @brief Structure representing single GEMM problem arguments. +/// +/// The pointer to the vector of those structures is passed to the GroupedGEMM entry +/// point kernel. +/// +/// @tparam NumDTensor The number of D input tensors. +/// +template +struct GroupedGemmTileLoopKernelArguments +{ + __host__ __device__ + GroupedGemmTileLoopKernelArguments(const void* p_a_grid_, + const void* p_b_grid_, + std::array p_ds_grid_, + void* p_e_grid_, + index_t M_, + index_t N_, + index_t K_, + index_t StrideA_, + index_t StrideB_, + std::array StrideDs_, + index_t StrideE_) + : p_a_grid{p_a_grid_}, + p_b_grid{p_b_grid_}, + p_ds_grid{p_ds_grid_}, + p_e_grid{p_e_grid_}, + M{M_}, + N{N_}, + K{K_}, + StrideA{StrideA_}, + StrideB{StrideB_}, + StrideDs{StrideDs_}, + StrideE{StrideE_} + { + } + + const void* p_a_grid; + const void* p_b_grid; + std::array p_ds_grid; + void* p_e_grid; + index_t M; + index_t N; + index_t K; + index_t StrideA; + index_t StrideB; + std::array StrideDs; + index_t StrideE; + + void Print() const + { + std::stringstream str; + for(auto sd : StrideDs) + str << sd << ","; + + std::cout << "arg {" + << "M:" << M << ", " + << "N:" << N << ", " + << "K:" << K << ", " + << "SA:" << StrideA << ", " + << "SB:" << StrideB << ", " + << "SE:" << StrideE << ", " + << "SDs: {" << str.str() << "}" + << "}" << std::endl; + } +}; + +template +struct DeviceGroupedGemmTileLoop : public DeviceGroupedGemm +{ + //---------------------------------------------------------------------------------------------- + /// @brief Sets the device kernel arguments pointer. + /// + /// @param p_arg The pointer to the Argument we're going to update. + /// @param[in] p_dev_kernel_args The pointer to the device memory which contains kernel + /// arguments. + /// + virtual void SetDeviceKernelArgs(BaseArgument* p_arg, void* p_dev_kernel_args) const = 0; + + //---------------------------------------------------------------------------------------------- + /// @brief Gets the device kernel argument size. + /// + /// @param[in] p_arg The pointer to the Device op Argument. + /// + /// @return The device kernel argument size. + /// + virtual size_t GetDeviceKernelArgSize(const BaseArgument* p_arg) const = 0; +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_xdl_cshuffle_tile_loop.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_xdl_cshuffle_tile_loop.hpp new file mode 100644 index 0000000000..0a0e8072bf --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_xdl_cshuffle_tile_loop.hpp @@ -0,0 +1,787 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include + +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" +#include "ck/host_utility/hip_check_error.hpp" +#include "ck/host_utility/stream_utility.hpp" +#include "ck/utility/common_header.hpp" +#include "ck/utility/loop_scheduler.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_gemm_tile_loop.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +/// +/// @brief Entry point kernel for device-wide Grouped GEMM operation. +/// +/// @param[in] gemm_descs_const The pointer to the array of GEMM descriptor structures. +/// @param[in] group_count The number of together processed GEMMs. +/// +/// @tparam GridwiseGemm The specific GridwiseGEMM algorithm implementation. +/// @tparam GemmDesc The structure holding all necessary descriptors and +/// other data needed for grouped gemm calculation and work +/// distribution. +/// @tparam LocalBlock2ETileMap The structure providing mapping between workgroup ids, +/// the data tiles to process and the output tiles. +/// +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_grouped_gemm_multiple_d_xdl(const void CK_CONSTANT_ADDRESS_SPACE* gemm_descs_const, + const index_t group_count, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CDEElementwiseOperation cde_element_op) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ + defined(__gfx94__)) + + constexpr index_t shared_size = GridwiseGemm::GetSharedMemoryNumberOfByte(); + __shared__ uint8_t p_shared[shared_size]; + + const auto gemm_desc_ptr = + reinterpret_cast(cast_pointer_to_generic_address_space(gemm_descs_const)); + + constexpr auto NumDTensor = DsDataType::Size(); + index_t tile_id = get_block_1d_id(); + index_t tile_offset = 0; + index_t group_id = -1; + index_t group_offset = 0; + index_t grid_size_grp = 0; + + index_t gemm_tile_id_start = 0; + index_t gemm_tile_id_end = 0; + + using AGridDescMK = + remove_cvref_t( + 1, 1, 1))>; + using BGridDescNK = + remove_cvref_t( + 1, 1, 1))>; + using EGridDescMN = + remove_cvref_t( + 1, 1, 1))>; + using DsGridDescMN = + remove_cvref_t( + {}, {}, {}))>; + + index_t M = 0, N = 0, K = 0; + index_t StrideA, StrideB, StrideE; + std::array StrideDs; + + AGridDescMK a_grid_desc_mk; + BGridDescNK b_grid_desc_nk; + EGridDescMN e_grid_desc_mn; + DsGridDescMN ds_grid_desc_mn; + auto b2c_tile_map = OffsettedBlockToCTileMap(LocalBlock2ETileMap(1, 1), 1, 1); + + do + { + // Find corresponding GEMM group for our tile + while(!(tile_id >= gemm_tile_id_start && tile_id < gemm_tile_id_end) && + group_id < group_count) + { + group_offset += grid_size_grp; + group_id++; + + if(group_id >= group_count) + return; + + M = gemm_desc_ptr[group_id].M; + N = gemm_desc_ptr[group_id].N; + K = gemm_desc_ptr[group_id].K; + + if(M * N * K == 0) + { + grid_size_grp = 0; + continue; + } + + b2c_tile_map = + OffsettedBlockToCTileMap(LocalBlock2ETileMap(M, N), group_offset, tile_offset); + grid_size_grp = b2c_tile_map.CalculateGridSize(M, N); + + gemm_tile_id_start = group_offset; + gemm_tile_id_end = group_offset + grid_size_grp; + } + + StrideA = gemm_desc_ptr[group_id].StrideA; + StrideB = gemm_desc_ptr[group_id].StrideB; + StrideDs = gemm_desc_ptr[group_id].StrideDs; + StrideE = gemm_desc_ptr[group_id].StrideE; + + a_grid_desc_mk = + GridwiseGemm::template MakeAGridDescriptor_M_K(M, K, StrideA); + b_grid_desc_nk = + GridwiseGemm::template MakeBGridDescriptor_N_K(K, N, StrideB); + e_grid_desc_mn = + GridwiseGemm::template MakeEGridDescriptor_M_N(M, N, StrideE); + + static_for<0, NumDTensor, 1>{}([&](auto j) { + using DLayout = remove_cvref_t>; + ds_grid_desc_mn(j) = GridwiseGemm::template MakeEGridDescriptor_M_N( + M, N, StrideDs[j]); + }); + + using DsGridPointer = decltype(GridwiseGemm::MakeDsGridPointer()); + DsGridPointer p_ds_grid; + + static_for<0, NumDTensor, 1>{}([&](auto i) { + using DDataType = remove_cvref_t>; + p_ds_grid(i) = static_cast(gemm_desc_ptr[group_id].p_ds_grid[i]); + }); + + bool has_main_kblock_loop = + GridwiseGemm::CalculateHasMainKBlockLoop(a_grid_desc_mk.GetLength(Number<1>{})); + // Update tile offset if we have moved within group + b2c_tile_map.UpdateTileOffset(tile_offset); + + if(has_main_kblock_loop) + { + GridwiseGemm::template Run(gemm_desc_ptr[group_id].p_a_grid, + gemm_desc_ptr[group_id].p_b_grid, + p_ds_grid, + gemm_desc_ptr[group_id].p_e_grid, + static_cast(p_shared), + a_element_op, + b_element_op, + cde_element_op, + a_grid_desc_mk, + b_grid_desc_nk, + ds_grid_desc_mn, + e_grid_desc_mn, + b2c_tile_map); + } + else + { + GridwiseGemm::template Run(gemm_desc_ptr[group_id].p_a_grid, + gemm_desc_ptr[group_id].p_b_grid, + p_ds_grid, + gemm_desc_ptr[group_id].p_e_grid, + static_cast(p_shared), + a_element_op, + b_element_op, + cde_element_op, + a_grid_desc_mk, + b_grid_desc_nk, + ds_grid_desc_mn, + e_grid_desc_mn, + b2c_tile_map); + } + + tile_id += get_grid_size(); + tile_offset += get_grid_size(); + + } while(group_id < group_count); +#else + ignore = gemm_descs_const; + ignore = group_count; + ignore = a_element_op; + ignore = b_element_op; + ignore = cde_element_op; +#endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) +} + +template +struct DeviceGroupedGemmMultipleDXdlCShuffleTileLoop + : public DeviceGroupedGemmTileLoop +{ + using DeviceOp = DeviceGroupedGemmMultipleDXdlCShuffleTileLoop; + static constexpr index_t NumDTensor = DsDataType::Size(); + + using GridwiseGemm = GridwiseGemmMultipleD_xdl_cshuffle< + ADataType, + BDataType, + ComputeDataType, + AccDataType, + CShuffleDataType, + DsDataType, + EDataType, + AElementwiseOperation, + BElementwiseOperation, + CDEElementwiseOperation, + InMemoryDataOperationEnum::Set, + NumGemmKPrefetchStage, + BlockSize, + MPerBlock, + NPerBlock, + KPerBlock, + AK1, + BK1, + MPerXDL, + NPerXDL, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + false, // AThreadTransferSrcResetCoordinateAfterRun, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + false, // BThreadTransferSrcResetCoordinateAfterRun, + BBlockLdsExtraN, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CDEShuffleBlockTransferScalarPerVector_NPerBlock, + LoopSched, + PipelineVer>; + + template + struct OffsettedBlockToCTileMap + { + using underlying_type = UnderlyingBlockToCTileMap; + + __host__ __device__ OffsettedBlockToCTileMap(UnderlyingBlockToCTileMap block_to_ctile_map, + index_t group_offset, + index_t tile_offset) + : block_to_ctile_map_{block_to_ctile_map}, + group_offset_{group_offset}, + tile_offset_{tile_offset} + { + } + + template + __host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx& idx_top) const + { + return block_to_ctile_map_.CalculateBottomIndex( + make_multi_index(idx_top[Number<0>{}] + tile_offset_ - group_offset_)); + } + + template + __host__ __device__ bool ValidCTileIndex(const CTileIdx& c_tile_idx, + const CTileDim& c_tile_dim) const + { + return block_to_ctile_map_.ValidCTileIndex(c_tile_idx, c_tile_dim); + } + + template + __host__ constexpr bool CheckValidity(const CGridDesc_M_N& c_grid_desc_m_n) const + { + return block_to_ctile_map_.CheckValidity(c_grid_desc_m_n); + } + + __host__ __device__ constexpr index_t CalculateGridSize(index_t M, index_t N) const + { + return block_to_ctile_map_.CalculateGridSize(M, N); + } + + __device__ void UpdateTileOffset(index_t offset) { tile_offset_ = offset; } + UnderlyingBlockToCTileMap block_to_ctile_map_; + index_t group_offset_; + index_t tile_offset_; + }; + + using KernelArguments = GroupedGemmTileLoopKernelArguments; + using Block2ETileMap = BlockToCTileMap_N00_M0_N01Adapt; + using OffsetedLocalBlock2ETileMap = OffsettedBlockToCTileMap; + + // Argument + struct Argument : public BaseArgument + { + Argument(std::vector& /* p_As */, + std::vector& /* p_Bs */, + std::vector>& /* p_Ds */, + std::vector& /* p_Es */, + std::vector& gemm_descs, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op, + int occupancy_num_blocks, + int gpu_cu_count) + : group_count_{static_cast(gemm_descs.size())}, + occupancy_num_blocks_{occupancy_num_blocks}, + gpu_cu_count_{gpu_cu_count}, + gemm_descs_{gemm_descs}, + a_element_op_{a_element_op}, + b_element_op_{b_element_op}, + cde_element_op_{cde_element_op}, + tile_count_{0} + { + for(const auto& desc : gemm_descs) + { + const auto M = desc.M_; + const auto N = desc.N_; + const auto b2c_tile_map = Block2ETileMap(M, N); + tile_count_ += b2c_tile_map.CalculateGridSize(M, N); + } + } + + index_t group_count_; + const void* p_dev_gemm_args_; + int occupancy_num_blocks_; + int gpu_cu_count_; + + const std::vector& gemm_descs_; + AElementwiseOperation a_element_op_; + BElementwiseOperation b_element_op_; + CDEElementwiseOperation cde_element_op_; + index_t tile_count_; + }; + + struct KernelConfig + { + // The oversubscription factor for the number of blocks that can simultaneously reside on + // GPU. + static constexpr int BLOCK_SUBSCRIPTION_FACTOR = 1; + static constexpr int BLOCK_WAVES = BlockSize / get_warp_size(); + static constexpr int CU_SIMDS = 4; + // Assume we want to have at most 2 waves per SIMD + static constexpr int CU_BLOCKS = math::integer_divide_floor(2 * CU_SIMDS, BLOCK_WAVES); + }; + + // Invoker + struct Invoker : public BaseInvoker + { + /// + /// @brief Launch Grouped Gemm kernel. + /// + /// @note This function overload is using user provided device buffer for kernel + /// arguments. + /// + /// @param[in] arg The structure containing kernel arguments (in host + /// memory). + /// @param[in] dev_gemm_args The pointer to device memory with kernel arguments. + /// @param[in] stream_config The device stream configuration. + /// + /// @return The average kernel execution time (if time measurement is enabled.) + /// + float Run(const Argument& arg, + const void* dev_gemm_args, + const StreamConfig& stream_config = StreamConfig{}) + { + if(dev_gemm_args == nullptr) + { + std::ostringstream err; + err << "The gemm arguments device buffer is not allocated!" + << " In " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__; + throw std::runtime_error(err.str()); + } + + float ave_time = 0; + ave_time = DispatchKernel(arg, dev_gemm_args, stream_config); + + return ave_time; + } + + /// + /// @brief Launch Grouped Gemm kernel. + /// + /// @note This function overload is using device buffers (for kernel arguments and + /// for kernel auxiliary workspace) provided with an argument. The user should + /// call @see GetDeviceKernelArgSize, and @see SetDeviceKernelArgs, on arg + /// parameter to properly allocate those buffers. + /// + /// @param[in] arg The structure containing kernel arguments (in host memory). + /// @param[in] stream_config The device stream configuration. + /// + /// @return The average kernel execution time (if time measurement is enabled.) + /// + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + if(arg.p_dev_gemm_args_ == nullptr) + { + std::ostringstream err; + err << "The gemm arguments device buffer is not allocated!" + << " In " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__; + throw std::runtime_error(err.str()); + } + + return Run(arg, arg.p_dev_gemm_args_, stream_config); + } + + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + + private: + float DispatchKernel(const Argument& arg, + const void* dev_gemm_args, + const StreamConfig& stream_config) const + { + const auto kernel = kernel_grouped_gemm_multiple_d_xdl; + return LaunchKernel(kernel, arg, dev_gemm_args, stream_config); + } + + template + int CalculateMaxOccupancyGridSize(const KernelFunction& kernel, + const StreamConfig& stream_config) const + { + // Calculate max number of workgroups that can simultaneously reside on the CU. + int occ_num_blocks = 0; + size_t dyn_shared_mem_per_blk = 0; + hip_check_error(hipOccupancyMaxActiveBlocksPerMultiprocessor( + &occ_num_blocks, kernel, BlockSize, dyn_shared_mem_per_blk)); + + int cu_count = getAvailableComputeUnitCount(stream_config); + + if(stream_config.log_level_ > 0) + { + std::cout << "MaxActiveBlocksPerCU: " << occ_num_blocks + << ", available CUs count: " << cu_count << ", occup. grid size: " + << ck::math::min(occ_num_blocks, KernelConfig::CU_BLOCKS) * cu_count + << std::endl; + } + + return cu_count * ck::math::min(occ_num_blocks, KernelConfig::CU_BLOCKS); + } + + template + float LaunchKernel(const KernelFunction& kernel, + const Argument& arg, + const void* dev_gemm_args, + const StreamConfig& stream_config) const + { + int grid_size = CalculateMaxOccupancyGridSize(kernel, stream_config); + + if(stream_config.log_level_ > 0) + { + std::cout << "grid_size: " << grid_size << " tile_count: " << arg.tile_count_ + << std::endl; + } + + return launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + cast_pointer_to_constant_address_space(dev_gemm_args), + arg.group_count_, + arg.a_element_op_, + arg.b_element_op_, + arg.cde_element_op_); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + static bool IsSupportedArgument(const Argument& arg) + { + if(!ck::is_xdl_supported()) + { + return false; + } + + using DsGridDescMN = remove_cvref_t< + decltype(GridwiseGemm::template MakeDsGridDescriptor_M_N( + {}, {}, {}))>; + + bool supported = true; + + for(const auto& gdesc : arg.gemm_descs_) + { + const auto M = gdesc.M_; + const auto N = gdesc.N_; + const auto K = gdesc.K_; + + const auto StrideA = gdesc.stride_A_; + const auto StrideB = gdesc.stride_B_; + const auto StrideE = gdesc.stride_C_; + const auto& StrideDs = gdesc.stride_Ds_; + + // If M dimension is unknown at launch time then validate just NK. + // If N or K dim is zero (or unknown) then the vector loads responsibility lies on + // the user. + if(N * K == 0) + continue; + + const auto a_grid_desc_mk = + GridwiseGemm::template MakeAGridDescriptor_M_K(M, K, StrideA); + const auto b_grid_desc_nk = + GridwiseGemm::template MakeBGridDescriptor_N_K(K, N, StrideB); + const auto e_grid_desc_mn = + GridwiseGemm::template MakeEGridDescriptor_M_N(M, N, StrideE); + + DsGridDescMN ds_grid_desc_mn; + static_for<0, NumDTensor, 1>{}([&](auto j) { + using DLayout = remove_cvref_t>; + ds_grid_desc_mn(j) = + GridwiseGemm::template MakeEGridDescriptor_M_N( + M, N, StrideDs[j]); + }); + + const auto b2c_tile_map = Block2ETileMap(M, N); + + if(!(GridwiseGemm::template CheckValidity(a_grid_desc_mk, + b_grid_desc_nk, + ds_grid_desc_mn, + e_grid_desc_mn, + b2c_tile_map) && + GridwiseGemm::template CheckTensorTransfersValidity( + M, N, K))) + { +#if DEBUG_LOG + std::cout << "The provided GEMM problem size (M,N,K) [" << M << "," << N << "," << K + << "] are not supported by current template parameters!" + << " In " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__; +#endif + supported = false; + } + } + + return supported; + } + + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument(std::vector& p_As, + std::vector& p_Bs, + std::vector>& p_Ds, + std::vector& p_Es, + std::vector gemm_descs, + AElementwiseOperation a_elementwise_op, + BElementwiseOperation b_elementwise_op, + CDEElementwiseOperation cde_elementwise_op) + { + const auto kernel = kernel_grouped_gemm_multiple_d_xdl; + int occupancy, num_cu; + hip_check_error( + hipOccupancyMaxActiveBlocksPerMultiprocessor(&occupancy, kernel, BlockSize, 0)); + + hipDeviceProp_t dev_prop; + hipDevice_t dev; + hip_check_error(hipGetDevice(&dev)); + hip_check_error(hipGetDeviceProperties(&dev_prop, dev)); + num_cu = dev_prop.multiProcessorCount; + + return Argument{p_As, + p_Bs, + p_Ds, + p_Es, + gemm_descs, + a_elementwise_op, + b_elementwise_op, + cde_elementwise_op, + occupancy, + num_cu}; + } + + std::unique_ptr + MakeArgumentPointer(std::vector& p_As, + std::vector& p_Bs, + std::vector>& p_Ds, + std::vector& p_Es, + std::vector& gemm_descs, + AElementwiseOperation a_elementwise_op, + BElementwiseOperation b_elementwise_op, + CDEElementwiseOperation cde_elementwise_op) override + { + const auto kernel = kernel_grouped_gemm_multiple_d_xdl; + int occupancy, num_cu; + hip_check_error( + hipOccupancyMaxActiveBlocksPerMultiprocessor(&occupancy, kernel, BlockSize, 0)); + + hipDeviceProp_t dev_prop; + hipDevice_t dev; + hip_check_error(hipGetDevice(&dev)); + hip_check_error(hipGetDeviceProperties(&dev_prop, dev)); + num_cu = dev_prop.multiProcessorCount; + + return std::make_unique(p_As, + p_Bs, + p_Ds, + p_Es, + gemm_descs, + a_elementwise_op, + b_elementwise_op, + cde_elementwise_op, + occupancy, + num_cu); + } + + static auto MakeInvoker() { return Invoker{}; } + + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + std::string GetTypeString() const override + { + auto str = std::ostringstream(); + + // clang-format off + str << "DeviceGroupedGemmMultipleDXdlCShuffleTileLoop" + << "<" + << std::string(ALayout::name)[0] << "," + << std::string(BLayout::name)[0] << "," + << std::string(ELayout::name)[0] << "," + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << KPerBlock << ", " + << AK1 << ", " + << BK1 << ", " + << MPerXDL << ", " + << NPerXDL << ", " + << MXdlPerWave << ", " + << NXdlPerWave << ", " + << ABlockTransferSrcScalarPerVector << ", " + << BBlockTransferSrcScalarPerVector << ", " + << CShuffleMXdlPerWavePerShuffle << ", " + << CShuffleNXdlPerWavePerShuffle << ", " + << getGemmSpecializationString(GemmSpec) << ", " + << PipelineVer << ", " + << LoopSched + << ">"; + // clang-format on + + return str.str(); + } + + void SetDeviceKernelArgs(Argument& arg, void* p_dev_kernel_args) const + { + arg.p_dev_gemm_args_ = p_dev_kernel_args; + } + + void SetDeviceKernelArgs(BaseArgument* p_arg, void* p_dev_kernel_args) const override + { + return SetDeviceKernelArgs(*dynamic_cast(p_arg), p_dev_kernel_args); + } + + size_t GetDeviceKernelArgSize(const BaseArgument* p_arg) const override + { + return dynamic_cast(p_arg)->group_count_ * sizeof(KernelArguments); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp b/include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp index 148aba5aaf..d92f504d52 100644 --- a/include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp +++ b/include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -151,7 +151,7 @@ struct BlockToCTileMap_M00_N0_M01Adapt { } - __host__ static constexpr index_t CalculateGridSize(index_t M, index_t N) + __host__ __device__ static constexpr index_t CalculateGridSize(index_t M, index_t N) { const auto M0 = math::integer_divide_ceil(M, MPerBlock); const auto N0 = math::integer_divide_ceil(N, NPerBlock); @@ -275,7 +275,7 @@ struct BlockToCTileMap_Grouped_M00_N0_M01Adapt { } - __host__ static constexpr index_t CalculateGridSize(index_t M, index_t N) + __host__ __device__ static constexpr index_t CalculateGridSize(index_t M, index_t N) { const auto M0 = math::integer_divide_ceil(M, MPerBlock); const auto N0 = math::integer_divide_ceil(N, NPerBlock); @@ -428,7 +428,7 @@ struct BlockToCTileMap_N00_M0_N01Adapt { } - __host__ static constexpr index_t CalculateGridSize(index_t M, index_t N) + __host__ __device__ static constexpr index_t CalculateGridSize(index_t M, index_t N) { const auto M0 = math::integer_divide_ceil(M, MPerBlock); const auto N0 = math::integer_divide_ceil(N, NPerBlock); @@ -900,6 +900,11 @@ struct OffsettedBlockToCTileMap return block_to_ctile_map_.CalculateGridSize(c_grid_desc_m_n); } + __host__ __device__ constexpr index_t CalculateGridSize(index_t M, index_t N) const + { + return block_to_ctile_map_.CalculateGridSize(M, N); + } + UnderlyingBlockToCTileMap block_to_ctile_map_; index_t block_start_; }; diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp index 6ddc3aca18..e6085fad8c 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp @@ -257,7 +257,70 @@ struct GridwiseGemmMultipleD_xdl_cshuffle e_grid_desc_m_n); } - // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} + template + __host__ __device__ static bool + CheckTensorTransfersValidity(index_t MRaw, index_t NRaw, index_t KRaw) + { + // Check if the vector dim is K1 or M|N + const auto A_vector_dim_size = ABlockTransferSrcVectorDim == 2 ? KRaw : MRaw; + const auto B_vector_dim_size = BBlockTransferSrcVectorDim == 2 ? KRaw : NRaw; + const auto E_vector_dim_size = NRaw; + + // check vector load for A tensor + if constexpr(is_same_v) + { + if(!(A_vector_dim_size == KRaw && + A_vector_dim_size % ABlockTransferSrcScalarPerVector == 0)) + return false; + } + else if constexpr(is_same_v) + { + if(!(A_vector_dim_size == MRaw && + A_vector_dim_size % ABlockTransferSrcScalarPerVector == 0)) + return false; + } + else + { + return false; + } + + if constexpr(is_same_v) + { + if(!(B_vector_dim_size == NRaw && + B_vector_dim_size % BBlockTransferSrcScalarPerVector == 0)) + return false; + } + else if constexpr(is_same_v) + { + if(!(B_vector_dim_size == KRaw && + B_vector_dim_size % BBlockTransferSrcScalarPerVector == 0)) + return false; + } + else + { + return false; + } + + if constexpr(is_same_v) + { + if(!(E_vector_dim_size == NRaw && + E_vector_dim_size % CDEShuffleBlockTransferScalarPerVector_NPerBlock == 0)) + return false; + } + else if constexpr(is_same_v) + { + if(!(E_vector_dim_size == NRaw && + CDEShuffleBlockTransferScalarPerVector_NPerBlock == 1)) + return false; + } + else + { + return false; + } + + return true; + } + template {}([&](auto i) { @@ -306,7 +368,6 @@ struct GridwiseGemmMultipleD_xdl_cshuffle // check gridwise gemm pipeline const auto num_k_loop = AK / KPerBlock; - if(!GridwiseGemmPipe::IsSupported(num_k_loop)) { return false; @@ -938,6 +999,63 @@ struct GridwiseGemmMultipleD_xdl_cshuffle e_grid_desc_mblock_mperblock_nblock_nperblock, block_2_etile_map); } + + template + __device__ static void Run(const void* __restrict__ p_a_grid_, + const void* __restrict__ p_b_grid_, + DsGridPointer p_ds_grid, + void* __restrict__ p_e_grid_, + void* __restrict__ p_shared, + const AElementwiseOperation& a_element_op, + const BElementwiseOperation& b_element_op, + const CDEElementwiseOperation& cde_element_op, + const AGridDesc_MK& a_grid_desc_m_k, + const BGridDesc_NK& b_grid_desc_n_k, + const DsGridDesc_MN& ds_grid_desc_m_n, + const EGridDesc_MN& e_grid_desc_m_n, + const Block2ETileMap& block_2_etile_map) + { + const auto p_a_grid = reinterpret_cast(p_a_grid_); + const auto p_b_grid = reinterpret_cast(p_b_grid_); + const auto p_e_grid = reinterpret_cast(p_e_grid_); + + // tensor descriptors for block/thread-wise copy + const auto a_grid_desc_ak0_m_ak1 = MakeDefaultAGridDescriptor_AK0_M_AK1(a_grid_desc_m_k); + const auto b_grid_desc_bk0_n_bk1 = MakeDefaultBGridDescriptor_BK0_N_BK1(b_grid_desc_n_k); + + using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = + remove_cvref_t; + + DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock ds_grid_desc_mblock_mperblock_nblock_nperblock; + + static_for<0, NumDTensor, 1>{}([&](auto j) { + ds_grid_desc_mblock_mperblock_nblock_nperblock(j) = + MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(ds_grid_desc_m_n[j]); + }); + + const auto e_grid_desc_mblock_mperblock_nblock_nperblock = + MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(e_grid_desc_m_n); + + Run(p_a_grid, + p_b_grid, + p_ds_grid, + p_e_grid, + p_shared, + a_element_op, + b_element_op, + cde_element_op, + a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + ds_grid_desc_mblock_mperblock_nblock_nperblock, + e_grid_desc_mblock_mperblock_nblock_nperblock, + block_2_etile_map); + } }; } // namespace ck diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp index 567c42362c..44cbbcd049 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp @@ -1,9 +1,10 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once #include +#include #include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v2.hpp" @@ -57,3 +58,16 @@ constexpr auto GridwiseGemmPipeline_Selector() } } // namespace ck + +inline std::ostream& operator<<(std::ostream& os, const ck::PipelineVersion& p) +{ + switch(p) + { + case ck::PipelineVersion::v1: os << "PipelineVersion::v1"; break; + case ck::PipelineVersion::v2: os << "PipelineVersion::v2"; break; + case ck::PipelineVersion::v4: os << "PipelineVersion::v4"; break; + case ck::PipelineVersion::weight_only: os << "PipelineVersion::weight_only"; break; + default: os << ""; + } + return os; +} diff --git a/include/ck/utility/debug.hpp b/include/ck/utility/debug.hpp index 80346f0d9f..03c4e16dd6 100644 --- a/include/ck/utility/debug.hpp +++ b/include/ck/utility/debug.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #ifndef UTILITY_DEBUG_HPP #define UTILITY_DEBUG_HPP @@ -79,6 +79,13 @@ __device__ void print_shared(T const* p_shared, index_t num_elements) __syncthreads(); } +template +__device__ static bool is_thread_local_1d_id_idx() +{ + const auto tid = get_thread_local_1d_id(); + return ((tid == Ids) || ...); +} + } // namespace debug } // namespace ck diff --git a/include/ck/utility/loop_scheduler.hpp b/include/ck/utility/loop_scheduler.hpp index b2eb0ddb93..0c4d85bedb 100644 --- a/include/ck/utility/loop_scheduler.hpp +++ b/include/ck/utility/loop_scheduler.hpp @@ -1,5 +1,6 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +#include #pragma once @@ -24,3 +25,14 @@ constexpr LoopScheduler make_default_loop_scheduler() } } // namespace ck + +inline std::ostream& operator<<(std::ostream& os, const ck::LoopScheduler& s) +{ + switch(s) + { + case ck::LoopScheduler::Default: os << "Default"; break; + case ck::LoopScheduler::Interwave: os << "Interwave"; break; + default: os << ""; + } + return os; +} diff --git a/include/ck/utility/sequence.hpp b/include/ck/utility/sequence.hpp index d6bfb2eba1..f9c9352dd7 100644 --- a/include/ck/utility/sequence.hpp +++ b/include/ck/utility/sequence.hpp @@ -1,8 +1,10 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once +#include + #include "ck/utility/integral_constant.hpp" #include "ck/utility/type.hpp" #include "ck/utility/functional.hpp" @@ -897,3 +899,14 @@ template using uniform_sequence_gen_t = typename uniform_sequence_gen::type; } // namespace ck + +template +std::ostream& operator<<(std::ostream& os, const ck::Sequence) +{ + using S = ck::Sequence; + os << "{"; + ck::static_for<0, S::Size() - ck::Number<1>{}, 1>{}( + [&](auto i) { os << S::At(i).value << ", "; }); + os << S::At(S::Size() - ck::Number<1>{}).value << "}"; + return os; +} diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_gemm_tile_loop.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_gemm_tile_loop.hpp new file mode 100644 index 0000000000..d3fce12ce7 --- /dev/null +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_gemm_tile_loop.hpp @@ -0,0 +1,108 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_gemm_tile_loop.hpp" + +#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +#ifdef CK_ENABLE_FP16 +// fp16_output +void add_device_grouped_gemm_xdl_tile_loop_f16_f16_f16_mk_kn_mn_instances( + std::vector>>& instances); + +void add_device_grouped_gemm_xdl_tile_loop_f16_f16_f16_mk_nk_mn_instances( + std::vector>>& instances); +#endif + +template +struct DeviceOperationInstanceFactory< + ck::tensor_operation::device::DeviceGroupedGemmTileLoop> +{ + using DeviceOp = DeviceGroupedGemmTileLoop; + + static auto GetInstances() + { + std::vector> op_ptrs; + +#ifdef CK_ENABLE_FP16 + // fp16_output + if constexpr(is_same_v && is_same_v && + is_same_v) + { + if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_gemm_xdl_tile_loop_f16_f16_f16_mk_kn_mn_instances(op_ptrs); + } + if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_gemm_xdl_tile_loop_f16_f16_f16_mk_nk_mn_instances(op_ptrs); + } + } +#endif + return op_ptrs; + } +}; + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm_tile_loop/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_gemm_tile_loop/CMakeLists.txt new file mode 100644 index 0000000000..50077e18b3 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm_tile_loop/CMakeLists.txt @@ -0,0 +1,9 @@ +# ONLY XDL_KERNELS +set(GROUPED_GEMM_TILE_LOOP_INSTANCES) + +list(APPEND GROUPED_GEMM_TILE_LOOP_INSTANCES + device_grouped_gemm_xdl_tile_loop_f16_f16_f16_mk_kn_mn_instance.cpp + device_grouped_gemm_xdl_tile_loop_f16_f16_f16_mk_nk_mn_instance.cpp + ) + +add_instance_library(device_grouped_gemm_tile_loop_instance ${GROUPED_GEMM_TILE_LOOP_INSTANCES}) diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm_tile_loop/device_grouped_gemm_xdl_tile_loop_f16_f16_f16_mk_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_gemm_tile_loop/device_grouped_gemm_xdl_tile_loop_f16_f16_f16_mk_kn_mn_instance.cpp new file mode 100644 index 0000000000..505afbdff7 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm_tile_loop/device_grouped_gemm_xdl_tile_loop_f16_f16_f16_mk_kn_mn_instance.cpp @@ -0,0 +1,75 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_xdl_cshuffle_tile_loop.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using DsDataType = ck::Tuple<>; + +using DsLayout = ck::Tuple<>; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +using device_grouped_gemm_xdl_tile_loop_f16_f16_f16_mk_kn_mn_irregular_tile_instances = std::tuple< + // clang-format off + //###########################################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //###########################################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //###########################################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //###########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 2, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 2, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 2, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 2, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8> + // clang-format on + >; + +void add_device_grouped_gemm_xdl_tile_loop_f16_f16_f16_mk_kn_mn_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_gemm_xdl_tile_loop_f16_f16_f16_mk_kn_mn_irregular_tile_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm_tile_loop/device_grouped_gemm_xdl_tile_loop_f16_f16_f16_mk_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_gemm_tile_loop/device_grouped_gemm_xdl_tile_loop_f16_f16_f16_mk_nk_mn_instance.cpp new file mode 100644 index 0000000000..9653d3eef0 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm_tile_loop/device_grouped_gemm_xdl_tile_loop_f16_f16_f16_mk_nk_mn_instance.cpp @@ -0,0 +1,77 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_xdl_cshuffle_tile_loop.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using DsDataType = ck::Tuple<>; +using DsLayout = ck::Tuple<>; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +using device_grouped_gemm_xdl_tile_loop_f16_f16_f16_mk_nk_mn_irregular_tile_instances = std::tuple< + // clang-format off + //###########################################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //###########################################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //###########################################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //###########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Col, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 256, 64, 8, 8, 32, 32, 2, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Col, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 128, 64, 8, 8, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Col, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 64, 64, 8, 8, 32, 32, 2, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Col, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 64, 128, 64, 8, 8, 32, 32, 1, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Col, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 128, 128, 64, 8, 8, 32, 32, 4, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Col, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 128, 64, 64, 8, 8, 32, 32, 2, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Col, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 64, 128, 64, 8, 8, 32, 32, 2, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Col, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 128, 32, 64, 8, 8, 32, 32, 2, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Col, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 32, 128, 64, 8, 8, 32, 32, 1, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Col, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 32, 256, 64, 8, 8, 32, 32, 1, 4, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Col, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 64, 64, 64, 64, 8, 8, 32, 32, 2, 2, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Col, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 64, 64, 32, 64, 8, 8, 32, 32, 2, 1, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Col, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 64, 32, 64, 64, 8, 8, 32, 32, 1, 2, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8> + // clang-format on + >; + +void add_device_grouped_gemm_xdl_tile_loop_f16_f16_f16_mk_nk_mn_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_gemm_xdl_tile_loop_f16_f16_f16_mk_nk_mn_irregular_tile_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/profiler/include/profiler/profile_grouped_gemm_fixed_nk_impl.hpp b/profiler/include/profiler/profile_grouped_gemm_fixed_nk_impl.hpp index 5d2b7e0d9b..67fba43d64 100644 --- a/profiler/include/profiler/profile_grouped_gemm_fixed_nk_impl.hpp +++ b/profiler/include/profiler/profile_grouped_gemm_fixed_nk_impl.hpp @@ -73,9 +73,11 @@ bool profile_grouped_gemm_fixed_nk_impl(int do_verification, std::vector> b_k_n; std::vector> c_m_n_host_results; std::vector> c_m_n_device_results; + int sum_of_m = 0; for(std::size_t i = 0; i < group_count; i++) { + sum_of_m += Ms[i]; a_m_k.push_back( Tensor(f_host_tensor_descriptor(Ms[i], Ks[i], StrideAs[i], ALayout{}))); b_k_n.push_back( @@ -146,7 +148,7 @@ bool profile_grouped_gemm_fixed_nk_impl(int do_verification, a_device_buf[i]->ToDevice(a_m_k[i].mData.data()); b_device_buf[i]->ToDevice(b_k_n[i].mData.data()); - gemm_descs.push_back({Ms[i], Ns[i], Ks[i], StrideAs[i], StrideBs[i], StrideCs[i], {}}); + gemm_descs.push_back({sum_of_m, Ns[i], Ks[i], StrideAs[i], StrideBs[i], StrideCs[i], {}}); p_a.push_back(a_device_buf[i]->GetDeviceBuffer()); p_b.push_back(b_device_buf[i]->GetDeviceBuffer()); diff --git a/profiler/include/profiler/profile_grouped_gemm_tile_loop_impl.hpp b/profiler/include/profiler/profile_grouped_gemm_tile_loop_impl.hpp new file mode 100644 index 0000000000..3d7fa47077 --- /dev/null +++ b/profiler/include/profiler/profile_grouped_gemm_tile_loop_impl.hpp @@ -0,0 +1,319 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +#include "ck/ck.hpp" +#include "ck/host_utility/hip_check_error.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_gemm_tile_loop.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/gpu/grouped_gemm_tile_loop.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/utility/fill.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" + +namespace ck { +namespace profiler { + +template +bool profile_grouped_gemm_tile_loop_impl(int do_verification, + int init_method, + bool do_log, + bool time_kernel, + const std::vector& Ms, + const std::vector& Ns, + const std::vector& Ks, + const std::vector& StrideAs, + const std::vector& StrideBs, + const std::vector& StrideCs, + int n_warmup = 10, + int n_iter = 50) +{ + bool pass = true; + + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + using namespace ck::literals; + + if(is_same::value) + { + return HostTensorDescriptor({row, col}, {stride, 1_uz}); + } + else + { + return HostTensorDescriptor({row, col}, {1_uz, stride}); + } + }; + + std::size_t group_count = Ms.size(); + + if(!(group_count == Ns.size() && group_count == Ks.size() && group_count == StrideAs.size() && + group_count == StrideBs.size() && group_count == StrideCs.size())) + { + throw std::runtime_error("wrong! inconsistent M/N/Ks, StrideA/B/Cs size\n"); + } + + std::vector> a_m_k; + std::vector> b_k_n; + std::vector> c_m_n_host_results; + std::vector> c_m_n_device_results; + + for(std::size_t i = 0; i < group_count; i++) + { + a_m_k.push_back( + Tensor(f_host_tensor_descriptor(Ms[i], Ks[i], StrideAs[i], ALayout{}))); + b_k_n.push_back( + Tensor(f_host_tensor_descriptor(Ks[i], Ns[i], StrideBs[i], BLayout{}))); + c_m_n_device_results.push_back( + Tensor(f_host_tensor_descriptor(Ms[i], Ns[i], StrideCs[i], CLayout{}))); + c_m_n_host_results.push_back( + Tensor(f_host_tensor_descriptor(Ms[i], Ns[i], StrideCs[i], CLayout{}))); +#if DEBUG_LOG + std::cout << "group: " << i << " a_m_k[" << i << "]:" << a_m_k[i].mDesc << ", b_k_n[" << i + << "]:" << b_k_n[i].mDesc << ", c_m_n_device_results[" << i + << "]:" << c_m_n_device_results[i].mDesc << std::endl; +#endif // DEBUG_LOG + switch(init_method) + { + case 0: break; + case 1: + ck::utils::FillUniformDistributionIntegerValue{-5, 5}(a_m_k[i]); + ck::utils::FillUniformDistributionIntegerValue{-5, 5}(b_k_n[i]); + break; + case 2: + ck::utils::FillUniformDistribution{.0, 1.}(a_m_k[i]); + ck::utils::FillUniformDistribution{-0.5, 0.5}(b_k_n[i]); + break; + default: + ck::utils::FillConstant{1}(a_m_k[i]); + ck::utils::FillConstant{1}(b_k_n[i]); + } + } + + using AElementOp = ck::tensor_operation::element_wise::PassThrough; + using BElementOp = ck::tensor_operation::element_wise::PassThrough; + using CElementOp = ck::tensor_operation::element_wise::PassThrough; + + const auto a_element_op = AElementOp{}; + const auto b_element_op = BElementOp{}; + const auto c_element_op = CElementOp{}; + + using DeviceMemPtr = std::unique_ptr; + std::vector a_device_buf, b_device_buf, c_device_buf; + + a_device_buf.reserve(group_count); + b_device_buf.reserve(group_count); + c_device_buf.reserve(group_count); + + std::vector p_a, p_b; + std::vector p_c; + + p_a.reserve(group_count); + p_b.reserve(group_count); + p_c.reserve(group_count); + + using KernelArguments = ck::tensor_operation::device::GroupedGemmTileLoopKernelArguments<>; + + std::vector gemm_descs; + std::vector gemm_kargs; + + gemm_descs.reserve(group_count); + gemm_kargs.reserve(group_count); + + for(std::size_t i = 0; i < group_count; i++) + { + a_device_buf.emplace_back( + std::make_unique(sizeof(ADataType) * a_m_k[i].mDesc.GetElementSpaceSize())); + b_device_buf.emplace_back( + std::make_unique(sizeof(BDataType) * b_k_n[i].mDesc.GetElementSpaceSize())); + c_device_buf.emplace_back(std::make_unique( + sizeof(CDataType) * c_m_n_device_results[i].mDesc.GetElementSpaceSize())); + + a_device_buf[i]->ToDevice(a_m_k[i].mData.data()); + b_device_buf[i]->ToDevice(b_k_n[i].mData.data()); + c_device_buf[i]->SetZero(); + + p_a.push_back(a_device_buf[i]->GetDeviceBuffer()); + p_b.push_back(b_device_buf[i]->GetDeviceBuffer()); + p_c.push_back(c_device_buf[i]->GetDeviceBuffer()); + + gemm_descs.push_back({0, Ns[i], Ks[i], StrideAs[i], StrideBs[i], StrideCs[i], {}}); + gemm_kargs.push_back({a_device_buf[i]->GetDeviceBuffer(), + b_device_buf[i]->GetDeviceBuffer(), + {}, + c_device_buf[i]->GetDeviceBuffer(), + Ms[i], + Ns[i], + Ks[i], + StrideAs[i], + StrideBs[i], + {}, + StrideCs[i]}); + } + + using DeviceOp = ck::tensor_operation::device::DeviceGroupedGemmTileLoop, + CLayout, + ADataType, + BDataType, + ck::Tuple<>, + CDataType, + AElementOp, + BElementOp, + CElementOp>; + + const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + if(op_ptrs.size() <= 0) + { + throw std::runtime_error("wrong! no device GEMM instance found"); + } + + std::string best_gemm_name; + float best_ave_time = 0; + float best_tflops = 0; + float best_gb_per_sec = 0; + + auto p_ds = std::vector>{}; + + if(do_verification) + { + for(std::size_t i = 0; i < gemm_descs.size(); i++) + { + using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; + + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + auto ref_argument = ref_gemm.MakeArgument(a_m_k[i], + b_k_n[i], + c_m_n_host_results[i], + a_element_op, + b_element_op, + c_element_op); + ref_invoker.Run(ref_argument); + } + } + + // profile device GEMM instances + for(auto& gemm_ptr : op_ptrs) + { + auto argument_ptr = + gemm_ptr->MakeArgumentPointer(p_a, + p_b, + p_ds, + p_c, + gemm_descs, + ck::tensor_operation::element_wise::PassThrough{}, + ck::tensor_operation::element_wise::PassThrough{}, + ck::tensor_operation::element_wise::PassThrough{}); + auto invoker_ptr = gemm_ptr->MakeInvokerPointer(); + std::string gemm_name = gemm_ptr->GetTypeString(); + + DeviceMem gemm_arg_dev_mem(gemm_ptr->GetDeviceKernelArgSize(argument_ptr.get())); + hip_check_error(hipMemcpy(gemm_arg_dev_mem.GetDeviceBuffer(), + gemm_kargs.data(), + gemm_ptr->GetDeviceKernelArgSize(argument_ptr.get()), + hipMemcpyHostToDevice)); + gemm_ptr->SetDeviceKernelArgs(argument_ptr.get(), gemm_arg_dev_mem.GetDeviceBuffer()); + + if(gemm_ptr->IsSupportedArgument(argument_ptr.get())) + { + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false, 0, n_warmup, n_iter}); + if(do_verification) + { + bool instance_pass = true; + for(std::size_t i = 0; i < gemm_descs.size(); i++) + { + c_device_buf[i]->FromDevice(c_m_n_device_results[i].mData.data()); + instance_pass = instance_pass && ck::utils::check_err(c_m_n_device_results[i], + c_m_n_host_results[i]); + + if(do_log) + { + LogRangeAsType(std::cout << "a : ", a_m_k[i].mData, ",") + << std::endl; + LogRangeAsType(std::cout << "b: ", b_k_n[i].mData, ",") << std::endl; + LogRangeAsType( + std::cout << "c_device: ", c_m_n_device_results[i].mData, ",") + << std::endl; + LogRangeAsType( + std::cout << "c_host : ", c_m_n_host_results[i].mData, ",") + << std::endl; + } + } + + std::cout << "Instance: " << gemm_name << " verification " + << (instance_pass ? "SUCCEED" : "FAILED") << std::endl; + + pass = pass && instance_pass; + } + + if(time_kernel) + { + float ave_time = invoker_ptr->Run( + argument_ptr.get(), StreamConfig{nullptr, time_kernel, 0, n_warmup, n_iter}); + + std::size_t flop = 0, num_btype = 0; + for(std::size_t i = 0; i < gemm_descs.size(); i++) + { + flop += std::size_t(2) * Ms[i] * Ns[i] * Ks[i]; + + num_btype += sizeof(ADataType) * Ms[i] * Ks[i] + + sizeof(BDataType) * Ks[i] * Ns[i] + + sizeof(CDataType) * Ms[i] * Ns[i]; + } + + float tflops = static_cast(flop) / 1.E9 / ave_time; + float gb_per_sec = num_btype / 1.E6 / ave_time; + std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops + << " TFlops, " << gb_per_sec << " GB/s, " << gemm_name << std::endl; + + if(tflops > best_tflops) + { + best_gemm_name = gemm_name; + best_tflops = tflops; + best_ave_time = ave_time; + best_gb_per_sec = gb_per_sec; + } + } + } + else + { + std::cout << "Instance: " << gemm_name << ", does not support this GEMM problem" + << std::endl; + } + } + + if(time_kernel) + { + std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, " + << best_gb_per_sec << " GB/s, " << best_gemm_name << std::endl; + } + + return pass; +} + +} // namespace profiler +} // namespace ck diff --git a/profiler/src/CMakeLists.txt b/profiler/src/CMakeLists.txt index ce813d05a1..1cfcbfff64 100644 --- a/profiler/src/CMakeLists.txt +++ b/profiler/src/CMakeLists.txt @@ -42,6 +42,7 @@ if(GPU_TARGETS MATCHES "gfx9") list(APPEND PROFILER_SOURCES profile_grouped_gemm_fixed_nk.cpp) list(APPEND PROFILER_SOURCES profile_grouped_gemm_two_stage.cpp) list(APPEND PROFILER_SOURCES profile_grouped_gemm_fastgelu.cpp) + list(APPEND PROFILER_SOURCES profile_grouped_gemm_tile_loop.cpp) endif() list(APPEND PROFILER_SOURCES profile_gemm_multiply_add.cpp) list(APPEND PROFILER_SOURCES profile_batched_gemm.cpp) @@ -111,6 +112,7 @@ if(GPU_TARGETS MATCHES "gfx9") target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_add_relu_add_layernorm_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_gemm_fixed_nk_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_gemm_fastgelu_instance) + target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_gemm_tile_loop_instance) endif() target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batched_gemm_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batched_gemm_reduce_instance) diff --git a/profiler/src/profile_grouped_gemm_tile_loop.cpp b/profiler/src/profile_grouped_gemm_tile_loop.cpp new file mode 100644 index 0000000000..76ff9e162e --- /dev/null +++ b/profiler/src/profile_grouped_gemm_tile_loop.cpp @@ -0,0 +1,152 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include +#include + +#include "profiler/profile_grouped_gemm_tile_loop_impl.hpp" +#include "profiler_operation_registry.hpp" + +enum struct GemmMatrixLayout +{ + MK_KN_MN, // 0 + MK_NK_MN, // 0 +}; + +enum struct GemmDataType +{ + F16_F16_F16, // 0 +}; + +#define OP_NAME "grouped_gemm_tile_loop" +#define OP_DESC "Grouped GEMM Multiple D Tile Loop" + +namespace { + +std::vector argToIntArray(char* input) +{ + std::vector out; + std::istringstream in(input); + std::string item; + + while(std::getline(in, item, ',')) + { + out.push_back(std::stoi(item)); + } + return out; +} + +int profile_grouped_gemm_tile_loop(int argc, char* argv[]) +{ + if(argc < 14) + { + std::cout + << "arg1: tensor operation (" OP_NAME ": " OP_DESC ")\n" + << "arg2: data type (0: fp16)\n" + << "arg3: matrix layout (0: A[m, k] * B[k, n] = C[m, n]);\n" + << " 1: A[m, k] * B[n, k] = C[m, n];\n" + << "arg4: verification (0: no; 1: yes)\n" + << "arg5: initialization (0: no init; 1: integer value; 2: decimal value)\n" + << "arg6: print tensor value (0: no; 1: yes)\n" + << "arg7: time kernel (0=n0, 1=yes)\n" + << "arg8 to 13: Ms, Ns, Ks, StrideAs, StrideBs, StrideCs (e.g., 256,256 128,128 64,64 " + "64,64 64,64 128,128)\n" + << "optional:\n" + << "arg14: number of warm-up cycles (default 1)\n" + << "arg15: number of iterations (default 10)\n" + << std::endl; + + exit(1); + } + + const auto data_type = static_cast(std::stoi(argv[2])); + const auto layout = static_cast(std::stoi(argv[3])); + const bool do_verification = std::stoi(argv[4]); + const int init_method = std::stoi(argv[5]); + const bool do_log = std::stoi(argv[6]); + const bool time_kernel = std::stoi(argv[7]); + + const auto Ms = argToIntArray(argv[8]); + const auto Ns = argToIntArray(argv[9]); + const auto Ks = argToIntArray(argv[10]); + + auto StrideAs = argToIntArray(argv[11]); + auto StrideBs = argToIntArray(argv[12]); + auto StrideCs = argToIntArray(argv[13]); + + const int DefaultStrideA = Ks[0]; + const int DefaultStrideB = Ns[0]; + const int DefaultStrideC = Ns[0]; + + for(size_t i = 0; i < Ms.size(); ++i) + { + StrideAs[i] = StrideAs[i] == -1 ? DefaultStrideA : StrideAs[i]; + StrideBs[i] = StrideBs[i] == -1 ? DefaultStrideB : StrideBs[i]; + StrideCs[i] = StrideCs[i] == -1 ? DefaultStrideC : StrideCs[i]; + } + + int n_warmup = 10; + int n_iter = 50; + if(argc == 16) + { + n_warmup = std::stoi(argv[14]); + n_iter = std::stoi(argv[15]); + } + + if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_KN_MN) + { + ck::profiler::profile_grouped_gemm_tile_loop_impl( + do_verification, + init_method, + do_log, + time_kernel, + Ms, + Ns, + Ks, + StrideAs, + StrideBs, + StrideCs, + n_warmup, + n_iter); + } + else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_NK_MN) + { + ck::profiler::profile_grouped_gemm_tile_loop_impl( + do_verification, + init_method, + do_log, + time_kernel, + Ms, + Ns, + Ks, + StrideAs, + StrideBs, + StrideCs, + n_warmup, + n_iter); + } + else + { + throw std::runtime_error("wrong! this GEMM data_type & layout is not implemented"); + } + return 0; +} + +} // anonymous namespace + +REGISTER_PROFILER_OPERATION(OP_NAME, OP_DESC, profile_grouped_gemm_tile_loop); diff --git a/test/normalization_bwd_data/CMakeLists.txt b/test/normalization_bwd_data/CMakeLists.txt index 65f33da74d..fb7ad81e19 100644 --- a/test/normalization_bwd_data/CMakeLists.txt +++ b/test/normalization_bwd_data/CMakeLists.txt @@ -1,8 +1,13 @@ add_custom_target(test_normalization_bwd_data) + add_gtest_executable(test_layernorm2d_bwd_data_fp32 test_layernorm2d_bwd_data_fp32.cpp) -target_link_libraries(test_layernorm2d_bwd_data_fp32 PRIVATE utility device_normalization_bwd_data_instance) -add_dependencies(test_normalization_bwd_data test_layernorm2d_bwd_data_fp32) +if (result EQUAL 0) + target_link_libraries(test_layernorm2d_bwd_data_fp32 PRIVATE utility device_normalization_bwd_data_instance) + add_dependencies(test_normalization_bwd_data test_layernorm2d_bwd_data_fp32) +endif() add_gtest_executable(test_groupnorm_bwd_data_fp32 test_groupnorm_bwd_data_fp32.cpp) -target_link_libraries(test_groupnorm_bwd_data_fp32 PRIVATE utility device_normalization_bwd_data_instance) -add_dependencies(test_normalization_bwd_data test_groupnorm_bwd_data_fp32) +if (result EQUAL 0) + target_link_libraries(test_groupnorm_bwd_data_fp32 PRIVATE utility device_normalization_bwd_data_instance) + add_dependencies(test_normalization_bwd_data test_groupnorm_bwd_data_fp32) +endif() diff --git a/test/normalization_bwd_gamma_beta/CMakeLists.txt b/test/normalization_bwd_gamma_beta/CMakeLists.txt index afb78dc58e..81b6d377ce 100644 --- a/test/normalization_bwd_gamma_beta/CMakeLists.txt +++ b/test/normalization_bwd_gamma_beta/CMakeLists.txt @@ -1,8 +1,11 @@ add_custom_target(test_normalization_bwd_gamma_beta) add_gtest_executable(test_layernorm2d_bwd_gamma_beta_fp32 test_layernorm2d_bwd_gamma_beta_fp32.cpp) -target_link_libraries(test_layernorm2d_bwd_gamma_beta_fp32 PRIVATE utility device_normalization_bwd_gamma_beta_instance) -add_dependencies(test_normalization_bwd_gamma_beta test_layernorm2d_bwd_gamma_beta_fp32) - +if (result EQUAL 0) + target_link_libraries(test_layernorm2d_bwd_gamma_beta_fp32 PRIVATE utility device_normalization_bwd_gamma_beta_instance) + add_dependencies(test_normalization_bwd_gamma_beta test_layernorm2d_bwd_gamma_beta_fp32) +endif() add_gtest_executable(test_groupnorm_bwd_gamma_beta_fp32 test_groupnorm_bwd_gamma_beta_fp32.cpp) -target_link_libraries(test_groupnorm_bwd_gamma_beta_fp32 PRIVATE utility device_normalization_bwd_gamma_beta_instance) -add_dependencies(test_normalization_bwd_gamma_beta test_groupnorm_bwd_gamma_beta_fp32) +if (result EQUAL 0) + target_link_libraries(test_groupnorm_bwd_gamma_beta_fp32 PRIVATE utility device_normalization_bwd_gamma_beta_instance) + add_dependencies(test_normalization_bwd_gamma_beta test_groupnorm_bwd_gamma_beta_fp32) +endif() \ No newline at end of file