Small fixes to runner script.

This commit is contained in:
Ville Pietilä
2026-02-04 09:55:05 -05:00
parent 73b459c5a4
commit 1c1ac4ef10

View File

@@ -15,8 +15,8 @@ profiler_commands = [
baseline_instances = [
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<256, 64, 64, 32, Default, 16, 16, 2, 2, 4, 4, 4, 1, 1, 1>",
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<128, 16, 64, 64, Default, 16, 16, 1, 2, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v1>",
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<256, 256, 32, 64, Default, 32, 32, 2, 1, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v1>",
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<128, 16, 64, 64, Default, 16, 16, 1, 2, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v1, 1>",
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<256, 256, 32, 64, Default, 32, 32, 2, 1, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v1, 1>",
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<256, 256, 128, 32, Default, 32, 32, 4, 2, 8, 8, 8, 1, 1, 1>",
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<256, 256, 128, 32, OddC, 32, 32, 4, 2, 8, 8, 8, 1, 1, 1>"
]
@@ -35,6 +35,7 @@ def main():
parser.add_argument('--profiler-path', type=str, required=True, help='Path to the profiler binary')
parser.add_argument('--baseline', action='store_true',
help='Run baseline instances (default: run improved instances)')
parser.add_argument("--print-stdout", action='store_true', help='Print CK profiler output to stdout')
args = parser.parse_args()
instances_to_run = baseline_instances if args.baseline else improved_instances
@@ -55,7 +56,10 @@ def main():
print(f"Running profiler for {instance_type} instance {i+1}/{len(profiler_commands)}:")
print(instance)
subprocess.run([ck_profiler_path] + ["grouped_conv_fwd"] + profiler_args, check=True)
res = subprocess.run([ck_profiler_path] + ["grouped_conv_fwd"] + profiler_args, check=True, timeout=300,
capture_output=True, text=True)
if args.print_stdout:
print(res.stdout)
print()
if __name__ == "__main__":