From c5acb522de57400d1debeb92c469feddcab761c0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bart=C5=82omiej=20Kocot?= Date: Thu, 27 Feb 2025 11:01:14 +0100 Subject: [PATCH] [CK TILE] Gemm pk_int4_t permute B (#1907) * [CK TILE] Gemm pk_int4_t permute B * Fixes [ROCm/composable_kernel commit: 0356ee069e3cd40c5f17c3b78ef6fd8c920ff4a4] --- example/ck_tile/03_gemm/gemm_basic.cpp | 2 +- .../{gemm_basic.hpp => gemm_utils.hpp} | 77 +++++++++++- example/ck_tile/03_gemm/run_gemm_example.inc | 91 ++++++++++++-- example/ck_tile/03_gemm/universal_gemm.cpp | 116 +++++------------- .../ck_tile/17_grouped_gemm/grouped_gemm.hpp | 8 +- .../ck_tile/ops/gemm/kernel/gemm_kernel.hpp | 67 ++++++++-- .../gemm_pipeline_ag_bg_cr_comp_v3.hpp | 19 +-- .../gemm_pipeline_ag_bg_cr_comp_v4.hpp | 3 + .../pipeline/gemm_pipeline_ag_bg_cr_mem.hpp | 3 + .../gemm_pipeline_agmem_bgmem_creg_v1.hpp | 3 + .../gemm_pipeline_agmem_bgmem_creg_v2.hpp | 3 + .../ops/gemm/pipeline/tile_gemm_shape.hpp | 9 +- 12 files changed, 279 insertions(+), 122 deletions(-) rename example/ck_tile/03_gemm/{gemm_basic.hpp => gemm_utils.hpp} (62%) diff --git a/example/ck_tile/03_gemm/gemm_basic.cpp b/example/ck_tile/03_gemm/gemm_basic.cpp index 5dc7b9cd0b..57298b68dc 100644 --- a/example/ck_tile/03_gemm/gemm_basic.cpp +++ b/example/ck_tile/03_gemm/gemm_basic.cpp @@ -10,7 +10,7 @@ #include #include "ck_tile/host.hpp" -#include "gemm_basic.hpp" +#include "gemm_utils.hpp" template -struct GemmBasicTypeConfig; +struct GemmTypeConfig; template <> -struct GemmBasicTypeConfig +struct GemmTypeConfig { using ADataType = ck_tile::half_t; using BDataType = ck_tile::half_t; @@ -49,7 +114,7 @@ struct GemmBasicTypeConfig }; template <> -struct GemmBasicTypeConfig +struct GemmTypeConfig { using ADataType = ck_tile::bf16_t; using BDataType = ck_tile::bf16_t; @@ -58,7 +123,7 @@ struct GemmBasicTypeConfig }; template <> -struct GemmBasicTypeConfig +struct GemmTypeConfig { using ADataType = ck_tile::fp8_t; using BDataType = ck_tile::fp8_t; @@ -67,7 +132,7 @@ struct GemmBasicTypeConfig }; template <> -struct GemmBasicTypeConfig +struct GemmTypeConfig { using ADataType = ck_tile::bf8_t; using BDataType = ck_tile::bf8_t; @@ -76,7 +141,7 @@ struct GemmBasicTypeConfig }; template <> -struct GemmBasicTypeConfig +struct GemmTypeConfig { using ADataType = ck_tile::half_t; using BDataType = ck_tile::pk_int4_t; diff --git a/example/ck_tile/03_gemm/run_gemm_example.inc b/example/ck_tile/03_gemm/run_gemm_example.inc index f068cbc1da..6cb40e45d1 100644 --- a/example/ck_tile/03_gemm/run_gemm_example.inc +++ b/example/ck_tile/03_gemm/run_gemm_example.inc @@ -29,8 +29,67 @@ auto calculate_rtol_atol(const ck_tile::index_t K, // Use higher threshold return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k)); } -template + +template void permute_tensor_b(Tensor& tensor) +{ + using GemmShape = ck_tile::TileGemmShape< + ck_tile::sequence, + ck_tile::sequence, + ck_tile:: + sequence, + GemmConfig::PermuteA, + GemmConfig::PermuteB>; + + using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits; + + using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem; + + using GemmPipeline = GEMM_PIPELINE; + + const ck_tile::index_t K = tensor.get_length(0); + const ck_tile::index_t N = tensor.get_length(1); + const ck_tile::index_t K1 = GemmPipeline::GetSmemPackB(); + const ck_tile::index_t K0 = K / K1; + + Tensor tensor_copy = tensor; + + // int K0, N, K1 + for(int j = 0; j < K0; j++) + { + for(int i = 0; i < N; i++) + { + for(int jj = 0; jj < K1; jj++) + { + tensor(j * N * K1 + i * K1 + jj) = tensor_copy(i * K + (j * K1 + jj)); + } + } + } +} + +template +void permute_vectors_i4x4_b(Tensor& tensor) { const ck_tile::index_t K = tensor.get_length(0); const ck_tile::index_t N = tensor.get_length(1); @@ -153,7 +212,7 @@ int run_gemm_example_with_layouts(int argc, if(!result) return -1; - using AccDataType = typename GemmBasicTypeConfig::AccDataType; + using AccDataType = typename GemmTypeConfig::AccDataType; ck_tile::index_t M = arg_parser.get_int("m"); ck_tile::index_t N = arg_parser.get_int("n"); @@ -181,8 +240,8 @@ int run_gemm_example_with_layouts(int argc, if(init_method == 0) { - ck_tile::FillUniformDistribution{-1.f, 1.f}(a_m_k); - ck_tile::FillUniformDistribution{-1.f, 1.f}(b_k_n); + ck_tile::FillUniformDistribution{-5.f, 5.f}(a_m_k); + ck_tile::FillUniformDistribution{-5.f, 5.f}(b_k_n); } else if(init_method == 1) { @@ -204,18 +263,36 @@ int run_gemm_example_with_layouts(int argc, ck_tile::DeviceMem b_k_n_dev_buf(b_k_n.get_element_space_size_in_bytes()); ck_tile::DeviceMem c_m_n_dev_buf(c_m_n_dev_result.get_element_space_size_in_bytes()); - a_m_k_dev_buf.ToDevice(a_m_k.data()); + static_assert(!GemmConfig::PermuteA, "Not implemented"); if constexpr(std::is_same_v) { - // Permute data for device implementation + // Permute vector pk_i4x4 data for device implementation ck_tile::HostTensor b_k_n_dev = b_k_n; - permute_tensor_b(b_k_n_dev); + if constexpr(GemmConfig::PermuteB) + { + permute_tensor_b(b_k_n_dev); + } + permute_vectors_i4x4_b(b_k_n_dev); b_k_n_dev_buf.ToDevice(b_k_n_dev.data()); } else { + if constexpr(GemmConfig::PermuteB) + { + std::cout << "Permute for this DataType is not implemented." << std::endl; + return false; + } b_k_n_dev_buf.ToDevice(b_k_n.data()); } + + a_m_k_dev_buf.ToDevice(a_m_k.data()); c_m_n_dev_buf.SetZero(); c_m_n_dev_result.SetZero(); diff --git a/example/ck_tile/03_gemm/universal_gemm.cpp b/example/ck_tile/03_gemm/universal_gemm.cpp index ab763437e5..8c04066b20 100644 --- a/example/ck_tile/03_gemm/universal_gemm.cpp +++ b/example/ck_tile/03_gemm/universal_gemm.cpp @@ -10,7 +10,7 @@ #include #include "ck_tile/host.hpp" -#include "gemm_basic.hpp" +#include "gemm_utils.hpp" template float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s) { -#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_MEMORY) - // Memory friendly for Interwave scheduler - constexpr ck_tile::index_t M_Tile = 128; - constexpr ck_tile::index_t N_Tile = 32; - constexpr ck_tile::index_t K_Tile = 64; + using GemmShape = ck_tile::TileGemmShape< + ck_tile::sequence, + ck_tile::sequence, + ck_tile:: + sequence, + GemmConfig::PermuteA, + GemmConfig::PermuteB>; + using TilePartitioner = + ck_tile::GemmSpatiallyLocalTilePartitioner; - constexpr ck_tile::index_t M_Warp = 4; - constexpr ck_tile::index_t N_Warp = 1; - constexpr ck_tile::index_t K_Warp = 1; - - constexpr ck_tile::index_t M_Warp_Tile = 32; - constexpr ck_tile::index_t N_Warp_Tile = 32; - constexpr ck_tile::index_t K_Warp_Tile = 8; - - constexpr bool DoubleSmemBuffer = false; -#endif -#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V3) - // Compute friendly for Intrawave scheduler - constexpr ck_tile::index_t M_Tile = 256; - constexpr ck_tile::index_t N_Tile = 256; - constexpr ck_tile::index_t K_Tile = 64; - - constexpr ck_tile::index_t M_Warp = 2; - constexpr ck_tile::index_t N_Warp = 2; - constexpr ck_tile::index_t K_Warp = 1; - - constexpr ck_tile::index_t M_Warp_Tile = 32; - constexpr ck_tile::index_t N_Warp_Tile = 32; - constexpr ck_tile::index_t K_Warp_Tile = 16; - - constexpr bool DoubleSmemBuffer = false; -#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V4) - // Compute friendly for Intrawave scheduler - // Using the ping pong reader in the lds level - constexpr ck_tile::index_t M_Tile = 256; - constexpr ck_tile::index_t N_Tile = 256; - constexpr ck_tile::index_t K_Tile = 32; - - constexpr ck_tile::index_t M_Warp = 2; - constexpr ck_tile::index_t N_Warp = 2; - constexpr ck_tile::index_t K_Warp = 1; - - constexpr ck_tile::index_t M_Warp_Tile = 32; - constexpr ck_tile::index_t N_Warp_Tile = 32; - constexpr ck_tile::index_t K_Warp_Tile = 16; - - constexpr bool DoubleSmemBuffer = true; -#endif - - constexpr bool kPadM = false; - constexpr bool kPadN = false; - constexpr bool kPadK = false; - - constexpr bool TransposeC = false; - - constexpr int kBlockPerCu = 1; - constexpr ck_tile::index_t TileParitionerGroupNum = 8; - constexpr ck_tile::index_t TileParitionerM01 = 4; - - // =============================================== - - using GemmShape = - ck_tile::TileGemmShape, - ck_tile::sequence, - ck_tile::sequence>; - using TilePartitioner = ck_tile:: - GemmSpatiallyLocalTilePartitioner; - - using Traits = ck_tile::TileGemmTraits; - using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits; + using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits; + GemmConfig::TransposeC>; using GemmPipelineProblem = ck_tile::GemmPipelineProblem; using BaseGemmPipeline = UNIVERSAL_GEMM_PIPELINE; - const ck_tile::index_t k_grain = args.k_batch * K_Tile; - const ck_tile::index_t K_split = (args.K + k_grain - 1) / k_grain * K_Tile; + const ck_tile::index_t k_grain = args.k_batch * GemmConfig::K_Tile; + const ck_tile::index_t K_split = (args.K + k_grain - 1) / k_grain * GemmConfig::K_Tile; const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split); const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop); const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop); @@ -133,11 +82,11 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& GemmPipelineProblem::kBlockSize, TilePartitioner::MPerBlock, TilePartitioner::NPerBlock, - M_Warp, - N_Warp, - M_Warp_Tile, - N_Warp_Tile, - K_Warp_Tile, + GemmConfig::M_Warp, + GemmConfig::N_Warp, + GemmConfig::M_Warp_Tile, + GemmConfig::N_Warp_Tile, + GemmConfig::K_Warp_Tile, UniversalGemmProblem::TransposeC>>; using Kernel = ck_tile::GemmKernel; auto kargs = Kernel::MakeKernelArgs(args); @@ -158,8 +107,9 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& << std::endl; } - ave_time = ck_tile::launch_kernel( - s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); + ave_time = ck_tile::launch_kernel(s, + ck_tile::make_kernel( + Kernel{}, grids, blocks, 0, kargs)); return ave_time; }; diff --git a/example/ck_tile/17_grouped_gemm/grouped_gemm.hpp b/example/ck_tile/17_grouped_gemm/grouped_gemm.hpp index 2ffef95196..14d450034d 100644 --- a/example/ck_tile/17_grouped_gemm/grouped_gemm.hpp +++ b/example/ck_tile/17_grouped_gemm/grouped_gemm.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -10,10 +10,10 @@ #include "ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp" template -struct GemmBasicTypeConfig; +struct GemmTypeConfig; template <> -struct GemmBasicTypeConfig +struct GemmTypeConfig { using ADataType = ck_tile::half_t; using BDataType = ck_tile::half_t; @@ -21,7 +21,7 @@ struct GemmBasicTypeConfig using AccDataType = float; }; -using Types = GemmBasicTypeConfig; +using Types = GemmTypeConfig; // Specific type aliases for easy access using ADataType = Types::ADataType; diff --git a/include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp b/include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp index f2aa3af196..915ce9b7aa 100644 --- a/include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp +++ b/include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp @@ -314,6 +314,7 @@ struct GemmKernel const GemmKernelArgs& kargs, const SplitKBatchOffset& splitk_batch_offset) { + static_assert(!TilePartitioner::BlockGemmShape::PermuteA, "Not implemented!"); const auto& a_tensor_view = [&]() { if constexpr(std::is_same_v) { @@ -338,21 +339,63 @@ struct GemmKernel const auto& b_tensor_view = [&]() { if constexpr(std::is_same_v) { - return make_naive_tensor_view( - b_ptr, - make_tuple(splitk_batch_offset.splitted_k, kargs.N), - make_tuple(kargs.stride_B, 1), - number{}, - number<1>{}); + if constexpr(TilePartitioner::BlockGemmShape::PermuteB) + { + constexpr index_t K1 = GemmPipeline::GetSmemPackB(); + const index_t K0 = splitk_batch_offset.splitted_k / K1; + constexpr index_t VectorSizeB = std::min(K1, GemmPipeline::GetVectorSizeB()); + const auto b_k0_n_k1_desc = + make_naive_tensor_descriptor(make_tuple(K0, kargs.N, K1), + make_tuple(kargs.N * K1, K1, I1), + number{}, + number<1>{}); + const auto b_n_k_desc = transform_tensor_descriptor( + b_k0_n_k1_desc, + make_tuple(make_merge_transform(make_tuple(K0, K1)), + make_pass_through_transform(kargs.N)), + make_tuple(sequence<0, 2>{}, sequence<1>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + return make_tensor_view(b_ptr, b_n_k_desc); + } + else + { + return make_naive_tensor_view( + b_ptr, + make_tuple(splitk_batch_offset.splitted_k, kargs.N), + make_tuple(kargs.stride_B, 1), + number{}, + number<1>{}); + } } else { - return make_naive_tensor_view( - b_ptr, - make_tuple(kargs.N, splitk_batch_offset.splitted_k), - make_tuple(kargs.stride_B, 1), - number{}, - number<1>{}); + if constexpr(TilePartitioner::BlockGemmShape::PermuteB) + { + constexpr index_t K1 = GemmPipeline::GetSmemPackB(); + const index_t K0 = splitk_batch_offset.splitted_k / K1; + constexpr index_t VectorSizeB = std::min(K1, GemmPipeline::GetVectorSizeB()); + const auto b_k0_n_k1_desc = + make_naive_tensor_descriptor(make_tuple(K0, kargs.N, K1), + make_tuple(kargs.N * K1, K1, I1), + number{}, + number<1>{}); + const auto b_n_k_desc = transform_tensor_descriptor( + b_k0_n_k1_desc, + make_tuple(make_merge_transform(make_tuple(K0, K1)), + make_pass_through_transform(kargs.N)), + make_tuple(sequence<0, 2>{}, sequence<1>{}), + make_tuple(sequence<1>{}, sequence<0>{})); + return make_tensor_view(b_ptr, b_n_k_desc); + } + else + { + return make_naive_tensor_view( + b_ptr, + make_tuple(kargs.N, splitk_batch_offset.splitted_k), + make_tuple(kargs.stride_B, 1), + number{}, + number<1>{}); + } } }(); diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp index b6e165e6da..1e3694d24c 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp @@ -77,6 +77,9 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3 static constexpr index_t GetVectorSizeB() { return Policy::template GetVectorSizeB(); } static constexpr index_t GetVectorSizeC() { return Policy::template GetVectorSizeC(); } + static constexpr index_t GetSmemPackA() { return Policy::template GetSmemPackA(); } + static constexpr index_t GetSmemPackB() { return Policy::template GetSmemPackB(); } + static constexpr bool kPadM = Problem::kPadM; static constexpr bool kPadN = Problem::kPadN; static constexpr bool kPadK = Problem::kPadK; @@ -114,11 +117,11 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3 constexpr index_t WaveNumN = BlockGemmShape::BlockWarps::at(I1{}); // Below should be equal to AK1|BK1 - constexpr index_t A_LDS_Read_Width = Policy::template GetSmemPackA(); - constexpr index_t B_LDS_Read_Width = Policy::template GetSmemPackB(); + constexpr index_t A_LDS_Read_Width = GetSmemPackA(); + constexpr index_t B_LDS_Read_Width = GetSmemPackB(); - constexpr index_t A_LDS_Write_Width = Policy::template GetSmemPackA(); - constexpr index_t B_LDS_Write_Width = Policy::template GetSmemPackB(); + constexpr index_t A_LDS_Write_Width = GetSmemPackA(); + constexpr index_t B_LDS_Write_Width = GetSmemPackB(); constexpr index_t A_Buffer_Load_Inst_Num = MPerBlock * KPerBlock / (BlockSize * GetVectorSizeA()); @@ -174,11 +177,11 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3 constexpr index_t WaveNumN = BlockGemmShape::BlockWarps::at(I1{}); // Below should be equal to AK1|BK1 - constexpr index_t A_LDS_Read_Width = Policy::template GetSmemPackA(); - constexpr index_t B_LDS_Read_Width = Policy::template GetSmemPackB(); + constexpr index_t A_LDS_Read_Width = GetSmemPackA(); + constexpr index_t B_LDS_Read_Width = GetSmemPackB(); - constexpr index_t A_LDS_Write_Width = Policy::template GetSmemPackA(); - constexpr index_t B_LDS_Write_Width = Policy::template GetSmemPackB(); + constexpr index_t A_LDS_Write_Width = GetSmemPackA(); + constexpr index_t B_LDS_Write_Width = GetSmemPackB(); constexpr index_t A_Buffer_Load_Inst_Num = MPerBlock * KPerBlock / (BlockSize * GetVectorSizeA()); diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4.hpp index b679f8c8aa..f95d80a6f5 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4.hpp @@ -86,6 +86,9 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4 static constexpr index_t GetVectorSizeB() { return Policy::template GetVectorSizeB(); } static constexpr index_t GetVectorSizeC() { return Policy::template GetVectorSizeC(); } + static constexpr index_t GetSmemPackA() { return Policy::template GetSmemPackA(); } + static constexpr index_t GetSmemPackB() { return Policy::template GetSmemPackB(); } + static constexpr bool kPadM = Problem::kPadM; static constexpr bool kPadN = Problem::kPadN; static constexpr bool kPadK = Problem::kPadK; diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp index 8a73b4b5a1..abf5b617ee 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp @@ -129,6 +129,9 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem static constexpr index_t GetVectorSizeB() { return Policy::template GetVectorSizeB(); } static constexpr index_t GetVectorSizeC() { return Policy::template GetVectorSizeC(); } + static constexpr index_t GetSmemPackA() { return Policy::template GetSmemPackA(); } + static constexpr index_t GetSmemPackB() { return Policy::template GetSmemPackB(); } + static constexpr bool kPadM = Problem::kPadM; static constexpr bool kPadN = Problem::kPadN; static constexpr bool kPadK = Problem::kPadK; diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp index 76bece9398..41ea89b2bd 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp @@ -36,6 +36,9 @@ struct GemmPipelineAGmemBGmemCRegV1 static constexpr index_t GetVectorSizeB() { return Problem::VectorSizeB; } static constexpr index_t GetVectorSizeC() { return Problem::VectorSizeC; } + static constexpr index_t GetSmemPackA() { return Policy::template GetSmemPackA(); } + static constexpr index_t GetSmemPackB() { return Policy::template GetSmemPackB(); } + static constexpr bool kPadM = Problem::kPadM; static constexpr bool kPadN = Problem::kPadN; static constexpr bool kPadK = Problem::kPadK; diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2.hpp index 2f658582c9..95b7618b11 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2.hpp @@ -31,6 +31,9 @@ struct GemmPipelineAGmemBGmemCRegV2 static constexpr index_t kNPerBlock = BlockGemmShape::kN; static constexpr index_t kKPerBlock = BlockGemmShape::kK; + static constexpr index_t GetSmemPackA() { return Policy::template GetSmemPackA(); } + static constexpr index_t GetSmemPackB() { return Policy::template GetSmemPackB(); } + [[nodiscard]] CK_TILE_HOST static const std::string GetName() { // clang-format off diff --git a/include/ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp b/include/ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp index 24a399f18d..f0aa4472e1 100644 --- a/include/ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp +++ b/include/ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp @@ -8,7 +8,11 @@ namespace ck_tile { -template +template struct TileGemmShape { using BlockTile = remove_cvref_t; @@ -21,6 +25,9 @@ struct TileGemmShape static constexpr index_t kN = BlockTile::at(number<1>{}); static constexpr index_t kK = BlockTile::at(number<2>{}); + static constexpr bool PermuteA = PermuteA_; + static constexpr bool PermuteB = PermuteB_; + CK_TILE_HOST static std::string GetName() { // clang-format off