Update devices.py

This commit is contained in:
lllyasviel
2024-01-24 10:51:36 -08:00
parent d2ea8793aa
commit 2ed12510ba

View File

@@ -1,188 +1,81 @@
import sys
import contextlib import contextlib
from functools import lru_cache
import torch import torch
from modules import errors, shared import ldm_patched.modules.model_management as model_management
from modules import torch_utils
if sys.platform == "darwin":
from modules import mac_specific
if shared.cmd_opts.use_ipex:
from modules import xpu_specific
def has_xpu() -> bool: def has_xpu() -> bool:
return shared.cmd_opts.use_ipex and xpu_specific.has_xpu return model_management.xpu_available
def has_mps() -> bool: def has_mps() -> bool:
if sys.platform != "darwin": return model_management.mps_mode()
return False
else:
return mac_specific.has_mps
def cuda_no_autocast(device_id=None) -> bool: def cuda_no_autocast(device_id=None) -> bool:
if device_id is None: return False
device_id = get_cuda_device_id()
return (
torch.cuda.get_device_capability(device_id) == (7, 5)
and torch.cuda.get_device_name(device_id).startswith("NVIDIA GeForce GTX 16")
)
def get_cuda_device_id(): def get_cuda_device_id():
return ( return model_management.get_torch_device().index
int(shared.cmd_opts.device_id)
if shared.cmd_opts.device_id is not None and shared.cmd_opts.device_id.isdigit()
else 0
) or torch.cuda.current_device()
def get_cuda_device_string(): def get_cuda_device_string():
if shared.cmd_opts.device_id is not None: return str(model_management.get_torch_device())
return f"cuda:{shared.cmd_opts.device_id}"
return "cuda"
def get_optimal_device_name(): def get_optimal_device_name():
if torch.cuda.is_available(): return model_management.get_torch_device().type
return get_cuda_device_string()
if has_mps():
return "mps"
if has_xpu():
return xpu_specific.get_xpu_device_string()
return "cpu"
def get_optimal_device(): def get_optimal_device():
return torch.device(get_optimal_device_name()) return model_management.get_torch_device()
def get_device_for(task): def get_device_for(task):
if task in shared.cmd_opts.use_cpu or "all" in shared.cmd_opts.use_cpu:
return cpu
return get_optimal_device() return get_optimal_device()
def torch_gc(): def torch_gc():
model_management.soft_empty_cache()
if torch.cuda.is_available():
with torch.cuda.device(get_cuda_device_string()):
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
if has_mps():
mac_specific.torch_mps_gc()
if has_xpu():
xpu_specific.torch_xpu_gc()
def enable_tf32(): def enable_tf32():
if torch.cuda.is_available(): return
# enabling benchmark option seems to enable a range of cards to do fp16 when they otherwise can't
# see https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/4407
if cuda_no_autocast():
torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
errors.run(enable_tf32, "Enabling TF32")
cpu: torch.device = torch.device("cpu") cpu: torch.device = torch.device("cpu")
fp8: bool = False fp8: bool = False
device: torch.device = None device: torch.device = model_management.get_torch_device()
device_interrogate: torch.device = None device_interrogate: torch.device = model_management.text_encoder_device()
device_gfpgan: torch.device = None device_gfpgan: torch.device = model_management.get_torch_device()
device_esrgan: torch.device = None device_esrgan: torch.device = model_management.get_torch_device()
device_codeformer: torch.device = None device_codeformer: torch.device = model_management.get_torch_device()
dtype: torch.dtype = torch.float16 dtype: torch.dtype = model_management.unet_dtype()
dtype_vae: torch.dtype = torch.float16 dtype_vae: torch.dtype = model_management.vae_dtype()
dtype_unet: torch.dtype = torch.float16 dtype_unet: torch.dtype = model_management.unet_dtype()
dtype_inference: torch.dtype = torch.float16 dtype_inference: torch.dtype = model_management.unet_dtype()
unet_needs_upcast = False unet_needs_upcast = False
def cond_cast_unet(input): def cond_cast_unet(input):
return input.to(dtype_unet) if unet_needs_upcast else input return input
def cond_cast_float(input): def cond_cast_float(input):
return input.float() if unet_needs_upcast else input return input
nv_rng = None nv_rng = None
patch_module_list = [ patch_module_list = []
torch.nn.Linear,
torch.nn.Conv2d,
torch.nn.MultiheadAttention,
torch.nn.GroupNorm,
torch.nn.LayerNorm,
]
def manual_cast_forward(target_dtype): def manual_cast_forward(target_dtype):
def forward_wrapper(self, *args, **kwargs): return
if any(
isinstance(arg, torch.Tensor) and arg.dtype != target_dtype
for arg in args
):
args = [arg.to(target_dtype) if isinstance(arg, torch.Tensor) else arg for arg in args]
kwargs = {k: v.to(target_dtype) if isinstance(v, torch.Tensor) else v for k, v in kwargs.items()}
org_dtype = torch_utils.get_param(self).dtype
if org_dtype != target_dtype:
self.to(target_dtype)
result = self.org_forward(*args, **kwargs)
if org_dtype != target_dtype:
self.to(org_dtype)
if target_dtype != dtype_inference:
if isinstance(result, tuple):
result = tuple(
i.to(dtype_inference)
if isinstance(i, torch.Tensor)
else i
for i in result
)
elif isinstance(result, torch.Tensor):
result = result.to(dtype_inference)
return result
return forward_wrapper
@contextlib.contextmanager @contextlib.contextmanager
def manual_cast(target_dtype): def manual_cast(target_dtype):
applied = False return
for module_type in patch_module_list:
if hasattr(module_type, "org_forward"):
continue
applied = True
org_forward = module_type.forward
if module_type == torch.nn.MultiheadAttention and has_xpu():
module_type.forward = manual_cast_forward(torch.float32)
else:
module_type.forward = manual_cast_forward(target_dtype)
module_type.org_forward = org_forward
try:
yield None
finally:
if applied:
for module_type in patch_module_list:
if hasattr(module_type, "org_forward"):
module_type.forward = module_type.org_forward
delattr(module_type, "org_forward")
def autocast(disable=False): def autocast(disable=False):
@@ -198,43 +91,9 @@ class NansException(Exception):
def test_for_nans(x, where): def test_for_nans(x, where):
if shared.cmd_opts.disable_nan_check: return
return
if not torch.all(torch.isnan(x)).item():
return
if where == "unet":
message = "A tensor with all NaNs was produced in Unet."
if not shared.cmd_opts.no_half:
message += " This could be either because there's not enough precision to represent the picture, or because your video card does not support half type. Try setting the \"Upcast cross attention layer to float32\" option in Settings > Stable Diffusion or using the --no-half commandline argument to fix this."
elif where == "vae":
message = "A tensor with all NaNs was produced in VAE."
if not shared.cmd_opts.no_half and not shared.cmd_opts.no_half_vae:
message += " This could be because there's not enough precision to represent the picture. Try adding --no-half-vae commandline argument to fix this."
else:
message = "A tensor with all NaNs was produced."
message += " Use --disable-nan-check commandline argument to disable this check."
raise NansException(message)
@lru_cache
def first_time_calculation(): def first_time_calculation():
""" return
just do any calculation with pytorch layers - the first time this is done it allocaltes about 700MB of memory and
spends about 2.7 seconds doing that, at least wih NVidia.
"""
x = torch.zeros((1, 1)).to(device, dtype)
linear = torch.nn.Linear(1, 1).to(device, dtype)
linear(x)
x = torch.zeros((1, 1, 3, 3)).to(device, dtype)
conv2d = torch.nn.Conv2d(1, 1, (3, 3)).to(device, dtype)
conv2d(x)