diff --git a/backend/loader.py b/backend/loader.py index fd06688e..fa33ccbc 100644 --- a/backend/loader.py +++ b/backend/loader.py @@ -10,7 +10,7 @@ from diffusers import DiffusionPipeline from transformers import modeling_utils from backend import memory_management -from backend.utils import read_arbitrary_config +from backend.utils import read_arbitrary_config, load_torch_file from backend.state_dict import try_filter_state_dict, load_state_dict from backend.operations import using_forge_operations from backend.nn.vae import IntegratedAutoencoderKL @@ -46,6 +46,8 @@ def load_huggingface_component(guess, component_name, lib_name, cls_name, repo_p comp._eventual_warn_about_too_long_sequence = lambda *args, **kwargs: None return comp if cls_name in ['AutoencoderKL']: + assert isinstance(state_dict, dict) and len(state_dict) > 16, 'You do not have VAE state dict!' + config = IntegratedAutoencoderKL.load_config(config_path) with using_forge_operations(device=memory_management.cpu, dtype=memory_management.vae_dtype()): @@ -54,6 +56,8 @@ def load_huggingface_component(guess, component_name, lib_name, cls_name, repo_p load_state_dict(model, state_dict, ignore_start='loss.') return model if component_name.startswith('text_encoder') and cls_name in ['CLIPTextModel', 'CLIPTextModelWithProjection']: + assert isinstance(state_dict, dict) and len(state_dict) > 16, 'You do not have CLIP state dict!' + from transformers import CLIPTextConfig, CLIPTextModel config = CLIPTextConfig.from_pretrained(config_path) @@ -71,6 +75,8 @@ def load_huggingface_component(guess, component_name, lib_name, cls_name, repo_p return model if cls_name == 'T5EncoderModel': + assert isinstance(state_dict, dict) and len(state_dict) > 16, 'You do not have T5 state dict!' + from backend.nn.t5 import IntegratedT5 config = read_arbitrary_config(config_path) @@ -78,17 +84,21 @@ def load_huggingface_component(guess, component_name, lib_name, cls_name, repo_p sd_dtype = memory_management.state_dict_dtype(state_dict) if sd_dtype in [torch.float8_e4m3fn, torch.float8_e5m2]: - print(f'Using Detected T5 Data Type: {sd_dtype}') dtype = sd_dtype + print(f'Using Detected T5 Data Type: {dtype}') + else: + print(f'Using Default T5 Data Type: {dtype}') with modeling_utils.no_init_weights(): with using_forge_operations(device=memory_management.cpu, dtype=dtype, manual_cast_enabled=True): model = IntegratedT5(config) - load_state_dict(model, state_dict, log_name=cls_name, ignore_errors=['transformer.encoder.embed_tokens.weight']) + load_state_dict(model, state_dict, log_name=cls_name, ignore_errors=['transformer.encoder.embed_tokens.weight', 'logit_scale']) return model if cls_name in ['UNet2DConditionModel', 'FluxTransformer2DModel']: + assert isinstance(state_dict, dict) and len(state_dict) > 16, 'You do not have model state dict!' + model_loader = None if cls_name == 'UNet2DConditionModel': model_loader = lambda c: IntegratedUNet2DConditionModel.from_config(c) @@ -148,16 +158,57 @@ def load_huggingface_component(guess, component_name, lib_name, cls_name, repo_p return None -def split_state_dict(sd, sd_vae=None): - guess = huggingface_guess.guess(sd) - guess.clip_target = guess.clip_target(sd) +def replace_state_dict(sd, asd, guess): + vae_key_prefix = guess.vae_key_prefix[0] + text_encoder_key_prefix = guess.text_encoder_key_prefix[0] - if sd_vae is not None: - print(f'Using external VAE state dict: {len(sd_vae)}') + if "decoder.conv_in.weight" in asd: + keys_to_delete = [k for k in sd if k.startswith(vae_key_prefix)] + for k in keys_to_delete: + del sd[k] + for k, v in asd.items(): + sd[vae_key_prefix + k] = v + + if 'text_model.encoder.layers.0.layer_norm1.weight' in asd: + keys_to_delete = [k for k in sd if k.startswith(f"{text_encoder_key_prefix}clip_l.")] + for k in keys_to_delete: + del sd[k] + for k, v in asd.items(): + sd[f"{text_encoder_key_prefix}clip_l.transformer.{k}"] = v + + if 'encoder.block.0.layer.0.SelfAttention.k.weight' in asd: + keys_to_delete = [k for k in sd if k.startswith(f"{text_encoder_key_prefix}t5xxl.")] + for k in keys_to_delete: + del sd[k] + for k, v in asd.items(): + sd[f"{text_encoder_key_prefix}t5xxl.transformer.{k}"] = v + + return sd + + +def preprocess_state_dict(sd): + if any("double_block" in k for k in sd.keys()): + if not any(k.startswith("model.diffusion_model") for k in sd.keys()): + sd = {f"model.diffusion_model.{k}": v for k, v in sd.items()} + + return sd + + +def split_state_dict(sd, additional_state_dicts: list = None): + sd = load_torch_file(sd) + sd = preprocess_state_dict(sd) + guess = huggingface_guess.guess(sd) + + if isinstance(additional_state_dicts, list): + for asd in additional_state_dicts: + asd = load_torch_file(asd) + sd = replace_state_dict(sd, asd, guess) + + guess.clip_target = guess.clip_target(sd) state_dict = { guess.unet_target: try_filter_state_dict(sd, guess.unet_key_prefix), - guess.vae_target: try_filter_state_dict(sd, guess.vae_key_prefix) if sd_vae is None else sd_vae + guess.vae_target: try_filter_state_dict(sd, guess.vae_key_prefix) } sd = guess.process_clip_state_dict(sd) @@ -176,9 +227,9 @@ def split_state_dict(sd, sd_vae=None): @torch.no_grad() -def forge_loader(sd, sd_vae=None): +def forge_loader(sd, additional_state_dicts=None): try: - state_dicts, estimated_config = split_state_dict(sd, sd_vae=sd_vae) + state_dicts, estimated_config = split_state_dict(sd, additional_state_dicts=additional_state_dicts) except: raise ValueError('Failed to recognize model type!') diff --git a/backend/memory_management.py b/backend/memory_management.py index 5a6098e0..b3fd3d9b 100644 --- a/backend/memory_management.py +++ b/backend/memory_management.py @@ -384,6 +384,7 @@ class LoadedModel: if not do_not_need_cpu_swap: memory_in_swap = 0 mem_counter = 0 + mem_cannot_cast = 0 for m in self.real_model.modules(): if hasattr(m, "parameters_manual_cast"): m.prev_parameters_manual_cast = m.parameters_manual_cast @@ -399,8 +400,12 @@ class LoadedModel: m._apply(lambda x: x.pin_memory()) elif hasattr(m, "weight"): m.to(self.device) - mem_counter += module_size(m) - print(f"[Memory Management] Swap disabled for", type(m).__name__) + module_mem = module_size(m) + mem_counter += module_mem + mem_cannot_cast += module_mem + + if mem_cannot_cast > 0: + print(f"[Memory Management] Loaded to GPU for backward capability: {mem_cannot_cast / (1024 * 1024):.2f} MB") swap_flag = 'Shared' if PIN_SHARED_MEMORY else 'CPU' method_flag = 'asynchronous' if stream.should_use_stream() else 'blocked' diff --git a/backend/patcher/clip.py b/backend/patcher/clip.py index a870d9bb..3979e96c 100644 --- a/backend/patcher/clip.py +++ b/backend/patcher/clip.py @@ -3,6 +3,10 @@ from backend.patcher.base import ModelPatcher from backend.nn.base import ModuleDict, ObjectDict +class JointTextEncoder(ModuleDict): + pass + + class CLIP: def __init__(self, model_dict={}, tokenizer_dict={}, no_init=False): if no_init: @@ -11,7 +15,7 @@ class CLIP: load_device = memory_management.text_encoder_device() offload_device = memory_management.text_encoder_offload_device() - self.cond_stage_model = ModuleDict(model_dict) + self.cond_stage_model = JointTextEncoder(model_dict) self.tokenizer = ObjectDict(tokenizer_dict) self.patcher = ModelPatcher(self.cond_stage_model, load_device=load_device, offload_device=offload_device) diff --git a/modules/processing.py b/modules/processing.py index 15c70693..b20a8dad 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -740,8 +740,8 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iter "Model": p.sd_model_name if opts.add_model_name_to_info else None, "FP8 weight": opts.fp8_storage if devices.fp8 else None, "Cache FP16 weight for LoRA": opts.cache_fp16_weight if devices.fp8 else None, - "VAE hash": p.sd_vae_hash if opts.add_vae_hash_to_info else None, - "VAE": p.sd_vae_name if opts.add_vae_name_to_info else None, + # "VAE hash": p.sd_vae_hash if opts.add_vae_hash_to_info else None, + # "VAE": p.sd_vae_name if opts.add_vae_name_to_info else None, "Variation seed": (None if p.subseed_strength == 0 else (p.all_subseeds[0] if use_main_prompt else all_subseeds[index])), "Variation seed strength": (None if p.subseed_strength == 0 else p.subseed_strength), "Seed resize from": (None if p.seed_resize_from_w <= 0 or p.seed_resize_from_h <= 0 else f"{p.seed_resize_from_w}x{p.seed_resize_from_h}"), @@ -759,6 +759,10 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iter "User": p.user if opts.add_user_name_to_info else None, }) + if isinstance(shared.opts.forge_additional_modules, list) and len(shared.opts.forge_additional_modules) > 0: + for i, m in enumerate(shared.opts.forge_additional_modules): + generation_params[f'Module {i+1}'] = os.path.splitext(os.path.basename(m))[0] + for key, value in generation_params.items(): try: if isinstance(value, list): @@ -787,6 +791,12 @@ def process_images(p: StableDiffusionProcessing) -> Processed: if need_global_unload and not just_reloaded: memory_management.unload_all_models() + if need_global_unload: + StableDiffusionProcessing.cached_c = [None, None, None] + StableDiffusionProcessing.cached_uc = [None, None, None] + p.cached_c = [None, None, None] + p.cached_uc = [None, None, None] + need_global_unload = False if p.scripts is not None: diff --git a/modules/sd_models.py b/modules/sd_models.py index 8686c264..16c4d459 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -490,27 +490,15 @@ def forge_model_reload(): timer.record("unload existing model") checkpoint_info = model_data.forge_loading_parameters['checkpoint_info'] - state_dict = load_torch_file(checkpoint_info.filename) - timer.record("load state dict") - - state_dict_vae = model_data.forge_loading_parameters.get('vae_filename', None) - - if state_dict_vae is not None: - state_dict_vae = load_torch_file(state_dict_vae) - - timer.record("load vae state dict") - - if shared.opts.sd_checkpoint_cache > 0: - # cache newly loaded model - checkpoints_loaded[checkpoint_info] = state_dict.copy() + state_dict = checkpoint_info.filename + additional_state_dicts = model_data.forge_loading_parameters.get('additional_modules', []) timer.record("cache state dict") dynamic_args['forge_unet_storage_dtype'] = model_data.forge_loading_parameters.get('unet_storage_dtype', None) dynamic_args['embedding_dir'] = cmd_opts.embeddings_dir dynamic_args['emphasis_name'] = opts.emphasis - sd_model = forge_loader(state_dict, sd_vae=state_dict_vae) - del state_dict + sd_model = forge_loader(state_dict, additional_state_dicts=additional_state_dicts) timer.record("forge model load") sd_model.extra_generation_params = {} @@ -520,10 +508,6 @@ def forge_model_reload(): sd_model.sd_model_hash = checkpoint_info.calculate_shorthash() timer.record("calculate hash") - # clean up cache if limit is reached - while len(checkpoints_loaded) > shared.opts.sd_checkpoint_cache: - checkpoints_loaded.popitem(last=False) - shared.opts.data["sd_checkpoint_hash"] = checkpoint_info.sha256 model_data.set_sd_model(sd_model) diff --git a/modules_forge/main_entry.py b/modules_forge/main_entry.py index bfd188bd..c5d6953e 100644 --- a/modules_forge/main_entry.py +++ b/modules_forge/main_entry.py @@ -1,9 +1,9 @@ +import os import torch import gradio as gr from gradio.context import Context -from modules import shared_items, shared, ui_common, sd_models, processing, infotext_utils -from modules import sd_vae as sd_vae_module +from modules import shared_items, shared, ui_common, sd_models, processing, infotext_utils, paths from backend import memory_management, stream @@ -13,7 +13,6 @@ ui_forge_preset: gr.Radio = None ui_checkpoint: gr.Dropdown = None ui_vae: gr.Dropdown = None -ui_vae_refresh_button: gr.Button = None ui_clip_skip: gr.Slider = None ui_forge_unet_storage_dtype_options: gr.Radio = None @@ -29,6 +28,8 @@ forge_unet_storage_dtype_options = { 'fp8e5': torch.float8_e5m2, } +module_list = {} + def bind_to_opts(comp, k, save=False, callback=None): def on_change(v): @@ -44,7 +45,7 @@ def bind_to_opts(comp, k, save=False, callback=None): def make_checkpoint_manager_ui(): - global ui_checkpoint, ui_vae, ui_clip_skip, ui_forge_unet_storage_dtype_options, ui_forge_async_loading, ui_forge_pin_shared_memory, ui_forge_inference_memory, ui_forge_preset, ui_vae_refresh_button + global ui_checkpoint, ui_vae, ui_clip_skip, ui_forge_unet_storage_dtype_options, ui_forge_async_loading, ui_forge_pin_shared_memory, ui_forge_inference_memory, ui_forge_preset if shared.opts.sd_model_checkpoint in [None, 'None', 'none', '']: if len(sd_models.checkpoints_list) == 0: @@ -54,22 +55,44 @@ def make_checkpoint_manager_ui(): ui_forge_preset = gr.Radio(label="UI", value=lambda: shared.opts.forge_preset, choices=['sd', 'xl', 'flux', 'all']) - sd_model_checkpoint_args = lambda: {"choices": shared_items.list_checkpoint_tiles(shared.opts.sd_checkpoint_dropdown_use_short)} + ckpt_list, vae_list = refresh_models() + ui_checkpoint = gr.Dropdown( value=lambda: shared.opts.sd_model_checkpoint, label="Checkpoint", elem_classes=['model_selection'], - **sd_model_checkpoint_args() + choices=ckpt_list ) - ui_common.create_refresh_button(ui_checkpoint, shared_items.refresh_checkpoints, sd_model_checkpoint_args, f"forge_refresh_checkpoint") - sd_vae_args = lambda: {"choices": shared_items.sd_vae_items()} ui_vae = gr.Dropdown( - value=lambda: shared.opts.sd_vae, - label="VAE", - **sd_vae_args() + value=lambda: [os.path.basename(x) for x in shared.opts.forge_additional_modules], + multiselect=True, + label="VAE / Text Encoder", + render=False, + choices=vae_list ) - ui_vae_refresh_button = ui_common.create_refresh_button(ui_vae, shared_items.refresh_vae_list, sd_vae_args, f"forge_refresh_vae") + + def gr_refresh_models(): + a, b = refresh_models() + return gr.update(choices=a), gr.update(choices=b) + + refresh_button = ui_common.ToolButton(value=ui_common.refresh_symbol, elem_id=f"forge_refresh_checkpoint", tooltip="Refresh") + refresh_button.click( + fn=gr_refresh_models, + inputs=[], + outputs=[ui_checkpoint, ui_vae], + show_progress=False, + queue=False + ) + Context.root_block.load( + fn=gr_refresh_models, + inputs=[], + outputs=[ui_checkpoint, ui_vae], + show_progress=False, + queue=False + ) + + ui_vae.render() ui_forge_unet_storage_dtype_options = gr.Radio(label="Diffusion with Low Bits", value=lambda: shared.opts.forge_unet_storage_dtype, choices=list(forge_unet_storage_dtype_options.keys())) bind_to_opts(ui_forge_unet_storage_dtype_options, 'forge_unet_storage_dtype', save=True, callback=refresh_model_loading_parameters) @@ -94,6 +117,37 @@ def make_checkpoint_manager_ui(): return +def find_files_with_extensions(base_path, extensions): + found_files = {} + for root, _, files in os.walk(base_path): + for file in files: + if any(file.endswith(ext) for ext in extensions): + full_path = os.path.join(root, file) + found_files[file] = full_path + return found_files + + +def refresh_models(): + global module_list + + shared_items.refresh_checkpoints() + ckpt_list = shared_items.list_checkpoint_tiles(shared.opts.sd_checkpoint_dropdown_use_short) + + vae_path = os.path.abspath(os.path.join(paths.models_path, "VAE")) + text_encoder_path = os.path.abspath(os.path.join(paths.models_path, "text_encoder")) + file_extensions = ['ckpt', 'pt', 'bin', 'safetensors'] + + module_list.clear() + + vae_files = find_files_with_extensions(vae_path, file_extensions) + module_list.update(vae_files) + + text_encoder_files = find_files_with_extensions(text_encoder_path, file_extensions) + module_list.update(text_encoder_files) + + return ckpt_list, module_list.keys() + + def refresh_memory_management_settings(model_memory, async_loading, pin_shared_memory): inference_memory = total_vram - model_memory @@ -121,11 +175,10 @@ def refresh_model_loading_parameters(): from modules.sd_models import select_checkpoint, model_data checkpoint_info = select_checkpoint() - vae_resolution = sd_vae_module.resolve_vae(checkpoint_info.filename) model_data.forge_loading_parameters = dict( checkpoint_info=checkpoint_info, - vae_filename=vae_resolution.vae, + additional_modules=shared.opts.forge_additional_modules, unet_storage_dtype=forge_unet_storage_dtype_options.get(shared.opts.forge_unet_storage_dtype, None) ) @@ -142,8 +195,15 @@ def checkpoint_change(ckpt_name): return -def vae_change(vae_name): - shared.opts.set('sd_vae', vae_name) +def vae_change(module_names): + modules = [] + + for n in module_names: + if n in module_list: + modules.append(module_list[n]) + + shared.opts.set('forge_additional_modules', modules) + shared.opts.save(shared.config_filename) refresh_model_loading_parameters() return @@ -173,7 +233,6 @@ def forge_main_entry(): output_targets = [ ui_vae, - ui_vae_refresh_button, ui_clip_skip, ui_forge_unet_storage_dtype_options, ui_forge_async_loading, @@ -207,8 +266,7 @@ def on_preset_change(preset=None): if shared.opts.forge_preset == 'sd': return [ - gr.update(visible=True, value='Automatic'), # ui_vae - gr.update(visible=True), # ui_vae_refresh_button + gr.update(visible=True), # ui_vae gr.update(visible=True, value=1), # ui_clip_skip gr.update(visible=False, value='Auto'), # ui_forge_unet_storage_dtype_options gr.update(visible=False, value='Queue'), # ui_forge_async_loading @@ -230,8 +288,7 @@ def on_preset_change(preset=None): if shared.opts.forge_preset == 'xl': return [ - gr.update(visible=False, value='Automatic'), # ui_vae - gr.update(visible=False), # ui_vae_refresh_button + gr.update(visible=True), # ui_vae gr.update(visible=False, value=1), # ui_clip_skip gr.update(visible=True, value='Auto'), # ui_forge_unet_storage_dtype_options gr.update(visible=False, value='Queue'), # ui_forge_async_loading @@ -253,8 +310,7 @@ def on_preset_change(preset=None): if shared.opts.forge_preset == 'flux': return [ - gr.update(visible=False, value='Automatic'), # ui_vae - gr.update(visible=False), # ui_vae_refresh_button + gr.update(visible=True), # ui_vae gr.update(visible=False, value=1), # ui_clip_skip gr.update(visible=True, value='Auto'), # ui_forge_unet_storage_dtype_options gr.update(visible=True, value='Queue'), # ui_forge_async_loading @@ -275,8 +331,7 @@ def on_preset_change(preset=None): ] return [ - gr.update(visible=True, value='Automatic'), # ui_vae - gr.update(visible=True), # ui_vae_refresh_button + gr.update(visible=True), # ui_vae gr.update(visible=True, value=1), # ui_clip_skip gr.update(visible=True, value='Auto'), # ui_forge_unet_storage_dtype_options gr.update(visible=True, value='Queue'), # ui_forge_async_loading diff --git a/modules_forge/shared_options.py b/modules_forge/shared_options.py index f88510fa..18c47ba8 100644 --- a/modules_forge/shared_options.py +++ b/modules_forge/shared_options.py @@ -6,4 +6,5 @@ def register(options_templates, options_section, OptionInfo): "forge_async_loading": OptionInfo('Queue'), "forge_pin_shared_memory": OptionInfo('CPU'), "forge_preset": OptionInfo('sd'), + "forge_additional_modules": OptionInfo([]), })) diff --git a/style.css b/style.css index 680f46b1..9cc2534f 100644 --- a/style.css +++ b/style.css @@ -440,7 +440,13 @@ div.toprow-compact-tools{ } #quicksettings > div.model_selection{ - min-width: 24em !important; + min-width: 20em !important; +} + +#quicksettings .subdued{ + display: block; + margin-left: auto; + width: 30px; } #settings{