mirror of
https://github.com/lllyasviel/stable-diffusion-webui-forge.git
synced 2026-03-02 11:39:48 +00:00
... based on 3 evidences: 1. torch.Tensor.view on one big tensor is slightly faster than calling torch.Tensor.to on multiple small tensors. 2. but torch.Tensor.to with dtype change is significantly slower than torch.Tensor.view 3. “baking” model on GPU is significantly faster than computing on CPU when model load. mainly influence inference of Q8_0, Q4_0/1/K and loading of all quants
429 lines
16 KiB
Python
429 lines
16 KiB
Python
import torch
|
|
import time
|
|
|
|
import packages_3rdparty.webui_lora_collection.lora as lora_utils_webui
|
|
import packages_3rdparty.comfyui_lora_collection.lora as lora_utils_comfyui
|
|
|
|
from tqdm import tqdm
|
|
from backend import memory_management, utils
|
|
from backend.args import dynamic_args
|
|
|
|
|
|
class ForgeLoraCollection:
|
|
# TODO
|
|
pass
|
|
|
|
|
|
extra_weight_calculators = {}
|
|
|
|
lora_utils_forge = ForgeLoraCollection()
|
|
|
|
lora_collection_priority = [lora_utils_forge, lora_utils_webui, lora_utils_comfyui]
|
|
|
|
|
|
def get_function(function_name: str):
|
|
for lora_collection in lora_collection_priority:
|
|
if hasattr(lora_collection, function_name):
|
|
return getattr(lora_collection, function_name)
|
|
|
|
|
|
def load_lora(lora, to_load):
|
|
patch_dict, remaining_dict = get_function('load_lora')(lora, to_load)
|
|
return patch_dict, remaining_dict
|
|
|
|
|
|
def model_lora_keys_clip(model, key_map={}):
|
|
return get_function('model_lora_keys_clip')(model, key_map)
|
|
|
|
|
|
def model_lora_keys_unet(model, key_map={}):
|
|
return get_function('model_lora_keys_unet')(model, key_map)
|
|
|
|
|
|
@torch.inference_mode()
|
|
def weight_decompose(dora_scale, weight, lora_diff, alpha, strength, computation_dtype):
|
|
# Modified from https://github.com/comfyanonymous/ComfyUI/blob/39f114c44bb99d4a221e8da451d4f2a20119c674/comfy/model_patcher.py#L33
|
|
|
|
dora_scale = memory_management.cast_to_device(dora_scale, weight.device, computation_dtype)
|
|
lora_diff *= alpha
|
|
weight_calc = weight + lora_diff.type(weight.dtype)
|
|
weight_norm = (
|
|
weight_calc.transpose(0, 1)
|
|
.reshape(weight_calc.shape[1], -1)
|
|
.norm(dim=1, keepdim=True)
|
|
.reshape(weight_calc.shape[1], *[1] * (weight_calc.dim() - 1))
|
|
.transpose(0, 1)
|
|
)
|
|
|
|
weight_calc *= (dora_scale / weight_norm).type(weight.dtype)
|
|
if strength != 1.0:
|
|
weight_calc -= weight
|
|
weight += strength * weight_calc
|
|
else:
|
|
weight[:] = weight_calc
|
|
return weight
|
|
|
|
|
|
@torch.inference_mode()
|
|
def merge_lora_to_weight(patches, weight, key="online_lora", computation_dtype=torch.float32):
|
|
# Modified from https://github.com/comfyanonymous/ComfyUI/blob/39f114c44bb99d4a221e8da451d4f2a20119c674/comfy/model_patcher.py#L446
|
|
|
|
weight_dtype_backup = None
|
|
|
|
if computation_dtype == weight.dtype:
|
|
weight = weight.clone()
|
|
else:
|
|
weight_dtype_backup = weight.dtype
|
|
weight = weight.to(dtype=computation_dtype)
|
|
|
|
for p in patches:
|
|
strength = p[0]
|
|
v = p[1]
|
|
strength_model = p[2]
|
|
offset = p[3]
|
|
function = p[4]
|
|
if function is None:
|
|
function = lambda a: a
|
|
|
|
old_weight = None
|
|
if offset is not None:
|
|
old_weight = weight
|
|
weight = weight.narrow(offset[0], offset[1], offset[2])
|
|
|
|
if strength_model != 1.0:
|
|
weight *= strength_model
|
|
|
|
if isinstance(v, list):
|
|
v = (merge_lora_to_weight(v[1:], v[0].clone(), key),)
|
|
|
|
patch_type = ''
|
|
|
|
if len(v) == 1:
|
|
patch_type = "diff"
|
|
elif len(v) == 2:
|
|
patch_type = v[0]
|
|
v = v[1]
|
|
|
|
if patch_type == "diff":
|
|
w1 = v[0]
|
|
if strength != 0.0:
|
|
if w1.shape != weight.shape:
|
|
if w1.ndim == weight.ndim == 4:
|
|
new_shape = [max(n, m) for n, m in zip(weight.shape, w1.shape)]
|
|
print(f'Merged with {key} channel changed to {new_shape}')
|
|
new_diff = strength * memory_management.cast_to_device(w1, weight.device, weight.dtype)
|
|
new_weight = torch.zeros(size=new_shape).to(weight)
|
|
new_weight[:weight.shape[0], :weight.shape[1], :weight.shape[2], :weight.shape[3]] = weight
|
|
new_weight[:new_diff.shape[0], :new_diff.shape[1], :new_diff.shape[2], :new_diff.shape[3]] += new_diff
|
|
new_weight = new_weight.contiguous().clone()
|
|
weight = new_weight
|
|
else:
|
|
print("WARNING SHAPE MISMATCH {} WEIGHT NOT MERGED {} != {}".format(key, w1.shape, weight.shape))
|
|
else:
|
|
weight += strength * memory_management.cast_to_device(w1, weight.device, weight.dtype)
|
|
elif patch_type == "lora":
|
|
mat1 = memory_management.cast_to_device(v[0], weight.device, computation_dtype)
|
|
mat2 = memory_management.cast_to_device(v[1], weight.device, computation_dtype)
|
|
dora_scale = v[4]
|
|
if v[2] is not None:
|
|
alpha = v[2] / mat2.shape[0]
|
|
else:
|
|
alpha = 1.0
|
|
|
|
if v[3] is not None:
|
|
mat3 = memory_management.cast_to_device(v[3], weight.device, computation_dtype)
|
|
final_shape = [mat2.shape[1], mat2.shape[0], mat3.shape[2], mat3.shape[3]]
|
|
mat2 = torch.mm(mat2.transpose(0, 1).flatten(start_dim=1), mat3.transpose(0, 1).flatten(start_dim=1)).reshape(final_shape).transpose(0, 1)
|
|
try:
|
|
lora_diff = torch.mm(mat1.flatten(start_dim=1), mat2.flatten(start_dim=1)).reshape(weight.shape)
|
|
if dora_scale is not None:
|
|
weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength, computation_dtype))
|
|
else:
|
|
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
|
|
except Exception as e:
|
|
print("ERROR {} {} {}".format(patch_type, key, e))
|
|
raise e
|
|
elif patch_type == "lokr":
|
|
w1 = v[0]
|
|
w2 = v[1]
|
|
w1_a = v[3]
|
|
w1_b = v[4]
|
|
w2_a = v[5]
|
|
w2_b = v[6]
|
|
t2 = v[7]
|
|
dora_scale = v[8]
|
|
dim = None
|
|
|
|
if w1 is None:
|
|
dim = w1_b.shape[0]
|
|
w1 = torch.mm(memory_management.cast_to_device(w1_a, weight.device, computation_dtype),
|
|
memory_management.cast_to_device(w1_b, weight.device, computation_dtype))
|
|
else:
|
|
w1 = memory_management.cast_to_device(w1, weight.device, computation_dtype)
|
|
|
|
if w2 is None:
|
|
dim = w2_b.shape[0]
|
|
if t2 is None:
|
|
w2 = torch.mm(memory_management.cast_to_device(w2_a, weight.device, computation_dtype),
|
|
memory_management.cast_to_device(w2_b, weight.device, computation_dtype))
|
|
else:
|
|
w2 = torch.einsum('i j k l, j r, i p -> p r k l',
|
|
memory_management.cast_to_device(t2, weight.device, computation_dtype),
|
|
memory_management.cast_to_device(w2_b, weight.device, computation_dtype),
|
|
memory_management.cast_to_device(w2_a, weight.device, computation_dtype))
|
|
else:
|
|
w2 = memory_management.cast_to_device(w2, weight.device, computation_dtype)
|
|
|
|
if len(w2.shape) == 4:
|
|
w1 = w1.unsqueeze(2).unsqueeze(2)
|
|
if v[2] is not None and dim is not None:
|
|
alpha = v[2] / dim
|
|
else:
|
|
alpha = 1.0
|
|
|
|
try:
|
|
lora_diff = torch.kron(w1, w2).reshape(weight.shape)
|
|
if dora_scale is not None:
|
|
weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength, computation_dtype))
|
|
else:
|
|
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
|
|
except Exception as e:
|
|
print("ERROR {} {} {}".format(patch_type, key, e))
|
|
raise e
|
|
elif patch_type == "loha":
|
|
w1a = v[0]
|
|
w1b = v[1]
|
|
if v[2] is not None:
|
|
alpha = v[2] / w1b.shape[0]
|
|
else:
|
|
alpha = 1.0
|
|
|
|
w2a = v[3]
|
|
w2b = v[4]
|
|
dora_scale = v[7]
|
|
if v[5] is not None:
|
|
t1 = v[5]
|
|
t2 = v[6]
|
|
m1 = torch.einsum('i j k l, j r, i p -> p r k l',
|
|
memory_management.cast_to_device(t1, weight.device, computation_dtype),
|
|
memory_management.cast_to_device(w1b, weight.device, computation_dtype),
|
|
memory_management.cast_to_device(w1a, weight.device, computation_dtype))
|
|
|
|
m2 = torch.einsum('i j k l, j r, i p -> p r k l',
|
|
memory_management.cast_to_device(t2, weight.device, computation_dtype),
|
|
memory_management.cast_to_device(w2b, weight.device, computation_dtype),
|
|
memory_management.cast_to_device(w2a, weight.device, computation_dtype))
|
|
else:
|
|
m1 = torch.mm(memory_management.cast_to_device(w1a, weight.device, computation_dtype),
|
|
memory_management.cast_to_device(w1b, weight.device, computation_dtype))
|
|
m2 = torch.mm(memory_management.cast_to_device(w2a, weight.device, computation_dtype),
|
|
memory_management.cast_to_device(w2b, weight.device, computation_dtype))
|
|
|
|
try:
|
|
lora_diff = (m1 * m2).reshape(weight.shape)
|
|
if dora_scale is not None:
|
|
weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength, computation_dtype))
|
|
else:
|
|
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
|
|
except Exception as e:
|
|
print("ERROR {} {} {}".format(patch_type, key, e))
|
|
raise e
|
|
elif patch_type == "glora":
|
|
if v[4] is not None:
|
|
alpha = v[4] / v[0].shape[0]
|
|
else:
|
|
alpha = 1.0
|
|
|
|
dora_scale = v[5]
|
|
|
|
a1 = memory_management.cast_to_device(v[0].flatten(start_dim=1), weight.device, computation_dtype)
|
|
a2 = memory_management.cast_to_device(v[1].flatten(start_dim=1), weight.device, computation_dtype)
|
|
b1 = memory_management.cast_to_device(v[2].flatten(start_dim=1), weight.device, computation_dtype)
|
|
b2 = memory_management.cast_to_device(v[3].flatten(start_dim=1), weight.device, computation_dtype)
|
|
|
|
try:
|
|
lora_diff = (torch.mm(b2, b1) + torch.mm(torch.mm(weight.flatten(start_dim=1), a2), a1)).reshape(weight.shape)
|
|
if dora_scale is not None:
|
|
weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength, computation_dtype))
|
|
else:
|
|
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
|
|
except Exception as e:
|
|
print("ERROR {} {} {}".format(patch_type, key, e))
|
|
raise e
|
|
elif patch_type in extra_weight_calculators:
|
|
weight = extra_weight_calculators[patch_type](weight, strength, v)
|
|
else:
|
|
print("patch type not recognized {} {}".format(patch_type, key))
|
|
|
|
if old_weight is not None:
|
|
weight = old_weight
|
|
|
|
if weight_dtype_backup is not None:
|
|
weight = weight.to(dtype=weight_dtype_backup)
|
|
|
|
return weight
|
|
|
|
|
|
def get_parameter_devices(model):
|
|
parameter_devices = {}
|
|
for key, p in model.named_parameters():
|
|
parameter_devices[key] = p.device
|
|
return parameter_devices
|
|
|
|
|
|
def set_parameter_devices(model, parameter_devices):
|
|
for key, device in parameter_devices.items():
|
|
p = utils.get_attr(model, key)
|
|
if p.device != device:
|
|
p = utils.tensor2parameter(p.to(device=device))
|
|
utils.set_attr_raw(model, key, p)
|
|
return model
|
|
|
|
|
|
from backend import operations
|
|
|
|
|
|
class LoraLoader:
|
|
def __init__(self, model):
|
|
self.model = model
|
|
self.patches = {}
|
|
self.backup = {}
|
|
self.online_backup = []
|
|
self.dirty = False
|
|
self.online_mode = False
|
|
|
|
def clear_patches(self):
|
|
self.patches.clear()
|
|
self.dirty = True
|
|
return
|
|
|
|
def add_patches(self, patches, strength_patch=1.0, strength_model=1.0):
|
|
p = set()
|
|
model_sd = self.model.state_dict()
|
|
|
|
for k in patches:
|
|
offset = None
|
|
function = None
|
|
|
|
if isinstance(k, str):
|
|
key = k
|
|
else:
|
|
offset = k[1]
|
|
key = k[0]
|
|
if len(k) > 2:
|
|
function = k[2]
|
|
|
|
if key in model_sd:
|
|
p.add(k)
|
|
current_patches = self.patches.get(key, [])
|
|
current_patches.append([strength_patch, patches[k], strength_model, offset, function])
|
|
self.patches[key] = current_patches
|
|
|
|
self.dirty = True
|
|
|
|
self.online_mode = dynamic_args.get('online_lora', False)
|
|
|
|
if hasattr(self.model, 'storage_dtype'):
|
|
if self.model.storage_dtype in [torch.float32, torch.float16, torch.bfloat16]:
|
|
self.online_mode = False
|
|
|
|
return list(p)
|
|
|
|
@torch.inference_mode()
|
|
def refresh(self, offload_device=torch.device('cpu')):
|
|
if not self.dirty:
|
|
return
|
|
|
|
self.dirty = False
|
|
|
|
# Initialize
|
|
|
|
memory_management.signal_empty_cache = True
|
|
|
|
parameter_devices = get_parameter_devices(self.model)
|
|
|
|
# Restore
|
|
|
|
for m in set(self.online_backup):
|
|
del m.forge_online_loras
|
|
|
|
self.online_backup = []
|
|
|
|
for k, w in self.backup.items():
|
|
if not isinstance(w, torch.nn.Parameter):
|
|
# In very few cases
|
|
w = torch.nn.Parameter(w, requires_grad=False)
|
|
|
|
utils.set_attr_raw(self.model, k, w)
|
|
|
|
self.backup = {}
|
|
|
|
set_parameter_devices(self.model, parameter_devices=parameter_devices)
|
|
|
|
# Patch
|
|
|
|
for key, current_patches in self.patches.items():
|
|
try:
|
|
parent_layer, child_key, weight = utils.get_attr_with_parent(self.model, key)
|
|
assert isinstance(weight, torch.nn.Parameter)
|
|
except:
|
|
raise ValueError(f"Wrong LoRA Key: {key}")
|
|
|
|
if self.online_mode:
|
|
if not hasattr(parent_layer, 'forge_online_loras'):
|
|
parent_layer.forge_online_loras = {}
|
|
|
|
parent_layer.forge_online_loras[child_key] = current_patches
|
|
self.online_backup.append(parent_layer)
|
|
continue
|
|
|
|
if key not in self.backup:
|
|
self.backup[key] = weight.to(device=offload_device)
|
|
|
|
bnb_layer = None
|
|
|
|
if hasattr(weight, 'bnb_quantized') and operations.bnb_avaliable:
|
|
bnb_layer = parent_layer
|
|
from backend.operations_bnb import functional_dequantize_4bit
|
|
weight = functional_dequantize_4bit(weight)
|
|
|
|
gguf_cls = getattr(weight, 'gguf_cls', None)
|
|
gguf_parameter = None
|
|
|
|
if gguf_cls is not None:
|
|
gguf_parameter = weight
|
|
from backend.operations_gguf import dequantize_tensor
|
|
weight = dequantize_tensor(weight)
|
|
|
|
try:
|
|
weight = merge_lora_to_weight(current_patches, weight, key, computation_dtype=torch.float32)
|
|
except:
|
|
print('Patching LoRA weights out of memory. Retrying by offloading models.')
|
|
set_parameter_devices(self.model, parameter_devices={k: offload_device for k in parameter_devices.keys()})
|
|
memory_management.soft_empty_cache()
|
|
weight = merge_lora_to_weight(current_patches, weight, key, computation_dtype=torch.float32)
|
|
|
|
if bnb_layer is not None:
|
|
bnb_layer.reload_weight(weight)
|
|
continue
|
|
|
|
if gguf_cls is not None:
|
|
gguf_parameter.data = gguf_cls.quantize_pytorch(weight, gguf_parameter.shape)
|
|
gguf_parameter.baked = False
|
|
gguf_cls.bake(gguf_parameter)
|
|
continue
|
|
|
|
utils.set_attr_raw(self.model, key, torch.nn.Parameter(weight, requires_grad=False))
|
|
|
|
# End
|
|
|
|
set_parameter_devices(self.model, parameter_devices=parameter_devices)
|
|
|
|
if len(self.patches) > 0:
|
|
if self.online_mode:
|
|
print(f'Patched LoRAs on-the-fly; ', end='')
|
|
else:
|
|
print(f'Patched LoRAs by precomputing model weights; ', end='')
|
|
|
|
return
|