From 896f8b4ccf01ccaabbfcc89d4f4769d9441a2836 Mon Sep 17 00:00:00 2001 From: Jakub Piasecki Date: Fri, 10 Jan 2025 11:57:03 +0000 Subject: [PATCH] add gemm_api and instances --- example/ck_tile/03_gemm/CMakeLists.txt | 27 +- example/ck_tile/03_gemm/gemm_basic.cpp | 31 +- example/ck_tile/03_gemm/gemm_basic.hpp | 67 ++- .../ck_tile/03_gemm/instances/gemm_api.cpp | 482 ++++++++++++++++++ ...universal_comp_bf16_bf16_bf16_km_kn_mn.cpp | 27 + ...universal_comp_bf16_bf16_bf16_km_nk_mn.cpp | 27 + ...universal_comp_bf16_bf16_bf16_mk_kn_mn.cpp | 26 + ...universal_comp_bf16_bf16_bf16_mk_nk_mn.cpp | 27 + ...mm_universal_comp_f16_f16_f16_km_kn_mn.cpp | 27 + ...mm_universal_comp_f16_f16_f16_km_nk_mn.cpp | 27 + ...mm_universal_comp_f16_f16_f16_mk_kn_mn.cpp | 26 + ...mm_universal_comp_f16_f16_f16_mk_nk_mn.cpp | 27 + .../gemm_universal_comp_instance_common.hpp | 206 ++++++++ ..._universal_mem_bf16_bf16_bf16_km_kn_mn.cpp | 27 + ..._universal_mem_bf16_bf16_bf16_km_nk_mn.cpp | 27 + ..._universal_mem_bf16_bf16_bf16_mk_kn_mn.cpp | 26 + ..._universal_mem_bf16_bf16_bf16_mk_nk_mn.cpp | 27 + ...emm_universal_mem_f16_f16_f16_km_kn_mn.cpp | 27 + ...emm_universal_mem_f16_f16_f16_km_nk_mn.cpp | 27 + ...emm_universal_mem_f16_f16_f16_mk_kn_mn.cpp | 26 + ...emm_universal_mem_f16_f16_f16_mk_nk_mn.cpp | 27 + .../gemm_universal_mem_instance_common.hpp | 206 ++++++++ example/ck_tile/03_gemm/run_gemm_example.inc | 30 +- example/ck_tile/03_gemm/universal_gemm.cpp | 203 +------- 24 files changed, 1453 insertions(+), 227 deletions(-) create mode 100644 example/ck_tile/03_gemm/instances/gemm_api.cpp create mode 100644 example/ck_tile/03_gemm/instances/gemm_universal_comp_bf16_bf16_bf16_km_kn_mn.cpp create mode 100644 example/ck_tile/03_gemm/instances/gemm_universal_comp_bf16_bf16_bf16_km_nk_mn.cpp create mode 100644 example/ck_tile/03_gemm/instances/gemm_universal_comp_bf16_bf16_bf16_mk_kn_mn.cpp create mode 100644 example/ck_tile/03_gemm/instances/gemm_universal_comp_bf16_bf16_bf16_mk_nk_mn.cpp create mode 100644 example/ck_tile/03_gemm/instances/gemm_universal_comp_f16_f16_f16_km_kn_mn.cpp create mode 100644 example/ck_tile/03_gemm/instances/gemm_universal_comp_f16_f16_f16_km_nk_mn.cpp create mode 100644 example/ck_tile/03_gemm/instances/gemm_universal_comp_f16_f16_f16_mk_kn_mn.cpp create mode 100644 example/ck_tile/03_gemm/instances/gemm_universal_comp_f16_f16_f16_mk_nk_mn.cpp create mode 100644 example/ck_tile/03_gemm/instances/gemm_universal_comp_instance_common.hpp create mode 100644 example/ck_tile/03_gemm/instances/gemm_universal_mem_bf16_bf16_bf16_km_kn_mn.cpp create mode 100644 example/ck_tile/03_gemm/instances/gemm_universal_mem_bf16_bf16_bf16_km_nk_mn.cpp create mode 100644 example/ck_tile/03_gemm/instances/gemm_universal_mem_bf16_bf16_bf16_mk_kn_mn.cpp create mode 100644 example/ck_tile/03_gemm/instances/gemm_universal_mem_bf16_bf16_bf16_mk_nk_mn.cpp create mode 100644 example/ck_tile/03_gemm/instances/gemm_universal_mem_f16_f16_f16_km_kn_mn.cpp create mode 100644 example/ck_tile/03_gemm/instances/gemm_universal_mem_f16_f16_f16_km_nk_mn.cpp create mode 100644 example/ck_tile/03_gemm/instances/gemm_universal_mem_f16_f16_f16_mk_kn_mn.cpp create mode 100644 example/ck_tile/03_gemm/instances/gemm_universal_mem_f16_f16_f16_mk_nk_mn.cpp create mode 100644 example/ck_tile/03_gemm/instances/gemm_universal_mem_instance_common.hpp diff --git a/example/ck_tile/03_gemm/CMakeLists.txt b/example/ck_tile/03_gemm/CMakeLists.txt index bc3799f015..f682ef0ac9 100644 --- a/example/ck_tile/03_gemm/CMakeLists.txt +++ b/example/ck_tile/03_gemm/CMakeLists.txt @@ -1,2 +1,27 @@ +# add_executable(tile_example_gemm_basic EXCLUDE_FROM_ALL gemm_basic.cpp) +# add_executable(tile_example_universal_gemm EXCLUDE_FROM_ALL universal_gemm.cpp) + +function (add_gemm_example TARGET_NAME MAIN_SRC) +message("adding ${TARGET_NAME}") +# not using add_example_executable() to add target, since we don't want this to have +# to be included in "make all/install/check" +add_executable(${TARGET_NAME} EXCLUDE_FROM_ALL ${MAIN_SRC}) +target_include_directories(${TARGET_NAME} PRIVATE ${CMAKE_CURRENT_LIST_DIR}) + +foreach(source IN LISTS ARGN) + list(APPEND INSTANCE_SRCS ${source}) +endforeach() + +target_sources(${TARGET_NAME} PRIVATE ${INSTANCE_SRCS}) + +set(COMPILE_OPTIONS) +# NOTE: we turn off undefined-func-template to let source compile without explicit declare function specializations +list(APPEND COMPILE_OPTIONS -Wno-undefined-func-template -Wno-float-equal) + +target_compile_options(${TARGET_NAME} PRIVATE ${COMPILE_OPTIONS}) +endfunction(add_gemm_example TARGET_NAME MAIN_SRC) + +file(GLOB INSTANCE_SRCS instances/*.cpp) +add_gemm_example(tile_example_gemm_universal universal_gemm.cpp ${INSTANCE_SRCS}) + add_executable(tile_example_gemm_basic EXCLUDE_FROM_ALL gemm_basic.cpp) -add_executable(tile_example_gemm_universal EXCLUDE_FROM_ALL universal_gemm.cpp) diff --git a/example/ck_tile/03_gemm/gemm_basic.cpp b/example/ck_tile/03_gemm/gemm_basic.cpp index 4c630375f4..71c508bd45 100644 --- a/example/ck_tile/03_gemm/gemm_basic.cpp +++ b/example/ck_tile/03_gemm/gemm_basic.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #include @@ -9,13 +9,10 @@ #include #include -#include "ck_tile/ops/epilogue.hpp" -#include "ck_tile/ops/gemm.hpp" -#include "ck_tile/host.hpp" #include "gemm_basic.hpp" template -float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s) +float gemm_(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s) { // The kPadM, kPadN, kPadK & kBlockPerCu should also come from the Codegen part. constexpr bool kPadM = false; @@ -103,6 +100,30 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& return ave_time; } +float gemm(const gemm_traits& t, const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s) +{ + if(t.is_a_rowmajor && t.is_b_rowmajor && t.is_c_rowmajor) + { + return gemm_(args, s); + } + else if(t.is_a_rowmajor && !t.is_b_rowmajor && t.is_c_rowmajor) + { + return gemm_(args, s); + } + else if(!t.is_a_rowmajor && t.is_b_rowmajor && t.is_c_rowmajor) + { + return gemm_(args, s); + } + else if(!t.is_a_rowmajor && !t.is_b_rowmajor && t.is_c_rowmajor) + { + return gemm_(args, s); + } + else + { + throw std::runtime_error("Wrong! Layouts not supported!\n"); + } +} + #include "run_gemm_example.inc" int main(int argc, char* argv[]) { return !run_gemm_example(argc, argv); } diff --git a/example/ck_tile/03_gemm/gemm_basic.hpp b/example/ck_tile/03_gemm/gemm_basic.hpp index 38c0a279db..e659890f4b 100644 --- a/example/ck_tile/03_gemm/gemm_basic.hpp +++ b/example/ck_tile/03_gemm/gemm_basic.hpp @@ -1,13 +1,14 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once #include - -#include "ck_tile/core.hpp" +#include "ck_tile/host.hpp" #include "ck_tile/host/kernel_launch.hpp" +#include "ck_tile/ops/gemm.hpp" +#include "ck_tile/ops/epilogue.hpp" template struct GemmBasicTypeConfig; @@ -51,6 +52,59 @@ using BDataType = Types::BDataType; using AccDataType = Types::AccDataType; using CDataType = Types::CDataType; +using Row = ck_tile::tensor_layout::gemm::RowMajor; +using Col = ck_tile::tensor_layout::gemm::ColumnMajor; + +struct gemm_traits +{ + std::string data_type; + bool is_a_rowmajor; + bool is_b_rowmajor; + bool is_c_rowmajor; +}; + +template +struct gemm_traits_ +{ + using ADataType = ck_tile::remove_cvref_t; + using BDataType = ck_tile::remove_cvref_t; + using AccDataType = ck_tile::remove_cvref_t; + using CDataType = ck_tile::remove_cvref_t; + using ALayout = ck_tile::remove_cvref_t; + using BLayout = ck_tile::remove_cvref_t; + using CLayout = ck_tile::remove_cvref_t; + static constexpr ck_tile::index_t M_Tile = M_Tile_; + static constexpr ck_tile::index_t N_Tile = N_Tile_; + static constexpr ck_tile::index_t K_Tile = K_Tile_; + static constexpr ck_tile::index_t M_Warp = M_Warp_; + static constexpr ck_tile::index_t N_Warp = N_Warp_; + static constexpr ck_tile::index_t K_Warp = K_Warp_; + static constexpr ck_tile::index_t M_Warp_Tile = M_Warp_Tile_; + static constexpr ck_tile::index_t N_Warp_Tile = N_Warp_Tile_; + static constexpr ck_tile::index_t K_Warp_Tile = K_Warp_Tile_; + static constexpr bool kPadM = kPadM_; + static constexpr bool kPadN = kPadN_; + static constexpr bool kPadK = kPadK_; +}; + auto create_args(int argc, char* argv[]) { ck_tile::ArgParser arg_parser; @@ -75,4 +129,9 @@ auto create_args(int argc, char* argv[]) } // host API -float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s); +template +float gemm_(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s); + +float gemm(const gemm_traits& traits, + const ck_tile::GemmHostArgs& args, + const ck_tile::stream_config& s); diff --git a/example/ck_tile/03_gemm/instances/gemm_api.cpp b/example/ck_tile/03_gemm/instances/gemm_api.cpp new file mode 100644 index 0000000000..05fba01bd2 --- /dev/null +++ b/example/ck_tile/03_gemm/instances/gemm_api.cpp @@ -0,0 +1,482 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "gemm_basic.hpp" + +using Row = ck_tile::tensor_layout::gemm::RowMajor; +using Col = ck_tile::tensor_layout::gemm::ColumnMajor; + +using FP32 = float; +using FP16 = ck_tile::half_t; +using BF16 = ck_tile::bf16_t; + +template +using trait_ = gemm_traits_; + +float gemm(const gemm_traits& t, const ck_tile::GemmHostArgs& a, const ck_tile::stream_config& s) +{ + if(t.data_type.compare("fp16") == 0) + { + if(t.is_a_rowmajor && t.is_b_rowmajor && t.is_c_rowmajor) + { + if(a.M > 512) + { + // universal gemm compute bound RR + std::cout << "fp16 comp\n"; + return gemm_>(a, s); + } + else + { + // universal gemm memory bound RR + std::cout << "fp16 mem\n"; + return gemm_>(a, s); + } + } + else if(t.is_a_rowmajor && !t.is_b_rowmajor && t.is_c_rowmajor) + { + if(a.M > 512) + { + // universal gemm compute bound RC + std::cout << "fp16 comp RC\n"; + return gemm_>(a, s); + } + else + { + // universal gemm memory bound RC + std::cout << "fp16 mem RC\n"; + return gemm_>(a, s); + } + } + else if(!t.is_a_rowmajor && t.is_b_rowmajor && t.is_c_rowmajor) + { + if(a.M > 512) + { + // universal gemm compute bound CR + std::cout << "fp16 comp CR\n"; + return gemm_>(a, s); + } + else + { + // universal gemm memory bound CR + std::cout << "fp16 mem CR\n"; + return gemm_>(a, s); + } + } + else if(!t.is_a_rowmajor && !t.is_b_rowmajor && t.is_c_rowmajor) + { + if(a.M > 512) + { + // universal gemm compute bound CC + std::cout << "fp16 comp CC\n"; + return gemm_>(a, s); + } + else + { + // universal gemm memory bound CC + std::cout << "fp16 mem CC\n"; + return gemm_>(a, s); + } + } + else + { + throw std::runtime_error("Wrong! ColumnMajor layout not supported for C Matrix!\n"); + } + } + else if(t.data_type.compare("bf16") == 0) + { + if(t.is_a_rowmajor && t.is_b_rowmajor && t.is_c_rowmajor) + { + if(a.M > 512) + { + // universal gemm compute bound RR + std::cout << "bf16 comp\n"; + return gemm_>(a, s); + } + else + { + // universal gemm memory bound RR + std::cout << "bf16 mem\n"; + return gemm_>(a, s); + } + } + else if(t.is_a_rowmajor && !t.is_b_rowmajor && t.is_c_rowmajor) + { + if(a.M > 512) + { + // universal gemm compute bound RC + std::cout << "bf16 comp RC\n"; + return gemm_>(a, s); + } + else + { + // universal gemm memory bound RC + std::cout << "bf16 mem RC\n"; + return gemm_>(a, s); + } + } + else if(!t.is_a_rowmajor && t.is_b_rowmajor && t.is_c_rowmajor) + { + if(a.M > 512) + { + // universal gemm compute bound CR + std::cout << "bf16 comp CR\n"; + return gemm_>(a, s); + } + else + { + // universal gemm memory bound CR + std::cout << "bf16 mem CR\n"; + return gemm_>(a, s); + } + } + else if(!t.is_a_rowmajor && !t.is_b_rowmajor && t.is_c_rowmajor) + { + if(a.M > 512) + { + // universal gemm compute bound CC + std::cout << "bf16 comp CC\n"; + return gemm_>(a, s); + } + else + { + // universal gemm memory bound CC + std::cout << "bf16 mem CC\n"; + return gemm_>(a, s); + } + } + else + { + throw std::runtime_error("Wrong! ColumnMajor layout not supported for C Matrix!\n"); + } + } + else + { + throw std::runtime_error("Wrong! DataTypes not supported!\n"); + } + + return 1.0f; +} diff --git a/example/ck_tile/03_gemm/instances/gemm_universal_comp_bf16_bf16_bf16_km_kn_mn.cpp b/example/ck_tile/03_gemm/instances/gemm_universal_comp_bf16_bf16_bf16_km_kn_mn.cpp new file mode 100644 index 0000000000..121b676b1c --- /dev/null +++ b/example/ck_tile/03_gemm/instances/gemm_universal_comp_bf16_bf16_bf16_km_kn_mn.cpp @@ -0,0 +1,27 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "gemm_universal_comp_instance_common.hpp" + +using Row = ck_tile::tensor_layout::gemm::RowMajor; +using Col = ck_tile::tensor_layout::gemm::ColumnMajor; + +template float gemm_>(const A&, const S&); diff --git a/example/ck_tile/03_gemm/instances/gemm_universal_comp_bf16_bf16_bf16_km_nk_mn.cpp b/example/ck_tile/03_gemm/instances/gemm_universal_comp_bf16_bf16_bf16_km_nk_mn.cpp new file mode 100644 index 0000000000..29d856c001 --- /dev/null +++ b/example/ck_tile/03_gemm/instances/gemm_universal_comp_bf16_bf16_bf16_km_nk_mn.cpp @@ -0,0 +1,27 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "gemm_universal_comp_instance_common.hpp" + +using Row = ck_tile::tensor_layout::gemm::RowMajor; +using Col = ck_tile::tensor_layout::gemm::ColumnMajor; + +template float gemm_>(const A&, const S&); diff --git a/example/ck_tile/03_gemm/instances/gemm_universal_comp_bf16_bf16_bf16_mk_kn_mn.cpp b/example/ck_tile/03_gemm/instances/gemm_universal_comp_bf16_bf16_bf16_mk_kn_mn.cpp new file mode 100644 index 0000000000..76138c42d6 --- /dev/null +++ b/example/ck_tile/03_gemm/instances/gemm_universal_comp_bf16_bf16_bf16_mk_kn_mn.cpp @@ -0,0 +1,26 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "gemm_universal_comp_instance_common.hpp" + +using Row = ck_tile::tensor_layout::gemm::RowMajor; + +template float gemm_>(const A&, const S&); diff --git a/example/ck_tile/03_gemm/instances/gemm_universal_comp_bf16_bf16_bf16_mk_nk_mn.cpp b/example/ck_tile/03_gemm/instances/gemm_universal_comp_bf16_bf16_bf16_mk_nk_mn.cpp new file mode 100644 index 0000000000..130b3e2691 --- /dev/null +++ b/example/ck_tile/03_gemm/instances/gemm_universal_comp_bf16_bf16_bf16_mk_nk_mn.cpp @@ -0,0 +1,27 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "gemm_universal_comp_instance_common.hpp" + +using Row = ck_tile::tensor_layout::gemm::RowMajor; +using Col = ck_tile::tensor_layout::gemm::ColumnMajor; + +template float gemm_>(const A&, const S&); diff --git a/example/ck_tile/03_gemm/instances/gemm_universal_comp_f16_f16_f16_km_kn_mn.cpp b/example/ck_tile/03_gemm/instances/gemm_universal_comp_f16_f16_f16_km_kn_mn.cpp new file mode 100644 index 0000000000..43971b017f --- /dev/null +++ b/example/ck_tile/03_gemm/instances/gemm_universal_comp_f16_f16_f16_km_kn_mn.cpp @@ -0,0 +1,27 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "gemm_universal_comp_instance_common.hpp" + +using Row = ck_tile::tensor_layout::gemm::RowMajor; +using Col = ck_tile::tensor_layout::gemm::ColumnMajor; + +template float gemm_>(const A&, const S&); diff --git a/example/ck_tile/03_gemm/instances/gemm_universal_comp_f16_f16_f16_km_nk_mn.cpp b/example/ck_tile/03_gemm/instances/gemm_universal_comp_f16_f16_f16_km_nk_mn.cpp new file mode 100644 index 0000000000..0ad95f8831 --- /dev/null +++ b/example/ck_tile/03_gemm/instances/gemm_universal_comp_f16_f16_f16_km_nk_mn.cpp @@ -0,0 +1,27 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "gemm_universal_comp_instance_common.hpp" + +using Row = ck_tile::tensor_layout::gemm::RowMajor; +using Col = ck_tile::tensor_layout::gemm::ColumnMajor; + +template float gemm_>(const A&, const S&); diff --git a/example/ck_tile/03_gemm/instances/gemm_universal_comp_f16_f16_f16_mk_kn_mn.cpp b/example/ck_tile/03_gemm/instances/gemm_universal_comp_f16_f16_f16_mk_kn_mn.cpp new file mode 100644 index 0000000000..6e2ec55c7c --- /dev/null +++ b/example/ck_tile/03_gemm/instances/gemm_universal_comp_f16_f16_f16_mk_kn_mn.cpp @@ -0,0 +1,26 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "gemm_universal_comp_instance_common.hpp" + +using Row = ck_tile::tensor_layout::gemm::RowMajor; + +template float gemm_>(const A&, const S&); diff --git a/example/ck_tile/03_gemm/instances/gemm_universal_comp_f16_f16_f16_mk_nk_mn.cpp b/example/ck_tile/03_gemm/instances/gemm_universal_comp_f16_f16_f16_mk_nk_mn.cpp new file mode 100644 index 0000000000..4ad3ed8a98 --- /dev/null +++ b/example/ck_tile/03_gemm/instances/gemm_universal_comp_f16_f16_f16_mk_nk_mn.cpp @@ -0,0 +1,27 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "gemm_universal_comp_instance_common.hpp" + +using Row = ck_tile::tensor_layout::gemm::RowMajor; +using Col = ck_tile::tensor_layout::gemm::ColumnMajor; + +template float gemm_>(const A&, const S&); diff --git a/example/ck_tile/03_gemm/instances/gemm_universal_comp_instance_common.hpp b/example/ck_tile/03_gemm/instances/gemm_universal_comp_instance_common.hpp new file mode 100644 index 0000000000..2bdd5fc380 --- /dev/null +++ b/example/ck_tile/03_gemm/instances/gemm_universal_comp_instance_common.hpp @@ -0,0 +1,206 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. +#include +#include +#include "gemm_basic.hpp" + +using A = ck_tile::GemmHostArgs; +using S = ck_tile::stream_config; + +template +using trait_ = gemm_traits_; + +template +float gemm_(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s) +{ + using GemmShape = ck_tile::TileGemmShape< + ck_tile::sequence, + ck_tile::sequence, + ck_tile::sequence>; + using TilePartitioner = ck_tile::GemmTilePartitioner; + + using GemmEpilogue = + ck_tile::Default2DEpilogue>; + using GemmTraits = ck_tile::TileGemmTraits; + using BaseGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV3< + ck_tile::GemmPipelineProblem>; + + constexpr int kBlockPerCu = 1; + + const ck_tile::index_t k_grain = args.k_batch * Traits_::K_Tile; + const ck_tile::index_t K_split = (args.K + k_grain - 1) / k_grain * Traits_::K_Tile; + const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split); + const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop); + const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop); + + float ave_time{0}; + + const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) { + constexpr bool has_hot_loop_v = has_hot_loop_.value; + constexpr auto tail_number_v = tail_number_.value; + + using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3< + ck_tile::UniversalGemmPipelineProblem>; + using Kernel = ck_tile::GemmKernel; + auto kargs = Kernel::MakeKernelArgs(args); + + const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch); + constexpr dim3 blocks = Kernel::BlockSize(); + + if(!Kernel::IsSupportedArgument(kargs)) + { + throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n"); + } + + if(s.log_level_ > 0) + { + std::cout << "Launching kernel with args:" + << " grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" + << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" + << std::endl; + } + + ave_time = ck_tile::launch_kernel( + s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); + return ave_time; + }; + + if(has_hot_loop) + { + // Tail pipeline One to Seven + if(tail_num == ck_tile::TailNumber::One) + { + Run(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + else if(tail_num == ck_tile::TailNumber::Full) + { + Run(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + + if constexpr(BaseGemmPipeline::PrefetchStages > 2) + { + if(tail_num == ck_tile::TailNumber::Two) + { + Run(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + } + if constexpr(BaseGemmPipeline::PrefetchStages > 3) + { + static_assert(BaseGemmPipeline::PrefetchStages > 3); + if(tail_num == ck_tile::TailNumber::Three) + { + Run(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + } + if constexpr(BaseGemmPipeline::PrefetchStages > 4) + { + if(tail_num == ck_tile::TailNumber::Four) + { + Run(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + } + if constexpr(BaseGemmPipeline::PrefetchStages > 5) + { + if(tail_num == ck_tile::TailNumber::Five) + { + Run(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + } + if constexpr(BaseGemmPipeline::PrefetchStages > 6) + { + if(tail_num == ck_tile::TailNumber::Six) + { + Run(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + } + if constexpr(BaseGemmPipeline::PrefetchStages > 7) + { + if(tail_num == ck_tile::TailNumber::Seven) + { + Run(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + } + } + else + { + // Tail number always Full - #PrefetchStages + if(tail_num == ck_tile::TailNumber::Full) + { + Run(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + else + { + std::ostringstream err; + err << "When there's no hot loop, this tail number \"" << tail_num + << "\" is not supported! PrefetchStages: " << BaseGemmPipeline::PrefetchStages + << "\n File: " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__; + throw std::runtime_error(err.str()); + } + } + + return ave_time; +} diff --git a/example/ck_tile/03_gemm/instances/gemm_universal_mem_bf16_bf16_bf16_km_kn_mn.cpp b/example/ck_tile/03_gemm/instances/gemm_universal_mem_bf16_bf16_bf16_km_kn_mn.cpp new file mode 100644 index 0000000000..f340a27001 --- /dev/null +++ b/example/ck_tile/03_gemm/instances/gemm_universal_mem_bf16_bf16_bf16_km_kn_mn.cpp @@ -0,0 +1,27 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "gemm_universal_mem_instance_common.hpp" + +using Row = ck_tile::tensor_layout::gemm::RowMajor; +using Col = ck_tile::tensor_layout::gemm::ColumnMajor; + +template float gemm_>(const A&, const S&); diff --git a/example/ck_tile/03_gemm/instances/gemm_universal_mem_bf16_bf16_bf16_km_nk_mn.cpp b/example/ck_tile/03_gemm/instances/gemm_universal_mem_bf16_bf16_bf16_km_nk_mn.cpp new file mode 100644 index 0000000000..24aec06e12 --- /dev/null +++ b/example/ck_tile/03_gemm/instances/gemm_universal_mem_bf16_bf16_bf16_km_nk_mn.cpp @@ -0,0 +1,27 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "gemm_universal_mem_instance_common.hpp" + +using Row = ck_tile::tensor_layout::gemm::RowMajor; +using Col = ck_tile::tensor_layout::gemm::ColumnMajor; + +template float gemm_>(const A&, const S&); diff --git a/example/ck_tile/03_gemm/instances/gemm_universal_mem_bf16_bf16_bf16_mk_kn_mn.cpp b/example/ck_tile/03_gemm/instances/gemm_universal_mem_bf16_bf16_bf16_mk_kn_mn.cpp new file mode 100644 index 0000000000..6ff10bfda8 --- /dev/null +++ b/example/ck_tile/03_gemm/instances/gemm_universal_mem_bf16_bf16_bf16_mk_kn_mn.cpp @@ -0,0 +1,26 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "gemm_universal_mem_instance_common.hpp" + +using Row = ck_tile::tensor_layout::gemm::RowMajor; + +template float gemm_>(const A&, const S&); diff --git a/example/ck_tile/03_gemm/instances/gemm_universal_mem_bf16_bf16_bf16_mk_nk_mn.cpp b/example/ck_tile/03_gemm/instances/gemm_universal_mem_bf16_bf16_bf16_mk_nk_mn.cpp new file mode 100644 index 0000000000..98fb82163d --- /dev/null +++ b/example/ck_tile/03_gemm/instances/gemm_universal_mem_bf16_bf16_bf16_mk_nk_mn.cpp @@ -0,0 +1,27 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "gemm_universal_mem_instance_common.hpp" + +using Row = ck_tile::tensor_layout::gemm::RowMajor; +using Col = ck_tile::tensor_layout::gemm::ColumnMajor; + +template float gemm_>(const A&, const S&); diff --git a/example/ck_tile/03_gemm/instances/gemm_universal_mem_f16_f16_f16_km_kn_mn.cpp b/example/ck_tile/03_gemm/instances/gemm_universal_mem_f16_f16_f16_km_kn_mn.cpp new file mode 100644 index 0000000000..8a462bb8f8 --- /dev/null +++ b/example/ck_tile/03_gemm/instances/gemm_universal_mem_f16_f16_f16_km_kn_mn.cpp @@ -0,0 +1,27 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "gemm_universal_mem_instance_common.hpp" + +using Row = ck_tile::tensor_layout::gemm::RowMajor; +using Col = ck_tile::tensor_layout::gemm::ColumnMajor; + +template float gemm_>(const A&, const S&); diff --git a/example/ck_tile/03_gemm/instances/gemm_universal_mem_f16_f16_f16_km_nk_mn.cpp b/example/ck_tile/03_gemm/instances/gemm_universal_mem_f16_f16_f16_km_nk_mn.cpp new file mode 100644 index 0000000000..8e78d850af --- /dev/null +++ b/example/ck_tile/03_gemm/instances/gemm_universal_mem_f16_f16_f16_km_nk_mn.cpp @@ -0,0 +1,27 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "gemm_universal_mem_instance_common.hpp" + +using Row = ck_tile::tensor_layout::gemm::RowMajor; +using Col = ck_tile::tensor_layout::gemm::ColumnMajor; + +template float gemm_>(const A&, const S&); diff --git a/example/ck_tile/03_gemm/instances/gemm_universal_mem_f16_f16_f16_mk_kn_mn.cpp b/example/ck_tile/03_gemm/instances/gemm_universal_mem_f16_f16_f16_mk_kn_mn.cpp new file mode 100644 index 0000000000..487dc07d69 --- /dev/null +++ b/example/ck_tile/03_gemm/instances/gemm_universal_mem_f16_f16_f16_mk_kn_mn.cpp @@ -0,0 +1,26 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "gemm_universal_mem_instance_common.hpp" + +using Row = ck_tile::tensor_layout::gemm::RowMajor; + +template float gemm_>(const A&, const S&); diff --git a/example/ck_tile/03_gemm/instances/gemm_universal_mem_f16_f16_f16_mk_nk_mn.cpp b/example/ck_tile/03_gemm/instances/gemm_universal_mem_f16_f16_f16_mk_nk_mn.cpp new file mode 100644 index 0000000000..50823e96cd --- /dev/null +++ b/example/ck_tile/03_gemm/instances/gemm_universal_mem_f16_f16_f16_mk_nk_mn.cpp @@ -0,0 +1,27 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "gemm_universal_mem_instance_common.hpp" + +using Row = ck_tile::tensor_layout::gemm::RowMajor; +using Col = ck_tile::tensor_layout::gemm::ColumnMajor; + +template float gemm_>(const A&, const S&); diff --git a/example/ck_tile/03_gemm/instances/gemm_universal_mem_instance_common.hpp b/example/ck_tile/03_gemm/instances/gemm_universal_mem_instance_common.hpp new file mode 100644 index 0000000000..b78efae272 --- /dev/null +++ b/example/ck_tile/03_gemm/instances/gemm_universal_mem_instance_common.hpp @@ -0,0 +1,206 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. +#include +#include +#include "gemm_basic.hpp" + +using A = ck_tile::GemmHostArgs; +using S = ck_tile::stream_config; + +template +using trait_ = gemm_traits_; + +template +float gemm_(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s) +{ + using GemmShape = ck_tile::TileGemmShape< + ck_tile::sequence, + ck_tile::sequence, + ck_tile::sequence>; + using TilePartitioner = ck_tile::GemmTilePartitioner; + + using GemmEpilogue = + ck_tile::Default2DEpilogue>; + using GemmTraits = ck_tile::TileGemmTraits; + using BaseGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrMem< + ck_tile::GemmPipelineProblem>; + + constexpr int kBlockPerCu = 1; + + const ck_tile::index_t k_grain = args.k_batch * Traits_::K_Tile; + const ck_tile::index_t K_split = (args.K + k_grain - 1) / k_grain * Traits_::K_Tile; + const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split); + const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop); + const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop); + + float ave_time{0}; + + const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) { + constexpr bool has_hot_loop_v = has_hot_loop_.value; + constexpr auto tail_number_v = tail_number_.value; + + using GemmPipeline = ck_tile::GemmPipelineAgBgCrMem< + ck_tile::UniversalGemmPipelineProblem>; + using Kernel = ck_tile::GemmKernel; + auto kargs = Kernel::MakeKernelArgs(args); + + const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch); + constexpr dim3 blocks = Kernel::BlockSize(); + + if(!Kernel::IsSupportedArgument(kargs)) + { + throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n"); + } + + if(s.log_level_ > 0) + { + std::cout << "Launching kernel with args:" + << " grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" + << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" + << std::endl; + } + + ave_time = ck_tile::launch_kernel( + s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); + return ave_time; + }; + + if(has_hot_loop) + { + // Tail pipeline One to Seven + if(tail_num == ck_tile::TailNumber::One) + { + Run(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + else if(tail_num == ck_tile::TailNumber::Full) + { + Run(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + + if constexpr(BaseGemmPipeline::PrefetchStages > 2) + { + if(tail_num == ck_tile::TailNumber::Two) + { + Run(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + } + if constexpr(BaseGemmPipeline::PrefetchStages > 3) + { + static_assert(BaseGemmPipeline::PrefetchStages > 3); + if(tail_num == ck_tile::TailNumber::Three) + { + Run(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + } + if constexpr(BaseGemmPipeline::PrefetchStages > 4) + { + if(tail_num == ck_tile::TailNumber::Four) + { + Run(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + } + if constexpr(BaseGemmPipeline::PrefetchStages > 5) + { + if(tail_num == ck_tile::TailNumber::Five) + { + Run(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + } + if constexpr(BaseGemmPipeline::PrefetchStages > 6) + { + if(tail_num == ck_tile::TailNumber::Six) + { + Run(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + } + if constexpr(BaseGemmPipeline::PrefetchStages > 7) + { + if(tail_num == ck_tile::TailNumber::Seven) + { + Run(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + } + } + else + { + // Tail number always Full - #PrefetchStages + if(tail_num == ck_tile::TailNumber::Full) + { + Run(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + else + { + std::ostringstream err; + err << "When there's no hot loop, this tail number \"" << tail_num + << "\" is not supported! PrefetchStages: " << BaseGemmPipeline::PrefetchStages + << "\n File: " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__; + throw std::runtime_error(err.str()); + } + } + + return ave_time; +} diff --git a/example/ck_tile/03_gemm/run_gemm_example.inc b/example/ck_tile/03_gemm/run_gemm_example.inc index 56d0348bd6..b6d9d4399d 100644 --- a/example/ck_tile/03_gemm/run_gemm_example.inc +++ b/example/ck_tile/03_gemm/run_gemm_example.inc @@ -28,8 +28,13 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf, args.stride_B = stride_B; args.stride_C = stride_C; - float ave_time = gemm_calc( - args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat}); + gemm_traits traits{DataTypeTraits{}.name, + std::is_same_v, + std::is_same_v, + std::is_same_v}; + + float ave_time = + gemm(traits, args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat}); std::size_t flop = std::size_t(2) * M * N * K; std::size_t num_byte = @@ -210,9 +215,6 @@ int run_gemm_example(int argc, char* argv[]) if(!result) return -1; - using Row = ck_tile::tensor_layout::gemm::RowMajor; - using Col = ck_tile::tensor_layout::gemm::ColumnMajor; - std::string a_layout = arg_parser.get_str("a_layout"); std::string b_layout = arg_parser.get_str("b_layout"); @@ -224,16 +226,14 @@ int run_gemm_example(int argc, char* argv[]) { return run_gemm_example_with_layouts(argc, argv, Row{}, Col{}, Row{}); } - // TODO: Fixme: with latest changes to GemmPipelineAGmemBGmemCRegV1DefaultPolicy below do not - // work. - // else if(a_layout == "C" && b_layout == "C") - // { - // return run_gemm_example_with_layouts(argc, argv, Col{}, Col{}, Row{}); - // } - // else if(a_layout == "C" && b_layout == "R") - // { - // return run_gemm_example_with_layouts(argc, argv, Col{}, Row{}, Row{}); - // } + else if(a_layout == "C" && b_layout == "C") + { + return run_gemm_example_with_layouts(argc, argv, Col{}, Col{}, Row{}); + } + else if(a_layout == "C" && b_layout == "R") + { + return run_gemm_example_with_layouts(argc, argv, Col{}, Row{}, Row{}); + } else { throw std::runtime_error("Unsupported data layout configuration for A,B and C tensors!"); diff --git a/example/ck_tile/03_gemm/universal_gemm.cpp b/example/ck_tile/03_gemm/universal_gemm.cpp index 1a9e025a9b..418c390cc1 100644 --- a/example/ck_tile/03_gemm/universal_gemm.cpp +++ b/example/ck_tile/03_gemm/universal_gemm.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #include @@ -14,207 +14,6 @@ #include "ck_tile/host.hpp" #include "gemm_basic.hpp" -#define CK_TILE_PIPELINE_COMPUTE 1 -#define CK_TILE_PIPELINE_MEMORY 2 - -#ifndef CK_TILE_PIPELINE_DEFAULT -#define CK_TILE_PIPELINE_DEFAULT CK_TILE_PIPELINE_COMPUTE -#endif - -template -float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s) -{ -#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_MEMORY) - // Memory friendly for Interwave scheduler - constexpr ck_tile::index_t M_Tile = 128; - constexpr ck_tile::index_t N_Tile = 32; - constexpr ck_tile::index_t K_Tile = 64; - - constexpr ck_tile::index_t M_Warp = 4; - constexpr ck_tile::index_t N_Warp = 1; - constexpr ck_tile::index_t K_Warp = 1; - - constexpr ck_tile::index_t M_Warp_Tile = 32; - constexpr ck_tile::index_t N_Warp_Tile = 32; - constexpr ck_tile::index_t K_Warp_Tile = 8; - -#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE) - // Compute friendly for Intrawave scheduler - constexpr ck_tile::index_t M_Tile = 256; - constexpr ck_tile::index_t N_Tile = 256; - constexpr ck_tile::index_t K_Tile = 32; - - constexpr ck_tile::index_t M_Warp = 2; - constexpr ck_tile::index_t N_Warp = 2; - constexpr ck_tile::index_t K_Warp = 1; - - constexpr ck_tile::index_t M_Warp_Tile = 32; - constexpr ck_tile::index_t N_Warp_Tile = 32; - constexpr ck_tile::index_t K_Warp_Tile = 16; -#endif - - constexpr bool kPadM = false; - constexpr bool kPadN = false; - constexpr bool kPadK = false; - - constexpr int kBlockPerCu = 1; - - // =============================================== - - using GemmShape = - ck_tile::TileGemmShape, - ck_tile::sequence, - ck_tile::sequence>; - using TilePartitioner = ck_tile::GemmTilePartitioner; - - using GemmEpilogue = ck_tile::Default2DEpilogue< - ck_tile::Default2DEpilogueProblem>; - - using Traits = ck_tile::TileGemmTraits; -#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_MEMORY) - using BaseGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrMem< -#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE) - using BaseGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV3< -#endif - ck_tile::GemmPipelineProblem>; - - const ck_tile::index_t k_grain = args.k_batch * K_Tile; - const ck_tile::index_t K_split = (args.K + k_grain - 1) / k_grain * K_Tile; - const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split); - const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop); - const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop); - - float ave_time{0}; - - const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) { - constexpr bool has_hot_loop_v = has_hot_loop_.value; - constexpr auto tail_number_v = tail_number_.value; - -#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_MEMORY) - using GemmPipeline = ck_tile::GemmPipelineAgBgCrMem< -#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE) - using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3< -#endif - ck_tile::UniversalGemmPipelineProblem>; - using Kernel = ck_tile::GemmKernel; - auto kargs = Kernel::MakeKernelArgs(args); - - const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch); - constexpr dim3 blocks = Kernel::BlockSize(); - - if(!Kernel::IsSupportedArgument(kargs)) - { - throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n"); - } - - if(s.log_level_ > 0) - { - std::cout << "Launching kernel with args:" - << " grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" - << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" - << std::endl; - } - - ave_time = ck_tile::launch_kernel( - s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); - return ave_time; - }; - - if(has_hot_loop) - { - // Tail pipeline One to Seven - if(tail_num == ck_tile::TailNumber::One) - { - Run(ck_tile::bool_constant{}, - ck_tile::integral_constant{}); - } - else if(tail_num == ck_tile::TailNumber::Full) - { - Run(ck_tile::bool_constant{}, - ck_tile::integral_constant{}); - } - - if constexpr(BaseGemmPipeline::PrefetchStages > 2) - { - if(tail_num == ck_tile::TailNumber::Two) - { - Run(ck_tile::bool_constant{}, - ck_tile::integral_constant{}); - } - } - if constexpr(BaseGemmPipeline::PrefetchStages > 3) - { - if(tail_num == ck_tile::TailNumber::Three) - { - Run(ck_tile::bool_constant{}, - ck_tile::integral_constant{}); - } - } - if constexpr(BaseGemmPipeline::PrefetchStages > 4) - { - if(tail_num == ck_tile::TailNumber::Four) - { - Run(ck_tile::bool_constant{}, - ck_tile::integral_constant{}); - } - } - if constexpr(BaseGemmPipeline::PrefetchStages > 5) - { - if(tail_num == ck_tile::TailNumber::Five) - { - Run(ck_tile::bool_constant{}, - ck_tile::integral_constant{}); - } - } - if constexpr(BaseGemmPipeline::PrefetchStages > 6) - { - if(tail_num == ck_tile::TailNumber::Six) - { - Run(ck_tile::bool_constant{}, - ck_tile::integral_constant{}); - } - } - if constexpr(BaseGemmPipeline::PrefetchStages > 7) - { - if(tail_num == ck_tile::TailNumber::Seven) - { - Run(ck_tile::bool_constant{}, - ck_tile::integral_constant{}); - } - } - } - else - { - // Tail number always Full - #PrefetchStages - if(tail_num == ck_tile::TailNumber::Full) - { - Run(ck_tile::bool_constant{}, - ck_tile::integral_constant{}); - } - else - { - std::ostringstream err; - err << "When there's no hot loop, this tail number \"" << tail_num - << "\" is not supported! PrefetchStages: " << BaseGemmPipeline::PrefetchStages - << "\n File: " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__; - throw std::runtime_error(err.str()); - } - } - - return ave_time; -} - #include "run_gemm_example.inc" int main(int argc, char* argv[]) { return !run_gemm_example(argc, argv); }