diff --git a/script/generate_test_data.py b/script/generate_test_data.py index d018313383..1f8a094fe1 100644 --- a/script/generate_test_data.py +++ b/script/generate_test_data.py @@ -167,18 +167,21 @@ def main(): 'Command': ktn_commands, }) - # Take a randomly sampled subset of Fremont commands if not args.full_set: - commands_fremont_df = commands_fremont_df.sample(n=min(n_fremont_shapes, len(commands_fremont_df)), random_state=seed) + # The hardest cases are at the beginning of the Fremont CSV file. + commands_fremont_df = commands_fremont_df.sample(n=min(n_fremont_shapes, len(commands_fremont_df))) commands_ktn_df = commands_ktn_df.sample(n=min(n_ktn_shapes, len(commands_ktn_df)), random_state=seed) + # Combine the two DataFrames + commands_df = pd.concat([commands_fremont_df, commands_ktn_df], ignore_index=True) + # Randomly permute the commands + commands_df = commands_df.sample(frac=1, random_state=seed).reset_index(drop=True) + output_file = os.path.join(args.output_path, "ck_profiler_commands.csv") with open(output_file, "w") as f: csv_writer = csv.writer(f, delimiter=',', quotechar='"', quoting=csv.QUOTE_MINIMAL) - for command in commands_fremont_df['Command']: + for command in commands_df['Command']: csv_writer.writerow(command) - for command in commands_ktn_df['Command']: - csv_writer.writerow(command) print(f"Commands saved to {output_file}") if __name__ == "__main__": diff --git a/script/run_conv_profiler.py b/script/run_conv_profiler.py index 6e78add8a9..4aa6e0a323 100644 --- a/script/run_conv_profiler.py +++ b/script/run_conv_profiler.py @@ -4,12 +4,15 @@ import os import argparse import subprocess import sys +import pandas as pd def parse_cli_args(): """Parse command line arguments""" parser = argparse.ArgumentParser(description="Run CK convolution profiler.") parser.add_argument("--csv-file", type=str, dest="csv_file", required=True, help="Path to the CSV file containing test cases.") parser.add_argument("--log-to-stdout", action="store_true", help="Log profiler output to stdout instead of /dev/null.") + parser.add_argument("--start", type=int, default=None, help="Start index for the commands to run (1-based).") + parser.add_argument("--end", type=int, default=None, help="End index for the commands to run (1-based).") args, unknown_args = parser.parse_known_args() @@ -20,24 +23,31 @@ def parse_cli_args(): return args def run_ck_profiler_cmd(cmd, log_to_stdout=False): - cmd_concatenated_str = "" - for arg in cmd: - cmd_concatenated_str += arg + " " working_dir = os.path.dirname(os.path.abspath(__file__)) pid = os.getpid() env_vars = os.environ.copy() - env_vars["CK_PROFILER_DISABLED_OPS"] = "" + env_vars["CK_PROFILER_DISABLED_OPS"] = "DeviceGroupedConvBwdWeight_Xdl_CShuffleV3;DeviceGroupedConvBwdWeight_Dl;DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle;DeviceGroupedConvBwdWeight_Explicit_Xdl" env_vars["CK_PROFILER_OUTPUT_FILE"] = f"{working_dir}/conv_profiler_output_{pid}.csv" if log_to_stdout: - subprocess.run(cmd, env=env_vars, stdout=devnull) + subprocess.run(cmd, env=env_vars) else: with open(os.devnull, 'w') as devnull: - subprocess.run(cmd, env=env_vars, stdout=devnull, stderr=devnull) + subprocess.run(cmd, env=env_vars, stdout=devnull) + +def get_profiler_commands(csv_file): + profiler_commands = [] + with open(csv_file, 'r') as f: + for line in f: + line = line.strip() + cmd = line.split(',') + profiler_commands.append(cmd) + return profiler_commands def main(): args = parse_cli_args() - profiler_commands = get_profiler_commands(args.csv_file, args.no_verification, args.fwd_only, args.bwd_data_only, args.bwd_weight_only) + profiler_commands = get_profiler_commands(args.csv_file) print(f"Got {len(profiler_commands)} commands in total to run.") + if args.start is not None: end = len(profiler_commands) if args.end is not None: @@ -45,7 +55,13 @@ def main(): profiler_commands = profiler_commands[args.start-1:end] for i, cmd in enumerate(profiler_commands): - print(f"Running command {i + 1}/{len(profiler_commands)}: {cmd}") + cmd_concatenated_str = "" + for arg in cmd: + cmd_concatenated_str += arg + " " + cmd_concatenated_str = cmd_concatenated_str.strip() + print(f"\n##################################################################################################################################") + print(f"Running command {i + 1}/{len(profiler_commands)}: {cmd_concatenated_str}") + print(f"##################################################################################################################################") run_ck_profiler_cmd(cmd, args.log_to_stdout) if __name__ == "__main__":