Files
stable-diffusion-webui-forge/backend/utils.py
layerdiffusion 07b2d2ccac clipvision, ipadapter, and misc
backend is 75% finished
2024-08-03 14:18:16 -07:00

64 lines
1.8 KiB
Python

import torch
import safetensors.torch
import backend.misc.checkpoint_pickle
def load_torch_file(ckpt, safe_load=False, device=None):
if device is None:
device = torch.device("cpu")
if ckpt.lower().endswith(".safetensors"):
sd = safetensors.torch.load_file(ckpt, device=device.type)
else:
if safe_load:
if not 'weights_only' in torch.load.__code__.co_varnames:
print("Warning torch.load doesn't support weights_only on this pytorch version, loading unsafely.")
safe_load = False
if safe_load:
pl_sd = torch.load(ckpt, map_location=device, weights_only=True)
else:
pl_sd = torch.load(ckpt, map_location=device, pickle_module=backend.misc.checkpoint_pickle)
if "global_step" in pl_sd:
print(f"Global Step: {pl_sd['global_step']}")
if "state_dict" in pl_sd:
sd = pl_sd["state_dict"]
else:
sd = pl_sd
return sd
def set_attr(obj, attr, value):
attrs = attr.split(".")
for name in attrs[:-1]:
obj = getattr(obj, name)
setattr(obj, attrs[-1], torch.nn.Parameter(value, requires_grad=False))
def set_attr_raw(obj, attr, value):
attrs = attr.split(".")
for name in attrs[:-1]:
obj = getattr(obj, name)
setattr(obj, attrs[-1], value)
def copy_to_param(obj, attr, value):
attrs = attr.split(".")
for name in attrs[:-1]:
obj = getattr(obj, name)
prev = getattr(obj, attrs[-1])
prev.data.copy_(value)
def get_attr(obj, attr):
attrs = attr.split(".")
for name in attrs:
obj = getattr(obj, name)
return obj
def calculate_parameters(sd, prefix=""):
params = 0
for k in sd.keys():
if k.startswith(prefix):
params += sd[k].nelement()
return params