Files
stable-diffusion-webui-forge/backend/patcher/lora.py
layerdiffusion 4c9380c46a Speed up quant model loading and inference ...
... based on 3 evidences:
1. torch.Tensor.view on one big tensor is slightly faster than calling torch.Tensor.to on multiple small tensors.
2. but torch.Tensor.to with dtype change is significantly slower than torch.Tensor.view
3. “baking” model on GPU is significantly faster than computing on CPU when model load.

mainly influence inference of Q8_0, Q4_0/1/K and loading of all quants
2024-08-30 00:49:05 -07:00

429 lines
16 KiB
Python

import torch
import time
import packages_3rdparty.webui_lora_collection.lora as lora_utils_webui
import packages_3rdparty.comfyui_lora_collection.lora as lora_utils_comfyui
from tqdm import tqdm
from backend import memory_management, utils
from backend.args import dynamic_args
class ForgeLoraCollection:
# TODO
pass
extra_weight_calculators = {}
lora_utils_forge = ForgeLoraCollection()
lora_collection_priority = [lora_utils_forge, lora_utils_webui, lora_utils_comfyui]
def get_function(function_name: str):
for lora_collection in lora_collection_priority:
if hasattr(lora_collection, function_name):
return getattr(lora_collection, function_name)
def load_lora(lora, to_load):
patch_dict, remaining_dict = get_function('load_lora')(lora, to_load)
return patch_dict, remaining_dict
def model_lora_keys_clip(model, key_map={}):
return get_function('model_lora_keys_clip')(model, key_map)
def model_lora_keys_unet(model, key_map={}):
return get_function('model_lora_keys_unet')(model, key_map)
@torch.inference_mode()
def weight_decompose(dora_scale, weight, lora_diff, alpha, strength, computation_dtype):
# Modified from https://github.com/comfyanonymous/ComfyUI/blob/39f114c44bb99d4a221e8da451d4f2a20119c674/comfy/model_patcher.py#L33
dora_scale = memory_management.cast_to_device(dora_scale, weight.device, computation_dtype)
lora_diff *= alpha
weight_calc = weight + lora_diff.type(weight.dtype)
weight_norm = (
weight_calc.transpose(0, 1)
.reshape(weight_calc.shape[1], -1)
.norm(dim=1, keepdim=True)
.reshape(weight_calc.shape[1], *[1] * (weight_calc.dim() - 1))
.transpose(0, 1)
)
weight_calc *= (dora_scale / weight_norm).type(weight.dtype)
if strength != 1.0:
weight_calc -= weight
weight += strength * weight_calc
else:
weight[:] = weight_calc
return weight
@torch.inference_mode()
def merge_lora_to_weight(patches, weight, key="online_lora", computation_dtype=torch.float32):
# Modified from https://github.com/comfyanonymous/ComfyUI/blob/39f114c44bb99d4a221e8da451d4f2a20119c674/comfy/model_patcher.py#L446
weight_dtype_backup = None
if computation_dtype == weight.dtype:
weight = weight.clone()
else:
weight_dtype_backup = weight.dtype
weight = weight.to(dtype=computation_dtype)
for p in patches:
strength = p[0]
v = p[1]
strength_model = p[2]
offset = p[3]
function = p[4]
if function is None:
function = lambda a: a
old_weight = None
if offset is not None:
old_weight = weight
weight = weight.narrow(offset[0], offset[1], offset[2])
if strength_model != 1.0:
weight *= strength_model
if isinstance(v, list):
v = (merge_lora_to_weight(v[1:], v[0].clone(), key),)
patch_type = ''
if len(v) == 1:
patch_type = "diff"
elif len(v) == 2:
patch_type = v[0]
v = v[1]
if patch_type == "diff":
w1 = v[0]
if strength != 0.0:
if w1.shape != weight.shape:
if w1.ndim == weight.ndim == 4:
new_shape = [max(n, m) for n, m in zip(weight.shape, w1.shape)]
print(f'Merged with {key} channel changed to {new_shape}')
new_diff = strength * memory_management.cast_to_device(w1, weight.device, weight.dtype)
new_weight = torch.zeros(size=new_shape).to(weight)
new_weight[:weight.shape[0], :weight.shape[1], :weight.shape[2], :weight.shape[3]] = weight
new_weight[:new_diff.shape[0], :new_diff.shape[1], :new_diff.shape[2], :new_diff.shape[3]] += new_diff
new_weight = new_weight.contiguous().clone()
weight = new_weight
else:
print("WARNING SHAPE MISMATCH {} WEIGHT NOT MERGED {} != {}".format(key, w1.shape, weight.shape))
else:
weight += strength * memory_management.cast_to_device(w1, weight.device, weight.dtype)
elif patch_type == "lora":
mat1 = memory_management.cast_to_device(v[0], weight.device, computation_dtype)
mat2 = memory_management.cast_to_device(v[1], weight.device, computation_dtype)
dora_scale = v[4]
if v[2] is not None:
alpha = v[2] / mat2.shape[0]
else:
alpha = 1.0
if v[3] is not None:
mat3 = memory_management.cast_to_device(v[3], weight.device, computation_dtype)
final_shape = [mat2.shape[1], mat2.shape[0], mat3.shape[2], mat3.shape[3]]
mat2 = torch.mm(mat2.transpose(0, 1).flatten(start_dim=1), mat3.transpose(0, 1).flatten(start_dim=1)).reshape(final_shape).transpose(0, 1)
try:
lora_diff = torch.mm(mat1.flatten(start_dim=1), mat2.flatten(start_dim=1)).reshape(weight.shape)
if dora_scale is not None:
weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength, computation_dtype))
else:
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
except Exception as e:
print("ERROR {} {} {}".format(patch_type, key, e))
raise e
elif patch_type == "lokr":
w1 = v[0]
w2 = v[1]
w1_a = v[3]
w1_b = v[4]
w2_a = v[5]
w2_b = v[6]
t2 = v[7]
dora_scale = v[8]
dim = None
if w1 is None:
dim = w1_b.shape[0]
w1 = torch.mm(memory_management.cast_to_device(w1_a, weight.device, computation_dtype),
memory_management.cast_to_device(w1_b, weight.device, computation_dtype))
else:
w1 = memory_management.cast_to_device(w1, weight.device, computation_dtype)
if w2 is None:
dim = w2_b.shape[0]
if t2 is None:
w2 = torch.mm(memory_management.cast_to_device(w2_a, weight.device, computation_dtype),
memory_management.cast_to_device(w2_b, weight.device, computation_dtype))
else:
w2 = torch.einsum('i j k l, j r, i p -> p r k l',
memory_management.cast_to_device(t2, weight.device, computation_dtype),
memory_management.cast_to_device(w2_b, weight.device, computation_dtype),
memory_management.cast_to_device(w2_a, weight.device, computation_dtype))
else:
w2 = memory_management.cast_to_device(w2, weight.device, computation_dtype)
if len(w2.shape) == 4:
w1 = w1.unsqueeze(2).unsqueeze(2)
if v[2] is not None and dim is not None:
alpha = v[2] / dim
else:
alpha = 1.0
try:
lora_diff = torch.kron(w1, w2).reshape(weight.shape)
if dora_scale is not None:
weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength, computation_dtype))
else:
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
except Exception as e:
print("ERROR {} {} {}".format(patch_type, key, e))
raise e
elif patch_type == "loha":
w1a = v[0]
w1b = v[1]
if v[2] is not None:
alpha = v[2] / w1b.shape[0]
else:
alpha = 1.0
w2a = v[3]
w2b = v[4]
dora_scale = v[7]
if v[5] is not None:
t1 = v[5]
t2 = v[6]
m1 = torch.einsum('i j k l, j r, i p -> p r k l',
memory_management.cast_to_device(t1, weight.device, computation_dtype),
memory_management.cast_to_device(w1b, weight.device, computation_dtype),
memory_management.cast_to_device(w1a, weight.device, computation_dtype))
m2 = torch.einsum('i j k l, j r, i p -> p r k l',
memory_management.cast_to_device(t2, weight.device, computation_dtype),
memory_management.cast_to_device(w2b, weight.device, computation_dtype),
memory_management.cast_to_device(w2a, weight.device, computation_dtype))
else:
m1 = torch.mm(memory_management.cast_to_device(w1a, weight.device, computation_dtype),
memory_management.cast_to_device(w1b, weight.device, computation_dtype))
m2 = torch.mm(memory_management.cast_to_device(w2a, weight.device, computation_dtype),
memory_management.cast_to_device(w2b, weight.device, computation_dtype))
try:
lora_diff = (m1 * m2).reshape(weight.shape)
if dora_scale is not None:
weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength, computation_dtype))
else:
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
except Exception as e:
print("ERROR {} {} {}".format(patch_type, key, e))
raise e
elif patch_type == "glora":
if v[4] is not None:
alpha = v[4] / v[0].shape[0]
else:
alpha = 1.0
dora_scale = v[5]
a1 = memory_management.cast_to_device(v[0].flatten(start_dim=1), weight.device, computation_dtype)
a2 = memory_management.cast_to_device(v[1].flatten(start_dim=1), weight.device, computation_dtype)
b1 = memory_management.cast_to_device(v[2].flatten(start_dim=1), weight.device, computation_dtype)
b2 = memory_management.cast_to_device(v[3].flatten(start_dim=1), weight.device, computation_dtype)
try:
lora_diff = (torch.mm(b2, b1) + torch.mm(torch.mm(weight.flatten(start_dim=1), a2), a1)).reshape(weight.shape)
if dora_scale is not None:
weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength, computation_dtype))
else:
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
except Exception as e:
print("ERROR {} {} {}".format(patch_type, key, e))
raise e
elif patch_type in extra_weight_calculators:
weight = extra_weight_calculators[patch_type](weight, strength, v)
else:
print("patch type not recognized {} {}".format(patch_type, key))
if old_weight is not None:
weight = old_weight
if weight_dtype_backup is not None:
weight = weight.to(dtype=weight_dtype_backup)
return weight
def get_parameter_devices(model):
parameter_devices = {}
for key, p in model.named_parameters():
parameter_devices[key] = p.device
return parameter_devices
def set_parameter_devices(model, parameter_devices):
for key, device in parameter_devices.items():
p = utils.get_attr(model, key)
if p.device != device:
p = utils.tensor2parameter(p.to(device=device))
utils.set_attr_raw(model, key, p)
return model
from backend import operations
class LoraLoader:
def __init__(self, model):
self.model = model
self.patches = {}
self.backup = {}
self.online_backup = []
self.dirty = False
self.online_mode = False
def clear_patches(self):
self.patches.clear()
self.dirty = True
return
def add_patches(self, patches, strength_patch=1.0, strength_model=1.0):
p = set()
model_sd = self.model.state_dict()
for k in patches:
offset = None
function = None
if isinstance(k, str):
key = k
else:
offset = k[1]
key = k[0]
if len(k) > 2:
function = k[2]
if key in model_sd:
p.add(k)
current_patches = self.patches.get(key, [])
current_patches.append([strength_patch, patches[k], strength_model, offset, function])
self.patches[key] = current_patches
self.dirty = True
self.online_mode = dynamic_args.get('online_lora', False)
if hasattr(self.model, 'storage_dtype'):
if self.model.storage_dtype in [torch.float32, torch.float16, torch.bfloat16]:
self.online_mode = False
return list(p)
@torch.inference_mode()
def refresh(self, offload_device=torch.device('cpu')):
if not self.dirty:
return
self.dirty = False
# Initialize
memory_management.signal_empty_cache = True
parameter_devices = get_parameter_devices(self.model)
# Restore
for m in set(self.online_backup):
del m.forge_online_loras
self.online_backup = []
for k, w in self.backup.items():
if not isinstance(w, torch.nn.Parameter):
# In very few cases
w = torch.nn.Parameter(w, requires_grad=False)
utils.set_attr_raw(self.model, k, w)
self.backup = {}
set_parameter_devices(self.model, parameter_devices=parameter_devices)
# Patch
for key, current_patches in self.patches.items():
try:
parent_layer, child_key, weight = utils.get_attr_with_parent(self.model, key)
assert isinstance(weight, torch.nn.Parameter)
except:
raise ValueError(f"Wrong LoRA Key: {key}")
if self.online_mode:
if not hasattr(parent_layer, 'forge_online_loras'):
parent_layer.forge_online_loras = {}
parent_layer.forge_online_loras[child_key] = current_patches
self.online_backup.append(parent_layer)
continue
if key not in self.backup:
self.backup[key] = weight.to(device=offload_device)
bnb_layer = None
if hasattr(weight, 'bnb_quantized') and operations.bnb_avaliable:
bnb_layer = parent_layer
from backend.operations_bnb import functional_dequantize_4bit
weight = functional_dequantize_4bit(weight)
gguf_cls = getattr(weight, 'gguf_cls', None)
gguf_parameter = None
if gguf_cls is not None:
gguf_parameter = weight
from backend.operations_gguf import dequantize_tensor
weight = dequantize_tensor(weight)
try:
weight = merge_lora_to_weight(current_patches, weight, key, computation_dtype=torch.float32)
except:
print('Patching LoRA weights out of memory. Retrying by offloading models.')
set_parameter_devices(self.model, parameter_devices={k: offload_device for k in parameter_devices.keys()})
memory_management.soft_empty_cache()
weight = merge_lora_to_weight(current_patches, weight, key, computation_dtype=torch.float32)
if bnb_layer is not None:
bnb_layer.reload_weight(weight)
continue
if gguf_cls is not None:
gguf_parameter.data = gguf_cls.quantize_pytorch(weight, gguf_parameter.shape)
gguf_parameter.baked = False
gguf_cls.bake(gguf_parameter)
continue
utils.set_attr_raw(self.model, key, torch.nn.Parameter(weight, requires_grad=False))
# End
set_parameter_devices(self.model, parameter_devices=parameter_devices)
if len(self.patches) > 0:
if self.online_mode:
print(f'Patched LoRAs on-the-fly; ', end='')
else:
print(f'Patched LoRAs by precomputing model weights; ', end='')
return