mirror of
https://github.com/lllyasviel/stable-diffusion-webui-forge.git
synced 2026-03-13 08:59:51 +00:00
@@ -3,7 +3,6 @@ import gradio as gr
|
||||
|
||||
from modules import shared_items, shared, ui_common, sd_models
|
||||
from modules import sd_vae as sd_vae_module
|
||||
from modules_forge import main_thread
|
||||
from backend import args as backend_args
|
||||
|
||||
|
||||
@@ -59,7 +58,7 @@ def make_checkpoint_manager_ui():
|
||||
ui_common.create_refresh_button(ui_vae, shared_items.refresh_vae_list, sd_vae_args, f"forge_refresh_vae")
|
||||
|
||||
ui_forge_unet_storage_dtype_options = gr.Radio(label="Diffusion in FP8", value=shared.opts.forge_unet_storage_dtype, choices=list(forge_unet_storage_dtype_options.keys()))
|
||||
bind_to_opts(ui_forge_unet_storage_dtype_options, 'forge_unet_storage_dtype', save=True, callback=lambda: main_thread.async_run(model_load_entry))
|
||||
bind_to_opts(ui_forge_unet_storage_dtype_options, 'forge_unet_storage_dtype', save=True, callback=refresh_model_loading_parameters)
|
||||
|
||||
ui_clip_skip = gr.Slider(label="Clip skip", value=shared.opts.CLIP_stop_at_last_layers, **{"minimum": 1, "maximum": 12, "step": 1})
|
||||
bind_to_opts(ui_clip_skip, 'CLIP_stop_at_last_layers', save=False)
|
||||
@@ -67,12 +66,18 @@ def make_checkpoint_manager_ui():
|
||||
return
|
||||
|
||||
|
||||
def model_load_entry():
|
||||
backend_args.dynamic_args.update(dict(
|
||||
forge_unet_storage_dtype=forge_unet_storage_dtype_options[shared.opts.forge_unet_storage_dtype]
|
||||
))
|
||||
def refresh_model_loading_parameters():
|
||||
from modules.sd_models import select_checkpoint, model_data
|
||||
|
||||
checkpoint_info = select_checkpoint()
|
||||
vae_resolution = sd_vae_module.resolve_vae(checkpoint_info.filename)
|
||||
|
||||
model_data.forge_loading_parameters = dict(
|
||||
checkpoint_info=checkpoint_info,
|
||||
vae_filename=vae_resolution.vae,
|
||||
unet_storage_dtype=forge_unet_storage_dtype_options[shared.opts.forge_unet_storage_dtype]
|
||||
)
|
||||
|
||||
sd_models.forge_model_reload()
|
||||
return
|
||||
|
||||
|
||||
@@ -81,21 +86,22 @@ def checkpoint_change(ckpt_name):
|
||||
shared.opts.set('sd_model_checkpoint', ckpt_name)
|
||||
shared.opts.save(shared.config_filename)
|
||||
|
||||
model_load_entry()
|
||||
refresh_model_loading_parameters()
|
||||
return
|
||||
|
||||
|
||||
def vae_change(vae_name):
|
||||
print(f'VAE Selected: {vae_name}')
|
||||
shared.opts.set('sd_vae', vae_name)
|
||||
sd_vae_module.reload_vae_weights()
|
||||
|
||||
refresh_model_loading_parameters()
|
||||
return
|
||||
|
||||
|
||||
def forge_main_entry():
|
||||
ui_checkpoint.change(lambda x: main_thread.async_run(checkpoint_change, x), inputs=[ui_checkpoint], show_progress=False)
|
||||
ui_vae.change(lambda x: main_thread.async_run(vae_change, x), inputs=[ui_vae], show_progress=False)
|
||||
ui_checkpoint.change(checkpoint_change, inputs=[ui_checkpoint], show_progress=False)
|
||||
ui_vae.change(vae_change, inputs=[ui_vae], show_progress=False)
|
||||
|
||||
# Load Model
|
||||
main_thread.async_run(model_load_entry)
|
||||
refresh_model_loading_parameters()
|
||||
return
|
||||
|
||||
Reference in New Issue
Block a user