// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT #pragma once #include #include #include "ck_tile/core.hpp" #include "ck_tile/ops/common.hpp" #include "ck_tile/host/concat.hpp" #include "ck_tile/host/kernel_launch.hpp" #include "ck_tile/host/stream_utils.hpp" #include "ck_tile/core/utility/env.hpp" #include "ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp" #include "ck_tile/core/utility/type_traits.hpp" namespace ck_tile { /// @brief The GEMM kernel host arguments. /// /// @par Overview /// This structure is passed to @ref GemmKernel "GemmKernel" when creating kernel arguments /// object. It contain all necessary information required to build proper kernel argument /// and launch kernel on GPU. /// This structure defines the GEMM problem configuration by stating all required information /// like M,N,K sizes and respective strides. struct GemmHostArgs { CK_TILE_HOST GemmHostArgs() = default; CK_TILE_HOST GemmHostArgs(const void* a_ptr_, const void* b_ptr_, void* e_ptr_, index_t k_batch_, index_t M_, index_t N_, index_t K_, index_t stride_A_, index_t stride_B_, index_t stride_E_) : a_ptr(a_ptr_), b_ptr(b_ptr_), e_ptr(e_ptr_), M(M_), N(N_), K(K_), stride_A(stride_A_), stride_B(stride_B_), stride_E(stride_E_), k_batch(k_batch_) { } const void* a_ptr; const void* b_ptr; union { void* e_ptr; void* c_ptr; }; index_t M; index_t N; index_t K; index_t stride_A; index_t stride_B; union { index_t stride_E; index_t stride_C; }; index_t k_batch; }; template struct GemmKernel { /// @brief Inject the UniversalGemmKernel base class to support execution of all necessary /// functions. using UniversalGemmKernel = UniversalGemmKernel; using TilePartitioner = remove_cvref_t; using GemmPipeline = remove_cvref_t; using EpiloguePipeline = remove_cvref_t; /// @brief Specify the layout configurations for A, B, E and D using ALayout = remove_cvref_t; using BLayout = remove_cvref_t; using CLayout = remove_cvref_t; /// @brief Specify the data type configurations for A, B, E and D using ADataType = remove_cvref_t; using BDataType = remove_cvref_t; using EDataType = remove_cvref_t; /// @brief ALayout and ADataType are expected to be scalars, not a tuple. static_assert( !is_detected::value && !is_detected::value, "ALayout and ADataType must be scalars. Multiple parameters are not currently supported."); /// @brief BLayout and BDataType are expected to be scalars, not a tuple. static_assert( !is_detected::value && !is_detected::value, "BLayout and BDataType must be scalars. Multiple parameters are not currently supported."); /// @brief C/CLayout and C/EDataType are expected to be scalars, not a tuple. static_assert(!is_detected::value && !is_detected::value, "C/CLayout and C/EDataType must be scalars."); static constexpr index_t NumATensor = 1; static constexpr index_t NumBTensor = 1; static constexpr index_t kBlockSize = UniversalGemmKernel::kBlockSize; CK_TILE_HOST static auto GetName() -> const std::string { return UniversalGemmKernel::GetName(); } CK_TILE_HOST static constexpr auto GridSize(index_t M, index_t N, index_t KBatch) -> dim3 { return UniversalGemmKernel::GridSize(M, N, KBatch); } CK_TILE_HOST static auto MaxOccupancyGridSize(const stream_config& s) -> dim3 { return UniversalGemmKernel::MaxOccupancyGridSize(s); } CK_TILE_HOST static constexpr auto BlockSize() -> dim3 { return UniversalGemmKernel::BlockSize(); } CK_TILE_HOST static constexpr auto MakeKernelArgs(const GemmHostArgs& hostArgs) -> typename UniversalGemmKernel::KernelArgs { /// @brief Universal GEMM requires array objects and corresponding stride information for /// matrices A, B. return UniversalGemmKernel::MakeKernelArgs( UniversalGemmHostArgs( {hostArgs.a_ptr}, {hostArgs.b_ptr}, {/*hostArgs.ds_ptr*/}, hostArgs.e_ptr, hostArgs.k_batch, hostArgs.M, hostArgs.N, hostArgs.K, {hostArgs.stride_A}, {hostArgs.stride_B}, {/*hostArgs.stride_Ds*/}, hostArgs.stride_E)); } CK_TILE_HOST static auto IsSupportedArgument(const typename UniversalGemmKernel::KernelArgs& kargs) -> bool { return UniversalGemmKernel::IsSupportedArgument(kargs); } CK_TILE_DEVICE auto operator()(typename UniversalGemmKernel::KernelArgs kargs) const -> void { UniversalGemmKernel{}.template operator()(kargs); } }; } // namespace ck_tile