Implement some rethinking about LoRA system

1. Add an option to allow users to use UNet in fp8/gguf but lora in fp16.
2. All FP16 loras do not need patch. Others will only patch again when lora weight change.
3. FP8 unet + fp16 lora are available (somewhat only available) in Forge now. This also solves some “LoRA too subtle” problems.
4. Significantly speed up all gguf models (in Async mode) by using independent thread (CUDA stream) to compute and dequant at the same time, even when low-bit weights are already on GPU.
5. View “online lora” as a module similar to ControlLoRA so that it is moved to GPU together with model when sampling, achieving significant speedup and perfect low VRAM management simultaneously.
This commit is contained in:
layerdiffusion
2024-08-19 04:31:00 -07:00
parent e5f213c21e
commit d38e560e42
11 changed files with 200 additions and 159 deletions

View File

@@ -378,6 +378,7 @@ class LoadedModel:
try:
self.real_model = self.model.forge_patch_model(patch_model_to)
self.model.current_device = self.model.load_device
except Exception as e:
self.model.forge_unpatch_model(self.model.offload_device)
self.model_unload()

View File

@@ -5,12 +5,53 @@ import torch
import contextlib
from backend import stream, memory_management, utils
from backend.patcher.lora import merge_lora_to_weight
stash = {}
def weights_manual_cast(layer, x, skip_weight_dtype=False, skip_bias_dtype=False):
def get_weight_and_bias(layer, weight_args=None, bias_args=None, weight_fn=None, bias_fn=None):
patches = getattr(layer, 'forge_online_loras', None)
weight_patches, bias_patches = None, None
if patches is not None:
weight_patches = patches.get('weight', None)
if patches is not None:
bias_patches = patches.get('bias', None)
weight = None
if layer.weight is not None:
weight = layer.weight
if weight_fn is not None:
if weight_args is not None:
fn_device = weight_args.get('device', None)
if fn_device is not None:
weight = weight.to(device=fn_device)
weight = weight_fn(weight)
if weight_args is not None:
weight = weight.to(**weight_args)
if weight_patches is not None:
weight = merge_lora_to_weight(patches=weight_patches, weight=weight, key="online weight lora", computation_dtype=weight.dtype)
bias = None
if layer.bias is not None:
bias = layer.bias
if bias_fn is not None:
if bias_args is not None:
fn_device = bias_args.get('device', None)
if fn_device is not None:
bias = bias.to(device=fn_device)
bias = bias_fn(bias)
if bias_args is not None:
bias = bias.to(**bias_args)
if bias_patches is not None:
bias = merge_lora_to_weight(patches=bias_patches, weight=bias, key="online bias lora", computation_dtype=bias.dtype)
return weight, bias
def weights_manual_cast(layer, x, skip_weight_dtype=False, skip_bias_dtype=False, weight_fn=None, bias_fn=None):
weight, bias, signal = None, None, None
non_blocking = True
@@ -32,16 +73,10 @@ def weights_manual_cast(layer, x, skip_weight_dtype=False, skip_bias_dtype=False
if stream.should_use_stream():
with stream.stream_context()(stream.mover_stream):
if layer.weight is not None:
weight = layer.weight.to(**weight_args)
if layer.bias is not None:
bias = layer.bias.to(**bias_args)
weight, bias = get_weight_and_bias(layer, weight_args, bias_args, weight_fn=weight_fn, bias_fn=bias_fn)
signal = stream.mover_stream.record_event()
else:
if layer.weight is not None:
weight = layer.weight.to(**weight_args)
if layer.bias is not None:
bias = layer.bias.to(**bias_args)
weight, bias = get_weight_and_bias(layer, weight_args, bias_args, weight_fn=weight_fn, bias_fn=bias_fn)
return weight, bias, signal
@@ -109,7 +144,8 @@ class ForgeOperations:
with main_stream_worker(weight, bias, signal):
return torch.nn.functional.linear(x, weight, bias)
else:
return torch.nn.functional.linear(x, self.weight, self.bias)
weight, bias = get_weight_and_bias(self)
return torch.nn.functional.linear(x, weight, bias)
class Conv2d(torch.nn.Conv2d):
@@ -128,7 +164,8 @@ class ForgeOperations:
with main_stream_worker(weight, bias, signal):
return self._conv_forward(x, weight, bias)
else:
return super().forward(x)
weight, bias = get_weight_and_bias(self)
return super()._conv_forward(x, weight, bias)
class Conv3d(torch.nn.Conv3d):
@@ -147,7 +184,8 @@ class ForgeOperations:
with main_stream_worker(weight, bias, signal):
return self._conv_forward(x, weight, bias)
else:
return super().forward(x)
weight, bias = get_weight_and_bias(self)
return super()._conv_forward(input, weight, bias)
class Conv1d(torch.nn.Conv1d):
@@ -166,7 +204,8 @@ class ForgeOperations:
with main_stream_worker(weight, bias, signal):
return self._conv_forward(x, weight, bias)
else:
return super().forward(x)
weight, bias = get_weight_and_bias(self)
return super()._conv_forward(input, weight, bias)
class ConvTranspose2d(torch.nn.ConvTranspose2d):
@@ -188,7 +227,10 @@ class ForgeOperations:
with main_stream_worker(weight, bias, signal):
return torch.nn.functional.conv_transpose2d(x, weight, bias, self.stride, self.padding, output_padding, self.groups, self.dilation)
else:
return super().forward(x, output_size)
weight, bias = get_weight_and_bias(self)
num_spatial_dims = 2
output_padding = self._output_padding(x, output_size, self.stride, self.padding, self.kernel_size, num_spatial_dims, self.dilation)
return torch.nn.functional.conv_transpose2d(x, weight, bias, self.stride, self.padding, output_padding, self.groups, self.dilation)
class ConvTranspose1d(torch.nn.ConvTranspose1d):
@@ -210,7 +252,10 @@ class ForgeOperations:
with main_stream_worker(weight, bias, signal):
return torch.nn.functional.conv_transpose1d(x, weight, bias, self.stride, self.padding, output_padding, self.groups, self.dilation)
else:
return super().forward(x, output_size)
weight, bias = get_weight_and_bias(self)
num_spatial_dims = 1
output_padding = self._output_padding(x, output_size, self.stride, self.padding, self.kernel_size, num_spatial_dims, self.dilation)
return torch.nn.functional.conv_transpose2d(x, weight, bias, self.stride, self.padding, output_padding, self.groups, self.dilation)
class ConvTranspose3d(torch.nn.ConvTranspose3d):
@@ -232,7 +277,10 @@ class ForgeOperations:
with main_stream_worker(weight, bias, signal):
return torch.nn.functional.conv_transpose3d(x, weight, bias, self.stride, self.padding, output_padding, self.groups, self.dilation)
else:
return super().forward(x, output_size)
weight, bias = get_weight_and_bias(self)
num_spatial_dims = 3
output_padding = self._output_padding(x, output_size, self.stride, self.padding, self.kernel_size, num_spatial_dims, self.dilation)
return torch.nn.functional.conv_transpose2d(x, weight, bias, self.stride, self.padding, output_padding, self.groups, self.dilation)
class GroupNorm(torch.nn.GroupNorm):
@@ -328,7 +376,7 @@ except:
bnb_avaliable = False
from backend.operations_gguf import functional_linear_gguf
from backend.operations_gguf import dequantize_tensor
class ForgeOperationsGGUF(ForgeOperations):
@@ -361,12 +409,9 @@ class ForgeOperationsGGUF(ForgeOperations):
return self
def forward(self, x):
if self.parameters_manual_cast:
weight, bias, signal = weights_manual_cast(self, x, skip_weight_dtype=True)
with main_stream_worker(weight, bias, signal):
return functional_linear_gguf(x, weight, bias)
else:
return functional_linear_gguf(x, self.weight, self.bias)
weight, bias, signal = weights_manual_cast(self, x, weight_fn=dequantize_tensor, bias_fn=dequantize_tensor)
with main_stream_worker(weight, bias, signal):
return torch.nn.functional.linear(x, weight, bias)
@contextlib.contextmanager

View File

@@ -58,13 +58,6 @@ class ParameterGGUF(torch.nn.Parameter):
return new
def functional_linear_gguf(x, weight, bias=None):
target_dtype = x.dtype
weight = dequantize_tensor(weight).to(target_dtype)
bias = dequantize_tensor(bias).to(target_dtype)
return torch.nn.functional.linear(x, weight, bias)
def dequantize_tensor(tensor):
if tensor is None:
return None

View File

@@ -4,14 +4,12 @@
# are from Forge, implemented from scratch (after forge-v1.0.1), and may have
# certain level of differences.
import time
import torch
import copy
import inspect
from tqdm import tqdm
from backend import memory_management, utils, operations
from backend.patcher.lora import merge_lora_to_model_weight, LoraLoader
from backend import memory_management, utils
from backend.patcher.lora import LoraLoader
def set_model_options_patch_replace(model_options, patch, name, block_name, number, transformer_index=None):
@@ -229,7 +227,6 @@ class ModelPatcher:
if target_device is not None:
self.model.to(target_device)
self.current_device = target_device
return self.model

View File

@@ -5,7 +5,8 @@ import packages_3rdparty.webui_lora_collection.lora as lora_utils_webui
import packages_3rdparty.comfyui_lora_collection.lora as lora_utils_comfyui
from tqdm import tqdm
from backend import memory_management, utils, operations
from backend import memory_management, utils
from backend.args import dynamic_args
class ForgeLoraCollection:
@@ -39,8 +40,10 @@ def model_lora_keys_unet(model, key_map={}):
return get_function('model_lora_keys_unet')(model, key_map)
def weight_decompose(dora_scale, weight, lora_diff, alpha, strength):
dora_scale = memory_management.cast_to_device(dora_scale, weight.device, torch.float32)
def weight_decompose(dora_scale, weight, lora_diff, alpha, strength, computation_dtype):
# Modified from https://github.com/comfyanonymous/ComfyUI/blob/39f114c44bb99d4a221e8da451d4f2a20119c674/comfy/model_patcher.py#L33
dora_scale = memory_management.cast_to_device(dora_scale, weight.device, computation_dtype)
lora_diff *= alpha
weight_calc = weight + lora_diff.type(weight.dtype)
weight_norm = (
@@ -60,7 +63,12 @@ def weight_decompose(dora_scale, weight, lora_diff, alpha, strength):
return weight
def merge_lora_to_model_weight(patches, weight, key):
def merge_lora_to_weight(patches, weight, key="online_lora", computation_dtype=torch.float32):
# Modified from https://github.com/comfyanonymous/ComfyUI/blob/39f114c44bb99d4a221e8da451d4f2a20119c674/comfy/model_patcher.py#L446
weight_original_dtype = weight.dtype
weight = weight.to(dtype=computation_dtype)
for p in patches:
strength = p[0]
v = p[1]
@@ -79,7 +87,7 @@ def merge_lora_to_model_weight(patches, weight, key):
weight *= strength_model
if isinstance(v, list):
v = (merge_lora_to_model_weight(v[1:], v[0].clone(), key),)
v = (merge_lora_to_weight(v[1:], v[0].clone(), key),)
patch_type = ''
@@ -107,8 +115,8 @@ def merge_lora_to_model_weight(patches, weight, key):
else:
weight += strength * memory_management.cast_to_device(w1, weight.device, weight.dtype)
elif patch_type == "lora":
mat1 = memory_management.cast_to_device(v[0], weight.device, torch.float32)
mat2 = memory_management.cast_to_device(v[1], weight.device, torch.float32)
mat1 = memory_management.cast_to_device(v[0], weight.device, computation_dtype)
mat2 = memory_management.cast_to_device(v[1], weight.device, computation_dtype)
dora_scale = v[4]
if v[2] is not None:
alpha = v[2] / mat2.shape[0]
@@ -116,13 +124,13 @@ def merge_lora_to_model_weight(patches, weight, key):
alpha = 1.0
if v[3] is not None:
mat3 = memory_management.cast_to_device(v[3], weight.device, torch.float32)
mat3 = memory_management.cast_to_device(v[3], weight.device, computation_dtype)
final_shape = [mat2.shape[1], mat2.shape[0], mat3.shape[2], mat3.shape[3]]
mat2 = torch.mm(mat2.transpose(0, 1).flatten(start_dim=1), mat3.transpose(0, 1).flatten(start_dim=1)).reshape(final_shape).transpose(0, 1)
try:
lora_diff = torch.mm(mat1.flatten(start_dim=1), mat2.flatten(start_dim=1)).reshape(weight.shape)
if dora_scale is not None:
weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength))
weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength, computation_dtype))
else:
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
except Exception as e:
@@ -141,23 +149,23 @@ def merge_lora_to_model_weight(patches, weight, key):
if w1 is None:
dim = w1_b.shape[0]
w1 = torch.mm(memory_management.cast_to_device(w1_a, weight.device, torch.float32),
memory_management.cast_to_device(w1_b, weight.device, torch.float32))
w1 = torch.mm(memory_management.cast_to_device(w1_a, weight.device, computation_dtype),
memory_management.cast_to_device(w1_b, weight.device, computation_dtype))
else:
w1 = memory_management.cast_to_device(w1, weight.device, torch.float32)
w1 = memory_management.cast_to_device(w1, weight.device, computation_dtype)
if w2 is None:
dim = w2_b.shape[0]
if t2 is None:
w2 = torch.mm(memory_management.cast_to_device(w2_a, weight.device, torch.float32),
memory_management.cast_to_device(w2_b, weight.device, torch.float32))
w2 = torch.mm(memory_management.cast_to_device(w2_a, weight.device, computation_dtype),
memory_management.cast_to_device(w2_b, weight.device, computation_dtype))
else:
w2 = torch.einsum('i j k l, j r, i p -> p r k l',
memory_management.cast_to_device(t2, weight.device, torch.float32),
memory_management.cast_to_device(w2_b, weight.device, torch.float32),
memory_management.cast_to_device(w2_a, weight.device, torch.float32))
memory_management.cast_to_device(t2, weight.device, computation_dtype),
memory_management.cast_to_device(w2_b, weight.device, computation_dtype),
memory_management.cast_to_device(w2_a, weight.device, computation_dtype))
else:
w2 = memory_management.cast_to_device(w2, weight.device, torch.float32)
w2 = memory_management.cast_to_device(w2, weight.device, computation_dtype)
if len(w2.shape) == 4:
w1 = w1.unsqueeze(2).unsqueeze(2)
@@ -169,7 +177,7 @@ def merge_lora_to_model_weight(patches, weight, key):
try:
lora_diff = torch.kron(w1, w2).reshape(weight.shape)
if dora_scale is not None:
weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength))
weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength, computation_dtype))
else:
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
except Exception as e:
@@ -190,24 +198,24 @@ def merge_lora_to_model_weight(patches, weight, key):
t1 = v[5]
t2 = v[6]
m1 = torch.einsum('i j k l, j r, i p -> p r k l',
memory_management.cast_to_device(t1, weight.device, torch.float32),
memory_management.cast_to_device(w1b, weight.device, torch.float32),
memory_management.cast_to_device(w1a, weight.device, torch.float32))
memory_management.cast_to_device(t1, weight.device, computation_dtype),
memory_management.cast_to_device(w1b, weight.device, computation_dtype),
memory_management.cast_to_device(w1a, weight.device, computation_dtype))
m2 = torch.einsum('i j k l, j r, i p -> p r k l',
memory_management.cast_to_device(t2, weight.device, torch.float32),
memory_management.cast_to_device(w2b, weight.device, torch.float32),
memory_management.cast_to_device(w2a, weight.device, torch.float32))
memory_management.cast_to_device(t2, weight.device, computation_dtype),
memory_management.cast_to_device(w2b, weight.device, computation_dtype),
memory_management.cast_to_device(w2a, weight.device, computation_dtype))
else:
m1 = torch.mm(memory_management.cast_to_device(w1a, weight.device, torch.float32),
memory_management.cast_to_device(w1b, weight.device, torch.float32))
m2 = torch.mm(memory_management.cast_to_device(w2a, weight.device, torch.float32),
memory_management.cast_to_device(w2b, weight.device, torch.float32))
m1 = torch.mm(memory_management.cast_to_device(w1a, weight.device, computation_dtype),
memory_management.cast_to_device(w1b, weight.device, computation_dtype))
m2 = torch.mm(memory_management.cast_to_device(w2a, weight.device, computation_dtype),
memory_management.cast_to_device(w2b, weight.device, computation_dtype))
try:
lora_diff = (m1 * m2).reshape(weight.shape)
if dora_scale is not None:
weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength))
weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength, computation_dtype))
else:
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
except Exception as e:
@@ -221,15 +229,15 @@ def merge_lora_to_model_weight(patches, weight, key):
dora_scale = v[5]
a1 = memory_management.cast_to_device(v[0].flatten(start_dim=1), weight.device, torch.float32)
a2 = memory_management.cast_to_device(v[1].flatten(start_dim=1), weight.device, torch.float32)
b1 = memory_management.cast_to_device(v[2].flatten(start_dim=1), weight.device, torch.float32)
b2 = memory_management.cast_to_device(v[3].flatten(start_dim=1), weight.device, torch.float32)
a1 = memory_management.cast_to_device(v[0].flatten(start_dim=1), weight.device, computation_dtype)
a2 = memory_management.cast_to_device(v[1].flatten(start_dim=1), weight.device, computation_dtype)
b1 = memory_management.cast_to_device(v[2].flatten(start_dim=1), weight.device, computation_dtype)
b2 = memory_management.cast_to_device(v[3].flatten(start_dim=1), weight.device, computation_dtype)
try:
lora_diff = (torch.mm(b2, b1) + torch.mm(torch.mm(weight.flatten(start_dim=1), a2), a1)).reshape(weight.shape)
if dora_scale is not None:
weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength))
weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength, computation_dtype))
else:
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
except Exception as e:
@@ -243,14 +251,19 @@ def merge_lora_to_model_weight(patches, weight, key):
if old_weight is not None:
weight = old_weight
weight = weight.to(dtype=weight_original_dtype)
return weight
from backend import operations
class LoraLoader:
def __init__(self, model):
self.model = model
self.patches = {}
self.backup = {}
self.online_backup = []
self.dirty = False
def clear_patches(self):
@@ -277,7 +290,7 @@ class LoraLoader:
if key in model_sd:
p.add(k)
current_patches = self.patches.get(key, [])
current_patches.append((strength_patch, patches[k], strength_model, offset, function))
current_patches.append([strength_patch, patches[k], strength_model, offset, function])
self.patches[key] = current_patches
self.dirty = True
@@ -293,10 +306,12 @@ class LoraLoader:
# Restore
for k, w in self.backup.items():
if target_device is not None:
w = w.to(device=target_device)
for m in set(self.online_backup):
del m.forge_online_loras
self.online_backup = []
for k, w in self.backup.items():
if not isinstance(w, torch.nn.Parameter):
# In very few cases
w = torch.nn.Parameter(w, requires_grad=False)
@@ -305,11 +320,13 @@ class LoraLoader:
self.backup = {}
online_mode = dynamic_args.get('online_lora', False)
# Patch
for key, current_patches in (tqdm(self.patches.items(), desc=f'Patching LoRAs for {type(self.model).__name__}') if len(self.patches) > 0 else self.patches):
try:
parent_layer, weight = utils.get_attr_with_parent(self.model, key)
parent_layer, child_key, weight = utils.get_attr_with_parent(self.model, key)
assert isinstance(weight, torch.nn.Parameter)
except:
raise ValueError(f"Wrong LoRA Key: {key}")
@@ -317,6 +334,14 @@ class LoraLoader:
if key not in self.backup:
self.backup[key] = weight.to(device=offload_device)
if online_mode:
if not hasattr(parent_layer, 'forge_online_loras'):
parent_layer.forge_online_loras = {}
parent_layer.forge_online_loras[child_key] = current_patches
self.online_backup.append(parent_layer)
continue
bnb_layer = None
if operations.bnb_avaliable:
@@ -357,17 +382,13 @@ class LoraLoader:
gguf_real_shape = weight.gguf_real_shape
weight = dequantize_tensor(weight)
weight_original_dtype = weight.dtype
try:
weight = weight.to(dtype=torch.float32)
weight = merge_lora_to_model_weight(current_patches, weight, key).to(dtype=weight_original_dtype)
weight = merge_lora_to_weight(current_patches, weight, key, computation_dtype=torch.float32)
except:
print('Patching LoRA weights failed. Retrying by offloading models.')
self.model.to(device=offload_device)
memory_management.soft_empty_cache()
weight = weight.to(dtype=torch.float32)
weight = merge_lora_to_model_weight(current_patches, weight, key).to(dtype=weight_original_dtype)
weight = merge_lora_to_weight(current_patches, weight, key, computation_dtype=torch.float32)
if bnb_layer is not None:
bnb_layer.reload_weight(weight)

