From c2edc3c5b86023ed4c3ce39f36119a6742e60208 Mon Sep 17 00:00:00 2001 From: mtgu0705 Date: Wed, 13 Aug 2025 09:24:24 -0500 Subject: [PATCH] update gemm_mx_pipeline_ag_bg_cr_v3 --- .../pipeline/gemm_mx_pipeline_ag_bg_cr_v3.hpp | 47 ++++++++++++------- 1 file changed, 30 insertions(+), 17 deletions(-) diff --git a/include/ck_tile/ops/gemm_mx/pipeline/gemm_mx_pipeline_ag_bg_cr_v3.hpp b/include/ck_tile/ops/gemm_mx/pipeline/gemm_mx_pipeline_ag_bg_cr_v3.hpp index 1952f867b1..3131a45eb8 100644 --- a/include/ck_tile/ops/gemm_mx/pipeline/gemm_mx_pipeline_ag_bg_cr_v3.hpp +++ b/include/ck_tile/ops/gemm_mx/pipeline/gemm_mx_pipeline_ag_bg_cr_v3.hpp @@ -115,13 +115,6 @@ struct GemmMXPipelineAgBgCrCompV3 : public BaseGemmMXPipelineAgBgCrCompV3())>; - static constexpr index_t BlockSize = Problem::kBlockSize; - static constexpr index_t MPerBlock = BlockGemmShape::kM; - static constexpr index_t NPerBlock = BlockGemmShape::kN; - static constexpr index_t KPerBlock = BlockGemmShape::kK; - static constexpr index_t BlockScaleSize = Problem::kBlockScaleSize; - static constexpr index_t KPerBlockScale = BlockGemmShape::kK / BlockScaleSize; - using MXdlPack = PipelineImplBase::MXdlPack; using NXdlPack = PipelineImplBase::NXdlPack; using KXdlPack = PipelineImplBase::KXdlPack; @@ -129,6 +122,14 @@ struct GemmMXPipelineAgBgCrCompV3 : public BaseGemmMXPipelineAgBgCrCompV3(); } static constexpr index_t GetVectorSizeB() { return Policy::template GetVectorSizeB(); } static constexpr index_t GetVectorSizeC() { return Policy::template GetVectorSizeC(); } @@ -324,22 +325,26 @@ struct GemmMXPipelineAgBgCrCompV3 : public BaseGemmMXPipelineAgBgCrCompV3(ABlockTileDistr{})); using BBlockTile = decltype(make_static_distributed_tensor(BBlockTileDistr{})); - using AQBlockTile = - decltype(make_static_distributed_tensor(AQBlockTileDistr{})); + using AScaleBlockTile = + decltype(make_static_distributed_tensor(AScaleBlockTileDistr{})); + using BScaleBlockTile = + decltype(make_static_distributed_tensor(BScaleBlockTileDistr{})); auto block_gemm = BlockGemm(); ABlockTile a_block_tile; BBlockTile b_block_tile; - AQBlockTile aq_block_tile[2]; + AScaleBlockTile a_scale_block_tile[2]; + BScaleBlockTile b_scale_block_tile[2]; int currIdx = 0; auto c_block_tile = block_gemm.MakeCBlockTile(); @@ -348,14 +353,22 @@ struct GemmMXPipelineAgBgCrCompV3 : public BaseGemmMXPipelineAgBgCrCompV3