diff --git a/modules/initialize_util.py b/modules/initialize_util.py index 0458e222..afda4206 100644 --- a/modules/initialize_util.py +++ b/modules/initialize_util.py @@ -181,14 +181,14 @@ def configure_opts_onchange(): from modules.call_queue import wrap_queued_call from modules_forge import main_thread - shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: main_thread.run_and_wait_result(sd_models.reload_model_weights)), call=False) - shared.opts.onchange("sd_vae", wrap_queued_call(lambda: main_thread.run_and_wait_result(sd_vae.reload_vae_weights)), call=False) + # shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: main_thread.run_and_wait_result(sd_models.reload_model_weights)), call=False) + # shared.opts.onchange("sd_vae", wrap_queued_call(lambda: main_thread.run_and_wait_result(sd_vae.reload_vae_weights)), call=False) shared.opts.onchange("sd_vae_overrides_per_model_preferences", wrap_queued_call(lambda: main_thread.run_and_wait_result(sd_vae.reload_vae_weights)), call=False) shared.opts.onchange("temp_dir", ui_tempdir.on_tmpdir_changed) shared.opts.onchange("gradio_theme", shared.reload_gradio_theme) # shared.opts.onchange("cross_attention_optimization", wrap_queued_call(lambda: sd_hijack.model_hijack.redo_hijack(shared.sd_model)), call=False) - shared.opts.onchange("fp8_storage", wrap_queued_call(lambda: sd_models.reload_model_weights()), call=False) - shared.opts.onchange("cache_fp16_weight", wrap_queued_call(lambda: sd_models.reload_model_weights(forced_reload=True)), call=False) + # shared.opts.onchange("fp8_storage", wrap_queued_call(lambda: sd_models.reload_model_weights()), call=False) + # shared.opts.onchange("cache_fp16_weight", wrap_queued_call(lambda: sd_models.reload_model_weights(forced_reload=True)), call=False) startup_timer.record("opts onchange") diff --git a/modules/processing.py b/modules/processing.py index fea50715..8eedbca4 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -1502,8 +1502,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): if self.enable_hr and self.hr_checkpoint_info is None: if shared.opts.hires_fix_use_firstpass_conds: self.calculate_hr_conds() - - elif shared.sd_model.sd_checkpoint_info == sd_models.select_checkpoint(): # if in lowvram mode, we need to calculate conds right away, before the cond NN is unloaded + else: with devices.autocast(): extra_networks.activate(self, self.hr_extra_network_data) diff --git a/modules/processing_scripts/refiner.py b/modules/processing_scripts/refiner.py index 669b8cac..43723073 100644 --- a/modules/processing_scripts/refiner.py +++ b/modules/processing_scripts/refiner.py @@ -24,10 +24,10 @@ class ScriptRefiner(scripts.ScriptBuiltinUI): gr.Markdown('Refiner is currently under maintenance and unavailable. Sorry for the inconvenience.') with gr.Row(): - refiner_checkpoint = gr.Dropdown(label='Checkpoint', elem_id=self.elem_id("checkpoint"), choices=["", *sd_models.checkpoint_tiles()], value='', tooltip="switch to another model in the middle of generation", interactive=False) - create_refresh_button(refiner_checkpoint, sd_models.list_models, lambda: {"choices": sd_models.checkpoint_tiles()}, self.elem_id("checkpoint_refresh")) + refiner_checkpoint = gr.Dropdown(label='Checkpoint', elem_id=self.elem_id("checkpoint"), choices=["", *sd_models.checkpoint_tiles()], value='', tooltip="switch to another model in the middle of generation", interactive=False, visible=False) + # create_refresh_button(refiner_checkpoint, sd_models.list_models, lambda: {"choices": sd_models.checkpoint_tiles()}, self.elem_id("checkpoint_refresh")) - refiner_switch_at = gr.Slider(value=0.8, label="Switch at", minimum=0.01, maximum=1.0, step=0.01, elem_id=self.elem_id("switch_at"), tooltip="fraction of sampling steps when the switch to refiner model should happen; 1=never, 0.5=switch in the middle of generation", interactive=False) + refiner_switch_at = gr.Slider(value=0.8, label="Switch at", minimum=0.01, maximum=1.0, step=0.01, elem_id=self.elem_id("switch_at"), tooltip="fraction of sampling steps when the switch to refiner model should happen; 1=never, 0.5=switch in the middle of generation", interactive=False, visible=False) def lookup_checkpoint(title): info = sd_models.get_closet_checkpoint_match(title) diff --git a/modules/sd_models.py b/modules/sd_models.py index 1de7a2ba..a000d57f 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -531,7 +531,6 @@ def get_obj_from_str(string, reload=False): @torch.no_grad() def load_model(checkpoint_info=None, already_loaded_state_dict=None): - from modules import sd_hijack checkpoint_info = checkpoint_info or select_checkpoint() timer = Timer() @@ -603,11 +602,11 @@ def reuse_model_from_already_loaded(sd_model, checkpoint_info, timer): def reload_model_weights(sd_model=None, info=None, forced_reload=False): - return load_model(info) + pass def unload_model_weights(sd_model=None, info=None): - return sd_model + pass def apply_token_merging(sd_model, token_merging_ratio): diff --git a/modules/shared_options.py b/modules/shared_options.py index df83a602..16cad9aa 100644 --- a/modules/shared_options.py +++ b/modules/shared_options.py @@ -168,7 +168,7 @@ options_templates.update(options_section(('training', "Training", "training"), { })) options_templates.update(options_section(('sd', "Stable Diffusion", "sd"), { - "sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": shared_items.list_checkpoint_tiles(shared.opts.sd_checkpoint_dropdown_use_short)}, refresh=shared_items.refresh_checkpoints, infotext='Model hash'), + "sd_model_checkpoint": OptionInfo(None, "(Managed by Forge)", gr.State), "sd_checkpoints_limit": OptionInfo(1, "Maximum number of checkpoints loaded at the same time", gr.Slider, {"minimum": 1, "maximum": 10, "step": 1}), "sd_checkpoints_keep_in_cpu": OptionInfo(True, "Only keep one model on device").info("will keep models other than the currently used one in RAM rather than VRAM"), "sd_checkpoint_cache": OptionInfo(0, "Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}).info("obsolete; set to 0 and use the two settings above instead"), @@ -178,7 +178,7 @@ options_templates.update(options_section(('sd', "Stable Diffusion", "sd"), { "enable_batch_seeds": OptionInfo(True, "Make K-diffusion samplers produce same images in a batch as when making a single image"), "comma_padding_backtrack": OptionInfo(20, "Prompt word wrap length limit", gr.Slider, {"minimum": 0, "maximum": 74, "step": 1}).info("in tokens - for texts shorter than specified, if they don't fit into 75 token limit, move them to the next 75 token chunk"), "sdxl_clip_l_skip": OptionInfo(False, "Clip skip SDXL", gr.Checkbox).info("Enable Clip skip for the secondary clip model in sdxl. Has no effect on SD 1.5 or SD 2.0/2.1."), - "CLIP_stop_at_last_layers": OptionInfo(1, "Clip skip", gr.Slider, {"minimum": 1, "maximum": 12, "step": 1}, infotext="Clip skip").link("wiki", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features#clip-skip").info("ignore last layers of CLIP network; 1 ignores none, 2 ignores one layer"), + "CLIP_stop_at_last_layers": OptionInfo(1, "(Managed by Forge)", gr.State), "upcast_attn": OptionInfo(False, "Upcast cross attention layer to float32"), "randn_source": OptionInfo("GPU", "Random number generator source.", gr.Radio, {"choices": ["GPU", "CPU", "NV"]}, infotext="RNG").info("changes seeds drastically; use CPU to produce the same picture across different videocard vendors; use NV to produce same picture as on NVidia videocards"), "tiling": OptionInfo(False, "Tiling", infotext='Tiling').info("produce a tileable picture"), @@ -204,7 +204,7 @@ image into latent space representation and back. Latent space representation is For img2img, VAE is used to process user's input image before the sampling, and to create an image after sampling. """), "sd_vae_checkpoint_cache": OptionInfo(0, "VAE Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}), - "sd_vae": OptionInfo("Automatic", "SD VAE", gr.Dropdown, lambda: {"choices": shared_items.sd_vae_items()}, refresh=shared_items.refresh_vae_list, infotext='VAE').info("choose VAE model: Automatic = use one with same filename as checkpoint; None = use VAE from checkpoint"), + "sd_vae": OptionInfo("Automatic", "(Managed by Forge)", gr.State), "sd_vae_overrides_per_model_preferences": OptionInfo(True, "Selected VAE overrides per-model preferences").info("you can set per-model VAE either by editing user metadata for checkpoints, or by making the VAE have same name as checkpoint"), "auto_vae_precision_bfloat16": OptionInfo(False, "Automatically convert VAE to bfloat16").info("triggers when a tensor with NaNs is produced in VAE; disabling the option in this case will result in a black square image; if enabled, overrides the option below"), "auto_vae_precision": OptionInfo(True, "Automatically revert VAE to 32-bit floats").info("triggers when a tensor with NaNs is produced in VAE; disabling the option in this case will result in a black square image"), @@ -329,7 +329,7 @@ options_templates.update(options_section(('ui_alternatives', "UI alternatives", options_templates.update(options_section(('ui', "User interface", "ui"), { "localization": OptionInfo("None", "Localization", gr.Dropdown, lambda: {"choices": ["None"] + list(localization.localizations.keys())}, refresh=lambda: localization.list_localizations(cmd_opts.localizations_dir)).needs_reload_ui(), - "quicksettings_list": OptionInfo(["sd_model_checkpoint", "sd_vae", "CLIP_stop_at_last_layers"], "Quicksettings list", ui_components.DropdownMulti, lambda: {"choices": list(shared.opts.data_labels.keys())}).js("info", "settingsHintsShowQuicksettings").info("setting entries that appear at the top of page rather than in settings tab").needs_reload_ui(), + "quick_setting_list": OptionInfo([], "Quicksettings list", ui_components.DropdownMulti, lambda: {"choices": list(shared.opts.data_labels.keys())}).js("info", "settingsHintsShowQuicksettings").info("setting entries that appear at the top of page rather than in settings tab").needs_reload_ui(), "ui_tab_order": OptionInfo([], "UI tab order", ui_components.DropdownMulti, lambda: {"choices": list(shared.tab_names)}).needs_reload_ui(), "hidden_tabs": OptionInfo([], "Hidden UI tabs", ui_components.DropdownMulti, lambda: {"choices": list(shared.tab_names)}).needs_reload_ui(), "ui_reorder_list": OptionInfo([], "UI item order for txt2img/img2img tabs", ui_components.DropdownMulti, lambda: {"choices": list(shared_items.ui_reorder_categories())}).info("selected items appear first").needs_reload_ui(), diff --git a/modules/ui.py b/modules/ui.py index 088f36af..3f2455c3 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -26,6 +26,7 @@ import modules.shared as shared from modules import prompt_parser from modules.infotext_utils import image_from_url_text, PasteField from modules_forge.forge_canvas.canvas import ForgeCanvas, canvas_head +from modules_forge import main_entry create_setting_component = ui_settings.create_setting_component @@ -328,15 +329,15 @@ def create_ui(): hr_resize_x = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize width to", value=0, elem_id="txt2img_hr_resize_x") hr_resize_y = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize height to", value=0, elem_id="txt2img_hr_resize_y") - with FormRow(elem_id="txt2img_hires_fix_row3", variant="compact", visible=opts.hires_fix_show_sampler) as hr_sampler_container: + with FormRow(elem_id="txt2img_hires_fix_row3", variant="compact") as hr_sampler_container: - hr_checkpoint_name = gr.Dropdown(label='Checkpoint', elem_id="hr_checkpoint", choices=["Use same checkpoint"] + modules.sd_models.checkpoint_tiles(use_short=True), value="Use same checkpoint") - create_refresh_button(hr_checkpoint_name, modules.sd_models.list_models, lambda: {"choices": ["Use same checkpoint"] + modules.sd_models.checkpoint_tiles(use_short=True)}, "hr_checkpoint_refresh") + hr_checkpoint_name = gr.Dropdown(label='Checkpoint', elem_id="hr_checkpoint", choices=["Use same checkpoint"], value="Use same checkpoint", visible=False, interactive=False) + # create_refresh_button(hr_checkpoint_name, modules.sd_models.list_models, lambda: {"choices": ["Use same checkpoint"] + modules.sd_models.checkpoint_tiles(use_short=True)}, "hr_checkpoint_refresh") hr_sampler_name = gr.Dropdown(label='Hires sampling method', elem_id="hr_sampler", choices=["Use same sampler"] + sd_samplers.visible_sampler_names(), value="Use same sampler") hr_scheduler = gr.Dropdown(label='Hires schedule type', elem_id="hr_scheduler", choices=["Use same scheduler"] + [x.label for x in sd_schedulers.schedulers], value="Use same scheduler") - with FormRow(elem_id="txt2img_hires_fix_row4", variant="compact", visible=opts.hires_fix_show_prompts) as hr_prompts_container: + with FormRow(elem_id="txt2img_hires_fix_row4", variant="compact") as hr_prompts_container: with gr.Column(scale=80): with gr.Row(): hr_prompt = gr.Textbox(label="Hires prompt", elem_id="hires_prompt", show_label=False, lines=3, placeholder="Prompt for hires fix pass.\nLeave empty to use the same prompt as in first pass.", elem_classes=["prompt"]) @@ -889,8 +890,8 @@ def create_ui(): settings.create_ui(loadsave, dummy_component) interfaces = [ - (txt2img_interface, "txt2img", "txt2img"), - (img2img_interface, "img2img", "img2img"), + (txt2img_interface, "Txt2img", "txt2img"), + (img2img_interface, "Img2img", "img2img"), (extras_interface, "Extras", "extras"), (pnginfo_interface, "PNG Info", "pnginfo"), (modelmerger_ui.blocks, "Checkpoint Merger", "modelmerger"), @@ -941,7 +942,9 @@ def create_ui(): settings.text_settings.change(fn=update_image_cfg_scale_visibility, inputs=[], outputs=[image_cfg_scale]) demo.load(fn=update_image_cfg_scale_visibility, inputs=[], outputs=[image_cfg_scale]) - modelmerger_ui.setup_ui(dummy_component=dummy_component, sd_model_checkpoint_component=settings.component_dict['sd_model_checkpoint']) + modelmerger_ui.setup_ui(dummy_component=dummy_component, sd_model_checkpoint_component=main_entry.sd_model_checkpoint) + + main_entry.forge_main_entry() if ui_settings_from_file != loadsave.ui_settings: loadsave.dump_defaults() diff --git a/modules/ui_settings.py b/modules/ui_settings.py index e750d371..1d5c504c 100644 --- a/modules/ui_settings.py +++ b/modules/ui_settings.py @@ -7,6 +7,7 @@ from modules.shared import opts from modules.ui_components import FormRow from modules.ui_gradio_extensions import reload_javascript from concurrent.futures import ThreadPoolExecutor, as_completed +from modules_forge import main_entry def get_value_for_setting(key): @@ -41,6 +42,9 @@ def create_setting_component(key, is_quicksettings=False): elem_id = f"setting_{key}" + if comp == gr.State: + return gr.State(fun()) + if info.refresh is not None: if is_quicksettings: res = comp(label=info.label, value=fun(), elem_id=elem_id, **(args or {})) @@ -125,7 +129,7 @@ class UiSettings: self.result = gr.HTML(elem_id="settings_result") - self.quicksettings_names = opts.quicksettings_list + self.quicksettings_names = opts.quick_setting_list self.quicksettings_names = {x: i for i, x in enumerate(self.quicksettings_names) if x != 'quicksettings'} self.quicksettings_list = [] @@ -289,6 +293,7 @@ class UiSettings: def add_quicksettings(self): with gr.Row(elem_id="quicksettings", variant="compact"): + main_entry.make_checkpoint_manager_ui() for _i, k, _item in sorted(self.quicksettings_list, key=lambda x: self.quicksettings_names.get(x[1], x[0])): component = create_setting_component(k, is_quicksettings=True) self.component_dict[k] = component @@ -318,12 +323,15 @@ class UiSettings: show_progress=False, ) + def button_set_checkpoint_change(value, dummy): + return value, opts.dumpjson() + button_set_checkpoint = gr.Button('Change checkpoint', elem_id='change_checkpoint', visible=False) button_set_checkpoint.click( - fn=lambda value, _: self.run_settings_single(value, key='sd_model_checkpoint'), + fn=button_set_checkpoint_change, _js="function(v){ var res = desiredCheckpointName; desiredCheckpointName = ''; return [res || v, null]; }", - inputs=[self.component_dict['sd_model_checkpoint'], self.dummy_component], - outputs=[self.component_dict['sd_model_checkpoint'], self.text_settings], + inputs=[main_entry.sd_model_checkpoint, self.dummy_component], + outputs=[main_entry.sd_model_checkpoint, self.text_settings], ) component_keys = [k for k in opts.data_labels.keys() if k in self.component_dict] diff --git a/modules_forge/main_entry.py b/modules_forge/main_entry.py new file mode 100644 index 00000000..36d0a117 --- /dev/null +++ b/modules_forge/main_entry.py @@ -0,0 +1,69 @@ +import gradio as gr + +from modules import shared_items, shared, ui_common, sd_models +from modules import sd_vae as sd_vae_module +from modules_forge import main_thread + + +sd_model_checkpoint: gr.Dropdown = None +sd_vae: gr.Dropdown = None +CLIP_stop_at_last_layers: gr.Slider = None + + +def make_checkpoint_manager_ui(): + global sd_model_checkpoint, sd_vae, CLIP_stop_at_last_layers + + if shared.opts.sd_model_checkpoint in [None, 'None', 'none', '']: + if len(sd_models.checkpoints_list) == 0: + sd_models.list_models() + if len(sd_models.checkpoints_list) > 0: + shared.opts.set('sd_model_checkpoint', next(iter(sd_models.checkpoints_list.keys()))) + + sd_model_checkpoint_args = lambda: {"choices": shared_items.list_checkpoint_tiles(shared.opts.sd_checkpoint_dropdown_use_short)} + sd_model_checkpoint = gr.Dropdown( + value=shared.opts.sd_model_checkpoint, + label="Checkpoint", + **sd_model_checkpoint_args() + ) + ui_common.create_refresh_button(sd_model_checkpoint, shared_items.refresh_checkpoints, sd_model_checkpoint_args, f"forge_refresh_checkpoint") + + sd_vae_args = lambda: {"choices": shared_items.sd_vae_items()} + sd_vae = gr.Dropdown( + value="Automatic", + label="VAE", + **sd_vae_args() + ) + ui_common.create_refresh_button(sd_vae, shared_items.refresh_vae_list, sd_vae_args, f"forge_refresh_vae") + + CLIP_stop_at_last_layers = gr.Slider(label="Clip skip", value=shared.opts.CLIP_stop_at_last_layers, **{"minimum": 1, "maximum": 12, "step": 1}) + + return + + +def checkpoint_change(ckpt_name): + print(f'Checkpoint Selected: {ckpt_name}') + shared.opts.set('sd_model_checkpoint', ckpt_name) + shared.opts.save(shared.config_filename) + + sd_models.load_model() + return + + +def vae_change(vae_name): + print(f'VAE Selected: {vae_name}') + shared.opts.set('sd_vae', vae_name) + sd_vae_module.reload_vae_weights() + return + + +def clip_skip_change(clip_skip): + print(f'CLIP SKIP Selected: {clip_skip}') + shared.opts.set('CLIP_stop_at_last_layers', clip_skip) + return + + +def forge_main_entry(): + sd_model_checkpoint.change(lambda x: main_thread.async_run(checkpoint_change, x), inputs=[sd_model_checkpoint], show_progress=False) + sd_vae.change(lambda x: main_thread.async_run(vae_change, x), inputs=[sd_vae], show_progress=False) + CLIP_stop_at_last_layers.change(lambda x: main_thread.async_run(clip_skip_change, x), inputs=[CLIP_stop_at_last_layers], show_progress=False) + return diff --git a/style.css b/style.css index cd109552..06e6597d 100644 --- a/style.css +++ b/style.css @@ -431,7 +431,6 @@ div.toprow-compact-tools{ } #quicksettings > div, #quicksettings > fieldset{ - max-width: 36em; width: fit-content; flex: 0 1 fit-content; padding: 0;