// SPDX-License-Identifier: MIT // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once #include #include #include "ck_tile/core.hpp" #include "ck_tile/ops/common.hpp" #include "ck_tile/host/concat.hpp" namespace ck_tile { struct GemmProblem { CK_TILE_HOST GemmProblem() = default; CK_TILE_HOST GemmProblem( index_t M_, index_t N_, index_t K_, index_t stride_A_, index_t stride_B_, index_t stride_C_) : M(M_), N(N_), K(K_), stride_A(stride_A_), stride_B(stride_B_), stride_C(stride_C_) { } index_t M; index_t N; index_t K; index_t stride_A; index_t stride_B; index_t stride_C; }; struct GemmHostArgs : public GemmProblem { CK_TILE_HOST GemmHostArgs() = default; CK_TILE_HOST GemmHostArgs(const void* a_ptr_, const void* b_ptr_, void* c_ptr_, index_t k_batch_, index_t M_, index_t N_, index_t K_, index_t stride_A_, index_t stride_B_, index_t stride_C_) : GemmProblem(M_, N_, K_, stride_A_, stride_B_, stride_C_), a_ptr(a_ptr_), b_ptr(b_ptr_), c_ptr(c_ptr_), k_batch(k_batch_) { } const void* a_ptr; const void* b_ptr; void* c_ptr; index_t k_batch; }; struct GemmKernelArgs { const void* a_ptr; const void* b_ptr; void* c_ptr; index_t M; index_t N; index_t K; index_t stride_A; index_t stride_B; index_t stride_C; index_t k_batch; }; template struct GemmKernel { using TilePartitioner = remove_cvref_t; using GemmPipeline = remove_cvref_t; using EpiloguePipeline = remove_cvref_t; using ALayout = remove_cvref_t; using BLayout = remove_cvref_t; using CLayout = remove_cvref_t; static constexpr index_t KernelBlockSize = GemmPipeline::BlockSize; using ADataType = remove_cvref_t; using BDataType = remove_cvref_t; // Below type is actually accumulation data type - the output of block GEMM. using CDataType = remove_cvref_t; static constexpr auto I0 = number<0>(); static constexpr auto I1 = number<1>(); static constexpr auto I2 = number<2>(); [[nodiscard]] CK_TILE_HOST static const std::string GetName() { // clang-format off return concat('_', "gemm", gemm_prec_str, GemmPipeline::GetName()); // clang-format on } CK_TILE_HOST static constexpr auto GridSize(index_t M, index_t N, index_t KBatch) { return dim3(TilePartitioner::GridSize(M, N), 1, KBatch); } CK_TILE_HOST static constexpr auto BlockSize() { return dim3(KernelBlockSize); } CK_TILE_HOST static constexpr GemmKernelArgs MakeKernelArgs(const GemmHostArgs& hostArgs) { return GemmKernelArgs{hostArgs.a_ptr, hostArgs.b_ptr, hostArgs.c_ptr, hostArgs.M, hostArgs.N, hostArgs.K, hostArgs.stride_A, hostArgs.stride_B, hostArgs.stride_C, hostArgs.k_batch}; } CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() { return max(GemmPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize()); } struct SplitKBatchOffset { __device__ SplitKBatchOffset(const GemmKernelArgs& kargs, const std::size_t k_id = blockIdx.z) { constexpr auto K1 = TilePartitioner::BlockGemmShape::WarpTile::at(number<2>{}); const index_t K_t = __builtin_amdgcn_readfirstlane(kargs.k_batch * K1); const index_t KRead = __builtin_amdgcn_readfirstlane((kargs.K + K_t - 1) / K_t * K1); if constexpr(std::is_same_v) { a_k_split_offset = __builtin_amdgcn_readfirstlane(k_id * KRead); } else if constexpr(std::is_same_v) { a_k_split_offset = __builtin_amdgcn_readfirstlane(k_id * KRead * kargs.stride_A); } if constexpr(std::is_same_v) { b_k_split_offset = __builtin_amdgcn_readfirstlane(k_id * KRead * kargs.stride_B); } else if constexpr(std::is_same_v) { b_k_split_offset = __builtin_amdgcn_readfirstlane(k_id * KRead); } if(k_id < static_cast(kargs.k_batch - 1)) { splitted_k = __builtin_amdgcn_readfirstlane(KRead); } else { splitted_k = __builtin_amdgcn_readfirstlane(kargs.K - KRead * (kargs.k_batch - 1)); } } index_t a_k_split_offset; index_t b_k_split_offset; index_t splitted_k; }; CK_TILE_HOST static bool IsSupportedArgument(const GemmKernelArgs& kargs) { if constexpr(EpiloguePipeline::GetVectorSizeC() % 2 != 0 && is_any_of::value) { if(kargs.k_batch != 1) { if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) { CK_TILE_ERROR("Conditions not met for Kbatch >1 !"); } return false; } } if constexpr(std::is_same_v) { if(kargs.K % (TilePartitioner::KPerBlock * kargs.k_batch) != 0 && GemmPipeline::kPadK == false) { if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) { CK_TILE_ERROR("Can't support K that is not a multiple of k_batch * KPerBlock " "without padding!"); } return false; } if(kargs.K % GemmPipeline::GetVectorSizeA() != 0) { if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) { CK_TILE_ERROR("K is not a multiple of vector load size for A tensor!"); } return false; } } else { if(kargs.M % TilePartitioner::MPerBlock != 0 && GemmPipeline::kPadM == false) { if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) { CK_TILE_ERROR( "Can't support M that is not a multiple of MPerBlock without padding!"); } return false; } if(kargs.M % GemmPipeline::GetVectorSizeA() != 0) { if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) { CK_TILE_ERROR("M is not a multiple of vector load size for A tensor!"); } return false; } } if constexpr(std::is_same_v) { if(kargs.N % TilePartitioner::NPerBlock != 0 && GemmPipeline::kPadN == false) { if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) { CK_TILE_ERROR( "Can't support N that is not a multiple of NPerBlock without padding!"); } return false; } if(kargs.N % GemmPipeline::GetVectorSizeB() != 0) { if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) { CK_TILE_ERROR("N is not a multiple of vector load size for B tensor!"); } return false; } } else { if(kargs.K % (TilePartitioner::KPerBlock * kargs.k_batch) != 0 && GemmPipeline::kPadK == false) { if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) { CK_TILE_ERROR("Can't support K that is not a multiple of k_batch * KPerBlock " "without padding!"); } return false; } if(kargs.K % GemmPipeline::GetVectorSizeB() != 0) { if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) { CK_TILE_ERROR("K is not a multiple of vector load size for B tensor!"); } return false; } } if constexpr(std::is_same_v) { if(kargs.N % TilePartitioner::NPerBlock != 0 && GemmPipeline::kPadN == false) { if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) { CK_TILE_ERROR( "Can't support N that is not a multiple of NPerBlock without padding!"); } return false; } if(kargs.N % EpiloguePipeline::GetVectorSizeC() != 0) { if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) { CK_TILE_ERROR("N is not a multiple of vector load size for C tensor!"); } return false; } } else { if(kargs.M % TilePartitioner::MPerBlock != 0 && GemmPipeline::kPadM == false) { if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) { CK_TILE_ERROR( "Can't support M that is not a multiple of MPerBlock without padding!"); } return false; } if(kargs.M % EpiloguePipeline::GetVectorSizeC() != 0) { if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) { CK_TILE_ERROR("M is not a multiple of vector load size for C tensor!"); } return false; } } return true; } template CK_TILE_DEVICE static auto MakeGemmTensorViews(const ADataType* a_ptr, const BDataType* b_ptr, CDataType* c_ptr, const GemmKernelArgs& kargs, const SplitKBatchOffset& splitk_batch_offset) { static_assert(!TilePartitioner::BlockGemmShape::PermuteA, "Not implemented!"); const auto& a_tensor_view = [&]() { if constexpr(std::is_same_v) { return make_naive_tensor_view( a_ptr, make_tuple(kargs.M, splitk_batch_offset.splitted_k), make_tuple(kargs.stride_A, 1), number{}, number<1>{}); } else { return make_naive_tensor_view( a_ptr, make_tuple(splitk_batch_offset.splitted_k, kargs.M), make_tuple(kargs.stride_A, 1), number{}, number<1>{}); } }(); const auto& b_tensor_view = [&]() { if constexpr(std::is_same_v) { if constexpr(TilePartitioner::BlockGemmShape::PermuteB) { constexpr index_t K1 = GemmPipeline::GetSmemPackB(); const index_t K0 = splitk_batch_offset.splitted_k / K1; constexpr index_t VectorSizeB = std::min(K1, GemmPipeline::GetVectorSizeB()); const auto b_k0_n_k1_desc = make_naive_tensor_descriptor(make_tuple(K0, kargs.N, K1), make_tuple(kargs.N * K1, K1, I1), number{}, number<1>{}); const auto b_n_k_desc = transform_tensor_descriptor( b_k0_n_k1_desc, make_tuple(make_merge_transform(make_tuple(K0, K1)), make_pass_through_transform(kargs.N)), make_tuple(sequence<0, 2>{}, sequence<1>{}), make_tuple(sequence<0>{}, sequence<1>{})); return make_tensor_view(b_ptr, b_n_k_desc); } else { return make_naive_tensor_view( b_ptr, make_tuple(splitk_batch_offset.splitted_k, kargs.N), make_tuple(kargs.stride_B, 1), number{}, number<1>{}); } } else { if constexpr(TilePartitioner::BlockGemmShape::PermuteB) { constexpr index_t K1 = GemmPipeline::GetSmemPackB(); const index_t K0 = splitk_batch_offset.splitted_k / K1; constexpr index_t VectorSizeB = std::min(K1, GemmPipeline::GetVectorSizeB()); const auto b_k0_n_k1_desc = make_naive_tensor_descriptor(make_tuple(K0, kargs.N, K1), make_tuple(kargs.N * K1, K1, I1), number{}, number<1>{}); const auto b_n_k_desc = transform_tensor_descriptor( b_k0_n_k1_desc, make_tuple(make_merge_transform(make_tuple(K0, K1)), make_pass_through_transform(kargs.N)), make_tuple(sequence<0, 2>{}, sequence<1>{}), make_tuple(sequence<1>{}, sequence<0>{})); return make_tensor_view(b_ptr, b_n_k_desc); } else { return make_naive_tensor_view( b_ptr, make_tuple(kargs.N, splitk_batch_offset.splitted_k), make_tuple(kargs.stride_B, 1), number{}, number<1>{}); } } }(); // TODO: enable vector write for C in ColMajor const auto& c_tensor_view = [&]() { if constexpr(std::is_same_v) { return make_naive_tensor_view( c_ptr, make_tuple(kargs.M, kargs.N), make_tuple(kargs.stride_C, 1), number{}, number<1>{}); } else { return make_naive_tensor_view( c_ptr, make_tuple(kargs.M, kargs.N), make_tuple(1, kargs.stride_C), number<1>{}, number<1>{}); } }(); return make_tuple(a_tensor_view, b_tensor_view, c_tensor_view); } template CK_TILE_DEVICE static auto MakeGemmPadViews(const TensorView& views) { const auto& a_pad_view = [&]() { const auto& a_tensor_view = views.at(I0); if constexpr(std::is_same_v) { return pad_tensor_view(a_tensor_view, make_tuple(number{}, number{}), sequence{}); } else { return pad_tensor_view(a_tensor_view, make_tuple(number{}, number{}), sequence{}); } }(); const auto& b_pad_view = [&]() { const auto& b_tensor_view = views.at(I1); if constexpr(std::is_same_v) { return pad_tensor_view(b_tensor_view, make_tuple(number{}, number{}), sequence{}); } else { return pad_tensor_view(b_tensor_view, make_tuple(number{}, number{}), sequence{}); } }(); // TODO vector write in for C in ColMajor const auto& c_pad_view = [&]() { const auto& c_tensor_view = views.at(I2); if constexpr(std::is_same_v) { return pad_tensor_view(c_tensor_view, make_tuple(number{}, number{}), sequence{}); } else { return pad_tensor_view(c_tensor_view, make_tuple(number{}, number{}), sequence{}); } }(); return make_tuple(a_pad_view, b_pad_view, c_pad_view); } template CK_TILE_DEVICE static auto MakeGemmTileWindows(const PadView& views, const index_t i_m, const index_t i_n) { const auto& a_pad_view = views.at(I0); const auto& b_pad_view = views.at(I1); const auto& c_pad_view = views.at(I2); const auto& a_block_window = [&]() { if constexpr(std::is_same_v) { return make_tile_window(a_pad_view, make_tuple(number{}, number{}), {i_m, 0}); } else { return make_tile_window(a_pad_view, make_tuple(number{}, number{}), {0, i_m}); } }(); const auto& b_block_window = [&]() { if constexpr(std::is_same_v) { return make_tile_window(b_pad_view, make_tuple(number{}, number{}), {i_n, 0}); } else { return make_tile_window(b_pad_view, make_tuple(number{}, number{}), {0, i_n}); } }(); auto c_block_window = make_tile_window( c_pad_view, make_tuple(number{}, number{}), {i_m, i_n}); return make_tuple(a_block_window, b_block_window, c_block_window); } /** * @brief Runs single GEMM problem cooperatively by whole workgroup. * * @param a_ptr input A pointer * @param b_ptr input B pointer * @param c_ptr output C pointer * @param smem_ptr_0 The start memory pointer of the shared memory block. * @param kargs GEMM kernel arguments * @param splitk_batch_offset splitk_batch_offset Utility structure used to calculate k batch. * @param block_idx_m The GEMM's output M dimension tile index processed by this workgroup. * @param block_idx_n The GEMM's output N dimension tile index processed by this workgroup. * * @tparam DstInMemOp Destination memory operation (default: set). */ template CK_TILE_DEVICE static void RunGemm(const ADataType* a_ptr, const BDataType* b_ptr, CDataType* c_ptr, void* smem_ptr_0, const GemmKernelArgs& kargs, const SplitKBatchOffset& splitk_batch_offset, const index_t block_idx_m, const index_t block_idx_n) { // Create Gemm tensor views, pad views and tile windows const auto& gemm_tensor_views_tuple = MakeGemmTensorViews(a_ptr, b_ptr, c_ptr, kargs, splitk_batch_offset); const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple); auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n); const index_t num_loop = __builtin_amdgcn_readfirstlane( TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k)); // Run GEMM cooperatively by whole workgroup. const auto& a_block_window = gemm_tile_windows.at(I0); const auto& b_block_window = gemm_tile_windows.at(I1); const auto& c_block_tile = GemmPipeline{}.template operator()( a_block_window, b_block_window, num_loop, smem_ptr_0); // Run Epilogue Pipeline auto& c_block_window = gemm_tile_windows.at(I2); EpiloguePipeline{} .template operator()( c_block_window, c_block_tile, smem_ptr_0); } /** * @brief Runs single GEMM problem cooperatively by whole workgroup. * * @note RunGEMM2LDS in with two shared memory buffers using the ping pong buffer mechanism. * * @param a_ptr input A pointer * @param b_ptr input B pointer * @param c_ptr output C pointer * @param smem_ptr_0 The starting pointer of 1st shared memory block. * @param smem_ptr_1 The starting pointer of 2nd shared memory block. * @param kargs GEMM kernel arguments * @param splitk_batch_offset Utility structure used to calculate k batch. * @param block_idx_m The GEMM's output M dimension tile index processed by this workgroup. * @param block_idx_n The GEMM's output N dimension tile index processed by this workgroup. * * @tparam DstInMemOp Destination memory operation (default: set). */ template CK_TILE_DEVICE static void RunGemm2LDS(const ADataType* a_ptr, const BDataType* b_ptr, CDataType* c_ptr, void* __restrict__ smem_ptr_0, void* __restrict__ smem_ptr_1, const GemmKernelArgs& kargs, const SplitKBatchOffset& splitk_batch_offset, const index_t block_idx_m, const index_t block_idx_n) { // Create Gemm tensor views, pad views and tile windows const auto& gemm_tensor_views_tuple = MakeGemmTensorViews(a_ptr, b_ptr, c_ptr, kargs, splitk_batch_offset); const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple); auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n); const index_t num_loop = __builtin_amdgcn_readfirstlane( TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k)); // Run GEMM cooperatively by whole workgroup. const auto& a_block_window = gemm_tile_windows.at(I0); const auto& b_block_window = gemm_tile_windows.at(I1); const auto& c_block_tile = GemmPipeline{}.template operator()( a_block_window, b_block_window, num_loop, smem_ptr_0, smem_ptr_1); // Run Epilogue Pipeline auto& c_block_window = gemm_tile_windows.at(I2); EpiloguePipeline{} .template operator()( c_block_window, c_block_tile, smem_ptr_0); } CK_TILE_DEVICE void operator()(GemmKernelArgs kargs) const { const auto blockId = __builtin_amdgcn_readfirstlane(blockIdx.x); const auto [iM, iN] = TilePartitioner{kargs.M, kargs.N}.GetOutputTileIndex(blockId); const index_t i_m = __builtin_amdgcn_readfirstlane(iM * TilePartitioner::MPerBlock); const index_t i_n = __builtin_amdgcn_readfirstlane(iN * TilePartitioner::NPerBlock); const SplitKBatchOffset splitk_batch_offset(kargs); // options const ADataType* a_ptr = static_cast(kargs.a_ptr) + splitk_batch_offset.a_k_split_offset; const BDataType* b_ptr = static_cast(kargs.b_ptr) + splitk_batch_offset.b_k_split_offset; CDataType* c_ptr = static_cast(kargs.c_ptr); // allocate LDS __shared__ char smem_ptr_0[GetSmemSize()]; if constexpr(GemmPipeline::DoubleSmemBuffer == true) { __shared__ char smem_ptr_1[GetSmemSize()]; if(kargs.k_batch == 1) { RunGemm2LDS(a_ptr, b_ptr, c_ptr, smem_ptr_0, smem_ptr_1, kargs, splitk_batch_offset, i_m, i_n); } else { if constexpr(!(EpiloguePipeline::GetVectorSizeC() % 2 != 0 && is_any_of::value)) { RunGemm2LDS(a_ptr, b_ptr, c_ptr, smem_ptr_0, smem_ptr_1, kargs, splitk_batch_offset, i_m, i_n); } } } else { if(kargs.k_batch == 1) { RunGemm(a_ptr, b_ptr, c_ptr, smem_ptr_0, kargs, splitk_batch_offset, i_m, i_n); } else { if constexpr(!(EpiloguePipeline::GetVectorSizeC() % 2 != 0 && is_any_of::value)) { RunGemm( a_ptr, b_ptr, c_ptr, smem_ptr_0, kargs, splitk_batch_offset, i_m, i_n); } } } } }; } // namespace ck_tile