mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-07-01 04:07:56 +00:00
cascade to GEMM Multi D and GEMM Preshuffle operators
This commit is contained in:
@@ -214,27 +214,23 @@ class GemmMultiDBenchmark:
|
||||
|
||||
# Add JSON output flag for clean JSON output
|
||||
cmd.append("-json_output=true")
|
||||
cmd.append(f"-json_file={json_file}")
|
||||
|
||||
if self.verbose:
|
||||
print(f"Running: {' '.join(cmd)}")
|
||||
|
||||
try:
|
||||
result = subprocess.run(cmd, capture_output=True, text=True, timeout=60)
|
||||
result = subprocess.run(cmd, timeout=60)
|
||||
|
||||
if result.returncode != 0:
|
||||
print(f"Error running {kernel_path.name}: {result.stderr}")
|
||||
print(f"Error running {kernel_path.name}")
|
||||
return None
|
||||
|
||||
# Save raw output to individual JSON file
|
||||
output = result.stdout.strip()
|
||||
if output:
|
||||
with open(json_file, "w") as f:
|
||||
f.write(output)
|
||||
|
||||
# Parse the JSON file
|
||||
# Parse the JSON file that was written by the C++ program
|
||||
if json_file.exists():
|
||||
return self.parse_json_file(json_file)
|
||||
else:
|
||||
print(f"No output from {kernel_path.name}")
|
||||
print(f"No output file created by {kernel_path.name}")
|
||||
return None
|
||||
|
||||
except subprocess.TimeoutExpired:
|
||||
|
||||
@@ -70,7 +70,10 @@ inline auto create_args(int argc, char* argv[])
|
||||
"false",
|
||||
"Whether to output results in JSON format only. Possible values are true or false. "
|
||||
"Default is "
|
||||
"false");
|
||||
"false")
|
||||
.insert("json_file",
|
||||
"",
|
||||
"The filename for JSON output. If empty, output goes to stdout. Default is empty.");
|
||||
|
||||
bool result = arg_parser.parse(argc, argv);
|
||||
return std::make_tuple(result, arg_parser);
|
||||
@@ -128,8 +131,27 @@ void benchmark_single(const ck_tile::ArgParser& arg_parser)
|
||||
arg_parser.get_int("rotating_count"),
|
||||
arg_parser.get_bool("json_output")};
|
||||
|
||||
// Get the profiler instance
|
||||
auto& profiler = GemmMultiDProfiler::instance(setting);
|
||||
// Handle output stream - either file or stdout
|
||||
std::string json_filename = arg_parser.get_str("json_file");
|
||||
std::ofstream json_file_stream;
|
||||
std::ostream* output_stream = &std::cout;
|
||||
|
||||
if(!json_filename.empty())
|
||||
{
|
||||
json_file_stream.open(json_filename);
|
||||
if(json_file_stream.is_open())
|
||||
{
|
||||
output_stream = &json_file_stream;
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cerr << "Warning: Failed to open JSON file " << json_filename
|
||||
<< ", using stdout instead." << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
// Get the profiler instance with output stream
|
||||
auto& profiler = GemmProfiler::instance(setting, output_stream);
|
||||
|
||||
try
|
||||
{
|
||||
|
||||
@@ -14,9 +14,9 @@
|
||||
class GemmMultiDProfiler
|
||||
{
|
||||
public:
|
||||
static GemmMultiDProfiler& instance(Setting setting)
|
||||
static GemmProfiler& instance(Setting setting, std::ostream* output_stream = &std::cout)
|
||||
{
|
||||
static GemmMultiDProfiler instance{setting};
|
||||
static GemmProfiler instance{setting, output_stream};
|
||||
return instance;
|
||||
}
|
||||
|
||||
@@ -199,7 +199,7 @@ class GemmMultiDProfiler
|
||||
|
||||
if(setting_.log_ > 0 && !setting_.json_output_)
|
||||
{
|
||||
std::cout << kernel_instance << std::endl;
|
||||
*output_stream_ << kernel_instance << std::endl;
|
||||
}
|
||||
|
||||
// verify result
|
||||
@@ -217,7 +217,7 @@ class GemmMultiDProfiler
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cout << "Verification failed, skip kernel: " << name << std::endl;
|
||||
*output_stream_ << "Verification failed, skip kernel: " << name << std::endl;
|
||||
}
|
||||
|
||||
// clear tensor
|
||||
@@ -240,14 +240,14 @@ class GemmMultiDProfiler
|
||||
if(setting_.json_output_)
|
||||
{
|
||||
// Output clean JSON only
|
||||
std::cout << kernel_instance << std::endl;
|
||||
*output_stream_ << kernel_instance << std::endl;
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cout << "**********************************" << std::endl;
|
||||
std::cout << "According to given metrics: " << get_metric_name(metric) << "\n"
|
||||
<< "Current kernel performance is: " << kernel_instance << std::endl;
|
||||
std::cout << "**********************************" << std::endl;
|
||||
*output_stream_ << "**********************************" << std::endl;
|
||||
*output_stream_ << "According to given metrics: " << get_metric_name(metric) << "\n"
|
||||
<< "Current kernel performance is: " << kernel_instance << std::endl;
|
||||
*output_stream_ << "**********************************" << std::endl;
|
||||
}
|
||||
|
||||
if(!setting_.csv_filename_.empty())
|
||||
@@ -299,9 +299,13 @@ class GemmMultiDProfiler
|
||||
|
||||
private:
|
||||
~GemmMultiDProfiler() { kernel_instances_.clear(); }
|
||||
GemmMultiDProfiler(Setting setting) : setting_(setting) {}
|
||||
GemmMultiDProfiler(Setting setting, std::ostream* output_stream = &std::cout)
|
||||
: setting_(setting), output_stream_(output_stream)
|
||||
{
|
||||
}
|
||||
|
||||
Setting setting_;
|
||||
std::ostream* output_stream_;
|
||||
|
||||
std::vector<KernelInstance> kernel_instances_;
|
||||
};
|
||||
|
||||
@@ -214,28 +214,23 @@ class GemmPreshuffleBenchmark:
|
||||
|
||||
# Add JSON output flag for clean JSON output
|
||||
cmd.append("-json_output=true")
|
||||
cmd.append(f"-json_file={json_file}")
|
||||
|
||||
if self.verbose:
|
||||
print(f"Running: {' '.join(cmd)}")
|
||||
|
||||
try:
|
||||
result = subprocess.run(cmd, capture_output=True, text=True, timeout=60)
|
||||
result = subprocess.run(cmd, timeout=60)
|
||||
|
||||
if result.returncode != 0:
|
||||
print(f"Error running {kernel_path.name}: {result.stderr}")
|
||||
print(f"Error running {kernel_path.name}")
|
||||
return None
|
||||
|
||||
# Save raw output to individual JSON file
|
||||
output = result.stdout.strip()
|
||||
|
||||
if output:
|
||||
with open(json_file, "w") as f:
|
||||
f.write(output)
|
||||
|
||||
# Parse the JSON file
|
||||
# Parse the JSON file that was written by the C++ program
|
||||
if json_file.exists():
|
||||
return self.parse_json_file(json_file)
|
||||
else:
|
||||
print(f"No output from {kernel_path.name}")
|
||||
print(f"No output file created by {kernel_path.name}")
|
||||
return None
|
||||
|
||||
except subprocess.TimeoutExpired:
|
||||
|
||||
@@ -69,7 +69,10 @@ inline auto create_args(int argc, char* argv[])
|
||||
"false",
|
||||
"Whether to output results in JSON format only. Possible values are true or false. "
|
||||
"Default is "
|
||||
"false");
|
||||
"false")
|
||||
.insert("json_file",
|
||||
"",
|
||||
"The filename for JSON output. If empty, output goes to stdout. Default is empty.");
|
||||
|
||||
bool result = arg_parser.parse(argc, argv);
|
||||
return std::make_tuple(result, arg_parser);
|
||||
@@ -118,8 +121,27 @@ void benchmark_single(const ck_tile::ArgParser& arg_parser)
|
||||
arg_parser.get_int("rotating_count"),
|
||||
arg_parser.get_bool("json_output")};
|
||||
|
||||
// Get the profiler instance
|
||||
auto& profiler = GemmProfiler::instance(setting);
|
||||
// Handle output stream - either file or stdout
|
||||
std::string json_filename = arg_parser.get_str("json_file");
|
||||
std::ofstream json_file_stream;
|
||||
std::ostream* output_stream = &std::cout;
|
||||
|
||||
if(!json_filename.empty())
|
||||
{
|
||||
json_file_stream.open(json_filename);
|
||||
if(json_file_stream.is_open())
|
||||
{
|
||||
output_stream = &json_file_stream;
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cerr << "Warning: Failed to open JSON file " << json_filename
|
||||
<< ", using stdout instead." << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
// Get the profiler instance with output stream
|
||||
auto& profiler = GemmProfiler::instance(setting, output_stream);
|
||||
|
||||
try
|
||||
{
|
||||
|
||||
@@ -10,9 +10,9 @@
|
||||
class GemmProfiler
|
||||
{
|
||||
public:
|
||||
static GemmProfiler& instance(Setting setting)
|
||||
static GemmProfiler& instance(Setting setting, std::ostream* output_stream = &std::cout)
|
||||
{
|
||||
static GemmProfiler instance{setting};
|
||||
static GemmProfiler instance{setting, output_stream};
|
||||
return instance;
|
||||
}
|
||||
|
||||
@@ -182,7 +182,7 @@ class GemmProfiler
|
||||
|
||||
if(setting_.log_ > 0 && !setting_.json_output_)
|
||||
{
|
||||
std::cout << kernel_instance << std::endl;
|
||||
*output_stream_ << kernel_instance << std::endl;
|
||||
}
|
||||
|
||||
// verify result
|
||||
@@ -198,7 +198,7 @@ class GemmProfiler
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cout << "Verification failed, skip kernel: " << name << std::endl;
|
||||
*output_stream_ << "Verification failed, skip kernel: " << name << std::endl;
|
||||
}
|
||||
|
||||
// clear tensor
|
||||
@@ -221,14 +221,14 @@ class GemmProfiler
|
||||
if(setting_.json_output_)
|
||||
{
|
||||
// Output clean JSON only
|
||||
std::cout << kernel_instance << std::endl;
|
||||
*output_stream_ << kernel_instance << std::endl;
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cout << "**********************************" << std::endl;
|
||||
std::cout << "According to given metrics: " << get_metric_name(metric) << "\n"
|
||||
<< "Current kernel performance is: " << kernel_instance << std::endl;
|
||||
std::cout << "**********************************" << std::endl;
|
||||
*output_stream_ << "**********************************" << std::endl;
|
||||
*output_stream_ << "According to given metrics: " << get_metric_name(metric) << "\n"
|
||||
<< "Current kernel performance is: " << kernel_instance << std::endl;
|
||||
*output_stream_ << "**********************************" << std::endl;
|
||||
}
|
||||
|
||||
if(!setting_.csv_filename_.empty())
|
||||
@@ -281,9 +281,13 @@ class GemmProfiler
|
||||
|
||||
private:
|
||||
~GemmProfiler() { kernel_instances_.clear(); }
|
||||
GemmProfiler(Setting setting) : setting_(setting) {}
|
||||
GemmProfiler(Setting setting, std::ostream* output_stream = &std::cout)
|
||||
: setting_(setting), output_stream_(output_stream)
|
||||
{
|
||||
}
|
||||
|
||||
Setting setting_;
|
||||
std::ostream* output_stream_;
|
||||
|
||||
std::vector<KernelInstance> kernel_instances_;
|
||||
};
|
||||
|
||||
Reference in New Issue
Block a user