mirror of
https://github.com/lllyasviel/stable-diffusion-webui-forge.git
synced 2026-03-05 13:09:48 +00:00
1. Add an option to allow users to use UNet in fp8/gguf but lora in fp16. 2. All FP16 loras do not need patch. Others will only patch again when lora weight change. 3. FP8 unet + fp16 lora are available (somewhat only available) in Forge now. This also solves some “LoRA too subtle” problems. 4. Significantly speed up all gguf models (in Async mode) by using independent thread (CUDA stream) to compute and dequant at the same time, even when low-bit weights are already on GPU. 5. View “online lora” as a module similar to ControlLoRA so that it is moved to GPU together with model when sampling, achieving significant speedup and perfect low VRAM management simultaneously.
418 lines
16 KiB
Python
418 lines
16 KiB
Python
import torch
|
|
import time
|
|
|
|
import packages_3rdparty.webui_lora_collection.lora as lora_utils_webui
|
|
import packages_3rdparty.comfyui_lora_collection.lora as lora_utils_comfyui
|
|
|
|
from tqdm import tqdm
|
|
from backend import memory_management, utils
|
|
from backend.args import dynamic_args
|
|
|
|
|
|
class ForgeLoraCollection:
|
|
# TODO
|
|
pass
|
|
|
|
|
|
extra_weight_calculators = {}
|
|
|
|
lora_utils_forge = ForgeLoraCollection()
|
|
|
|
lora_collection_priority = [lora_utils_forge, lora_utils_webui, lora_utils_comfyui]
|
|
|
|
|
|
def get_function(function_name: str):
|
|
for lora_collection in lora_collection_priority:
|
|
if hasattr(lora_collection, function_name):
|
|
return getattr(lora_collection, function_name)
|
|
|
|
|
|
def load_lora(lora, to_load):
|
|
patch_dict, remaining_dict = get_function('load_lora')(lora, to_load)
|
|
return patch_dict, remaining_dict
|
|
|
|
|
|
def model_lora_keys_clip(model, key_map={}):
|
|
return get_function('model_lora_keys_clip')(model, key_map)
|
|
|
|
|
|
def model_lora_keys_unet(model, key_map={}):
|
|
return get_function('model_lora_keys_unet')(model, key_map)
|
|
|
|
|
|
def weight_decompose(dora_scale, weight, lora_diff, alpha, strength, computation_dtype):
|
|
# Modified from https://github.com/comfyanonymous/ComfyUI/blob/39f114c44bb99d4a221e8da451d4f2a20119c674/comfy/model_patcher.py#L33
|
|
|
|
dora_scale = memory_management.cast_to_device(dora_scale, weight.device, computation_dtype)
|
|
lora_diff *= alpha
|
|
weight_calc = weight + lora_diff.type(weight.dtype)
|
|
weight_norm = (
|
|
weight_calc.transpose(0, 1)
|
|
.reshape(weight_calc.shape[1], -1)
|
|
.norm(dim=1, keepdim=True)
|
|
.reshape(weight_calc.shape[1], *[1] * (weight_calc.dim() - 1))
|
|
.transpose(0, 1)
|
|
)
|
|
|
|
weight_calc *= (dora_scale / weight_norm).type(weight.dtype)
|
|
if strength != 1.0:
|
|
weight_calc -= weight
|
|
weight += strength * weight_calc
|
|
else:
|
|
weight[:] = weight_calc
|
|
return weight
|
|
|
|
|
|
def merge_lora_to_weight(patches, weight, key="online_lora", computation_dtype=torch.float32):
|
|
# Modified from https://github.com/comfyanonymous/ComfyUI/blob/39f114c44bb99d4a221e8da451d4f2a20119c674/comfy/model_patcher.py#L446
|
|
|
|
weight_original_dtype = weight.dtype
|
|
weight = weight.to(dtype=computation_dtype)
|
|
|
|
for p in patches:
|
|
strength = p[0]
|
|
v = p[1]
|
|
strength_model = p[2]
|
|
offset = p[3]
|
|
function = p[4]
|
|
if function is None:
|
|
function = lambda a: a
|
|
|
|
old_weight = None
|
|
if offset is not None:
|
|
old_weight = weight
|
|
weight = weight.narrow(offset[0], offset[1], offset[2])
|
|
|
|
if strength_model != 1.0:
|
|
weight *= strength_model
|
|
|
|
if isinstance(v, list):
|
|
v = (merge_lora_to_weight(v[1:], v[0].clone(), key),)
|
|
|
|
patch_type = ''
|
|
|
|
if len(v) == 1:
|
|
patch_type = "diff"
|
|
elif len(v) == 2:
|
|
patch_type = v[0]
|
|
v = v[1]
|
|
|
|
if patch_type == "diff":
|
|
w1 = v[0]
|
|
if strength != 0.0:
|
|
if w1.shape != weight.shape:
|
|
if w1.ndim == weight.ndim == 4:
|
|
new_shape = [max(n, m) for n, m in zip(weight.shape, w1.shape)]
|
|
print(f'Merged with {key} channel changed to {new_shape}')
|
|
new_diff = strength * memory_management.cast_to_device(w1, weight.device, weight.dtype)
|
|
new_weight = torch.zeros(size=new_shape).to(weight)
|
|
new_weight[:weight.shape[0], :weight.shape[1], :weight.shape[2], :weight.shape[3]] = weight
|
|
new_weight[:new_diff.shape[0], :new_diff.shape[1], :new_diff.shape[2], :new_diff.shape[3]] += new_diff
|
|
new_weight = new_weight.contiguous().clone()
|
|
weight = new_weight
|
|
else:
|
|
print("WARNING SHAPE MISMATCH {} WEIGHT NOT MERGED {} != {}".format(key, w1.shape, weight.shape))
|
|
else:
|
|
weight += strength * memory_management.cast_to_device(w1, weight.device, weight.dtype)
|
|
elif patch_type == "lora":
|
|
mat1 = memory_management.cast_to_device(v[0], weight.device, computation_dtype)
|
|
mat2 = memory_management.cast_to_device(v[1], weight.device, computation_dtype)
|
|
dora_scale = v[4]
|
|
if v[2] is not None:
|
|
alpha = v[2] / mat2.shape[0]
|
|
else:
|
|
alpha = 1.0
|
|
|
|
if v[3] is not None:
|
|
mat3 = memory_management.cast_to_device(v[3], weight.device, computation_dtype)
|
|
final_shape = [mat2.shape[1], mat2.shape[0], mat3.shape[2], mat3.shape[3]]
|
|
mat2 = torch.mm(mat2.transpose(0, 1).flatten(start_dim=1), mat3.transpose(0, 1).flatten(start_dim=1)).reshape(final_shape).transpose(0, 1)
|
|
try:
|
|
lora_diff = torch.mm(mat1.flatten(start_dim=1), mat2.flatten(start_dim=1)).reshape(weight.shape)
|
|
if dora_scale is not None:
|
|
weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength, computation_dtype))
|
|
else:
|
|
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
|
|
except Exception as e:
|
|
print("ERROR {} {} {}".format(patch_type, key, e))
|
|
raise e
|
|
elif patch_type == "lokr":
|
|
w1 = v[0]
|
|
w2 = v[1]
|
|
w1_a = v[3]
|
|
w1_b = v[4]
|
|
w2_a = v[5]
|
|
w2_b = v[6]
|
|
t2 = v[7]
|
|
dora_scale = v[8]
|
|
dim = None
|
|
|
|
if w1 is None:
|
|
dim = w1_b.shape[0]
|
|
w1 = torch.mm(memory_management.cast_to_device(w1_a, weight.device, computation_dtype),
|
|
memory_management.cast_to_device(w1_b, weight.device, computation_dtype))
|
|
else:
|
|
w1 = memory_management.cast_to_device(w1, weight.device, computation_dtype)
|
|
|
|
if w2 is None:
|
|
dim = w2_b.shape[0]
|
|
if t2 is None:
|
|
w2 = torch.mm(memory_management.cast_to_device(w2_a, weight.device, computation_dtype),
|
|
memory_management.cast_to_device(w2_b, weight.device, computation_dtype))
|
|
else:
|
|
w2 = torch.einsum('i j k l, j r, i p -> p r k l',
|
|
memory_management.cast_to_device(t2, weight.device, computation_dtype),
|
|
memory_management.cast_to_device(w2_b, weight.device, computation_dtype),
|
|
memory_management.cast_to_device(w2_a, weight.device, computation_dtype))
|
|
else:
|
|
w2 = memory_management.cast_to_device(w2, weight.device, computation_dtype)
|
|
|
|
if len(w2.shape) == 4:
|
|
w1 = w1.unsqueeze(2).unsqueeze(2)
|
|
if v[2] is not None and dim is not None:
|
|
alpha = v[2] / dim
|
|
else:
|
|
alpha = 1.0
|
|
|
|
try:
|
|
lora_diff = torch.kron(w1, w2).reshape(weight.shape)
|
|
if dora_scale is not None:
|
|
weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength, computation_dtype))
|
|
else:
|
|
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
|
|
except Exception as e:
|
|
print("ERROR {} {} {}".format(patch_type, key, e))
|
|
raise e
|
|
elif patch_type == "loha":
|
|
w1a = v[0]
|
|
w1b = v[1]
|
|
if v[2] is not None:
|
|
alpha = v[2] / w1b.shape[0]
|
|
else:
|
|
alpha = 1.0
|
|
|
|
w2a = v[3]
|
|
w2b = v[4]
|
|
dora_scale = v[7]
|
|
if v[5] is not None:
|
|
t1 = v[5]
|
|
t2 = v[6]
|
|
m1 = torch.einsum('i j k l, j r, i p -> p r k l',
|
|
memory_management.cast_to_device(t1, weight.device, computation_dtype),
|
|
memory_management.cast_to_device(w1b, weight.device, computation_dtype),
|
|
memory_management.cast_to_device(w1a, weight.device, computation_dtype))
|
|
|
|
m2 = torch.einsum('i j k l, j r, i p -> p r k l',
|
|
memory_management.cast_to_device(t2, weight.device, computation_dtype),
|
|
memory_management.cast_to_device(w2b, weight.device, computation_dtype),
|
|
memory_management.cast_to_device(w2a, weight.device, computation_dtype))
|
|
else:
|
|
m1 = torch.mm(memory_management.cast_to_device(w1a, weight.device, computation_dtype),
|
|
memory_management.cast_to_device(w1b, weight.device, computation_dtype))
|
|
m2 = torch.mm(memory_management.cast_to_device(w2a, weight.device, computation_dtype),
|
|
memory_management.cast_to_device(w2b, weight.device, computation_dtype))
|
|
|
|
try:
|
|
lora_diff = (m1 * m2).reshape(weight.shape)
|
|
if dora_scale is not None:
|
|
weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength, computation_dtype))
|
|
else:
|
|
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
|
|
except Exception as e:
|
|
print("ERROR {} {} {}".format(patch_type, key, e))
|
|
raise e
|
|
elif patch_type == "glora":
|
|
if v[4] is not None:
|
|
alpha = v[4] / v[0].shape[0]
|
|
else:
|
|
alpha = 1.0
|
|
|
|
dora_scale = v[5]
|
|
|
|
a1 = memory_management.cast_to_device(v[0].flatten(start_dim=1), weight.device, computation_dtype)
|
|
a2 = memory_management.cast_to_device(v[1].flatten(start_dim=1), weight.device, computation_dtype)
|
|
b1 = memory_management.cast_to_device(v[2].flatten(start_dim=1), weight.device, computation_dtype)
|
|
b2 = memory_management.cast_to_device(v[3].flatten(start_dim=1), weight.device, computation_dtype)
|
|
|
|
try:
|
|
lora_diff = (torch.mm(b2, b1) + torch.mm(torch.mm(weight.flatten(start_dim=1), a2), a1)).reshape(weight.shape)
|
|
if dora_scale is not None:
|
|
weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength, computation_dtype))
|
|
else:
|
|
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
|
|
except Exception as e:
|
|
print("ERROR {} {} {}".format(patch_type, key, e))
|
|
raise e
|
|
elif patch_type in extra_weight_calculators:
|
|
weight = extra_weight_calculators[patch_type](weight, strength, v)
|
|
else:
|
|
print("patch type not recognized {} {}".format(patch_type, key))
|
|
|
|
if old_weight is not None:
|
|
weight = old_weight
|
|
|
|
weight = weight.to(dtype=weight_original_dtype)
|
|
return weight
|
|
|
|
|
|
from backend import operations
|
|
|
|
|
|
class LoraLoader:
|
|
def __init__(self, model):
|
|
self.model = model
|
|
self.patches = {}
|
|
self.backup = {}
|
|
self.online_backup = []
|
|
self.dirty = False
|
|
|
|
def clear_patches(self):
|
|
self.patches.clear()
|
|
self.dirty = True
|
|
return
|
|
|
|
def add_patches(self, patches, strength_patch=1.0, strength_model=1.0):
|
|
p = set()
|
|
model_sd = self.model.state_dict()
|
|
|
|
for k in patches:
|
|
offset = None
|
|
function = None
|
|
|
|
if isinstance(k, str):
|
|
key = k
|
|
else:
|
|
offset = k[1]
|
|
key = k[0]
|
|
if len(k) > 2:
|
|
function = k[2]
|
|
|
|
if key in model_sd:
|
|
p.add(k)
|
|
current_patches = self.patches.get(key, [])
|
|
current_patches.append([strength_patch, patches[k], strength_model, offset, function])
|
|
self.patches[key] = current_patches
|
|
|
|
self.dirty = True
|
|
return list(p)
|
|
|
|
def refresh(self, target_device=None, offload_device=torch.device('cpu')):
|
|
if not self.dirty:
|
|
return
|
|
|
|
self.dirty = False
|
|
|
|
execution_start_time = time.perf_counter()
|
|
|
|
# Restore
|
|
|
|
for m in set(self.online_backup):
|
|
del m.forge_online_loras
|
|
|
|
self.online_backup = []
|
|
|
|
for k, w in self.backup.items():
|
|
if not isinstance(w, torch.nn.Parameter):
|
|
# In very few cases
|
|
w = torch.nn.Parameter(w, requires_grad=False)
|
|
|
|
utils.set_attr_raw(self.model, k, w)
|
|
|
|
self.backup = {}
|
|
|
|
online_mode = dynamic_args.get('online_lora', False)
|
|
|
|
# Patch
|
|
|
|
for key, current_patches in (tqdm(self.patches.items(), desc=f'Patching LoRAs for {type(self.model).__name__}') if len(self.patches) > 0 else self.patches):
|
|
try:
|
|
parent_layer, child_key, weight = utils.get_attr_with_parent(self.model, key)
|
|
assert isinstance(weight, torch.nn.Parameter)
|
|
except:
|
|
raise ValueError(f"Wrong LoRA Key: {key}")
|
|
|
|
if key not in self.backup:
|
|
self.backup[key] = weight.to(device=offload_device)
|
|
|
|
if online_mode:
|
|
if not hasattr(parent_layer, 'forge_online_loras'):
|
|
parent_layer.forge_online_loras = {}
|
|
|
|
parent_layer.forge_online_loras[child_key] = current_patches
|
|
self.online_backup.append(parent_layer)
|
|
continue
|
|
|
|
bnb_layer = None
|
|
|
|
if operations.bnb_avaliable:
|
|
if hasattr(weight, 'bnb_quantized'):
|
|
bnb_layer = parent_layer
|
|
if weight.bnb_quantized:
|
|
weight_original_device = weight.device
|
|
|
|
if target_device is not None:
|
|
assert target_device.type == 'cuda', 'BNB Must use CUDA!'
|
|
weight = weight.to(target_device)
|
|
else:
|
|
weight = weight.cuda()
|
|
|
|
from backend.operations_bnb import functional_dequantize_4bit
|
|
weight = functional_dequantize_4bit(weight)
|
|
|
|
if target_device is None:
|
|
weight = weight.to(device=weight_original_device)
|
|
else:
|
|
weight = weight.data
|
|
|
|
if target_device is not None:
|
|
try:
|
|
weight = weight.to(device=target_device)
|
|
except:
|
|
print('Moving layer weight failed. Retrying by offloading models.')
|
|
self.model.to(device=offload_device)
|
|
memory_management.soft_empty_cache()
|
|
weight = weight.to(device=target_device)
|
|
|
|
gguf_cls, gguf_type, gguf_real_shape = None, None, None
|
|
|
|
if hasattr(weight, 'is_gguf'):
|
|
from backend.operations_gguf import dequantize_tensor
|
|
gguf_cls = weight.gguf_cls
|
|
gguf_type = weight.gguf_type
|
|
gguf_real_shape = weight.gguf_real_shape
|
|
weight = dequantize_tensor(weight)
|
|
|
|
try:
|
|
weight = merge_lora_to_weight(current_patches, weight, key, computation_dtype=torch.float32)
|
|
except:
|
|
print('Patching LoRA weights failed. Retrying by offloading models.')
|
|
self.model.to(device=offload_device)
|
|
memory_management.soft_empty_cache()
|
|
weight = merge_lora_to_weight(current_patches, weight, key, computation_dtype=torch.float32)
|
|
|
|
if bnb_layer is not None:
|
|
bnb_layer.reload_weight(weight)
|
|
continue
|
|
|
|
if gguf_cls is not None:
|
|
from backend.operations_gguf import ParameterGGUF
|
|
weight = gguf_cls.quantize_pytorch(weight, gguf_real_shape)
|
|
utils.set_attr_raw(self.model, key, ParameterGGUF.make(
|
|
data=weight,
|
|
gguf_type=gguf_type,
|
|
gguf_cls=gguf_cls,
|
|
gguf_real_shape=gguf_real_shape
|
|
))
|
|
continue
|
|
|
|
utils.set_attr_raw(self.model, key, torch.nn.Parameter(weight, requires_grad=False))
|
|
|
|
# Time
|
|
|
|
moving_time = time.perf_counter() - execution_start_time
|
|
|
|
if moving_time > 0.1:
|
|
print(f'LoRA patching has taken {moving_time:.2f} seconds')
|
|
|
|
return
|