From 2d23e434ff910e1966f13117bf5a002c3355d719 Mon Sep 17 00:00:00 2001 From: Mateusz Ozga Date: Mon, 16 Jun 2025 15:12:24 +0000 Subject: [PATCH] Fix develop: basic gemm --- example/ck_tile/03_gemm/gemm_basic.cpp | 3 +++ .../ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v5.hpp | 2 ++ .../ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp | 2 ++ .../ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2.hpp | 2 ++ include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp | 2 ++ include/ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp | 1 + 6 files changed, 12 insertions(+) diff --git a/example/ck_tile/03_gemm/gemm_basic.cpp b/example/ck_tile/03_gemm/gemm_basic.cpp index defeffc2ee..29d449f803 100644 --- a/example/ck_tile/03_gemm/gemm_basic.cpp +++ b/example/ck_tile/03_gemm/gemm_basic.cpp @@ -69,9 +69,12 @@ float gemm(const ck_tile::GemmHostArgs& args, const ck_tile: using GemmEpilogue = ck_tile::CShuffleEpilogue< ck_tile::CShuffleEpilogueProblem; using I2 = number<2>; + static constexpr index_t NumWaveGroups = Problem::NumWaveGroups; + static constexpr index_t BlockSize = Problem::kBlockSize; static constexpr index_t kMPerBlock = BlockGemmShape::kM; diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2.hpp index 95b7618b11..bc097356b1 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2.hpp @@ -27,6 +27,8 @@ struct GemmPipelineAGmemBGmemCRegV2 static constexpr index_t kBlockSize = Problem::kBlockSize; + static constexpr index_t NumWaveGroups = Problem::NumWaveGroups; + static constexpr index_t kMPerBlock = BlockGemmShape::kM; static constexpr index_t kNPerBlock = BlockGemmShape::kN; static constexpr index_t kKPerBlock = BlockGemmShape::kK; diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp index 678fb6eb46..a78c0bb0b4 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp @@ -45,6 +45,8 @@ struct GemmPipelineProblemBase static constexpr auto Scheduler = GemmPipelineScheduler::Default; static constexpr index_t VectorLoadSize = Traits::_VectorSize; + static constexpr index_t NumWaveGroups = Traits::NumWaveGroups; + [[nodiscard]] CK_TILE_HOST static const std::string GetName() { // clang-format off diff --git a/include/ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp b/include/ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp index 353192d86f..c6f83068a9 100644 --- a/include/ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp +++ b/include/ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp @@ -28,6 +28,7 @@ struct TileGemmTraits static constexpr bool TransposeC = false; static constexpr bool UseStructuredSparsity = false; + static constexpr index_t NumWaveGroups = 1; }; template