diff --git a/modules/devices.py b/modules/devices.py index 0bda9325..f8daafc0 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -54,7 +54,7 @@ device_interrogate: torch.device = memory_management.text_encoder_device() # fo device_gfpgan: torch.device = memory_management.get_torch_device() # will be managed by memory management system device_esrgan: torch.device = memory_management.get_torch_device() # will be managed by memory management system device_codeformer: torch.device = memory_management.get_torch_device() # will be managed by memory management system -dtype: torch.dtype = memory_management.unet_dtype() +dtype: torch.dtype = torch.float32 if memory_management.unet_dtype() is torch.float32 else torch.float16 dtype_vae: torch.dtype = memory_management.vae_dtype() dtype_unet: torch.dtype = memory_management.unet_dtype() dtype_inference: torch.dtype = memory_management.unet_dtype()