mirror of
https://github.com/lllyasviel/stable-diffusion-webui-forge.git
synced 2026-02-08 17:09:59 +00:00
... based on 3 evidences: 1. torch.Tensor.view on one big tensor is slightly faster than calling torch.Tensor.to on multiple small tensors. 2. but torch.Tensor.to with dtype change is significantly slower than torch.Tensor.view 3. “baking” model on GPU is significantly faster than computing on CPU when model load. mainly influence inference of Q8_0, Q4_0/1/K and loading of all quants
557 lines
22 KiB
Python
557 lines
22 KiB
Python
# Copyright Forge 2024
|
|
|
|
import time
|
|
import torch
|
|
import contextlib
|
|
|
|
from backend import stream, memory_management, utils
|
|
from backend.patcher.lora import merge_lora_to_weight
|
|
|
|
|
|
stash = {}
|
|
|
|
|
|
def get_weight_and_bias(layer, weight_args=None, bias_args=None, weight_fn=None, bias_fn=None):
|
|
patches = getattr(layer, 'forge_online_loras', None)
|
|
weight_patches, bias_patches = None, None
|
|
|
|
if patches is not None:
|
|
weight_patches = patches.get('weight', None)
|
|
|
|
if patches is not None:
|
|
bias_patches = patches.get('bias', None)
|
|
|
|
weight = None
|
|
if layer.weight is not None:
|
|
weight = layer.weight
|
|
if weight_fn is not None:
|
|
if weight_args is not None:
|
|
fn_device = weight_args.get('device', None)
|
|
if fn_device is not None:
|
|
weight = weight.to(device=fn_device)
|
|
weight = weight_fn(weight)
|
|
if weight_args is not None:
|
|
weight = weight.to(**weight_args)
|
|
if weight_patches is not None:
|
|
weight = merge_lora_to_weight(patches=weight_patches, weight=weight, key="online weight lora", computation_dtype=weight.dtype)
|
|
|
|
bias = None
|
|
if layer.bias is not None:
|
|
bias = layer.bias
|
|
if bias_fn is not None:
|
|
if bias_args is not None:
|
|
fn_device = bias_args.get('device', None)
|
|
if fn_device is not None:
|
|
bias = bias.to(device=fn_device)
|
|
bias = bias_fn(bias)
|
|
if bias_args is not None:
|
|
bias = bias.to(**bias_args)
|
|
if bias_patches is not None:
|
|
bias = merge_lora_to_weight(patches=bias_patches, weight=bias, key="online bias lora", computation_dtype=bias.dtype)
|
|
return weight, bias
|
|
|
|
|
|
def weights_manual_cast(layer, x, skip_weight_dtype=False, skip_bias_dtype=False, weight_fn=None, bias_fn=None):
|
|
weight, bias, signal = None, None, None
|
|
non_blocking = True
|
|
|
|
if getattr(x.device, 'type', None) == 'mps':
|
|
non_blocking = False
|
|
|
|
target_dtype = x.dtype
|
|
target_device = x.device
|
|
|
|
if skip_weight_dtype:
|
|
weight_args = dict(device=target_device, non_blocking=non_blocking)
|
|
else:
|
|
weight_args = dict(device=target_device, dtype=target_dtype, non_blocking=non_blocking)
|
|
|
|
if skip_bias_dtype:
|
|
bias_args = dict(device=target_device, non_blocking=non_blocking)
|
|
else:
|
|
bias_args = dict(device=target_device, dtype=target_dtype, non_blocking=non_blocking)
|
|
|
|
if stream.should_use_stream():
|
|
with stream.stream_context()(stream.mover_stream):
|
|
weight, bias = get_weight_and_bias(layer, weight_args, bias_args, weight_fn=weight_fn, bias_fn=bias_fn)
|
|
signal = stream.mover_stream.record_event()
|
|
else:
|
|
weight, bias = get_weight_and_bias(layer, weight_args, bias_args, weight_fn=weight_fn, bias_fn=bias_fn)
|
|
|
|
return weight, bias, signal
|
|
|
|
|
|
@contextlib.contextmanager
|
|
def main_stream_worker(weight, bias, signal):
|
|
if signal is None or not stream.should_use_stream():
|
|
yield
|
|
return
|
|
|
|
with stream.stream_context()(stream.current_stream):
|
|
stream.current_stream.wait_event(signal)
|
|
yield
|
|
finished_signal = stream.current_stream.record_event()
|
|
stash[id(finished_signal)] = (weight, bias, finished_signal)
|
|
|
|
garbage = []
|
|
for k, (w, b, s) in stash.items():
|
|
if s.query():
|
|
garbage.append(k)
|
|
|
|
for k in garbage:
|
|
del stash[k]
|
|
return
|
|
|
|
|
|
def cleanup_cache():
|
|
if not stream.should_use_stream():
|
|
return
|
|
|
|
stream.current_stream.synchronize()
|
|
stream.mover_stream.synchronize()
|
|
stash.clear()
|
|
return
|
|
|
|
|
|
current_device = None
|
|
current_dtype = None
|
|
current_manual_cast_enabled = False
|
|
current_bnb_dtype = None
|
|
|
|
|
|
class ForgeOperations:
|
|
class Linear(torch.nn.Module):
|
|
def __init__(self, *args, **kwargs):
|
|
super().__init__()
|
|
self.dummy = torch.nn.Parameter(torch.empty(1, device=current_device, dtype=current_dtype))
|
|
self.weight = None
|
|
self.bias = None
|
|
self.parameters_manual_cast = current_manual_cast_enabled
|
|
|
|
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs):
|
|
if hasattr(self, 'dummy'):
|
|
if prefix + 'weight' in state_dict:
|
|
self.weight = torch.nn.Parameter(state_dict[prefix + 'weight'].to(self.dummy))
|
|
if prefix + 'bias' in state_dict:
|
|
self.bias = torch.nn.Parameter(state_dict[prefix + 'bias'].to(self.dummy))
|
|
del self.dummy
|
|
else:
|
|
super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
|
|
|
|
def forward(self, x):
|
|
if self.parameters_manual_cast:
|
|
weight, bias, signal = weights_manual_cast(self, x)
|
|
with main_stream_worker(weight, bias, signal):
|
|
return torch.nn.functional.linear(x, weight, bias)
|
|
else:
|
|
weight, bias = get_weight_and_bias(self)
|
|
return torch.nn.functional.linear(x, weight, bias)
|
|
|
|
class Conv2d(torch.nn.Conv2d):
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
kwargs['device'] = current_device
|
|
kwargs['dtype'] = current_dtype
|
|
super().__init__(*args, **kwargs)
|
|
self.parameters_manual_cast = current_manual_cast_enabled
|
|
|
|
def reset_parameters(self):
|
|
return None
|
|
|
|
def forward(self, x):
|
|
if self.parameters_manual_cast:
|
|
weight, bias, signal = weights_manual_cast(self, x)
|
|
with main_stream_worker(weight, bias, signal):
|
|
return self._conv_forward(x, weight, bias)
|
|
else:
|
|
weight, bias = get_weight_and_bias(self)
|
|
return super()._conv_forward(x, weight, bias)
|
|
|
|
class Conv3d(torch.nn.Conv3d):
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
kwargs['device'] = current_device
|
|
kwargs['dtype'] = current_dtype
|
|
super().__init__(*args, **kwargs)
|
|
self.parameters_manual_cast = current_manual_cast_enabled
|
|
|
|
def reset_parameters(self):
|
|
return None
|
|
|
|
def forward(self, x):
|
|
if self.parameters_manual_cast:
|
|
weight, bias, signal = weights_manual_cast(self, x)
|
|
with main_stream_worker(weight, bias, signal):
|
|
return self._conv_forward(x, weight, bias)
|
|
else:
|
|
weight, bias = get_weight_and_bias(self)
|
|
return super()._conv_forward(input, weight, bias)
|
|
|
|
class Conv1d(torch.nn.Conv1d):
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
kwargs['device'] = current_device
|
|
kwargs['dtype'] = current_dtype
|
|
super().__init__(*args, **kwargs)
|
|
self.parameters_manual_cast = current_manual_cast_enabled
|
|
|
|
def reset_parameters(self):
|
|
return None
|
|
|
|
def forward(self, x):
|
|
if self.parameters_manual_cast:
|
|
weight, bias, signal = weights_manual_cast(self, x)
|
|
with main_stream_worker(weight, bias, signal):
|
|
return self._conv_forward(x, weight, bias)
|
|
else:
|
|
weight, bias = get_weight_and_bias(self)
|
|
return super()._conv_forward(input, weight, bias)
|
|
|
|
class ConvTranspose2d(torch.nn.ConvTranspose2d):
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
kwargs['device'] = current_device
|
|
kwargs['dtype'] = current_dtype
|
|
super().__init__(*args, **kwargs)
|
|
self.parameters_manual_cast = current_manual_cast_enabled
|
|
|
|
def reset_parameters(self):
|
|
return None
|
|
|
|
def forward(self, x, output_size=None):
|
|
if self.parameters_manual_cast:
|
|
num_spatial_dims = 2
|
|
output_padding = self._output_padding(x, output_size, self.stride, self.padding, self.kernel_size, num_spatial_dims, self.dilation)
|
|
|
|
weight, bias, signal = weights_manual_cast(self, x)
|
|
with main_stream_worker(weight, bias, signal):
|
|
return torch.nn.functional.conv_transpose2d(x, weight, bias, self.stride, self.padding, output_padding, self.groups, self.dilation)
|
|
else:
|
|
weight, bias = get_weight_and_bias(self)
|
|
num_spatial_dims = 2
|
|
output_padding = self._output_padding(x, output_size, self.stride, self.padding, self.kernel_size, num_spatial_dims, self.dilation)
|
|
return torch.nn.functional.conv_transpose2d(x, weight, bias, self.stride, self.padding, output_padding, self.groups, self.dilation)
|
|
|
|
class ConvTranspose1d(torch.nn.ConvTranspose1d):
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
kwargs['device'] = current_device
|
|
kwargs['dtype'] = current_dtype
|
|
super().__init__(*args, **kwargs)
|
|
self.parameters_manual_cast = current_manual_cast_enabled
|
|
|
|
def reset_parameters(self):
|
|
return None
|
|
|
|
def forward(self, x, output_size=None):
|
|
if self.parameters_manual_cast:
|
|
num_spatial_dims = 1
|
|
output_padding = self._output_padding(x, output_size, self.stride, self.padding, self.kernel_size, num_spatial_dims, self.dilation)
|
|
|
|
weight, bias, signal = weights_manual_cast(self, x)
|
|
with main_stream_worker(weight, bias, signal):
|
|
return torch.nn.functional.conv_transpose1d(x, weight, bias, self.stride, self.padding, output_padding, self.groups, self.dilation)
|
|
else:
|
|
weight, bias = get_weight_and_bias(self)
|
|
num_spatial_dims = 1
|
|
output_padding = self._output_padding(x, output_size, self.stride, self.padding, self.kernel_size, num_spatial_dims, self.dilation)
|
|
return torch.nn.functional.conv_transpose2d(x, weight, bias, self.stride, self.padding, output_padding, self.groups, self.dilation)
|
|
|
|
class ConvTranspose3d(torch.nn.ConvTranspose3d):
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
kwargs['device'] = current_device
|
|
kwargs['dtype'] = current_dtype
|
|
super().__init__(*args, **kwargs)
|
|
self.parameters_manual_cast = current_manual_cast_enabled
|
|
|
|
def reset_parameters(self):
|
|
return None
|
|
|
|
def forward(self, x, output_size=None):
|
|
if self.parameters_manual_cast:
|
|
num_spatial_dims = 3
|
|
output_padding = self._output_padding(x, output_size, self.stride, self.padding, self.kernel_size, num_spatial_dims, self.dilation)
|
|
|
|
weight, bias, signal = weights_manual_cast(self, x)
|
|
with main_stream_worker(weight, bias, signal):
|
|
return torch.nn.functional.conv_transpose3d(x, weight, bias, self.stride, self.padding, output_padding, self.groups, self.dilation)
|
|
else:
|
|
weight, bias = get_weight_and_bias(self)
|
|
num_spatial_dims = 3
|
|
output_padding = self._output_padding(x, output_size, self.stride, self.padding, self.kernel_size, num_spatial_dims, self.dilation)
|
|
return torch.nn.functional.conv_transpose2d(x, weight, bias, self.stride, self.padding, output_padding, self.groups, self.dilation)
|
|
|
|
class GroupNorm(torch.nn.GroupNorm):
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
kwargs['device'] = current_device
|
|
kwargs['dtype'] = current_dtype
|
|
super().__init__(*args, **kwargs)
|
|
self.parameters_manual_cast = current_manual_cast_enabled
|
|
|
|
def reset_parameters(self):
|
|
return None
|
|
|
|
def forward(self, x):
|
|
if self.parameters_manual_cast:
|
|
weight, bias, signal = weights_manual_cast(self, x)
|
|
with main_stream_worker(weight, bias, signal):
|
|
return torch.nn.functional.group_norm(x, self.num_groups, weight, bias, self.eps)
|
|
else:
|
|
return super().forward(x)
|
|
|
|
class LayerNorm(torch.nn.LayerNorm):
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
kwargs['device'] = current_device
|
|
kwargs['dtype'] = current_dtype
|
|
super().__init__(*args, **kwargs)
|
|
self.parameters_manual_cast = current_manual_cast_enabled
|
|
|
|
def reset_parameters(self):
|
|
return None
|
|
|
|
def forward(self, x):
|
|
if self.parameters_manual_cast:
|
|
weight, bias, signal = weights_manual_cast(self, x)
|
|
with main_stream_worker(weight, bias, signal):
|
|
return torch.nn.functional.layer_norm(x, self.normalized_shape, weight, bias, self.eps)
|
|
else:
|
|
return super().forward(x)
|
|
|
|
class Embedding(torch.nn.Embedding):
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
kwargs['device'] = current_device
|
|
super().__init__(*args, **kwargs)
|
|
self.parameters_manual_cast = current_manual_cast_enabled
|
|
self.bias = None
|
|
|
|
def reset_parameters(self):
|
|
self.bias = None
|
|
return None
|
|
|
|
def forward(self, x):
|
|
if self.parameters_manual_cast:
|
|
weight, bias, signal = weights_manual_cast(self, x, skip_weight_dtype=True, skip_bias_dtype=True)
|
|
with main_stream_worker(weight, bias, signal):
|
|
return torch.nn.functional.embedding(x, weight, self.padding_idx, self.max_norm, self.norm_type, self.scale_grad_by_freq, self.sparse)
|
|
else:
|
|
return super().forward(x)
|
|
|
|
|
|
try:
|
|
from backend.operations_bnb import ForgeLoader4Bit, ForgeParams4bit, functional_linear_4bits, functional_dequantize_4bit
|
|
|
|
class ForgeOperationsBNB4bits(ForgeOperations):
|
|
class Linear(ForgeLoader4Bit):
|
|
def __init__(self, *args, **kwargs):
|
|
super().__init__(device=current_device, dtype=current_dtype, quant_type=current_bnb_dtype)
|
|
self.parameters_manual_cast = current_manual_cast_enabled
|
|
|
|
def forward(self, x):
|
|
if self.bias is not None and self.bias.dtype != x.dtype:
|
|
# Maybe this can also be set to all non-bnb ops since the cost is very low.
|
|
# And it only invokes one time, and most linear does not have bias
|
|
self.bias = utils.tensor2parameter(self.bias.to(x.dtype))
|
|
|
|
if hasattr(self, 'forge_online_loras'):
|
|
weight, bias, signal = weights_manual_cast(self, x, weight_fn=functional_dequantize_4bit, bias_fn=None, skip_bias_dtype=True)
|
|
with main_stream_worker(weight, bias, signal):
|
|
return torch.nn.functional.linear(x, weight, bias)
|
|
|
|
if not self.parameters_manual_cast:
|
|
return functional_linear_4bits(x, self.weight, self.bias)
|
|
elif not self.weight.bnb_quantized:
|
|
assert x.device.type == 'cuda', 'BNB Must Use CUDA as Computation Device!'
|
|
layer_original_device = self.weight.device
|
|
self.weight = self.weight._quantize(x.device)
|
|
bias = self.bias.to(x.device) if self.bias is not None else None
|
|
out = functional_linear_4bits(x, self.weight, bias)
|
|
self.weight = self.weight.to(layer_original_device)
|
|
return out
|
|
else:
|
|
weight, bias, signal = weights_manual_cast(self, x, skip_weight_dtype=True, skip_bias_dtype=True)
|
|
with main_stream_worker(weight, bias, signal):
|
|
return functional_linear_4bits(x, weight, bias)
|
|
|
|
bnb_avaliable = True
|
|
except:
|
|
bnb_avaliable = False
|
|
|
|
|
|
from backend.operations_gguf import dequantize_tensor
|
|
|
|
|
|
class ForgeOperationsGGUF(ForgeOperations):
|
|
class Linear(torch.nn.Module):
|
|
def __init__(self, *args, **kwargs):
|
|
super().__init__()
|
|
self.dummy = torch.nn.Parameter(torch.empty(1, device=current_device, dtype=current_dtype))
|
|
self.weight = None
|
|
self.bias = None
|
|
self.parameters_manual_cast = current_manual_cast_enabled
|
|
|
|
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs):
|
|
if hasattr(self, 'dummy'):
|
|
computation_dtype = self.dummy.dtype
|
|
if computation_dtype not in [torch.float16, torch.bfloat16]:
|
|
# GGUF cast only supports 16bits otherwise super slow
|
|
computation_dtype = torch.float16
|
|
if prefix + 'weight' in state_dict:
|
|
self.weight = state_dict[prefix + 'weight'].to(device=self.dummy.device)
|
|
self.weight.computation_dtype = computation_dtype
|
|
if prefix + 'bias' in state_dict:
|
|
self.bias = state_dict[prefix + 'bias'].to(device=self.dummy.device)
|
|
self.bias.computation_dtype = computation_dtype
|
|
del self.dummy
|
|
else:
|
|
if prefix + 'weight' in state_dict:
|
|
self.weight = state_dict[prefix + 'weight']
|
|
if prefix + 'bias' in state_dict:
|
|
self.bias = state_dict[prefix + 'bias']
|
|
return
|
|
|
|
def _apply(self, fn, recurse=True):
|
|
for k, p in self.named_parameters(recurse=False, remove_duplicate=True):
|
|
setattr(self, k, utils.tensor2parameter(fn(p)))
|
|
return self
|
|
|
|
def forward(self, x):
|
|
if self.bias is not None and self.bias.dtype != x.dtype:
|
|
self.bias = utils.tensor2parameter(dequantize_tensor(self.bias).to(x.dtype))
|
|
|
|
if self.weight is not None and self.weight.dtype != x.dtype and getattr(self.weight, 'gguf_cls', None) is None:
|
|
self.weight = utils.tensor2parameter(self.weight.to(x.dtype))
|
|
|
|
weight, bias, signal = weights_manual_cast(self, x, weight_fn=dequantize_tensor, bias_fn=None, skip_bias_dtype=True)
|
|
with main_stream_worker(weight, bias, signal):
|
|
return torch.nn.functional.linear(x, weight, bias)
|
|
|
|
|
|
@contextlib.contextmanager
|
|
def using_forge_operations(operations=None, device=None, dtype=None, manual_cast_enabled=False, bnb_dtype=None):
|
|
global current_device, current_dtype, current_manual_cast_enabled, current_bnb_dtype
|
|
|
|
current_device, current_dtype, current_manual_cast_enabled, current_bnb_dtype = device, dtype, manual_cast_enabled, bnb_dtype
|
|
|
|
if operations is None:
|
|
if bnb_dtype in ['gguf']:
|
|
operations = ForgeOperationsGGUF
|
|
elif bnb_avaliable and bnb_dtype in ['nf4', 'fp4']:
|
|
operations = ForgeOperationsBNB4bits
|
|
else:
|
|
operations = ForgeOperations
|
|
|
|
op_names = ['Linear', 'Conv1d', 'Conv2d', 'Conv3d', 'ConvTranspose1d', 'ConvTranspose2d', 'ConvTranspose3d', 'GroupNorm', 'LayerNorm', 'Embedding']
|
|
backups = {op_name: getattr(torch.nn, op_name) for op_name in op_names}
|
|
|
|
try:
|
|
for op_name in op_names:
|
|
setattr(torch.nn, op_name, getattr(operations, op_name))
|
|
|
|
yield
|
|
|
|
finally:
|
|
for op_name in op_names:
|
|
setattr(torch.nn, op_name, backups[op_name])
|
|
return
|
|
|
|
|
|
def shift_manual_cast(model, enabled):
|
|
for m in model.modules():
|
|
if hasattr(m, 'parameters_manual_cast'):
|
|
m.parameters_manual_cast = enabled
|
|
return
|
|
|
|
|
|
@contextlib.contextmanager
|
|
def automatic_memory_management():
|
|
memory_management.free_memory(
|
|
memory_required=3 * 1024 * 1024 * 1024,
|
|
device=memory_management.get_torch_device()
|
|
)
|
|
|
|
module_list = []
|
|
|
|
original_init = torch.nn.Module.__init__
|
|
original_to = torch.nn.Module.to
|
|
|
|
def patched_init(self, *args, **kwargs):
|
|
module_list.append(self)
|
|
return original_init(self, *args, **kwargs)
|
|
|
|
def patched_to(self, *args, **kwargs):
|
|
module_list.append(self)
|
|
return original_to(self, *args, **kwargs)
|
|
|
|
try:
|
|
torch.nn.Module.__init__ = patched_init
|
|
torch.nn.Module.to = patched_to
|
|
yield
|
|
finally:
|
|
torch.nn.Module.__init__ = original_init
|
|
torch.nn.Module.to = original_to
|
|
|
|
start = time.perf_counter()
|
|
module_list = set(module_list)
|
|
|
|
for module in module_list:
|
|
module.cpu()
|
|
|
|
memory_management.soft_empty_cache()
|
|
end = time.perf_counter()
|
|
|
|
print(f'Automatic Memory Management: {len(module_list)} Modules in {(end - start):.2f} seconds.')
|
|
return
|
|
|
|
|
|
class DynamicSwapInstaller:
|
|
@staticmethod
|
|
def _install_module(module: torch.nn.Module, target_device: torch.device):
|
|
original_class = module.__class__
|
|
module.__dict__['forge_backup_original_class'] = original_class
|
|
|
|
def hacked_get_attr(self, name: str):
|
|
if '_parameters' in self.__dict__:
|
|
_parameters = self.__dict__['_parameters']
|
|
if name in _parameters:
|
|
p = _parameters[name]
|
|
if p is None:
|
|
return None
|
|
if p.__class__ == torch.nn.Parameter:
|
|
return torch.nn.Parameter(p.to(target_device), requires_grad=p.requires_grad)
|
|
else:
|
|
return p.to(target_device)
|
|
if '_buffers' in self.__dict__:
|
|
_buffers = self.__dict__['_buffers']
|
|
if name in _buffers:
|
|
return _buffers[name].to(target_device)
|
|
return super(original_class, self).__getattr__(name)
|
|
|
|
module.__class__ = type('DynamicSwap_' + original_class.__name__, (original_class,), {
|
|
'__getattr__': hacked_get_attr,
|
|
})
|
|
|
|
return
|
|
|
|
@staticmethod
|
|
def _uninstall_module(module: torch.nn.Module):
|
|
if 'forge_backup_original_class' in module.__dict__:
|
|
module.__class__ = module.__dict__.pop('forge_backup_original_class')
|
|
return
|
|
|
|
@staticmethod
|
|
def install_model(model: torch.nn.Module, target_device: torch.device):
|
|
for m in model.modules():
|
|
DynamicSwapInstaller._install_module(m, target_device)
|
|
return
|
|
|
|
@staticmethod
|
|
def uninstall_model(model: torch.nn.Module):
|
|
for m in model.modules():
|
|
DynamicSwapInstaller._uninstall_module(m)
|
|
return
|
|
|