// SPDX-License-Identifier: MIT // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once #include #include #include "ck_tile/core.hpp" #include "ck_tile/ops/common.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp" namespace ck_tile { struct FlatmmProblem { CK_TILE_HOST FlatmmProblem() = default; CK_TILE_HOST FlatmmProblem( index_t M_, index_t N_, index_t K_, index_t stride_A_, index_t stride_B_, index_t stride_C_) : M(M_), N(N_), K(K_), stride_A(stride_A_), stride_B(stride_B_), stride_C(stride_C_) { } index_t M; index_t N; index_t K; index_t stride_A; index_t stride_B; index_t stride_C; }; struct FlatmmHostArgs : public FlatmmProblem { CK_TILE_HOST FlatmmHostArgs() = default; CK_TILE_HOST FlatmmHostArgs(const void* a_ptr_, const void* b_shuffle_ptr_, void* c_ptr_, index_t k_batch_, index_t M_, index_t N_, index_t K_, index_t stride_A_, index_t stride_B_, index_t stride_C_) : FlatmmProblem(M_, N_, K_, stride_A_, stride_B_, stride_C_), a_ptr(a_ptr_), b_shuffle_ptr(b_shuffle_ptr_), c_ptr(c_ptr_), k_batch(k_batch_) { } const void* a_ptr; const void* b_shuffle_ptr; void* c_ptr; index_t k_batch; }; template struct FlatmmKernel { using TilePartitioner = remove_cvref_t; using FlatmmPipeline = remove_cvref_t; using BlockGemmShape = remove_cvref_t; // TileFlatmmShape using EpiloguePipeline = remove_cvref_t; using ALayout = remove_cvref_t; using BLayout = remove_cvref_t; using CLayout = remove_cvref_t; static constexpr index_t KernelBlockSize = FlatmmPipeline::BlockSize; using ADataType = remove_cvref_t; using BDataType = remove_cvref_t; // Below type is actually accumulation data type - the output of block GEMM. using CDataType = remove_cvref_t; static constexpr auto I0 = number<0>(); static constexpr auto I1 = number<1>(); static constexpr auto I2 = number<2>(); static constexpr auto idxM = I0; static constexpr auto idxN = I1; static constexpr auto idxK = I2; [[nodiscard]] CK_TILE_HOST static const std::string GetName() { // clang-format off return concat('_', "gemm", gemm_prec_str, FlatmmPipeline::GetName()); // clang-format on } CK_TILE_HOST static constexpr auto GridSize(index_t M, index_t N, index_t KBatch) { return dim3(TilePartitioner::GridSize(M, N), 1, KBatch); } CK_TILE_HOST static constexpr auto BlockSize() { return dim3(KernelBlockSize); } struct FlatmmKernelArgs { const void* a_ptr; const void* b_shuffle_ptr; void* c_ptr; index_t M; index_t N; index_t K; index_t stride_A; index_t stride_B; index_t stride_C; index_t k_batch; }; CK_TILE_HOST static constexpr FlatmmKernelArgs MakeKernelArgs(const FlatmmHostArgs& hostArgs) { return FlatmmKernelArgs{hostArgs.a_ptr, hostArgs.b_shuffle_ptr, hostArgs.c_ptr, hostArgs.M, hostArgs.N, hostArgs.K, hostArgs.stride_A, hostArgs.stride_B, hostArgs.stride_C, hostArgs.k_batch}; } CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() { return max(FlatmmPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize()); } struct SplitKBatchOffset { __device__ SplitKBatchOffset(const FlatmmKernelArgs& kargs, const std::size_t k_id = blockIdx.z) { constexpr auto K1 = TilePartitioner::BlockGemmShape::WarpTile::at(number<2>{}); const index_t K_t = kargs.k_batch * K1; const index_t KRead = (kargs.K + K_t - 1) / K_t * K1; if constexpr(std::is_same_v) { a_k_split_offset = k_id * KRead; } else if constexpr(std::is_same_v) { a_k_split_offset = k_id * KRead * kargs.stride_A; } if constexpr(std::is_same_v) { b_k_split_offset = k_id * KRead * kargs.stride_B; } else if constexpr(std::is_same_v) { b_k_split_offset = k_id * KRead; } if(k_id < static_cast(kargs.k_batch - 1)) { splitted_k = KRead; } else { splitted_k = kargs.K - KRead * (kargs.k_batch - 1); } } index_t a_k_split_offset; index_t b_k_split_offset; index_t splitted_k; }; CK_TILE_HOST static bool IsSupportedArgument(const FlatmmKernelArgs& kargs) { if constexpr(EpiloguePipeline::GetVectorSizeC() % 2 != 0 && is_any_of::value) { if(kargs.k_batch != 1) { std::cerr << "Conditions not met for Kbatch >1 !" << std::endl; return false; } } if constexpr(std::is_same_v) { if(kargs.K % TilePartitioner::KPerBlock != 0 && FlatmmPipeline::kPadK == false) { std::cerr << "Can't support K that is not a multiple of KPerBlock" " without padding!" << std::endl; return false; } if(kargs.K % FlatmmPipeline::GetVectorSizeA() != 0) { std::cerr << "K is not a multiple of vector load size for A tensor!" << std::endl; return false; } } else { if(kargs.M % TilePartitioner::MPerBlock != 0 && FlatmmPipeline::kPadM == false) { std::cerr << "Can't support M that is not a multiple of MPerBlock" " without padding!" << std::endl; return false; } if(kargs.M % FlatmmPipeline::GetVectorSizeA() != 0) { std::cerr << "M is not a multiple of vector load size for A tensor!" << std::endl; return false; } } if constexpr(std::is_same_v) { if(kargs.N % TilePartitioner::NPerBlock != 0 && FlatmmPipeline::kPadN == false) { std::cerr << "Can't support N that is not a multiple of NPerBlock" " without padding!" << std::endl; return false; } if(kargs.N % FlatmmPipeline::GetVectorSizeB() != 0) { std::cerr << "N is not a multiple of vector load size for B tensor!" << std::endl; return false; } } else { if(kargs.K % TilePartitioner::KPerBlock != 0 && FlatmmPipeline::kPadK == false) { std::cerr << "Can't support K that is not a multiple of KPerBlock" " without padding!" << std::endl; return false; } if(kargs.K % FlatmmPipeline::GetVectorSizeB() != 0) { std::cerr << "K is not a multiple of vector load size for B tensor!" << std::endl; return false; } } if constexpr(std::is_same_v) { if(kargs.N % TilePartitioner::NPerBlock != 0 && FlatmmPipeline::kPadN == false) { std::cerr << "Can't support N that is not a multiple of NPerBlock" " without padding!" << std::endl; return false; } if(kargs.N % EpiloguePipeline::GetVectorSizeC() != 0) { std::cerr << "N is not a multiple of vector load size for C tensor!" << std::endl; return false; } } else { if(kargs.M % TilePartitioner::MPerBlock != 0 && FlatmmPipeline::kPadM == false) { std::cerr << "Can't support M that is not a multiple of MPerBlock" " without padding!" << std::endl; return false; } if(kargs.M % EpiloguePipeline::GetVectorSizeC() != 0) { std::cerr << "M is not a multiple of vector load size for C tensor!" << std::endl; return false; } } return true; } template CK_TILE_DEVICE static auto MakeGemmTensorViews(const ADataType* a_ptr, const BDataType* b_flat_ptr, CDataType* c_ptr, const FlatmmKernelArgs& kargs, const SplitKBatchOffset& splitk_batch_offset) { const auto& a_tensor_view = [&]() { if constexpr(std::is_same_v) { return make_naive_tensor_view( a_ptr, make_tuple(kargs.M, splitk_batch_offset.splitted_k), make_tuple(kargs.stride_A, 1), number{}, number<1>{}); } else { return make_naive_tensor_view( a_ptr, make_tuple(splitk_batch_offset.splitted_k, kargs.M), make_tuple(kargs.stride_A, 1), number{}, number<1>{}); } }(); index_t kFlatK = FlatmmPipeline::flatKPerWarp * (splitk_batch_offset.splitted_k / BlockGemmShape::WarpTile::at(number<2>{})); index_t kFlatN = kargs.N * kargs.K / kFlatK; const auto& b_flat_tensor_view = [&]() { return make_naive_tensor_view( b_flat_ptr, make_tuple(kFlatN, kFlatK), make_tuple(kFlatK, 1), number{}, number<1>{}); }(); // TODO: enable vector write for C in ColMajor const auto& c_tensor_view = [&]() { if constexpr(std::is_same_v) { return make_naive_tensor_view( c_ptr, make_tuple(kargs.M, kargs.N), make_tuple(kargs.stride_C, 1), number{}, number<1>{}); } else { return make_naive_tensor_view( c_ptr, make_tuple(kargs.M, kargs.N), make_tuple(1, kargs.stride_C), number<1>{}, number<1>{}); } }(); return make_tuple(a_tensor_view, b_flat_tensor_view, c_tensor_view); } template CK_TILE_DEVICE static auto MakeGemmPadViews(const TensorView& views) { const auto& a_pad_view = [&]() { const auto& a_tensor_view = views.at(I0); if constexpr(std::is_same_v) { return pad_tensor_view(a_tensor_view, make_tuple(number{}, number{}), sequence{}); } else { return pad_tensor_view(a_tensor_view, make_tuple(number{}, number{}), sequence{}); } }(); const auto& b_flat_tensor_view = views.at(I1); // TODO vector write in for C in ColMajor const auto& c_pad_view = [&]() { const auto& c_tensor_view = views.at(I2); if constexpr(std::is_same_v) { return pad_tensor_view(c_tensor_view, make_tuple(number{}, number{}), sequence{}); } else { return pad_tensor_view(c_tensor_view, make_tuple(number{}, number{}), sequence{}); } }(); return make_tuple(a_pad_view, b_flat_tensor_view, c_pad_view); } template CK_TILE_DEVICE static auto MakeGemmTileWindows(const PadView& views, const index_t i_m, const index_t i_n) { const auto& a_pad_view = views.at(I0); const auto& b_flat_pad_view = views.at(I1); const auto& c_pad_view = views.at(I2); const auto& a_block_window = [&]() { if constexpr(std::is_same_v) { return make_tile_window(a_pad_view, make_tuple(number{}, number{}), {i_m, 0}); } else { return make_tile_window(a_pad_view, make_tuple(number{}, number{}), {0, i_m}); } }(); const auto& b_flat_block_window = make_tile_window(b_flat_pad_view, make_tuple(number{}, number{}), {static_cast(i_n / BlockGemmShape::WarpTile::at(idxN)), 0}); auto c_block_window = make_tile_window( c_pad_view, make_tuple(number{}, number{}), {i_m, i_n}); return make_tuple(a_block_window, b_flat_block_window, c_block_window); } CK_TILE_DEVICE static void RunFlatmm(const ADataType* a_ptr, const BDataType* b_flat_ptr, CDataType* c_ptr, void* smem_ptr, const FlatmmKernelArgs& kargs, const SplitKBatchOffset& splitk_batch_offset, const index_t block_idx_m, const index_t block_idx_n) { // Create Gemm tensor views, pad views and tile windows const auto& gemm_tensor_views_tuple = MakeGemmTensorViews( a_ptr, b_flat_ptr, c_ptr, kargs, splitk_batch_offset); const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple); auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n); const index_t num_loop = TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k); // Run GEMM cooperatively by whole workgroup. const auto& a_block_window = gemm_tile_windows.at(I0); const auto& b_flat_block_window = gemm_tile_windows.at(I1); const auto& c_block_tile = FlatmmPipeline{}.template operator()( a_block_window, b_flat_block_window, num_loop, smem_ptr); // Run Epilogue Pipeline auto& c_block_window = gemm_tile_windows.at(I2); EpiloguePipeline{}.template operator()( c_block_window, c_block_tile, smem_ptr); } CK_TILE_DEVICE void operator()(FlatmmKernelArgs kargs) const { const auto [iM, iN] = TilePartitioner{kargs.M, kargs.N}.GetOutputTileIndex(blockIdx.x); const index_t i_m = __builtin_amdgcn_readfirstlane(iM * TilePartitioner::MPerBlock); const index_t i_n = __builtin_amdgcn_readfirstlane(iN * TilePartitioner::NPerBlock); const SplitKBatchOffset splitk_batch_offset(kargs); // options const ADataType* a_ptr = static_cast(kargs.a_ptr) + splitk_batch_offset.a_k_split_offset; const BDataType* b_flat_ptr = static_cast(kargs.b_shuffle_ptr) + splitk_batch_offset.b_k_split_offset; CDataType* c_ptr = static_cast(kargs.c_ptr); // allocate LDS __shared__ char smem_ptr[GetSmemSize()]; if constexpr(!(EpiloguePipeline::MemoryOperation == memory_operation_enum::atomic_add && EpiloguePipeline::GetVectorSizeC() % 2 != 0 && is_any_of::value)) { RunFlatmm(a_ptr, b_flat_ptr, c_ptr, smem_ptr, kargs, splitk_batch_offset, i_m, i_n); } } }; } // namespace ck_tile