View File

@@ -10,6 +10,8 @@ import collections
from backend import memory_management
from backend.sampling.condition import Condition, compile_conditions, compile_weighted_conditions
from backend.operations import cleanup_cache
from backend.args import dynamic_args
from backend import utils
def get_area_and_mult(conds, x_in, timestep_in):
@@ -353,10 +355,19 @@ def sampling_prepare(unet, x):
additional_inference_memory += unet.controlnet_linked_list.inference_memory_requirements(unet.model_dtype())
additional_model_patchers += unet.controlnet_linked_list.get_models()
if dynamic_args.get('online_lora', False):
lora_memory = utils.nested_compute_size(unet.lora_loader.patches)
additional_inference_memory += lora_memory
memory_management.load_models_gpu(
models=[unet] + additional_model_patchers,
memory_required=unet_inference_memory + additional_inference_memory)
if dynamic_args.get('online_lora', False):
utils.nested_move_to_device(unet.lora_loader.patches, device=unet.current_device)
unet.lora_loader.patches = {}
real_model = unet.model
percent_to_timestep_function = lambda p: real_model.predictor.percent_to_sigma(p)

View File

@@ -79,10 +79,11 @@ def get_attr(obj, attr):
def get_attr_with_parent(obj, attr):
attrs = attr.split(".")
parent = obj
name = None
for name in attrs:
parent = obj
obj = getattr(obj, name)
return parent, obj
return parent, name, obj
def calculate_parameters(sd, prefix=""):
@@ -108,3 +109,32 @@ def fp16_fix(x):
if x.dtype in [torch.float16]:
return x.clip(-32768.0, 32768.0)
return x
def nested_compute_size(obj):
module_mem = 0
if isinstance(obj, dict):
for key in obj:
module_mem += nested_compute_size(obj[key])
elif isinstance(obj, list) or isinstance(obj, tuple):
for i in range(len(obj)):
module_mem += nested_compute_size(obj[i])
elif isinstance(obj, torch.Tensor):
module_mem += obj.nelement() * obj.element_size()
return module_mem
def nested_move_to_device(obj, device):
if isinstance(obj, dict):
for key in obj:
obj[key] = nested_move_to_device(obj[key], device)
elif isinstance(obj, list):
for i in range(len(obj)):
obj[i] = nested_move_to_device(obj[i], device)
elif isinstance(obj, tuple):
obj = tuple(nested_move_to_device(i, device) for i in obj)
elif isinstance(obj, torch.Tensor):
return obj.to(device)
return obj

