Files
stable-diffusion-webui-forge/backend/operations.py
layerdiffusion 4c9380c46a Speed up quant model loading and inference ...
... based on 3 evidences:
1. torch.Tensor.view on one big tensor is slightly faster than calling torch.Tensor.to on multiple small tensors.
2. but torch.Tensor.to with dtype change is significantly slower than torch.Tensor.view
3. “baking” model on GPU is significantly faster than computing on CPU when model load.

mainly influence inference of Q8_0, Q4_0/1/K and loading of all quants
2024-08-30 00:49:05 -07:00

557 lines
22 KiB
Python

# Copyright Forge 2024
import time
import torch
import contextlib
from backend import stream, memory_management, utils
from backend.patcher.lora import merge_lora_to_weight
stash = {}
def get_weight_and_bias(layer, weight_args=None, bias_args=None, weight_fn=None, bias_fn=None):
patches = getattr(layer, 'forge_online_loras', None)
weight_patches, bias_patches = None, None
if patches is not None:
weight_patches = patches.get('weight', None)
if patches is not None:
bias_patches = patches.get('bias', None)
weight = None
if layer.weight is not None:
weight = layer.weight
if weight_fn is not None:
if weight_args is not None:
fn_device = weight_args.get('device', None)
if fn_device is not None:
weight = weight.to(device=fn_device)
weight = weight_fn(weight)
if weight_args is not None:
weight = weight.to(**weight_args)
if weight_patches is not None:
weight = merge_lora_to_weight(patches=weight_patches, weight=weight, key="online weight lora", computation_dtype=weight.dtype)
bias = None
if layer.bias is not None:
bias = layer.bias
if bias_fn is not None:
if bias_args is not None:
fn_device = bias_args.get('device', None)
if fn_device is not None:
bias = bias.to(device=fn_device)
bias = bias_fn(bias)
if bias_args is not None:
bias = bias.to(**bias_args)
if bias_patches is not None:
bias = merge_lora_to_weight(patches=bias_patches, weight=bias, key="online bias lora", computation_dtype=bias.dtype)
return weight, bias
def weights_manual_cast(layer, x, skip_weight_dtype=False, skip_bias_dtype=False, weight_fn=None, bias_fn=None):
weight, bias, signal = None, None, None
non_blocking = True
if getattr(x.device, 'type', None) == 'mps':
non_blocking = False
target_dtype = x.dtype
target_device = x.device
if skip_weight_dtype:
weight_args = dict(device=target_device, non_blocking=non_blocking)
else:
weight_args = dict(device=target_device, dtype=target_dtype, non_blocking=non_blocking)
if skip_bias_dtype:
bias_args = dict(device=target_device, non_blocking=non_blocking)
else:
bias_args = dict(device=target_device, dtype=target_dtype, non_blocking=non_blocking)
if stream.should_use_stream():
with stream.stream_context()(stream.mover_stream):
weight, bias = get_weight_and_bias(layer, weight_args, bias_args, weight_fn=weight_fn, bias_fn=bias_fn)
signal = stream.mover_stream.record_event()
else:
weight, bias = get_weight_and_bias(layer, weight_args, bias_args, weight_fn=weight_fn, bias_fn=bias_fn)
return weight, bias, signal
@contextlib.contextmanager
def main_stream_worker(weight, bias, signal):
if signal is None or not stream.should_use_stream():
yield
return
with stream.stream_context()(stream.current_stream):
stream.current_stream.wait_event(signal)
yield
finished_signal = stream.current_stream.record_event()
stash[id(finished_signal)] = (weight, bias, finished_signal)
garbage = []
for k, (w, b, s) in stash.items():
if s.query():
garbage.append(k)
for k in garbage:
del stash[k]
return
def cleanup_cache():
if not stream.should_use_stream():
return
stream.current_stream.synchronize()
stream.mover_stream.synchronize()
stash.clear()
return
current_device = None
current_dtype = None
current_manual_cast_enabled = False
current_bnb_dtype = None
class ForgeOperations:
class Linear(torch.nn.Module):
def __init__(self, *args, **kwargs):
super().__init__()
self.dummy = torch.nn.Parameter(torch.empty(1, device=current_device, dtype=current_dtype))
self.weight = None
self.bias = None
self.parameters_manual_cast = current_manual_cast_enabled
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs):
if hasattr(self, 'dummy'):
if prefix + 'weight' in state_dict:
self.weight = torch.nn.Parameter(state_dict[prefix + 'weight'].to(self.dummy))
if prefix + 'bias' in state_dict:
self.bias = torch.nn.Parameter(state_dict[prefix + 'bias'].to(self.dummy))
del self.dummy
else:
super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
def forward(self, x):
if self.parameters_manual_cast:
weight, bias, signal = weights_manual_cast(self, x)
with main_stream_worker(weight, bias, signal):
return torch.nn.functional.linear(x, weight, bias)
else:
weight, bias = get_weight_and_bias(self)
return torch.nn.functional.linear(x, weight, bias)
class Conv2d(torch.nn.Conv2d):
def __init__(self, *args, **kwargs):
kwargs['device'] = current_device
kwargs['dtype'] = current_dtype
super().__init__(*args, **kwargs)
self.parameters_manual_cast = current_manual_cast_enabled
def reset_parameters(self):
return None
def forward(self, x):
if self.parameters_manual_cast:
weight, bias, signal = weights_manual_cast(self, x)
with main_stream_worker(weight, bias, signal):
return self._conv_forward(x, weight, bias)
else:
weight, bias = get_weight_and_bias(self)
return super()._conv_forward(x, weight, bias)
class Conv3d(torch.nn.Conv3d):
def __init__(self, *args, **kwargs):
kwargs['device'] = current_device
kwargs['dtype'] = current_dtype
super().__init__(*args, **kwargs)
self.parameters_manual_cast = current_manual_cast_enabled
def reset_parameters(self):
return None
def forward(self, x):
if self.parameters_manual_cast:
weight, bias, signal = weights_manual_cast(self, x)
with main_stream_worker(weight, bias, signal):
return self._conv_forward(x, weight, bias)
else:
weight, bias = get_weight_and_bias(self)
return super()._conv_forward(input, weight, bias)
class Conv1d(torch.nn.Conv1d):
def __init__(self, *args, **kwargs):
kwargs['device'] = current_device
kwargs['dtype'] = current_dtype
super().__init__(*args, **kwargs)
self.parameters_manual_cast = current_manual_cast_enabled
def reset_parameters(self):
return None
def forward(self, x):
if self.parameters_manual_cast:
weight, bias, signal = weights_manual_cast(self, x)
with main_stream_worker(weight, bias, signal):
return self._conv_forward(x, weight, bias)
else:
weight, bias = get_weight_and_bias(self)
return super()._conv_forward(input, weight, bias)
class ConvTranspose2d(torch.nn.ConvTranspose2d):
def __init__(self, *args, **kwargs):
kwargs['device'] = current_device
kwargs['dtype'] = current_dtype
super().__init__(*args, **kwargs)
self.parameters_manual_cast = current_manual_cast_enabled
def reset_parameters(self):
return None
def forward(self, x, output_size=None):
if self.parameters_manual_cast:
num_spatial_dims = 2
output_padding = self._output_padding(x, output_size, self.stride, self.padding, self.kernel_size, num_spatial_dims, self.dilation)
weight, bias, signal = weights_manual_cast(self, x)
with main_stream_worker(weight, bias, signal):
return torch.nn.functional.conv_transpose2d(x, weight, bias, self.stride, self.padding, output_padding, self.groups, self.dilation)
else:
weight, bias = get_weight_and_bias(self)
num_spatial_dims = 2
output_padding = self._output_padding(x, output_size, self.stride, self.padding, self.kernel_size, num_spatial_dims, self.dilation)
return torch.nn.functional.conv_transpose2d(x, weight, bias, self.stride, self.padding, output_padding, self.groups, self.dilation)
class ConvTranspose1d(torch.nn.ConvTranspose1d):
def __init__(self, *args, **kwargs):
kwargs['device'] = current_device
kwargs['dtype'] = current_dtype
super().__init__(*args, **kwargs)
self.parameters_manual_cast = current_manual_cast_enabled
def reset_parameters(self):
return None
def forward(self, x, output_size=None):
if self.parameters_manual_cast:
num_spatial_dims = 1
output_padding = self._output_padding(x, output_size, self.stride, self.padding, self.kernel_size, num_spatial_dims, self.dilation)
weight, bias, signal = weights_manual_cast(self, x)
with main_stream_worker(weight, bias, signal):
return torch.nn.functional.conv_transpose1d(x, weight, bias, self.stride, self.padding, output_padding, self.groups, self.dilation)
else:
weight, bias = get_weight_and_bias(self)
num_spatial_dims = 1
output_padding = self._output_padding(x, output_size, self.stride, self.padding, self.kernel_size, num_spatial_dims, self.dilation)
return torch.nn.functional.conv_transpose2d(x, weight, bias, self.stride, self.padding, output_padding, self.groups, self.dilation)
class ConvTranspose3d(torch.nn.ConvTranspose3d):
def __init__(self, *args, **kwargs):
kwargs['device'] = current_device
kwargs['dtype'] = current_dtype
super().__init__(*args, **kwargs)
self.parameters_manual_cast = current_manual_cast_enabled
def reset_parameters(self):
return None
def forward(self, x, output_size=None):
if self.parameters_manual_cast:
num_spatial_dims = 3
output_padding = self._output_padding(x, output_size, self.stride, self.padding, self.kernel_size, num_spatial_dims, self.dilation)
weight, bias, signal = weights_manual_cast(self, x)
with main_stream_worker(weight, bias, signal):
return torch.nn.functional.conv_transpose3d(x, weight, bias, self.stride, self.padding, output_padding, self.groups, self.dilation)
else:
weight, bias = get_weight_and_bias(self)
num_spatial_dims = 3
output_padding = self._output_padding(x, output_size, self.stride, self.padding, self.kernel_size, num_spatial_dims, self.dilation)
return torch.nn.functional.conv_transpose2d(x, weight, bias, self.stride, self.padding, output_padding, self.groups, self.dilation)
class GroupNorm(torch.nn.GroupNorm):
def __init__(self, *args, **kwargs):
kwargs['device'] = current_device
kwargs['dtype'] = current_dtype
super().__init__(*args, **kwargs)
self.parameters_manual_cast = current_manual_cast_enabled
def reset_parameters(self):
return None
def forward(self, x):
if self.parameters_manual_cast:
weight, bias, signal = weights_manual_cast(self, x)
with main_stream_worker(weight, bias, signal):
return torch.nn.functional.group_norm(x, self.num_groups, weight, bias, self.eps)
else:
return super().forward(x)
class LayerNorm(torch.nn.LayerNorm):
def __init__(self, *args, **kwargs):
kwargs['device'] = current_device
kwargs['dtype'] = current_dtype
super().__init__(*args, **kwargs)
self.parameters_manual_cast = current_manual_cast_enabled
def reset_parameters(self):
return None
def forward(self, x):
if self.parameters_manual_cast:
weight, bias, signal = weights_manual_cast(self, x)
with main_stream_worker(weight, bias, signal):
return torch.nn.functional.layer_norm(x, self.normalized_shape, weight, bias, self.eps)
else:
return super().forward(x)
class Embedding(torch.nn.Embedding):
def __init__(self, *args, **kwargs):
kwargs['device'] = current_device
super().__init__(*args, **kwargs)
self.parameters_manual_cast = current_manual_cast_enabled
self.bias = None
def reset_parameters(self):
self.bias = None
return None
def forward(self, x):
if self.parameters_manual_cast:
weight, bias, signal = weights_manual_cast(self, x, skip_weight_dtype=True, skip_bias_dtype=True)
with main_stream_worker(weight, bias, signal):
return torch.nn.functional.embedding(x, weight, self.padding_idx, self.max_norm, self.norm_type, self.scale_grad_by_freq, self.sparse)
else:
return super().forward(x)
try:
from backend.operations_bnb import ForgeLoader4Bit, ForgeParams4bit, functional_linear_4bits, functional_dequantize_4bit
class ForgeOperationsBNB4bits(ForgeOperations):
class Linear(ForgeLoader4Bit):
def __init__(self, *args, **kwargs):
super().__init__(device=current_device, dtype=current_dtype, quant_type=current_bnb_dtype)
self.parameters_manual_cast = current_manual_cast_enabled
def forward(self, x):
if self.bias is not None and self.bias.dtype != x.dtype:
# Maybe this can also be set to all non-bnb ops since the cost is very low.
# And it only invokes one time, and most linear does not have bias
self.bias = utils.tensor2parameter(self.bias.to(x.dtype))
if hasattr(self, 'forge_online_loras'):
weight, bias, signal = weights_manual_cast(self, x, weight_fn=functional_dequantize_4bit, bias_fn=None, skip_bias_dtype=True)
with main_stream_worker(weight, bias, signal):
return torch.nn.functional.linear(x, weight, bias)
if not self.parameters_manual_cast:
return functional_linear_4bits(x, self.weight, self.bias)
elif not self.weight.bnb_quantized:
assert x.device.type == 'cuda', 'BNB Must Use CUDA as Computation Device!'
layer_original_device = self.weight.device
self.weight = self.weight._quantize(x.device)
bias = self.bias.to(x.device) if self.bias is not None else None
out = functional_linear_4bits(x, self.weight, bias)
self.weight = self.weight.to(layer_original_device)
return out
else:
weight, bias, signal = weights_manual_cast(self, x, skip_weight_dtype=True, skip_bias_dtype=True)
with main_stream_worker(weight, bias, signal):
return functional_linear_4bits(x, weight, bias)
bnb_avaliable = True
except:
bnb_avaliable = False
from backend.operations_gguf import dequantize_tensor
class ForgeOperationsGGUF(ForgeOperations):
class Linear(torch.nn.Module):
def __init__(self, *args, **kwargs):
super().__init__()
self.dummy = torch.nn.Parameter(torch.empty(1, device=current_device, dtype=current_dtype))
self.weight = None
self.bias = None
self.parameters_manual_cast = current_manual_cast_enabled
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs):
if hasattr(self, 'dummy'):
computation_dtype = self.dummy.dtype
if computation_dtype not in [torch.float16, torch.bfloat16]:
# GGUF cast only supports 16bits otherwise super slow
computation_dtype = torch.float16
if prefix + 'weight' in state_dict:
self.weight = state_dict[prefix + 'weight'].to(device=self.dummy.device)
self.weight.computation_dtype = computation_dtype
if prefix + 'bias' in state_dict:
self.bias = state_dict[prefix + 'bias'].to(device=self.dummy.device)
self.bias.computation_dtype = computation_dtype
del self.dummy
else:
if prefix + 'weight' in state_dict:
self.weight = state_dict[prefix + 'weight']
if prefix + 'bias' in state_dict:
self.bias = state_dict[prefix + 'bias']
return
def _apply(self, fn, recurse=True):
for k, p in self.named_parameters(recurse=False, remove_duplicate=True):
setattr(self, k, utils.tensor2parameter(fn(p)))
return self
def forward(self, x):
if self.bias is not None and self.bias.dtype != x.dtype:
self.bias = utils.tensor2parameter(dequantize_tensor(self.bias).to(x.dtype))
if self.weight is not None and self.weight.dtype != x.dtype and getattr(self.weight, 'gguf_cls', None) is None:
self.weight = utils.tensor2parameter(self.weight.to(x.dtype))
weight, bias, signal = weights_manual_cast(self, x, weight_fn=dequantize_tensor, bias_fn=None, skip_bias_dtype=True)
with main_stream_worker(weight, bias, signal):
return torch.nn.functional.linear(x, weight, bias)
@contextlib.contextmanager
def using_forge_operations(operations=None, device=None, dtype=None, manual_cast_enabled=False, bnb_dtype=None):
global current_device, current_dtype, current_manual_cast_enabled, current_bnb_dtype
current_device, current_dtype, current_manual_cast_enabled, current_bnb_dtype = device, dtype, manual_cast_enabled, bnb_dtype
if operations is None:
if bnb_dtype in ['gguf']:
operations = ForgeOperationsGGUF
elif bnb_avaliable and bnb_dtype in ['nf4', 'fp4']:
operations = ForgeOperationsBNB4bits
else:
operations = ForgeOperations
op_names = ['Linear', 'Conv1d', 'Conv2d', 'Conv3d', 'ConvTranspose1d', 'ConvTranspose2d', 'ConvTranspose3d', 'GroupNorm', 'LayerNorm', 'Embedding']
backups = {op_name: getattr(torch.nn, op_name) for op_name in op_names}
try:
for op_name in op_names:
setattr(torch.nn, op_name, getattr(operations, op_name))
yield
finally:
for op_name in op_names:
setattr(torch.nn, op_name, backups[op_name])
return
def shift_manual_cast(model, enabled):
for m in model.modules():
if hasattr(m, 'parameters_manual_cast'):
m.parameters_manual_cast = enabled
return
@contextlib.contextmanager
def automatic_memory_management():
memory_management.free_memory(
memory_required=3 * 1024 * 1024 * 1024,
device=memory_management.get_torch_device()
)
module_list = []
original_init = torch.nn.Module.__init__
original_to = torch.nn.Module.to
def patched_init(self, *args, **kwargs):
module_list.append(self)
return original_init(self, *args, **kwargs)
def patched_to(self, *args, **kwargs):
module_list.append(self)
return original_to(self, *args, **kwargs)
try:
torch.nn.Module.__init__ = patched_init
torch.nn.Module.to = patched_to
yield
finally:
torch.nn.Module.__init__ = original_init
torch.nn.Module.to = original_to
start = time.perf_counter()
module_list = set(module_list)
for module in module_list:
module.cpu()
memory_management.soft_empty_cache()
end = time.perf_counter()
print(f'Automatic Memory Management: {len(module_list)} Modules in {(end - start):.2f} seconds.')
return
class DynamicSwapInstaller:
@staticmethod
def _install_module(module: torch.nn.Module, target_device: torch.device):
original_class = module.__class__
module.__dict__['forge_backup_original_class'] = original_class
def hacked_get_attr(self, name: str):
if '_parameters' in self.__dict__:
_parameters = self.__dict__['_parameters']
if name in _parameters:
p = _parameters[name]
if p is None:
return None
if p.__class__ == torch.nn.Parameter:
return torch.nn.Parameter(p.to(target_device), requires_grad=p.requires_grad)
else:
return p.to(target_device)
if '_buffers' in self.__dict__:
_buffers = self.__dict__['_buffers']
if name in _buffers:
return _buffers[name].to(target_device)
return super(original_class, self).__getattr__(name)
module.__class__ = type('DynamicSwap_' + original_class.__name__, (original_class,), {
'__getattr__': hacked_get_attr,
})
return
@staticmethod
def _uninstall_module(module: torch.nn.Module):
if 'forge_backup_original_class' in module.__dict__:
module.__class__ = module.__dict__.pop('forge_backup_original_class')
return
@staticmethod
def install_model(model: torch.nn.Module, target_device: torch.device):
for m in model.modules():
DynamicSwapInstaller._install_module(m, target_device)
return
@staticmethod
def uninstall_model(model: torch.nn.Module):
for m in model.modules():
DynamicSwapInstaller._uninstall_module(m)
return