[CK Tile] Spatially local GEMM tile partitioner. (#1843)

* Add spatially local tile partitioner

* Use 1D Grid size & create partitioner object.

* Docs & use 1D partitioner in example.

* Clang format.

* Change kernel grid size

Now: X is the # of output C-tiles,
     Y is the batch count
     Z is the splitK

* Formatting & more doc.

* Clang format.

* Fix batched gemm test. Use 1d partitioner.

* Move condition.

* FIx ctor.

* clang-format.
This commit is contained in:
Adam Osewski
2025-01-31 00:10:16 +01:00
committed by GitHub
parent e6d4180498
commit ce448002ee
10 changed files with 285 additions and 88 deletions

View File

@@ -75,12 +75,12 @@ struct GemmKernel
static constexpr auto I1 = number<1>();
static constexpr auto I2 = number<2>();
__host__ static constexpr auto GridSize(index_t M, index_t N, index_t KBatch)
CK_TILE_HOST static constexpr auto GridSize(index_t M, index_t N, index_t KBatch)
{
return TilePartitioner::GridSize(M, N, KBatch);
return dim3(TilePartitioner::GridSize(M, N), 1, KBatch);
}
__host__ static constexpr auto BlockSize() { return dim3(KernelBlockSize); }
CK_TILE_HOST static constexpr auto BlockSize() { return dim3(KernelBlockSize); }
struct GemmKernelArgs
{
@@ -93,7 +93,7 @@ struct GemmKernel
index_t stride_A;
index_t stride_B;
index_t stride_C;
index_t KBatch;
index_t k_batch;
};
CK_TILE_HOST static constexpr GemmKernelArgs MakeKernelArgs(const GemmHostArgs& hostArgs)
@@ -121,7 +121,7 @@ struct GemmKernel
const std::size_t k_id = blockIdx.z)
{
constexpr auto K1 = TilePartitioner::BlockGemmShape::WarpTile::at(number<2>{});
const index_t K_t = kargs.KBatch * K1;
const index_t K_t = kargs.k_batch * K1;
const index_t KRead = (kargs.K + K_t - 1) / K_t * K1;
if constexpr(std::is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
@@ -142,13 +142,13 @@ struct GemmKernel
b_k_split_offset = k_id * KRead;
}
if(k_id < static_cast<uint32_t>(kargs.KBatch - 1))
if(k_id < static_cast<uint32_t>(kargs.k_batch - 1))
{
splitted_k = KRead;
}
else
{
splitted_k = kargs.K - KRead * (kargs.KBatch - 1);
splitted_k = kargs.K - KRead * (kargs.k_batch - 1);
}
}
@@ -162,7 +162,7 @@ struct GemmKernel
if constexpr(EpiloguePipeline::GetVectorSizeC() % 2 != 0 &&
is_any_of<CDataType, fp16_t, bf16_t>::value)
{
if(kargs.KBatch != 1)
if(kargs.k_batch != 1)
{
std::cerr << "Conditions not met for Kbatch >1 !" << std::endl;
return false;
@@ -489,19 +489,14 @@ struct GemmKernel
// Run Epilogue Pipeline
auto& c_block_window = gemm_tile_windows.at(I2);
if constexpr(DstInMemOp == memory_operation_enum::set ||
!(EpiloguePipeline::GetVectorSizeC() % 2 != 0 &&
is_any_of<CDataType, fp16_t, bf16_t>::value))
{
EpiloguePipeline{}
.template operator()<decltype(c_block_window), decltype(c_block_tile), DstInMemOp>(
c_block_window, c_block_tile, smem_ptr);
}
EpiloguePipeline{}
.template operator()<decltype(c_block_window), decltype(c_block_tile), DstInMemOp>(
c_block_window, c_block_tile, smem_ptr);
}
CK_TILE_DEVICE void operator()(GemmKernelArgs kargs) const
{
const auto [iM, iN] = TilePartitioner::GetOutputTileIndex(blockIdx.x, blockIdx.y);
const auto [iM, iN] = TilePartitioner{kargs.M, kargs.N}.GetOutputTileIndex(blockIdx.x);
const index_t i_m = __builtin_amdgcn_readfirstlane(iM * TilePartitioner::MPerBlock);
const index_t i_n = __builtin_amdgcn_readfirstlane(iN * TilePartitioner::NPerBlock);
@@ -516,14 +511,20 @@ struct GemmKernel
// allocate LDS
__shared__ char smem_ptr[GetSmemSize()];
if(kargs.KBatch == 1)
if(kargs.k_batch == 1)
{
RunGemm(a_ptr, b_ptr, c_ptr, smem_ptr, kargs, splitk_batch_offset, i_m, i_n);
}
else
{
RunGemm<memory_operation_enum::atomic_add>(
a_ptr, b_ptr, c_ptr, smem_ptr, kargs, splitk_batch_offset, i_m, i_n);
// Do not compile in case where we have unsupported
// VectorSizeC & data type configuration.
if constexpr(!(EpiloguePipeline::GetVectorSizeC() % 2 != 0 &&
is_any_of<CDataType, fp16_t, bf16_t>::value))
{
RunGemm<memory_operation_enum::atomic_add>(
a_ptr, b_ptr, c_ptr, smem_ptr, kargs, splitk_batch_offset, i_m, i_n);
}
}
}
};