diff --git a/example/ck_tile/03_gemm/README.md b/example/ck_tile/03_gemm/README.md index 20cc202176..59ef2640b7 100644 --- a/example/ck_tile/03_gemm/README.md +++ b/example/ck_tile/03_gemm/README.md @@ -18,7 +18,6 @@ This will result in an executable `build/bin/tile_example_gemm_basic` & `build/b ## example ``` args: - -b batch size (default:1) -m m dimension (default:1024) -n n dimension (default:2048) -k k dimension (default:64) @@ -29,9 +28,11 @@ args: -stride_b Tensor B stride (default:0) -stride_c Tensor C stride (default:0) -v 0. No validation, 1. Validation on CPU, 2. Validation on GPU (default:2) - -e Absolute error tolerance (default:1e-5) -prec data type. fp16/bf16/fp8/bf8/int8 (default:fp16) -warmup number of iterations before benchmark the kernel (default:10) -repeat number of iterations to benchmark the kernel (default:100) -timer gpu:gpu timer, cpu:cpu timer (default:gpu) + -split_k splitK value (default:1) + -init 0:random, 1:linear, 2:constant (default:1) + -persistent 0:non-persistent, 1:persistent (default:0) ``` diff --git a/example/ck_tile/03_gemm/gemm_basic.cpp b/example/ck_tile/03_gemm/gemm_basic.cpp index 0d9c2d9957..25781a4ae8 100644 --- a/example/ck_tile/03_gemm/gemm_basic.cpp +++ b/example/ck_tile/03_gemm/gemm_basic.cpp @@ -1,15 +1,6 @@ // SPDX-License-Identifier: MIT // Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. -#include - -#include -#include -#include -#include -#include - -#include "ck_tile/host.hpp" #include "gemm_utils.hpp" template ; - using CodegenGemmTraits = - ck_tile::TileGemmTraits; + using CodegenGemmTraits = ck_tile::TileGemmTraits; using CodegenPipelineProblem = ck_tile:: GemmPipelineProblem; @@ -111,28 +100,30 @@ float gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s) << std::endl; } - float ave_time = ck_tile::launch_kernel( - s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); + float ave_time = + ck_tile::launch_kernel(s, + ck_tile::make_kernel( + Kernel{}, grids, blocks, 0, kargs)); return ave_time; }; if(args.k_batch == 1) { - return Run(ck_tile::integral_constant{}); + return Run(MemoryOpSet{}); } else { - return Run(ck_tile::integral_constant{}); + return Run(MemoryOpAtomicAdd{}); } } #include "run_gemm_example.inc" template -int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int argc, char* argv[]) +int run_gemm_example_prec_type(std::string a_layout, + std::string b_layout, + ck_tile::ArgParser& arg_parser) { using Row = ck_tile::tensor_layout::gemm::RowMajor; using Col = ck_tile::tensor_layout::gemm::ColumnMajor; @@ -142,12 +133,12 @@ int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int a if(a_layout == "R" && b_layout == "C") { return run_gemm_example_with_layouts( - argc, argv, Row{}, Col{}, Row{}); + arg_parser, Row{}, Col{}, Row{}); } else if(a_layout == "C" && b_layout == "C") { return run_gemm_example_with_layouts( - argc, argv, Col{}, Col{}, Row{}); + arg_parser, Col{}, Col{}, Row{}); } else { @@ -160,22 +151,22 @@ int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int a if(a_layout == "R" && b_layout == "C") { return run_gemm_example_with_layouts( - argc, argv, Row{}, Col{}, Row{}); + arg_parser, Row{}, Col{}, Row{}); } else if(a_layout == "R" && b_layout == "R") { return run_gemm_example_with_layouts( - argc, argv, Row{}, Row{}, Row{}); + arg_parser, Row{}, Row{}, Row{}); } else if(a_layout == "C" && b_layout == "R") { return run_gemm_example_with_layouts( - argc, argv, Col{}, Row{}, Row{}); + arg_parser, Col{}, Row{}, Row{}); } else if(a_layout == "C" && b_layout == "C") { return run_gemm_example_with_layouts( - argc, argv, Col{}, Col{}, Row{}); + arg_parser, Col{}, Col{}, Row{}); } else { @@ -184,38 +175,34 @@ int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int a } } -int run_gemm_example(int argc, char* argv[]) +int run_gemm_example(ck_tile::ArgParser& arg_parser) { - auto [result, arg_parser] = create_args(argc, argv); - if(!result) - return -1; - std::string data_type = arg_parser.get_str("prec"); std::string a_layout = arg_parser.get_str("a_layout"); std::string b_layout = arg_parser.get_str("b_layout"); if(data_type == "fp16") { - return run_gemm_example_prec_type(a_layout, b_layout, argc, argv); + return run_gemm_example_prec_type(a_layout, b_layout, arg_parser); } else if(data_type == "bf16") { - return run_gemm_example_prec_type(a_layout, b_layout, argc, argv); + return run_gemm_example_prec_type(a_layout, b_layout, arg_parser); } else if(data_type == "fp8") { return run_gemm_example_prec_type( - a_layout, b_layout, argc, argv); + a_layout, b_layout, arg_parser); } else if(data_type == "bf8") { return run_gemm_example_prec_type( - a_layout, b_layout, argc, argv); + a_layout, b_layout, arg_parser); } else if(data_type == "i8") { return run_gemm_example_prec_type( - a_layout, b_layout, argc, argv); + a_layout, b_layout, arg_parser); } else if(data_type == "pk_int4_t") { @@ -223,7 +210,7 @@ int run_gemm_example(int argc, char* argv[]) if constexpr(GemmConfigBase::Pipeline == CK_TILE_PIPELINE_COMPUTE_V3) { return run_gemm_example_prec_type( - a_layout, b_layout, argc, argv); + a_layout, b_layout, arg_parser); } else { @@ -238,9 +225,13 @@ int run_gemm_example(int argc, char* argv[]) int main(int argc, char* argv[]) { + auto [result, arg_parser] = create_args(argc, argv); + if(!result) + return -1; + try { - return !run_gemm_example(argc, argv); + return !run_gemm_example(arg_parser); } catch(const std::runtime_error& e) { diff --git a/example/ck_tile/03_gemm/gemm_utils.hpp b/example/ck_tile/03_gemm/gemm_utils.hpp index cab110597b..5f477b3821 100644 --- a/example/ck_tile/03_gemm/gemm_utils.hpp +++ b/example/ck_tile/03_gemm/gemm_utils.hpp @@ -476,6 +476,12 @@ auto create_args(int argc, char* argv[]) return std::make_tuple(result, arg_parser); } +// Type aliases for memory operation integral constants +using MemoryOpSet = + std::integral_constant; +using MemoryOpAtomicAdd = std::integral_constant; + // host API template -int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int argc, char* argv[]) +int run_gemm_example_prec_type(std::string a_layout, + std::string b_layout, + ck_tile::ArgParser& arg_parser) { - using Row = ck_tile::tensor_layout::gemm::RowMajor; - using Col = ck_tile::tensor_layout::gemm::ColumnMajor; - auto [result, arg_parser] = create_args(argc, argv); - bool preshuffle = GemmConfig::Preshuffle; + using Row = ck_tile::tensor_layout::gemm::RowMajor; + using Col = ck_tile::tensor_layout::gemm::ColumnMajor; + bool preshuffle = GemmConfig::Preshuffle; if(preshuffle && (a_layout != "R" || b_layout != "C")) { @@ -226,7 +227,7 @@ int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int a if(a_layout == "R" && b_layout == "C") { return run_gemm_example_with_layouts( - argc, argv, Row{}, Col{}, Row{}); + arg_parser, Row{}, Col{}, Row{}); } else { @@ -235,12 +236,8 @@ int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int a } template