Implement some rethinking about LoRA system

1. Add an option to allow users to use UNet in fp8/gguf but lora in fp16.
2. All FP16 loras do not need patch. Others will only patch again when lora weight change.
3. FP8 unet + fp16 lora are available (somewhat only available) in Forge now. This also solves some “LoRA too subtle” problems.
4. Significantly speed up all gguf models (in Async mode) by using independent thread (CUDA stream) to compute and dequant at the same time, even when low-bit weights are already on GPU.
5. View “online lora” as a module similar to ControlLoRA so that it is moved to GPU together with model when sampling, achieving significant speedup and perfect low VRAM management simultaneously.
This commit is contained in:
layerdiffusion
2024-08-19 04:31:00 -07:00
parent e5f213c21e
commit d38e560e42
11 changed files with 200 additions and 159 deletions

View File

@@ -5,12 +5,53 @@ import torch
import contextlib
from backend import stream, memory_management, utils
from backend.patcher.lora import merge_lora_to_weight
stash = {}
def weights_manual_cast(layer, x, skip_weight_dtype=False, skip_bias_dtype=False):
def get_weight_and_bias(layer, weight_args=None, bias_args=None, weight_fn=None, bias_fn=None):
patches = getattr(layer, 'forge_online_loras', None)
weight_patches, bias_patches = None, None
if patches is not None:
weight_patches = patches.get('weight', None)
if patches is not None:
bias_patches = patches.get('bias', None)
weight = None
if layer.weight is not None:
weight = layer.weight
if weight_fn is not None:
if weight_args is not None:
fn_device = weight_args.get('device', None)
if fn_device is not None:
weight = weight.to(device=fn_device)
weight = weight_fn(weight)
if weight_args is not None:
weight = weight.to(**weight_args)
if weight_patches is not None:
weight = merge_lora_to_weight(patches=weight_patches, weight=weight, key="online weight lora", computation_dtype=weight.dtype)
bias = None
if layer.bias is not None:
bias = layer.bias
if bias_fn is not None:
if bias_args is not None:
fn_device = bias_args.get('device', None)
if fn_device is not None:
bias = bias.to(device=fn_device)
bias = bias_fn(bias)
if bias_args is not None:
bias = bias.to(**bias_args)
if bias_patches is not None:
bias = merge_lora_to_weight(patches=bias_patches, weight=bias, key="online bias lora", computation_dtype=bias.dtype)
return weight, bias
def weights_manual_cast(layer, x, skip_weight_dtype=False, skip_bias_dtype=False, weight_fn=None, bias_fn=None):
weight, bias, signal = None, None, None
non_blocking = True
@@ -32,16 +73,10 @@ def weights_manual_cast(layer, x, skip_weight_dtype=False, skip_bias_dtype=False
if stream.should_use_stream():
with stream.stream_context()(stream.mover_stream):
if layer.weight is not None:
weight = layer.weight.to(**weight_args)
if layer.bias is not None:
bias = layer.bias.to(**bias_args)
weight, bias = get_weight_and_bias(layer, weight_args, bias_args, weight_fn=weight_fn, bias_fn=bias_fn)
signal = stream.mover_stream.record_event()
else:
if layer.weight is not None:
weight = layer.weight.to(**weight_args)
if layer.bias is not None:
bias = layer.bias.to(**bias_args)
weight, bias = get_weight_and_bias(layer, weight_args, bias_args, weight_fn=weight_fn, bias_fn=bias_fn)
return weight, bias, signal
@@ -109,7 +144,8 @@ class ForgeOperations:
with main_stream_worker(weight, bias, signal):
return torch.nn.functional.linear(x, weight, bias)
else:
return torch.nn.functional.linear(x, self.weight, self.bias)
weight, bias = get_weight_and_bias(self)
return torch.nn.functional.linear(x, weight, bias)
class Conv2d(torch.nn.Conv2d):
@@ -128,7 +164,8 @@ class ForgeOperations:
with main_stream_worker(weight, bias, signal):
return self._conv_forward(x, weight, bias)
else:
return super().forward(x)
weight, bias = get_weight_and_bias(self)
return super()._conv_forward(x, weight, bias)
class Conv3d(torch.nn.Conv3d):
@@ -147,7 +184,8 @@ class ForgeOperations:
with main_stream_worker(weight, bias, signal):
return self._conv_forward(x, weight, bias)
else:
return super().forward(x)
weight, bias = get_weight_and_bias(self)
return super()._conv_forward(input, weight, bias)
class Conv1d(torch.nn.Conv1d):
@@ -166,7 +204,8 @@ class ForgeOperations:
with main_stream_worker(weight, bias, signal):
return self._conv_forward(x, weight, bias)
else:
return super().forward(x)
weight, bias = get_weight_and_bias(self)
return super()._conv_forward(input, weight, bias)
class ConvTranspose2d(torch.nn.ConvTranspose2d):
@@ -188,7 +227,10 @@ class ForgeOperations:
with main_stream_worker(weight, bias, signal):
return torch.nn.functional.conv_transpose2d(x, weight, bias, self.stride, self.padding, output_padding, self.groups, self.dilation)
else:
return super().forward(x, output_size)
weight, bias = get_weight_and_bias(self)
num_spatial_dims = 2
output_padding = self._output_padding(x, output_size, self.stride, self.padding, self.kernel_size, num_spatial_dims, self.dilation)
return torch.nn.functional.conv_transpose2d(x, weight, bias, self.stride, self.padding, output_padding, self.groups, self.dilation)
class ConvTranspose1d(torch.nn.ConvTranspose1d):
@@ -210,7 +252,10 @@ class ForgeOperations:
with main_stream_worker(weight, bias, signal):
return torch.nn.functional.conv_transpose1d(x, weight, bias, self.stride, self.padding, output_padding, self.groups, self.dilation)
else:
return super().forward(x, output_size)
weight, bias = get_weight_and_bias(self)
num_spatial_dims = 1
output_padding = self._output_padding(x, output_size, self.stride, self.padding, self.kernel_size, num_spatial_dims, self.dilation)
return torch.nn.functional.conv_transpose2d(x, weight, bias, self.stride, self.padding, output_padding, self.groups, self.dilation)
class ConvTranspose3d(torch.nn.ConvTranspose3d):
@@ -232,7 +277,10 @@ class ForgeOperations:
with main_stream_worker(weight, bias, signal):
return torch.nn.functional.conv_transpose3d(x, weight, bias, self.stride, self.padding, output_padding, self.groups, self.dilation)
else:
return super().forward(x, output_size)
weight, bias = get_weight_and_bias(self)
num_spatial_dims = 3
output_padding = self._output_padding(x, output_size, self.stride, self.padding, self.kernel_size, num_spatial_dims, self.dilation)
return torch.nn.functional.conv_transpose2d(x, weight, bias, self.stride, self.padding, output_padding, self.groups, self.dilation)
class GroupNorm(torch.nn.GroupNorm):
@@ -328,7 +376,7 @@ except:
bnb_avaliable = False
from backend.operations_gguf import functional_linear_gguf
from backend.operations_gguf import dequantize_tensor
class ForgeOperationsGGUF(ForgeOperations):
@@ -361,12 +409,9 @@ class ForgeOperationsGGUF(ForgeOperations):
return self
def forward(self, x):
if self.parameters_manual_cast:
weight, bias, signal = weights_manual_cast(self, x, skip_weight_dtype=True)
with main_stream_worker(weight, bias, signal):
return functional_linear_gguf(x, weight, bias)
else:
return functional_linear_gguf(x, self.weight, self.bias)
weight, bias, signal = weights_manual_cast(self, x, weight_fn=dequantize_tensor, bias_fn=dequantize_tensor)
with main_stream_worker(weight, bias, signal):
return torch.nn.functional.linear(x, weight, bias)
@contextlib.contextmanager