support more t5 quants (#1482)

lets hope this is the last time that people randomly invent new state dict key formats
This commit is contained in:
lllyasviel
2024-08-24 12:47:49 -07:00
committed by GitHub
parent 0f3309eb3d
commit f82029c5cf
4 changed files with 73 additions and 33 deletions

View File

@@ -10,7 +10,7 @@ from diffusers import DiffusionPipeline
from transformers import modeling_utils
from backend import memory_management
from backend.utils import read_arbitrary_config, load_torch_file
from backend.utils import read_arbitrary_config, load_torch_file, beautiful_print_gguf_state_dict_statics
from backend.state_dict import try_filter_state_dict, load_state_dict
from backend.operations import using_forge_operations
from backend.nn.vae import IntegratedAutoencoderKL
@@ -80,18 +80,27 @@ def load_huggingface_component(guess, component_name, lib_name, cls_name, repo_p
from backend.nn.t5 import IntegratedT5
config = read_arbitrary_config(config_path)
dtype = memory_management.text_encoder_dtype()
sd_dtype = memory_management.state_dict_dtype(state_dict)
storage_dtype = memory_management.text_encoder_dtype()
state_dict_dtype = memory_management.state_dict_dtype(state_dict)
if sd_dtype in [torch.float8_e4m3fn, torch.float8_e5m2]:
dtype = sd_dtype
print(f'Using Detected T5 Data Type: {dtype}')
if state_dict_dtype in [torch.float8_e4m3fn, torch.float8_e5m2, 'nf4', 'fp4', 'gguf']:
print(f'Using Detected T5 Data Type: {state_dict_dtype}')
storage_dtype = state_dict_dtype
if state_dict_dtype in ['nf4', 'fp4', 'gguf']:
print(f'Using pre-quant state dict!')
if state_dict_dtype in ['gguf']:
beautiful_print_gguf_state_dict_statics(state_dict)
else:
print(f'Using Default T5 Data Type: {dtype}')
print(f'Using Default T5 Data Type: {storage_dtype}')
with modeling_utils.no_init_weights():
with using_forge_operations(device=memory_management.cpu, dtype=dtype, manual_cast_enabled=True):
model = IntegratedT5(config)
if storage_dtype in ['nf4', 'fp4', 'gguf']:
with modeling_utils.no_init_weights():
with using_forge_operations(device=memory_management.cpu, dtype=memory_management.text_encoder_dtype(), manual_cast_enabled=False, bnb_dtype=storage_dtype):
model = IntegratedT5(config)
else:
with modeling_utils.no_init_weights():
with using_forge_operations(device=memory_management.cpu, dtype=storage_dtype, manual_cast_enabled=True):
model = IntegratedT5(config)
load_state_dict(model, state_dict, log_name=cls_name, ignore_errors=['transformer.encoder.embed_tokens.weight', 'logit_scale'])
@@ -116,28 +125,13 @@ def load_huggingface_component(guess, component_name, lib_name, cls_name, repo_p
if unet_storage_dtype_overwrite is not None:
storage_dtype = unet_storage_dtype_overwrite
else:
if state_dict_dtype in [torch.float8_e4m3fn, torch.float8_e5m2, 'nf4', 'fp4', 'gguf']:
print(f'Using Detected UNet Type: {state_dict_dtype}')
storage_dtype = state_dict_dtype
if state_dict_dtype in ['nf4', 'fp4', 'gguf']:
print(f'Using pre-quant state dict!')
if state_dict_dtype in ['gguf']:
from gguf.constants import GGMLQuantizationType
type_counts = {}
for k, v in state_dict.items():
gguf_type = getattr(v, 'gguf_type', None)
if gguf_type is not None:
type_name = GGMLQuantizationType(gguf_type).name
if type_name in type_counts:
type_counts[type_name] += 1
else:
type_counts[type_name] = 1
print(f'Using GGUF state dict: {type_counts}')
elif state_dict_dtype in [torch.float8_e4m3fn, torch.float8_e5m2, 'nf4', 'fp4', 'gguf']:
print(f'Using Detected UNet Type: {state_dict_dtype}')
storage_dtype = state_dict_dtype
if state_dict_dtype in ['nf4', 'fp4', 'gguf']:
print(f'Using pre-quant state dict!')
if state_dict_dtype in ['gguf']:
beautiful_print_gguf_state_dict_statics(state_dict)
load_device = memory_management.get_torch_device()
computation_dtype = memory_management.get_computation_dtype(load_device, parameters=state_dict_parameters, supported_dtypes=guess.supported_inference_dtypes)
@@ -178,6 +172,34 @@ def replace_state_dict(sd, asd, guess):
vae_key_prefix = guess.vae_key_prefix[0]
text_encoder_key_prefix = guess.text_encoder_key_prefix[0]
if 'enc.blk.0.attn_k.weight' in asd:
wierd_t5_format_from_city96 = {
"enc.": "encoder.",
".blk.": ".block.",
"token_embd": "shared",
"output_norm": "final_layer_norm",
"attn_q": "layer.0.SelfAttention.q",
"attn_k": "layer.0.SelfAttention.k",
"attn_v": "layer.0.SelfAttention.v",
"attn_o": "layer.0.SelfAttention.o",
"attn_norm": "layer.0.layer_norm",
"attn_rel_b": "layer.0.SelfAttention.relative_attention_bias",
"ffn_up": "layer.1.DenseReluDense.wi_1",
"ffn_down": "layer.1.DenseReluDense.wo",
"ffn_gate": "layer.1.DenseReluDense.wi_0",
"ffn_norm": "layer.1.layer_norm",
}
wierd_t5_pre_quant_keys_from_city96 = ['shared.weight']
asd_new = {}
for k, v in asd.items():
for s, d in wierd_t5_format_from_city96.items():
k = k.replace(s, d)
asd_new[k] = v
for k in wierd_t5_pre_quant_keys_from_city96:
asd_new[k] = asd_new[k].dequantize_as_pytorch_parameter()
asd.clear()
asd = asd_new
if "decoder.conv_in.weight" in asd:
keys_to_delete = [k for k in sd if k.startswith(vae_key_prefix)]
for k in keys_to_delete:

View File

@@ -35,6 +35,9 @@ class ParameterGGUF(torch.nn.Parameter):
def __new__(cls, tensor=None, requires_grad=False, no_init=False):
return super().__new__(cls, torch.tensor(tensor.data), requires_grad=requires_grad)
def dequantize_as_pytorch_parameter(self):
return torch.nn.Parameter(dequantize_tensor(self), requires_grad=False)
def to(self, *args, **kwargs):
new = ParameterGGUF(self.data.to(*args, **kwargs), no_init=True)
new.gguf_type = self.gguf_type

View File

@@ -151,3 +151,18 @@ def get_state_dict_after_quant(model, prefix=''):
sd = model.state_dict()
sd = {(prefix + k): v.clone() for k, v in sd.items()}
return sd
def beautiful_print_gguf_state_dict_statics(state_dict):
from gguf.constants import GGMLQuantizationType
type_counts = {}
for k, v in state_dict.items():
gguf_type = getattr(v, 'gguf_type', None)
if gguf_type is not None:
type_name = GGMLQuantizationType(gguf_type).name
if type_name in type_counts:
type_counts[type_name] += 1
else:
type_counts[type_name] = 1
print(f'GGUF state dict: {type_counts}')
return

View File

@@ -139,7 +139,7 @@ def refresh_models():
shared_items.refresh_checkpoints()
ckpt_list = shared_items.list_checkpoint_tiles(shared.opts.sd_checkpoint_dropdown_use_short)
file_extensions = ['ckpt', 'pt', 'bin', 'safetensors']
file_extensions = ['ckpt', 'pt', 'bin', 'safetensors', 'gguf']
module_list.clear()