diff --git a/python/test/executor_test.py b/python/test/executor_test.py index 83b2cb86..74dbca11 100644 --- a/python/test/executor_test.py +++ b/python/test/executor_test.py @@ -70,6 +70,7 @@ def bench_correctness( ): type_size = cp.dtype(parse_dtype(dtype_str)).itemsize + print("collective: ", collective) fill_data_kernel_name = "fill_data_%s" % dtype_str if "allgather" in collective: coll = "all_gather" @@ -78,7 +79,7 @@ def bench_correctness( elif "allreduce" in collective: coll = "all_reduce" else: - coll = "all_to_all" + coll = "sendrecv" test_data_kernel_name = "test_data_%s_%s" % (coll, dtype_str) file_dir = os.path.dirname(os.path.abspath(__file__))