diff --git a/backend/operations.py b/backend/operations.py index 1ca15eba..1d79a4c3 100644 --- a/backend/operations.py +++ b/backend/operations.py @@ -4,7 +4,7 @@ import time import torch import contextlib -from backend import stream, memory_management +from backend import stream, memory_management, utils stash = {} @@ -355,9 +355,9 @@ class ForgeOperationsGGUF(ForgeOperations): def _apply(self, fn, recurse=True): if self.weight is not None: - self.weight = fn(self.weight) + self.weight = utils.tensor2parameter(fn(self.weight)) if self.bias is not None: - self.bias = fn(self.bias) + self.bias = utils.tensor2parameter(fn(self.bias)) return self def forward(self, x): diff --git a/backend/operations_bnb.py b/backend/operations_bnb.py index 5a7089a9..654776ca 100644 --- a/backend/operations_bnb.py +++ b/backend/operations_bnb.py @@ -3,6 +3,7 @@ import torch import bitsandbytes as bnb +from backend import utils from bitsandbytes.nn.modules import Params4bit, QuantState from bitsandbytes.functional import dequantize_4bit @@ -88,9 +89,9 @@ class ForgeLoader4Bit(torch.nn.Module): def _apply(self, fn, recurse=True): if self.weight is not None: - self.weight = fn(self.weight) + self.weight = utils.tensor2parameter(fn(self.weight)) if self.bias is not None: - self.bias = torch.nn.Parameter(fn(self.bias), requires_grad=False) + self.bias = utils.tensor2parameter(fn(self.bias)) return self def _save_to_state_dict(self, destination, prefix, keep_vars): diff --git a/backend/operations_gguf.py b/backend/operations_gguf.py index 9c54e090..922e4d33 100644 --- a/backend/operations_gguf.py +++ b/backend/operations_gguf.py @@ -64,6 +64,9 @@ def dequantize_tensor(tensor): if tensor is None: return None + if not hasattr(tensor, 'gguf_cls'): + return tensor + data = torch.tensor(tensor.data) gguf_cls = tensor.gguf_cls gguf_real_shape = tensor.gguf_real_shape diff --git a/backend/utils.py b/backend/utils.py index 0a940af9..20f53ff7 100644 --- a/backend/utils.py +++ b/backend/utils.py @@ -93,6 +93,13 @@ def calculate_parameters(sd, prefix=""): return params +def tensor2parameter(x): + if isinstance(x, torch.nn.Parameter): + return x + else: + return torch.nn.Parameter(x, requires_grad=False) + + def fp16_fix(x): # An interesting trick to avoid fp16 overflow # Source: https://github.com/lllyasviel/stable-diffusion-webui-forge/issues/1114