Merge pull request #70 from YellowRoseCx/patch-1

This commit is contained in:
Cohee
2023-06-30 10:00:21 +03:00
committed by GitHub

View File

@@ -145,7 +145,7 @@ if len(modules) == 0:
# Models init
cuda_device = DEFAULT_CUDA_DEVICE if not args.cuda_device else args.cuda_device
device_string = cuda_device if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu'
device_string = cuda_device if torch.cuda.is_available() and not args.cpu else 'mps' if torch.backends.mps.is_available() and not args.cpu else 'cpu'
device = torch.device(device_string)
torch_dtype = torch.float32 if device_string != cuda_device else torch.float16