support all flux models

This commit is contained in:
lllyasviel
2024-08-13 05:42:17 -07:00
committed by GitHub
parent 3589b57ec1
commit 61f83dd610
8 changed files with 177 additions and 61 deletions

View File

@@ -10,7 +10,7 @@ from diffusers import DiffusionPipeline
from transformers import modeling_utils
from backend import memory_management
from backend.utils import read_arbitrary_config
from backend.utils import read_arbitrary_config, load_torch_file
from backend.state_dict import try_filter_state_dict, load_state_dict
from backend.operations import using_forge_operations
from backend.nn.vae import IntegratedAutoencoderKL
@@ -46,6 +46,8 @@ def load_huggingface_component(guess, component_name, lib_name, cls_name, repo_p
comp._eventual_warn_about_too_long_sequence = lambda *args, **kwargs: None
return comp
if cls_name in ['AutoencoderKL']:
assert isinstance(state_dict, dict) and len(state_dict) > 16, 'You do not have VAE state dict!'
config = IntegratedAutoencoderKL.load_config(config_path)
with using_forge_operations(device=memory_management.cpu, dtype=memory_management.vae_dtype()):
@@ -54,6 +56,8 @@ def load_huggingface_component(guess, component_name, lib_name, cls_name, repo_p
load_state_dict(model, state_dict, ignore_start='loss.')
return model
if component_name.startswith('text_encoder') and cls_name in ['CLIPTextModel', 'CLIPTextModelWithProjection']:
assert isinstance(state_dict, dict) and len(state_dict) > 16, 'You do not have CLIP state dict!'
from transformers import CLIPTextConfig, CLIPTextModel
config = CLIPTextConfig.from_pretrained(config_path)
@@ -71,6 +75,8 @@ def load_huggingface_component(guess, component_name, lib_name, cls_name, repo_p
return model
if cls_name == 'T5EncoderModel':
assert isinstance(state_dict, dict) and len(state_dict) > 16, 'You do not have T5 state dict!'
from backend.nn.t5 import IntegratedT5
config = read_arbitrary_config(config_path)
@@ -78,17 +84,21 @@ def load_huggingface_component(guess, component_name, lib_name, cls_name, repo_p
sd_dtype = memory_management.state_dict_dtype(state_dict)
if sd_dtype in [torch.float8_e4m3fn, torch.float8_e5m2]:
print(f'Using Detected T5 Data Type: {sd_dtype}')
dtype = sd_dtype
print(f'Using Detected T5 Data Type: {dtype}')
else:
print(f'Using Default T5 Data Type: {dtype}')
with modeling_utils.no_init_weights():
with using_forge_operations(device=memory_management.cpu, dtype=dtype, manual_cast_enabled=True):
model = IntegratedT5(config)
load_state_dict(model, state_dict, log_name=cls_name, ignore_errors=['transformer.encoder.embed_tokens.weight'])
load_state_dict(model, state_dict, log_name=cls_name, ignore_errors=['transformer.encoder.embed_tokens.weight', 'logit_scale'])
return model
if cls_name in ['UNet2DConditionModel', 'FluxTransformer2DModel']:
assert isinstance(state_dict, dict) and len(state_dict) > 16, 'You do not have model state dict!'
model_loader = None
if cls_name == 'UNet2DConditionModel':
model_loader = lambda c: IntegratedUNet2DConditionModel.from_config(c)
@@ -148,16 +158,57 @@ def load_huggingface_component(guess, component_name, lib_name, cls_name, repo_p
return None
def split_state_dict(sd, sd_vae=None):
guess = huggingface_guess.guess(sd)
guess.clip_target = guess.clip_target(sd)
def replace_state_dict(sd, asd, guess):
vae_key_prefix = guess.vae_key_prefix[0]
text_encoder_key_prefix = guess.text_encoder_key_prefix[0]
if sd_vae is not None:
print(f'Using external VAE state dict: {len(sd_vae)}')
if "decoder.conv_in.weight" in asd:
keys_to_delete = [k for k in sd if k.startswith(vae_key_prefix)]
for k in keys_to_delete:
del sd[k]
for k, v in asd.items():
sd[vae_key_prefix + k] = v
if 'text_model.encoder.layers.0.layer_norm1.weight' in asd:
keys_to_delete = [k for k in sd if k.startswith(f"{text_encoder_key_prefix}clip_l.")]
for k in keys_to_delete:
del sd[k]
for k, v in asd.items():
sd[f"{text_encoder_key_prefix}clip_l.transformer.{k}"] = v
if 'encoder.block.0.layer.0.SelfAttention.k.weight' in asd:
keys_to_delete = [k for k in sd if k.startswith(f"{text_encoder_key_prefix}t5xxl.")]
for k in keys_to_delete:
del sd[k]
for k, v in asd.items():
sd[f"{text_encoder_key_prefix}t5xxl.transformer.{k}"] = v
return sd
def preprocess_state_dict(sd):
if any("double_block" in k for k in sd.keys()):
if not any(k.startswith("model.diffusion_model") for k in sd.keys()):
sd = {f"model.diffusion_model.{k}": v for k, v in sd.items()}
return sd
def split_state_dict(sd, additional_state_dicts: list = None):
sd = load_torch_file(sd)
sd = preprocess_state_dict(sd)
guess = huggingface_guess.guess(sd)
if isinstance(additional_state_dicts, list):
for asd in additional_state_dicts:
asd = load_torch_file(asd)
sd = replace_state_dict(sd, asd, guess)
guess.clip_target = guess.clip_target(sd)
state_dict = {
guess.unet_target: try_filter_state_dict(sd, guess.unet_key_prefix),
guess.vae_target: try_filter_state_dict(sd, guess.vae_key_prefix) if sd_vae is None else sd_vae
guess.vae_target: try_filter_state_dict(sd, guess.vae_key_prefix)
}
sd = guess.process_clip_state_dict(sd)
@@ -176,9 +227,9 @@ def split_state_dict(sd, sd_vae=None):
@torch.no_grad()
def forge_loader(sd, sd_vae=None):
def forge_loader(sd, additional_state_dicts=None):
try:
state_dicts, estimated_config = split_state_dict(sd, sd_vae=sd_vae)
state_dicts, estimated_config = split_state_dict(sd, additional_state_dicts=additional_state_dicts)
except:
raise ValueError('Failed to recognize model type!')

View File

@@ -384,6 +384,7 @@ class LoadedModel:
if not do_not_need_cpu_swap:
memory_in_swap = 0
mem_counter = 0
mem_cannot_cast = 0
for m in self.real_model.modules():
if hasattr(m, "parameters_manual_cast"):
m.prev_parameters_manual_cast = m.parameters_manual_cast
@@ -399,8 +400,12 @@ class LoadedModel:
m._apply(lambda x: x.pin_memory())
elif hasattr(m, "weight"):
m.to(self.device)
mem_counter += module_size(m)
print(f"[Memory Management] Swap disabled for", type(m).__name__)
module_mem = module_size(m)
mem_counter += module_mem
mem_cannot_cast += module_mem
if mem_cannot_cast > 0:
print(f"[Memory Management] Loaded to GPU for backward capability: {mem_cannot_cast / (1024 * 1024):.2f} MB")
swap_flag = 'Shared' if PIN_SHARED_MEMORY else 'CPU'
method_flag = 'asynchronous' if stream.should_use_stream() else 'blocked'

View File

@@ -3,6 +3,10 @@ from backend.patcher.base import ModelPatcher
from backend.nn.base import ModuleDict, ObjectDict
class JointTextEncoder(ModuleDict):
pass
class CLIP:
def __init__(self, model_dict={}, tokenizer_dict={}, no_init=False):
if no_init:
@@ -11,7 +15,7 @@ class CLIP:
load_device = memory_management.text_encoder_device()
offload_device = memory_management.text_encoder_offload_device()
self.cond_stage_model = ModuleDict(model_dict)
self.cond_stage_model = JointTextEncoder(model_dict)
self.tokenizer = ObjectDict(tokenizer_dict)
self.patcher = ModelPatcher(self.cond_stage_model, load_device=load_device, offload_device=offload_device)

View File

@@ -740,8 +740,8 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iter
"Model": p.sd_model_name if opts.add_model_name_to_info else None,
"FP8 weight": opts.fp8_storage if devices.fp8 else None,
"Cache FP16 weight for LoRA": opts.cache_fp16_weight if devices.fp8 else None,
"VAE hash": p.sd_vae_hash if opts.add_vae_hash_to_info else None,
"VAE": p.sd_vae_name if opts.add_vae_name_to_info else None,
# "VAE hash": p.sd_vae_hash if opts.add_vae_hash_to_info else None,
# "VAE": p.sd_vae_name if opts.add_vae_name_to_info else None,
"Variation seed": (None if p.subseed_strength == 0 else (p.all_subseeds[0] if use_main_prompt else all_subseeds[index])),
"Variation seed strength": (None if p.subseed_strength == 0 else p.subseed_strength),
"Seed resize from": (None if p.seed_resize_from_w <= 0 or p.seed_resize_from_h <= 0 else f"{p.seed_resize_from_w}x{p.seed_resize_from_h}"),
@@ -759,6 +759,10 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iter
"User": p.user if opts.add_user_name_to_info else None,
})
if isinstance(shared.opts.forge_additional_modules, list) and len(shared.opts.forge_additional_modules) > 0:
for i, m in enumerate(shared.opts.forge_additional_modules):
generation_params[f'Module {i+1}'] = os.path.splitext(os.path.basename(m))[0]
for key, value in generation_params.items():
try:
if isinstance(value, list):
@@ -787,6 +791,12 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
if need_global_unload and not just_reloaded:
memory_management.unload_all_models()
if need_global_unload:
StableDiffusionProcessing.cached_c = [None, None, None]
StableDiffusionProcessing.cached_uc = [None, None, None]
p.cached_c = [None, None, None]
p.cached_uc = [None, None, None]
need_global_unload = False
if p.scripts is not None:

View File

@@ -490,27 +490,15 @@ def forge_model_reload():
timer.record("unload existing model")
checkpoint_info = model_data.forge_loading_parameters['checkpoint_info']
state_dict = load_torch_file(checkpoint_info.filename)
timer.record("load state dict")
state_dict_vae = model_data.forge_loading_parameters.get('vae_filename', None)
if state_dict_vae is not None:
state_dict_vae = load_torch_file(state_dict_vae)
timer.record("load vae state dict")
if shared.opts.sd_checkpoint_cache > 0:
# cache newly loaded model
checkpoints_loaded[checkpoint_info] = state_dict.copy()
state_dict = checkpoint_info.filename
additional_state_dicts = model_data.forge_loading_parameters.get('additional_modules', [])
timer.record("cache state dict")
dynamic_args['forge_unet_storage_dtype'] = model_data.forge_loading_parameters.get('unet_storage_dtype', None)
dynamic_args['embedding_dir'] = cmd_opts.embeddings_dir
dynamic_args['emphasis_name'] = opts.emphasis
sd_model = forge_loader(state_dict, sd_vae=state_dict_vae)
del state_dict
sd_model = forge_loader(state_dict, additional_state_dicts=additional_state_dicts)
timer.record("forge model load")
sd_model.extra_generation_params = {}
@@ -520,10 +508,6 @@ def forge_model_reload():
sd_model.sd_model_hash = checkpoint_info.calculate_shorthash()
timer.record("calculate hash")
# clean up cache if limit is reached
while len(checkpoints_loaded) > shared.opts.sd_checkpoint_cache:
checkpoints_loaded.popitem(last=False)
shared.opts.data["sd_checkpoint_hash"] = checkpoint_info.sha256
model_data.set_sd_model(sd_model)

View File

@@ -1,9 +1,9 @@
import os
import torch
import gradio as gr
from gradio.context import Context
from modules import shared_items, shared, ui_common, sd_models, processing, infotext_utils
from modules import sd_vae as sd_vae_module
from modules import shared_items, shared, ui_common, sd_models, processing, infotext_utils, paths
from backend import memory_management, stream
@@ -13,7 +13,6 @@ ui_forge_preset: gr.Radio = None
ui_checkpoint: gr.Dropdown = None
ui_vae: gr.Dropdown = None
ui_vae_refresh_button: gr.Button = None
ui_clip_skip: gr.Slider = None
ui_forge_unet_storage_dtype_options: gr.Radio = None
@@ -29,6 +28,8 @@ forge_unet_storage_dtype_options = {
'fp8e5': torch.float8_e5m2,
}
module_list = {}
def bind_to_opts(comp, k, save=False, callback=None):
def on_change(v):
@@ -44,7 +45,7 @@ def bind_to_opts(comp, k, save=False, callback=None):
def make_checkpoint_manager_ui():
global ui_checkpoint, ui_vae, ui_clip_skip, ui_forge_unet_storage_dtype_options, ui_forge_async_loading, ui_forge_pin_shared_memory, ui_forge_inference_memory, ui_forge_preset, ui_vae_refresh_button
global ui_checkpoint, ui_vae, ui_clip_skip, ui_forge_unet_storage_dtype_options, ui_forge_async_loading, ui_forge_pin_shared_memory, ui_forge_inference_memory, ui_forge_preset
if shared.opts.sd_model_checkpoint in [None, 'None', 'none', '']:
if len(sd_models.checkpoints_list) == 0:
@@ -54,22 +55,44 @@ def make_checkpoint_manager_ui():
ui_forge_preset = gr.Radio(label="UI", value=lambda: shared.opts.forge_preset, choices=['sd', 'xl', 'flux', 'all'])
sd_model_checkpoint_args = lambda: {"choices": shared_items.list_checkpoint_tiles(shared.opts.sd_checkpoint_dropdown_use_short)}
ckpt_list, vae_list = refresh_models()
ui_checkpoint = gr.Dropdown(
value=lambda: shared.opts.sd_model_checkpoint,
label="Checkpoint",
elem_classes=['model_selection'],
**sd_model_checkpoint_args()
choices=ckpt_list
)
ui_common.create_refresh_button(ui_checkpoint, shared_items.refresh_checkpoints, sd_model_checkpoint_args, f"forge_refresh_checkpoint")
sd_vae_args = lambda: {"choices": shared_items.sd_vae_items()}
ui_vae = gr.Dropdown(
value=lambda: shared.opts.sd_vae,
label="VAE",
**sd_vae_args()
value=lambda: [os.path.basename(x) for x in shared.opts.forge_additional_modules],
multiselect=True,
label="VAE / Text Encoder",
render=False,
choices=vae_list
)
ui_vae_refresh_button = ui_common.create_refresh_button(ui_vae, shared_items.refresh_vae_list, sd_vae_args, f"forge_refresh_vae")
def gr_refresh_models():
a, b = refresh_models()
return gr.update(choices=a), gr.update(choices=b)
refresh_button = ui_common.ToolButton(value=ui_common.refresh_symbol, elem_id=f"forge_refresh_checkpoint", tooltip="Refresh")
refresh_button.click(
fn=gr_refresh_models,
inputs=[],
outputs=[ui_checkpoint, ui_vae],
show_progress=False,
queue=False
)
Context.root_block.load(
fn=gr_refresh_models,
inputs=[],
outputs=[ui_checkpoint, ui_vae],
show_progress=False,
queue=False
)
ui_vae.render()
ui_forge_unet_storage_dtype_options = gr.Radio(label="Diffusion with Low Bits", value=lambda: shared.opts.forge_unet_storage_dtype, choices=list(forge_unet_storage_dtype_options.keys()))
bind_to_opts(ui_forge_unet_storage_dtype_options, 'forge_unet_storage_dtype', save=True, callback=refresh_model_loading_parameters)
@@ -94,6 +117,37 @@ def make_checkpoint_manager_ui():
return
def find_files_with_extensions(base_path, extensions):
found_files = {}
for root, _, files in os.walk(base_path):
for file in files:
if any(file.endswith(ext) for ext in extensions):
full_path = os.path.join(root, file)
found_files[file] = full_path
return found_files
def refresh_models():
global module_list
shared_items.refresh_checkpoints()
ckpt_list = shared_items.list_checkpoint_tiles(shared.opts.sd_checkpoint_dropdown_use_short)
vae_path = os.path.abspath(os.path.join(paths.models_path, "VAE"))
text_encoder_path = os.path.abspath(os.path.join(paths.models_path, "text_encoder"))
file_extensions = ['ckpt', 'pt', 'bin', 'safetensors']
module_list.clear()
vae_files = find_files_with_extensions(vae_path, file_extensions)
module_list.update(vae_files)
text_encoder_files = find_files_with_extensions(text_encoder_path, file_extensions)
module_list.update(text_encoder_files)
return ckpt_list, module_list.keys()
def refresh_memory_management_settings(model_memory, async_loading, pin_shared_memory):
inference_memory = total_vram - model_memory
@@ -121,11 +175,10 @@ def refresh_model_loading_parameters():
from modules.sd_models import select_checkpoint, model_data
checkpoint_info = select_checkpoint()
vae_resolution = sd_vae_module.resolve_vae(checkpoint_info.filename)
model_data.forge_loading_parameters = dict(
checkpoint_info=checkpoint_info,
vae_filename=vae_resolution.vae,
additional_modules=shared.opts.forge_additional_modules,
unet_storage_dtype=forge_unet_storage_dtype_options.get(shared.opts.forge_unet_storage_dtype, None)
)
@@ -142,8 +195,15 @@ def checkpoint_change(ckpt_name):
return
def vae_change(vae_name):
shared.opts.set('sd_vae', vae_name)
def vae_change(module_names):
modules = []
for n in module_names:
if n in module_list:
modules.append(module_list[n])
shared.opts.set('forge_additional_modules', modules)
shared.opts.save(shared.config_filename)
refresh_model_loading_parameters()
return
@@ -173,7 +233,6 @@ def forge_main_entry():
output_targets = [
ui_vae,
ui_vae_refresh_button,
ui_clip_skip,
ui_forge_unet_storage_dtype_options,
ui_forge_async_loading,
@@ -207,8 +266,7 @@ def on_preset_change(preset=None):
if shared.opts.forge_preset == 'sd':
return [
gr.update(visible=True, value='Automatic'), # ui_vae
gr.update(visible=True), # ui_vae_refresh_button
gr.update(visible=True), # ui_vae
gr.update(visible=True, value=1), # ui_clip_skip
gr.update(visible=False, value='Auto'), # ui_forge_unet_storage_dtype_options
gr.update(visible=False, value='Queue'), # ui_forge_async_loading
@@ -230,8 +288,7 @@ def on_preset_change(preset=None):
if shared.opts.forge_preset == 'xl':
return [
gr.update(visible=False, value='Automatic'), # ui_vae
gr.update(visible=False), # ui_vae_refresh_button
gr.update(visible=True), # ui_vae
gr.update(visible=False, value=1), # ui_clip_skip
gr.update(visible=True, value='Auto'), # ui_forge_unet_storage_dtype_options
gr.update(visible=False, value='Queue'), # ui_forge_async_loading
@@ -253,8 +310,7 @@ def on_preset_change(preset=None):
if shared.opts.forge_preset == 'flux':
return [
gr.update(visible=False, value='Automatic'), # ui_vae
gr.update(visible=False), # ui_vae_refresh_button
gr.update(visible=True), # ui_vae
gr.update(visible=False, value=1), # ui_clip_skip
gr.update(visible=True, value='Auto'), # ui_forge_unet_storage_dtype_options
gr.update(visible=True, value='Queue'), # ui_forge_async_loading
@@ -275,8 +331,7 @@ def on_preset_change(preset=None):
]
return [
gr.update(visible=True, value='Automatic'), # ui_vae
gr.update(visible=True), # ui_vae_refresh_button
gr.update(visible=True), # ui_vae
gr.update(visible=True, value=1), # ui_clip_skip
gr.update(visible=True, value='Auto'), # ui_forge_unet_storage_dtype_options
gr.update(visible=True, value='Queue'), # ui_forge_async_loading

View File

@@ -6,4 +6,5 @@ def register(options_templates, options_section, OptionInfo):
"forge_async_loading": OptionInfo('Queue'),
"forge_pin_shared_memory": OptionInfo('CPU'),
"forge_preset": OptionInfo('sd'),
"forge_additional_modules": OptionInfo([]),
}))

View File

@@ -440,7 +440,13 @@ div.toprow-compact-tools{
}
#quicksettings > div.model_selection{
min-width: 24em !important;
min-width: 20em !important;
}
#quicksettings .subdued{
display: block;
margin-left: auto;
width: 30px;
}
#settings{