Files
stable-diffusion-webui-forge/modules_forge/main_entry.py
layerdiffusion 02ffb04649 revise stream
2024-08-08 19:23:23 -07:00

125 lines
4.5 KiB
Python

import torch
import gradio as gr
from modules import shared_items, shared, ui_common, sd_models
from modules import sd_vae as sd_vae_module
from backend import memory_management, stream
ui_checkpoint: gr.Dropdown = None
ui_vae: gr.Dropdown = None
ui_clip_skip: gr.Slider = None
forge_unet_storage_dtype_options = {
'None': None,
'fp8e4m3': torch.float8_e4m3fn,
'fp8e5m2': torch.float8_e5m2,
}
def bind_to_opts(comp, k, save=False, callback=None):
def on_change(v):
shared.opts.set(k, v)
if save:
shared.opts.save(shared.config_filename)
if callback is not None:
callback()
return
comp.change(on_change, inputs=[comp], show_progress=False)
return
def make_checkpoint_manager_ui():
global ui_checkpoint, ui_vae, ui_clip_skip
if shared.opts.sd_model_checkpoint in [None, 'None', 'none', '']:
if len(sd_models.checkpoints_list) == 0:
sd_models.list_models()
if len(sd_models.checkpoints_list) > 0:
shared.opts.set('sd_model_checkpoint', next(iter(sd_models.checkpoints_list.keys())))
sd_model_checkpoint_args = lambda: {"choices": shared_items.list_checkpoint_tiles(shared.opts.sd_checkpoint_dropdown_use_short)}
ui_checkpoint = gr.Dropdown(
value=shared.opts.sd_model_checkpoint,
label="Checkpoint",
elem_classes=['model_selection'],
**sd_model_checkpoint_args()
)
ui_common.create_refresh_button(ui_checkpoint, shared_items.refresh_checkpoints, sd_model_checkpoint_args, f"forge_refresh_checkpoint")
sd_vae_args = lambda: {"choices": shared_items.sd_vae_items()}
ui_vae = gr.Dropdown(
value=shared.opts.sd_vae,
label="VAE",
**sd_vae_args()
)
ui_common.create_refresh_button(ui_vae, shared_items.refresh_vae_list, sd_vae_args, f"forge_refresh_vae")
ui_forge_unet_storage_dtype_options = gr.Radio(label="Diffusion in FP8", value=shared.opts.forge_unet_storage_dtype, choices=list(forge_unet_storage_dtype_options.keys()))
bind_to_opts(ui_forge_unet_storage_dtype_options, 'forge_unet_storage_dtype', save=True, callback=refresh_model_loading_parameters)
from backend.args import args as backend_args
ui_forge_inference_memory = gr.Slider(label="Inference Memory (MB)", value=shared.opts.forge_inference_memory, minimum=0, maximum=int(memory_management.total_vram), step=128, visible=backend_args.i_am_lllyasviel)
bind_to_opts(ui_forge_inference_memory, 'forge_inference_memory', save=False, callback=refresh_memory_management_settings)
ui_forge_async_loading = gr.Checkbox(label="Async Loader", value=shared.opts.forge_async_loading, visible=backend_args.i_am_lllyasviel)
bind_to_opts(ui_forge_async_loading, 'forge_async_loading', save=False, callback=refresh_memory_management_settings)
ui_clip_skip = gr.Slider(label="Clip skip", value=shared.opts.CLIP_stop_at_last_layers, **{"minimum": 1, "maximum": 12, "step": 1})
bind_to_opts(ui_clip_skip, 'CLIP_stop_at_last_layers', save=False)
return
def refresh_memory_management_settings():
stream.stream_activated = shared.opts.forge_async_loading
memory_management.current_inference_memory = shared.opts.forge_inference_memory * 1024 * 1024
print(f'Stream Set to: {stream.stream_activated}')
print(f'Stream Used by CUDA: {stream.should_use_stream()}')
print(f'Current Inference Memory: {memory_management.minimum_inference_memory() / (1024 * 1024):.2f} MB')
return
def refresh_model_loading_parameters():
from modules.sd_models import select_checkpoint, model_data
checkpoint_info = select_checkpoint()
vae_resolution = sd_vae_module.resolve_vae(checkpoint_info.filename)
model_data.forge_loading_parameters = dict(
checkpoint_info=checkpoint_info,
vae_filename=vae_resolution.vae,
unet_storage_dtype=forge_unet_storage_dtype_options[shared.opts.forge_unet_storage_dtype]
)
print(f'Loading parameters: {model_data.forge_loading_parameters}')
return
def checkpoint_change(ckpt_name):
shared.opts.set('sd_model_checkpoint', ckpt_name)
shared.opts.save(shared.config_filename)
refresh_model_loading_parameters()
return
def vae_change(vae_name):
shared.opts.set('sd_vae', vae_name)
refresh_model_loading_parameters()
return
def forge_main_entry():
ui_checkpoint.change(checkpoint_change, inputs=[ui_checkpoint], show_progress=False)
ui_vae.change(vae_change, inputs=[ui_vae], show_progress=False)
# Load Model
refresh_model_loading_parameters()
return