diff --git a/example/ck_tile/03_gemm/gemm_basic.cpp b/example/ck_tile/03_gemm/gemm_basic.cpp index 81fbd96323..b667886f84 100644 --- a/example/ck_tile/03_gemm/gemm_basic.cpp +++ b/example/ck_tile/03_gemm/gemm_basic.cpp @@ -40,7 +40,7 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& ck_tile::sequence, ck_tile::sequence>; - using TilePartitioner = ck_tile::GemmTile2DPartitioner; + using TilePartitioner = ck_tile::GemmTile1DPartitioner; using CodegenGemmTraits = ck_tile::TileGemmTraits; diff --git a/example/ck_tile/03_gemm/gemm_basic.hpp b/example/ck_tile/03_gemm/gemm_basic.hpp index 4500e3b4fd..3fdc4ac46c 100644 --- a/example/ck_tile/03_gemm/gemm_basic.hpp +++ b/example/ck_tile/03_gemm/gemm_basic.hpp @@ -79,7 +79,7 @@ auto create_args(int argc, char* argv[]) .insert("n", "4096", "n dimension") .insert("k", "2048", "k dimension") .insert("a_layout", "R", "A tensor data layout - Row by default") - .insert("b_layout", "R", "B tensor data layout - Row by default") + .insert("b_layout", "C", "B tensor data layout - Column by default") .insert("c_layout", "R", "C tensor data layout - Row by default") .insert("stride_a", "0", "Tensor A stride") .insert("stride_b", "0", "Tensor B stride") diff --git a/example/ck_tile/03_gemm/universal_gemm.cpp b/example/ck_tile/03_gemm/universal_gemm.cpp index fb43e6f504..eaaf3dbed9 100644 --- a/example/ck_tile/03_gemm/universal_gemm.cpp +++ b/example/ck_tile/03_gemm/universal_gemm.cpp @@ -50,7 +50,9 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& constexpr bool TransposeC = false; - constexpr int kBlockPerCu = 1; + constexpr int kBlockPerCu = 1; + constexpr ck_tile::index_t TileParitionerGroupNum = 8; + constexpr ck_tile::index_t TileParitionerM01 = 4; // =============================================== @@ -58,7 +60,8 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& ck_tile::TileGemmShape, ck_tile::sequence, ck_tile::sequence>; - using TilePartitioner = ck_tile::GemmTile2DPartitioner; + using TilePartitioner = ck_tile:: + GemmSpatiallyLocalTilePartitioner; using Traits = ck_tile::TileGemmTraits; using GemmUniversalTraits = ck_tile:: diff --git a/example/ck_tile/16_batched_gemm/batched_gemm.cpp b/example/ck_tile/16_batched_gemm/batched_gemm.cpp index 2a1cd58255..949621e116 100644 --- a/example/ck_tile/16_batched_gemm/batched_gemm.cpp +++ b/example/ck_tile/16_batched_gemm/batched_gemm.cpp @@ -43,7 +43,7 @@ float batched_gemm(const ck_tile::BatchedGemmHostArgs& args, const ck_tile::stre ck_tile::sequence, ck_tile::sequence>; - using TilePartitioner = ck_tile::GemmTile2DPartitioner; + using TilePartitioner = ck_tile::GemmTile1DPartitioner; using CodegenGemmTraits = ck_tile::TileGemmTraits; diff --git a/include/ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp b/include/ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp index 4b4a4d7a09..0f8bec3cf4 100644 --- a/include/ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp +++ b/include/ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp @@ -70,7 +70,7 @@ struct BatchedGemmKernel : public GemmKernelRunGemm(a_ptr, b_ptr, c_ptr, smem_ptr, kargs, splitk_batch_offset, i_m, i_n); } diff --git a/include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp b/include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp index 774736e1fa..4c65f51914 100644 --- a/include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp +++ b/include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp @@ -75,12 +75,12 @@ struct GemmKernel static constexpr auto I1 = number<1>(); static constexpr auto I2 = number<2>(); - __host__ static constexpr auto GridSize(index_t M, index_t N, index_t KBatch) + CK_TILE_HOST static constexpr auto GridSize(index_t M, index_t N, index_t KBatch) { - return TilePartitioner::GridSize(M, N, KBatch); + return dim3(TilePartitioner::GridSize(M, N), 1, KBatch); } - __host__ static constexpr auto BlockSize() { return dim3(KernelBlockSize); } + CK_TILE_HOST static constexpr auto BlockSize() { return dim3(KernelBlockSize); } struct GemmKernelArgs { @@ -93,7 +93,7 @@ struct GemmKernel index_t stride_A; index_t stride_B; index_t stride_C; - index_t KBatch; + index_t k_batch; }; CK_TILE_HOST static constexpr GemmKernelArgs MakeKernelArgs(const GemmHostArgs& hostArgs) @@ -121,7 +121,7 @@ struct GemmKernel const std::size_t k_id = blockIdx.z) { constexpr auto K1 = TilePartitioner::BlockGemmShape::WarpTile::at(number<2>{}); - const index_t K_t = kargs.KBatch * K1; + const index_t K_t = kargs.k_batch * K1; const index_t KRead = (kargs.K + K_t - 1) / K_t * K1; if constexpr(std::is_same_v) @@ -142,13 +142,13 @@ struct GemmKernel b_k_split_offset = k_id * KRead; } - if(k_id < static_cast(kargs.KBatch - 1)) + if(k_id < static_cast(kargs.k_batch - 1)) { splitted_k = KRead; } else { - splitted_k = kargs.K - KRead * (kargs.KBatch - 1); + splitted_k = kargs.K - KRead * (kargs.k_batch - 1); } } @@ -162,7 +162,7 @@ struct GemmKernel if constexpr(EpiloguePipeline::GetVectorSizeC() % 2 != 0 && is_any_of::value) { - if(kargs.KBatch != 1) + if(kargs.k_batch != 1) { std::cerr << "Conditions not met for Kbatch >1 !" << std::endl; return false; @@ -489,19 +489,14 @@ struct GemmKernel // Run Epilogue Pipeline auto& c_block_window = gemm_tile_windows.at(I2); - if constexpr(DstInMemOp == memory_operation_enum::set || - !(EpiloguePipeline::GetVectorSizeC() % 2 != 0 && - is_any_of::value)) - { - EpiloguePipeline{} - .template operator()( - c_block_window, c_block_tile, smem_ptr); - } + EpiloguePipeline{} + .template operator()( + c_block_window, c_block_tile, smem_ptr); } CK_TILE_DEVICE void operator()(GemmKernelArgs kargs) const { - const auto [iM, iN] = TilePartitioner::GetOutputTileIndex(blockIdx.x, blockIdx.y); + const auto [iM, iN] = TilePartitioner{kargs.M, kargs.N}.GetOutputTileIndex(blockIdx.x); const index_t i_m = __builtin_amdgcn_readfirstlane(iM * TilePartitioner::MPerBlock); const index_t i_n = __builtin_amdgcn_readfirstlane(iN * TilePartitioner::NPerBlock); @@ -516,14 +511,20 @@ struct GemmKernel // allocate LDS __shared__ char smem_ptr[GetSmemSize()]; - if(kargs.KBatch == 1) + if(kargs.k_batch == 1) { RunGemm(a_ptr, b_ptr, c_ptr, smem_ptr, kargs, splitk_batch_offset, i_m, i_n); } else { - RunGemm( - a_ptr, b_ptr, c_ptr, smem_ptr, kargs, splitk_batch_offset, i_m, i_n); + // Do not compile in case where we have unsupported + // VectorSizeC & data type configuration. + if constexpr(!(EpiloguePipeline::GetVectorSizeC() % 2 != 0 && + is_any_of::value)) + { + RunGemm( + a_ptr, b_ptr, c_ptr, smem_ptr, kargs, splitk_batch_offset, i_m, i_n); + } } } }; diff --git a/include/ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp b/include/ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp index eb2b817db6..d8c0239153 100644 --- a/include/ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp +++ b/include/ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp @@ -1,13 +1,21 @@ // SPDX-License-Identifier: MIT // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. +/** + * @file + * GemmTilePartitioner allows customized mapping between a workgroup and the C-tile it computes. + */ + #pragma once #include "ck_tile/core.hpp" namespace ck_tile { -/** @brief Struct representing 2D block index mapping into 3D output tile space. */ +/** + * @brief Class providing 2D workgroup index mapping into 2D output GEMM C-tile space. + * + */ template struct GemmTile2DPartitioner { @@ -17,21 +25,32 @@ struct GemmTile2DPartitioner static constexpr index_t NPerBlock = BlockGemmShape::kN; static constexpr index_t KPerBlock = BlockGemmShape::kK; - /** @brief Returns 3D grid size. */ - CK_TILE_HOST static constexpr auto GridSize(index_t M, index_t N, index_t batch_size) noexcept( - noexcept(MPerBlock != 0 && NPerBlock != 0)) -> dim3 + CK_TILE_HOST_DEVICE GemmTile2DPartitioner() noexcept = delete; + CK_TILE_HOST_DEVICE GemmTile2DPartitioner([[maybe_unused]] index_t M, + [[maybe_unused]] index_t N) noexcept; + + /** + * @brief Calculates GEMM kernel grid size. + * + * @param M GEMM's M dimension. + * @param N GEMM's N dimension. + * @return dim3 Structure holding grid's X,Y and Z dimensions. + */ + CK_TILE_HOST static auto + GridSize(index_t M, index_t N) noexcept(noexcept(MPerBlock != 0 && NPerBlock != 0)) -> dim3 { const index_t GridDimX = (M + MPerBlock - 1) / MPerBlock; const index_t GridDimY = (N + NPerBlock - 1) / NPerBlock; - const index_t GridDimZ = batch_size; - return dim3(GridDimX, GridDimY, GridDimZ); + return dim3(GridDimX, GridDimY, 1); } /** - * @brief Returns the number of loops. - * @param [in] K is dimension + * @brief Calculate number of loop iterations over GEMM's K dimension. + * + * @param K GEMM's K dimension. + * @return index_t The number of loop iterations over K dimension. */ - CK_TILE_HOST_DEVICE static constexpr auto GetLoopNum(index_t K) noexcept -> index_t + CK_TILE_HOST_DEVICE static auto GetLoopNum(index_t K) noexcept -> index_t { return integer_divide_ceil(K, KPerBlock); } @@ -42,8 +61,15 @@ struct GemmTile2DPartitioner * @param [in] blockIdy is blockIdx.y * @return Returns the output tile indexes. */ - CK_TILE_DEVICE static constexpr auto GetOutputTileIndex(index_t blockIdx, - index_t blockIdy) noexcept + + /** + * @brief Calculate workgroup 2D index mapping into 2D output C-tile space. + * + * @param blockIdx WGP's X index. + * @param blockIdy WGP's Y index. + * @return const tuple Tuple containing 2D output C-tile index. + */ + CK_TILE_DEVICE static auto GetOutputTileIndex(index_t blockIdx, index_t blockIdy) noexcept -> const tuple { const index_t iM = __builtin_amdgcn_readfirstlane(blockIdx); @@ -53,61 +79,71 @@ struct GemmTile2DPartitioner }; /** - * @brief Struct representing 1D block index mapping into 2D output tile space. + * @brief Class providing 1D WGP index mapping into 2D output C-tile space. + * + * @tparam BlockGemmShape_ A class providing basic GEMM parameters. \link TileGemmShape */ -template +template struct GemmTile1DPartitioner { - using BlockGemmShape = remove_cvref_t; + using BlockGemmShape = remove_cvref_t; static constexpr index_t MPerBlock = BlockGemmShape::kM; static constexpr index_t NPerBlock = BlockGemmShape::kN; static constexpr index_t KPerBlock = BlockGemmShape::kK; - /** @brief delete default ctr with no any object */ - constexpr GemmTile1DPartitioner() noexcept = delete; + CK_TILE_HOST_DEVICE GemmTile1DPartitioner() noexcept = delete; - /** @brief constructs an object that does contain a N value. */ - constexpr GemmTile1DPartitioner(index_t N) noexcept { N_ = N; } + /** + * @brief Construct a new GemmTile1DPartitioner object. + * + * @param M GEMM's M dimension. + * @param N GEMM's N dimension. + */ + CK_TILE_HOST_DEVICE GemmTile1DPartitioner([[maybe_unused]] index_t M, index_t N) noexcept + { + N_ = N; + } - /** @brief Returns 1D grid size. */ - CK_TILE_HOST static constexpr auto - GridSize(index_t M, index_t N) noexcept(noexcept(MPerBlock != 0 && NPerBlock != 0)) -> dim3 + /** + * @brief Calculates GEMM kernel grid size. + * + * @param M GEMM's M dimension. + * @param N GEMM's N dimension. + * @return dim3 Structure holding grid's X,Y and Z dimensions. + */ + CK_TILE_HOST static auto + GridSize(index_t M, index_t N) noexcept(noexcept(MPerBlock != 0 && NPerBlock != 0)) -> index_t { const index_t GridDimX = (M + MPerBlock - 1) / MPerBlock; const index_t GridDimY = (N + NPerBlock - 1) / NPerBlock; - return dim3(GridDimX * GridDimY, 1, 1); + return GridDimX * GridDimY; } /** - * @brief Returns the number of blocks in N. - * @param [in] N is dimension + * @brief Calculate number of loop iterations over GEMM's K dimension. + * + * @param K GEMM's K dimension. + * @return index_t The number of loop iterations over K dimension. */ - CK_TILE_HOST_DEVICE static constexpr auto GetNBlock(index_t N) noexcept -> index_t - { - return integer_divide_ceil(N, NPerBlock); - } - - /** - * @brief Returns the number of loops. - * @param [in] K is dimension - */ - CK_TILE_HOST_DEVICE static constexpr auto GetLoopNum(index_t K) noexcept -> index_t + CK_TILE_HOST_DEVICE static auto GetLoopNum(index_t K) noexcept -> index_t { return integer_divide_ceil(K, KPerBlock); } /** - * @brief The function returns 2D output tile space. - * @param [in] blockIdx is blockIdx.x - block_start. - * */ - CK_TILE_DEVICE static constexpr auto GetOutputTileIndex(index_t blockIdx) noexcept + * @brief Calculate workgroup 1D index mapping into 2D output C-tile space. + * + * @param blockIdx WGP's index. + * @return const tuple Tuple containing 2D output C-tile index. + */ + CK_TILE_DEVICE static auto GetOutputTileIndex(index_t blockIdx) noexcept -> const tuple { - const index_t NBlock = GetNBlock(N_); + const index_t NBlocks = integer_divide_ceil(N_, NPerBlock); - const index_t iM = __builtin_amdgcn_readfirstlane(blockIdx / NBlock); - const index_t iN = __builtin_amdgcn_readfirstlane(blockIdx - (iM)*NBlock); + const index_t iM = __builtin_amdgcn_readfirstlane(blockIdx / NBlocks); + const index_t iN = __builtin_amdgcn_readfirstlane(blockIdx - iM * NBlocks); return make_tuple(iM, iN); } @@ -141,21 +177,176 @@ struct HasFnOneArgImpl().GetOutputTileIn * enable-if `GetOutputTileIndex`-fn is std::true_type when `GetOutputTileIndex`-fn is well-formed, * otherwise std::false_type. */ -template {}>> +template {}>> struct OffsettedTile1DPartitioner { /** * @brief The function subtracts the block's start (offset) from 1D raw-indexes. - * @param [in] block_start is `blockIdx.x - block_start`. - * @return Returns a `tuple` [Im, In] shifted index, used to shift 1d-tile index. + * @param [in] block_start Workgroup offset. + * @param [in] M Gemm's M dimension. + * @param [in] N Gemm's N dimension. + * @return Returns a `tuple` [Im, In] with shifted index. */ - [[nodiscard]] CK_TILE_DEVICE static constexpr auto GetOffsetedTileIndex(index_t block_start, - index_t N) noexcept + [[nodiscard]] CK_TILE_DEVICE static auto + GetOffsetedTileIndex(index_t block_start, index_t M, index_t N) noexcept -> const tuple { - const auto [iM, iN] = PartitionerFn(N).GetOutputTileIndex(blockIdx.x - block_start); + const auto [iM, iN] = TilePartitioner{M, N}.GetOutputTileIndex(blockIdx.x - block_start); return make_tuple(iM, iN); } }; + +/** + * @brief Class mapping 1D block index into 2D output tile space. + * + * @note It groups spatially workgroups in order to better utilize caches. + * It is using grouped Rows of column-vectors WGP pattern. It's optimized + * for gfx94x-like multiple-die chip. + * + * @tparam GroupNum - The number of big groups. + * @tparam M01 - The number of groups in M dim within spatially local WGPs, + * + */ +template +struct GemmSpatiallyLocalTilePartitioner +{ + using BlockGemmShape = remove_cvref_t; + + static constexpr index_t MPerBlock = BlockGemmShape::kM; + static constexpr index_t NPerBlock = BlockGemmShape::kN; + static constexpr index_t KPerBlock = BlockGemmShape::kK; + + CK_TILE_HOST_DEVICE GemmSpatiallyLocalTilePartitioner() noexcept = delete; + CK_TILE_HOST_DEVICE GemmSpatiallyLocalTilePartitioner(index_t M_, index_t N_) noexcept + : M(M_), N(N_) + { + } + + /** + * @brief Calculates GEMM kernel grid size. + * + * @param M GEMM's M dimension. + * @param N GEMM's N dimension. + * @return index_t A total number of workgroups. + */ + CK_TILE_HOST static auto + GridSize(index_t M, index_t N) noexcept(noexcept(MPerBlock != 0 && NPerBlock != 0)) -> index_t + { + const index_t GridDimX = integer_divide_ceil(M, MPerBlock); + const index_t GridDimY = integer_divide_ceil(N, NPerBlock); + return GridDimX * GridDimY; + } + + /** + * @brief Calculate number of loop iterations over GEMM's K dimension. + * + * @param K GEMM's K dimension. + * @return index_t The number of loop iterations over K dimension. + */ + CK_TILE_HOST_DEVICE static auto GetLoopNum(index_t K) noexcept -> index_t + { + return integer_divide_ceil(K, KPerBlock); + } + + /** + * @brief Calculate workgroup 1D index mapping into 2D output C-tile space. + * + * @param [in] block_1d_id WGP's index. + * @return const tuple Tuple containing 2D output C-tile index. + */ + CK_TILE_DEVICE auto GetOutputTileIndex(index_t block_1d_id) noexcept + -> const tuple + { + const auto M0 = integer_divide_ceil(M, MPerBlock); + const auto N0 = integer_divide_ceil(N, NPerBlock); + + if(M0 == 1) + { + return make_tuple(0, block_1d_id); + } + else if(N0 == 1) + { + return make_tuple(block_1d_id, 0); + } + // block_1d_id = block_1d_id % (M0 * N0); // swallow batch index + else + { + const auto group_size = integer_divide_ceil(M0 * N0, GroupNum); + const auto big_group_num = GroupNum - (group_size * GroupNum - M0 * N0); + const auto group_id_y = block_1d_id / GroupNum; + const auto group_id_x = block_1d_id - group_id_y * GroupNum; + const auto remap_block_1d_id = + group_id_x <= big_group_num + ? group_id_x * group_size + group_id_y + : group_id_x * group_size + big_group_num - group_id_x + group_id_y; + + const index_t idx_M0 = remap_block_1d_id / N0; + const index_t idx_N0 = remap_block_1d_id - idx_M0 * N0; + + const index_t M0_tmp = M0 / M01; + const index_t M0_mod_M01 = M0 - M0_tmp * M01; + + const auto M01_adapt = (idx_M0 < M0 - M0_mod_M01) ? M01 : M0_mod_M01; + + const index_t idx_M00 = idx_M0 / M01; + const index_t idx_M01 = idx_M0 - idx_M00 * M01; + const index_t idx_N0_M01_local = idx_N0 + idx_M01 * N0; + + /** + * idxN0 + * + * |< mtx N >| + * + * NPerBlock NPerBlock NPerBlock NPerBlock + * N_0 N_1 N_2 N_3 + * - |-----------|-----------|-----------|-----|-----|- + * ^ | - - 0 |/----> 2 | | | | + * | | | / | | | | | M_0 MPerBlock + * | M | /| | | | | | + * |-0---|---/-|-----|-----|-----------|-----|-----|- + * | 1 | / | | | blockid | | | + * idxM0 | | | / | V | 5 | | | M_1 MPerBlock + * | - V 1 | - 3 | | | | + * |-----------|-----------|-----------|-----|-----|- + * mtx M | | | | | | + * | | | | | | M_2 MPerBlock + * | | | | | | + * |-----------|-----------|-----------|-----|-----|- + * | | | | | | + * | | | | | | M_3 MPerBlock + * | | | | | | + * |-----------|-----------|-----------|-----|-----|- + * V | | | | | | + * - |-----------|-----------|-----------|-----|-----|- M_4 MPerBlock + * | | | | | | + * |-----------|-----------|-----------|-----|-----|- + * Example: + * assume: + * M0 = 5 + * N0 = 4 + * block_1d_id = 5 + * M01 = 2 + * + * idx_N0 = 1 + * idx_M0 = 1 + * M01_adapt = 2 + * idx_M00 = 0 + * idx_M01 = 1 + * idx_N0_M01_local = 5 + * output {1, 2} + */ + + const index_t N_out = idx_N0_M01_local / M01_adapt; + const index_t idx_loc_mod_M01 = idx_N0_M01_local - N_out * M01_adapt; + + return make_tuple(idx_loc_mod_M01 + idx_M00 * M01, N_out); + } + } + + private: + index_t M; + index_t N; +}; + } // namespace ck_tile diff --git a/include/ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp b/include/ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp index 656939770c..13d3df02f9 100644 --- a/include/ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp +++ b/include/ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp @@ -77,8 +77,8 @@ struct GroupedGemmKernel : public GemmKernel, ck_tile::sequence>; - using TilePartitioner = ck_tile::GemmTile2DPartitioner; + using TilePartitioner = ck_tile::GemmTile1DPartitioner; using CodegenGemmTraits = ck_tile::TileGemmTraits; diff --git a/test/ck_tile/gemm/test_gemm_pipeline_util.hpp b/test/ck_tile/gemm/test_gemm_pipeline_util.hpp index 647b54cb8e..dc685567eb 100644 --- a/test/ck_tile/gemm/test_gemm_pipeline_util.hpp +++ b/test/ck_tile/gemm/test_gemm_pipeline_util.hpp @@ -55,7 +55,9 @@ class TestCkTileGemmPipeline : public ::testing::Test // TODO: For now - but this should also be a test parameter constexpr bool TransposeC = false; - constexpr int kBlockPerCu = 1; + constexpr int kBlockPerCu = 1; + constexpr ck_tile::index_t TileParitionerGroupNum = 8; + constexpr ck_tile::index_t TileParitionerM01 = 4; // =============================================== @@ -63,7 +65,8 @@ class TestCkTileGemmPipeline : public ::testing::Test ck_tile::TileGemmShape, ck_tile::sequence, ck_tile::sequence>; - using TilePartitioner = ck_tile::GemmTile2DPartitioner; + using TilePartitioner = ck_tile:: + GemmSpatiallyLocalTilePartitioner; using Traits = ck_tile::TileGemmTraits; using GemmUniversalTraits = ck_tile::