From af00674037c5a3ac920ec2d312a37bb47c38c36d Mon Sep 17 00:00:00 2001 From: linqunAMD Date: Wed, 18 Jun 2025 08:27:46 +0800 Subject: [PATCH] [CK_TILE] Support multi-config in tile_example_gemm_universal (#2240) * [CK_TILE] Support multi-config in tile_example_gemm_universal Add GemmConfig in run_gemm_example to support multiple tile config. - It is useful when use you need compare gemm perf with different tile/pipeline config - we also can use it simplify the code for wmma support in the furture. * [CK_TILE] Support multi-config in tile_example_gemm_universal Address review comments * rebase code and fix clang format. * fix clang format * support pipeline v5. * fix merge conflict * address review comment * add missing file * address review comment v2 * fix build error [ROCm/composable_kernel commit: 0eb8974502df073be0e131f25435a30ecbf9a656] --- example/ck_tile/03_gemm/gemm_basic.cpp | 41 +-- example/ck_tile/03_gemm/gemm_utils.hpp | 301 ++++++++++++------ example/ck_tile/03_gemm/run_gemm_example.inc | 40 ++- example/ck_tile/03_gemm/universal_gemm.cpp | 71 +++-- .../gemm/pipeline/gemm_pipeline_problem.hpp | 3 +- .../ops/gemm/pipeline/tile_gemm_traits.hpp | 5 +- 6 files changed, 306 insertions(+), 155 deletions(-) diff --git a/example/ck_tile/03_gemm/gemm_basic.cpp b/example/ck_tile/03_gemm/gemm_basic.cpp index 1906b0bda7..090a98486e 100644 --- a/example/ck_tile/03_gemm/gemm_basic.cpp +++ b/example/ck_tile/03_gemm/gemm_basic.cpp @@ -12,7 +12,8 @@ #include "ck_tile/host.hpp" #include "gemm_utils.hpp" -template + typename CDEElementWise> float gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s) { @@ -140,12 +141,12 @@ int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int a { if(a_layout == "R" && b_layout == "C") { - return run_gemm_example_with_layouts( + return run_gemm_example_with_layouts( argc, argv, Row{}, Col{}, Row{}); } else if(a_layout == "C" && b_layout == "C") { - return run_gemm_example_with_layouts( + return run_gemm_example_with_layouts( argc, argv, Col{}, Col{}, Row{}); } else @@ -156,24 +157,24 @@ int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int a } else { - if(a_layout == "R" && b_layout == "R") + if(a_layout == "R" && b_layout == "C") { - return run_gemm_example_with_layouts( - argc, argv, Row{}, Row{}, Row{}); - } - else if(a_layout == "R" && b_layout == "C") - { - return run_gemm_example_with_layouts( + return run_gemm_example_with_layouts( argc, argv, Row{}, Col{}, Row{}); } + else if(a_layout == "R" && b_layout == "R") + { + return run_gemm_example_with_layouts( + argc, argv, Row{}, Row{}, Row{}); + } else if(a_layout == "C" && b_layout == "R") { - return run_gemm_example_with_layouts( + return run_gemm_example_with_layouts( argc, argv, Col{}, Row{}, Row{}); } else if(a_layout == "C" && b_layout == "C") { - return run_gemm_example_with_layouts( + return run_gemm_example_with_layouts( argc, argv, Col{}, Col{}, Row{}); } else @@ -211,15 +212,19 @@ int run_gemm_example(int argc, char* argv[]) return run_gemm_example_prec_type( a_layout, b_layout, argc, argv); } - -#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V3) else if(data_type == "pk_int4_t") { // TODO: Add support for bhalf_t ADataType - return run_gemm_example_prec_type( - a_layout, b_layout, argc, argv); + if constexpr(GemmConfigBase::Pipeline == CK_TILE_PIPELINE_COMPUTE_V3) + { + return run_gemm_example_prec_type( + a_layout, b_layout, argc, argv); + } + else + { + throw std::runtime_error("Unsupported data type for this operation !!!"); + } } -#endif else { throw std::runtime_error("Unsupported data type for this operation !!!"); diff --git a/example/ck_tile/03_gemm/gemm_utils.hpp b/example/ck_tile/03_gemm/gemm_utils.hpp index 6987a2492e..101e195903 100644 --- a/example/ck_tile/03_gemm/gemm_utils.hpp +++ b/example/ck_tile/03_gemm/gemm_utils.hpp @@ -16,105 +16,8 @@ #define CK_TILE_PIPELINE_COMPUTE_V4 3 #define CK_TILE_PIPELINE_COMPUTE_V5 4 -#ifndef CK_TILE_PIPELINE_DEFAULT -#define CK_TILE_PIPELINE_DEFAULT CK_TILE_PIPELINE_COMPUTE_V3 -#endif - -#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_MEMORY) -#define GEMM_PIPELINE ck_tile::GemmPipelineAgBgCrMem -#define UNIVERSAL_GEMM_PIPELINE ck_tile::BaseGemmPipelineAgBgCrMem -#define GEMM_PIPELINE_SCHEDULER ck_tile::GemmPipelineScheduler::Interwave -#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V3) -#define GEMM_PIPELINE ck_tile::GemmPipelineAgBgCrCompV3 -#define UNIVERSAL_GEMM_PIPELINE ck_tile::BaseGemmPipelineAgBgCrCompV3 -#define GEMM_PIPELINE_SCHEDULER ck_tile::GemmPipelineScheduler::Intrawave -#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V4) -#define GEMM_PIPELINE ck_tile::GemmPipelineAgBgCrCompV4 -#define UNIVERSAL_GEMM_PIPELINE ck_tile::BaseGemmPipelineAgBgCrCompV4 -#define GEMM_PIPELINE_SCHEDULER ck_tile::GemmPipelineScheduler::Intrawave -#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V5) -#define GEMM_PIPELINE ck_tile::GemmPipelineAgBgCrCompV5 -#define UNIVERSAL_GEMM_PIPELINE ck_tile::BaseGemmPipelineAgBgCrCompV5 -#define GEMM_PIPELINE_SCHEDULER ck_tile::GemmPipelineScheduler::Intrawave -#else -#error "unsupported CK_TILE_PIPELINE_DEFAULT value" -#endif - -struct GemmConfig +struct GemmConfigBase { -#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_MEMORY) - // Memory friendly for Interwave scheduler - static constexpr ck_tile::index_t M_Tile = 128; - static constexpr ck_tile::index_t N_Tile = 32; - static constexpr ck_tile::index_t K_Tile = 64; - - static constexpr ck_tile::index_t M_Warp = 4; - static constexpr ck_tile::index_t N_Warp = 1; - static constexpr ck_tile::index_t K_Warp = 1; - - static constexpr ck_tile::index_t M_Warp_Tile = 32; - static constexpr ck_tile::index_t N_Warp_Tile = 32; - static constexpr ck_tile::index_t K_Warp_Tile = 16; - - static constexpr bool DoubleSmemBuffer = false; - static constexpr ck_tile::index_t NumWaveGroups = 1; -#endif -#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V3) - // Compute friendly for Intrawave scheduler - static constexpr ck_tile::index_t M_Tile = 128; - static constexpr ck_tile::index_t N_Tile = 128; - static constexpr ck_tile::index_t K_Tile = 128; - - static constexpr ck_tile::index_t M_Warp = 2; - static constexpr ck_tile::index_t N_Warp = 2; - static constexpr ck_tile::index_t K_Warp = 1; - - static constexpr ck_tile::index_t M_Warp_Tile = 16; - static constexpr ck_tile::index_t N_Warp_Tile = 16; - static constexpr ck_tile::index_t K_Warp_Tile = 32; - - static constexpr bool DoubleSmemBuffer = false; - static constexpr ck_tile::index_t NumWaveGroups = 1; -#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V4) - // Compute friendly for Intrawave scheduler - // Using the ping pong reader in the lds level - static constexpr ck_tile::index_t M_Tile = 256; - static constexpr ck_tile::index_t N_Tile = 256; - static constexpr ck_tile::index_t K_Tile = 32; - - static constexpr ck_tile::index_t M_Warp = 2; - static constexpr ck_tile::index_t N_Warp = 2; - static constexpr ck_tile::index_t K_Warp = 1; - - static constexpr ck_tile::index_t M_Warp_Tile = 32; - static constexpr ck_tile::index_t N_Warp_Tile = 32; - static constexpr ck_tile::index_t K_Warp_Tile = 16; - - static constexpr bool DoubleSmemBuffer = true; - static constexpr ck_tile::index_t NumWaveGroups = 1; -#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V5) - // Compute friendly for Intrawave scheduler - // Using the ping pong reader in the lds level - static constexpr ck_tile::index_t M_Tile = 128; - static constexpr ck_tile::index_t N_Tile = 128; - static constexpr ck_tile::index_t K_Tile = 32; - - static constexpr ck_tile::index_t M_Warp = 1; - static constexpr ck_tile::index_t N_Warp = 1; - static constexpr ck_tile::index_t K_Warp = 2; - - static constexpr ck_tile::index_t M_Warp_Tile = 32; - static constexpr ck_tile::index_t N_Warp_Tile = 32; - static constexpr ck_tile::index_t K_Warp_Tile = 16; - - static constexpr bool DoubleSmemBuffer = false; - - // Available wavegroups will be split into `NumWaveGroups` and each of these wavegroups - // will be responsible for specific jobs. For instance, perform Global Memory read operations, - // perform block-gemm operation etc... - static constexpr ck_tile::index_t NumWaveGroups = 2; -#endif - static constexpr bool kPadM = false; static constexpr bool kPadN = false; static constexpr bool kPadK = false; @@ -128,6 +31,169 @@ struct GemmConfig static constexpr int kBlockPerCu = 1; static constexpr ck_tile::index_t TileParitionerGroupNum = 8; static constexpr ck_tile::index_t TileParitionerM01 = 4; + static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave; + static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3; + static constexpr ck_tile::index_t NumWaveGroups = 1; +}; + +template +struct GemmConfigMemoryInterwave : public GemmConfigBase +{ + // Memory friendly for Interwave scheduler + static constexpr ck_tile::index_t M_Tile = 128; + static constexpr ck_tile::index_t N_Tile = 32; + static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType); + + static constexpr ck_tile::index_t M_Warp = 4; + static constexpr ck_tile::index_t N_Warp = 1; + static constexpr ck_tile::index_t K_Warp = 1; + + static constexpr ck_tile::index_t M_Warp_Tile = 32; + static constexpr ck_tile::index_t N_Warp_Tile = 32; + static constexpr ck_tile::index_t K_Warp_Tile = sizeof(PrecType) == 2 ? 8 : 16; + + static constexpr bool DoubleSmemBuffer = false; + static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_MEMORY; + static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Interwave; +}; + +template +struct GemmConfigMemoryIntrawave : public GemmConfigBase +{ + static constexpr ck_tile::index_t M_Tile = 128; + static constexpr ck_tile::index_t N_Tile = 32; + static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType); + + static constexpr ck_tile::index_t M_Warp = 4; + static constexpr ck_tile::index_t N_Warp = 1; + static constexpr ck_tile::index_t K_Warp = 1; + + static constexpr ck_tile::index_t M_Warp_Tile = 32; + static constexpr ck_tile::index_t N_Warp_Tile = 32; + static constexpr ck_tile::index_t K_Warp_Tile = sizeof(PrecType) == 2 ? 8 : 16; + + static constexpr bool DoubleSmemBuffer = false; + static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_MEMORY; +}; + +template +struct GemmConfigComputeV3 : public GemmConfigBase +{ + // Compute V3 only support Intrawave scheduler + static constexpr ck_tile::index_t M_Tile = 256; + static constexpr ck_tile::index_t N_Tile = 256; + static constexpr ck_tile::index_t K_Tile = 64 / sizeof(PrecType); + + static constexpr ck_tile::index_t M_Warp = 2; + static constexpr ck_tile::index_t N_Warp = 2; + static constexpr ck_tile::index_t K_Warp = 1; + + static constexpr ck_tile::index_t M_Warp_Tile = 32; + static constexpr ck_tile::index_t N_Warp_Tile = 32; + static constexpr ck_tile::index_t K_Warp_Tile = sizeof(PrecType) == 2 ? 16 : 64; + + static constexpr bool DoubleSmemBuffer = false; + static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3; +}; + +template +struct GemmConfigComputeV3_1 : public GemmConfigBase +{ + static constexpr ck_tile::index_t M_Tile = 256; + static constexpr ck_tile::index_t N_Tile = 256; + static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType); + + static constexpr ck_tile::index_t M_Warp = 2; + static constexpr ck_tile::index_t N_Warp = 2; + static constexpr ck_tile::index_t K_Warp = 1; + + static constexpr ck_tile::index_t M_Warp_Tile = 32; + static constexpr ck_tile::index_t N_Warp_Tile = 32; + static constexpr ck_tile::index_t K_Warp_Tile = sizeof(PrecType) == 2 ? 16 : 64; + + static constexpr bool DoubleSmemBuffer = false; + static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3; +}; + +template +struct GemmConfigComputeV3_2 : public GemmConfigBase +{ + static constexpr ck_tile::index_t M_Tile = 128; + static constexpr ck_tile::index_t N_Tile = 128; + static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType); + + static constexpr ck_tile::index_t M_Warp = 2; + static constexpr ck_tile::index_t N_Warp = 2; + static constexpr ck_tile::index_t K_Warp = 1; + + static constexpr ck_tile::index_t M_Warp_Tile = 16; + static constexpr ck_tile::index_t N_Warp_Tile = 16; + static constexpr ck_tile::index_t K_Warp_Tile = sizeof(PrecType) == 2 ? 32 : 128; + + static constexpr bool DoubleSmemBuffer = false; + static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3; + + static constexpr int kBlockPerCu = 2; +}; + +template +struct GemmConfigComputeV4 : public GemmConfigBase +{ + // Compute V4 only support Intrawave scheduler + // Using the ping pong reader in the lds level + static constexpr ck_tile::index_t M_Tile = 256; + static constexpr ck_tile::index_t N_Tile = 256; + static constexpr ck_tile::index_t K_Tile = 64 / sizeof(PrecType); + + static constexpr ck_tile::index_t M_Warp = 2; + static constexpr ck_tile::index_t N_Warp = 2; + static constexpr ck_tile::index_t K_Warp = 1; + + static constexpr ck_tile::index_t M_Warp_Tile = 32; + static constexpr ck_tile::index_t N_Warp_Tile = 32; + static constexpr ck_tile::index_t K_Warp_Tile = sizeof(PrecType) == 2 ? 16 : 64; + + static constexpr bool DoubleSmemBuffer = true; + static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V4; +}; + +template +struct GemmConfigComputeV4_1 : public GemmConfigBase +{ + static constexpr ck_tile::index_t M_Tile = 256; + static constexpr ck_tile::index_t N_Tile = 256; + static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType); + + static constexpr ck_tile::index_t M_Warp = 2; + static constexpr ck_tile::index_t N_Warp = 2; + static constexpr ck_tile::index_t K_Warp = 1; + + static constexpr ck_tile::index_t M_Warp_Tile = 32; + static constexpr ck_tile::index_t N_Warp_Tile = 32; + static constexpr ck_tile::index_t K_Warp_Tile = sizeof(PrecType) == 2 ? 16 : 64; + + static constexpr bool DoubleSmemBuffer = true; + static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V4; +}; + +template +struct GemmConfigComputeV5 : public GemmConfigBase +{ + static constexpr ck_tile::index_t M_Tile = 128; + static constexpr ck_tile::index_t N_Tile = 128; + static constexpr ck_tile::index_t K_Tile = 64 / sizeof(PrecType); + + static constexpr ck_tile::index_t M_Warp = 1; + static constexpr ck_tile::index_t N_Warp = 1; + static constexpr ck_tile::index_t K_Warp = 2; + + static constexpr ck_tile::index_t M_Warp_Tile = 32; + static constexpr ck_tile::index_t N_Warp_Tile = 32; + static constexpr ck_tile::index_t K_Warp_Tile = sizeof(PrecType) == 2 ? 16 : 64; + + static constexpr bool DoubleSmemBuffer = false; + static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V5; + static constexpr ck_tile::index_t NumWaNumWaveGroups = 2; }; template @@ -224,6 +290,45 @@ struct DataTypeTraits static constexpr const char* name = "pk_int4_t"; }; +template +struct PipelineTypeTraits; + +template <> +struct PipelineTypeTraits +{ + template + using GemmPipeline = ck_tile::GemmPipelineAgBgCrMem; + template + using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrMem; +}; + +template <> +struct PipelineTypeTraits +{ + template + using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3; + template + using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV3; +}; + +template <> +struct PipelineTypeTraits +{ + template + using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV4; + template + using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV4; +}; + +template <> +struct PipelineTypeTraits +{ + template + using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV5; + template + using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV5; +}; + auto create_args(int argc, char* argv[]) { ck_tile::ArgParser arg_parser; diff --git a/example/ck_tile/03_gemm/run_gemm_example.inc b/example/ck_tile/03_gemm/run_gemm_example.inc index cc9a825c73..140107bfb4 100644 --- a/example/ck_tile/03_gemm/run_gemm_example.inc +++ b/example/ck_tile/03_gemm/run_gemm_example.inc @@ -30,7 +30,8 @@ auto calculate_rtol_atol(const ck_tile::index_t K, return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k)); } -template ; - using GemmPipeline = GEMM_PIPELINE; + using GemmPipeline = typename PipelineTypeTraits::template GemmPipeline< + UniversalGemmProblem>; const ck_tile::index_t K = tensor.get_length(0); const ck_tile::index_t N = tensor.get_length(1); @@ -144,7 +146,22 @@ void permute_vectors_i4x4_b(Tensor& tensor) } } -template +float gemm(const ck_tile::GemmHostArgs<>& args, const ck_tile::stream_config& s); + +template b_k_n_dev = b_k_n; if constexpr(GemmConfig::PermuteB) { - permute_tensor_b, AccDataType, diff --git a/example/ck_tile/03_gemm/universal_gemm.cpp b/example/ck_tile/03_gemm/universal_gemm.cpp index 3ec90e7f00..ecfaa92b9a 100644 --- a/example/ck_tile/03_gemm/universal_gemm.cpp +++ b/example/ck_tile/03_gemm/universal_gemm.cpp @@ -13,7 +13,8 @@ #include "gemm_utils.hpp" #include "run_gemm_example.inc" -template + typename CDEElementWise> float gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s) { @@ -45,7 +46,8 @@ float gemm(const ck_tile::GemmHostArgs& args, const ck_tile: GemmConfig::kPadK, ALayout, BLayout, - ELayout>; + ELayout, + GemmConfig::NumWaveGroups>; using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits& args, const ck_tile: using GemmPipelineProblem = ck_tile::GemmPipelineProblem; - using BaseGemmPipeline = UNIVERSAL_GEMM_PIPELINE; + using BaseGemmPipeline = typename PipelineTypeTraits< + GemmConfig::Pipeline>::template UniversalGemmPipeline; const ck_tile::index_t k_grain = args.k_batch * GemmConfig::K_Tile; const ck_tile::index_t K_split = (args.K + k_grain - 1) / k_grain * GemmConfig::K_Tile; @@ -75,7 +78,7 @@ float gemm(const ck_tile::GemmHostArgs& args, const ck_tile: [&](const auto has_hot_loop_, const auto tail_number_, const auto memory_operation_) { constexpr bool has_hot_loop_v = has_hot_loop_.value; constexpr auto tail_number_v = tail_number_.value; - constexpr auto scheduler = GEMM_PIPELINE_SCHEDULER; + constexpr auto scheduler = GemmConfig::Scheduler; constexpr auto memory_operation = memory_operation_.value; using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem& args, const ck_tile: has_hot_loop_v, tail_number_v>; - using GemmPipeline = GEMM_PIPELINE; + using GemmPipeline = typename PipelineTypeTraits< + GemmConfig::Pipeline>::template GemmPipeline; using GemmEpilogue = ck_tile::CShuffleEpilogue< ck_tile::CShuffleEpilogueProblem& args, const ck_tile: UniversalGemmProblem::TransposeC, memory_operation, GemmConfig::NumWaveGroups>>; - using Kernel = ck_tile::GemmKernel; auto kargs = Kernel::MakeKernelArgs(args); @@ -205,7 +208,10 @@ float gemm(const ck_tile::GemmHostArgs& args, const ck_tile: return ave_time; } -template +template int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int argc, char* argv[]) { using Row = ck_tile::tensor_layout::gemm::RowMajor; @@ -215,12 +221,12 @@ int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int a { if(a_layout == "R" && b_layout == "C") { - return run_gemm_example_with_layouts( + return run_gemm_example_with_layouts( argc, argv, Row{}, Col{}, Row{}); } else if(a_layout == "C" && b_layout == "C") { - return run_gemm_example_with_layouts( + return run_gemm_example_with_layouts( argc, argv, Col{}, Col{}, Row{}); } else @@ -233,22 +239,22 @@ int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int a { if(a_layout == "R" && b_layout == "R") { - return run_gemm_example_with_layouts( + return run_gemm_example_with_layouts( argc, argv, Row{}, Row{}, Row{}); } else if(a_layout == "R" && b_layout == "C") { - return run_gemm_example_with_layouts( + return run_gemm_example_with_layouts( argc, argv, Row{}, Col{}, Row{}); } else if(a_layout == "C" && b_layout == "R") { - return run_gemm_example_with_layouts( + return run_gemm_example_with_layouts( argc, argv, Col{}, Row{}, Row{}); } else if(a_layout == "C" && b_layout == "C") { - return run_gemm_example_with_layouts( + return run_gemm_example_with_layouts( argc, argv, Col{}, Col{}, Row{}); } else @@ -258,6 +264,7 @@ int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int a } } +template