mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-07-01 04:07:56 +00:00
Improve test scripts.
This commit is contained in:
@@ -167,18 +167,21 @@ def main():
|
||||
'Command': ktn_commands,
|
||||
})
|
||||
|
||||
# Take a randomly sampled subset of Fremont commands
|
||||
if not args.full_set:
|
||||
commands_fremont_df = commands_fremont_df.sample(n=min(n_fremont_shapes, len(commands_fremont_df)), random_state=seed)
|
||||
# The hardest cases are at the beginning of the Fremont CSV file.
|
||||
commands_fremont_df = commands_fremont_df.sample(n=min(n_fremont_shapes, len(commands_fremont_df)))
|
||||
commands_ktn_df = commands_ktn_df.sample(n=min(n_ktn_shapes, len(commands_ktn_df)), random_state=seed)
|
||||
|
||||
# Combine the two DataFrames
|
||||
commands_df = pd.concat([commands_fremont_df, commands_ktn_df], ignore_index=True)
|
||||
# Randomly permute the commands
|
||||
commands_df = commands_df.sample(frac=1, random_state=seed).reset_index(drop=True)
|
||||
|
||||
output_file = os.path.join(args.output_path, "ck_profiler_commands.csv")
|
||||
with open(output_file, "w") as f:
|
||||
csv_writer = csv.writer(f, delimiter=',', quotechar='"', quoting=csv.QUOTE_MINIMAL)
|
||||
for command in commands_fremont_df['Command']:
|
||||
for command in commands_df['Command']:
|
||||
csv_writer.writerow(command)
|
||||
for command in commands_ktn_df['Command']:
|
||||
csv_writer.writerow(command)
|
||||
print(f"Commands saved to {output_file}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -4,12 +4,15 @@ import os
|
||||
import argparse
|
||||
import subprocess
|
||||
import sys
|
||||
import pandas as pd
|
||||
|
||||
def parse_cli_args():
|
||||
"""Parse command line arguments"""
|
||||
parser = argparse.ArgumentParser(description="Run CK convolution profiler.")
|
||||
parser.add_argument("--csv-file", type=str, dest="csv_file", required=True, help="Path to the CSV file containing test cases.")
|
||||
parser.add_argument("--log-to-stdout", action="store_true", help="Log profiler output to stdout instead of /dev/null.")
|
||||
parser.add_argument("--start", type=int, default=None, help="Start index for the commands to run (1-based).")
|
||||
parser.add_argument("--end", type=int, default=None, help="End index for the commands to run (1-based).")
|
||||
|
||||
args, unknown_args = parser.parse_known_args()
|
||||
|
||||
@@ -20,24 +23,31 @@ def parse_cli_args():
|
||||
return args
|
||||
|
||||
def run_ck_profiler_cmd(cmd, log_to_stdout=False):
|
||||
cmd_concatenated_str = ""
|
||||
for arg in cmd:
|
||||
cmd_concatenated_str += arg + " "
|
||||
working_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
pid = os.getpid()
|
||||
env_vars = os.environ.copy()
|
||||
env_vars["CK_PROFILER_DISABLED_OPS"] = ""
|
||||
env_vars["CK_PROFILER_DISABLED_OPS"] = "DeviceGroupedConvBwdWeight_Xdl_CShuffleV3;DeviceGroupedConvBwdWeight_Dl;DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle;DeviceGroupedConvBwdWeight_Explicit_Xdl"
|
||||
env_vars["CK_PROFILER_OUTPUT_FILE"] = f"{working_dir}/conv_profiler_output_{pid}.csv"
|
||||
if log_to_stdout:
|
||||
subprocess.run(cmd, env=env_vars, stdout=devnull)
|
||||
subprocess.run(cmd, env=env_vars)
|
||||
else:
|
||||
with open(os.devnull, 'w') as devnull:
|
||||
subprocess.run(cmd, env=env_vars, stdout=devnull, stderr=devnull)
|
||||
subprocess.run(cmd, env=env_vars, stdout=devnull)
|
||||
|
||||
def get_profiler_commands(csv_file):
|
||||
profiler_commands = []
|
||||
with open(csv_file, 'r') as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
cmd = line.split(',')
|
||||
profiler_commands.append(cmd)
|
||||
return profiler_commands
|
||||
|
||||
def main():
|
||||
args = parse_cli_args()
|
||||
profiler_commands = get_profiler_commands(args.csv_file, args.no_verification, args.fwd_only, args.bwd_data_only, args.bwd_weight_only)
|
||||
profiler_commands = get_profiler_commands(args.csv_file)
|
||||
print(f"Got {len(profiler_commands)} commands in total to run.")
|
||||
|
||||
if args.start is not None:
|
||||
end = len(profiler_commands)
|
||||
if args.end is not None:
|
||||
@@ -45,7 +55,13 @@ def main():
|
||||
profiler_commands = profiler_commands[args.start-1:end]
|
||||
|
||||
for i, cmd in enumerate(profiler_commands):
|
||||
print(f"Running command {i + 1}/{len(profiler_commands)}: {cmd}")
|
||||
cmd_concatenated_str = ""
|
||||
for arg in cmd:
|
||||
cmd_concatenated_str += arg + " "
|
||||
cmd_concatenated_str = cmd_concatenated_str.strip()
|
||||
print(f"\n##################################################################################################################################")
|
||||
print(f"Running command {i + 1}/{len(profiler_commands)}: {cmd_concatenated_str}")
|
||||
print(f"##################################################################################################################################")
|
||||
run_ck_profiler_cmd(cmd, args.log_to_stdout)
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
Reference in New Issue
Block a user