diff --git a/backend/operations.py b/backend/operations.py index 03a204d4..4c6adc29 100644 --- a/backend/operations.py +++ b/backend/operations.py @@ -1,7 +1,7 @@ import torch import contextlib -from modules_forge import stream +from backend import stream stash = {} diff --git a/backend/stream.py b/backend/stream.py new file mode 100644 index 00000000..a231caf5 --- /dev/null +++ b/backend/stream.py @@ -0,0 +1,67 @@ +import torch +import argparse + + +def stream_context(): + if torch.cuda.is_available(): + return torch.cuda.stream + + if torch.xpu.is_available(): + return torch.xpu.stream + + return None + + +def get_current_stream(): + try: + if torch.cuda.is_available(): + device = torch.device(torch.cuda.current_device()) + stream = torch.cuda.current_stream(device) + with torch.cuda.stream(stream): + torch.zeros((1, 1)).to(device, torch.float32) + stream.synchronize() + return stream + if torch.xpu.is_available(): + device = torch.device("xpu") + stream = torch.xpu.current_stream(device) + with torch.xpu.stream(stream): + torch.zeros((1, 1)).to(device, torch.float32) + stream.synchronize() + return stream + except: + return None + + +def get_new_stream(): + try: + if torch.cuda.is_available(): + device = torch.device(torch.cuda.current_device()) + stream = torch.cuda.Stream(device) + with torch.cuda.stream(stream): + torch.zeros((1, 1)).to(device, torch.float32) + stream.synchronize() + return stream + if torch.xpu.is_available(): + device = torch.device("xpu") + stream = torch.xpu.Stream(device) + with torch.xpu.stream(stream): + torch.zeros((1, 1)).to(device, torch.float32) + stream.synchronize() + return stream + except: + return None + + +current_stream = None +mover_stream = None +using_stream = False + + +parser = argparse.ArgumentParser() +parser.add_argument("--cuda-stream", action="store_true") +args = parser.parse_known_args()[0] + +if args.cuda_stream: + current_stream = get_current_stream() + mover_stream = get_new_stream() + using_stream = current_stream is not None and mover_stream is not None