Fix: inverted sense bug with cuda sd float32

Inverting the use of float16 with SD on cuda was introduced with my MPS
hack.  Fixed.
Also, make the MPS changes more consistent with cuda_device, since we
don't need a hardcoded value any more.
This commit is contained in:
majick
2023-06-28 14:15:46 -07:00
parent aee5e71e45
commit dea4254e01

View File

@@ -147,7 +147,7 @@ if len(modules) == 0:
cuda_device = DEFAULT_CUDA_DEVICE if not args.cuda_device else args.cuda_device
device_string = cuda_device if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu'
device = torch.device(device_string)
torch_dtype = torch.float32 if device_string != "cuda:0" else torch.float16
torch_dtype = torch.float32 if device_string != cuda_device else torch.float16
if not torch.cuda.is_available() and not args.cpu:
print(f"{Fore.YELLOW}{Style.BRIGHT}torch-cuda is not supported on this device.{Style.RESET_ALL}")
@@ -193,7 +193,7 @@ if "sd" in modules and not sd_use_remote:
print("Initializing Stable Diffusion pipeline...")
sd_device_string = cuda_device if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu'
sd_device = torch.device(sd_device_string)
sd_torch_dtype = torch.float32 if sd_device_string != "cpu" else torch.float16
sd_torch_dtype = torch.float32 if sd_device_string != cuda_device else torch.float16
sd_pipe = StableDiffusionPipeline.from_pretrained(
sd_model, custom_pipeline="lpw_stable_diffusion", torch_dtype=sd_torch_dtype
).to(sd_device)