From 9c193594380f8319f2123dbeb19d48bff191a8db Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bart=C5=82omiej=20Kocot?= Date: Fri, 27 Feb 2026 23:16:19 +0100 Subject: [PATCH] [CK][CK Tile] Fix batched gemm kernel 2 lds (#4963) ## Motivation Fix 2 lds batched gemm universal gemm call. Disable split k for not valid atomic add instruction size. ## Technical Details Fix 2 lds batched gemm universal gemm call. Disable split k for not valid atomic add instruction size. ## Test Plan CI overall ## Test Result pending ## Submission Checklist - [x] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests. --- .../ops/gemm/kernel/batched_gemm_kernel.hpp | 31 ++----------------- .../ops/gemm/kernel/universal_gemm_kernel.hpp | 10 ++++-- 2 files changed, 10 insertions(+), 31 deletions(-) diff --git a/include/ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp b/include/ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp index 2761b16571..f75c034555 100644 --- a/include/ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp +++ b/include/ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp @@ -228,34 +228,9 @@ struct BatchedGemmKernel CDataType* c_ptr = static_cast(kargs.e_ptr) + batch_offset_C; // allocate LDS - __shared__ char smem_ptr0[GetSmemSize()]; - - if constexpr(GemmPipeline::DoubleSmemBuffer == true) - { - __shared__ char smem_ptr1[GemmPipeline::GetSmemSize()]; - UniversalGemmKernel::RunGemm2LDS({a_ptr}, - {b_ptr}, - {/*ds_ptr*/}, - c_ptr, - smem_ptr0, - smem_ptr1, - kargs, - splitk_batch_offset, - i_m, - i_n); - } - else - { - UniversalGemmKernel::RunGemm({a_ptr}, - {b_ptr}, - {/*ds_ptr*/}, - c_ptr, - smem_ptr0, - kargs, - splitk_batch_offset, - i_m, - i_n); - } + __shared__ char smem_ptr[GetSmemSize()]; + UniversalGemmKernel::RunGemm( + {a_ptr}, {b_ptr}, {/*ds_ptr*/}, c_ptr, smem_ptr, kargs, splitk_batch_offset, i_m, i_n); } }; diff --git a/include/ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp b/include/ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp index f5d63e977d..8a3bbc425a 100644 --- a/include/ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp +++ b/include/ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp @@ -1156,9 +1156,13 @@ struct UniversalGemmKernel } else { - auto c_block_window = MakeCBlockWindows( - e_ptr, kargs, block_idx_m, block_idx_n); - EpiloguePipeline{}(c_block_window, c_block_tile, ds_block_window, smem_ptr); + if constexpr(EpiloguePipeline::GetVectorSizeC() % 2 == 0 || + !is_any_of::value) + { + auto c_block_window = MakeCBlockWindows( + e_ptr, kargs, block_idx_m, block_idx_n); + EpiloguePipeline{}(c_block_window, c_block_tile, ds_block_window, smem_ptr); + } } }