From 54f903def5a9dfb5b27b7a87a46f7a4bdd3b150d Mon Sep 17 00:00:00 2001 From: jakpiase Date: Fri, 5 Dec 2025 09:30:22 +0100 Subject: [PATCH] fix enforcing fixedvectorsizes for ck tile conv (#3344) [ROCm/composable_kernel commit: f7650ee82b306a05d9c3c44d3feefdd570a4bd58] --- .../gemm_universal_pipeline_ag_bg_cr_policy.hpp | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp index d843916f5e..76341af70b 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp @@ -545,7 +545,7 @@ struct UniversalGemmBasePolicy } template - CK_TILE_HOST_DEVICE static constexpr auto GetVectorSizeA() + CK_TILE_HOST_DEVICE static constexpr index_t GetVectorSizeA() { using AsLayout = remove_cvref_t; using AsDataType = remove_cvref_t; @@ -555,6 +555,11 @@ struct UniversalGemmBasePolicy using ALayout = remove_cvref_t{}, AsLayout>>; using ADataType = remove_cvref_t{}, AsDataType>>; + if constexpr(Problem::FixedVectorSize) + { + return Problem::VectorSizeA; + } + if constexpr(std::is_same_v) { return GetGlobalVectorLoadSize - CK_TILE_HOST_DEVICE static constexpr auto GetVectorSizeB() + CK_TILE_HOST_DEVICE static constexpr index_t GetVectorSizeB() { using BsLayout = remove_cvref_t; using BsDataType = remove_cvref_t; @@ -584,6 +589,11 @@ struct UniversalGemmBasePolicy using BLayout = remove_cvref_t{}, BsLayout>>; using BDataType = remove_cvref_t{}, BsDataType>>; + if constexpr(Problem::FixedVectorSize) + { + return Problem::VectorSizeB; + } + if constexpr(std::is_same_v) { return GetGlobalVectorLoadSize