From aa5ae1c0052d568ad033124a4b385b1ce119e9b9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bart=C5=82omiej=20Kocot?= Date: Thu, 19 Mar 2026 12:59:44 +0100 Subject: [PATCH] [CK][CK Tile] Fix dram step for KM/KN layouts in V1 pipeline (#5470) ## Motivation Fix v1 pipeline for KM/KN layouts by passing correct step for dram tile window. ## Technical Details - Fix dram step for KM/KN layouts in V1 pipeline - Disable instances which use more threads than warp size in continous dim (not supported in ck tile yet) - Use 1x1 specialization for explicit gemm - Use two stage for vectorsize =1 and sizeof(datatype) ==2 - remove not needed check sinze GetVectorSizeA/B check if vector size is fixed ## Test Plan test_grouped_convnd_bwd_weight_tile ## Test Result passed locally ## Submission Checklist - [x] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests. AICK-966 --- .../generate_instances.py | 9 ++++++++- .../gemm_pipeline_agmem_bgmem_creg_v1.hpp | 16 ++++++++++++---- .../gemm_universal_pipeline_ag_bg_cr_policy.hpp | 13 +++++-------- 3 files changed, 25 insertions(+), 13 deletions(-) diff --git a/experimental/grouped_convolution_tile_instances/generate_instances.py b/experimental/grouped_convolution_tile_instances/generate_instances.py index 5346f6d2cb..b2cb37c39b 100755 --- a/experimental/grouped_convolution_tile_instances/generate_instances.py +++ b/experimental/grouped_convolution_tile_instances/generate_instances.py @@ -317,7 +317,7 @@ def parse_bwd_weight_instances(instances, problem_name): gemm_params = device_op_name = instance.split("<")[2].split(">")[1].split(",") args = [param.split(":")[1].strip() for param in gemm_params] - spec = "Default" + spec = "Filter1x1Stride1Pad0" block_size = int(args[0]) mnk_per_block = args[1].split("x") @@ -450,6 +450,13 @@ def parse_bwd_weight_instances(instances, problem_name): if pipeline_version == "V6": print(f"Skipping instance {instance_id} with V6 since it's not supported yet.") continue + if m_per_block > (warp_size * a_scalar_per_vector) or n_per_block > (warp_size * b_scalar_per_vector): + print(f"Skipping instance {instance_id} with multiple warps per continous tile dim since it's not supported yet.") + continue + + if is_explicit_gemm: + if dtype != "float" and c_scalar_per_vector % 2 != 0: + is_two_stage_instance = True conv = ConvInstanceTemplateParams( spec, diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp index 918eb3de26..a3268fa2ff 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp @@ -449,6 +449,14 @@ struct GemmPipelineAGmemBGmemCRegV1 : public BaseGemmPipelineAGmemBGmemCRegV1 CK_TILE_HOST_DEVICE static constexpr auto MakeADramTileDistribution() { - constexpr index_t BlockSize = Problem::kBlockSize; - constexpr index_t MPerBlock = Problem::BlockGemmShape::kM; - constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; - constexpr index_t VecLoadSize = - Problem::FixedVectorSize ? Problem::VectorSizeA : GetVectorSizeA(); + constexpr index_t BlockSize = Problem::kBlockSize; + constexpr index_t MPerBlock = Problem::BlockGemmShape::kM; + constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; + constexpr index_t VecLoadSize = GetVectorSizeA(); constexpr index_t NumWaveGroups = Problem::NumWaveGroups; using ALayout = remove_cvref_t< @@ -756,9 +755,7 @@ struct UniversalGemmBasePolicy // since the assumption is that A type is going to be the B LDS type constexpr bool IsBCastPolicyBeforeLDSWrite = IsBCastPolicyBeforeLDSWrite_v; constexpr index_t VecLoadSize = - IsBCastPolicyBeforeLDSWrite - ? (Problem::FixedVectorSize ? Problem::VectorSizeA : GetVectorSizeA()) - : (Problem::FixedVectorSize ? Problem::VectorSizeB : GetVectorSizeB()); + IsBCastPolicyBeforeLDSWrite ? GetVectorSizeA() : GetVectorSizeB(); constexpr index_t NumWaveGroups = Problem::NumWaveGroups; using BLayout = remove_cvref_t< std::tuple_element_t{}, remove_cvref_t>>;