diff --git a/include/ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp b/include/ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp index 2761b16571..f75c034555 100644 --- a/include/ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp +++ b/include/ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp @@ -228,34 +228,9 @@ struct BatchedGemmKernel CDataType* c_ptr = static_cast(kargs.e_ptr) + batch_offset_C; // allocate LDS - __shared__ char smem_ptr0[GetSmemSize()]; - - if constexpr(GemmPipeline::DoubleSmemBuffer == true) - { - __shared__ char smem_ptr1[GemmPipeline::GetSmemSize()]; - UniversalGemmKernel::RunGemm2LDS({a_ptr}, - {b_ptr}, - {/*ds_ptr*/}, - c_ptr, - smem_ptr0, - smem_ptr1, - kargs, - splitk_batch_offset, - i_m, - i_n); - } - else - { - UniversalGemmKernel::RunGemm({a_ptr}, - {b_ptr}, - {/*ds_ptr*/}, - c_ptr, - smem_ptr0, - kargs, - splitk_batch_offset, - i_m, - i_n); - } + __shared__ char smem_ptr[GetSmemSize()]; + UniversalGemmKernel::RunGemm( + {a_ptr}, {b_ptr}, {/*ds_ptr*/}, c_ptr, smem_ptr, kargs, splitk_batch_offset, i_m, i_n); } }; diff --git a/include/ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp b/include/ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp index f5d63e977d..8a3bbc425a 100644 --- a/include/ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp +++ b/include/ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp @@ -1156,9 +1156,13 @@ struct UniversalGemmKernel } else { - auto c_block_window = MakeCBlockWindows( - e_ptr, kargs, block_idx_m, block_idx_n); - EpiloguePipeline{}(c_block_window, c_block_tile, ds_block_window, smem_ptr); + if constexpr(EpiloguePipeline::GetVectorSizeC() % 2 == 0 || + !is_any_of::value) + { + auto c_block_window = MakeCBlockWindows( + e_ptr, kargs, block_idx_m, block_idx_n); + EpiloguePipeline{}(c_block_window, c_block_tile, ds_block_window, smem_ptr); + } } }