forge 2.0.0

see also discussions
This commit is contained in:
lllyasviel
2024-08-10 19:24:19 -07:00
committed by GitHub
parent 4014013d05
commit cfa5242a75
28 changed files with 785 additions and 1249 deletions

View File

@@ -1,3 +1,5 @@
# Copyright Forge 2024
import time
import torch
import contextlib
@@ -8,7 +10,7 @@ from backend import stream, memory_management
stash = {}
def weights_manual_cast(layer, x, skip_dtype=False):
def weights_manual_cast(layer, x, skip_weight_dtype=False, skip_bias_dtype=False):
weight, bias, signal = None, None, None
non_blocking = True
@@ -18,21 +20,28 @@ def weights_manual_cast(layer, x, skip_dtype=False):
target_dtype = x.dtype
target_device = x.device
if skip_dtype:
target_dtype = None
if skip_weight_dtype:
weight_args = dict(device=target_device, non_blocking=non_blocking)
else:
weight_args = dict(device=target_device, dtype=target_dtype, non_blocking=non_blocking)
if skip_bias_dtype:
bias_args = dict(device=target_device, non_blocking=non_blocking)
else:
bias_args = dict(device=target_device, dtype=target_dtype, non_blocking=non_blocking)
if stream.should_use_stream():
with stream.stream_context()(stream.mover_stream):
if layer.weight is not None:
weight = layer.weight.to(device=target_device, dtype=target_dtype, non_blocking=non_blocking)
weight = layer.weight.to(**weight_args)
if layer.bias is not None:
bias = layer.bias.to(device=target_device, dtype=target_dtype, non_blocking=non_blocking)
bias = layer.bias.to(**bias_args)
signal = stream.mover_stream.record_event()
else:
if layer.weight is not None:
weight = layer.weight.to(device=target_device, dtype=target_dtype, non_blocking=non_blocking)
weight = layer.weight.to(**weight_args)
if layer.bias is not None:
bias = layer.bias.to(device=target_device, dtype=target_dtype, non_blocking=non_blocking)
bias = layer.bias.to(**bias_args)
return weight, bias, signal
@@ -72,19 +81,27 @@ def cleanup_cache():
current_device = None
current_dtype = None
current_manual_cast_enabled = False
current_bnb_dtype = None
class ForgeOperations:
class Linear(torch.nn.Linear):
class Linear(torch.nn.Module):
def __init__(self, *args, **kwargs):
kwargs['device'] = current_device
kwargs['dtype'] = current_dtype
super().__init__(*args, **kwargs)
super().__init__()
self.dummy = torch.nn.Parameter(torch.empty(1, device=current_device, dtype=current_dtype))
self.weight = None
self.bias = None
self.parameters_manual_cast = current_manual_cast_enabled
def reset_parameters(self):
return None
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs):
if hasattr(self, 'dummy'):
if prefix + 'weight' in state_dict:
self.weight = torch.nn.Parameter(state_dict[prefix + 'weight'].to(self.dummy))
if prefix + 'bias' in state_dict:
self.bias = torch.nn.Parameter(state_dict[prefix + 'bias'].to(self.dummy))
del self.dummy
else:
super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
def forward(self, x):
if self.parameters_manual_cast:
@@ -92,7 +109,7 @@ class ForgeOperations:
with main_stream_worker(weight, bias, signal):
return torch.nn.functional.linear(x, weight, bias)
else:
return super().forward(x)
return torch.nn.functional.linear(x, self.weight, self.bias)
class Conv2d(torch.nn.Conv2d):
@@ -269,21 +286,61 @@ class ForgeOperations:
def forward(self, x):
if self.parameters_manual_cast:
weight, bias, signal = weights_manual_cast(self, x, skip_dtype=True)
weight, bias, signal = weights_manual_cast(self, x, skip_weight_dtype=True, skip_bias_dtype=True)
with main_stream_worker(weight, bias, signal):
return torch.nn.functional.embedding(x, weight, self.padding_idx, self.max_norm, self.norm_type, self.scale_grad_by_freq, self.sparse)
else:
return super().forward(x)
@contextlib.contextmanager
def using_forge_operations(operations=None, device=None, dtype=None, manual_cast_enabled=False):
global current_device, current_dtype, current_manual_cast_enabled
try:
from backend.operations_bnb import ForgeLoader4Bit, ForgeParams4bit, functional_linear_4bits
current_device, current_dtype, current_manual_cast_enabled = device, dtype, manual_cast_enabled
class ForgeOperationsBNB4bits(ForgeOperations):
class Linear(ForgeLoader4Bit):
def __init__(self, *args, **kwargs):
super().__init__(device=current_device, dtype=current_dtype, quant_type=current_bnb_dtype)
self.parameters_manual_cast = current_manual_cast_enabled
def forward(self, x):
self.weight.quant_state = self.quant_state
if self.bias is not None and self.bias.dtype != x.dtype:
# Maybe this can also be set to all non-bnb ops since the cost is very low.
# And it only invokes one time, and most linear does not have bias
self.bias.data = self.bias.data.to(x.dtype)
if not self.parameters_manual_cast:
return functional_linear_4bits(x, self.weight, self.bias)
elif not self.weight.bnb_quantized:
assert x.device.type == 'cuda', 'BNB Must Use CUDA as Computation Device!'
layer_original_device = self.weight.device
self.weight = self.weight._quantize(x.device)
bias = self.bias.to(x.device) if self.bias is not None else None
out = functional_linear_4bits(x, self.weight, bias)
self.weight = self.weight.to(layer_original_device)
return out
else:
weight, bias, signal = weights_manual_cast(self, x, skip_weight_dtype=True, skip_bias_dtype=True)
with main_stream_worker(weight, bias, signal):
return functional_linear_4bits(x, weight, bias)
bnb_avaliable = True
except:
bnb_avaliable = False
@contextlib.contextmanager
def using_forge_operations(operations=None, device=None, dtype=None, manual_cast_enabled=False, bnb_dtype=None):
global current_device, current_dtype, current_manual_cast_enabled, current_bnb_dtype
current_device, current_dtype, current_manual_cast_enabled, current_bnb_dtype = device, dtype, manual_cast_enabled, bnb_dtype
if operations is None:
operations = ForgeOperations
if bnb_avaliable and bnb_dtype in ['nf4', 'fp4']:
operations = ForgeOperationsBNB4bits
else:
operations = ForgeOperations
op_names = ['Linear', 'Conv1d', 'Conv2d', 'Conv3d', 'ConvTranspose1d', 'ConvTranspose2d', 'ConvTranspose3d', 'GroupNorm', 'LayerNorm', 'Embedding']
backups = {op_name: getattr(torch.nn, op_name) for op_name in op_names}