mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 22:39:03 +00:00
[CK Tile] Spatially local GEMM tile partitioner. (#1843)
* Add spatially local tile partitioner
* Use 1D Grid size & create partitioner object.
* Docs & use 1D partitioner in example.
* Clang format.
* Change kernel grid size
Now: X is the # of output C-tiles,
Y is the batch count
Z is the splitK
* Formatting & more doc.
* Clang format.
* Fix batched gemm test. Use 1d partitioner.
* Move condition.
* FIx ctor.
* clang-format.
This commit is contained in:
@@ -40,7 +40,7 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
|
||||
ck_tile::sequence<M_Warp, N_Warp, K_Warp>,
|
||||
ck_tile::sequence<M_Warp_Tile, N_Warp_Tile, K_Warp_Tile>>;
|
||||
|
||||
using TilePartitioner = ck_tile::GemmTile2DPartitioner<CodegenGemmShape>;
|
||||
using TilePartitioner = ck_tile::GemmTile1DPartitioner<CodegenGemmShape>;
|
||||
|
||||
using CodegenGemmTraits =
|
||||
ck_tile::TileGemmTraits<kPadM, kPadN, kPadK, ALayout, BLayout, CLayout>;
|
||||
|
||||
@@ -79,7 +79,7 @@ auto create_args(int argc, char* argv[])
|
||||
.insert("n", "4096", "n dimension")
|
||||
.insert("k", "2048", "k dimension")
|
||||
.insert("a_layout", "R", "A tensor data layout - Row by default")
|
||||
.insert("b_layout", "R", "B tensor data layout - Row by default")
|
||||
.insert("b_layout", "C", "B tensor data layout - Column by default")
|
||||
.insert("c_layout", "R", "C tensor data layout - Row by default")
|
||||
.insert("stride_a", "0", "Tensor A stride")
|
||||
.insert("stride_b", "0", "Tensor B stride")
|
||||
|
||||
@@ -50,7 +50,9 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
|
||||
|
||||
constexpr bool TransposeC = false;
|
||||
|
||||
constexpr int kBlockPerCu = 1;
|
||||
constexpr int kBlockPerCu = 1;
|
||||
constexpr ck_tile::index_t TileParitionerGroupNum = 8;
|
||||
constexpr ck_tile::index_t TileParitionerM01 = 4;
|
||||
|
||||
// ===============================================
|
||||
|
||||
@@ -58,7 +60,8 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
|
||||
ck_tile::TileGemmShape<ck_tile::sequence<M_Tile, N_Tile, K_Tile>,
|
||||
ck_tile::sequence<M_Warp, N_Warp, K_Warp>,
|
||||
ck_tile::sequence<M_Warp_Tile, N_Warp_Tile, K_Warp_Tile>>;
|
||||
using TilePartitioner = ck_tile::GemmTile2DPartitioner<GemmShape>;
|
||||
using TilePartitioner = ck_tile::
|
||||
GemmSpatiallyLocalTilePartitioner<GemmShape, TileParitionerGroupNum, TileParitionerM01>;
|
||||
|
||||
using Traits = ck_tile::TileGemmTraits<kPadM, kPadN, kPadK, ALayout, BLayout, CLayout>;
|
||||
using GemmUniversalTraits = ck_tile::
|
||||
|
||||
@@ -43,7 +43,7 @@ float batched_gemm(const ck_tile::BatchedGemmHostArgs& args, const ck_tile::stre
|
||||
ck_tile::sequence<M_Warp, N_Warp, K_Warp>,
|
||||
ck_tile::sequence<M_Warp_Tile, N_Warp_Tile, K_Warp_Tile>>;
|
||||
|
||||
using TilePartitioner = ck_tile::GemmTile2DPartitioner<CodegenGemmShape>;
|
||||
using TilePartitioner = ck_tile::GemmTile1DPartitioner<CodegenGemmShape>;
|
||||
|
||||
using CodegenGemmTraits =
|
||||
ck_tile::TileGemmTraits<kPadM, kPadN, kPadK, ALayout, BLayout, CLayout>;
|
||||
|
||||
@@ -70,7 +70,7 @@ struct BatchedGemmKernel : public GemmKernel<TilePartitioner_, GemmPipeline_, Ep
|
||||
__host__ static constexpr auto
|
||||
GridSize(index_t M, index_t N, index_t KBatch, index_t batch_count)
|
||||
{
|
||||
return TilePartitioner::GridSize(M, N, KBatch * batch_count);
|
||||
return dim3(TilePartitioner::GridSize(M, N), batch_count, KBatch);
|
||||
}
|
||||
|
||||
__host__ static constexpr auto BlockSize() { return dim3(Base::KernelBlockSize); }
|
||||
@@ -101,14 +101,14 @@ struct BatchedGemmKernel : public GemmKernel<TilePartitioner_, GemmPipeline_, Ep
|
||||
|
||||
CK_TILE_DEVICE void operator()(BatchedGemmKernelArgs kargs) const
|
||||
{
|
||||
const auto [iM, iN] = TilePartitioner::GetOutputTileIndex(blockIdx.x, blockIdx.y);
|
||||
const auto [iM, iN] = TilePartitioner{kargs.M, kargs.N}.GetOutputTileIndex(blockIdx.x);
|
||||
const index_t i_m = __builtin_amdgcn_readfirstlane(iM * TilePartitioner::MPerBlock);
|
||||
const index_t i_n = __builtin_amdgcn_readfirstlane(iN * TilePartitioner::NPerBlock);
|
||||
|
||||
const auto i_batch = __builtin_amdgcn_readfirstlane(blockIdx.z / kargs.KBatch);
|
||||
const auto i_k = __builtin_amdgcn_readfirstlane(blockIdx.z - i_batch * kargs.KBatch);
|
||||
const auto i_batch = __builtin_amdgcn_readfirstlane(blockIdx.y);
|
||||
const auto i_splitk = __builtin_amdgcn_readfirstlane(blockIdx.z);
|
||||
|
||||
const typename Base::SplitKBatchOffset splitk_batch_offset(kargs, i_k);
|
||||
const typename Base::SplitKBatchOffset splitk_batch_offset(kargs, i_splitk);
|
||||
|
||||
// options
|
||||
const auto batch_stride_A = __builtin_amdgcn_readfirstlane(kargs.batch_stride_A);
|
||||
@@ -128,7 +128,7 @@ struct BatchedGemmKernel : public GemmKernel<TilePartitioner_, GemmPipeline_, Ep
|
||||
// allocate LDS
|
||||
__shared__ char smem_ptr[GetSmemSize()];
|
||||
|
||||
if(kargs.KBatch == 1)
|
||||
if(kargs.k_batch == 1)
|
||||
{
|
||||
this->RunGemm(a_ptr, b_ptr, c_ptr, smem_ptr, kargs, splitk_batch_offset, i_m, i_n);
|
||||
}
|
||||
|
||||
@@ -75,12 +75,12 @@ struct GemmKernel
|
||||
static constexpr auto I1 = number<1>();
|
||||
static constexpr auto I2 = number<2>();
|
||||
|
||||
__host__ static constexpr auto GridSize(index_t M, index_t N, index_t KBatch)
|
||||
CK_TILE_HOST static constexpr auto GridSize(index_t M, index_t N, index_t KBatch)
|
||||
{
|
||||
return TilePartitioner::GridSize(M, N, KBatch);
|
||||
return dim3(TilePartitioner::GridSize(M, N), 1, KBatch);
|
||||
}
|
||||
|
||||
__host__ static constexpr auto BlockSize() { return dim3(KernelBlockSize); }
|
||||
CK_TILE_HOST static constexpr auto BlockSize() { return dim3(KernelBlockSize); }
|
||||
|
||||
struct GemmKernelArgs
|
||||
{
|
||||
@@ -93,7 +93,7 @@ struct GemmKernel
|
||||
index_t stride_A;
|
||||
index_t stride_B;
|
||||
index_t stride_C;
|
||||
index_t KBatch;
|
||||
index_t k_batch;
|
||||
};
|
||||
|
||||
CK_TILE_HOST static constexpr GemmKernelArgs MakeKernelArgs(const GemmHostArgs& hostArgs)
|
||||
@@ -121,7 +121,7 @@ struct GemmKernel
|
||||
const std::size_t k_id = blockIdx.z)
|
||||
{
|
||||
constexpr auto K1 = TilePartitioner::BlockGemmShape::WarpTile::at(number<2>{});
|
||||
const index_t K_t = kargs.KBatch * K1;
|
||||
const index_t K_t = kargs.k_batch * K1;
|
||||
const index_t KRead = (kargs.K + K_t - 1) / K_t * K1;
|
||||
|
||||
if constexpr(std::is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
|
||||
@@ -142,13 +142,13 @@ struct GemmKernel
|
||||
b_k_split_offset = k_id * KRead;
|
||||
}
|
||||
|
||||
if(k_id < static_cast<uint32_t>(kargs.KBatch - 1))
|
||||
if(k_id < static_cast<uint32_t>(kargs.k_batch - 1))
|
||||
{
|
||||
splitted_k = KRead;
|
||||
}
|
||||
else
|
||||
{
|
||||
splitted_k = kargs.K - KRead * (kargs.KBatch - 1);
|
||||
splitted_k = kargs.K - KRead * (kargs.k_batch - 1);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -162,7 +162,7 @@ struct GemmKernel
|
||||
if constexpr(EpiloguePipeline::GetVectorSizeC() % 2 != 0 &&
|
||||
is_any_of<CDataType, fp16_t, bf16_t>::value)
|
||||
{
|
||||
if(kargs.KBatch != 1)
|
||||
if(kargs.k_batch != 1)
|
||||
{
|
||||
std::cerr << "Conditions not met for Kbatch >1 !" << std::endl;
|
||||
return false;
|
||||
@@ -489,19 +489,14 @@ struct GemmKernel
|
||||
// Run Epilogue Pipeline
|
||||
auto& c_block_window = gemm_tile_windows.at(I2);
|
||||
|
||||
if constexpr(DstInMemOp == memory_operation_enum::set ||
|
||||
!(EpiloguePipeline::GetVectorSizeC() % 2 != 0 &&
|
||||
is_any_of<CDataType, fp16_t, bf16_t>::value))
|
||||
{
|
||||
EpiloguePipeline{}
|
||||
.template operator()<decltype(c_block_window), decltype(c_block_tile), DstInMemOp>(
|
||||
c_block_window, c_block_tile, smem_ptr);
|
||||
}
|
||||
EpiloguePipeline{}
|
||||
.template operator()<decltype(c_block_window), decltype(c_block_tile), DstInMemOp>(
|
||||
c_block_window, c_block_tile, smem_ptr);
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE void operator()(GemmKernelArgs kargs) const
|
||||
{
|
||||
const auto [iM, iN] = TilePartitioner::GetOutputTileIndex(blockIdx.x, blockIdx.y);
|
||||
const auto [iM, iN] = TilePartitioner{kargs.M, kargs.N}.GetOutputTileIndex(blockIdx.x);
|
||||
const index_t i_m = __builtin_amdgcn_readfirstlane(iM * TilePartitioner::MPerBlock);
|
||||
const index_t i_n = __builtin_amdgcn_readfirstlane(iN * TilePartitioner::NPerBlock);
|
||||
|
||||
@@ -516,14 +511,20 @@ struct GemmKernel
|
||||
// allocate LDS
|
||||
__shared__ char smem_ptr[GetSmemSize()];
|
||||
|
||||
if(kargs.KBatch == 1)
|
||||
if(kargs.k_batch == 1)
|
||||
{
|
||||
RunGemm(a_ptr, b_ptr, c_ptr, smem_ptr, kargs, splitk_batch_offset, i_m, i_n);
|
||||
}
|
||||
else
|
||||
{
|
||||
RunGemm<memory_operation_enum::atomic_add>(
|
||||
a_ptr, b_ptr, c_ptr, smem_ptr, kargs, splitk_batch_offset, i_m, i_n);
|
||||
// Do not compile in case where we have unsupported
|
||||
// VectorSizeC & data type configuration.
|
||||
if constexpr(!(EpiloguePipeline::GetVectorSizeC() % 2 != 0 &&
|
||||
is_any_of<CDataType, fp16_t, bf16_t>::value))
|
||||
{
|
||||
RunGemm<memory_operation_enum::atomic_add>(
|
||||
a_ptr, b_ptr, c_ptr, smem_ptr, kargs, splitk_batch_offset, i_m, i_n);
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
@@ -1,13 +1,21 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
/**
|
||||
* @file
|
||||
* GemmTilePartitioner allows customized mapping between a workgroup and the C-tile it computes.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
/** @brief Struct representing 2D block index mapping into 3D output tile space. */
|
||||
/**
|
||||
* @brief Class providing 2D workgroup index mapping into 2D output GEMM C-tile space.
|
||||
*
|
||||
*/
|
||||
template <typename BlockGemmShapeType>
|
||||
struct GemmTile2DPartitioner
|
||||
{
|
||||
@@ -17,21 +25,32 @@ struct GemmTile2DPartitioner
|
||||
static constexpr index_t NPerBlock = BlockGemmShape::kN;
|
||||
static constexpr index_t KPerBlock = BlockGemmShape::kK;
|
||||
|
||||
/** @brief Returns 3D grid size. */
|
||||
CK_TILE_HOST static constexpr auto GridSize(index_t M, index_t N, index_t batch_size) noexcept(
|
||||
noexcept(MPerBlock != 0 && NPerBlock != 0)) -> dim3
|
||||
CK_TILE_HOST_DEVICE GemmTile2DPartitioner() noexcept = delete;
|
||||
CK_TILE_HOST_DEVICE GemmTile2DPartitioner([[maybe_unused]] index_t M,
|
||||
[[maybe_unused]] index_t N) noexcept;
|
||||
|
||||
/**
|
||||
* @brief Calculates GEMM kernel grid size.
|
||||
*
|
||||
* @param M GEMM's M dimension.
|
||||
* @param N GEMM's N dimension.
|
||||
* @return dim3 Structure holding grid's X,Y and Z dimensions.
|
||||
*/
|
||||
CK_TILE_HOST static auto
|
||||
GridSize(index_t M, index_t N) noexcept(noexcept(MPerBlock != 0 && NPerBlock != 0)) -> dim3
|
||||
{
|
||||
const index_t GridDimX = (M + MPerBlock - 1) / MPerBlock;
|
||||
const index_t GridDimY = (N + NPerBlock - 1) / NPerBlock;
|
||||
const index_t GridDimZ = batch_size;
|
||||
return dim3(GridDimX, GridDimY, GridDimZ);
|
||||
return dim3(GridDimX, GridDimY, 1);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Returns the number of loops.
|
||||
* @param [in] K is dimension
|
||||
* @brief Calculate number of loop iterations over GEMM's K dimension.
|
||||
*
|
||||
* @param K GEMM's K dimension.
|
||||
* @return index_t The number of loop iterations over K dimension.
|
||||
*/
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetLoopNum(index_t K) noexcept -> index_t
|
||||
CK_TILE_HOST_DEVICE static auto GetLoopNum(index_t K) noexcept -> index_t
|
||||
{
|
||||
return integer_divide_ceil(K, KPerBlock);
|
||||
}
|
||||
@@ -42,8 +61,15 @@ struct GemmTile2DPartitioner
|
||||
* @param [in] blockIdy is blockIdx.y
|
||||
* @return Returns the output tile indexes.
|
||||
*/
|
||||
CK_TILE_DEVICE static constexpr auto GetOutputTileIndex(index_t blockIdx,
|
||||
index_t blockIdy) noexcept
|
||||
|
||||
/**
|
||||
* @brief Calculate workgroup 2D index mapping into 2D output C-tile space.
|
||||
*
|
||||
* @param blockIdx WGP's X index.
|
||||
* @param blockIdy WGP's Y index.
|
||||
* @return const tuple<index_t, index_t> Tuple containing 2D output C-tile index.
|
||||
*/
|
||||
CK_TILE_DEVICE static auto GetOutputTileIndex(index_t blockIdx, index_t blockIdy) noexcept
|
||||
-> const tuple<index_t, index_t>
|
||||
{
|
||||
const index_t iM = __builtin_amdgcn_readfirstlane(blockIdx);
|
||||
@@ -53,61 +79,71 @@ struct GemmTile2DPartitioner
|
||||
};
|
||||
|
||||
/**
|
||||
* @brief Struct representing 1D block index mapping into 2D output tile space.
|
||||
* @brief Class providing 1D WGP index mapping into 2D output C-tile space.
|
||||
*
|
||||
* @tparam BlockGemmShape_ A class providing basic GEMM parameters. \link TileGemmShape
|
||||
*/
|
||||
template <typename BlockGemmShapeType>
|
||||
template <typename BlockGemmShape_>
|
||||
struct GemmTile1DPartitioner
|
||||
{
|
||||
using BlockGemmShape = remove_cvref_t<BlockGemmShapeType>;
|
||||
using BlockGemmShape = remove_cvref_t<BlockGemmShape_>;
|
||||
|
||||
static constexpr index_t MPerBlock = BlockGemmShape::kM;
|
||||
static constexpr index_t NPerBlock = BlockGemmShape::kN;
|
||||
static constexpr index_t KPerBlock = BlockGemmShape::kK;
|
||||
|
||||
/** @brief delete default ctr with no any object */
|
||||
constexpr GemmTile1DPartitioner() noexcept = delete;
|
||||
CK_TILE_HOST_DEVICE GemmTile1DPartitioner() noexcept = delete;
|
||||
|
||||
/** @brief constructs an object that does contain a N value. */
|
||||
constexpr GemmTile1DPartitioner(index_t N) noexcept { N_ = N; }
|
||||
/**
|
||||
* @brief Construct a new GemmTile1DPartitioner object.
|
||||
*
|
||||
* @param M GEMM's M dimension.
|
||||
* @param N GEMM's N dimension.
|
||||
*/
|
||||
CK_TILE_HOST_DEVICE GemmTile1DPartitioner([[maybe_unused]] index_t M, index_t N) noexcept
|
||||
{
|
||||
N_ = N;
|
||||
}
|
||||
|
||||
/** @brief Returns 1D grid size. */
|
||||
CK_TILE_HOST static constexpr auto
|
||||
GridSize(index_t M, index_t N) noexcept(noexcept(MPerBlock != 0 && NPerBlock != 0)) -> dim3
|
||||
/**
|
||||
* @brief Calculates GEMM kernel grid size.
|
||||
*
|
||||
* @param M GEMM's M dimension.
|
||||
* @param N GEMM's N dimension.
|
||||
* @return dim3 Structure holding grid's X,Y and Z dimensions.
|
||||
*/
|
||||
CK_TILE_HOST static auto
|
||||
GridSize(index_t M, index_t N) noexcept(noexcept(MPerBlock != 0 && NPerBlock != 0)) -> index_t
|
||||
{
|
||||
const index_t GridDimX = (M + MPerBlock - 1) / MPerBlock;
|
||||
const index_t GridDimY = (N + NPerBlock - 1) / NPerBlock;
|
||||
return dim3(GridDimX * GridDimY, 1, 1);
|
||||
return GridDimX * GridDimY;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Returns the number of blocks in N.
|
||||
* @param [in] N is dimension
|
||||
* @brief Calculate number of loop iterations over GEMM's K dimension.
|
||||
*
|
||||
* @param K GEMM's K dimension.
|
||||
* @return index_t The number of loop iterations over K dimension.
|
||||
*/
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetNBlock(index_t N) noexcept -> index_t
|
||||
{
|
||||
return integer_divide_ceil(N, NPerBlock);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Returns the number of loops.
|
||||
* @param [in] K is dimension
|
||||
*/
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetLoopNum(index_t K) noexcept -> index_t
|
||||
CK_TILE_HOST_DEVICE static auto GetLoopNum(index_t K) noexcept -> index_t
|
||||
{
|
||||
return integer_divide_ceil(K, KPerBlock);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief The function returns 2D output tile space.
|
||||
* @param [in] blockIdx is blockIdx.x - block_start.
|
||||
* */
|
||||
CK_TILE_DEVICE static constexpr auto GetOutputTileIndex(index_t blockIdx) noexcept
|
||||
* @brief Calculate workgroup 1D index mapping into 2D output C-tile space.
|
||||
*
|
||||
* @param blockIdx WGP's index.
|
||||
* @return const tuple<index_t, index_t> Tuple containing 2D output C-tile index.
|
||||
*/
|
||||
CK_TILE_DEVICE static auto GetOutputTileIndex(index_t blockIdx) noexcept
|
||||
-> const tuple<index_t, index_t>
|
||||
{
|
||||
const index_t NBlock = GetNBlock(N_);
|
||||
const index_t NBlocks = integer_divide_ceil(N_, NPerBlock);
|
||||
|
||||
const index_t iM = __builtin_amdgcn_readfirstlane(blockIdx / NBlock);
|
||||
const index_t iN = __builtin_amdgcn_readfirstlane(blockIdx - (iM)*NBlock);
|
||||
const index_t iM = __builtin_amdgcn_readfirstlane(blockIdx / NBlocks);
|
||||
const index_t iN = __builtin_amdgcn_readfirstlane(blockIdx - iM * NBlocks);
|
||||
return make_tuple(iM, iN);
|
||||
}
|
||||
|
||||
@@ -141,21 +177,176 @@ struct HasFnOneArgImpl<T, std::void_t<decltype(std::declval<T>().GetOutputTileIn
|
||||
* enable-if `GetOutputTileIndex`-fn is std::true_type when `GetOutputTileIndex`-fn is well-formed,
|
||||
* otherwise std::false_type.
|
||||
*/
|
||||
template <typename PartitionerFn,
|
||||
typename = typename std::enable_if_t<HasFnOneArgImpl<PartitionerFn>{}>>
|
||||
template <typename TilePartitioner,
|
||||
typename = typename std::enable_if_t<HasFnOneArgImpl<TilePartitioner>{}>>
|
||||
struct OffsettedTile1DPartitioner
|
||||
{
|
||||
/**
|
||||
* @brief The function subtracts the block's start (offset) from 1D raw-indexes.
|
||||
* @param [in] block_start is `blockIdx.x - block_start`.
|
||||
* @return Returns a `tuple` [Im, In] shifted index, used to shift 1d-tile index.
|
||||
* @param [in] block_start Workgroup offset.
|
||||
* @param [in] M Gemm's M dimension.
|
||||
* @param [in] N Gemm's N dimension.
|
||||
* @return Returns a `tuple` [Im, In] with shifted index.
|
||||
*/
|
||||
[[nodiscard]] CK_TILE_DEVICE static constexpr auto GetOffsetedTileIndex(index_t block_start,
|
||||
index_t N) noexcept
|
||||
[[nodiscard]] CK_TILE_DEVICE static auto
|
||||
GetOffsetedTileIndex(index_t block_start, index_t M, index_t N) noexcept
|
||||
-> const tuple<index_t, index_t>
|
||||
{
|
||||
const auto [iM, iN] = PartitionerFn(N).GetOutputTileIndex(blockIdx.x - block_start);
|
||||
const auto [iM, iN] = TilePartitioner{M, N}.GetOutputTileIndex(blockIdx.x - block_start);
|
||||
return make_tuple(iM, iN);
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* @brief Class mapping 1D block index into 2D output tile space.
|
||||
*
|
||||
* @note It groups spatially workgroups in order to better utilize caches.
|
||||
* It is using grouped Rows of column-vectors WGP pattern. It's optimized
|
||||
* for gfx94x-like multiple-die chip.
|
||||
*
|
||||
* @tparam GroupNum - The number of big groups.
|
||||
* @tparam M01 - The number of groups in M dim within spatially local WGPs,
|
||||
*
|
||||
*/
|
||||
template <typename BlockGemmShapeType, index_t GroupNum, index_t M01>
|
||||
struct GemmSpatiallyLocalTilePartitioner
|
||||
{
|
||||
using BlockGemmShape = remove_cvref_t<BlockGemmShapeType>;
|
||||
|
||||
static constexpr index_t MPerBlock = BlockGemmShape::kM;
|
||||
static constexpr index_t NPerBlock = BlockGemmShape::kN;
|
||||
static constexpr index_t KPerBlock = BlockGemmShape::kK;
|
||||
|
||||
CK_TILE_HOST_DEVICE GemmSpatiallyLocalTilePartitioner() noexcept = delete;
|
||||
CK_TILE_HOST_DEVICE GemmSpatiallyLocalTilePartitioner(index_t M_, index_t N_) noexcept
|
||||
: M(M_), N(N_)
|
||||
{
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Calculates GEMM kernel grid size.
|
||||
*
|
||||
* @param M GEMM's M dimension.
|
||||
* @param N GEMM's N dimension.
|
||||
* @return index_t A total number of workgroups.
|
||||
*/
|
||||
CK_TILE_HOST static auto
|
||||
GridSize(index_t M, index_t N) noexcept(noexcept(MPerBlock != 0 && NPerBlock != 0)) -> index_t
|
||||
{
|
||||
const index_t GridDimX = integer_divide_ceil(M, MPerBlock);
|
||||
const index_t GridDimY = integer_divide_ceil(N, NPerBlock);
|
||||
return GridDimX * GridDimY;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Calculate number of loop iterations over GEMM's K dimension.
|
||||
*
|
||||
* @param K GEMM's K dimension.
|
||||
* @return index_t The number of loop iterations over K dimension.
|
||||
*/
|
||||
CK_TILE_HOST_DEVICE static auto GetLoopNum(index_t K) noexcept -> index_t
|
||||
{
|
||||
return integer_divide_ceil(K, KPerBlock);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Calculate workgroup 1D index mapping into 2D output C-tile space.
|
||||
*
|
||||
* @param [in] block_1d_id WGP's index.
|
||||
* @return const tuple<index_t, index_t> Tuple containing 2D output C-tile index.
|
||||
*/
|
||||
CK_TILE_DEVICE auto GetOutputTileIndex(index_t block_1d_id) noexcept
|
||||
-> const tuple<index_t, index_t>
|
||||
{
|
||||
const auto M0 = integer_divide_ceil(M, MPerBlock);
|
||||
const auto N0 = integer_divide_ceil(N, NPerBlock);
|
||||
|
||||
if(M0 == 1)
|
||||
{
|
||||
return make_tuple(0, block_1d_id);
|
||||
}
|
||||
else if(N0 == 1)
|
||||
{
|
||||
return make_tuple(block_1d_id, 0);
|
||||
}
|
||||
// block_1d_id = block_1d_id % (M0 * N0); // swallow batch index
|
||||
else
|
||||
{
|
||||
const auto group_size = integer_divide_ceil(M0 * N0, GroupNum);
|
||||
const auto big_group_num = GroupNum - (group_size * GroupNum - M0 * N0);
|
||||
const auto group_id_y = block_1d_id / GroupNum;
|
||||
const auto group_id_x = block_1d_id - group_id_y * GroupNum;
|
||||
const auto remap_block_1d_id =
|
||||
group_id_x <= big_group_num
|
||||
? group_id_x * group_size + group_id_y
|
||||
: group_id_x * group_size + big_group_num - group_id_x + group_id_y;
|
||||
|
||||
const index_t idx_M0 = remap_block_1d_id / N0;
|
||||
const index_t idx_N0 = remap_block_1d_id - idx_M0 * N0;
|
||||
|
||||
const index_t M0_tmp = M0 / M01;
|
||||
const index_t M0_mod_M01 = M0 - M0_tmp * M01;
|
||||
|
||||
const auto M01_adapt = (idx_M0 < M0 - M0_mod_M01) ? M01 : M0_mod_M01;
|
||||
|
||||
const index_t idx_M00 = idx_M0 / M01;
|
||||
const index_t idx_M01 = idx_M0 - idx_M00 * M01;
|
||||
const index_t idx_N0_M01_local = idx_N0 + idx_M01 * N0;
|
||||
|
||||
/**
|
||||
* idxN0
|
||||
*
|
||||
* |< mtx N >|
|
||||
*
|
||||
* NPerBlock NPerBlock NPerBlock NPerBlock
|
||||
* N_0 N_1 N_2 N_3
|
||||
* - |-----------|-----------|-----------|-----|-----|-
|
||||
* ^ | - - 0 |/----> 2 | | | |
|
||||
* | | | / | | | | | M_0 MPerBlock
|
||||
* | M | /| | | | | |
|
||||
* |-0---|---/-|-----|-----|-----------|-----|-----|-
|
||||
* | 1 | / | | | blockid | | |
|
||||
* idxM0 | | | / | V | 5 | | | M_1 MPerBlock
|
||||
* | - V 1 | - 3 | | | |
|
||||
* |-----------|-----------|-----------|-----|-----|-
|
||||
* mtx M | | | | | |
|
||||
* | | | | | | M_2 MPerBlock
|
||||
* | | | | | |
|
||||
* |-----------|-----------|-----------|-----|-----|-
|
||||
* | | | | | |
|
||||
* | | | | | | M_3 MPerBlock
|
||||
* | | | | | |
|
||||
* |-----------|-----------|-----------|-----|-----|-
|
||||
* V | | | | | |
|
||||
* - |-----------|-----------|-----------|-----|-----|- M_4 MPerBlock
|
||||
* | | | | | |
|
||||
* |-----------|-----------|-----------|-----|-----|-
|
||||
* Example:
|
||||
* assume:
|
||||
* M0 = 5
|
||||
* N0 = 4
|
||||
* block_1d_id = 5
|
||||
* M01 = 2
|
||||
*
|
||||
* idx_N0 = 1
|
||||
* idx_M0 = 1
|
||||
* M01_adapt = 2
|
||||
* idx_M00 = 0
|
||||
* idx_M01 = 1
|
||||
* idx_N0_M01_local = 5
|
||||
* output {1, 2}
|
||||
*/
|
||||
|
||||
const index_t N_out = idx_N0_M01_local / M01_adapt;
|
||||
const index_t idx_loc_mod_M01 = idx_N0_M01_local - N_out * M01_adapt;
|
||||
|
||||
return make_tuple(idx_loc_mod_M01 + idx_M00 * M01, N_out);
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
index_t M;
|
||||
index_t N;
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -77,8 +77,8 @@ struct GroupedGemmKernel : public GemmKernel<TilePartitioner_, GemmPipeline_, Ep
|
||||
index_t grid_size = 0;
|
||||
for(const auto& it_desc : gemm_descs)
|
||||
{
|
||||
const auto dim3 = TilePartitioner::GridSize(it_desc.M, it_desc.N);
|
||||
grid_size += dim3.x * dim3.y * 1;
|
||||
const auto local_grid_size = TilePartitioner::GridSize(it_desc.M, it_desc.N);
|
||||
grid_size += local_grid_size * it_desc.k_batch;
|
||||
}
|
||||
return dim3(grid_size, 1, 1);
|
||||
}
|
||||
@@ -106,8 +106,7 @@ struct GroupedGemmKernel : public GemmKernel<TilePartitioner_, GemmPipeline_, Ep
|
||||
const index_t stride_b = gemm_descs[i].stride_B;
|
||||
const index_t stride_c = gemm_descs[i].stride_C;
|
||||
|
||||
const auto dim3 = TilePartitioner::GridSize(M, N);
|
||||
const index_t grid_size_grp = dim3.x;
|
||||
const index_t grid_size_grp = TilePartitioner::GridSize(M, N) * gemm_descs[i].k_batch;
|
||||
|
||||
const index_t block_start = grid_size;
|
||||
const index_t block_end = grid_size + grid_size_grp;
|
||||
@@ -138,8 +137,8 @@ struct GroupedGemmKernel : public GemmKernel<TilePartitioner_, GemmPipeline_, Ep
|
||||
|
||||
CK_TILE_DEVICE void Run(const GemmTransKernelArg& kargs) const
|
||||
{
|
||||
const auto [iM, iN] =
|
||||
OffsetTile1DPartitioner::GetOffsetedTileIndex(kargs.block_start, kargs.group_karg.N);
|
||||
const auto [iM, iN] = OffsetTile1DPartitioner::GetOffsetedTileIndex(
|
||||
kargs.block_start, kargs.group_karg.M, kargs.group_karg.N);
|
||||
|
||||
const index_t i_m = __builtin_amdgcn_readfirstlane(iM * TilePartitioner::MPerBlock);
|
||||
const index_t i_n = __builtin_amdgcn_readfirstlane(iN * TilePartitioner::NPerBlock);
|
||||
|
||||
@@ -53,7 +53,7 @@ class TestCkTileBatchedGemm : public ::testing::Test
|
||||
ck_tile::sequence<M_Warp, N_Warp, K_Warp>,
|
||||
ck_tile::sequence<M_Warp_Tile, N_Warp_Tile, K_Warp_Tile>>;
|
||||
|
||||
using TilePartitioner = ck_tile::GemmTile2DPartitioner<CodegenGemmShape>;
|
||||
using TilePartitioner = ck_tile::GemmTile1DPartitioner<CodegenGemmShape>;
|
||||
|
||||
using CodegenGemmTraits =
|
||||
ck_tile::TileGemmTraits<kPadM, kPadN, kPadK, ALayout, BLayout, CLayout>;
|
||||
|
||||
@@ -55,7 +55,9 @@ class TestCkTileGemmPipeline : public ::testing::Test
|
||||
// TODO: For now - but this should also be a test parameter
|
||||
constexpr bool TransposeC = false;
|
||||
|
||||
constexpr int kBlockPerCu = 1;
|
||||
constexpr int kBlockPerCu = 1;
|
||||
constexpr ck_tile::index_t TileParitionerGroupNum = 8;
|
||||
constexpr ck_tile::index_t TileParitionerM01 = 4;
|
||||
|
||||
// ===============================================
|
||||
|
||||
@@ -63,7 +65,8 @@ class TestCkTileGemmPipeline : public ::testing::Test
|
||||
ck_tile::TileGemmShape<ck_tile::sequence<M_Tile, N_Tile, K_Tile>,
|
||||
ck_tile::sequence<M_Warp, N_Warp, K_Warp>,
|
||||
ck_tile::sequence<M_Warp_Tile, N_Warp_Tile, K_Warp_Tile>>;
|
||||
using TilePartitioner = ck_tile::GemmTile2DPartitioner<GemmShape>;
|
||||
using TilePartitioner = ck_tile::
|
||||
GemmSpatiallyLocalTilePartitioner<GemmShape, TileParitionerGroupNum, TileParitionerM01>;
|
||||
|
||||
using Traits = ck_tile::TileGemmTraits<kPadM, kPadN, kPadK, ALayout, BLayout, CLayout>;
|
||||
using GemmUniversalTraits = ck_tile::
|
||||
|
||||
Reference in New Issue
Block a user