mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-21 13:29:20 +00:00
[CK TILE] GEMM and Batched GEMM SplitK support (#1724)
* [CK TILE] Add split K support in GEMM
* Updates
* Fixes
* rebase
* fix
* Fix
* fixes
* support for batched gemm
[ROCm/composable_kernel commit: af66494880]
This commit is contained in:
@@ -93,7 +93,7 @@ class TestCkTileBatchedGemm : public ::testing::Test
|
||||
|
||||
auto kargs = Kernel::MakeKernelArgs(args);
|
||||
|
||||
const dim3 grids = Kernel::GridSize(args.M, args.N, args.batch_count);
|
||||
const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch, args.batch_count);
|
||||
constexpr dim3 blocks = Kernel::BlockSize();
|
||||
|
||||
if(s.log_level_ > 0)
|
||||
@@ -186,6 +186,7 @@ class TestCkTileBatchedGemm : public ::testing::Test
|
||||
args.a_ptr = a_m_k_dev_buf.GetDeviceBuffer();
|
||||
args.b_ptr = b_k_n_dev_buf.GetDeviceBuffer();
|
||||
args.c_ptr = c_m_n_dev_buf.GetDeviceBuffer();
|
||||
args.k_batch = 1;
|
||||
args.M = M;
|
||||
args.N = N;
|
||||
args.K = K;
|
||||
|
||||
@@ -74,7 +74,9 @@ class TestCkTileGemmPipeline : public ::testing::Test
|
||||
ck_tile::
|
||||
GemmPipelineProblem<ADataType, BDataType, AccDataType, GemmShape, Traits>>>;
|
||||
|
||||
const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(args.K);
|
||||
const ck_tile::index_t k_grain = args.k_batch * K_Tile;
|
||||
const ck_tile::index_t K_split = (args.K + k_grain - 1) / k_grain * K_Tile;
|
||||
const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split);
|
||||
const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop);
|
||||
const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop);
|
||||
|
||||
|
||||
Reference in New Issue
Block a user