diff --git a/example/ck_tile/03_gemm/gemm_basic.cpp b/example/ck_tile/03_gemm/gemm_basic.cpp index d687e35f5d..f92f6ef87a 100644 --- a/example/ck_tile/03_gemm/gemm_basic.cpp +++ b/example/ck_tile/03_gemm/gemm_basic.cpp @@ -5,16 +5,32 @@ #include "run_gemm_example.inc" #include "run_gemm_example_common.hpp" #include "gemm_basic_invoker.hpp" +#include "ck_tile/core/utility/gemm_validation.hpp" int run_gemm_example(ck_tile::ArgParser& arg_parser) { std::string data_type = arg_parser.get_str("prec"); std::string a_layout = arg_parser.get_str("a_layout"); std::string b_layout = arg_parser.get_str("b_layout"); + std::string c_layout = arg_parser.get_str("c_layout"); + + std::tuple gemm_sizes = + parse_gemm_size(arg_parser); + + int m = std::get<0>(gemm_sizes); + int n = std::get<1>(gemm_sizes); + int k = std::get<2>(gemm_sizes); + + int stride_a = arg_parser.get_int("stride_a"); + int stride_b = arg_parser.get_int("stride_b"); + int stride_c = arg_parser.get_int("stride_c"); using GemmConfig = GemmConfigBase; using Invoker = BasicInvoker; + ck_tile::validate_gemm_stride( + a_layout, b_layout, c_layout, m, n, k, stride_a, stride_b, stride_c); + if(data_type == "fp16") { return run_gemm_example_prec_type( diff --git a/example/ck_tile/03_gemm/run_gemm_example.inc b/example/ck_tile/03_gemm/run_gemm_example.inc index 42a2d70692..d5f164c40f 100644 --- a/example/ck_tile/03_gemm/run_gemm_example.inc +++ b/example/ck_tile/03_gemm/run_gemm_example.inc @@ -254,6 +254,15 @@ bool do_verify(const ck_tile::HostTensor& c_m_n_dev_result, return pass; } +std::tuple +parse_gemm_size(ck_tile::ArgParser& arg_parser) +{ + ck_tile::index_t M = arg_parser.get_int("m"); + ck_tile::index_t N = arg_parser.get_int("n"); + ck_tile::index_t K = arg_parser.get_int("k"); + return std::make_tuple(M, N, K); +} + template +#include +#include "ck_tile/core/config.hpp" + +namespace ck_tile { + +inline void +validate_stride(std::string Layout, int M, int N, int stride, const std::string& stride_name) +{ + if(Layout == "C" && stride < M) + { + throw std::runtime_error("For ColumnMajor layout, " + stride_name + "(" + + std::to_string(stride) + ") must be greater or equal to dim " + + std::to_string(M)); + } + if(Layout == "R" && stride < N) + { + throw std::runtime_error("For RowMajor layout, " + stride_name + "(" + + std::to_string(stride) + ") must be greater or equal to dim " + + std::to_string(N)); + } +} + +inline void validate_gemm_stride(std::string a_layout, + std::string b_layout, + std::string c_layout, + int M, + int N, + int K, + int Stride_A, + int Stride_B, + int Stride_C) +{ + // set default stride + if(Stride_A <= 0) + Stride_A = (a_layout == "R") ? K : M; + if(Stride_B <= 0) + Stride_B = (b_layout == "R") ? N : K; + if(Stride_C <= 0) + Stride_C = (c_layout == "R") ? N : M; + + validate_stride(a_layout, M, K, Stride_A, "Stride_A"); + validate_stride(b_layout, K, N, Stride_B, "Stride_B"); + validate_stride(c_layout, M, N, Stride_C, "Stride_C"); +} +} // namespace ck_tile