mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-07-01 04:07:56 +00:00
Improve benchmarking script.
This commit is contained in:
@@ -9,7 +9,14 @@ import matplotlib.pyplot as plt
|
||||
plt.switch_backend('Agg')
|
||||
import numpy as np
|
||||
|
||||
import xlsxwriter
|
||||
def data_type_str_to_profiler_arg(data_type_str):
|
||||
"""Convert data type string to profiler argument"""
|
||||
data_type_map = {
|
||||
"fp16": "1",
|
||||
"fp32": "0",
|
||||
"int8": "3"
|
||||
}
|
||||
return data_type_map.get(data_type_str.lower(), 1) # Default to fp16 if unknown
|
||||
|
||||
def parse_cli_args():
|
||||
"""Parse command line arguments"""
|
||||
@@ -19,6 +26,7 @@ def parse_cli_args():
|
||||
parser.add_argument("--bin-path", type=str, dest="bin_path", required=False, help="Path to the CK/CK Tile profiler executables.")
|
||||
parser.add_argument("--results-path", type=str, dest="results_path", required=False, help="Path to store profiler results.", default=".")
|
||||
parser.add_argument("--analyze-file", type=str, dest="analyze_file", required=False, help="Path to store analysis results.", default="")
|
||||
parser.add_argument("--data-type", type=str, dest="data_type", required=False, help="Data type for the profiler (e.g., fp16, fp32).", default="fp16")
|
||||
|
||||
args, unknown_args = parser.parse_known_args()
|
||||
|
||||
@@ -61,7 +69,7 @@ def get_profiler_commands(file):
|
||||
lines = list(dict.fromkeys(lines))
|
||||
for line in lines:
|
||||
line = line.strip()
|
||||
cmd = [x.strip() for x in line.split(' ') if x.strip() and x.strip() != '']
|
||||
cmd = [x.strip() for x in line.split(',') if x.strip() and x.strip() != '']
|
||||
profiler_commands.append(cmd)
|
||||
return profiler_commands
|
||||
|
||||
@@ -408,18 +416,26 @@ def main():
|
||||
if not os.path.exists(args.results_path):
|
||||
os.makedirs(args.results_path)
|
||||
|
||||
results_file = os.path.join(args.results_path, f"ck_vs_ck_tile_results_{os.getpid()}.txt")
|
||||
results_file = os.path.join(args.results_path, f"ck_results_{args.data_type}_{os.getpid()}.txt")
|
||||
|
||||
data_type_arg = data_type_str_to_profiler_arg(args.data_type)
|
||||
|
||||
for i, cmd in enumerate(profiler_commands):
|
||||
cmd_concatenated_str = ' '.join(cmd)
|
||||
print(f"\n####################################################################################################################")
|
||||
print(f"Running command {i + 1}/{len(profiler_commands)}: {cmd_concatenated_str}")
|
||||
print(f"######################################################################################################################")
|
||||
with open(results_file, 'a') as f:
|
||||
f.write(cmd_concatenated_str + "\n")
|
||||
run_ck_profiler_cmd(cmd, ProfilerType.CK_TILE, args.bin_path, results_file, args.log_to_stdout)
|
||||
# with open(results_file, 'a') as f:
|
||||
# f.write(cmd_concatenated_str + "\n")
|
||||
# run_ck_profiler_cmd(cmd, ProfilerType.CK_TILE, args.bin_path, results_file, args.log_to_stdout)
|
||||
|
||||
# For the old CK, we don't want to run verification. We assume CK already works correctly.
|
||||
# Set the correct data type based on user input
|
||||
cmd[1] = data_type_arg
|
||||
|
||||
# Set layout arg to Input[N, Hi, Wi, G, C], Weight[G, K, Y, X, C], Output[N, Ho, Wo, G, K]
|
||||
cmd[2] = '1'
|
||||
|
||||
# We don't want to run verification. We assume CK already works correctly.
|
||||
cmd[3] = '0' # Set verification flag to 0 (no verification)
|
||||
|
||||
run_ck_profiler_cmd(cmd, ProfilerType.CK, args.bin_path, results_file, args.log_to_stdout)
|
||||
|
||||
Reference in New Issue
Block a user