View File

@@ -1,3 +1,6 @@
# import faulthandler
# faulthandler.enable()
from modules import launch_utils
args = launch_utils.args

View File

@@ -799,6 +799,7 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
memory_management.unload_all_models()
if need_global_unload:
p.sd_model.forge_objects.unet.lora_loader.dirty = True
p.clear_prompt_cache()
need_global_unload = False

View File

@@ -5,6 +5,7 @@ import gradio as gr
from gradio.context import Context
from modules import shared_items, shared, ui_common, sd_models, processing, infotext_utils, paths
from backend import memory_management, stream
from backend.args import dynamic_args
total_vram = int(memory_management.total_vram)
@@ -21,11 +22,14 @@ ui_forge_pin_shared_memory: gr.Radio = None
ui_forge_inference_memory: gr.Slider = None
forge_unet_storage_dtype_options = {
'Automatic': None,
'bnb-nf4': 'nf4',
'float8-e4m3fn': torch.float8_e4m3fn,
'bnb-fp4': 'fp4',
'float8-e5m2': torch.float8_e5m2,
'Automatic': (None, False),
'Automatic (fp16 LoRA)': (None, True),
'bnb-nf4': ('nf4', False),
'float8-e4m3fn': (torch.float8_e4m3fn, False),
'float8-e4m3fn (fp16 LoRA)': (torch.float8_e4m3fn, True),
'bnb-fp4': ('fp4', False),
'float8-e5m2': (torch.float8_e5m2, False),
'float8-e5m2 (fp16 LoRA)': (torch.float8_e5m2, True),
}
module_list = {}
@@ -180,10 +184,14 @@ def refresh_model_loading_parameters():
checkpoint_info = select_checkpoint()
unet_storage_dtype, lora_fp16 = forge_unet_storage_dtype_options.get(shared.opts.forge_unet_storage_dtype, (None, False))
dynamic_args['online_lora'] = lora_fp16
model_data.forge_loading_parameters = dict(
checkpoint_info=checkpoint_info,
additional_modules=shared.opts.forge_additional_modules,
unet_storage_dtype=forge_unet_storage_dtype_options.get(shared.opts.forge_unet_storage_dtype, None)
unet_storage_dtype=unet_storage_dtype
)
print(f'Model selected: {model_data.forge_loading_parameters}')

