mirror of
https://github.com/lllyasviel/stable-diffusion-webui-forge.git
synced 2026-01-26 19:09:45 +00:00
fix some gguf loras
This commit is contained in:
@@ -4,7 +4,7 @@ import time
|
||||
import torch
|
||||
import contextlib
|
||||
|
||||
from backend import stream, memory_management
|
||||
from backend import stream, memory_management, utils
|
||||
|
||||
|
||||
stash = {}
|
||||
@@ -355,9 +355,9 @@ class ForgeOperationsGGUF(ForgeOperations):
|
||||
|
||||
def _apply(self, fn, recurse=True):
|
||||
if self.weight is not None:
|
||||
self.weight = fn(self.weight)
|
||||
self.weight = utils.tensor2parameter(fn(self.weight))
|
||||
if self.bias is not None:
|
||||
self.bias = fn(self.bias)
|
||||
self.bias = utils.tensor2parameter(fn(self.bias))
|
||||
return self
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
import torch
|
||||
import bitsandbytes as bnb
|
||||
|
||||
from backend import utils
|
||||
from bitsandbytes.nn.modules import Params4bit, QuantState
|
||||
from bitsandbytes.functional import dequantize_4bit
|
||||
|
||||
@@ -88,9 +89,9 @@ class ForgeLoader4Bit(torch.nn.Module):
|
||||
|
||||
def _apply(self, fn, recurse=True):
|
||||
if self.weight is not None:
|
||||
self.weight = fn(self.weight)
|
||||
self.weight = utils.tensor2parameter(fn(self.weight))
|
||||
if self.bias is not None:
|
||||
self.bias = torch.nn.Parameter(fn(self.bias), requires_grad=False)
|
||||
self.bias = utils.tensor2parameter(fn(self.bias))
|
||||
return self
|
||||
|
||||
def _save_to_state_dict(self, destination, prefix, keep_vars):
|
||||
|
||||
@@ -64,6 +64,9 @@ def dequantize_tensor(tensor):
|
||||
if tensor is None:
|
||||
return None
|
||||
|
||||
if not hasattr(tensor, 'gguf_cls'):
|
||||
return tensor
|
||||
|
||||
data = torch.tensor(tensor.data)
|
||||
gguf_cls = tensor.gguf_cls
|
||||
gguf_real_shape = tensor.gguf_real_shape
|
||||
|
||||
@@ -93,6 +93,13 @@ def calculate_parameters(sd, prefix=""):
|
||||
return params
|
||||
|
||||
|
||||
def tensor2parameter(x):
|
||||
if isinstance(x, torch.nn.Parameter):
|
||||
return x
|
||||
else:
|
||||
return torch.nn.Parameter(x, requires_grad=False)
|
||||
|
||||
|
||||
def fp16_fix(x):
|
||||
# An interesting trick to avoid fp16 overflow
|
||||
# Source: https://github.com/lllyasviel/stable-diffusion-webui-forge/issues/1114
|
||||
|
||||
Reference in New Issue
Block a user