From 2234480b5db2ce1206860ee6cd2b75f43a6b86d6 Mon Sep 17 00:00:00 2001 From: Aviral Goel Date: Tue, 3 Jun 2025 10:26:58 -0400 Subject: [PATCH] Add 0 as an acceptable arguement for strides in CK GEMM example (Issue 2037) (#2268) * add 0 as valid default arguement for strides * add 0 as valid default arguement for strides # Conflicts: # example/01_gemm/common.hpp [ROCm/composable_kernel commit: 11f6c14e03996486fc86e7f509640098dfdf306d] --- example/01_gemm/common.hpp | 42 ++++++++++--------- example/01_gemm/run_gemm_example.inc | 2 +- example/01_gemm/run_gemm_example_streamk.inc | 2 +- .../01_gemm/run_gemm_example_streamk_v2.inc | 2 +- example/01_gemm/run_gemm_example_v2.inc | 2 +- 5 files changed, 27 insertions(+), 23 deletions(-) diff --git a/example/01_gemm/common.hpp b/example/01_gemm/common.hpp index 9073ffcfc1..d3e61b8216 100644 --- a/example/01_gemm/common.hpp +++ b/example/01_gemm/common.hpp @@ -128,11 +128,12 @@ bool parse_cmd_args(int argc, } else { - std::cerr << "arg1: verification (0=no, 1=CPU, 2=GPU, 3=CPU and GPU)" << std::endl - << "arg2: initialization (0=no init, 1=integer value, 2=decimal value)" - << std::endl - << "arg3: time kernel (0=no, 1=yes)" << std::endl - << "arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC" << std::endl; + std::cerr + << "arg1: verification (0=no, 1=CPU, 2=GPU, 3=CPU and GPU)" << std::endl + << "arg2: initialization (0=no init, 1=integer value, 2=decimal value)" << std::endl + << "arg3: time kernel (0=no, 1=yes)" << std::endl + << "arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC (default: -1 or 0)" + << std::endl; return false; } @@ -181,7 +182,8 @@ bool parse_cmd_args(int argc, << "arg1: verification (0=no, 1=CPU, 2=GPU, 3=CPU and GPU)" << std::endl << "arg2: initialization (0=no init, 1=integer value, 2=decimal value)" << std::endl << "arg3: time kernel (0=no, 1=yes)" << std::endl - << "arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC" << std::endl + << "arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC (default: -1 or 0)" + << std::endl << "arg10: stream-k select (-1: default config, 0: all DP, 1: 1-tile SK, 2: 2-tile SK)" << "\narg11: Grid_size(-1 for max occupancy)" << std::endl; return false; @@ -227,13 +229,14 @@ bool parse_cmd_args(int argc, } else { - std::cerr << "arg1: verification (0=no, 1=CPU, 2=GPU, 3=CPU and GPU)" << std::endl - << "arg2: initialization (0=no init, 1=integer value, 2=decimal value)" - << std::endl - << "arg3: time kernel (0=no, 1=yes)" << std::endl - << "arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC" << std::endl - << "arg10: stream-k select (0: all DP, 1: 1-tile SK, 2: 2-tile SK)" - << "\narg11: Grid_size(-1 for max occupancy)" << std::endl; + std::cerr + << "arg1: verification (0=no, 1=CPU, 2=GPU, 3=CPU and GPU)" << std::endl + << "arg2: initialization (0=no init, 1=integer value, 2=decimal value)" << std::endl + << "arg3: time kernel (0=no, 1=yes)" << std::endl + << "arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC (default: -1 or 0)" + << std::endl + << "arg10: stream-k select (0: all DP, 1: 1-tile SK, 2: 2-tile SK)" + << "\narg11: Grid_size(-1 for max occupancy)" << std::endl; return false; } @@ -277,12 +280,13 @@ bool parse_cmd_args(int argc, } else { - std::cerr << "arg1: verification (0=no, 1=CPU, 2=GPU, 3=CPU and GPU)" << std::endl - << "arg2: initialization (0=no init, 1=integer value, 2=decimal value)" - << std::endl - << "arg3: time kernel (0=no, 1=yes)" << std::endl - << "arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC" << std::endl - << "arg10: KBatch" << std::endl; + std::cerr + << "arg1: verification (0=no, 1=CPU, 2=GPU, 3=CPU and GPU)" << std::endl + << "arg2: initialization (0=no init, 1=integer value, 2=decimal value)" << std::endl + << "arg3: time kernel (0=no, 1=yes)" << std::endl + << "arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC (default: -1 or 0)" + << std::endl + << "arg10: KBatch" << std::endl; return false; } diff --git a/example/01_gemm/run_gemm_example.inc b/example/01_gemm/run_gemm_example.inc index c064ed500c..6c5d9f9fba 100644 --- a/example/01_gemm/run_gemm_example.inc +++ b/example/01_gemm/run_gemm_example.inc @@ -33,7 +33,7 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) auto f_get_default_stride = [](std::size_t row, std::size_t col, ck::index_t stride, auto layout) { - if(stride == -1) + if(stride == -1 || stride == 0) { // give a chance if stride is -1, return a default packed stride if constexpr(std::is_same_v) diff --git a/example/01_gemm/run_gemm_example_streamk.inc b/example/01_gemm/run_gemm_example_streamk.inc index 438afcf71a..7e43847463 100644 --- a/example/01_gemm/run_gemm_example_streamk.inc +++ b/example/01_gemm/run_gemm_example_streamk.inc @@ -36,7 +36,7 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) auto f_get_default_stride = [](std::size_t row, std::size_t col, ck::index_t stride, auto layout) { - if(stride == -1) + if(stride == -1 || stride == 0) { // give a chance if stride is -1, return a default packed stride if constexpr(std::is_same_v) diff --git a/example/01_gemm/run_gemm_example_streamk_v2.inc b/example/01_gemm/run_gemm_example_streamk_v2.inc index 9ee380d247..af35de0d25 100644 --- a/example/01_gemm/run_gemm_example_streamk_v2.inc +++ b/example/01_gemm/run_gemm_example_streamk_v2.inc @@ -35,7 +35,7 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) auto f_get_default_stride = [](std::size_t row, std::size_t col, ck::index_t stride, auto layout) { - if(stride == -1) + if(stride == -1 || stride == 0) { // give a chance if stride is -1, return a default packed stride if constexpr(std::is_same_v) diff --git a/example/01_gemm/run_gemm_example_v2.inc b/example/01_gemm/run_gemm_example_v2.inc index 2b60fa5d28..4adb6f896b 100644 --- a/example/01_gemm/run_gemm_example_v2.inc +++ b/example/01_gemm/run_gemm_example_v2.inc @@ -34,7 +34,7 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) auto f_get_default_stride = [](std::size_t row, std::size_t col, ck::index_t stride, auto layout) { - if(stride == -1) + if(stride == -1 || stride == 0) { // give a chance if stride is -1, return a default packed stride if constexpr(std::is_same_v)