[CK TILE] GEMM and Batched GEMM SplitK support (#1724)

* [CK TILE] Add split K support in GEMM

* Updates

* Fixes

* rebase

* fix

* Fix

* fixes

* support for batched gemm

[ROCm/composable_kernel commit: af66494880]
This commit is contained in:
Bartłomiej Kocot
2024-12-28 14:40:17 +01:00
committed by GitHub
parent 282f02cc66
commit a5a7f2675f
18 changed files with 245 additions and 91 deletions

View File

@@ -93,7 +93,7 @@ class TestCkTileBatchedGemm : public ::testing::Test
auto kargs = Kernel::MakeKernelArgs(args);
const dim3 grids = Kernel::GridSize(args.M, args.N, args.batch_count);
const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch, args.batch_count);
constexpr dim3 blocks = Kernel::BlockSize();
if(s.log_level_ > 0)
@@ -186,6 +186,7 @@ class TestCkTileBatchedGemm : public ::testing::Test
args.a_ptr = a_m_k_dev_buf.GetDeviceBuffer();
args.b_ptr = b_k_n_dev_buf.GetDeviceBuffer();
args.c_ptr = c_m_n_dev_buf.GetDeviceBuffer();
args.k_batch = 1;
args.M = M;
args.N = N;
args.K = K;

View File

@@ -74,7 +74,9 @@ class TestCkTileGemmPipeline : public ::testing::Test
ck_tile::
GemmPipelineProblem<ADataType, BDataType, AccDataType, GemmShape, Traits>>>;
const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(args.K);
const ck_tile::index_t k_grain = args.k_batch * K_Tile;
const ck_tile::index_t K_split = (args.K + k_grain - 1) / k_grain * K_Tile;
const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split);
const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop);
const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop);