mirror of
https://github.com/lllyasviel/stable-diffusion-webui-forge.git
synced 2026-04-27 17:51:22 +00:00
support all flux models
This commit is contained in:
@@ -1,9 +1,9 @@
|
||||
import os
|
||||
import torch
|
||||
import gradio as gr
|
||||
|
||||
from gradio.context import Context
|
||||
from modules import shared_items, shared, ui_common, sd_models, processing, infotext_utils
|
||||
from modules import sd_vae as sd_vae_module
|
||||
from modules import shared_items, shared, ui_common, sd_models, processing, infotext_utils, paths
|
||||
from backend import memory_management, stream
|
||||
|
||||
|
||||
@@ -13,7 +13,6 @@ ui_forge_preset: gr.Radio = None
|
||||
|
||||
ui_checkpoint: gr.Dropdown = None
|
||||
ui_vae: gr.Dropdown = None
|
||||
ui_vae_refresh_button: gr.Button = None
|
||||
ui_clip_skip: gr.Slider = None
|
||||
|
||||
ui_forge_unet_storage_dtype_options: gr.Radio = None
|
||||
@@ -29,6 +28,8 @@ forge_unet_storage_dtype_options = {
|
||||
'fp8e5': torch.float8_e5m2,
|
||||
}
|
||||
|
||||
module_list = {}
|
||||
|
||||
|
||||
def bind_to_opts(comp, k, save=False, callback=None):
|
||||
def on_change(v):
|
||||
@@ -44,7 +45,7 @@ def bind_to_opts(comp, k, save=False, callback=None):
|
||||
|
||||
|
||||
def make_checkpoint_manager_ui():
|
||||
global ui_checkpoint, ui_vae, ui_clip_skip, ui_forge_unet_storage_dtype_options, ui_forge_async_loading, ui_forge_pin_shared_memory, ui_forge_inference_memory, ui_forge_preset, ui_vae_refresh_button
|
||||
global ui_checkpoint, ui_vae, ui_clip_skip, ui_forge_unet_storage_dtype_options, ui_forge_async_loading, ui_forge_pin_shared_memory, ui_forge_inference_memory, ui_forge_preset
|
||||
|
||||
if shared.opts.sd_model_checkpoint in [None, 'None', 'none', '']:
|
||||
if len(sd_models.checkpoints_list) == 0:
|
||||
@@ -54,22 +55,44 @@ def make_checkpoint_manager_ui():
|
||||
|
||||
ui_forge_preset = gr.Radio(label="UI", value=lambda: shared.opts.forge_preset, choices=['sd', 'xl', 'flux', 'all'])
|
||||
|
||||
sd_model_checkpoint_args = lambda: {"choices": shared_items.list_checkpoint_tiles(shared.opts.sd_checkpoint_dropdown_use_short)}
|
||||
ckpt_list, vae_list = refresh_models()
|
||||
|
||||
ui_checkpoint = gr.Dropdown(
|
||||
value=lambda: shared.opts.sd_model_checkpoint,
|
||||
label="Checkpoint",
|
||||
elem_classes=['model_selection'],
|
||||
**sd_model_checkpoint_args()
|
||||
choices=ckpt_list
|
||||
)
|
||||
ui_common.create_refresh_button(ui_checkpoint, shared_items.refresh_checkpoints, sd_model_checkpoint_args, f"forge_refresh_checkpoint")
|
||||
|
||||
sd_vae_args = lambda: {"choices": shared_items.sd_vae_items()}
|
||||
ui_vae = gr.Dropdown(
|
||||
value=lambda: shared.opts.sd_vae,
|
||||
label="VAE",
|
||||
**sd_vae_args()
|
||||
value=lambda: [os.path.basename(x) for x in shared.opts.forge_additional_modules],
|
||||
multiselect=True,
|
||||
label="VAE / Text Encoder",
|
||||
render=False,
|
||||
choices=vae_list
|
||||
)
|
||||
ui_vae_refresh_button = ui_common.create_refresh_button(ui_vae, shared_items.refresh_vae_list, sd_vae_args, f"forge_refresh_vae")
|
||||
|
||||
def gr_refresh_models():
|
||||
a, b = refresh_models()
|
||||
return gr.update(choices=a), gr.update(choices=b)
|
||||
|
||||
refresh_button = ui_common.ToolButton(value=ui_common.refresh_symbol, elem_id=f"forge_refresh_checkpoint", tooltip="Refresh")
|
||||
refresh_button.click(
|
||||
fn=gr_refresh_models,
|
||||
inputs=[],
|
||||
outputs=[ui_checkpoint, ui_vae],
|
||||
show_progress=False,
|
||||
queue=False
|
||||
)
|
||||
Context.root_block.load(
|
||||
fn=gr_refresh_models,
|
||||
inputs=[],
|
||||
outputs=[ui_checkpoint, ui_vae],
|
||||
show_progress=False,
|
||||
queue=False
|
||||
)
|
||||
|
||||
ui_vae.render()
|
||||
|
||||
ui_forge_unet_storage_dtype_options = gr.Radio(label="Diffusion with Low Bits", value=lambda: shared.opts.forge_unet_storage_dtype, choices=list(forge_unet_storage_dtype_options.keys()))
|
||||
bind_to_opts(ui_forge_unet_storage_dtype_options, 'forge_unet_storage_dtype', save=True, callback=refresh_model_loading_parameters)
|
||||
@@ -94,6 +117,37 @@ def make_checkpoint_manager_ui():
|
||||
return
|
||||
|
||||
|
||||
def find_files_with_extensions(base_path, extensions):
|
||||
found_files = {}
|
||||
for root, _, files in os.walk(base_path):
|
||||
for file in files:
|
||||
if any(file.endswith(ext) for ext in extensions):
|
||||
full_path = os.path.join(root, file)
|
||||
found_files[file] = full_path
|
||||
return found_files
|
||||
|
||||
|
||||
def refresh_models():
|
||||
global module_list
|
||||
|
||||
shared_items.refresh_checkpoints()
|
||||
ckpt_list = shared_items.list_checkpoint_tiles(shared.opts.sd_checkpoint_dropdown_use_short)
|
||||
|
||||
vae_path = os.path.abspath(os.path.join(paths.models_path, "VAE"))
|
||||
text_encoder_path = os.path.abspath(os.path.join(paths.models_path, "text_encoder"))
|
||||
file_extensions = ['ckpt', 'pt', 'bin', 'safetensors']
|
||||
|
||||
module_list.clear()
|
||||
|
||||
vae_files = find_files_with_extensions(vae_path, file_extensions)
|
||||
module_list.update(vae_files)
|
||||
|
||||
text_encoder_files = find_files_with_extensions(text_encoder_path, file_extensions)
|
||||
module_list.update(text_encoder_files)
|
||||
|
||||
return ckpt_list, module_list.keys()
|
||||
|
||||
|
||||
def refresh_memory_management_settings(model_memory, async_loading, pin_shared_memory):
|
||||
inference_memory = total_vram - model_memory
|
||||
|
||||
@@ -121,11 +175,10 @@ def refresh_model_loading_parameters():
|
||||
from modules.sd_models import select_checkpoint, model_data
|
||||
|
||||
checkpoint_info = select_checkpoint()
|
||||
vae_resolution = sd_vae_module.resolve_vae(checkpoint_info.filename)
|
||||
|
||||
model_data.forge_loading_parameters = dict(
|
||||
checkpoint_info=checkpoint_info,
|
||||
vae_filename=vae_resolution.vae,
|
||||
additional_modules=shared.opts.forge_additional_modules,
|
||||
unet_storage_dtype=forge_unet_storage_dtype_options.get(shared.opts.forge_unet_storage_dtype, None)
|
||||
)
|
||||
|
||||
@@ -142,8 +195,15 @@ def checkpoint_change(ckpt_name):
|
||||
return
|
||||
|
||||
|
||||
def vae_change(vae_name):
|
||||
shared.opts.set('sd_vae', vae_name)
|
||||
def vae_change(module_names):
|
||||
modules = []
|
||||
|
||||
for n in module_names:
|
||||
if n in module_list:
|
||||
modules.append(module_list[n])
|
||||
|
||||
shared.opts.set('forge_additional_modules', modules)
|
||||
shared.opts.save(shared.config_filename)
|
||||
|
||||
refresh_model_loading_parameters()
|
||||
return
|
||||
@@ -173,7 +233,6 @@ def forge_main_entry():
|
||||
|
||||
output_targets = [
|
||||
ui_vae,
|
||||
ui_vae_refresh_button,
|
||||
ui_clip_skip,
|
||||
ui_forge_unet_storage_dtype_options,
|
||||
ui_forge_async_loading,
|
||||
@@ -207,8 +266,7 @@ def on_preset_change(preset=None):
|
||||
|
||||
if shared.opts.forge_preset == 'sd':
|
||||
return [
|
||||
gr.update(visible=True, value='Automatic'), # ui_vae
|
||||
gr.update(visible=True), # ui_vae_refresh_button
|
||||
gr.update(visible=True), # ui_vae
|
||||
gr.update(visible=True, value=1), # ui_clip_skip
|
||||
gr.update(visible=False, value='Auto'), # ui_forge_unet_storage_dtype_options
|
||||
gr.update(visible=False, value='Queue'), # ui_forge_async_loading
|
||||
@@ -230,8 +288,7 @@ def on_preset_change(preset=None):
|
||||
|
||||
if shared.opts.forge_preset == 'xl':
|
||||
return [
|
||||
gr.update(visible=False, value='Automatic'), # ui_vae
|
||||
gr.update(visible=False), # ui_vae_refresh_button
|
||||
gr.update(visible=True), # ui_vae
|
||||
gr.update(visible=False, value=1), # ui_clip_skip
|
||||
gr.update(visible=True, value='Auto'), # ui_forge_unet_storage_dtype_options
|
||||
gr.update(visible=False, value='Queue'), # ui_forge_async_loading
|
||||
@@ -253,8 +310,7 @@ def on_preset_change(preset=None):
|
||||
|
||||
if shared.opts.forge_preset == 'flux':
|
||||
return [
|
||||
gr.update(visible=False, value='Automatic'), # ui_vae
|
||||
gr.update(visible=False), # ui_vae_refresh_button
|
||||
gr.update(visible=True), # ui_vae
|
||||
gr.update(visible=False, value=1), # ui_clip_skip
|
||||
gr.update(visible=True, value='Auto'), # ui_forge_unet_storage_dtype_options
|
||||
gr.update(visible=True, value='Queue'), # ui_forge_async_loading
|
||||
@@ -275,8 +331,7 @@ def on_preset_change(preset=None):
|
||||
]
|
||||
|
||||
return [
|
||||
gr.update(visible=True, value='Automatic'), # ui_vae
|
||||
gr.update(visible=True), # ui_vae_refresh_button
|
||||
gr.update(visible=True), # ui_vae
|
||||
gr.update(visible=True, value=1), # ui_clip_skip
|
||||
gr.update(visible=True, value='Auto'), # ui_forge_unet_storage_dtype_options
|
||||
gr.update(visible=True, value='Queue'), # ui_forge_async_loading
|
||||
|
||||
Reference in New Issue
Block a user