From 6fdf2bd896cbd26391c0446dfb728546f417e118 Mon Sep 17 00:00:00 2001 From: YC Lin Date: Thu, 10 Apr 2025 14:22:22 +0000 Subject: [PATCH] [GEMM] Refactor GetStaticLdsSize and remove GetSmemSize --- .../block_gemm_pipeline_agmem_bgmem_creg.hpp | 5 -- ...peline_agmem_bgmem_creg_default_policy.hpp | 60 ++++--------------- .../ck_tile/99_toy_example/02_gemm/gemm.hpp | 43 ++++--------- 3 files changed, 20 insertions(+), 88 deletions(-) diff --git a/example/ck_tile/99_toy_example/02_gemm/block_gemm_pipeline_agmem_bgmem_creg.hpp b/example/ck_tile/99_toy_example/02_gemm/block_gemm_pipeline_agmem_bgmem_creg.hpp index d6adf81054..649d56336d 100644 --- a/example/ck_tile/99_toy_example/02_gemm/block_gemm_pipeline_agmem_bgmem_creg.hpp +++ b/example/ck_tile/99_toy_example/02_gemm/block_gemm_pipeline_agmem_bgmem_creg.hpp @@ -30,9 +30,6 @@ struct BlockGemmPipelineAGmemBGmemCReg CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetStaticLdsSize() { -#if defined(ENABLE_INSTRUCTION_SCH) - return Policy::template GetSmemSize(); -#else return integer_divide_ceil( sizeof(ADataType) * Policy::template MakeALdsBlockDescriptor().get_element_space_size(), @@ -40,7 +37,6 @@ struct BlockGemmPipelineAGmemBGmemCReg 16 + sizeof(BDataType) * Policy::template MakeBLdsBlockDescriptor().get_element_space_size(); -#endif } #if defined(ENABLE_INSTRUCTION_SCH) @@ -67,7 +63,6 @@ struct BlockGemmPipelineAGmemBGmemCReg static constexpr index_t GetVectorSizeA() { return Policy::template GetVectorSizeA(); } static constexpr index_t GetVectorSizeB() { return Policy::template GetVectorSizeB(); } - static constexpr index_t GetVectorSizeC() { return Policy::template GetVectorSizeC(); } static constexpr index_t GetSmemPackA() { return Policy::template GetSmemPackA(); } static constexpr index_t GetSmemPackB() { return Policy::template GetSmemPackB(); } diff --git a/example/ck_tile/99_toy_example/02_gemm/block_gemm_pipeline_agmem_bgmem_creg_default_policy.hpp b/example/ck_tile/99_toy_example/02_gemm/block_gemm_pipeline_agmem_bgmem_creg_default_policy.hpp index 917f86e960..6bde5eb939 100644 --- a/example/ck_tile/99_toy_example/02_gemm/block_gemm_pipeline_agmem_bgmem_creg_default_policy.hpp +++ b/example/ck_tile/99_toy_example/02_gemm/block_gemm_pipeline_agmem_bgmem_creg_default_policy.hpp @@ -336,23 +336,6 @@ struct BlockGemmPipelineAGmemBGmemCRegDefaultPolicy return GetGlobalVectorLoadSize(); } - template - CK_TILE_HOST_DEVICE static constexpr auto GetVectorSizeC() - { - using BlockGemm = remove_cvref_t())>; - using WG = typename BlockGemm::WarpGemm; - using CWarpDstr = typename WG::CWarpDstr; - - // In this case each thread has multiple consecutive elements in - // N dimension, however consecutive threads' elements have stride. - constexpr index_t NDimY = CWarpDstr::NDimY; - constexpr auto c_warp_y_lengths = - CWarpDstr{}.get_ys_to_d_descriptor().get_lengths(); - static_assert(WG::WarpGemmAttribute::Impl::kCM1PerLane == - c_warp_y_lengths.get(number{})); - return c_warp_y_lengths.get(number{}); - } - template CK_TILE_HOST_DEVICE static constexpr auto IsTransposeC() { @@ -362,44 +345,21 @@ struct BlockGemmPipelineAGmemBGmemCRegDefaultPolicy template CK_TILE_HOST_DEVICE static constexpr auto GetSmemPackA() { - using BlockGemm = remove_cvref_t())>; - constexpr index_t KPack = BlockGemm::Traits::KPack; - return KPack; + constexpr index_t kKPack = 8; + return kKPack; + // using BlockGemm = remove_cvref_t())>; + // constexpr index_t KPack = BlockGemm::Traits::KPack; + // return KPack; } template CK_TILE_HOST_DEVICE static constexpr auto GetSmemPackB() { - using BlockGemm = remove_cvref_t())>; - constexpr index_t KPack = BlockGemm::Traits::KPack; - return KPack; - } - - template - CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSizeA() - { - constexpr auto a_lds_desc = MakeALdsBlockDescriptor(); - constexpr index_t smem_size_a = integer_least_multiple( - sizeof(typename Problem::ADataType) * a_lds_desc.get_element_space_size(), 16); - return smem_size_a; - } - - template - CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSizeB() - { - constexpr auto b_lds_desc = MakeBLdsBlockDescriptor(); - constexpr index_t smem_size_b = integer_least_multiple( - sizeof(typename Problem::BDataType) * b_lds_desc.get_element_space_size(), 16); - return smem_size_b; - } - - template - CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() - { - constexpr index_t smem_size_a = GetSmemSizeA(); - constexpr index_t smem_size_b = GetSmemSizeB(); - - return smem_size_a + smem_size_b; + constexpr index_t kKPack = 8; + return kKPack; + // using BlockGemm = remove_cvref_t())>; + // constexpr index_t KPack = BlockGemm::Traits::KPack; + // return KPack; } template diff --git a/example/ck_tile/99_toy_example/02_gemm/gemm.hpp b/example/ck_tile/99_toy_example/02_gemm/gemm.hpp index 631d43b25d..421f442f3d 100644 --- a/example/ck_tile/99_toy_example/02_gemm/gemm.hpp +++ b/example/ck_tile/99_toy_example/02_gemm/gemm.hpp @@ -61,28 +61,6 @@ struct TileGemmShape #endif #if defined(ENABLE_INSTRUCTION_SCH) -template -struct TileGemmTraits -{ - static constexpr bool kPadM = kPadM_; - static constexpr bool kPadN = kPadN_; - static constexpr bool kPadK = kPadK_; - - // TODO this can't be hardcoded here! Should be in policy! - static constexpr int _VectorSize = 16; - - using ALayout = ALayout_; - using BLayout = BLayout_; - using CLayout = CLayout_; - - static constexpr bool TransposeC = false; -}; - template -struct TileGemmUniversalTraits +struct TileGemmTraits { static constexpr bool kPadM = kPadM_; static constexpr bool kPadN = kPadN_; @@ -284,22 +262,21 @@ struct Gemm PermuteA, PermuteB>; - using GemmUniversalTraits = - TileGemmUniversalTraits; + using GemmTraits = TileGemmTraits; using BlockGemmPipelineProblem_ = BlockGemmPipelineProblem;