From b83ea0b858c56fbaf93e9c23fd07b4cbb40d43f0 Mon Sep 17 00:00:00 2001 From: Philip Maybank Date: Fri, 21 Nov 2025 12:09:52 +0000 Subject: [PATCH] cascade to GEMM Multi D and GEMM Preshuffle operators --- .../gemm_multi_d/gemm_multi_d_benchmark.py | 16 ++++------- .../gemm_multi_d_benchmark_single.cpp | 28 +++++++++++++++++-- .../gemm_multi_d/gemm_multi_d_profiler.hpp | 24 +++++++++------- .../gemm_preshuffle_benchmark.py | 17 ++++------- .../gemm_preshuffle_benchmark_single.cpp | 28 +++++++++++++++++-- .../gemm_preshuffle_profiler.hpp | 24 +++++++++------- 6 files changed, 90 insertions(+), 47 deletions(-) diff --git a/tile_engine/ops/gemm_multi_d/gemm_multi_d_benchmark.py b/tile_engine/ops/gemm_multi_d/gemm_multi_d_benchmark.py index 044e08baca..53fdff3949 100755 --- a/tile_engine/ops/gemm_multi_d/gemm_multi_d_benchmark.py +++ b/tile_engine/ops/gemm_multi_d/gemm_multi_d_benchmark.py @@ -214,27 +214,23 @@ class GemmMultiDBenchmark: # Add JSON output flag for clean JSON output cmd.append("-json_output=true") + cmd.append(f"-json_file={json_file}") if self.verbose: print(f"Running: {' '.join(cmd)}") try: - result = subprocess.run(cmd, capture_output=True, text=True, timeout=60) + result = subprocess.run(cmd, timeout=60) if result.returncode != 0: - print(f"Error running {kernel_path.name}: {result.stderr}") + print(f"Error running {kernel_path.name}") return None - # Save raw output to individual JSON file - output = result.stdout.strip() - if output: - with open(json_file, "w") as f: - f.write(output) - - # Parse the JSON file + # Parse the JSON file that was written by the C++ program + if json_file.exists(): return self.parse_json_file(json_file) else: - print(f"No output from {kernel_path.name}") + print(f"No output file created by {kernel_path.name}") return None except subprocess.TimeoutExpired: diff --git a/tile_engine/ops/gemm_multi_d/gemm_multi_d_benchmark_single.cpp b/tile_engine/ops/gemm_multi_d/gemm_multi_d_benchmark_single.cpp index 41d2f736e1..39305c6239 100644 --- a/tile_engine/ops/gemm_multi_d/gemm_multi_d_benchmark_single.cpp +++ b/tile_engine/ops/gemm_multi_d/gemm_multi_d_benchmark_single.cpp @@ -70,7 +70,10 @@ inline auto create_args(int argc, char* argv[]) "false", "Whether to output results in JSON format only. Possible values are true or false. " "Default is " - "false"); + "false") + .insert("json_file", + "", + "The filename for JSON output. If empty, output goes to stdout. Default is empty."); bool result = arg_parser.parse(argc, argv); return std::make_tuple(result, arg_parser); @@ -128,8 +131,27 @@ void benchmark_single(const ck_tile::ArgParser& arg_parser) arg_parser.get_int("rotating_count"), arg_parser.get_bool("json_output")}; - // Get the profiler instance - auto& profiler = GemmMultiDProfiler::instance(setting); + // Handle output stream - either file or stdout + std::string json_filename = arg_parser.get_str("json_file"); + std::ofstream json_file_stream; + std::ostream* output_stream = &std::cout; + + if(!json_filename.empty()) + { + json_file_stream.open(json_filename); + if(json_file_stream.is_open()) + { + output_stream = &json_file_stream; + } + else + { + std::cerr << "Warning: Failed to open JSON file " << json_filename + << ", using stdout instead." << std::endl; + } + } + + // Get the profiler instance with output stream + auto& profiler = GemmProfiler::instance(setting, output_stream); try { diff --git a/tile_engine/ops/gemm_multi_d/gemm_multi_d_profiler.hpp b/tile_engine/ops/gemm_multi_d/gemm_multi_d_profiler.hpp index 3a2cdc71fe..d01d13ba87 100644 --- a/tile_engine/ops/gemm_multi_d/gemm_multi_d_profiler.hpp +++ b/tile_engine/ops/gemm_multi_d/gemm_multi_d_profiler.hpp @@ -14,9 +14,9 @@ class GemmMultiDProfiler { public: - static GemmMultiDProfiler& instance(Setting setting) + static GemmProfiler& instance(Setting setting, std::ostream* output_stream = &std::cout) { - static GemmMultiDProfiler instance{setting}; + static GemmProfiler instance{setting, output_stream}; return instance; } @@ -199,7 +199,7 @@ class GemmMultiDProfiler if(setting_.log_ > 0 && !setting_.json_output_) { - std::cout << kernel_instance << std::endl; + *output_stream_ << kernel_instance << std::endl; } // verify result @@ -217,7 +217,7 @@ class GemmMultiDProfiler } else { - std::cout << "Verification failed, skip kernel: " << name << std::endl; + *output_stream_ << "Verification failed, skip kernel: " << name << std::endl; } // clear tensor @@ -240,14 +240,14 @@ class GemmMultiDProfiler if(setting_.json_output_) { // Output clean JSON only - std::cout << kernel_instance << std::endl; + *output_stream_ << kernel_instance << std::endl; } else { - std::cout << "**********************************" << std::endl; - std::cout << "According to given metrics: " << get_metric_name(metric) << "\n" - << "Current kernel performance is: " << kernel_instance << std::endl; - std::cout << "**********************************" << std::endl; + *output_stream_ << "**********************************" << std::endl; + *output_stream_ << "According to given metrics: " << get_metric_name(metric) << "\n" + << "Current kernel performance is: " << kernel_instance << std::endl; + *output_stream_ << "**********************************" << std::endl; } if(!setting_.csv_filename_.empty()) @@ -299,9 +299,13 @@ class GemmMultiDProfiler private: ~GemmMultiDProfiler() { kernel_instances_.clear(); } - GemmMultiDProfiler(Setting setting) : setting_(setting) {} + GemmMultiDProfiler(Setting setting, std::ostream* output_stream = &std::cout) + : setting_(setting), output_stream_(output_stream) + { + } Setting setting_; + std::ostream* output_stream_; std::vector kernel_instances_; }; diff --git a/tile_engine/ops/gemm_preshuffle/gemm_preshuffle_benchmark.py b/tile_engine/ops/gemm_preshuffle/gemm_preshuffle_benchmark.py index d8892be7d6..27717ed535 100755 --- a/tile_engine/ops/gemm_preshuffle/gemm_preshuffle_benchmark.py +++ b/tile_engine/ops/gemm_preshuffle/gemm_preshuffle_benchmark.py @@ -214,28 +214,23 @@ class GemmPreshuffleBenchmark: # Add JSON output flag for clean JSON output cmd.append("-json_output=true") + cmd.append(f"-json_file={json_file}") if self.verbose: print(f"Running: {' '.join(cmd)}") try: - result = subprocess.run(cmd, capture_output=True, text=True, timeout=60) + result = subprocess.run(cmd, timeout=60) if result.returncode != 0: - print(f"Error running {kernel_path.name}: {result.stderr}") + print(f"Error running {kernel_path.name}") return None - # Save raw output to individual JSON file - output = result.stdout.strip() - - if output: - with open(json_file, "w") as f: - f.write(output) - - # Parse the JSON file + # Parse the JSON file that was written by the C++ program + if json_file.exists(): return self.parse_json_file(json_file) else: - print(f"No output from {kernel_path.name}") + print(f"No output file created by {kernel_path.name}") return None except subprocess.TimeoutExpired: diff --git a/tile_engine/ops/gemm_preshuffle/gemm_preshuffle_benchmark_single.cpp b/tile_engine/ops/gemm_preshuffle/gemm_preshuffle_benchmark_single.cpp index 4fbb25f0c9..b546b39ab9 100644 --- a/tile_engine/ops/gemm_preshuffle/gemm_preshuffle_benchmark_single.cpp +++ b/tile_engine/ops/gemm_preshuffle/gemm_preshuffle_benchmark_single.cpp @@ -69,7 +69,10 @@ inline auto create_args(int argc, char* argv[]) "false", "Whether to output results in JSON format only. Possible values are true or false. " "Default is " - "false"); + "false") + .insert("json_file", + "", + "The filename for JSON output. If empty, output goes to stdout. Default is empty."); bool result = arg_parser.parse(argc, argv); return std::make_tuple(result, arg_parser); @@ -118,8 +121,27 @@ void benchmark_single(const ck_tile::ArgParser& arg_parser) arg_parser.get_int("rotating_count"), arg_parser.get_bool("json_output")}; - // Get the profiler instance - auto& profiler = GemmProfiler::instance(setting); + // Handle output stream - either file or stdout + std::string json_filename = arg_parser.get_str("json_file"); + std::ofstream json_file_stream; + std::ostream* output_stream = &std::cout; + + if(!json_filename.empty()) + { + json_file_stream.open(json_filename); + if(json_file_stream.is_open()) + { + output_stream = &json_file_stream; + } + else + { + std::cerr << "Warning: Failed to open JSON file " << json_filename + << ", using stdout instead." << std::endl; + } + } + + // Get the profiler instance with output stream + auto& profiler = GemmProfiler::instance(setting, output_stream); try { diff --git a/tile_engine/ops/gemm_preshuffle/gemm_preshuffle_profiler.hpp b/tile_engine/ops/gemm_preshuffle/gemm_preshuffle_profiler.hpp index 739bd7e677..d18af10725 100644 --- a/tile_engine/ops/gemm_preshuffle/gemm_preshuffle_profiler.hpp +++ b/tile_engine/ops/gemm_preshuffle/gemm_preshuffle_profiler.hpp @@ -10,9 +10,9 @@ class GemmProfiler { public: - static GemmProfiler& instance(Setting setting) + static GemmProfiler& instance(Setting setting, std::ostream* output_stream = &std::cout) { - static GemmProfiler instance{setting}; + static GemmProfiler instance{setting, output_stream}; return instance; } @@ -182,7 +182,7 @@ class GemmProfiler if(setting_.log_ > 0 && !setting_.json_output_) { - std::cout << kernel_instance << std::endl; + *output_stream_ << kernel_instance << std::endl; } // verify result @@ -198,7 +198,7 @@ class GemmProfiler } else { - std::cout << "Verification failed, skip kernel: " << name << std::endl; + *output_stream_ << "Verification failed, skip kernel: " << name << std::endl; } // clear tensor @@ -221,14 +221,14 @@ class GemmProfiler if(setting_.json_output_) { // Output clean JSON only - std::cout << kernel_instance << std::endl; + *output_stream_ << kernel_instance << std::endl; } else { - std::cout << "**********************************" << std::endl; - std::cout << "According to given metrics: " << get_metric_name(metric) << "\n" - << "Current kernel performance is: " << kernel_instance << std::endl; - std::cout << "**********************************" << std::endl; + *output_stream_ << "**********************************" << std::endl; + *output_stream_ << "According to given metrics: " << get_metric_name(metric) << "\n" + << "Current kernel performance is: " << kernel_instance << std::endl; + *output_stream_ << "**********************************" << std::endl; } if(!setting_.csv_filename_.empty()) @@ -281,9 +281,13 @@ class GemmProfiler private: ~GemmProfiler() { kernel_instances_.clear(); } - GemmProfiler(Setting setting) : setting_(setting) {} + GemmProfiler(Setting setting, std::ostream* output_stream = &std::cout) + : setting_(setting), output_stream_(output_stream) + { + } Setting setting_; + std::ostream* output_stream_; std::vector kernel_instances_; };