mirror of
https://github.com/lllyasviel/stable-diffusion-webui-forge.git
synced 2026-01-26 10:59:47 +00:00
Implement some rethinking about LoRA system
1. Add an option to allow users to use UNet in fp8/gguf but lora in fp16. 2. All FP16 loras do not need patch. Others will only patch again when lora weight change. 3. FP8 unet + fp16 lora are available (somewhat only available) in Forge now. This also solves some “LoRA too subtle” problems. 4. Significantly speed up all gguf models (in Async mode) by using independent thread (CUDA stream) to compute and dequant at the same time, even when low-bit weights are already on GPU. 5. View “online lora” as a module similar to ControlLoRA so that it is moved to GPU together with model when sampling, achieving significant speedup and perfect low VRAM management simultaneously.
This commit is contained in:
@@ -378,6 +378,7 @@ class LoadedModel:
|
||||
|
||||
try:
|
||||
self.real_model = self.model.forge_patch_model(patch_model_to)
|
||||
self.model.current_device = self.model.load_device
|
||||
except Exception as e:
|
||||
self.model.forge_unpatch_model(self.model.offload_device)
|
||||
self.model_unload()
|
||||
|
||||
@@ -5,12 +5,53 @@ import torch
|
||||
import contextlib
|
||||
|
||||
from backend import stream, memory_management, utils
|
||||
from backend.patcher.lora import merge_lora_to_weight
|
||||
|
||||
|
||||
stash = {}
|
||||
|
||||
|
||||
def weights_manual_cast(layer, x, skip_weight_dtype=False, skip_bias_dtype=False):
|
||||
def get_weight_and_bias(layer, weight_args=None, bias_args=None, weight_fn=None, bias_fn=None):
|
||||
patches = getattr(layer, 'forge_online_loras', None)
|
||||
weight_patches, bias_patches = None, None
|
||||
|
||||
if patches is not None:
|
||||
weight_patches = patches.get('weight', None)
|
||||
|
||||
if patches is not None:
|
||||
bias_patches = patches.get('bias', None)
|
||||
|
||||
weight = None
|
||||
if layer.weight is not None:
|
||||
weight = layer.weight
|
||||
if weight_fn is not None:
|
||||
if weight_args is not None:
|
||||
fn_device = weight_args.get('device', None)
|
||||
if fn_device is not None:
|
||||
weight = weight.to(device=fn_device)
|
||||
weight = weight_fn(weight)
|
||||
if weight_args is not None:
|
||||
weight = weight.to(**weight_args)
|
||||
if weight_patches is not None:
|
||||
weight = merge_lora_to_weight(patches=weight_patches, weight=weight, key="online weight lora", computation_dtype=weight.dtype)
|
||||
|
||||
bias = None
|
||||
if layer.bias is not None:
|
||||
bias = layer.bias
|
||||
if bias_fn is not None:
|
||||
if bias_args is not None:
|
||||
fn_device = bias_args.get('device', None)
|
||||
if fn_device is not None:
|
||||
bias = bias.to(device=fn_device)
|
||||
bias = bias_fn(bias)
|
||||
if bias_args is not None:
|
||||
bias = bias.to(**bias_args)
|
||||
if bias_patches is not None:
|
||||
bias = merge_lora_to_weight(patches=bias_patches, weight=bias, key="online bias lora", computation_dtype=bias.dtype)
|
||||
return weight, bias
|
||||
|
||||
|
||||
def weights_manual_cast(layer, x, skip_weight_dtype=False, skip_bias_dtype=False, weight_fn=None, bias_fn=None):
|
||||
weight, bias, signal = None, None, None
|
||||
non_blocking = True
|
||||
|
||||
@@ -32,16 +73,10 @@ def weights_manual_cast(layer, x, skip_weight_dtype=False, skip_bias_dtype=False
|
||||
|
||||
if stream.should_use_stream():
|
||||
with stream.stream_context()(stream.mover_stream):
|
||||
if layer.weight is not None:
|
||||
weight = layer.weight.to(**weight_args)
|
||||
if layer.bias is not None:
|
||||
bias = layer.bias.to(**bias_args)
|
||||
weight, bias = get_weight_and_bias(layer, weight_args, bias_args, weight_fn=weight_fn, bias_fn=bias_fn)
|
||||
signal = stream.mover_stream.record_event()
|
||||
else:
|
||||
if layer.weight is not None:
|
||||
weight = layer.weight.to(**weight_args)
|
||||
if layer.bias is not None:
|
||||
bias = layer.bias.to(**bias_args)
|
||||
weight, bias = get_weight_and_bias(layer, weight_args, bias_args, weight_fn=weight_fn, bias_fn=bias_fn)
|
||||
|
||||
return weight, bias, signal
|
||||
|
||||
@@ -109,7 +144,8 @@ class ForgeOperations:
|
||||
with main_stream_worker(weight, bias, signal):
|
||||
return torch.nn.functional.linear(x, weight, bias)
|
||||
else:
|
||||
return torch.nn.functional.linear(x, self.weight, self.bias)
|
||||
weight, bias = get_weight_and_bias(self)
|
||||
return torch.nn.functional.linear(x, weight, bias)
|
||||
|
||||
class Conv2d(torch.nn.Conv2d):
|
||||
|
||||
@@ -128,7 +164,8 @@ class ForgeOperations:
|
||||
with main_stream_worker(weight, bias, signal):
|
||||
return self._conv_forward(x, weight, bias)
|
||||
else:
|
||||
return super().forward(x)
|
||||
weight, bias = get_weight_and_bias(self)
|
||||
return super()._conv_forward(x, weight, bias)
|
||||
|
||||
class Conv3d(torch.nn.Conv3d):
|
||||
|
||||
@@ -147,7 +184,8 @@ class ForgeOperations:
|
||||
with main_stream_worker(weight, bias, signal):
|
||||
return self._conv_forward(x, weight, bias)
|
||||
else:
|
||||
return super().forward(x)
|
||||
weight, bias = get_weight_and_bias(self)
|
||||
return super()._conv_forward(input, weight, bias)
|
||||
|
||||
class Conv1d(torch.nn.Conv1d):
|
||||
|
||||
@@ -166,7 +204,8 @@ class ForgeOperations:
|
||||
with main_stream_worker(weight, bias, signal):
|
||||
return self._conv_forward(x, weight, bias)
|
||||
else:
|
||||
return super().forward(x)
|
||||
weight, bias = get_weight_and_bias(self)
|
||||
return super()._conv_forward(input, weight, bias)
|
||||
|
||||
class ConvTranspose2d(torch.nn.ConvTranspose2d):
|
||||
|
||||
@@ -188,7 +227,10 @@ class ForgeOperations:
|
||||
with main_stream_worker(weight, bias, signal):
|
||||
return torch.nn.functional.conv_transpose2d(x, weight, bias, self.stride, self.padding, output_padding, self.groups, self.dilation)
|
||||
else:
|
||||
return super().forward(x, output_size)
|
||||
weight, bias = get_weight_and_bias(self)
|
||||
num_spatial_dims = 2
|
||||
output_padding = self._output_padding(x, output_size, self.stride, self.padding, self.kernel_size, num_spatial_dims, self.dilation)
|
||||
return torch.nn.functional.conv_transpose2d(x, weight, bias, self.stride, self.padding, output_padding, self.groups, self.dilation)
|
||||
|
||||
class ConvTranspose1d(torch.nn.ConvTranspose1d):
|
||||
|
||||
@@ -210,7 +252,10 @@ class ForgeOperations:
|
||||
with main_stream_worker(weight, bias, signal):
|
||||
return torch.nn.functional.conv_transpose1d(x, weight, bias, self.stride, self.padding, output_padding, self.groups, self.dilation)
|
||||
else:
|
||||
return super().forward(x, output_size)
|
||||
weight, bias = get_weight_and_bias(self)
|
||||
num_spatial_dims = 1
|
||||
output_padding = self._output_padding(x, output_size, self.stride, self.padding, self.kernel_size, num_spatial_dims, self.dilation)
|
||||
return torch.nn.functional.conv_transpose2d(x, weight, bias, self.stride, self.padding, output_padding, self.groups, self.dilation)
|
||||
|
||||
class ConvTranspose3d(torch.nn.ConvTranspose3d):
|
||||
|
||||
@@ -232,7 +277,10 @@ class ForgeOperations:
|
||||
with main_stream_worker(weight, bias, signal):
|
||||
return torch.nn.functional.conv_transpose3d(x, weight, bias, self.stride, self.padding, output_padding, self.groups, self.dilation)
|
||||
else:
|
||||
return super().forward(x, output_size)
|
||||
weight, bias = get_weight_and_bias(self)
|
||||
num_spatial_dims = 3
|
||||
output_padding = self._output_padding(x, output_size, self.stride, self.padding, self.kernel_size, num_spatial_dims, self.dilation)
|
||||
return torch.nn.functional.conv_transpose2d(x, weight, bias, self.stride, self.padding, output_padding, self.groups, self.dilation)
|
||||
|
||||
class GroupNorm(torch.nn.GroupNorm):
|
||||
|
||||
@@ -328,7 +376,7 @@ except:
|
||||
bnb_avaliable = False
|
||||
|
||||
|
||||
from backend.operations_gguf import functional_linear_gguf
|
||||
from backend.operations_gguf import dequantize_tensor
|
||||
|
||||
|
||||
class ForgeOperationsGGUF(ForgeOperations):
|
||||
@@ -361,12 +409,9 @@ class ForgeOperationsGGUF(ForgeOperations):
|
||||
return self
|
||||
|
||||
def forward(self, x):
|
||||
if self.parameters_manual_cast:
|
||||
weight, bias, signal = weights_manual_cast(self, x, skip_weight_dtype=True)
|
||||
with main_stream_worker(weight, bias, signal):
|
||||
return functional_linear_gguf(x, weight, bias)
|
||||
else:
|
||||
return functional_linear_gguf(x, self.weight, self.bias)
|
||||
weight, bias, signal = weights_manual_cast(self, x, weight_fn=dequantize_tensor, bias_fn=dequantize_tensor)
|
||||
with main_stream_worker(weight, bias, signal):
|
||||
return torch.nn.functional.linear(x, weight, bias)
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
|
||||
@@ -58,13 +58,6 @@ class ParameterGGUF(torch.nn.Parameter):
|
||||
return new
|
||||
|
||||
|
||||
def functional_linear_gguf(x, weight, bias=None):
|
||||
target_dtype = x.dtype
|
||||
weight = dequantize_tensor(weight).to(target_dtype)
|
||||
bias = dequantize_tensor(bias).to(target_dtype)
|
||||
return torch.nn.functional.linear(x, weight, bias)
|
||||
|
||||
|
||||
def dequantize_tensor(tensor):
|
||||
if tensor is None:
|
||||
return None
|
||||
|
||||
@@ -4,14 +4,12 @@
|
||||
# are from Forge, implemented from scratch (after forge-v1.0.1), and may have
|
||||
# certain level of differences.
|
||||
|
||||
import time
|
||||
import torch
|
||||
|
||||
import copy
|
||||
import inspect
|
||||
|
||||
from tqdm import tqdm
|
||||
from backend import memory_management, utils, operations
|
||||
from backend.patcher.lora import merge_lora_to_model_weight, LoraLoader
|
||||
from backend import memory_management, utils
|
||||
from backend.patcher.lora import LoraLoader
|
||||
|
||||
|
||||
def set_model_options_patch_replace(model_options, patch, name, block_name, number, transformer_index=None):
|
||||
@@ -229,7 +227,6 @@ class ModelPatcher:
|
||||
|
||||
if target_device is not None:
|
||||
self.model.to(target_device)
|
||||
self.current_device = target_device
|
||||
|
||||
return self.model
|
||||
|
||||
|
||||
@@ -5,7 +5,8 @@ import packages_3rdparty.webui_lora_collection.lora as lora_utils_webui
|
||||
import packages_3rdparty.comfyui_lora_collection.lora as lora_utils_comfyui
|
||||
|
||||
from tqdm import tqdm
|
||||
from backend import memory_management, utils, operations
|
||||
from backend import memory_management, utils
|
||||
from backend.args import dynamic_args
|
||||
|
||||
|
||||
class ForgeLoraCollection:
|
||||
@@ -39,8 +40,10 @@ def model_lora_keys_unet(model, key_map={}):
|
||||
return get_function('model_lora_keys_unet')(model, key_map)
|
||||
|
||||
|
||||
def weight_decompose(dora_scale, weight, lora_diff, alpha, strength):
|
||||
dora_scale = memory_management.cast_to_device(dora_scale, weight.device, torch.float32)
|
||||
def weight_decompose(dora_scale, weight, lora_diff, alpha, strength, computation_dtype):
|
||||
# Modified from https://github.com/comfyanonymous/ComfyUI/blob/39f114c44bb99d4a221e8da451d4f2a20119c674/comfy/model_patcher.py#L33
|
||||
|
||||
dora_scale = memory_management.cast_to_device(dora_scale, weight.device, computation_dtype)
|
||||
lora_diff *= alpha
|
||||
weight_calc = weight + lora_diff.type(weight.dtype)
|
||||
weight_norm = (
|
||||
@@ -60,7 +63,12 @@ def weight_decompose(dora_scale, weight, lora_diff, alpha, strength):
|
||||
return weight
|
||||
|
||||
|
||||
def merge_lora_to_model_weight(patches, weight, key):
|
||||
def merge_lora_to_weight(patches, weight, key="online_lora", computation_dtype=torch.float32):
|
||||
# Modified from https://github.com/comfyanonymous/ComfyUI/blob/39f114c44bb99d4a221e8da451d4f2a20119c674/comfy/model_patcher.py#L446
|
||||
|
||||
weight_original_dtype = weight.dtype
|
||||
weight = weight.to(dtype=computation_dtype)
|
||||
|
||||
for p in patches:
|
||||
strength = p[0]
|
||||
v = p[1]
|
||||
@@ -79,7 +87,7 @@ def merge_lora_to_model_weight(patches, weight, key):
|
||||
weight *= strength_model
|
||||
|
||||
if isinstance(v, list):
|
||||
v = (merge_lora_to_model_weight(v[1:], v[0].clone(), key),)
|
||||
v = (merge_lora_to_weight(v[1:], v[0].clone(), key),)
|
||||
|
||||
patch_type = ''
|
||||
|
||||
@@ -107,8 +115,8 @@ def merge_lora_to_model_weight(patches, weight, key):
|
||||
else:
|
||||
weight += strength * memory_management.cast_to_device(w1, weight.device, weight.dtype)
|
||||
elif patch_type == "lora":
|
||||
mat1 = memory_management.cast_to_device(v[0], weight.device, torch.float32)
|
||||
mat2 = memory_management.cast_to_device(v[1], weight.device, torch.float32)
|
||||
mat1 = memory_management.cast_to_device(v[0], weight.device, computation_dtype)
|
||||
mat2 = memory_management.cast_to_device(v[1], weight.device, computation_dtype)
|
||||
dora_scale = v[4]
|
||||
if v[2] is not None:
|
||||
alpha = v[2] / mat2.shape[0]
|
||||
@@ -116,13 +124,13 @@ def merge_lora_to_model_weight(patches, weight, key):
|
||||
alpha = 1.0
|
||||
|
||||
if v[3] is not None:
|
||||
mat3 = memory_management.cast_to_device(v[3], weight.device, torch.float32)
|
||||
mat3 = memory_management.cast_to_device(v[3], weight.device, computation_dtype)
|
||||
final_shape = [mat2.shape[1], mat2.shape[0], mat3.shape[2], mat3.shape[3]]
|
||||
mat2 = torch.mm(mat2.transpose(0, 1).flatten(start_dim=1), mat3.transpose(0, 1).flatten(start_dim=1)).reshape(final_shape).transpose(0, 1)
|
||||
try:
|
||||
lora_diff = torch.mm(mat1.flatten(start_dim=1), mat2.flatten(start_dim=1)).reshape(weight.shape)
|
||||
if dora_scale is not None:
|
||||
weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength))
|
||||
weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength, computation_dtype))
|
||||
else:
|
||||
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
|
||||
except Exception as e:
|
||||
@@ -141,23 +149,23 @@ def merge_lora_to_model_weight(patches, weight, key):
|
||||
|
||||
if w1 is None:
|
||||
dim = w1_b.shape[0]
|
||||
w1 = torch.mm(memory_management.cast_to_device(w1_a, weight.device, torch.float32),
|
||||
memory_management.cast_to_device(w1_b, weight.device, torch.float32))
|
||||
w1 = torch.mm(memory_management.cast_to_device(w1_a, weight.device, computation_dtype),
|
||||
memory_management.cast_to_device(w1_b, weight.device, computation_dtype))
|
||||
else:
|
||||
w1 = memory_management.cast_to_device(w1, weight.device, torch.float32)
|
||||
w1 = memory_management.cast_to_device(w1, weight.device, computation_dtype)
|
||||
|
||||
if w2 is None:
|
||||
dim = w2_b.shape[0]
|
||||
if t2 is None:
|
||||
w2 = torch.mm(memory_management.cast_to_device(w2_a, weight.device, torch.float32),
|
||||
memory_management.cast_to_device(w2_b, weight.device, torch.float32))
|
||||
w2 = torch.mm(memory_management.cast_to_device(w2_a, weight.device, computation_dtype),
|
||||
memory_management.cast_to_device(w2_b, weight.device, computation_dtype))
|
||||
else:
|
||||
w2 = torch.einsum('i j k l, j r, i p -> p r k l',
|
||||
memory_management.cast_to_device(t2, weight.device, torch.float32),
|
||||
memory_management.cast_to_device(w2_b, weight.device, torch.float32),
|
||||
memory_management.cast_to_device(w2_a, weight.device, torch.float32))
|
||||
memory_management.cast_to_device(t2, weight.device, computation_dtype),
|
||||
memory_management.cast_to_device(w2_b, weight.device, computation_dtype),
|
||||
memory_management.cast_to_device(w2_a, weight.device, computation_dtype))
|
||||
else:
|
||||
w2 = memory_management.cast_to_device(w2, weight.device, torch.float32)
|
||||
w2 = memory_management.cast_to_device(w2, weight.device, computation_dtype)
|
||||
|
||||
if len(w2.shape) == 4:
|
||||
w1 = w1.unsqueeze(2).unsqueeze(2)
|
||||
@@ -169,7 +177,7 @@ def merge_lora_to_model_weight(patches, weight, key):
|
||||
try:
|
||||
lora_diff = torch.kron(w1, w2).reshape(weight.shape)
|
||||
if dora_scale is not None:
|
||||
weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength))
|
||||
weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength, computation_dtype))
|
||||
else:
|
||||
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
|
||||
except Exception as e:
|
||||
@@ -190,24 +198,24 @@ def merge_lora_to_model_weight(patches, weight, key):
|
||||
t1 = v[5]
|
||||
t2 = v[6]
|
||||
m1 = torch.einsum('i j k l, j r, i p -> p r k l',
|
||||
memory_management.cast_to_device(t1, weight.device, torch.float32),
|
||||
memory_management.cast_to_device(w1b, weight.device, torch.float32),
|
||||
memory_management.cast_to_device(w1a, weight.device, torch.float32))
|
||||
memory_management.cast_to_device(t1, weight.device, computation_dtype),
|
||||
memory_management.cast_to_device(w1b, weight.device, computation_dtype),
|
||||
memory_management.cast_to_device(w1a, weight.device, computation_dtype))
|
||||
|
||||
m2 = torch.einsum('i j k l, j r, i p -> p r k l',
|
||||
memory_management.cast_to_device(t2, weight.device, torch.float32),
|
||||
memory_management.cast_to_device(w2b, weight.device, torch.float32),
|
||||
memory_management.cast_to_device(w2a, weight.device, torch.float32))
|
||||
memory_management.cast_to_device(t2, weight.device, computation_dtype),
|
||||
memory_management.cast_to_device(w2b, weight.device, computation_dtype),
|
||||
memory_management.cast_to_device(w2a, weight.device, computation_dtype))
|
||||
else:
|
||||
m1 = torch.mm(memory_management.cast_to_device(w1a, weight.device, torch.float32),
|
||||
memory_management.cast_to_device(w1b, weight.device, torch.float32))
|
||||
m2 = torch.mm(memory_management.cast_to_device(w2a, weight.device, torch.float32),
|
||||
memory_management.cast_to_device(w2b, weight.device, torch.float32))
|
||||
m1 = torch.mm(memory_management.cast_to_device(w1a, weight.device, computation_dtype),
|
||||
memory_management.cast_to_device(w1b, weight.device, computation_dtype))
|
||||
m2 = torch.mm(memory_management.cast_to_device(w2a, weight.device, computation_dtype),
|
||||
memory_management.cast_to_device(w2b, weight.device, computation_dtype))
|
||||
|
||||
try:
|
||||
lora_diff = (m1 * m2).reshape(weight.shape)
|
||||
if dora_scale is not None:
|
||||
weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength))
|
||||
weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength, computation_dtype))
|
||||
else:
|
||||
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
|
||||
except Exception as e:
|
||||
@@ -221,15 +229,15 @@ def merge_lora_to_model_weight(patches, weight, key):
|
||||
|
||||
dora_scale = v[5]
|
||||
|
||||
a1 = memory_management.cast_to_device(v[0].flatten(start_dim=1), weight.device, torch.float32)
|
||||
a2 = memory_management.cast_to_device(v[1].flatten(start_dim=1), weight.device, torch.float32)
|
||||
b1 = memory_management.cast_to_device(v[2].flatten(start_dim=1), weight.device, torch.float32)
|
||||
b2 = memory_management.cast_to_device(v[3].flatten(start_dim=1), weight.device, torch.float32)
|
||||
a1 = memory_management.cast_to_device(v[0].flatten(start_dim=1), weight.device, computation_dtype)
|
||||
a2 = memory_management.cast_to_device(v[1].flatten(start_dim=1), weight.device, computation_dtype)
|
||||
b1 = memory_management.cast_to_device(v[2].flatten(start_dim=1), weight.device, computation_dtype)
|
||||
b2 = memory_management.cast_to_device(v[3].flatten(start_dim=1), weight.device, computation_dtype)
|
||||
|
||||
try:
|
||||
lora_diff = (torch.mm(b2, b1) + torch.mm(torch.mm(weight.flatten(start_dim=1), a2), a1)).reshape(weight.shape)
|
||||
if dora_scale is not None:
|
||||
weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength))
|
||||
weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength, computation_dtype))
|
||||
else:
|
||||
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
|
||||
except Exception as e:
|
||||
@@ -243,14 +251,19 @@ def merge_lora_to_model_weight(patches, weight, key):
|
||||
if old_weight is not None:
|
||||
weight = old_weight
|
||||
|
||||
weight = weight.to(dtype=weight_original_dtype)
|
||||
return weight
|
||||
|
||||
|
||||
from backend import operations
|
||||
|
||||
|
||||
class LoraLoader:
|
||||
def __init__(self, model):
|
||||
self.model = model
|
||||
self.patches = {}
|
||||
self.backup = {}
|
||||
self.online_backup = []
|
||||
self.dirty = False
|
||||
|
||||
def clear_patches(self):
|
||||
@@ -277,7 +290,7 @@ class LoraLoader:
|
||||
if key in model_sd:
|
||||
p.add(k)
|
||||
current_patches = self.patches.get(key, [])
|
||||
current_patches.append((strength_patch, patches[k], strength_model, offset, function))
|
||||
current_patches.append([strength_patch, patches[k], strength_model, offset, function])
|
||||
self.patches[key] = current_patches
|
||||
|
||||
self.dirty = True
|
||||
@@ -293,10 +306,12 @@ class LoraLoader:
|
||||
|
||||
# Restore
|
||||
|
||||
for k, w in self.backup.items():
|
||||
if target_device is not None:
|
||||
w = w.to(device=target_device)
|
||||
for m in set(self.online_backup):
|
||||
del m.forge_online_loras
|
||||
|
||||
self.online_backup = []
|
||||
|
||||
for k, w in self.backup.items():
|
||||
if not isinstance(w, torch.nn.Parameter):
|
||||
# In very few cases
|
||||
w = torch.nn.Parameter(w, requires_grad=False)
|
||||
@@ -305,11 +320,13 @@ class LoraLoader:
|
||||
|
||||
self.backup = {}
|
||||
|
||||
online_mode = dynamic_args.get('online_lora', False)
|
||||
|
||||
# Patch
|
||||
|
||||
for key, current_patches in (tqdm(self.patches.items(), desc=f'Patching LoRAs for {type(self.model).__name__}') if len(self.patches) > 0 else self.patches):
|
||||
try:
|
||||
parent_layer, weight = utils.get_attr_with_parent(self.model, key)
|
||||
parent_layer, child_key, weight = utils.get_attr_with_parent(self.model, key)
|
||||
assert isinstance(weight, torch.nn.Parameter)
|
||||
except:
|
||||
raise ValueError(f"Wrong LoRA Key: {key}")
|
||||
@@ -317,6 +334,14 @@ class LoraLoader:
|
||||
if key not in self.backup:
|
||||
self.backup[key] = weight.to(device=offload_device)
|
||||
|
||||
if online_mode:
|
||||
if not hasattr(parent_layer, 'forge_online_loras'):
|
||||
parent_layer.forge_online_loras = {}
|
||||
|
||||
parent_layer.forge_online_loras[child_key] = current_patches
|
||||
self.online_backup.append(parent_layer)
|
||||
continue
|
||||
|
||||
bnb_layer = None
|
||||
|
||||
if operations.bnb_avaliable:
|
||||
@@ -357,17 +382,13 @@ class LoraLoader:
|
||||
gguf_real_shape = weight.gguf_real_shape
|
||||
weight = dequantize_tensor(weight)
|
||||
|
||||
weight_original_dtype = weight.dtype
|
||||
|
||||
try:
|
||||
weight = weight.to(dtype=torch.float32)
|
||||
weight = merge_lora_to_model_weight(current_patches, weight, key).to(dtype=weight_original_dtype)
|
||||
weight = merge_lora_to_weight(current_patches, weight, key, computation_dtype=torch.float32)
|
||||
except:
|
||||
print('Patching LoRA weights failed. Retrying by offloading models.')
|
||||
self.model.to(device=offload_device)
|
||||
memory_management.soft_empty_cache()
|
||||
weight = weight.to(dtype=torch.float32)
|
||||
weight = merge_lora_to_model_weight(current_patches, weight, key).to(dtype=weight_original_dtype)
|
||||
weight = merge_lora_to_weight(current_patches, weight, key, computation_dtype=torch.float32)
|
||||
|
||||
if bnb_layer is not None:
|
||||
bnb_layer.reload_weight(weight)
|
||||
|
||||
@@ -10,6 +10,8 @@ import collections
|
||||
from backend import memory_management
|
||||
from backend.sampling.condition import Condition, compile_conditions, compile_weighted_conditions
|
||||
from backend.operations import cleanup_cache
|
||||
from backend.args import dynamic_args
|
||||
from backend import utils
|
||||
|
||||
|
||||
def get_area_and_mult(conds, x_in, timestep_in):
|
||||
@@ -353,10 +355,19 @@ def sampling_prepare(unet, x):
|
||||
additional_inference_memory += unet.controlnet_linked_list.inference_memory_requirements(unet.model_dtype())
|
||||
additional_model_patchers += unet.controlnet_linked_list.get_models()
|
||||
|
||||
if dynamic_args.get('online_lora', False):
|
||||
lora_memory = utils.nested_compute_size(unet.lora_loader.patches)
|
||||
additional_inference_memory += lora_memory
|
||||
|
||||
memory_management.load_models_gpu(
|
||||
models=[unet] + additional_model_patchers,
|
||||
memory_required=unet_inference_memory + additional_inference_memory)
|
||||
|
||||
if dynamic_args.get('online_lora', False):
|
||||
utils.nested_move_to_device(unet.lora_loader.patches, device=unet.current_device)
|
||||
|
||||
unet.lora_loader.patches = {}
|
||||
|
||||
real_model = unet.model
|
||||
|
||||
percent_to_timestep_function = lambda p: real_model.predictor.percent_to_sigma(p)
|
||||
|
||||
@@ -79,10 +79,11 @@ def get_attr(obj, attr):
|
||||
def get_attr_with_parent(obj, attr):
|
||||
attrs = attr.split(".")
|
||||
parent = obj
|
||||
name = None
|
||||
for name in attrs:
|
||||
parent = obj
|
||||
obj = getattr(obj, name)
|
||||
return parent, obj
|
||||
return parent, name, obj
|
||||
|
||||
|
||||
def calculate_parameters(sd, prefix=""):
|
||||
@@ -108,3 +109,32 @@ def fp16_fix(x):
|
||||
if x.dtype in [torch.float16]:
|
||||
return x.clip(-32768.0, 32768.0)
|
||||
return x
|
||||
|
||||
|
||||
def nested_compute_size(obj):
|
||||
module_mem = 0
|
||||
|
||||
if isinstance(obj, dict):
|
||||
for key in obj:
|
||||
module_mem += nested_compute_size(obj[key])
|
||||
elif isinstance(obj, list) or isinstance(obj, tuple):
|
||||
for i in range(len(obj)):
|
||||
module_mem += nested_compute_size(obj[i])
|
||||
elif isinstance(obj, torch.Tensor):
|
||||
module_mem += obj.nelement() * obj.element_size()
|
||||
|
||||
return module_mem
|
||||
|
||||
|
||||
def nested_move_to_device(obj, device):
|
||||
if isinstance(obj, dict):
|
||||
for key in obj:
|
||||
obj[key] = nested_move_to_device(obj[key], device)
|
||||
elif isinstance(obj, list):
|
||||
for i in range(len(obj)):
|
||||
obj[i] = nested_move_to_device(obj[i], device)
|
||||
elif isinstance(obj, tuple):
|
||||
obj = tuple(nested_move_to_device(i, device) for i in obj)
|
||||
elif isinstance(obj, torch.Tensor):
|
||||
return obj.to(device)
|
||||
return obj
|
||||
|
||||
@@ -1,3 +1,6 @@
|
||||
# import faulthandler
|
||||
# faulthandler.enable()
|
||||
|
||||
from modules import launch_utils
|
||||
|
||||
args = launch_utils.args
|
||||
|
||||
@@ -799,6 +799,7 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
|
||||
memory_management.unload_all_models()
|
||||
|
||||
if need_global_unload:
|
||||
p.sd_model.forge_objects.unet.lora_loader.dirty = True
|
||||
p.clear_prompt_cache()
|
||||
|
||||
need_global_unload = False
|
||||
|
||||
@@ -5,6 +5,7 @@ import gradio as gr
|
||||
from gradio.context import Context
|
||||
from modules import shared_items, shared, ui_common, sd_models, processing, infotext_utils, paths
|
||||
from backend import memory_management, stream
|
||||
from backend.args import dynamic_args
|
||||
|
||||
|
||||
total_vram = int(memory_management.total_vram)
|
||||
@@ -21,11 +22,14 @@ ui_forge_pin_shared_memory: gr.Radio = None
|
||||
ui_forge_inference_memory: gr.Slider = None
|
||||
|
||||
forge_unet_storage_dtype_options = {
|
||||
'Automatic': None,
|
||||
'bnb-nf4': 'nf4',
|
||||
'float8-e4m3fn': torch.float8_e4m3fn,
|
||||
'bnb-fp4': 'fp4',
|
||||
'float8-e5m2': torch.float8_e5m2,
|
||||
'Automatic': (None, False),
|
||||
'Automatic (fp16 LoRA)': (None, True),
|
||||
'bnb-nf4': ('nf4', False),
|
||||
'float8-e4m3fn': (torch.float8_e4m3fn, False),
|
||||
'float8-e4m3fn (fp16 LoRA)': (torch.float8_e4m3fn, True),
|
||||
'bnb-fp4': ('fp4', False),
|
||||
'float8-e5m2': (torch.float8_e5m2, False),
|
||||
'float8-e5m2 (fp16 LoRA)': (torch.float8_e5m2, True),
|
||||
}
|
||||
|
||||
module_list = {}
|
||||
@@ -180,10 +184,14 @@ def refresh_model_loading_parameters():
|
||||
|
||||
checkpoint_info = select_checkpoint()
|
||||
|
||||
unet_storage_dtype, lora_fp16 = forge_unet_storage_dtype_options.get(shared.opts.forge_unet_storage_dtype, (None, False))
|
||||
|
||||
dynamic_args['online_lora'] = lora_fp16
|
||||
|
||||
model_data.forge_loading_parameters = dict(
|
||||
checkpoint_info=checkpoint_info,
|
||||
additional_modules=shared.opts.forge_additional_modules,
|
||||
unet_storage_dtype=forge_unet_storage_dtype_options.get(shared.opts.forge_unet_storage_dtype, None)
|
||||
unet_storage_dtype=unet_storage_dtype
|
||||
)
|
||||
|
||||
print(f'Model selected: {model_data.forge_loading_parameters}')
|
||||
|
||||
71
packages_3rdparty/gguf/quants.py
vendored
71
packages_3rdparty/gguf/quants.py
vendored
@@ -158,7 +158,7 @@ class __Quant(ABC):
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def quantize_blocks_pytorch(cls, blocks, block_size, type_size) -> torch.Tensor:
|
||||
raise NotImplementedError
|
||||
raise NotImplementedError('Low bit LoRA for this data type is not implemented yet. Please select "Automatic (fp16 LoRA)" to use this LoRA.')
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
@@ -370,30 +370,6 @@ class Q4_1(__Quant, qtype=GGMLQuantizationType.Q4_1):
|
||||
|
||||
return (d * qs) + m
|
||||
|
||||
@classmethod
|
||||
def quantize_blocks_pytorch(cls, blocks, block_size, type_size) -> torch.Tensor:
|
||||
# WIP
|
||||
|
||||
raise NotImplementedError('Q4_1 Lora is under construction!')
|
||||
|
||||
n_blocks = blocks.shape[0]
|
||||
|
||||
max_vals = blocks.max(dim=-1, keepdim=True).values
|
||||
min_vals = blocks.min(dim=-1, keepdim=True).values
|
||||
|
||||
d = (max_vals - min_vals) / 15
|
||||
id = torch.where(d == 0, torch.tensor(0.0, device=d.device), 1 / d)
|
||||
|
||||
qs = torch.trunc((blocks - min_vals) * id + 0.5).to(torch.uint8).clip(0, 15)
|
||||
|
||||
qs = qs.view(n_blocks, 2, block_size // 2)
|
||||
qs = qs[:, 0, :] | (qs[:, 1, :] << 4)
|
||||
|
||||
d = d.to(torch.float16).view(n_blocks, -1)
|
||||
m = min_vals.to(torch.float16).view(n_blocks, -1)
|
||||
|
||||
return torch.cat([d, m, qs], dim=-1)
|
||||
|
||||
|
||||
class Q5_0(__Quant, qtype=GGMLQuantizationType.Q5_0):
|
||||
@classmethod
|
||||
@@ -567,31 +543,6 @@ class Q5_1(__Quant, qtype=GGMLQuantizationType.Q5_1):
|
||||
qs = (ql | (qh << 4))
|
||||
return (d * qs) + m
|
||||
|
||||
@classmethod
|
||||
def quantize_blocks_pytorch(cls, blocks, block_size, type_size) -> torch.Tensor:
|
||||
# WIP
|
||||
|
||||
raise NotImplementedError('Q5_1 Lora is under construction!')
|
||||
|
||||
n_blocks = blocks.shape[0]
|
||||
|
||||
max_val = blocks.max(dim=-1, keepdim=True)[0]
|
||||
min_val = blocks.min(dim=-1, keepdim=True)[0]
|
||||
|
||||
d = (max_val - min_val) / 31
|
||||
id = torch.where(d == 0, torch.zeros_like(d), 1.0 / d)
|
||||
q = torch.trunc((blocks - min_val) * id + 0.5).clip(0, 31).to(torch.uint8)
|
||||
|
||||
qs = q.view(n_blocks, 2, block_size // 2)
|
||||
qs = (qs[..., 0, :] & 0x0F) | (qs[..., 1, :] << 4)
|
||||
|
||||
qh = torch.bitwise_right_shift(q.view(n_blocks, 1, 32), torch.arange(4, dtype=torch.uint8, device=blocks.device) * 8).byte()
|
||||
|
||||
d = d.to(torch.float16).view(-1, 2)
|
||||
min_val = min_val.to(torch.float16).view(-1, 2)
|
||||
|
||||
return torch.cat([d.byte(), min_val.byte(), qh, qs], dim=-1)
|
||||
|
||||
|
||||
class Q8_0(__Quant, qtype=GGMLQuantizationType.Q8_0):
|
||||
@classmethod
|
||||
@@ -677,10 +628,6 @@ class Q2_K(__Quant, qtype=GGMLQuantizationType.Q2_K):
|
||||
qs = dl * qs - ml
|
||||
return qs.reshape((n_blocks, -1))
|
||||
|
||||
@classmethod
|
||||
def quantize_blocks_pytorch(cls, blocks, block_size, type_size) -> torch.Tensor:
|
||||
raise NotImplementedError('Not Implemented Yet')
|
||||
|
||||
|
||||
class Q3_K(__Quant, qtype=GGMLQuantizationType.Q3_K):
|
||||
@classmethod
|
||||
@@ -746,10 +693,6 @@ class Q3_K(__Quant, qtype=GGMLQuantizationType.Q3_K):
|
||||
q = (ql.to(torch.int8) - (qh << 2).to(torch.int8))
|
||||
return (dl * q).reshape((n_blocks, QK_K))
|
||||
|
||||
@classmethod
|
||||
def quantize_blocks_pytorch(cls, blocks, block_size, type_size) -> torch.Tensor:
|
||||
raise NotImplementedError('Not Implemented Yet')
|
||||
|
||||
|
||||
class Q4_K(__Quant, qtype=GGMLQuantizationType.Q4_K):
|
||||
K_SCALE_SIZE = 12
|
||||
@@ -826,10 +769,6 @@ class Q4_K(__Quant, qtype=GGMLQuantizationType.Q4_K):
|
||||
qs = (qs & 0x0F).reshape((n_blocks, -1, 32))
|
||||
return (d * qs - dm).reshape((n_blocks, QK_K))
|
||||
|
||||
@classmethod
|
||||
def quantize_blocks_pytorch(cls, blocks, block_size, type_size) -> torch.Tensor:
|
||||
raise NotImplementedError('Not Implemented Yet')
|
||||
|
||||
|
||||
class Q5_K(__Quant, qtype=GGMLQuantizationType.Q5_K):
|
||||
@classmethod
|
||||
@@ -876,10 +815,6 @@ class Q5_K(__Quant, qtype=GGMLQuantizationType.Q5_K):
|
||||
q = (ql | (qh << 4))
|
||||
return (d * q - dm).reshape((n_blocks, QK_K))
|
||||
|
||||
@classmethod
|
||||
def quantize_blocks_pytorch(cls, blocks, block_size, type_size) -> torch.Tensor:
|
||||
raise NotImplementedError('Not Implemented Yet')
|
||||
|
||||
|
||||
class Q6_K(__Quant, qtype=GGMLQuantizationType.Q6_K):
|
||||
@classmethod
|
||||
@@ -919,10 +854,6 @@ class Q6_K(__Quant, qtype=GGMLQuantizationType.Q6_K):
|
||||
q = q.reshape((n_blocks, QK_K // 16, -1))
|
||||
return (d * q).reshape((n_blocks, QK_K))
|
||||
|
||||
@classmethod
|
||||
def quantize_blocks_pytorch(cls, blocks, block_size, type_size) -> torch.Tensor:
|
||||
raise NotImplementedError('Not Implemented Yet')
|
||||
|
||||
|
||||
class IQ2_XXS(__Quant, qtype=GGMLQuantizationType.IQ2_XXS):
|
||||
ksigns: bytes = (
|
||||
|
||||
Reference in New Issue
Block a user