mirror of
https://github.com/lllyasviel/stable-diffusion-webui-forge.git
synced 2026-03-06 05:29:49 +00:00
support more t5 quants (#1482)
lets hope this is the last time that people randomly invent new state dict key formats
This commit is contained in:
@@ -10,7 +10,7 @@ from diffusers import DiffusionPipeline
|
||||
from transformers import modeling_utils
|
||||
|
||||
from backend import memory_management
|
||||
from backend.utils import read_arbitrary_config, load_torch_file
|
||||
from backend.utils import read_arbitrary_config, load_torch_file, beautiful_print_gguf_state_dict_statics
|
||||
from backend.state_dict import try_filter_state_dict, load_state_dict
|
||||
from backend.operations import using_forge_operations
|
||||
from backend.nn.vae import IntegratedAutoencoderKL
|
||||
@@ -80,18 +80,27 @@ def load_huggingface_component(guess, component_name, lib_name, cls_name, repo_p
|
||||
from backend.nn.t5 import IntegratedT5
|
||||
config = read_arbitrary_config(config_path)
|
||||
|
||||
dtype = memory_management.text_encoder_dtype()
|
||||
sd_dtype = memory_management.state_dict_dtype(state_dict)
|
||||
storage_dtype = memory_management.text_encoder_dtype()
|
||||
state_dict_dtype = memory_management.state_dict_dtype(state_dict)
|
||||
|
||||
if sd_dtype in [torch.float8_e4m3fn, torch.float8_e5m2]:
|
||||
dtype = sd_dtype
|
||||
print(f'Using Detected T5 Data Type: {dtype}')
|
||||
if state_dict_dtype in [torch.float8_e4m3fn, torch.float8_e5m2, 'nf4', 'fp4', 'gguf']:
|
||||
print(f'Using Detected T5 Data Type: {state_dict_dtype}')
|
||||
storage_dtype = state_dict_dtype
|
||||
if state_dict_dtype in ['nf4', 'fp4', 'gguf']:
|
||||
print(f'Using pre-quant state dict!')
|
||||
if state_dict_dtype in ['gguf']:
|
||||
beautiful_print_gguf_state_dict_statics(state_dict)
|
||||
else:
|
||||
print(f'Using Default T5 Data Type: {dtype}')
|
||||
print(f'Using Default T5 Data Type: {storage_dtype}')
|
||||
|
||||
with modeling_utils.no_init_weights():
|
||||
with using_forge_operations(device=memory_management.cpu, dtype=dtype, manual_cast_enabled=True):
|
||||
model = IntegratedT5(config)
|
||||
if storage_dtype in ['nf4', 'fp4', 'gguf']:
|
||||
with modeling_utils.no_init_weights():
|
||||
with using_forge_operations(device=memory_management.cpu, dtype=memory_management.text_encoder_dtype(), manual_cast_enabled=False, bnb_dtype=storage_dtype):
|
||||
model = IntegratedT5(config)
|
||||
else:
|
||||
with modeling_utils.no_init_weights():
|
||||
with using_forge_operations(device=memory_management.cpu, dtype=storage_dtype, manual_cast_enabled=True):
|
||||
model = IntegratedT5(config)
|
||||
|
||||
load_state_dict(model, state_dict, log_name=cls_name, ignore_errors=['transformer.encoder.embed_tokens.weight', 'logit_scale'])
|
||||
|
||||
@@ -116,28 +125,13 @@ def load_huggingface_component(guess, component_name, lib_name, cls_name, repo_p
|
||||
|
||||
if unet_storage_dtype_overwrite is not None:
|
||||
storage_dtype = unet_storage_dtype_overwrite
|
||||
else:
|
||||
if state_dict_dtype in [torch.float8_e4m3fn, torch.float8_e5m2, 'nf4', 'fp4', 'gguf']:
|
||||
print(f'Using Detected UNet Type: {state_dict_dtype}')
|
||||
storage_dtype = state_dict_dtype
|
||||
if state_dict_dtype in ['nf4', 'fp4', 'gguf']:
|
||||
print(f'Using pre-quant state dict!')
|
||||
|
||||
if state_dict_dtype in ['gguf']:
|
||||
from gguf.constants import GGMLQuantizationType
|
||||
|
||||
type_counts = {}
|
||||
|
||||
for k, v in state_dict.items():
|
||||
gguf_type = getattr(v, 'gguf_type', None)
|
||||
if gguf_type is not None:
|
||||
type_name = GGMLQuantizationType(gguf_type).name
|
||||
if type_name in type_counts:
|
||||
type_counts[type_name] += 1
|
||||
else:
|
||||
type_counts[type_name] = 1
|
||||
|
||||
print(f'Using GGUF state dict: {type_counts}')
|
||||
elif state_dict_dtype in [torch.float8_e4m3fn, torch.float8_e5m2, 'nf4', 'fp4', 'gguf']:
|
||||
print(f'Using Detected UNet Type: {state_dict_dtype}')
|
||||
storage_dtype = state_dict_dtype
|
||||
if state_dict_dtype in ['nf4', 'fp4', 'gguf']:
|
||||
print(f'Using pre-quant state dict!')
|
||||
if state_dict_dtype in ['gguf']:
|
||||
beautiful_print_gguf_state_dict_statics(state_dict)
|
||||
|
||||
load_device = memory_management.get_torch_device()
|
||||
computation_dtype = memory_management.get_computation_dtype(load_device, parameters=state_dict_parameters, supported_dtypes=guess.supported_inference_dtypes)
|
||||
@@ -178,6 +172,34 @@ def replace_state_dict(sd, asd, guess):
|
||||
vae_key_prefix = guess.vae_key_prefix[0]
|
||||
text_encoder_key_prefix = guess.text_encoder_key_prefix[0]
|
||||
|
||||
if 'enc.blk.0.attn_k.weight' in asd:
|
||||
wierd_t5_format_from_city96 = {
|
||||
"enc.": "encoder.",
|
||||
".blk.": ".block.",
|
||||
"token_embd": "shared",
|
||||
"output_norm": "final_layer_norm",
|
||||
"attn_q": "layer.0.SelfAttention.q",
|
||||
"attn_k": "layer.0.SelfAttention.k",
|
||||
"attn_v": "layer.0.SelfAttention.v",
|
||||
"attn_o": "layer.0.SelfAttention.o",
|
||||
"attn_norm": "layer.0.layer_norm",
|
||||
"attn_rel_b": "layer.0.SelfAttention.relative_attention_bias",
|
||||
"ffn_up": "layer.1.DenseReluDense.wi_1",
|
||||
"ffn_down": "layer.1.DenseReluDense.wo",
|
||||
"ffn_gate": "layer.1.DenseReluDense.wi_0",
|
||||
"ffn_norm": "layer.1.layer_norm",
|
||||
}
|
||||
wierd_t5_pre_quant_keys_from_city96 = ['shared.weight']
|
||||
asd_new = {}
|
||||
for k, v in asd.items():
|
||||
for s, d in wierd_t5_format_from_city96.items():
|
||||
k = k.replace(s, d)
|
||||
asd_new[k] = v
|
||||
for k in wierd_t5_pre_quant_keys_from_city96:
|
||||
asd_new[k] = asd_new[k].dequantize_as_pytorch_parameter()
|
||||
asd.clear()
|
||||
asd = asd_new
|
||||
|
||||
if "decoder.conv_in.weight" in asd:
|
||||
keys_to_delete = [k for k in sd if k.startswith(vae_key_prefix)]
|
||||
for k in keys_to_delete:
|
||||
|
||||
@@ -35,6 +35,9 @@ class ParameterGGUF(torch.nn.Parameter):
|
||||
def __new__(cls, tensor=None, requires_grad=False, no_init=False):
|
||||
return super().__new__(cls, torch.tensor(tensor.data), requires_grad=requires_grad)
|
||||
|
||||
def dequantize_as_pytorch_parameter(self):
|
||||
return torch.nn.Parameter(dequantize_tensor(self), requires_grad=False)
|
||||
|
||||
def to(self, *args, **kwargs):
|
||||
new = ParameterGGUF(self.data.to(*args, **kwargs), no_init=True)
|
||||
new.gguf_type = self.gguf_type
|
||||
|
||||
@@ -151,3 +151,18 @@ def get_state_dict_after_quant(model, prefix=''):
|
||||
sd = model.state_dict()
|
||||
sd = {(prefix + k): v.clone() for k, v in sd.items()}
|
||||
return sd
|
||||
|
||||
|
||||
def beautiful_print_gguf_state_dict_statics(state_dict):
|
||||
from gguf.constants import GGMLQuantizationType
|
||||
type_counts = {}
|
||||
for k, v in state_dict.items():
|
||||
gguf_type = getattr(v, 'gguf_type', None)
|
||||
if gguf_type is not None:
|
||||
type_name = GGMLQuantizationType(gguf_type).name
|
||||
if type_name in type_counts:
|
||||
type_counts[type_name] += 1
|
||||
else:
|
||||
type_counts[type_name] = 1
|
||||
print(f'GGUF state dict: {type_counts}')
|
||||
return
|
||||
|
||||
@@ -139,7 +139,7 @@ def refresh_models():
|
||||
shared_items.refresh_checkpoints()
|
||||
ckpt_list = shared_items.list_checkpoint_tiles(shared.opts.sd_checkpoint_dropdown_use_short)
|
||||
|
||||
file_extensions = ['ckpt', 'pt', 'bin', 'safetensors']
|
||||
file_extensions = ['ckpt', 'pt', 'bin', 'safetensors', 'gguf']
|
||||
|
||||
module_list.clear()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user