mirror of
https://github.com/lllyasviel/stable-diffusion-webui-forge.git
synced 2026-04-29 02:31:16 +00:00
Implement some rethinking about LoRA system
1. Add an option to allow users to use UNet in fp8/gguf but lora in fp16. 2. All FP16 loras do not need patch. Others will only patch again when lora weight change. 3. FP8 unet + fp16 lora are available (somewhat only available) in Forge now. This also solves some “LoRA too subtle” problems. 4. Significantly speed up all gguf models (in Async mode) by using independent thread (CUDA stream) to compute and dequant at the same time, even when low-bit weights are already on GPU. 5. View “online lora” as a module similar to ControlLoRA so that it is moved to GPU together with model when sampling, achieving significant speedup and perfect low VRAM management simultaneously.
This commit is contained in:
@@ -5,12 +5,53 @@ import torch
|
||||
import contextlib
|
||||
|
||||
from backend import stream, memory_management, utils
|
||||
from backend.patcher.lora import merge_lora_to_weight
|
||||
|
||||
|
||||
stash = {}
|
||||
|
||||
|
||||
def weights_manual_cast(layer, x, skip_weight_dtype=False, skip_bias_dtype=False):
|
||||
def get_weight_and_bias(layer, weight_args=None, bias_args=None, weight_fn=None, bias_fn=None):
|
||||
patches = getattr(layer, 'forge_online_loras', None)
|
||||
weight_patches, bias_patches = None, None
|
||||
|
||||
if patches is not None:
|
||||
weight_patches = patches.get('weight', None)
|
||||
|
||||
if patches is not None:
|
||||
bias_patches = patches.get('bias', None)
|
||||
|
||||
weight = None
|
||||
if layer.weight is not None:
|
||||
weight = layer.weight
|
||||
if weight_fn is not None:
|
||||
if weight_args is not None:
|
||||
fn_device = weight_args.get('device', None)
|
||||
if fn_device is not None:
|
||||
weight = weight.to(device=fn_device)
|
||||
weight = weight_fn(weight)
|
||||
if weight_args is not None:
|
||||
weight = weight.to(**weight_args)
|
||||
if weight_patches is not None:
|
||||
weight = merge_lora_to_weight(patches=weight_patches, weight=weight, key="online weight lora", computation_dtype=weight.dtype)
|
||||
|
||||
bias = None
|
||||
if layer.bias is not None:
|
||||
bias = layer.bias
|
||||
if bias_fn is not None:
|
||||
if bias_args is not None:
|
||||
fn_device = bias_args.get('device', None)
|
||||
if fn_device is not None:
|
||||
bias = bias.to(device=fn_device)
|
||||
bias = bias_fn(bias)
|
||||
if bias_args is not None:
|
||||
bias = bias.to(**bias_args)
|
||||
if bias_patches is not None:
|
||||
bias = merge_lora_to_weight(patches=bias_patches, weight=bias, key="online bias lora", computation_dtype=bias.dtype)
|
||||
return weight, bias
|
||||
|
||||
|
||||
def weights_manual_cast(layer, x, skip_weight_dtype=False, skip_bias_dtype=False, weight_fn=None, bias_fn=None):
|
||||
weight, bias, signal = None, None, None
|
||||
non_blocking = True
|
||||
|
||||
@@ -32,16 +73,10 @@ def weights_manual_cast(layer, x, skip_weight_dtype=False, skip_bias_dtype=False
|
||||
|
||||
if stream.should_use_stream():
|
||||
with stream.stream_context()(stream.mover_stream):
|
||||
if layer.weight is not None:
|
||||
weight = layer.weight.to(**weight_args)
|
||||
if layer.bias is not None:
|
||||
bias = layer.bias.to(**bias_args)
|
||||
weight, bias = get_weight_and_bias(layer, weight_args, bias_args, weight_fn=weight_fn, bias_fn=bias_fn)
|
||||
signal = stream.mover_stream.record_event()
|
||||
else:
|
||||
if layer.weight is not None:
|
||||
weight = layer.weight.to(**weight_args)
|
||||
if layer.bias is not None:
|
||||
bias = layer.bias.to(**bias_args)
|
||||
weight, bias = get_weight_and_bias(layer, weight_args, bias_args, weight_fn=weight_fn, bias_fn=bias_fn)
|
||||
|
||||
return weight, bias, signal
|
||||
|
||||
@@ -109,7 +144,8 @@ class ForgeOperations:
|
||||
with main_stream_worker(weight, bias, signal):
|
||||
return torch.nn.functional.linear(x, weight, bias)
|
||||
else:
|
||||
return torch.nn.functional.linear(x, self.weight, self.bias)
|
||||
weight, bias = get_weight_and_bias(self)
|
||||
return torch.nn.functional.linear(x, weight, bias)
|
||||
|
||||
class Conv2d(torch.nn.Conv2d):
|
||||
|
||||
@@ -128,7 +164,8 @@ class ForgeOperations:
|
||||
with main_stream_worker(weight, bias, signal):
|
||||
return self._conv_forward(x, weight, bias)
|
||||
else:
|
||||
return super().forward(x)
|
||||
weight, bias = get_weight_and_bias(self)
|
||||
return super()._conv_forward(x, weight, bias)
|
||||
|
||||
class Conv3d(torch.nn.Conv3d):
|
||||
|
||||
@@ -147,7 +184,8 @@ class ForgeOperations:
|
||||
with main_stream_worker(weight, bias, signal):
|
||||
return self._conv_forward(x, weight, bias)
|
||||
else:
|
||||
return super().forward(x)
|
||||
weight, bias = get_weight_and_bias(self)
|
||||
return super()._conv_forward(input, weight, bias)
|
||||
|
||||
class Conv1d(torch.nn.Conv1d):
|
||||
|
||||
@@ -166,7 +204,8 @@ class ForgeOperations:
|
||||
with main_stream_worker(weight, bias, signal):
|
||||
return self._conv_forward(x, weight, bias)
|
||||
else:
|
||||
return super().forward(x)
|
||||
weight, bias = get_weight_and_bias(self)
|
||||
return super()._conv_forward(input, weight, bias)
|
||||
|
||||
class ConvTranspose2d(torch.nn.ConvTranspose2d):
|
||||
|
||||
@@ -188,7 +227,10 @@ class ForgeOperations:
|
||||
with main_stream_worker(weight, bias, signal):
|
||||
return torch.nn.functional.conv_transpose2d(x, weight, bias, self.stride, self.padding, output_padding, self.groups, self.dilation)
|
||||
else:
|
||||
return super().forward(x, output_size)
|
||||
weight, bias = get_weight_and_bias(self)
|
||||
num_spatial_dims = 2
|
||||
output_padding = self._output_padding(x, output_size, self.stride, self.padding, self.kernel_size, num_spatial_dims, self.dilation)
|
||||
return torch.nn.functional.conv_transpose2d(x, weight, bias, self.stride, self.padding, output_padding, self.groups, self.dilation)
|
||||
|
||||
class ConvTranspose1d(torch.nn.ConvTranspose1d):
|
||||
|
||||
@@ -210,7 +252,10 @@ class ForgeOperations:
|
||||
with main_stream_worker(weight, bias, signal):
|
||||
return torch.nn.functional.conv_transpose1d(x, weight, bias, self.stride, self.padding, output_padding, self.groups, self.dilation)
|
||||
else:
|
||||
return super().forward(x, output_size)
|
||||
weight, bias = get_weight_and_bias(self)
|
||||
num_spatial_dims = 1
|
||||
output_padding = self._output_padding(x, output_size, self.stride, self.padding, self.kernel_size, num_spatial_dims, self.dilation)
|
||||
return torch.nn.functional.conv_transpose2d(x, weight, bias, self.stride, self.padding, output_padding, self.groups, self.dilation)
|
||||
|
||||
class ConvTranspose3d(torch.nn.ConvTranspose3d):
|
||||
|
||||
@@ -232,7 +277,10 @@ class ForgeOperations:
|
||||
with main_stream_worker(weight, bias, signal):
|
||||
return torch.nn.functional.conv_transpose3d(x, weight, bias, self.stride, self.padding, output_padding, self.groups, self.dilation)
|
||||
else:
|
||||
return super().forward(x, output_size)
|
||||
weight, bias = get_weight_and_bias(self)
|
||||
num_spatial_dims = 3
|
||||
output_padding = self._output_padding(x, output_size, self.stride, self.padding, self.kernel_size, num_spatial_dims, self.dilation)
|
||||
return torch.nn.functional.conv_transpose2d(x, weight, bias, self.stride, self.padding, output_padding, self.groups, self.dilation)
|
||||
|
||||
class GroupNorm(torch.nn.GroupNorm):
|
||||
|
||||
@@ -328,7 +376,7 @@ except:
|
||||
bnb_avaliable = False
|
||||
|
||||
|
||||
from backend.operations_gguf import functional_linear_gguf
|
||||
from backend.operations_gguf import dequantize_tensor
|
||||
|
||||
|
||||
class ForgeOperationsGGUF(ForgeOperations):
|
||||
@@ -361,12 +409,9 @@ class ForgeOperationsGGUF(ForgeOperations):
|
||||
return self
|
||||
|
||||
def forward(self, x):
|
||||
if self.parameters_manual_cast:
|
||||
weight, bias, signal = weights_manual_cast(self, x, skip_weight_dtype=True)
|
||||
with main_stream_worker(weight, bias, signal):
|
||||
return functional_linear_gguf(x, weight, bias)
|
||||
else:
|
||||
return functional_linear_gguf(x, self.weight, self.bias)
|
||||
weight, bias, signal = weights_manual_cast(self, x, weight_fn=dequantize_tensor, bias_fn=dequantize_tensor)
|
||||
with main_stream_worker(weight, bias, signal):
|
||||
return torch.nn.functional.linear(x, weight, bias)
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
|
||||
Reference in New Issue
Block a user