mirror of
https://github.com/lllyasviel/stable-diffusion-webui-forge.git
synced 2026-01-26 19:09:45 +00:00
1. Add an option to allow users to use UNet in fp8/gguf but lora in fp16. 2. All FP16 loras do not need patch. Others will only patch again when lora weight change. 3. FP8 unet + fp16 lora are available (somewhat only available) in Forge now. This also solves some “LoRA too subtle” problems. 4. Significantly speed up all gguf models (in Async mode) by using independent thread (CUDA stream) to compute and dequant at the same time, even when low-bit weights are already on GPU. 5. View “online lora” as a module similar to ControlLoRA so that it is moved to GPU together with model when sampling, achieving significant speedup and perfect low VRAM management simultaneously.
141 lines
4.0 KiB
Python
141 lines
4.0 KiB
Python
import gguf
|
|
import torch
|
|
import os
|
|
import json
|
|
import safetensors.torch
|
|
import backend.misc.checkpoint_pickle
|
|
from backend.operations_gguf import ParameterGGUF
|
|
|
|
|
|
def read_arbitrary_config(directory):
|
|
config_path = os.path.join(directory, 'config.json')
|
|
|
|
if not os.path.exists(config_path):
|
|
raise FileNotFoundError(f"No config.json file found in the directory: {directory}")
|
|
|
|
with open(config_path, 'rt', encoding='utf-8') as file:
|
|
config_data = json.load(file)
|
|
|
|
return config_data
|
|
|
|
|
|
def load_torch_file(ckpt, safe_load=False, device=None):
|
|
if device is None:
|
|
device = torch.device("cpu")
|
|
if ckpt.lower().endswith(".safetensors"):
|
|
sd = safetensors.torch.load_file(ckpt, device=device.type)
|
|
elif ckpt.lower().endswith(".gguf"):
|
|
reader = gguf.GGUFReader(ckpt)
|
|
sd = {}
|
|
for tensor in reader.tensors:
|
|
sd[str(tensor.name)] = ParameterGGUF(tensor)
|
|
else:
|
|
if safe_load:
|
|
if not 'weights_only' in torch.load.__code__.co_varnames:
|
|
print("Warning torch.load doesn't support weights_only on this pytorch version, loading unsafely.")
|
|
safe_load = False
|
|
if safe_load:
|
|
pl_sd = torch.load(ckpt, map_location=device, weights_only=True)
|
|
else:
|
|
pl_sd = torch.load(ckpt, map_location=device, pickle_module=backend.misc.checkpoint_pickle)
|
|
if "global_step" in pl_sd:
|
|
print(f"Global Step: {pl_sd['global_step']}")
|
|
if "state_dict" in pl_sd:
|
|
sd = pl_sd["state_dict"]
|
|
else:
|
|
sd = pl_sd
|
|
return sd
|
|
|
|
|
|
def set_attr(obj, attr, value):
|
|
attrs = attr.split(".")
|
|
for name in attrs[:-1]:
|
|
obj = getattr(obj, name)
|
|
setattr(obj, attrs[-1], torch.nn.Parameter(value, requires_grad=False))
|
|
|
|
|
|
def set_attr_raw(obj, attr, value):
|
|
attrs = attr.split(".")
|
|
for name in attrs[:-1]:
|
|
obj = getattr(obj, name)
|
|
setattr(obj, attrs[-1], value)
|
|
|
|
|
|
def copy_to_param(obj, attr, value):
|
|
attrs = attr.split(".")
|
|
for name in attrs[:-1]:
|
|
obj = getattr(obj, name)
|
|
prev = getattr(obj, attrs[-1])
|
|
prev.data.copy_(value)
|
|
|
|
|
|
def get_attr(obj, attr):
|
|
attrs = attr.split(".")
|
|
for name in attrs:
|
|
obj = getattr(obj, name)
|
|
return obj
|
|
|
|
|
|
def get_attr_with_parent(obj, attr):
|
|
attrs = attr.split(".")
|
|
parent = obj
|
|
name = None
|
|
for name in attrs:
|
|
parent = obj
|
|
obj = getattr(obj, name)
|
|
return parent, name, obj
|
|
|
|
|
|
def calculate_parameters(sd, prefix=""):
|
|
params = 0
|
|
for k in sd.keys():
|
|
if k.startswith(prefix):
|
|
params += sd[k].nelement()
|
|
return params
|
|
|
|
|
|
def tensor2parameter(x):
|
|
if isinstance(x, torch.nn.Parameter):
|
|
return x
|
|
else:
|
|
return torch.nn.Parameter(x, requires_grad=False)
|
|
|
|
|
|
def fp16_fix(x):
|
|
# An interesting trick to avoid fp16 overflow
|
|
# Source: https://github.com/lllyasviel/stable-diffusion-webui-forge/issues/1114
|
|
# Related: https://github.com/comfyanonymous/ComfyUI/blob/f1d6cef71c70719cc3ed45a2455a4e5ac910cd5e/comfy/ldm/flux/layers.py#L180
|
|
|
|
if x.dtype in [torch.float16]:
|
|
return x.clip(-32768.0, 32768.0)
|
|
return x
|
|
|
|
|
|
def nested_compute_size(obj):
|
|
module_mem = 0
|
|
|
|
if isinstance(obj, dict):
|
|
for key in obj:
|
|
module_mem += nested_compute_size(obj[key])
|
|
elif isinstance(obj, list) or isinstance(obj, tuple):
|
|
for i in range(len(obj)):
|
|
module_mem += nested_compute_size(obj[i])
|
|
elif isinstance(obj, torch.Tensor):
|
|
module_mem += obj.nelement() * obj.element_size()
|
|
|
|
return module_mem
|
|
|
|
|
|
def nested_move_to_device(obj, device):
|
|
if isinstance(obj, dict):
|
|
for key in obj:
|
|
obj[key] = nested_move_to_device(obj[key], device)
|
|
elif isinstance(obj, list):
|
|
for i in range(len(obj)):
|
|
obj[i] = nested_move_to_device(obj[i], device)
|
|
elif isinstance(obj, tuple):
|
|
obj = tuple(nested_move_to_device(i, device) for i in obj)
|
|
elif isinstance(obj, torch.Tensor):
|
|
return obj.to(device)
|
|
return obj
|