mirror of
https://github.com/lllyasviel/stable-diffusion-webui-forge.git
synced 2026-04-29 10:41:25 +00:00
gguf preview
This commit is contained in:
@@ -330,6 +330,44 @@ except:
|
||||
bnb_avaliable = False
|
||||
|
||||
|
||||
from backend.operations_gguf import functional_linear_gguf
|
||||
|
||||
|
||||
class ForgeOperationsGGUF(ForgeOperations):
|
||||
class Linear(torch.nn.Module):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__()
|
||||
self.dummy = torch.nn.Parameter(torch.empty(1, device=current_device, dtype=current_dtype))
|
||||
self.weight = None
|
||||
self.bias = None
|
||||
self.parameters_manual_cast = current_manual_cast_enabled
|
||||
|
||||
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs):
|
||||
if hasattr(self, 'dummy'):
|
||||
if prefix + 'weight' in state_dict:
|
||||
self.weight = state_dict[prefix + 'weight'].to(device=self.dummy.device)
|
||||
if prefix + 'bias' in state_dict:
|
||||
self.bias = state_dict[prefix + 'bias'].to(device=self.dummy.device)
|
||||
del self.dummy
|
||||
else:
|
||||
super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
|
||||
|
||||
def _apply(self, fn, recurse=True):
|
||||
if self.weight is not None:
|
||||
self.weight = fn(self.weight)
|
||||
if self.bias is not None:
|
||||
self.bias = fn(self.bias)
|
||||
return super()._apply(fn, recurse=recurse)
|
||||
|
||||
def forward(self, x):
|
||||
if self.parameters_manual_cast:
|
||||
weight, bias, signal = weights_manual_cast(self, x, skip_weight_dtype=True)
|
||||
with main_stream_worker(weight, bias, signal):
|
||||
return functional_linear_gguf(x, weight, bias)
|
||||
else:
|
||||
return functional_linear_gguf(x, self.weight, self.bias)
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def using_forge_operations(operations=None, device=None, dtype=None, manual_cast_enabled=False, bnb_dtype=None):
|
||||
global current_device, current_dtype, current_manual_cast_enabled, current_bnb_dtype
|
||||
@@ -337,7 +375,9 @@ def using_forge_operations(operations=None, device=None, dtype=None, manual_cast
|
||||
current_device, current_dtype, current_manual_cast_enabled, current_bnb_dtype = device, dtype, manual_cast_enabled, bnb_dtype
|
||||
|
||||
if operations is None:
|
||||
if bnb_avaliable and bnb_dtype in ['nf4', 'fp4']:
|
||||
if bnb_dtype in ['gguf']:
|
||||
operations = ForgeOperationsGGUF
|
||||
elif bnb_avaliable and bnb_dtype in ['nf4', 'fp4']:
|
||||
operations = ForgeOperationsBNB4bits
|
||||
else:
|
||||
operations = ForgeOperations
|
||||
|
||||
Reference in New Issue
Block a user