Files
stable-diffusion-webui-forge/backend/patcher/lora.py
layerdiffusion d38e560e42 Implement some rethinking about LoRA system
1. Add an option to allow users to use UNet in fp8/gguf but lora in fp16.
2. All FP16 loras do not need patch. Others will only patch again when lora weight change.
3. FP8 unet + fp16 lora are available (somewhat only available) in Forge now. This also solves some “LoRA too subtle” problems.
4. Significantly speed up all gguf models (in Async mode) by using independent thread (CUDA stream) to compute and dequant at the same time, even when low-bit weights are already on GPU.
5. View “online lora” as a module similar to ControlLoRA so that it is moved to GPU together with model when sampling, achieving significant speedup and perfect low VRAM management simultaneously.
2024-08-19 04:31:59 -07:00

418 lines
16 KiB
Python

import torch
import time
import packages_3rdparty.webui_lora_collection.lora as lora_utils_webui
import packages_3rdparty.comfyui_lora_collection.lora as lora_utils_comfyui
from tqdm import tqdm
from backend import memory_management, utils
from backend.args import dynamic_args
class ForgeLoraCollection:
# TODO
pass
extra_weight_calculators = {}
lora_utils_forge = ForgeLoraCollection()
lora_collection_priority = [lora_utils_forge, lora_utils_webui, lora_utils_comfyui]
def get_function(function_name: str):
for lora_collection in lora_collection_priority:
if hasattr(lora_collection, function_name):
return getattr(lora_collection, function_name)
def load_lora(lora, to_load):
patch_dict, remaining_dict = get_function('load_lora')(lora, to_load)
return patch_dict, remaining_dict
def model_lora_keys_clip(model, key_map={}):
return get_function('model_lora_keys_clip')(model, key_map)
def model_lora_keys_unet(model, key_map={}):
return get_function('model_lora_keys_unet')(model, key_map)
def weight_decompose(dora_scale, weight, lora_diff, alpha, strength, computation_dtype):
# Modified from https://github.com/comfyanonymous/ComfyUI/blob/39f114c44bb99d4a221e8da451d4f2a20119c674/comfy/model_patcher.py#L33
dora_scale = memory_management.cast_to_device(dora_scale, weight.device, computation_dtype)
lora_diff *= alpha
weight_calc = weight + lora_diff.type(weight.dtype)
weight_norm = (
weight_calc.transpose(0, 1)
.reshape(weight_calc.shape[1], -1)
.norm(dim=1, keepdim=True)
.reshape(weight_calc.shape[1], *[1] * (weight_calc.dim() - 1))
.transpose(0, 1)
)
weight_calc *= (dora_scale / weight_norm).type(weight.dtype)
if strength != 1.0:
weight_calc -= weight
weight += strength * weight_calc
else:
weight[:] = weight_calc
return weight
def merge_lora_to_weight(patches, weight, key="online_lora", computation_dtype=torch.float32):
# Modified from https://github.com/comfyanonymous/ComfyUI/blob/39f114c44bb99d4a221e8da451d4f2a20119c674/comfy/model_patcher.py#L446
weight_original_dtype = weight.dtype
weight = weight.to(dtype=computation_dtype)
for p in patches:
strength = p[0]
v = p[1]
strength_model = p[2]
offset = p[3]
function = p[4]
if function is None:
function = lambda a: a
old_weight = None
if offset is not None:
old_weight = weight
weight = weight.narrow(offset[0], offset[1], offset[2])
if strength_model != 1.0:
weight *= strength_model
if isinstance(v, list):
v = (merge_lora_to_weight(v[1:], v[0].clone(), key),)
patch_type = ''
if len(v) == 1:
patch_type = "diff"
elif len(v) == 2:
patch_type = v[0]
v = v[1]
if patch_type == "diff":
w1 = v[0]
if strength != 0.0:
if w1.shape != weight.shape:
if w1.ndim == weight.ndim == 4:
new_shape = [max(n, m) for n, m in zip(weight.shape, w1.shape)]
print(f'Merged with {key} channel changed to {new_shape}')
new_diff = strength * memory_management.cast_to_device(w1, weight.device, weight.dtype)
new_weight = torch.zeros(size=new_shape).to(weight)
new_weight[:weight.shape[0], :weight.shape[1], :weight.shape[2], :weight.shape[3]] = weight
new_weight[:new_diff.shape[0], :new_diff.shape[1], :new_diff.shape[2], :new_diff.shape[3]] += new_diff
new_weight = new_weight.contiguous().clone()
weight = new_weight
else:
print("WARNING SHAPE MISMATCH {} WEIGHT NOT MERGED {} != {}".format(key, w1.shape, weight.shape))
else:
weight += strength * memory_management.cast_to_device(w1, weight.device, weight.dtype)
elif patch_type == "lora":
mat1 = memory_management.cast_to_device(v[0], weight.device, computation_dtype)
mat2 = memory_management.cast_to_device(v[1], weight.device, computation_dtype)
dora_scale = v[4]
if v[2] is not None:
alpha = v[2] / mat2.shape[0]
else:
alpha = 1.0
if v[3] is not None:
mat3 = memory_management.cast_to_device(v[3], weight.device, computation_dtype)
final_shape = [mat2.shape[1], mat2.shape[0], mat3.shape[2], mat3.shape[3]]
mat2 = torch.mm(mat2.transpose(0, 1).flatten(start_dim=1), mat3.transpose(0, 1).flatten(start_dim=1)).reshape(final_shape).transpose(0, 1)
try:
lora_diff = torch.mm(mat1.flatten(start_dim=1), mat2.flatten(start_dim=1)).reshape(weight.shape)
if dora_scale is not None:
weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength, computation_dtype))
else:
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
except Exception as e:
print("ERROR {} {} {}".format(patch_type, key, e))
raise e
elif patch_type == "lokr":
w1 = v[0]
w2 = v[1]
w1_a = v[3]
w1_b = v[4]
w2_a = v[5]
w2_b = v[6]
t2 = v[7]
dora_scale = v[8]
dim = None
if w1 is None:
dim = w1_b.shape[0]
w1 = torch.mm(memory_management.cast_to_device(w1_a, weight.device, computation_dtype),
memory_management.cast_to_device(w1_b, weight.device, computation_dtype))
else:
w1 = memory_management.cast_to_device(w1, weight.device, computation_dtype)
if w2 is None:
dim = w2_b.shape[0]
if t2 is None:
w2 = torch.mm(memory_management.cast_to_device(w2_a, weight.device, computation_dtype),
memory_management.cast_to_device(w2_b, weight.device, computation_dtype))
else:
w2 = torch.einsum('i j k l, j r, i p -> p r k l',
memory_management.cast_to_device(t2, weight.device, computation_dtype),
memory_management.cast_to_device(w2_b, weight.device, computation_dtype),
memory_management.cast_to_device(w2_a, weight.device, computation_dtype))
else:
w2 = memory_management.cast_to_device(w2, weight.device, computation_dtype)
if len(w2.shape) == 4:
w1 = w1.unsqueeze(2).unsqueeze(2)
if v[2] is not None and dim is not None:
alpha = v[2] / dim
else:
alpha = 1.0
try:
lora_diff = torch.kron(w1, w2).reshape(weight.shape)
if dora_scale is not None:
weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength, computation_dtype))
else:
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
except Exception as e:
print("ERROR {} {} {}".format(patch_type, key, e))
raise e
elif patch_type == "loha":
w1a = v[0]
w1b = v[1]
if v[2] is not None:
alpha = v[2] / w1b.shape[0]
else:
alpha = 1.0
w2a = v[3]
w2b = v[4]
dora_scale = v[7]
if v[5] is not None:
t1 = v[5]
t2 = v[6]
m1 = torch.einsum('i j k l, j r, i p -> p r k l',
memory_management.cast_to_device(t1, weight.device, computation_dtype),
memory_management.cast_to_device(w1b, weight.device, computation_dtype),
memory_management.cast_to_device(w1a, weight.device, computation_dtype))
m2 = torch.einsum('i j k l, j r, i p -> p r k l',
memory_management.cast_to_device(t2, weight.device, computation_dtype),
memory_management.cast_to_device(w2b, weight.device, computation_dtype),
memory_management.cast_to_device(w2a, weight.device, computation_dtype))
else:
m1 = torch.mm(memory_management.cast_to_device(w1a, weight.device, computation_dtype),
memory_management.cast_to_device(w1b, weight.device, computation_dtype))
m2 = torch.mm(memory_management.cast_to_device(w2a, weight.device, computation_dtype),
memory_management.cast_to_device(w2b, weight.device, computation_dtype))
try:
lora_diff = (m1 * m2).reshape(weight.shape)
if dora_scale is not None:
weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength, computation_dtype))
else:
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
except Exception as e:
print("ERROR {} {} {}".format(patch_type, key, e))
raise e
elif patch_type == "glora":
if v[4] is not None:
alpha = v[4] / v[0].shape[0]
else:
alpha = 1.0
dora_scale = v[5]
a1 = memory_management.cast_to_device(v[0].flatten(start_dim=1), weight.device, computation_dtype)
a2 = memory_management.cast_to_device(v[1].flatten(start_dim=1), weight.device, computation_dtype)
b1 = memory_management.cast_to_device(v[2].flatten(start_dim=1), weight.device, computation_dtype)
b2 = memory_management.cast_to_device(v[3].flatten(start_dim=1), weight.device, computation_dtype)
try:
lora_diff = (torch.mm(b2, b1) + torch.mm(torch.mm(weight.flatten(start_dim=1), a2), a1)).reshape(weight.shape)
if dora_scale is not None:
weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength, computation_dtype))
else:
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
except Exception as e:
print("ERROR {} {} {}".format(patch_type, key, e))
raise e
elif patch_type in extra_weight_calculators:
weight = extra_weight_calculators[patch_type](weight, strength, v)
else:
print("patch type not recognized {} {}".format(patch_type, key))
if old_weight is not None:
weight = old_weight
weight = weight.to(dtype=weight_original_dtype)
return weight
from backend import operations
class LoraLoader:
def __init__(self, model):
self.model = model
self.patches = {}
self.backup = {}
self.online_backup = []
self.dirty = False
def clear_patches(self):
self.patches.clear()
self.dirty = True
return
def add_patches(self, patches, strength_patch=1.0, strength_model=1.0):
p = set()
model_sd = self.model.state_dict()
for k in patches:
offset = None
function = None
if isinstance(k, str):
key = k
else:
offset = k[1]
key = k[0]
if len(k) > 2:
function = k[2]
if key in model_sd:
p.add(k)
current_patches = self.patches.get(key, [])
current_patches.append([strength_patch, patches[k], strength_model, offset, function])
self.patches[key] = current_patches
self.dirty = True
return list(p)
def refresh(self, target_device=None, offload_device=torch.device('cpu')):
if not self.dirty:
return
self.dirty = False
execution_start_time = time.perf_counter()
# Restore
for m in set(self.online_backup):
del m.forge_online_loras
self.online_backup = []
for k, w in self.backup.items():
if not isinstance(w, torch.nn.Parameter):
# In very few cases
w = torch.nn.Parameter(w, requires_grad=False)
utils.set_attr_raw(self.model, k, w)
self.backup = {}
online_mode = dynamic_args.get('online_lora', False)
# Patch
for key, current_patches in (tqdm(self.patches.items(), desc=f'Patching LoRAs for {type(self.model).__name__}') if len(self.patches) > 0 else self.patches):
try:
parent_layer, child_key, weight = utils.get_attr_with_parent(self.model, key)
assert isinstance(weight, torch.nn.Parameter)
except:
raise ValueError(f"Wrong LoRA Key: {key}")
if key not in self.backup:
self.backup[key] = weight.to(device=offload_device)
if online_mode:
if not hasattr(parent_layer, 'forge_online_loras'):
parent_layer.forge_online_loras = {}
parent_layer.forge_online_loras[child_key] = current_patches
self.online_backup.append(parent_layer)
continue
bnb_layer = None
if operations.bnb_avaliable:
if hasattr(weight, 'bnb_quantized'):
bnb_layer = parent_layer
if weight.bnb_quantized:
weight_original_device = weight.device
if target_device is not None:
assert target_device.type == 'cuda', 'BNB Must use CUDA!'
weight = weight.to(target_device)
else:
weight = weight.cuda()
from backend.operations_bnb import functional_dequantize_4bit
weight = functional_dequantize_4bit(weight)
if target_device is None:
weight = weight.to(device=weight_original_device)
else:
weight = weight.data
if target_device is not None:
try:
weight = weight.to(device=target_device)
except:
print('Moving layer weight failed. Retrying by offloading models.')
self.model.to(device=offload_device)
memory_management.soft_empty_cache()
weight = weight.to(device=target_device)
gguf_cls, gguf_type, gguf_real_shape = None, None, None
if hasattr(weight, 'is_gguf'):
from backend.operations_gguf import dequantize_tensor
gguf_cls = weight.gguf_cls
gguf_type = weight.gguf_type
gguf_real_shape = weight.gguf_real_shape
weight = dequantize_tensor(weight)
try:
weight = merge_lora_to_weight(current_patches, weight, key, computation_dtype=torch.float32)
except:
print('Patching LoRA weights failed. Retrying by offloading models.')
self.model.to(device=offload_device)
memory_management.soft_empty_cache()
weight = merge_lora_to_weight(current_patches, weight, key, computation_dtype=torch.float32)
if bnb_layer is not None:
bnb_layer.reload_weight(weight)
continue
if gguf_cls is not None:
from backend.operations_gguf import ParameterGGUF
weight = gguf_cls.quantize_pytorch(weight, gguf_real_shape)
utils.set_attr_raw(self.model, key, ParameterGGUF.make(
data=weight,
gguf_type=gguf_type,
gguf_cls=gguf_cls,
gguf_real_shape=gguf_real_shape
))
continue
utils.set_attr_raw(self.model, key, torch.nn.Parameter(weight, requires_grad=False))
# Time
moving_time = time.perf_counter() - execution_start_time
if moving_time > 0.1:
print(f'LoRA patching has taken {moving_time:.2f} seconds')
return