diff --git a/backend/loader.py b/backend/loader.py index 17818085..24dc4b40 100644 --- a/backend/loader.py +++ b/backend/loader.py @@ -10,7 +10,7 @@ from diffusers import DiffusionPipeline from transformers import modeling_utils from backend import memory_management -from backend.utils import read_arbitrary_config, load_torch_file +from backend.utils import read_arbitrary_config, load_torch_file, beautiful_print_gguf_state_dict_statics from backend.state_dict import try_filter_state_dict, load_state_dict from backend.operations import using_forge_operations from backend.nn.vae import IntegratedAutoencoderKL @@ -80,18 +80,27 @@ def load_huggingface_component(guess, component_name, lib_name, cls_name, repo_p from backend.nn.t5 import IntegratedT5 config = read_arbitrary_config(config_path) - dtype = memory_management.text_encoder_dtype() - sd_dtype = memory_management.state_dict_dtype(state_dict) + storage_dtype = memory_management.text_encoder_dtype() + state_dict_dtype = memory_management.state_dict_dtype(state_dict) - if sd_dtype in [torch.float8_e4m3fn, torch.float8_e5m2]: - dtype = sd_dtype - print(f'Using Detected T5 Data Type: {dtype}') + if state_dict_dtype in [torch.float8_e4m3fn, torch.float8_e5m2, 'nf4', 'fp4', 'gguf']: + print(f'Using Detected T5 Data Type: {state_dict_dtype}') + storage_dtype = state_dict_dtype + if state_dict_dtype in ['nf4', 'fp4', 'gguf']: + print(f'Using pre-quant state dict!') + if state_dict_dtype in ['gguf']: + beautiful_print_gguf_state_dict_statics(state_dict) else: - print(f'Using Default T5 Data Type: {dtype}') + print(f'Using Default T5 Data Type: {storage_dtype}') - with modeling_utils.no_init_weights(): - with using_forge_operations(device=memory_management.cpu, dtype=dtype, manual_cast_enabled=True): - model = IntegratedT5(config) + if storage_dtype in ['nf4', 'fp4', 'gguf']: + with modeling_utils.no_init_weights(): + with using_forge_operations(device=memory_management.cpu, dtype=memory_management.text_encoder_dtype(), manual_cast_enabled=False, bnb_dtype=storage_dtype): + model = IntegratedT5(config) + else: + with modeling_utils.no_init_weights(): + with using_forge_operations(device=memory_management.cpu, dtype=storage_dtype, manual_cast_enabled=True): + model = IntegratedT5(config) load_state_dict(model, state_dict, log_name=cls_name, ignore_errors=['transformer.encoder.embed_tokens.weight', 'logit_scale']) @@ -116,28 +125,13 @@ def load_huggingface_component(guess, component_name, lib_name, cls_name, repo_p if unet_storage_dtype_overwrite is not None: storage_dtype = unet_storage_dtype_overwrite - else: - if state_dict_dtype in [torch.float8_e4m3fn, torch.float8_e5m2, 'nf4', 'fp4', 'gguf']: - print(f'Using Detected UNet Type: {state_dict_dtype}') - storage_dtype = state_dict_dtype - if state_dict_dtype in ['nf4', 'fp4', 'gguf']: - print(f'Using pre-quant state dict!') - - if state_dict_dtype in ['gguf']: - from gguf.constants import GGMLQuantizationType - - type_counts = {} - - for k, v in state_dict.items(): - gguf_type = getattr(v, 'gguf_type', None) - if gguf_type is not None: - type_name = GGMLQuantizationType(gguf_type).name - if type_name in type_counts: - type_counts[type_name] += 1 - else: - type_counts[type_name] = 1 - - print(f'Using GGUF state dict: {type_counts}') + elif state_dict_dtype in [torch.float8_e4m3fn, torch.float8_e5m2, 'nf4', 'fp4', 'gguf']: + print(f'Using Detected UNet Type: {state_dict_dtype}') + storage_dtype = state_dict_dtype + if state_dict_dtype in ['nf4', 'fp4', 'gguf']: + print(f'Using pre-quant state dict!') + if state_dict_dtype in ['gguf']: + beautiful_print_gguf_state_dict_statics(state_dict) load_device = memory_management.get_torch_device() computation_dtype = memory_management.get_computation_dtype(load_device, parameters=state_dict_parameters, supported_dtypes=guess.supported_inference_dtypes) @@ -178,6 +172,34 @@ def replace_state_dict(sd, asd, guess): vae_key_prefix = guess.vae_key_prefix[0] text_encoder_key_prefix = guess.text_encoder_key_prefix[0] + if 'enc.blk.0.attn_k.weight' in asd: + wierd_t5_format_from_city96 = { + "enc.": "encoder.", + ".blk.": ".block.", + "token_embd": "shared", + "output_norm": "final_layer_norm", + "attn_q": "layer.0.SelfAttention.q", + "attn_k": "layer.0.SelfAttention.k", + "attn_v": "layer.0.SelfAttention.v", + "attn_o": "layer.0.SelfAttention.o", + "attn_norm": "layer.0.layer_norm", + "attn_rel_b": "layer.0.SelfAttention.relative_attention_bias", + "ffn_up": "layer.1.DenseReluDense.wi_1", + "ffn_down": "layer.1.DenseReluDense.wo", + "ffn_gate": "layer.1.DenseReluDense.wi_0", + "ffn_norm": "layer.1.layer_norm", + } + wierd_t5_pre_quant_keys_from_city96 = ['shared.weight'] + asd_new = {} + for k, v in asd.items(): + for s, d in wierd_t5_format_from_city96.items(): + k = k.replace(s, d) + asd_new[k] = v + for k in wierd_t5_pre_quant_keys_from_city96: + asd_new[k] = asd_new[k].dequantize_as_pytorch_parameter() + asd.clear() + asd = asd_new + if "decoder.conv_in.weight" in asd: keys_to_delete = [k for k in sd if k.startswith(vae_key_prefix)] for k in keys_to_delete: diff --git a/backend/operations_gguf.py b/backend/operations_gguf.py index 80b77382..72da4604 100644 --- a/backend/operations_gguf.py +++ b/backend/operations_gguf.py @@ -35,6 +35,9 @@ class ParameterGGUF(torch.nn.Parameter): def __new__(cls, tensor=None, requires_grad=False, no_init=False): return super().__new__(cls, torch.tensor(tensor.data), requires_grad=requires_grad) + def dequantize_as_pytorch_parameter(self): + return torch.nn.Parameter(dequantize_tensor(self), requires_grad=False) + def to(self, *args, **kwargs): new = ParameterGGUF(self.data.to(*args, **kwargs), no_init=True) new.gguf_type = self.gguf_type diff --git a/backend/utils.py b/backend/utils.py index 56535e0c..d023f577 100644 --- a/backend/utils.py +++ b/backend/utils.py @@ -151,3 +151,18 @@ def get_state_dict_after_quant(model, prefix=''): sd = model.state_dict() sd = {(prefix + k): v.clone() for k, v in sd.items()} return sd + + +def beautiful_print_gguf_state_dict_statics(state_dict): + from gguf.constants import GGMLQuantizationType + type_counts = {} + for k, v in state_dict.items(): + gguf_type = getattr(v, 'gguf_type', None) + if gguf_type is not None: + type_name = GGMLQuantizationType(gguf_type).name + if type_name in type_counts: + type_counts[type_name] += 1 + else: + type_counts[type_name] = 1 + print(f'GGUF state dict: {type_counts}') + return diff --git a/modules_forge/main_entry.py b/modules_forge/main_entry.py index 86cf6670..71164bb3 100644 --- a/modules_forge/main_entry.py +++ b/modules_forge/main_entry.py @@ -139,7 +139,7 @@ def refresh_models(): shared_items.refresh_checkpoints() ckpt_list = shared_items.list_checkpoint_tiles(shared.opts.sd_checkpoint_dropdown_use_short) - file_extensions = ['ckpt', 'pt', 'bin', 'safetensors'] + file_extensions = ['ckpt', 'pt', 'bin', 'safetensors', 'gguf'] module_list.clear()