diff --git a/backend/args.py b/backend/args.py new file mode 100644 index 00000000..3ef5f5af --- /dev/null +++ b/backend/args.py @@ -0,0 +1,7 @@ +import argparse + +parser = argparse.ArgumentParser() + +parser.add_argument("--cuda-stream", action="store_true") + +args = parser.parse_known_args()[0] diff --git a/backend/stream.py b/backend/stream.py index a231caf5..3972d0e4 100644 --- a/backend/stream.py +++ b/backend/stream.py @@ -1,5 +1,5 @@ import torch -import argparse +from backend import args def stream_context(): @@ -56,11 +56,6 @@ current_stream = None mover_stream = None using_stream = False - -parser = argparse.ArgumentParser() -parser.add_argument("--cuda-stream", action="store_true") -args = parser.parse_known_args()[0] - if args.cuda_stream: current_stream = get_current_stream() mover_stream = get_new_stream()