View File

@@ -158,7 +158,7 @@ class __Quant(ABC):
@classmethod
@abstractmethod
def quantize_blocks_pytorch(cls, blocks, block_size, type_size) -> torch.Tensor:
raise NotImplementedError
raise NotImplementedError('Low bit LoRA for this data type is not implemented yet. Please select "Automatic (fp16 LoRA)" to use this LoRA.')
@classmethod
@abstractmethod
@@ -370,30 +370,6 @@ class Q4_1(__Quant, qtype=GGMLQuantizationType.Q4_1):
return (d * qs) + m
@classmethod
def quantize_blocks_pytorch(cls, blocks, block_size, type_size) -> torch.Tensor:
# WIP
raise NotImplementedError('Q4_1 Lora is under construction!')
n_blocks = blocks.shape[0]
max_vals = blocks.max(dim=-1, keepdim=True).values
min_vals = blocks.min(dim=-1, keepdim=True).values
d = (max_vals - min_vals) / 15
id = torch.where(d == 0, torch.tensor(0.0, device=d.device), 1 / d)
qs = torch.trunc((blocks - min_vals) * id + 0.5).to(torch.uint8).clip(0, 15)
qs = qs.view(n_blocks, 2, block_size // 2)
qs = qs[:, 0, :] | (qs[:, 1, :] << 4)
d = d.to(torch.float16).view(n_blocks, -1)
m = min_vals.to(torch.float16).view(n_blocks, -1)
return torch.cat([d, m, qs], dim=-1)
class Q5_0(__Quant, qtype=GGMLQuantizationType.Q5_0):
@classmethod
@@ -567,31 +543,6 @@ class Q5_1(__Quant, qtype=GGMLQuantizationType.Q5_1):
qs = (ql | (qh << 4))
return (d * qs) + m
@classmethod
def quantize_blocks_pytorch(cls, blocks, block_size, type_size) -> torch.Tensor:
# WIP
raise NotImplementedError('Q5_1 Lora is under construction!')
n_blocks = blocks.shape[0]
max_val = blocks.max(dim=-1, keepdim=True)[0]
min_val = blocks.min(dim=-1, keepdim=True)[0]
d = (max_val - min_val) / 31
id = torch.where(d == 0, torch.zeros_like(d), 1.0 / d)
q = torch.trunc((blocks - min_val) * id + 0.5).clip(0, 31).to(torch.uint8)
qs = q.view(n_blocks, 2, block_size // 2)
qs = (qs[..., 0, :] & 0x0F) | (qs[..., 1, :] << 4)
qh = torch.bitwise_right_shift(q.view(n_blocks, 1, 32), torch.arange(4, dtype=torch.uint8, device=blocks.device) * 8).byte()
d = d.to(torch.float16).view(-1, 2)
min_val = min_val.to(torch.float16).view(-1, 2)
return torch.cat([d.byte(), min_val.byte(), qh, qs], dim=-1)
class Q8_0(__Quant, qtype=GGMLQuantizationType.Q8_0):
@classmethod
@@ -677,10 +628,6 @@ class Q2_K(__Quant, qtype=GGMLQuantizationType.Q2_K):
qs = dl * qs - ml
return qs.reshape((n_blocks, -1))
@classmethod
def quantize_blocks_pytorch(cls, blocks, block_size, type_size) -> torch.Tensor:
raise NotImplementedError('Not Implemented Yet')
class Q3_K(__Quant, qtype=GGMLQuantizationType.Q3_K):
@classmethod
@@ -746,10 +693,6 @@ class Q3_K(__Quant, qtype=GGMLQuantizationType.Q3_K):
q = (ql.to(torch.int8) - (qh << 2).to(torch.int8))
return (dl * q).reshape((n_blocks, QK_K))
@classmethod
def quantize_blocks_pytorch(cls, blocks, block_size, type_size) -> torch.Tensor:
raise NotImplementedError('Not Implemented Yet')
class Q4_K(__Quant, qtype=GGMLQuantizationType.Q4_K):
K_SCALE_SIZE = 12
@@ -826,10 +769,6 @@ class Q4_K(__Quant, qtype=GGMLQuantizationType.Q4_K):
qs = (qs & 0x0F).reshape((n_blocks, -1, 32))
return (d * qs - dm).reshape((n_blocks, QK_K))
@classmethod
def quantize_blocks_pytorch(cls, blocks, block_size, type_size) -> torch.Tensor:
raise NotImplementedError('Not Implemented Yet')
class Q5_K(__Quant, qtype=GGMLQuantizationType.Q5_K):
@classmethod
@@ -876,10 +815,6 @@ class Q5_K(__Quant, qtype=GGMLQuantizationType.Q5_K):
q = (ql | (qh << 4))
return (d * q - dm).reshape((n_blocks, QK_K))
@classmethod
def quantize_blocks_pytorch(cls, blocks, block_size, type_size) -> torch.Tensor:
raise NotImplementedError('Not Implemented Yet')
class Q6_K(__Quant, qtype=GGMLQuantizationType.Q6_K):
@classmethod
@@ -919,10 +854,6 @@ class Q6_K(__Quant, qtype=GGMLQuantizationType.Q6_K):
q = q.reshape((n_blocks, QK_K // 16, -1))
return (d * q).reshape((n_blocks, QK_K))
@classmethod
def quantize_blocks_pytorch(cls, blocks, block_size, type_size) -> torch.Tensor:
raise NotImplementedError('Not Implemented Yet')
class IQ2_XXS(__Quant, qtype=GGMLQuantizationType.IQ2_XXS):
ksigns: bytes = (