From 181ea79a3d9f8671627d1a375b828af29c1765fd Mon Sep 17 00:00:00 2001 From: Alan Turner Date: Fri, 22 Sep 2023 20:09:41 +0000 Subject: [PATCH] Avoid pipeline version 2 when k % kpb != 0 --- library/src/jit_library/src/device_gemm_multiple_d.cpp | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/library/src/jit_library/src/device_gemm_multiple_d.cpp b/library/src/jit_library/src/device_gemm_multiple_d.cpp index b31f53c161..a10a1650a1 100644 --- a/library/src/jit_library/src/device_gemm_multiple_d.cpp +++ b/library/src/jit_library/src/device_gemm_multiple_d.cpp @@ -134,13 +134,15 @@ Solution Problem::MakeSolution(std::size_t idx, const std::string& arch) const const std::size_t k_per_block = std::stoi(k_per_block_str); const std::size_t grid_size = GetGridSize(M, N, m_per_block, n_per_block); params[gemm_spec_idx] = GetGemmSpec(M, N, K, m_per_block, n_per_block, k_per_block); - + std::string str = std::accumulate( params.begin() + 1, params.end(), std::string{}, [](const std::string& a, const std::string& b) { return a.empty() ? b : a + ", " + b; }); str = params.front() + "< " + str + ">"; + if (params.back().find("v2") != std::string::npos and K % k_per_block != 0) + str = ""; return Solution{str, block_size, grid_size}; } @@ -156,7 +158,9 @@ std::vector Problem::GetSolutions(const std::string& arch) const const std::size_t num_instances = GetInstances(arch).size(); for(std::size_t i = 0; i < num_instances; ++i) { - solutions.push_back(MakeSolution(i, arch)); + auto solution = MakeSolution(i, arch); + if (solution.template_str != "") + solutions.push_back(solution); } return solutions;