From ef6e1866ff5c44a68ad3db4b91d73dfb4ca9d210 Mon Sep 17 00:00:00 2001 From: "assistant-librarian[bot]" Date: Thu, 30 Oct 2025 03:31:18 +0000 Subject: [PATCH] Merge commit '8c4cb4f9f4d3e96813c8dd5b26e175c169d14a9c' into develop --- example/ck_tile/03_gemm/gemm_basic.cpp | 16 ++++++ example/ck_tile/03_gemm/run_gemm_example.inc | 9 ++++ include/ck_tile/core.hpp | 1 + .../ck_tile/core/utility/gemm_validation.hpp | 51 +++++++++++++++++++ 4 files changed, 77 insertions(+) create mode 100644 include/ck_tile/core/utility/gemm_validation.hpp diff --git a/example/ck_tile/03_gemm/gemm_basic.cpp b/example/ck_tile/03_gemm/gemm_basic.cpp index d687e35f5d..f92f6ef87a 100644 --- a/example/ck_tile/03_gemm/gemm_basic.cpp +++ b/example/ck_tile/03_gemm/gemm_basic.cpp @@ -5,16 +5,32 @@ #include "run_gemm_example.inc" #include "run_gemm_example_common.hpp" #include "gemm_basic_invoker.hpp" +#include "ck_tile/core/utility/gemm_validation.hpp" int run_gemm_example(ck_tile::ArgParser& arg_parser) { std::string data_type = arg_parser.get_str("prec"); std::string a_layout = arg_parser.get_str("a_layout"); std::string b_layout = arg_parser.get_str("b_layout"); + std::string c_layout = arg_parser.get_str("c_layout"); + + std::tuple gemm_sizes = + parse_gemm_size(arg_parser); + + int m = std::get<0>(gemm_sizes); + int n = std::get<1>(gemm_sizes); + int k = std::get<2>(gemm_sizes); + + int stride_a = arg_parser.get_int("stride_a"); + int stride_b = arg_parser.get_int("stride_b"); + int stride_c = arg_parser.get_int("stride_c"); using GemmConfig = GemmConfigBase; using Invoker = BasicInvoker; + ck_tile::validate_gemm_stride( + a_layout, b_layout, c_layout, m, n, k, stride_a, stride_b, stride_c); + if(data_type == "fp16") { return run_gemm_example_prec_type( diff --git a/example/ck_tile/03_gemm/run_gemm_example.inc b/example/ck_tile/03_gemm/run_gemm_example.inc index 42a2d70692..d5f164c40f 100644 --- a/example/ck_tile/03_gemm/run_gemm_example.inc +++ b/example/ck_tile/03_gemm/run_gemm_example.inc @@ -254,6 +254,15 @@ bool do_verify(const ck_tile::HostTensor& c_m_n_dev_result, return pass; } +std::tuple +parse_gemm_size(ck_tile::ArgParser& arg_parser) +{ + ck_tile::index_t M = arg_parser.get_int("m"); + ck_tile::index_t N = arg_parser.get_int("n"); + ck_tile::index_t K = arg_parser.get_int("k"); + return std::make_tuple(M, N, K); +} + template +#include +#include "ck_tile/core/config.hpp" + +namespace ck_tile { + +inline void +validate_stride(std::string Layout, int M, int N, int stride, const std::string& stride_name) +{ + if(Layout == "C" && stride < M) + { + throw std::runtime_error("For ColumnMajor layout, " + stride_name + "(" + + std::to_string(stride) + ") must be greater or equal to dim " + + std::to_string(M)); + } + if(Layout == "R" && stride < N) + { + throw std::runtime_error("For RowMajor layout, " + stride_name + "(" + + std::to_string(stride) + ") must be greater or equal to dim " + + std::to_string(N)); + } +} + +inline void validate_gemm_stride(std::string a_layout, + std::string b_layout, + std::string c_layout, + int M, + int N, + int K, + int Stride_A, + int Stride_B, + int Stride_C) +{ + // set default stride + if(Stride_A <= 0) + Stride_A = (a_layout == "R") ? K : M; + if(Stride_B <= 0) + Stride_B = (b_layout == "R") ? N : K; + if(Stride_C <= 0) + Stride_C = (c_layout == "R") ? N : M; + + validate_stride(a_layout, M, K, Stride_A, "Stride_A"); + validate_stride(b_layout, K, N, Stride_B, "Stride_B"); + validate_stride(c_layout, M, N, Stride_C, "Stride_C"); +} +} // namespace ck_tile