mirror of
https://github.com/lllyasviel/stable-diffusion-webui-forge.git
synced 2026-02-24 16:53:56 +00:00
gguf preview
This commit is contained in:
@@ -117,17 +117,17 @@ def load_huggingface_component(guess, component_name, lib_name, cls_name, repo_p
|
||||
if unet_storage_dtype_overwrite is not None:
|
||||
storage_dtype = unet_storage_dtype_overwrite
|
||||
else:
|
||||
if state_dict_dtype in [torch.float8_e4m3fn, torch.float8_e5m2, 'nf4', 'fp4']:
|
||||
if state_dict_dtype in [torch.float8_e4m3fn, torch.float8_e5m2, 'nf4', 'fp4', 'gguf']:
|
||||
print(f'Using Detected UNet Type: {state_dict_dtype}')
|
||||
storage_dtype = state_dict_dtype
|
||||
if state_dict_dtype in ['nf4', 'fp4']:
|
||||
if state_dict_dtype in ['nf4', 'fp4', 'gguf']:
|
||||
print(f'Using pre-quant state dict!')
|
||||
|
||||
load_device = memory_management.get_torch_device()
|
||||
computation_dtype = memory_management.get_computation_dtype(load_device, supported_dtypes=guess.supported_inference_dtypes)
|
||||
offload_device = memory_management.unet_offload_device()
|
||||
|
||||
if storage_dtype in ['nf4', 'fp4']:
|
||||
if storage_dtype in ['nf4', 'fp4', 'gguf']:
|
||||
initial_device = memory_management.unet_inital_load_device(parameters=state_dict_size, dtype=computation_dtype)
|
||||
with using_forge_operations(device=initial_device, dtype=computation_dtype, manual_cast_enabled=False, bnb_dtype=storage_dtype):
|
||||
model = model_loader(unet_config)
|
||||
|
||||
@@ -302,7 +302,9 @@ def state_dict_size(sd, exclude_device=None):
|
||||
|
||||
|
||||
def state_dict_dtype(state_dict):
|
||||
for k in state_dict.keys():
|
||||
for k, v in state_dict.items():
|
||||
if hasattr(v, 'is_gguf'):
|
||||
return 'gguf'
|
||||
if 'bitsandbytes__nf4' in k:
|
||||
return 'nf4'
|
||||
if 'bitsandbytes__fp4' in k:
|
||||
|
||||
@@ -330,6 +330,44 @@ except:
|
||||
bnb_avaliable = False
|
||||
|
||||
|
||||
from backend.operations_gguf import functional_linear_gguf
|
||||
|
||||
|
||||
class ForgeOperationsGGUF(ForgeOperations):
|
||||
class Linear(torch.nn.Module):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__()
|
||||
self.dummy = torch.nn.Parameter(torch.empty(1, device=current_device, dtype=current_dtype))
|
||||
self.weight = None
|
||||
self.bias = None
|
||||
self.parameters_manual_cast = current_manual_cast_enabled
|
||||
|
||||
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs):
|
||||
if hasattr(self, 'dummy'):
|
||||
if prefix + 'weight' in state_dict:
|
||||
self.weight = state_dict[prefix + 'weight'].to(device=self.dummy.device)
|
||||
if prefix + 'bias' in state_dict:
|
||||
self.bias = state_dict[prefix + 'bias'].to(device=self.dummy.device)
|
||||
del self.dummy
|
||||
else:
|
||||
super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
|
||||
|
||||
def _apply(self, fn, recurse=True):
|
||||
if self.weight is not None:
|
||||
self.weight = fn(self.weight)
|
||||
if self.bias is not None:
|
||||
self.bias = fn(self.bias)
|
||||
return super()._apply(fn, recurse=recurse)
|
||||
|
||||
def forward(self, x):
|
||||
if self.parameters_manual_cast:
|
||||
weight, bias, signal = weights_manual_cast(self, x, skip_weight_dtype=True)
|
||||
with main_stream_worker(weight, bias, signal):
|
||||
return functional_linear_gguf(x, weight, bias)
|
||||
else:
|
||||
return functional_linear_gguf(x, self.weight, self.bias)
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def using_forge_operations(operations=None, device=None, dtype=None, manual_cast_enabled=False, bnb_dtype=None):
|
||||
global current_device, current_dtype, current_manual_cast_enabled, current_bnb_dtype
|
||||
@@ -337,7 +375,9 @@ def using_forge_operations(operations=None, device=None, dtype=None, manual_cast
|
||||
current_device, current_dtype, current_manual_cast_enabled, current_bnb_dtype = device, dtype, manual_cast_enabled, bnb_dtype
|
||||
|
||||
if operations is None:
|
||||
if bnb_avaliable and bnb_dtype in ['nf4', 'fp4']:
|
||||
if bnb_dtype in ['gguf']:
|
||||
operations = ForgeOperationsGGUF
|
||||
elif bnb_avaliable and bnb_dtype in ['nf4', 'fp4']:
|
||||
operations = ForgeOperationsBNB4bits
|
||||
else:
|
||||
operations = ForgeOperations
|
||||
|
||||
109
backend/operations_gguf.py
Normal file
109
backend/operations_gguf.py
Normal file
@@ -0,0 +1,109 @@
|
||||
import gguf
|
||||
import torch
|
||||
|
||||
|
||||
def functional_linear_gguf(x, weight, bias=None):
|
||||
target_dtype = x.dtype
|
||||
weight = dequantize_tensor(weight, target_dtype)
|
||||
bias = dequantize_tensor(bias, target_dtype)
|
||||
return torch.nn.functional.linear(x, weight, bias)
|
||||
|
||||
|
||||
def dequantize_tensor(tensor, dtype=torch.float16):
|
||||
# (c) City96 || Apache-2.0 (apache.org/licenses/LICENSE-2.0)
|
||||
|
||||
if tensor is None:
|
||||
return None
|
||||
|
||||
data = torch.tensor(tensor.data)
|
||||
qtype = tensor.gguf_type
|
||||
oshape = tensor.gguf_real_shape
|
||||
|
||||
if qtype == gguf.GGMLQuantizationType.F32:
|
||||
return data.to(dtype)
|
||||
elif qtype == gguf.GGMLQuantizationType.F16:
|
||||
return data.to(dtype)
|
||||
elif qtype in dequantize_functions:
|
||||
# this is the main pytorch op
|
||||
return dequantize(data, qtype, oshape).to(dtype)
|
||||
else:
|
||||
# this is incredibly slow
|
||||
new = gguf.quants.dequantize(data.cpu().numpy(), qtype)
|
||||
return torch.from_numpy(new).to(data.device, dtype=dtype)
|
||||
|
||||
|
||||
def dequantize(data, qtype, oshape):
|
||||
# (c) City96 || Apache-2.0 (apache.org/licenses/LICENSE-2.0)
|
||||
|
||||
"""
|
||||
Dequantize tensor back to usable shape/dtype
|
||||
"""
|
||||
block_size, type_size = gguf.GGML_QUANT_SIZES[qtype]
|
||||
dequantize_blocks = dequantize_functions[qtype]
|
||||
|
||||
rows = data.reshape(
|
||||
(-1, data.shape[-1])
|
||||
).view(torch.uint8)
|
||||
|
||||
n_blocks = rows.numel() // type_size
|
||||
blocks = rows.reshape((n_blocks, type_size))
|
||||
blocks = dequantize_blocks(blocks, block_size, type_size)
|
||||
return blocks.reshape(oshape)
|
||||
|
||||
|
||||
def to_uint32(x):
|
||||
# (c) City96 || Apache-2.0 (apache.org/licenses/LICENSE-2.0)
|
||||
|
||||
# no uint32 :(
|
||||
x = x.view(torch.uint8).to(torch.int32)
|
||||
return (x[:, 0] | x[:, 1] << 8 | x[:, 2] << 16 | x[:, 3] << 24).unsqueeze(1)
|
||||
|
||||
|
||||
def dequantize_blocks_Q8_0(blocks, block_size, type_size):
|
||||
# (c) City96 || Apache-2.0 (apache.org/licenses/LICENSE-2.0)
|
||||
|
||||
d = blocks[:, :2].view(torch.float16)
|
||||
x = blocks[:, 2:].view(torch.int8).to(torch.float16)
|
||||
return (x * d)
|
||||
|
||||
|
||||
def dequantize_blocks_Q5_0(blocks, block_size, type_size):
|
||||
# (c) City96 || Apache-2.0 (apache.org/licenses/LICENSE-2.0)
|
||||
|
||||
n_blocks = blocks.shape[0]
|
||||
|
||||
d = blocks[:, :2]
|
||||
qh = blocks[:, 2:6]
|
||||
qs = blocks[:, 6:]
|
||||
|
||||
d = d.view(torch.float16).to(torch.float32)
|
||||
qh = to_uint32(qh)
|
||||
|
||||
qh = qh.reshape(n_blocks, 1) >> torch.arange(32, device=d.device, dtype=torch.int32).reshape(1, 32)
|
||||
ql = qs.reshape(n_blocks, -1, 1, block_size // 2) >> torch.tensor([0, 4], device=d.device, dtype=torch.uint8).reshape(1, 1, 2, 1)
|
||||
|
||||
qh = (qh & 1).to(torch.uint8)
|
||||
ql = (ql & 0x0F).reshape(n_blocks, -1)
|
||||
|
||||
qs = (ql | (qh << 4)).to(torch.int8) - 16
|
||||
return (d * qs)
|
||||
|
||||
|
||||
def dequantize_blocks_Q4_0(blocks, block_size, type_size):
|
||||
# (c) City96 || Apache-2.0 (apache.org/licenses/LICENSE-2.0)
|
||||
|
||||
n_blocks = blocks.shape[0]
|
||||
|
||||
d = blocks[:, :2].view(torch.float16)
|
||||
qs = blocks[:, 2:]
|
||||
|
||||
qs = qs.reshape((n_blocks, -1, 1, block_size // 2)) >> torch.tensor([0, 4], device=d.device, dtype=torch.uint8).reshape((1, 1, 2, 1))
|
||||
qs = (qs & 0x0F).reshape((n_blocks, -1)).to(torch.int8) - 8
|
||||
return (d * qs)
|
||||
|
||||
|
||||
dequantize_functions = {
|
||||
gguf.GGMLQuantizationType.Q8_0: dequantize_blocks_Q8_0,
|
||||
gguf.GGMLQuantizationType.Q5_0: dequantize_blocks_Q5_0,
|
||||
gguf.GGMLQuantizationType.Q4_0: dequantize_blocks_Q4_0,
|
||||
}
|
||||
@@ -431,7 +431,7 @@ class ControlLora(ControlNet):
|
||||
|
||||
dtype = model.storage_dtype
|
||||
|
||||
if dtype in ['nf4', 'fp4']:
|
||||
if dtype in ['nf4', 'fp4', 'gguf']:
|
||||
dtype = torch.float16
|
||||
|
||||
controlnet_config["dtype"] = dtype
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import gguf
|
||||
import torch
|
||||
import os
|
||||
import json
|
||||
@@ -5,6 +6,31 @@ import safetensors.torch
|
||||
import backend.misc.checkpoint_pickle
|
||||
|
||||
|
||||
class ParameterGGUF(torch.nn.Parameter):
|
||||
def __init__(self, tensor=None, requires_grad=False, no_init=False):
|
||||
super().__init__()
|
||||
self.is_gguf = True
|
||||
|
||||
if no_init:
|
||||
return
|
||||
|
||||
self.gguf_type = tensor.tensor_type
|
||||
self.gguf_real_shape = torch.Size(reversed(list(tensor.shape)))
|
||||
|
||||
@property
|
||||
def shape(self):
|
||||
return self.gguf_real_shape
|
||||
|
||||
def __new__(cls, tensor=None, requires_grad=False, no_init=False):
|
||||
return super().__new__(cls, torch.tensor(tensor.data), requires_grad=requires_grad)
|
||||
|
||||
def to(self, *args, **kwargs):
|
||||
new = ParameterGGUF(self.data.to(*args, **kwargs), no_init=True)
|
||||
new.gguf_type = self.gguf_type
|
||||
new.gguf_real_shape = self.gguf_real_shape
|
||||
return new
|
||||
|
||||
|
||||
def read_arbitrary_config(directory):
|
||||
config_path = os.path.join(directory, 'config.json')
|
||||
|
||||
@@ -22,6 +48,11 @@ def load_torch_file(ckpt, safe_load=False, device=None):
|
||||
device = torch.device("cpu")
|
||||
if ckpt.lower().endswith(".safetensors"):
|
||||
sd = safetensors.torch.load_file(ckpt, device=device.type)
|
||||
elif ckpt.lower().endswith(".gguf"):
|
||||
reader = gguf.GGUFReader(ckpt)
|
||||
sd = {}
|
||||
for tensor in reader.tensors:
|
||||
sd[str(tensor.name)] = ParameterGGUF(tensor)
|
||||
else:
|
||||
if safe_load:
|
||||
if not 'weights_only' in torch.load.__code__.co_varnames:
|
||||
|
||||
@@ -173,7 +173,7 @@ def list_models():
|
||||
else:
|
||||
model_url = "https://huggingface.co/lllyasviel/fav_models/resolve/main/fav/realisticVisionV51_v51VAE.safetensors"
|
||||
|
||||
model_list = modelloader.load_models(model_path=model_path, model_url=model_url, command_path=shared.cmd_opts.ckpt_dir, ext_filter=[".ckpt", ".safetensors"], download_name="realisticVisionV51_v51VAE.safetensors", ext_blacklist=[".vae.ckpt", ".vae.safetensors"])
|
||||
model_list = modelloader.load_models(model_path=model_path, model_url=model_url, command_path=shared.cmd_opts.ckpt_dir, ext_filter=[".ckpt", ".safetensors", ".gguf"], download_name="realisticVisionV51_v51VAE.safetensors", ext_blacklist=[".vae.ckpt", ".vae.safetensors"])
|
||||
|
||||
if os.path.exists(cmd_ckpt):
|
||||
checkpoint_info = CheckpointInfo(cmd_ckpt)
|
||||
|
||||
@@ -37,3 +37,4 @@ basicsr==1.4.2
|
||||
diffusers==0.29.2
|
||||
gradio_rangeslider==0.0.6
|
||||
tqdm==4.66.1
|
||||
gguf==0.9.1
|
||||
|
||||
Reference in New Issue
Block a user