mirror of
https://github.com/lllyasviel/stable-diffusion-webui-forge.git
synced 2026-04-28 02:01:25 +00:00
forge 2.0.0
see also discussions
This commit is contained in:
@@ -1,3 +1,5 @@
|
||||
# Copyright Forge 2024
|
||||
|
||||
import time
|
||||
import torch
|
||||
import contextlib
|
||||
@@ -8,7 +10,7 @@ from backend import stream, memory_management
|
||||
stash = {}
|
||||
|
||||
|
||||
def weights_manual_cast(layer, x, skip_dtype=False):
|
||||
def weights_manual_cast(layer, x, skip_weight_dtype=False, skip_bias_dtype=False):
|
||||
weight, bias, signal = None, None, None
|
||||
non_blocking = True
|
||||
|
||||
@@ -18,21 +20,28 @@ def weights_manual_cast(layer, x, skip_dtype=False):
|
||||
target_dtype = x.dtype
|
||||
target_device = x.device
|
||||
|
||||
if skip_dtype:
|
||||
target_dtype = None
|
||||
if skip_weight_dtype:
|
||||
weight_args = dict(device=target_device, non_blocking=non_blocking)
|
||||
else:
|
||||
weight_args = dict(device=target_device, dtype=target_dtype, non_blocking=non_blocking)
|
||||
|
||||
if skip_bias_dtype:
|
||||
bias_args = dict(device=target_device, non_blocking=non_blocking)
|
||||
else:
|
||||
bias_args = dict(device=target_device, dtype=target_dtype, non_blocking=non_blocking)
|
||||
|
||||
if stream.should_use_stream():
|
||||
with stream.stream_context()(stream.mover_stream):
|
||||
if layer.weight is not None:
|
||||
weight = layer.weight.to(device=target_device, dtype=target_dtype, non_blocking=non_blocking)
|
||||
weight = layer.weight.to(**weight_args)
|
||||
if layer.bias is not None:
|
||||
bias = layer.bias.to(device=target_device, dtype=target_dtype, non_blocking=non_blocking)
|
||||
bias = layer.bias.to(**bias_args)
|
||||
signal = stream.mover_stream.record_event()
|
||||
else:
|
||||
if layer.weight is not None:
|
||||
weight = layer.weight.to(device=target_device, dtype=target_dtype, non_blocking=non_blocking)
|
||||
weight = layer.weight.to(**weight_args)
|
||||
if layer.bias is not None:
|
||||
bias = layer.bias.to(device=target_device, dtype=target_dtype, non_blocking=non_blocking)
|
||||
bias = layer.bias.to(**bias_args)
|
||||
|
||||
return weight, bias, signal
|
||||
|
||||
@@ -72,19 +81,27 @@ def cleanup_cache():
|
||||
current_device = None
|
||||
current_dtype = None
|
||||
current_manual_cast_enabled = False
|
||||
current_bnb_dtype = None
|
||||
|
||||
|
||||
class ForgeOperations:
|
||||
class Linear(torch.nn.Linear):
|
||||
|
||||
class Linear(torch.nn.Module):
|
||||
def __init__(self, *args, **kwargs):
|
||||
kwargs['device'] = current_device
|
||||
kwargs['dtype'] = current_dtype
|
||||
super().__init__(*args, **kwargs)
|
||||
super().__init__()
|
||||
self.dummy = torch.nn.Parameter(torch.empty(1, device=current_device, dtype=current_dtype))
|
||||
self.weight = None
|
||||
self.bias = None
|
||||
self.parameters_manual_cast = current_manual_cast_enabled
|
||||
|
||||
def reset_parameters(self):
|
||||
return None
|
||||
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs):
|
||||
if hasattr(self, 'dummy'):
|
||||
if prefix + 'weight' in state_dict:
|
||||
self.weight = torch.nn.Parameter(state_dict[prefix + 'weight'].to(self.dummy))
|
||||
if prefix + 'bias' in state_dict:
|
||||
self.bias = torch.nn.Parameter(state_dict[prefix + 'bias'].to(self.dummy))
|
||||
del self.dummy
|
||||
else:
|
||||
super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
|
||||
|
||||
def forward(self, x):
|
||||
if self.parameters_manual_cast:
|
||||
@@ -92,7 +109,7 @@ class ForgeOperations:
|
||||
with main_stream_worker(weight, bias, signal):
|
||||
return torch.nn.functional.linear(x, weight, bias)
|
||||
else:
|
||||
return super().forward(x)
|
||||
return torch.nn.functional.linear(x, self.weight, self.bias)
|
||||
|
||||
class Conv2d(torch.nn.Conv2d):
|
||||
|
||||
@@ -269,21 +286,61 @@ class ForgeOperations:
|
||||
|
||||
def forward(self, x):
|
||||
if self.parameters_manual_cast:
|
||||
weight, bias, signal = weights_manual_cast(self, x, skip_dtype=True)
|
||||
weight, bias, signal = weights_manual_cast(self, x, skip_weight_dtype=True, skip_bias_dtype=True)
|
||||
with main_stream_worker(weight, bias, signal):
|
||||
return torch.nn.functional.embedding(x, weight, self.padding_idx, self.max_norm, self.norm_type, self.scale_grad_by_freq, self.sparse)
|
||||
else:
|
||||
return super().forward(x)
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def using_forge_operations(operations=None, device=None, dtype=None, manual_cast_enabled=False):
|
||||
global current_device, current_dtype, current_manual_cast_enabled
|
||||
try:
|
||||
from backend.operations_bnb import ForgeLoader4Bit, ForgeParams4bit, functional_linear_4bits
|
||||
|
||||
current_device, current_dtype, current_manual_cast_enabled = device, dtype, manual_cast_enabled
|
||||
class ForgeOperationsBNB4bits(ForgeOperations):
|
||||
class Linear(ForgeLoader4Bit):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(device=current_device, dtype=current_dtype, quant_type=current_bnb_dtype)
|
||||
self.parameters_manual_cast = current_manual_cast_enabled
|
||||
|
||||
def forward(self, x):
|
||||
self.weight.quant_state = self.quant_state
|
||||
|
||||
if self.bias is not None and self.bias.dtype != x.dtype:
|
||||
# Maybe this can also be set to all non-bnb ops since the cost is very low.
|
||||
# And it only invokes one time, and most linear does not have bias
|
||||
self.bias.data = self.bias.data.to(x.dtype)
|
||||
|
||||
if not self.parameters_manual_cast:
|
||||
return functional_linear_4bits(x, self.weight, self.bias)
|
||||
elif not self.weight.bnb_quantized:
|
||||
assert x.device.type == 'cuda', 'BNB Must Use CUDA as Computation Device!'
|
||||
layer_original_device = self.weight.device
|
||||
self.weight = self.weight._quantize(x.device)
|
||||
bias = self.bias.to(x.device) if self.bias is not None else None
|
||||
out = functional_linear_4bits(x, self.weight, bias)
|
||||
self.weight = self.weight.to(layer_original_device)
|
||||
return out
|
||||
else:
|
||||
weight, bias, signal = weights_manual_cast(self, x, skip_weight_dtype=True, skip_bias_dtype=True)
|
||||
with main_stream_worker(weight, bias, signal):
|
||||
return functional_linear_4bits(x, weight, bias)
|
||||
|
||||
bnb_avaliable = True
|
||||
except:
|
||||
bnb_avaliable = False
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def using_forge_operations(operations=None, device=None, dtype=None, manual_cast_enabled=False, bnb_dtype=None):
|
||||
global current_device, current_dtype, current_manual_cast_enabled, current_bnb_dtype
|
||||
|
||||
current_device, current_dtype, current_manual_cast_enabled, current_bnb_dtype = device, dtype, manual_cast_enabled, bnb_dtype
|
||||
|
||||
if operations is None:
|
||||
operations = ForgeOperations
|
||||
if bnb_avaliable and bnb_dtype in ['nf4', 'fp4']:
|
||||
operations = ForgeOperationsBNB4bits
|
||||
else:
|
||||
operations = ForgeOperations
|
||||
|
||||
op_names = ['Linear', 'Conv1d', 'Conv2d', 'Conv3d', 'ConvTranspose1d', 'ConvTranspose2d', 'ConvTranspose3d', 'GroupNorm', 'LayerNorm', 'Embedding']
|
||||
backups = {op_name: getattr(torch.nn, op_name) for op_name in op_names}
|
||||
|
||||
Reference in New Issue
Block a user