mirror of
https://github.com/lllyasviel/stable-diffusion-webui-forge.git
synced 2026-05-01 03:31:30 +00:00
support all flux models
This commit is contained in:
@@ -10,7 +10,7 @@ from diffusers import DiffusionPipeline
|
|||||||
from transformers import modeling_utils
|
from transformers import modeling_utils
|
||||||
|
|
||||||
from backend import memory_management
|
from backend import memory_management
|
||||||
from backend.utils import read_arbitrary_config
|
from backend.utils import read_arbitrary_config, load_torch_file
|
||||||
from backend.state_dict import try_filter_state_dict, load_state_dict
|
from backend.state_dict import try_filter_state_dict, load_state_dict
|
||||||
from backend.operations import using_forge_operations
|
from backend.operations import using_forge_operations
|
||||||
from backend.nn.vae import IntegratedAutoencoderKL
|
from backend.nn.vae import IntegratedAutoencoderKL
|
||||||
@@ -46,6 +46,8 @@ def load_huggingface_component(guess, component_name, lib_name, cls_name, repo_p
|
|||||||
comp._eventual_warn_about_too_long_sequence = lambda *args, **kwargs: None
|
comp._eventual_warn_about_too_long_sequence = lambda *args, **kwargs: None
|
||||||
return comp
|
return comp
|
||||||
if cls_name in ['AutoencoderKL']:
|
if cls_name in ['AutoencoderKL']:
|
||||||
|
assert isinstance(state_dict, dict) and len(state_dict) > 16, 'You do not have VAE state dict!'
|
||||||
|
|
||||||
config = IntegratedAutoencoderKL.load_config(config_path)
|
config = IntegratedAutoencoderKL.load_config(config_path)
|
||||||
|
|
||||||
with using_forge_operations(device=memory_management.cpu, dtype=memory_management.vae_dtype()):
|
with using_forge_operations(device=memory_management.cpu, dtype=memory_management.vae_dtype()):
|
||||||
@@ -54,6 +56,8 @@ def load_huggingface_component(guess, component_name, lib_name, cls_name, repo_p
|
|||||||
load_state_dict(model, state_dict, ignore_start='loss.')
|
load_state_dict(model, state_dict, ignore_start='loss.')
|
||||||
return model
|
return model
|
||||||
if component_name.startswith('text_encoder') and cls_name in ['CLIPTextModel', 'CLIPTextModelWithProjection']:
|
if component_name.startswith('text_encoder') and cls_name in ['CLIPTextModel', 'CLIPTextModelWithProjection']:
|
||||||
|
assert isinstance(state_dict, dict) and len(state_dict) > 16, 'You do not have CLIP state dict!'
|
||||||
|
|
||||||
from transformers import CLIPTextConfig, CLIPTextModel
|
from transformers import CLIPTextConfig, CLIPTextModel
|
||||||
config = CLIPTextConfig.from_pretrained(config_path)
|
config = CLIPTextConfig.from_pretrained(config_path)
|
||||||
|
|
||||||
@@ -71,6 +75,8 @@ def load_huggingface_component(guess, component_name, lib_name, cls_name, repo_p
|
|||||||
|
|
||||||
return model
|
return model
|
||||||
if cls_name == 'T5EncoderModel':
|
if cls_name == 'T5EncoderModel':
|
||||||
|
assert isinstance(state_dict, dict) and len(state_dict) > 16, 'You do not have T5 state dict!'
|
||||||
|
|
||||||
from backend.nn.t5 import IntegratedT5
|
from backend.nn.t5 import IntegratedT5
|
||||||
config = read_arbitrary_config(config_path)
|
config = read_arbitrary_config(config_path)
|
||||||
|
|
||||||
@@ -78,17 +84,21 @@ def load_huggingface_component(guess, component_name, lib_name, cls_name, repo_p
|
|||||||
sd_dtype = memory_management.state_dict_dtype(state_dict)
|
sd_dtype = memory_management.state_dict_dtype(state_dict)
|
||||||
|
|
||||||
if sd_dtype in [torch.float8_e4m3fn, torch.float8_e5m2]:
|
if sd_dtype in [torch.float8_e4m3fn, torch.float8_e5m2]:
|
||||||
print(f'Using Detected T5 Data Type: {sd_dtype}')
|
|
||||||
dtype = sd_dtype
|
dtype = sd_dtype
|
||||||
|
print(f'Using Detected T5 Data Type: {dtype}')
|
||||||
|
else:
|
||||||
|
print(f'Using Default T5 Data Type: {dtype}')
|
||||||
|
|
||||||
with modeling_utils.no_init_weights():
|
with modeling_utils.no_init_weights():
|
||||||
with using_forge_operations(device=memory_management.cpu, dtype=dtype, manual_cast_enabled=True):
|
with using_forge_operations(device=memory_management.cpu, dtype=dtype, manual_cast_enabled=True):
|
||||||
model = IntegratedT5(config)
|
model = IntegratedT5(config)
|
||||||
|
|
||||||
load_state_dict(model, state_dict, log_name=cls_name, ignore_errors=['transformer.encoder.embed_tokens.weight'])
|
load_state_dict(model, state_dict, log_name=cls_name, ignore_errors=['transformer.encoder.embed_tokens.weight', 'logit_scale'])
|
||||||
|
|
||||||
return model
|
return model
|
||||||
if cls_name in ['UNet2DConditionModel', 'FluxTransformer2DModel']:
|
if cls_name in ['UNet2DConditionModel', 'FluxTransformer2DModel']:
|
||||||
|
assert isinstance(state_dict, dict) and len(state_dict) > 16, 'You do not have model state dict!'
|
||||||
|
|
||||||
model_loader = None
|
model_loader = None
|
||||||
if cls_name == 'UNet2DConditionModel':
|
if cls_name == 'UNet2DConditionModel':
|
||||||
model_loader = lambda c: IntegratedUNet2DConditionModel.from_config(c)
|
model_loader = lambda c: IntegratedUNet2DConditionModel.from_config(c)
|
||||||
@@ -148,16 +158,57 @@ def load_huggingface_component(guess, component_name, lib_name, cls_name, repo_p
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def split_state_dict(sd, sd_vae=None):
|
def replace_state_dict(sd, asd, guess):
|
||||||
guess = huggingface_guess.guess(sd)
|
vae_key_prefix = guess.vae_key_prefix[0]
|
||||||
guess.clip_target = guess.clip_target(sd)
|
text_encoder_key_prefix = guess.text_encoder_key_prefix[0]
|
||||||
|
|
||||||
if sd_vae is not None:
|
if "decoder.conv_in.weight" in asd:
|
||||||
print(f'Using external VAE state dict: {len(sd_vae)}')
|
keys_to_delete = [k for k in sd if k.startswith(vae_key_prefix)]
|
||||||
|
for k in keys_to_delete:
|
||||||
|
del sd[k]
|
||||||
|
for k, v in asd.items():
|
||||||
|
sd[vae_key_prefix + k] = v
|
||||||
|
|
||||||
|
if 'text_model.encoder.layers.0.layer_norm1.weight' in asd:
|
||||||
|
keys_to_delete = [k for k in sd if k.startswith(f"{text_encoder_key_prefix}clip_l.")]
|
||||||
|
for k in keys_to_delete:
|
||||||
|
del sd[k]
|
||||||
|
for k, v in asd.items():
|
||||||
|
sd[f"{text_encoder_key_prefix}clip_l.transformer.{k}"] = v
|
||||||
|
|
||||||
|
if 'encoder.block.0.layer.0.SelfAttention.k.weight' in asd:
|
||||||
|
keys_to_delete = [k for k in sd if k.startswith(f"{text_encoder_key_prefix}t5xxl.")]
|
||||||
|
for k in keys_to_delete:
|
||||||
|
del sd[k]
|
||||||
|
for k, v in asd.items():
|
||||||
|
sd[f"{text_encoder_key_prefix}t5xxl.transformer.{k}"] = v
|
||||||
|
|
||||||
|
return sd
|
||||||
|
|
||||||
|
|
||||||
|
def preprocess_state_dict(sd):
|
||||||
|
if any("double_block" in k for k in sd.keys()):
|
||||||
|
if not any(k.startswith("model.diffusion_model") for k in sd.keys()):
|
||||||
|
sd = {f"model.diffusion_model.{k}": v for k, v in sd.items()}
|
||||||
|
|
||||||
|
return sd
|
||||||
|
|
||||||
|
|
||||||
|
def split_state_dict(sd, additional_state_dicts: list = None):
|
||||||
|
sd = load_torch_file(sd)
|
||||||
|
sd = preprocess_state_dict(sd)
|
||||||
|
guess = huggingface_guess.guess(sd)
|
||||||
|
|
||||||
|
if isinstance(additional_state_dicts, list):
|
||||||
|
for asd in additional_state_dicts:
|
||||||
|
asd = load_torch_file(asd)
|
||||||
|
sd = replace_state_dict(sd, asd, guess)
|
||||||
|
|
||||||
|
guess.clip_target = guess.clip_target(sd)
|
||||||
|
|
||||||
state_dict = {
|
state_dict = {
|
||||||
guess.unet_target: try_filter_state_dict(sd, guess.unet_key_prefix),
|
guess.unet_target: try_filter_state_dict(sd, guess.unet_key_prefix),
|
||||||
guess.vae_target: try_filter_state_dict(sd, guess.vae_key_prefix) if sd_vae is None else sd_vae
|
guess.vae_target: try_filter_state_dict(sd, guess.vae_key_prefix)
|
||||||
}
|
}
|
||||||
|
|
||||||
sd = guess.process_clip_state_dict(sd)
|
sd = guess.process_clip_state_dict(sd)
|
||||||
@@ -176,9 +227,9 @@ def split_state_dict(sd, sd_vae=None):
|
|||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def forge_loader(sd, sd_vae=None):
|
def forge_loader(sd, additional_state_dicts=None):
|
||||||
try:
|
try:
|
||||||
state_dicts, estimated_config = split_state_dict(sd, sd_vae=sd_vae)
|
state_dicts, estimated_config = split_state_dict(sd, additional_state_dicts=additional_state_dicts)
|
||||||
except:
|
except:
|
||||||
raise ValueError('Failed to recognize model type!')
|
raise ValueError('Failed to recognize model type!')
|
||||||
|
|
||||||
|
|||||||
@@ -384,6 +384,7 @@ class LoadedModel:
|
|||||||
if not do_not_need_cpu_swap:
|
if not do_not_need_cpu_swap:
|
||||||
memory_in_swap = 0
|
memory_in_swap = 0
|
||||||
mem_counter = 0
|
mem_counter = 0
|
||||||
|
mem_cannot_cast = 0
|
||||||
for m in self.real_model.modules():
|
for m in self.real_model.modules():
|
||||||
if hasattr(m, "parameters_manual_cast"):
|
if hasattr(m, "parameters_manual_cast"):
|
||||||
m.prev_parameters_manual_cast = m.parameters_manual_cast
|
m.prev_parameters_manual_cast = m.parameters_manual_cast
|
||||||
@@ -399,8 +400,12 @@ class LoadedModel:
|
|||||||
m._apply(lambda x: x.pin_memory())
|
m._apply(lambda x: x.pin_memory())
|
||||||
elif hasattr(m, "weight"):
|
elif hasattr(m, "weight"):
|
||||||
m.to(self.device)
|
m.to(self.device)
|
||||||
mem_counter += module_size(m)
|
module_mem = module_size(m)
|
||||||
print(f"[Memory Management] Swap disabled for", type(m).__name__)
|
mem_counter += module_mem
|
||||||
|
mem_cannot_cast += module_mem
|
||||||
|
|
||||||
|
if mem_cannot_cast > 0:
|
||||||
|
print(f"[Memory Management] Loaded to GPU for backward capability: {mem_cannot_cast / (1024 * 1024):.2f} MB")
|
||||||
|
|
||||||
swap_flag = 'Shared' if PIN_SHARED_MEMORY else 'CPU'
|
swap_flag = 'Shared' if PIN_SHARED_MEMORY else 'CPU'
|
||||||
method_flag = 'asynchronous' if stream.should_use_stream() else 'blocked'
|
method_flag = 'asynchronous' if stream.should_use_stream() else 'blocked'
|
||||||
|
|||||||
@@ -3,6 +3,10 @@ from backend.patcher.base import ModelPatcher
|
|||||||
from backend.nn.base import ModuleDict, ObjectDict
|
from backend.nn.base import ModuleDict, ObjectDict
|
||||||
|
|
||||||
|
|
||||||
|
class JointTextEncoder(ModuleDict):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class CLIP:
|
class CLIP:
|
||||||
def __init__(self, model_dict={}, tokenizer_dict={}, no_init=False):
|
def __init__(self, model_dict={}, tokenizer_dict={}, no_init=False):
|
||||||
if no_init:
|
if no_init:
|
||||||
@@ -11,7 +15,7 @@ class CLIP:
|
|||||||
load_device = memory_management.text_encoder_device()
|
load_device = memory_management.text_encoder_device()
|
||||||
offload_device = memory_management.text_encoder_offload_device()
|
offload_device = memory_management.text_encoder_offload_device()
|
||||||
|
|
||||||
self.cond_stage_model = ModuleDict(model_dict)
|
self.cond_stage_model = JointTextEncoder(model_dict)
|
||||||
self.tokenizer = ObjectDict(tokenizer_dict)
|
self.tokenizer = ObjectDict(tokenizer_dict)
|
||||||
self.patcher = ModelPatcher(self.cond_stage_model, load_device=load_device, offload_device=offload_device)
|
self.patcher = ModelPatcher(self.cond_stage_model, load_device=load_device, offload_device=offload_device)
|
||||||
|
|
||||||
|
|||||||
@@ -740,8 +740,8 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iter
|
|||||||
"Model": p.sd_model_name if opts.add_model_name_to_info else None,
|
"Model": p.sd_model_name if opts.add_model_name_to_info else None,
|
||||||
"FP8 weight": opts.fp8_storage if devices.fp8 else None,
|
"FP8 weight": opts.fp8_storage if devices.fp8 else None,
|
||||||
"Cache FP16 weight for LoRA": opts.cache_fp16_weight if devices.fp8 else None,
|
"Cache FP16 weight for LoRA": opts.cache_fp16_weight if devices.fp8 else None,
|
||||||
"VAE hash": p.sd_vae_hash if opts.add_vae_hash_to_info else None,
|
# "VAE hash": p.sd_vae_hash if opts.add_vae_hash_to_info else None,
|
||||||
"VAE": p.sd_vae_name if opts.add_vae_name_to_info else None,
|
# "VAE": p.sd_vae_name if opts.add_vae_name_to_info else None,
|
||||||
"Variation seed": (None if p.subseed_strength == 0 else (p.all_subseeds[0] if use_main_prompt else all_subseeds[index])),
|
"Variation seed": (None if p.subseed_strength == 0 else (p.all_subseeds[0] if use_main_prompt else all_subseeds[index])),
|
||||||
"Variation seed strength": (None if p.subseed_strength == 0 else p.subseed_strength),
|
"Variation seed strength": (None if p.subseed_strength == 0 else p.subseed_strength),
|
||||||
"Seed resize from": (None if p.seed_resize_from_w <= 0 or p.seed_resize_from_h <= 0 else f"{p.seed_resize_from_w}x{p.seed_resize_from_h}"),
|
"Seed resize from": (None if p.seed_resize_from_w <= 0 or p.seed_resize_from_h <= 0 else f"{p.seed_resize_from_w}x{p.seed_resize_from_h}"),
|
||||||
@@ -759,6 +759,10 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iter
|
|||||||
"User": p.user if opts.add_user_name_to_info else None,
|
"User": p.user if opts.add_user_name_to_info else None,
|
||||||
})
|
})
|
||||||
|
|
||||||
|
if isinstance(shared.opts.forge_additional_modules, list) and len(shared.opts.forge_additional_modules) > 0:
|
||||||
|
for i, m in enumerate(shared.opts.forge_additional_modules):
|
||||||
|
generation_params[f'Module {i+1}'] = os.path.splitext(os.path.basename(m))[0]
|
||||||
|
|
||||||
for key, value in generation_params.items():
|
for key, value in generation_params.items():
|
||||||
try:
|
try:
|
||||||
if isinstance(value, list):
|
if isinstance(value, list):
|
||||||
@@ -787,6 +791,12 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
|
|||||||
if need_global_unload and not just_reloaded:
|
if need_global_unload and not just_reloaded:
|
||||||
memory_management.unload_all_models()
|
memory_management.unload_all_models()
|
||||||
|
|
||||||
|
if need_global_unload:
|
||||||
|
StableDiffusionProcessing.cached_c = [None, None, None]
|
||||||
|
StableDiffusionProcessing.cached_uc = [None, None, None]
|
||||||
|
p.cached_c = [None, None, None]
|
||||||
|
p.cached_uc = [None, None, None]
|
||||||
|
|
||||||
need_global_unload = False
|
need_global_unload = False
|
||||||
|
|
||||||
if p.scripts is not None:
|
if p.scripts is not None:
|
||||||
|
|||||||
@@ -490,27 +490,15 @@ def forge_model_reload():
|
|||||||
timer.record("unload existing model")
|
timer.record("unload existing model")
|
||||||
|
|
||||||
checkpoint_info = model_data.forge_loading_parameters['checkpoint_info']
|
checkpoint_info = model_data.forge_loading_parameters['checkpoint_info']
|
||||||
state_dict = load_torch_file(checkpoint_info.filename)
|
state_dict = checkpoint_info.filename
|
||||||
timer.record("load state dict")
|
additional_state_dicts = model_data.forge_loading_parameters.get('additional_modules', [])
|
||||||
|
|
||||||
state_dict_vae = model_data.forge_loading_parameters.get('vae_filename', None)
|
|
||||||
|
|
||||||
if state_dict_vae is not None:
|
|
||||||
state_dict_vae = load_torch_file(state_dict_vae)
|
|
||||||
|
|
||||||
timer.record("load vae state dict")
|
|
||||||
|
|
||||||
if shared.opts.sd_checkpoint_cache > 0:
|
|
||||||
# cache newly loaded model
|
|
||||||
checkpoints_loaded[checkpoint_info] = state_dict.copy()
|
|
||||||
|
|
||||||
timer.record("cache state dict")
|
timer.record("cache state dict")
|
||||||
|
|
||||||
dynamic_args['forge_unet_storage_dtype'] = model_data.forge_loading_parameters.get('unet_storage_dtype', None)
|
dynamic_args['forge_unet_storage_dtype'] = model_data.forge_loading_parameters.get('unet_storage_dtype', None)
|
||||||
dynamic_args['embedding_dir'] = cmd_opts.embeddings_dir
|
dynamic_args['embedding_dir'] = cmd_opts.embeddings_dir
|
||||||
dynamic_args['emphasis_name'] = opts.emphasis
|
dynamic_args['emphasis_name'] = opts.emphasis
|
||||||
sd_model = forge_loader(state_dict, sd_vae=state_dict_vae)
|
sd_model = forge_loader(state_dict, additional_state_dicts=additional_state_dicts)
|
||||||
del state_dict
|
|
||||||
timer.record("forge model load")
|
timer.record("forge model load")
|
||||||
|
|
||||||
sd_model.extra_generation_params = {}
|
sd_model.extra_generation_params = {}
|
||||||
@@ -520,10 +508,6 @@ def forge_model_reload():
|
|||||||
sd_model.sd_model_hash = checkpoint_info.calculate_shorthash()
|
sd_model.sd_model_hash = checkpoint_info.calculate_shorthash()
|
||||||
timer.record("calculate hash")
|
timer.record("calculate hash")
|
||||||
|
|
||||||
# clean up cache if limit is reached
|
|
||||||
while len(checkpoints_loaded) > shared.opts.sd_checkpoint_cache:
|
|
||||||
checkpoints_loaded.popitem(last=False)
|
|
||||||
|
|
||||||
shared.opts.data["sd_checkpoint_hash"] = checkpoint_info.sha256
|
shared.opts.data["sd_checkpoint_hash"] = checkpoint_info.sha256
|
||||||
|
|
||||||
model_data.set_sd_model(sd_model)
|
model_data.set_sd_model(sd_model)
|
||||||
|
|||||||
@@ -1,9 +1,9 @@
|
|||||||
|
import os
|
||||||
import torch
|
import torch
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
|
|
||||||
from gradio.context import Context
|
from gradio.context import Context
|
||||||
from modules import shared_items, shared, ui_common, sd_models, processing, infotext_utils
|
from modules import shared_items, shared, ui_common, sd_models, processing, infotext_utils, paths
|
||||||
from modules import sd_vae as sd_vae_module
|
|
||||||
from backend import memory_management, stream
|
from backend import memory_management, stream
|
||||||
|
|
||||||
|
|
||||||
@@ -13,7 +13,6 @@ ui_forge_preset: gr.Radio = None
|
|||||||
|
|
||||||
ui_checkpoint: gr.Dropdown = None
|
ui_checkpoint: gr.Dropdown = None
|
||||||
ui_vae: gr.Dropdown = None
|
ui_vae: gr.Dropdown = None
|
||||||
ui_vae_refresh_button: gr.Button = None
|
|
||||||
ui_clip_skip: gr.Slider = None
|
ui_clip_skip: gr.Slider = None
|
||||||
|
|
||||||
ui_forge_unet_storage_dtype_options: gr.Radio = None
|
ui_forge_unet_storage_dtype_options: gr.Radio = None
|
||||||
@@ -29,6 +28,8 @@ forge_unet_storage_dtype_options = {
|
|||||||
'fp8e5': torch.float8_e5m2,
|
'fp8e5': torch.float8_e5m2,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
module_list = {}
|
||||||
|
|
||||||
|
|
||||||
def bind_to_opts(comp, k, save=False, callback=None):
|
def bind_to_opts(comp, k, save=False, callback=None):
|
||||||
def on_change(v):
|
def on_change(v):
|
||||||
@@ -44,7 +45,7 @@ def bind_to_opts(comp, k, save=False, callback=None):
|
|||||||
|
|
||||||
|
|
||||||
def make_checkpoint_manager_ui():
|
def make_checkpoint_manager_ui():
|
||||||
global ui_checkpoint, ui_vae, ui_clip_skip, ui_forge_unet_storage_dtype_options, ui_forge_async_loading, ui_forge_pin_shared_memory, ui_forge_inference_memory, ui_forge_preset, ui_vae_refresh_button
|
global ui_checkpoint, ui_vae, ui_clip_skip, ui_forge_unet_storage_dtype_options, ui_forge_async_loading, ui_forge_pin_shared_memory, ui_forge_inference_memory, ui_forge_preset
|
||||||
|
|
||||||
if shared.opts.sd_model_checkpoint in [None, 'None', 'none', '']:
|
if shared.opts.sd_model_checkpoint in [None, 'None', 'none', '']:
|
||||||
if len(sd_models.checkpoints_list) == 0:
|
if len(sd_models.checkpoints_list) == 0:
|
||||||
@@ -54,22 +55,44 @@ def make_checkpoint_manager_ui():
|
|||||||
|
|
||||||
ui_forge_preset = gr.Radio(label="UI", value=lambda: shared.opts.forge_preset, choices=['sd', 'xl', 'flux', 'all'])
|
ui_forge_preset = gr.Radio(label="UI", value=lambda: shared.opts.forge_preset, choices=['sd', 'xl', 'flux', 'all'])
|
||||||
|
|
||||||
sd_model_checkpoint_args = lambda: {"choices": shared_items.list_checkpoint_tiles(shared.opts.sd_checkpoint_dropdown_use_short)}
|
ckpt_list, vae_list = refresh_models()
|
||||||
|
|
||||||
ui_checkpoint = gr.Dropdown(
|
ui_checkpoint = gr.Dropdown(
|
||||||
value=lambda: shared.opts.sd_model_checkpoint,
|
value=lambda: shared.opts.sd_model_checkpoint,
|
||||||
label="Checkpoint",
|
label="Checkpoint",
|
||||||
elem_classes=['model_selection'],
|
elem_classes=['model_selection'],
|
||||||
**sd_model_checkpoint_args()
|
choices=ckpt_list
|
||||||
)
|
)
|
||||||
ui_common.create_refresh_button(ui_checkpoint, shared_items.refresh_checkpoints, sd_model_checkpoint_args, f"forge_refresh_checkpoint")
|
|
||||||
|
|
||||||
sd_vae_args = lambda: {"choices": shared_items.sd_vae_items()}
|
|
||||||
ui_vae = gr.Dropdown(
|
ui_vae = gr.Dropdown(
|
||||||
value=lambda: shared.opts.sd_vae,
|
value=lambda: [os.path.basename(x) for x in shared.opts.forge_additional_modules],
|
||||||
label="VAE",
|
multiselect=True,
|
||||||
**sd_vae_args()
|
label="VAE / Text Encoder",
|
||||||
|
render=False,
|
||||||
|
choices=vae_list
|
||||||
)
|
)
|
||||||
ui_vae_refresh_button = ui_common.create_refresh_button(ui_vae, shared_items.refresh_vae_list, sd_vae_args, f"forge_refresh_vae")
|
|
||||||
|
def gr_refresh_models():
|
||||||
|
a, b = refresh_models()
|
||||||
|
return gr.update(choices=a), gr.update(choices=b)
|
||||||
|
|
||||||
|
refresh_button = ui_common.ToolButton(value=ui_common.refresh_symbol, elem_id=f"forge_refresh_checkpoint", tooltip="Refresh")
|
||||||
|
refresh_button.click(
|
||||||
|
fn=gr_refresh_models,
|
||||||
|
inputs=[],
|
||||||
|
outputs=[ui_checkpoint, ui_vae],
|
||||||
|
show_progress=False,
|
||||||
|
queue=False
|
||||||
|
)
|
||||||
|
Context.root_block.load(
|
||||||
|
fn=gr_refresh_models,
|
||||||
|
inputs=[],
|
||||||
|
outputs=[ui_checkpoint, ui_vae],
|
||||||
|
show_progress=False,
|
||||||
|
queue=False
|
||||||
|
)
|
||||||
|
|
||||||
|
ui_vae.render()
|
||||||
|
|
||||||
ui_forge_unet_storage_dtype_options = gr.Radio(label="Diffusion with Low Bits", value=lambda: shared.opts.forge_unet_storage_dtype, choices=list(forge_unet_storage_dtype_options.keys()))
|
ui_forge_unet_storage_dtype_options = gr.Radio(label="Diffusion with Low Bits", value=lambda: shared.opts.forge_unet_storage_dtype, choices=list(forge_unet_storage_dtype_options.keys()))
|
||||||
bind_to_opts(ui_forge_unet_storage_dtype_options, 'forge_unet_storage_dtype', save=True, callback=refresh_model_loading_parameters)
|
bind_to_opts(ui_forge_unet_storage_dtype_options, 'forge_unet_storage_dtype', save=True, callback=refresh_model_loading_parameters)
|
||||||
@@ -94,6 +117,37 @@ def make_checkpoint_manager_ui():
|
|||||||
return
|
return
|
||||||
|
|
||||||
|
|
||||||
|
def find_files_with_extensions(base_path, extensions):
|
||||||
|
found_files = {}
|
||||||
|
for root, _, files in os.walk(base_path):
|
||||||
|
for file in files:
|
||||||
|
if any(file.endswith(ext) for ext in extensions):
|
||||||
|
full_path = os.path.join(root, file)
|
||||||
|
found_files[file] = full_path
|
||||||
|
return found_files
|
||||||
|
|
||||||
|
|
||||||
|
def refresh_models():
|
||||||
|
global module_list
|
||||||
|
|
||||||
|
shared_items.refresh_checkpoints()
|
||||||
|
ckpt_list = shared_items.list_checkpoint_tiles(shared.opts.sd_checkpoint_dropdown_use_short)
|
||||||
|
|
||||||
|
vae_path = os.path.abspath(os.path.join(paths.models_path, "VAE"))
|
||||||
|
text_encoder_path = os.path.abspath(os.path.join(paths.models_path, "text_encoder"))
|
||||||
|
file_extensions = ['ckpt', 'pt', 'bin', 'safetensors']
|
||||||
|
|
||||||
|
module_list.clear()
|
||||||
|
|
||||||
|
vae_files = find_files_with_extensions(vae_path, file_extensions)
|
||||||
|
module_list.update(vae_files)
|
||||||
|
|
||||||
|
text_encoder_files = find_files_with_extensions(text_encoder_path, file_extensions)
|
||||||
|
module_list.update(text_encoder_files)
|
||||||
|
|
||||||
|
return ckpt_list, module_list.keys()
|
||||||
|
|
||||||
|
|
||||||
def refresh_memory_management_settings(model_memory, async_loading, pin_shared_memory):
|
def refresh_memory_management_settings(model_memory, async_loading, pin_shared_memory):
|
||||||
inference_memory = total_vram - model_memory
|
inference_memory = total_vram - model_memory
|
||||||
|
|
||||||
@@ -121,11 +175,10 @@ def refresh_model_loading_parameters():
|
|||||||
from modules.sd_models import select_checkpoint, model_data
|
from modules.sd_models import select_checkpoint, model_data
|
||||||
|
|
||||||
checkpoint_info = select_checkpoint()
|
checkpoint_info = select_checkpoint()
|
||||||
vae_resolution = sd_vae_module.resolve_vae(checkpoint_info.filename)
|
|
||||||
|
|
||||||
model_data.forge_loading_parameters = dict(
|
model_data.forge_loading_parameters = dict(
|
||||||
checkpoint_info=checkpoint_info,
|
checkpoint_info=checkpoint_info,
|
||||||
vae_filename=vae_resolution.vae,
|
additional_modules=shared.opts.forge_additional_modules,
|
||||||
unet_storage_dtype=forge_unet_storage_dtype_options.get(shared.opts.forge_unet_storage_dtype, None)
|
unet_storage_dtype=forge_unet_storage_dtype_options.get(shared.opts.forge_unet_storage_dtype, None)
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -142,8 +195,15 @@ def checkpoint_change(ckpt_name):
|
|||||||
return
|
return
|
||||||
|
|
||||||
|
|
||||||
def vae_change(vae_name):
|
def vae_change(module_names):
|
||||||
shared.opts.set('sd_vae', vae_name)
|
modules = []
|
||||||
|
|
||||||
|
for n in module_names:
|
||||||
|
if n in module_list:
|
||||||
|
modules.append(module_list[n])
|
||||||
|
|
||||||
|
shared.opts.set('forge_additional_modules', modules)
|
||||||
|
shared.opts.save(shared.config_filename)
|
||||||
|
|
||||||
refresh_model_loading_parameters()
|
refresh_model_loading_parameters()
|
||||||
return
|
return
|
||||||
@@ -173,7 +233,6 @@ def forge_main_entry():
|
|||||||
|
|
||||||
output_targets = [
|
output_targets = [
|
||||||
ui_vae,
|
ui_vae,
|
||||||
ui_vae_refresh_button,
|
|
||||||
ui_clip_skip,
|
ui_clip_skip,
|
||||||
ui_forge_unet_storage_dtype_options,
|
ui_forge_unet_storage_dtype_options,
|
||||||
ui_forge_async_loading,
|
ui_forge_async_loading,
|
||||||
@@ -207,8 +266,7 @@ def on_preset_change(preset=None):
|
|||||||
|
|
||||||
if shared.opts.forge_preset == 'sd':
|
if shared.opts.forge_preset == 'sd':
|
||||||
return [
|
return [
|
||||||
gr.update(visible=True, value='Automatic'), # ui_vae
|
gr.update(visible=True), # ui_vae
|
||||||
gr.update(visible=True), # ui_vae_refresh_button
|
|
||||||
gr.update(visible=True, value=1), # ui_clip_skip
|
gr.update(visible=True, value=1), # ui_clip_skip
|
||||||
gr.update(visible=False, value='Auto'), # ui_forge_unet_storage_dtype_options
|
gr.update(visible=False, value='Auto'), # ui_forge_unet_storage_dtype_options
|
||||||
gr.update(visible=False, value='Queue'), # ui_forge_async_loading
|
gr.update(visible=False, value='Queue'), # ui_forge_async_loading
|
||||||
@@ -230,8 +288,7 @@ def on_preset_change(preset=None):
|
|||||||
|
|
||||||
if shared.opts.forge_preset == 'xl':
|
if shared.opts.forge_preset == 'xl':
|
||||||
return [
|
return [
|
||||||
gr.update(visible=False, value='Automatic'), # ui_vae
|
gr.update(visible=True), # ui_vae
|
||||||
gr.update(visible=False), # ui_vae_refresh_button
|
|
||||||
gr.update(visible=False, value=1), # ui_clip_skip
|
gr.update(visible=False, value=1), # ui_clip_skip
|
||||||
gr.update(visible=True, value='Auto'), # ui_forge_unet_storage_dtype_options
|
gr.update(visible=True, value='Auto'), # ui_forge_unet_storage_dtype_options
|
||||||
gr.update(visible=False, value='Queue'), # ui_forge_async_loading
|
gr.update(visible=False, value='Queue'), # ui_forge_async_loading
|
||||||
@@ -253,8 +310,7 @@ def on_preset_change(preset=None):
|
|||||||
|
|
||||||
if shared.opts.forge_preset == 'flux':
|
if shared.opts.forge_preset == 'flux':
|
||||||
return [
|
return [
|
||||||
gr.update(visible=False, value='Automatic'), # ui_vae
|
gr.update(visible=True), # ui_vae
|
||||||
gr.update(visible=False), # ui_vae_refresh_button
|
|
||||||
gr.update(visible=False, value=1), # ui_clip_skip
|
gr.update(visible=False, value=1), # ui_clip_skip
|
||||||
gr.update(visible=True, value='Auto'), # ui_forge_unet_storage_dtype_options
|
gr.update(visible=True, value='Auto'), # ui_forge_unet_storage_dtype_options
|
||||||
gr.update(visible=True, value='Queue'), # ui_forge_async_loading
|
gr.update(visible=True, value='Queue'), # ui_forge_async_loading
|
||||||
@@ -275,8 +331,7 @@ def on_preset_change(preset=None):
|
|||||||
]
|
]
|
||||||
|
|
||||||
return [
|
return [
|
||||||
gr.update(visible=True, value='Automatic'), # ui_vae
|
gr.update(visible=True), # ui_vae
|
||||||
gr.update(visible=True), # ui_vae_refresh_button
|
|
||||||
gr.update(visible=True, value=1), # ui_clip_skip
|
gr.update(visible=True, value=1), # ui_clip_skip
|
||||||
gr.update(visible=True, value='Auto'), # ui_forge_unet_storage_dtype_options
|
gr.update(visible=True, value='Auto'), # ui_forge_unet_storage_dtype_options
|
||||||
gr.update(visible=True, value='Queue'), # ui_forge_async_loading
|
gr.update(visible=True, value='Queue'), # ui_forge_async_loading
|
||||||
|
|||||||
@@ -6,4 +6,5 @@ def register(options_templates, options_section, OptionInfo):
|
|||||||
"forge_async_loading": OptionInfo('Queue'),
|
"forge_async_loading": OptionInfo('Queue'),
|
||||||
"forge_pin_shared_memory": OptionInfo('CPU'),
|
"forge_pin_shared_memory": OptionInfo('CPU'),
|
||||||
"forge_preset": OptionInfo('sd'),
|
"forge_preset": OptionInfo('sd'),
|
||||||
|
"forge_additional_modules": OptionInfo([]),
|
||||||
}))
|
}))
|
||||||
|
|||||||
@@ -440,7 +440,13 @@ div.toprow-compact-tools{
|
|||||||
}
|
}
|
||||||
|
|
||||||
#quicksettings > div.model_selection{
|
#quicksettings > div.model_selection{
|
||||||
min-width: 24em !important;
|
min-width: 20em !important;
|
||||||
|
}
|
||||||
|
|
||||||
|
#quicksettings .subdued{
|
||||||
|
display: block;
|
||||||
|
margin-left: auto;
|
||||||
|
width: 30px;
|
||||||
}
|
}
|
||||||
|
|
||||||
#settings{
|
#settings{
|
||||||
|
|||||||
Reference in New Issue
Block a user