gguf preview

This commit is contained in:
layerdiffusion
2024-08-14 22:26:00 -07:00
parent 59790f2cb4
commit d8b83a9501
8 changed files with 190 additions and 7 deletions

View File

@@ -330,6 +330,44 @@ except:
bnb_avaliable = False
from backend.operations_gguf import functional_linear_gguf
class ForgeOperationsGGUF(ForgeOperations):
class Linear(torch.nn.Module):
def __init__(self, *args, **kwargs):
super().__init__()
self.dummy = torch.nn.Parameter(torch.empty(1, device=current_device, dtype=current_dtype))
self.weight = None
self.bias = None
self.parameters_manual_cast = current_manual_cast_enabled
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs):
if hasattr(self, 'dummy'):
if prefix + 'weight' in state_dict:
self.weight = state_dict[prefix + 'weight'].to(device=self.dummy.device)
if prefix + 'bias' in state_dict:
self.bias = state_dict[prefix + 'bias'].to(device=self.dummy.device)
del self.dummy
else:
super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
def _apply(self, fn, recurse=True):
if self.weight is not None:
self.weight = fn(self.weight)
if self.bias is not None:
self.bias = fn(self.bias)
return super()._apply(fn, recurse=recurse)
def forward(self, x):
if self.parameters_manual_cast:
weight, bias, signal = weights_manual_cast(self, x, skip_weight_dtype=True)
with main_stream_worker(weight, bias, signal):
return functional_linear_gguf(x, weight, bias)
else:
return functional_linear_gguf(x, self.weight, self.bias)
@contextlib.contextmanager
def using_forge_operations(operations=None, device=None, dtype=None, manual_cast_enabled=False, bnb_dtype=None):
global current_device, current_dtype, current_manual_cast_enabled, current_bnb_dtype
@@ -337,7 +375,9 @@ def using_forge_operations(operations=None, device=None, dtype=None, manual_cast
current_device, current_dtype, current_manual_cast_enabled, current_bnb_dtype = device, dtype, manual_cast_enabled, bnb_dtype
if operations is None:
if bnb_avaliable and bnb_dtype in ['nf4', 'fp4']:
if bnb_dtype in ['gguf']:
operations = ForgeOperationsGGUF
elif bnb_avaliable and bnb_dtype in ['nf4', 'fp4']:
operations = ForgeOperationsBNB4bits
else:
operations = ForgeOperations