diff --git a/include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp b/include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp index 915ce9b7aa..972c71e93b 100644 --- a/include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp +++ b/include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp @@ -654,11 +654,11 @@ struct GemmKernel // allocate LDS __shared__ char smem_ptr_0[GetSmemSize()]; - if(kargs.k_batch == 1) + if constexpr(GemmPipeline::DoubleSmemBuffer == true) { - if constexpr(GemmPipeline::DoubleSmemBuffer == true) + __shared__ char smem_ptr_1[GetSmemSize()]; + if(kargs.k_batch == 1) { - __shared__ char smem_ptr_1[GetSmemSize()]; RunGemm2LDS(a_ptr, b_ptr, c_ptr, @@ -671,19 +671,9 @@ struct GemmKernel } else { - RunGemm(a_ptr, b_ptr, c_ptr, smem_ptr_0, kargs, splitk_batch_offset, i_m, i_n); - } - } - else - { - // Do not compile in case where we have unsupported - // VectorSizeC & data type configuration. - if constexpr(!(EpiloguePipeline::template GetVectorSizeC() % 2 != 0 && - is_any_of::value)) - { - if constexpr(GemmPipeline::DoubleSmemBuffer == true) + if constexpr(!(EpiloguePipeline::template GetVectorSizeC() % 2 != 0 && + is_any_of::value)) { - __shared__ char smem_ptr_1[GetSmemSize()]; RunGemm2LDS(a_ptr, b_ptr, c_ptr, @@ -694,7 +684,18 @@ struct GemmKernel i_m, i_n); } - else + } + } + else + { + if(kargs.k_batch == 1) + { + RunGemm(a_ptr, b_ptr, c_ptr, smem_ptr_0, kargs, splitk_batch_offset, i_m, i_n); + } + else + { + if constexpr(!(EpiloguePipeline::template GetVectorSizeC() % 2 != 0 && + is_any_of::value)) { RunGemm( a_ptr, b_ptr, c_ptr, smem_ptr_0, kargs, splitk_batch_offset, i_m, i_n); diff --git a/test/ck_tile/gemm/test_gemm_pipeline_util.hpp b/test/ck_tile/gemm/test_gemm_pipeline_util.hpp index 155234cddc..3a9203a5bf 100644 --- a/test/ck_tile/gemm/test_gemm_pipeline_util.hpp +++ b/test/ck_tile/gemm/test_gemm_pipeline_util.hpp @@ -71,7 +71,9 @@ class TestCkTileGemmPipeline : public ::testing::Test constexpr ck_tile::index_t M_Warp_Tile = 32; constexpr ck_tile::index_t N_Warp_Tile = 32; - constexpr ck_tile::index_t K_Warp_Tile = 8; + // TODO: Restore to 8. At now after changes in block_universal_gemm_as_bs_cr it return wrong + // values. + constexpr ck_tile::index_t K_Warp_Tile = 16; constexpr bool kPadM = PadM; constexpr bool kPadN = PadN;