mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 14:59:17 +00:00
[CK Tile] Spatially local GEMM tile partitioner. (#1843)
* Add spatially local tile partitioner
* Use 1D Grid size & create partitioner object.
* Docs & use 1D partitioner in example.
* Clang format.
* Change kernel grid size
Now: X is the # of output C-tiles,
Y is the batch count
Z is the splitK
* Formatting & more doc.
* Clang format.
* Fix batched gemm test. Use 1d partitioner.
* Move condition.
* FIx ctor.
* clang-format.
This commit is contained in:
@@ -75,12 +75,12 @@ struct GemmKernel
|
||||
static constexpr auto I1 = number<1>();
|
||||
static constexpr auto I2 = number<2>();
|
||||
|
||||
__host__ static constexpr auto GridSize(index_t M, index_t N, index_t KBatch)
|
||||
CK_TILE_HOST static constexpr auto GridSize(index_t M, index_t N, index_t KBatch)
|
||||
{
|
||||
return TilePartitioner::GridSize(M, N, KBatch);
|
||||
return dim3(TilePartitioner::GridSize(M, N), 1, KBatch);
|
||||
}
|
||||
|
||||
__host__ static constexpr auto BlockSize() { return dim3(KernelBlockSize); }
|
||||
CK_TILE_HOST static constexpr auto BlockSize() { return dim3(KernelBlockSize); }
|
||||
|
||||
struct GemmKernelArgs
|
||||
{
|
||||
@@ -93,7 +93,7 @@ struct GemmKernel
|
||||
index_t stride_A;
|
||||
index_t stride_B;
|
||||
index_t stride_C;
|
||||
index_t KBatch;
|
||||
index_t k_batch;
|
||||
};
|
||||
|
||||
CK_TILE_HOST static constexpr GemmKernelArgs MakeKernelArgs(const GemmHostArgs& hostArgs)
|
||||
@@ -121,7 +121,7 @@ struct GemmKernel
|
||||
const std::size_t k_id = blockIdx.z)
|
||||
{
|
||||
constexpr auto K1 = TilePartitioner::BlockGemmShape::WarpTile::at(number<2>{});
|
||||
const index_t K_t = kargs.KBatch * K1;
|
||||
const index_t K_t = kargs.k_batch * K1;
|
||||
const index_t KRead = (kargs.K + K_t - 1) / K_t * K1;
|
||||
|
||||
if constexpr(std::is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
|
||||
@@ -142,13 +142,13 @@ struct GemmKernel
|
||||
b_k_split_offset = k_id * KRead;
|
||||
}
|
||||
|
||||
if(k_id < static_cast<uint32_t>(kargs.KBatch - 1))
|
||||
if(k_id < static_cast<uint32_t>(kargs.k_batch - 1))
|
||||
{
|
||||
splitted_k = KRead;
|
||||
}
|
||||
else
|
||||
{
|
||||
splitted_k = kargs.K - KRead * (kargs.KBatch - 1);
|
||||
splitted_k = kargs.K - KRead * (kargs.k_batch - 1);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -162,7 +162,7 @@ struct GemmKernel
|
||||
if constexpr(EpiloguePipeline::GetVectorSizeC() % 2 != 0 &&
|
||||
is_any_of<CDataType, fp16_t, bf16_t>::value)
|
||||
{
|
||||
if(kargs.KBatch != 1)
|
||||
if(kargs.k_batch != 1)
|
||||
{
|
||||
std::cerr << "Conditions not met for Kbatch >1 !" << std::endl;
|
||||
return false;
|
||||
@@ -489,19 +489,14 @@ struct GemmKernel
|
||||
// Run Epilogue Pipeline
|
||||
auto& c_block_window = gemm_tile_windows.at(I2);
|
||||
|
||||
if constexpr(DstInMemOp == memory_operation_enum::set ||
|
||||
!(EpiloguePipeline::GetVectorSizeC() % 2 != 0 &&
|
||||
is_any_of<CDataType, fp16_t, bf16_t>::value))
|
||||
{
|
||||
EpiloguePipeline{}
|
||||
.template operator()<decltype(c_block_window), decltype(c_block_tile), DstInMemOp>(
|
||||
c_block_window, c_block_tile, smem_ptr);
|
||||
}
|
||||
EpiloguePipeline{}
|
||||
.template operator()<decltype(c_block_window), decltype(c_block_tile), DstInMemOp>(
|
||||
c_block_window, c_block_tile, smem_ptr);
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE void operator()(GemmKernelArgs kargs) const
|
||||
{
|
||||
const auto [iM, iN] = TilePartitioner::GetOutputTileIndex(blockIdx.x, blockIdx.y);
|
||||
const auto [iM, iN] = TilePartitioner{kargs.M, kargs.N}.GetOutputTileIndex(blockIdx.x);
|
||||
const index_t i_m = __builtin_amdgcn_readfirstlane(iM * TilePartitioner::MPerBlock);
|
||||
const index_t i_n = __builtin_amdgcn_readfirstlane(iN * TilePartitioner::NPerBlock);
|
||||
|
||||
@@ -516,14 +511,20 @@ struct GemmKernel
|
||||
// allocate LDS
|
||||
__shared__ char smem_ptr[GetSmemSize()];
|
||||
|
||||
if(kargs.KBatch == 1)
|
||||
if(kargs.k_batch == 1)
|
||||
{
|
||||
RunGemm(a_ptr, b_ptr, c_ptr, smem_ptr, kargs, splitk_batch_offset, i_m, i_n);
|
||||
}
|
||||
else
|
||||
{
|
||||
RunGemm<memory_operation_enum::atomic_add>(
|
||||
a_ptr, b_ptr, c_ptr, smem_ptr, kargs, splitk_batch_offset, i_m, i_n);
|
||||
// Do not compile in case where we have unsupported
|
||||
// VectorSizeC & data type configuration.
|
||||
if constexpr(!(EpiloguePipeline::GetVectorSizeC() % 2 != 0 &&
|
||||
is_any_of<CDataType, fp16_t, bf16_t>::value))
|
||||
{
|
||||
RunGemm<memory_operation_enum::atomic_add>(
|
||||
a_ptr, b_ptr, c_ptr, smem_ptr, kargs, splitk_batch_offset, i_m, i_n);
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
Reference in New Issue
Block a user